RobertoBarrosoLuque commited on
Commit
ff3bc98
·
1 Parent(s): 6739aa3

Add interactive plots

Browse files
data/evaluation_results.csv ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model,category,accuracy,precision,recall,num_samples
2
+ Qwen2.5-VL-32B-BASE,masterCategory,0.909,0.9196051103650724,0.909,1000
3
+ Qwen2.5-VL-32B-BASE,gender,0.546,0.9259626959624715,0.546,1000
4
+ Qwen2.5-VL-32B-BASE,subCategory,0.432,0.7070035848765855,0.432,1000
5
+ Qwen2-VL-72B-BASE,masterCategory,0.968968968968969,0.9711267688093789,0.968968968968969,999
6
+ Qwen2-VL-72B-BASE,gender,0.7607607607607607,0.9354341592843324,0.7607607607607607,999
7
+ Qwen2-VL-72B-BASE,subCategory,0.34134134134134136,0.6784829173652965,0.34134134134134136,999
8
+ Qwen2-VL-72B-SFT,masterCategory,0.993993993993994,0.9940108529582213,0.993993993993994,999
9
+ Qwen2-VL-72B-SFT,gender,0.9169169169169169,0.9144956029794004,0.9169169169169169,999
10
+ Qwen2-VL-72B-SFT,subCategory,0.9419419419419419,0.9512743495222181,0.9419419419419419,999
11
+ GPT-5-Mini,masterCategory,0.981,0.9810138759482104,0.981,1000
12
+ GPT-5-Mini,gender,0.907,0.9260515702929443,0.907,1000
13
+ GPT-5-Mini,subCategory,0.897,0.944355065421394,0.897,1000
generate_eval_results.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to generate evaluation results CSV for all model predictions
3
+ """
4
+
5
+ import pandas as pd
6
+ from pathlib import Path
7
+ from src.modules.evals import evaluate_all_categories, extract_metrics
8
+
9
+ DATA_PATH = Path(__file__).parent / "data"
10
+
11
+ # Load test.csv (ground truth)
12
+ test_df = pd.read_csv(DATA_PATH / "test.csv")
13
+
14
+ # Define model prediction files and their display names
15
+ model_files = {
16
+ "Qwen2.5-VL-32B-BASE": "df_pred_FireworksAI_qwen2p5-vl-32b-instruct-ralh0ben.csv",
17
+ "Qwen2-VL-72B-BASE": "df_pred_FireworksAI_qwen2-vl-72b-BASE-instruct-yaxztv7t.csv",
18
+ "Qwen2-VL-72B-SFT": "df_pred_FireworksAI_qwen-72b-SFT-fashion-catalog-oueqouqs.csv",
19
+ "GPT-5-Mini": "df_pred_OpenAI_gpt-5-mini-2025-08-07.csv",
20
+ }
21
+
22
+ # Collect all metrics
23
+ all_metrics = []
24
+
25
+ for model_name, filename in model_files.items():
26
+ pred_file = DATA_PATH / filename
27
+
28
+ if not pred_file.exists():
29
+ print(f"Warning: {filename} not found, skipping...")
30
+ continue
31
+
32
+ print(f"\nEvaluating {model_name}...")
33
+ print("=" * 60)
34
+
35
+ # Load predictions
36
+ pred_df = pd.read_csv(pred_file)
37
+
38
+ # Evaluate all categories
39
+ results = evaluate_all_categories(
40
+ df_ground_truth=test_df,
41
+ df_predictions=pred_df,
42
+ id_col="id",
43
+ categories=["masterCategory", "gender", "subCategory"],
44
+ )
45
+
46
+ # Extract metrics for this model
47
+ model_metrics = extract_metrics(results, model_name)
48
+ all_metrics.extend(model_metrics)
49
+
50
+ # Create DataFrame with all metrics
51
+ metrics_df = pd.DataFrame(all_metrics)
52
+
53
+ # Save to CSV
54
+ output_file = DATA_PATH / "evaluation_results.csv"
55
+ metrics_df.to_csv(output_file, index=False)
56
+
57
+ print(f"\n{'=' * 60}")
58
+ print(f"Evaluation complete! Results saved to: {output_file}")
59
+ print(f"{'=' * 60}")
60
+ print("\nSummary:")
61
+ print(metrics_df.to_string(index=False))
notebooks/01-eda-and-fine-tuning.ipynb CHANGED
@@ -273,10 +273,47 @@
273
  "| `--eval-auto-carveout` | Auto validation split | Always include |"
274
  ]
275
  },
 
 
 
 
 
 
 
 
276
  {
277
  "cell_type": "code",
278
  "execution_count": null,
279
- "id": "22",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  "metadata": {},
281
  "outputs": [],
282
  "source": [
@@ -286,7 +323,7 @@
286
  {
287
  "cell_type": "code",
288
  "execution_count": null,
289
- "id": "23",
290
  "metadata": {},
291
  "outputs": [],
292
  "source": [
@@ -297,7 +334,7 @@
297
  {
298
  "cell_type": "code",
299
  "execution_count": null,
300
- "id": "24",
301
  "metadata": {},
302
  "outputs": [],
303
  "source": []
 
273
  "| `--eval-auto-carveout` | Auto validation split | Always include |"
274
  ]
275
  },
276
+ {
277
+ "cell_type": "markdown",
278
+ "id": "22",
279
+ "metadata": {},
280
+ "source": [
281
+ "##### Fine tune Qwen 2.5 vl 32B"
282
+ ]
283
+ },
284
  {
285
  "cell_type": "code",
286
  "execution_count": null,
287
+ "id": "23",
288
+ "metadata": {},
289
+ "outputs": [],
290
+ "source": [
291
+ "! firectl -a pyroworks create sftj --base-model accounts/fireworks/models/qwen2p5-vl-32b-instruct --dataset accounts/pyroworks/datasets/fashion-catalog-train --output-model qwen-32b-fashion-catalog --display-name \"Qwen2.5-32b-fashion-catalog\" --epochs 3 --learning-rate 0.0001 --early-stop"
292
+ ]
293
+ },
294
+ {
295
+ "cell_type": "code",
296
+ "execution_count": null,
297
+ "id": "24",
298
+ "metadata": {},
299
+ "outputs": [],
300
+ "source": [
301
+ "### Check status of job\n",
302
+ "! firectl -a pyroworks get sftj j588i1qm"
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "markdown",
307
+ "id": "25",
308
+ "metadata": {},
309
+ "source": [
310
+ "##### Fine tune Qwen 2.5 vl 72B"
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "code",
315
+ "execution_count": null,
316
+ "id": "26",
317
  "metadata": {},
318
  "outputs": [],
319
  "source": [
 
323
  {
324
  "cell_type": "code",
325
  "execution_count": null,
326
+ "id": "27",
327
  "metadata": {},
328
  "outputs": [],
329
  "source": [
 
334
  {
335
  "cell_type": "code",
336
  "execution_count": null,
337
+ "id": "28",
338
  "metadata": {},
339
  "outputs": [],
340
  "source": []
notebooks/02-model-evals.ipynb CHANGED
@@ -57,6 +57,14 @@
57
  "cell_type": "markdown",
58
  "id": "3",
59
  "metadata": {},
 
 
 
 
 
 
 
 
60
  "source": [
61
  "#### Run example image through a serverless Qwen VL model to test"
62
  ]
@@ -64,7 +72,7 @@
64
  {
65
  "cell_type": "code",
66
  "execution_count": null,
67
- "id": "4",
68
  "metadata": {},
69
  "outputs": [],
70
  "source": [
@@ -80,7 +88,7 @@
80
  {
81
  "cell_type": "code",
82
  "execution_count": null,
83
- "id": "5",
84
  "metadata": {},
85
  "outputs": [],
86
  "source": [
@@ -96,7 +104,7 @@
96
  {
97
  "cell_type": "code",
98
  "execution_count": null,
99
- "id": "6",
100
  "metadata": {},
101
  "outputs": [],
102
  "source": [
@@ -105,7 +113,7 @@
105
  },
106
  {
107
  "cell_type": "markdown",
108
- "id": "7",
109
  "metadata": {},
110
  "source": [
111
  "*Important*: If you are following through this notebook make sure to replace \"pyroworks\" with your account name"
@@ -113,19 +121,80 @@
113
  },
114
  {
115
  "cell_type": "markdown",
116
- "id": "8",
117
  "metadata": {},
118
  "source": [
119
  "#### Run test set through base OSS model\n",
120
- "1. Create a deployment for accounts/fireworks/models/qwen2-vl-72b-instruct\n",
121
  "2. Check deployment status\n",
122
- "3. Run test set through deployment for base model and save results"
 
 
 
 
 
 
 
 
 
 
 
123
  ]
124
  },
125
  {
126
  "cell_type": "code",
127
  "execution_count": null,
128
- "id": "9",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  "metadata": {},
130
  "outputs": [],
131
  "source": [
@@ -135,7 +204,7 @@
135
  {
136
  "cell_type": "code",
137
  "execution_count": null,
138
- "id": "10",
139
  "metadata": {},
140
  "outputs": [],
141
  "source": [
@@ -145,7 +214,7 @@
145
  {
146
  "cell_type": "code",
147
  "execution_count": null,
148
- "id": "11",
149
  "metadata": {},
150
  "outputs": [],
151
  "source": [
@@ -167,7 +236,7 @@
167
  },
168
  {
169
  "cell_type": "markdown",
170
- "id": "12",
171
  "metadata": {},
172
  "source": [
173
  "#### Run test set through fine tuned FW Qwen model\n",
@@ -179,7 +248,7 @@
179
  {
180
  "cell_type": "code",
181
  "execution_count": null,
182
- "id": "13",
183
  "metadata": {},
184
  "outputs": [],
185
  "source": [
@@ -189,7 +258,7 @@
189
  {
190
  "cell_type": "code",
191
  "execution_count": null,
192
- "id": "14",
193
  "metadata": {},
194
  "outputs": [],
195
  "source": [
@@ -199,7 +268,7 @@
199
  {
200
  "cell_type": "code",
201
  "execution_count": null,
202
- "id": "15",
203
  "metadata": {},
204
  "outputs": [],
205
  "source": [
@@ -221,7 +290,7 @@
221
  },
222
  {
223
  "cell_type": "markdown",
224
- "id": "16",
225
  "metadata": {},
226
  "source": [
227
  "#### Run test set through closed source model"
@@ -230,7 +299,7 @@
230
  {
231
  "cell_type": "code",
232
  "execution_count": null,
233
- "id": "17",
234
  "metadata": {},
235
  "outputs": [],
236
  "source": [
@@ -253,7 +322,7 @@
253
  },
254
  {
255
  "cell_type": "markdown",
256
- "id": "18",
257
  "metadata": {},
258
  "source": [
259
  "### Compare eval metrics across models"
@@ -262,7 +331,7 @@
262
  {
263
  "cell_type": "code",
264
  "execution_count": null,
265
- "id": "19",
266
  "metadata": {},
267
  "outputs": [],
268
  "source": [
@@ -283,7 +352,7 @@
283
  {
284
  "cell_type": "code",
285
  "execution_count": null,
286
- "id": "20",
287
  "metadata": {},
288
  "outputs": [],
289
  "source": [
@@ -298,7 +367,7 @@
298
  {
299
  "cell_type": "code",
300
  "execution_count": null,
301
- "id": "21",
302
  "metadata": {},
303
  "outputs": [],
304
  "source": [
@@ -335,7 +404,7 @@
335
  {
336
  "cell_type": "code",
337
  "execution_count": null,
338
- "id": "22",
339
  "metadata": {},
340
  "outputs": [],
341
  "source": [
@@ -345,7 +414,7 @@
345
  {
346
  "cell_type": "code",
347
  "execution_count": null,
348
- "id": "23",
349
  "metadata": {},
350
  "outputs": [],
351
  "source": [
 
57
  "cell_type": "markdown",
58
  "id": "3",
59
  "metadata": {},
60
+ "source": [
61
+ "**Note: if using this notebook make sure to replace \"pyroworks\" with your account name**"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "markdown",
66
+ "id": "4",
67
+ "metadata": {},
68
  "source": [
69
  "#### Run example image through a serverless Qwen VL model to test"
70
  ]
 
72
  {
73
  "cell_type": "code",
74
  "execution_count": null,
75
+ "id": "5",
76
  "metadata": {},
77
  "outputs": [],
78
  "source": [
 
88
  {
89
  "cell_type": "code",
90
  "execution_count": null,
91
+ "id": "6",
92
  "metadata": {},
93
  "outputs": [],
94
  "source": [
 
104
  {
105
  "cell_type": "code",
106
  "execution_count": null,
107
+ "id": "7",
108
  "metadata": {},
109
  "outputs": [],
110
  "source": [
 
113
  },
114
  {
115
  "cell_type": "markdown",
116
+ "id": "8",
117
  "metadata": {},
118
  "source": [
119
  "*Important*: If you are following through this notebook make sure to replace \"pyroworks\" with your account name"
 
121
  },
122
  {
123
  "cell_type": "markdown",
124
+ "id": "9",
125
  "metadata": {},
126
  "source": [
127
  "#### Run test set through base OSS model\n",
128
+ "1. Create a deployment for the model for faster inference\n",
129
  "2. Check deployment status\n",
130
+ "3. Run test set through deployment for base model and save results\n",
131
+ "\n",
132
+ "NOTE:make sure to delete or scale down deployment when done to avoid costs"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "markdown",
137
+ "id": "10",
138
+ "metadata": {},
139
+ "source": [
140
+ "##### Run inference on Qwen 2.5 VL 32B\n",
141
+ "m"
142
  ]
143
  },
144
  {
145
  "cell_type": "code",
146
  "execution_count": null,
147
+ "id": "11",
148
+ "metadata": {},
149
+ "outputs": [],
150
+ "source": [
151
+ "! firectl create deployment accounts/fireworks/models/qwen2p5-vl-32b-instruct --min-replica-count 1 --max-replica-count 1 --accelerator-type NVIDIA_H100_80GB"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "execution_count": null,
157
+ "id": "12",
158
+ "metadata": {},
159
+ "outputs": [],
160
+ "source": [
161
+ "! firectl -a pyroworks get deployment itmxuke2"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": null,
167
+ "id": "13",
168
+ "metadata": {},
169
+ "outputs": [],
170
+ "source": [
171
+ "df_predictions_qwen_base_32b = await run_inference_on_dataframe_async(\n",
172
+ " df_test,\n",
173
+ " model=\"accounts/pyroworks/deployedModels/qwen2p5-vl-32b-instruct-ralh0ben\",\n",
174
+ " provider=\"FireworksAI\",\n",
175
+ " api_key=FIREWORKS_API_KEY,\n",
176
+ " max_concurrent_requests=20, # Adjust based on rate limits\n",
177
+ ")\n",
178
+ "\n",
179
+ "results_qwen_base_32b = evaluate_all_categories(\n",
180
+ " df_ground_truth=df_test,\n",
181
+ " df_predictions=df_predictions_qwen_base_32b,\n",
182
+ " categories=[\"masterCategory\", \"gender\", \"subCategory\"]\n",
183
+ ")"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "markdown",
188
+ "id": "14",
189
+ "metadata": {},
190
+ "source": [
191
+ "##### Run inference on Qwen 2.5 VL 72B"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": null,
197
+ "id": "15",
198
  "metadata": {},
199
  "outputs": [],
200
  "source": [
 
204
  {
205
  "cell_type": "code",
206
  "execution_count": null,
207
+ "id": "16",
208
  "metadata": {},
209
  "outputs": [],
210
  "source": [
 
214
  {
215
  "cell_type": "code",
216
  "execution_count": null,
217
+ "id": "17",
218
  "metadata": {},
219
  "outputs": [],
220
  "source": [
 
236
  },
237
  {
238
  "cell_type": "markdown",
239
+ "id": "18",
240
  "metadata": {},
241
  "source": [
242
  "#### Run test set through fine tuned FW Qwen model\n",
 
248
  {
249
  "cell_type": "code",
250
  "execution_count": null,
251
+ "id": "19",
252
  "metadata": {},
253
  "outputs": [],
254
  "source": [
 
258
  {
259
  "cell_type": "code",
260
  "execution_count": null,
261
+ "id": "20",
262
  "metadata": {},
263
  "outputs": [],
264
  "source": [
 
268
  {
269
  "cell_type": "code",
270
  "execution_count": null,
271
+ "id": "21",
272
  "metadata": {},
273
  "outputs": [],
274
  "source": [
 
290
  },
291
  {
292
  "cell_type": "markdown",
293
+ "id": "22",
294
  "metadata": {},
295
  "source": [
296
  "#### Run test set through closed source model"
 
299
  {
300
  "cell_type": "code",
301
  "execution_count": null,
302
+ "id": "23",
303
  "metadata": {},
304
  "outputs": [],
305
  "source": [
 
322
  },
323
  {
324
  "cell_type": "markdown",
325
+ "id": "24",
326
  "metadata": {},
327
  "source": [
328
  "### Compare eval metrics across models"
 
331
  {
332
  "cell_type": "code",
333
  "execution_count": null,
334
+ "id": "25",
335
  "metadata": {},
336
  "outputs": [],
337
  "source": [
 
352
  {
353
  "cell_type": "code",
354
  "execution_count": null,
355
+ "id": "26",
356
  "metadata": {},
357
  "outputs": [],
358
  "source": [
 
367
  {
368
  "cell_type": "code",
369
  "execution_count": null,
370
+ "id": "27",
371
  "metadata": {},
372
  "outputs": [],
373
  "source": [
 
404
  {
405
  "cell_type": "code",
406
  "execution_count": null,
407
+ "id": "28",
408
  "metadata": {},
409
  "outputs": [],
410
  "source": [
 
414
  {
415
  "cell_type": "code",
416
  "execution_count": null,
417
+ "id": "29",
418
  "metadata": {},
419
  "outputs": [],
420
  "source": [
src/app.py CHANGED
@@ -11,6 +11,11 @@ from dotenv import load_dotenv
11
  from src.modules.vlm_inference import analyze_product_image
12
  from src.modules.data_processing import pil_to_base64
13
  from src.modules.evals import run_inference_on_dataframe
 
 
 
 
 
14
 
15
  load_dotenv()
16
 
@@ -25,6 +30,7 @@ MAX_CONCURRENT_REQUESTS = 10
25
 
26
  FILE_PATH = Path(__file__).parents[1]
27
  ASSETS_PATH = FILE_PATH / "assets"
 
28
  _NOTEBOOK_PATH = "https://huggingface.co/spaces/fireworks-ai/catalog-extract/blob/main/notebooks/01-eda-and-fine-tuning.ipynb"
29
 
30
  # Prompt style display names
@@ -56,13 +62,10 @@ def analyze_single_image(
56
  return "No image provided", "", "", ""
57
 
58
  try:
59
- # Convert PIL Image to base64
60
  img_b64 = pil_to_base64(image_input)
61
 
62
- # Determine provider from model name
63
  model_id = AVAILABLE_MODELS[model_name]
64
  api_key = os.getenv("FIREWORKS_API_KEY")
65
- # Map display name to prompt key
66
  prompt_style = (
67
  PROMPT_STYLES.get(prompt_style_display) if prompt_style_display else None
68
  )
@@ -304,7 +307,7 @@ def create_demo_interface():
304
  outputs=[image_input],
305
  )
306
 
307
- # Tab 3: Model Evaluation (show uploaded charts)
308
  with gr.TabItem("📈 Model Performance"):
309
  gr.Markdown(
310
  """
@@ -316,17 +319,60 @@ def create_demo_interface():
316
  """
317
  )
318
 
319
- # Display uploaded evaluation charts
320
- with gr.Row():
321
- gr.Image(
322
- value=str(ASSETS_PATH / "Accuracy.png"),
323
- interactive=False,
324
- show_label=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  )
326
- gr.Image(
327
- value=str(ASSETS_PATH / "Accuracy-precision-recall.png"),
328
- interactive=False,
329
- show_label=False,
 
 
 
 
 
330
  )
331
 
332
  with gr.Row():
 
11
  from src.modules.vlm_inference import analyze_product_image
12
  from src.modules.data_processing import pil_to_base64
13
  from src.modules.evals import run_inference_on_dataframe
14
+ from src.modules.viz import (
15
+ load_evaluation_data,
16
+ create_accuracy_plot,
17
+ create_precision_recall_plot,
18
+ )
19
 
20
  load_dotenv()
21
 
 
30
 
31
  FILE_PATH = Path(__file__).parents[1]
32
  ASSETS_PATH = FILE_PATH / "assets"
33
+ DATA_PATH = FILE_PATH / "data"
34
  _NOTEBOOK_PATH = "https://huggingface.co/spaces/fireworks-ai/catalog-extract/blob/main/notebooks/01-eda-and-fine-tuning.ipynb"
35
 
36
  # Prompt style display names
 
62
  return "No image provided", "", "", ""
63
 
64
  try:
 
65
  img_b64 = pil_to_base64(image_input)
66
 
 
67
  model_id = AVAILABLE_MODELS[model_name]
68
  api_key = os.getenv("FIREWORKS_API_KEY")
 
69
  prompt_style = (
70
  PROMPT_STYLES.get(prompt_style_display) if prompt_style_display else None
71
  )
 
307
  outputs=[image_input],
308
  )
309
 
310
+ # Tab 3: Model Evaluation (interactive charts)
311
  with gr.TabItem("📈 Model Performance"):
312
  gr.Markdown(
313
  """
 
319
  """
320
  )
321
 
322
+ eval_df = load_evaluation_data(DATA_PATH)
323
+
324
+ if eval_df is not None:
325
+ all_models = eval_df["model"].unique().tolist()
326
+ all_categories = eval_df["category"].unique().tolist()
327
+
328
+ with gr.Row():
329
+ model_filter = gr.CheckboxGroup(
330
+ choices=all_models,
331
+ value=all_models,
332
+ label="Select Models to Display",
333
+ interactive=True,
334
+ )
335
+ category_filter = gr.CheckboxGroup(
336
+ choices=all_categories,
337
+ value=all_categories,
338
+ label="Select Categories to Display",
339
+ interactive=True,
340
+ )
341
+ with gr.Row():
342
+ accuracy_plot = gr.Plot()
343
+
344
+ with gr.Row():
345
+ precision_recall_plot = gr.Plot()
346
+
347
+ def update_plots(selected_models, selected_categories):
348
+ acc_fig = create_accuracy_plot(
349
+ eval_df, selected_models, selected_categories
350
+ )
351
+ pr_fig = create_precision_recall_plot(
352
+ eval_df, selected_models, selected_categories
353
+ )
354
+ return acc_fig, pr_fig
355
+
356
+ model_filter.change(
357
+ fn=update_plots,
358
+ inputs=[model_filter, category_filter],
359
+ outputs=[accuracy_plot, precision_recall_plot],
360
+ )
361
+
362
+ category_filter.change(
363
+ fn=update_plots,
364
+ inputs=[model_filter, category_filter],
365
+ outputs=[accuracy_plot, precision_recall_plot],
366
  )
367
+
368
+ demo.load(
369
+ fn=update_plots,
370
+ inputs=[model_filter, category_filter],
371
+ outputs=[accuracy_plot, precision_recall_plot],
372
+ )
373
+ else:
374
+ gr.Markdown(
375
+ "⚠️ Evaluation data not found. Please run `python generate_eval_results.py` first."
376
  )
377
 
378
  with gr.Row():
src/modules/data_processing.py CHANGED
@@ -36,13 +36,9 @@ def image_to_base64(img_bytes):
36
  def create_training_example(row):
37
  """Create a training example with both classification and description tasks"""
38
 
39
- # Convert image to base64
40
  img_b64 = image_to_base64(row["image"])
41
-
42
- # Create multi-task prompt combining classification and description
43
  user_prompt = "Analyze this fashion product image and provide: 1) Master category, 2) Gender, 3) Sub-category, and 4) A detailed description."
44
 
45
- # Create structured response with all classification info
46
  assistant_response = f"""
47
  Master Category: {row['masterCategory']}
48
  Gender: {row['gender']}
@@ -50,7 +46,6 @@ def create_training_example(row):
50
 
51
  Description: This is a {row['gender'].lower()} {row['subCategory'].lower()} from the {row['masterCategory'].lower()} category."""
52
 
53
- # Format as OpenAI-compatible messages
54
  return {
55
  "messages": [
56
  {
 
36
  def create_training_example(row):
37
  """Create a training example with both classification and description tasks"""
38
 
 
39
  img_b64 = image_to_base64(row["image"])
 
 
40
  user_prompt = "Analyze this fashion product image and provide: 1) Master category, 2) Gender, 3) Sub-category, and 4) A detailed description."
41
 
 
42
  assistant_response = f"""
43
  Master Category: {row['masterCategory']}
44
  Gender: {row['gender']}
 
46
 
47
  Description: This is a {row['gender'].lower()} {row['subCategory'].lower()} from the {row['masterCategory'].lower()} category."""
48
 
 
49
  return {
50
  "messages": [
51
  {
src/modules/viz.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib
4
+
5
+ matplotlib.use("Agg")
6
+
7
+
8
+ def load_evaluation_data(data_path) -> pd.DataFrame:
9
+ """Load evaluation results from CSV"""
10
+ eval_file = data_path / "evaluation_results.csv"
11
+ if eval_file.exists():
12
+ return pd.read_csv(eval_file)
13
+ return None
14
+
15
+
16
+ def get_model_style(model_name):
17
+ """
18
+ Get color and hatch pattern for a model
19
+
20
+ Color scheme:
21
+ - GPT models: Gray (#808080)
22
+ - Qwen2.5-VL-32B: Light purple (#9B87E8) - BASE solid, SFT with pattern
23
+ - Qwen2-VL-72B: Medium blue (#5B7FD8) - BASE solid, SFT with pattern
24
+
25
+ Returns:
26
+ tuple: (color, hatch_pattern)
27
+ """
28
+ if "GPT" in model_name or "gpt" in model_name:
29
+ return "#808080", None
30
+
31
+ if "Qwen2.5" in model_name or "qwen2p5" in model_name or "32B" in model_name:
32
+ if "SFT" in model_name:
33
+ return "#9B87E8", "///"
34
+ else:
35
+ return "#9B87E8", None
36
+
37
+ if "Qwen2" in model_name or "72B" in model_name:
38
+ if "SFT" in model_name:
39
+ return "#5B7FD8", "///"
40
+ else:
41
+ return "#5B7FD8", None
42
+
43
+ return "#6B4DB8", None
44
+
45
+
46
+ def create_accuracy_plot(
47
+ eval_df: pd.DataFrame,
48
+ selected_models: list = None,
49
+ selected_categories: list = None,
50
+ ):
51
+ """
52
+ Create bar chart of accuracy by category, colored by model
53
+
54
+ Args:
55
+ eval_df: DataFrame with evaluation results
56
+ selected_models: List of models to display (None for all)
57
+ selected_categories: List of categories to display (None for all)
58
+
59
+ Returns:
60
+ matplotlib figure
61
+ """
62
+ if eval_df is None:
63
+ return None
64
+
65
+ # Filter data
66
+ df_filtered = eval_df.copy()
67
+ if selected_models:
68
+ df_filtered = df_filtered[df_filtered["model"].isin(selected_models)]
69
+ if selected_categories:
70
+ df_filtered = df_filtered[df_filtered["category"].isin(selected_categories)]
71
+
72
+ # Create figure
73
+ fig, ax = plt.subplots(figsize=(12, 6))
74
+
75
+ # Get unique categories and models
76
+ categories = df_filtered["category"].unique()
77
+ models = df_filtered["model"].unique()
78
+
79
+ # Set up bar positions
80
+ x = range(len(categories))
81
+ width = 0.8 / len(models)
82
+
83
+ for i, model in enumerate(models):
84
+ model_data = df_filtered[df_filtered["model"] == model]
85
+ accuracies = [
86
+ model_data[model_data["category"] == cat]["accuracy"].values[0]
87
+ for cat in categories
88
+ ]
89
+
90
+ color, hatch = get_model_style(model)
91
+
92
+ offset = (i - len(models) / 2) * width + width / 2
93
+ ax.bar(
94
+ [xi + offset for xi in x],
95
+ accuracies,
96
+ width,
97
+ label=model,
98
+ color=color,
99
+ hatch=hatch,
100
+ alpha=0.8,
101
+ edgecolor="white",
102
+ linewidth=1.2,
103
+ )
104
+
105
+ # Customize plot
106
+ ax.set_xlabel("Category", fontsize=12, fontweight="bold")
107
+ ax.set_ylabel("Accuracy", fontsize=12, fontweight="bold")
108
+ ax.set_title("Model Accuracy by Category", fontsize=14, fontweight="bold")
109
+ ax.set_xticks(x)
110
+ ax.set_xticklabels(categories, rotation=0)
111
+ ax.set_ylim(0, 1.0)
112
+ ax.legend(loc="lower right", framealpha=0.9)
113
+ ax.grid(axis="y", alpha=0.3, linestyle="--")
114
+
115
+ plt.tight_layout()
116
+ return fig
117
+
118
+
119
+ def create_precision_recall_plot(
120
+ eval_df: pd.DataFrame,
121
+ selected_models: list = None,
122
+ selected_categories: list = None,
123
+ ):
124
+ """
125
+ Create subplot with precision and recall by category, colored by model
126
+
127
+ Args:
128
+ eval_df: DataFrame with evaluation results
129
+ selected_models: List of models to display (None for all)
130
+ selected_categories: List of categories to display (None for all)
131
+
132
+ Returns:
133
+ matplotlib figure
134
+ """
135
+ if eval_df is None:
136
+ return None
137
+
138
+ # Filter data
139
+ df_filtered = eval_df.copy()
140
+ if selected_models:
141
+ df_filtered = df_filtered[df_filtered["model"].isin(selected_models)]
142
+ if selected_categories:
143
+ df_filtered = df_filtered[df_filtered["category"].isin(selected_categories)]
144
+
145
+ # Create figure with subplots
146
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
147
+
148
+ # Get unique categories and models
149
+ categories = df_filtered["category"].unique()
150
+ models = df_filtered["model"].unique()
151
+
152
+ # Set up bar positions
153
+ x = range(len(categories))
154
+ width = 0.8 / len(models)
155
+
156
+ # Plot precision bars
157
+ for i, model in enumerate(models):
158
+ model_data = df_filtered[df_filtered["model"] == model]
159
+ precisions = [
160
+ model_data[model_data["category"] == cat]["precision"].values[0]
161
+ for cat in categories
162
+ ]
163
+
164
+ # Get color and pattern for this model
165
+ color, hatch = get_model_style(model)
166
+
167
+ offset = (i - len(models) / 2) * width + width / 2
168
+ ax1.bar(
169
+ [xi + offset for xi in x],
170
+ precisions,
171
+ width,
172
+ label=model,
173
+ color=color,
174
+ hatch=hatch,
175
+ alpha=0.8,
176
+ edgecolor="white",
177
+ linewidth=1.2,
178
+ )
179
+
180
+ # Customize precision plot
181
+ ax1.set_xlabel("Category", fontsize=12, fontweight="bold")
182
+ ax1.set_ylabel("Precision", fontsize=12, fontweight="bold")
183
+ ax1.set_title("Model Precision by Category", fontsize=14, fontweight="bold")
184
+ ax1.set_xticks(x)
185
+ ax1.set_xticklabels(categories, rotation=0)
186
+ ax1.set_ylim(0, 1.0)
187
+ ax1.legend(loc="lower right", framealpha=0.9)
188
+ ax1.grid(axis="y", alpha=0.3, linestyle="--")
189
+
190
+ # Plot recall bars
191
+ for i, model in enumerate(models):
192
+ model_data = df_filtered[df_filtered["model"] == model]
193
+ recalls = [
194
+ model_data[model_data["category"] == cat]["recall"].values[0]
195
+ for cat in categories
196
+ ]
197
+
198
+ # Get color and pattern for this model
199
+ color, hatch = get_model_style(model)
200
+
201
+ offset = (i - len(models) / 2) * width + width / 2
202
+ ax2.bar(
203
+ [xi + offset for xi in x],
204
+ recalls,
205
+ width,
206
+ label=model,
207
+ color=color,
208
+ hatch=hatch,
209
+ alpha=0.8,
210
+ edgecolor="white",
211
+ linewidth=1.2,
212
+ )
213
+
214
+ ax2.set_xlabel("Category", fontsize=12, fontweight="bold")
215
+ ax2.set_ylabel("Recall", fontsize=12, fontweight="bold")
216
+ ax2.set_title("Model Recall by Category", fontsize=14, fontweight="bold")
217
+ ax2.set_xticks(x)
218
+ ax2.set_xticklabels(categories, rotation=0)
219
+ ax2.set_ylim(0, 1.0)
220
+ ax2.legend(loc="lower right", framealpha=0.9)
221
+ ax2.grid(axis="y", alpha=0.3, linestyle="--")
222
+
223
+ plt.tight_layout()
224
+ return fig