RobertoBarrosoLuque commited on
Commit
b32f568
·
1 Parent(s): b5d7c36

Update frontend with different prompts and cleanup

Browse files
configs/prompt_library.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ concise:
2
+ system: "You are an expert e-commerce fashion catalog assistant specializing in product classification and data management."
3
+ user: |
4
+ Analyze this fashion product image for internal catalog management.
5
+
6
+ Provide classification and a forconcise, factual description focusing on:
7
+ - Product type and key identifying features
8
+ - Essential attributes (color, style, material if visible)
9
+
10
+ Keep the description brief and functional (1-2 sentences maximum).
11
+
12
+ descriptive:
13
+ system: "You are an expert e-commerce fashion copywriter who creates engaging, conversion-focused product descriptions."
14
+ user: |
15
+ Analyze this fashion product image for our customer-facing website.
16
+
17
+ Provide classification and a descriptive product description that:
18
+ - Highlights key features and visual appeal
19
+ - Uses vivid, engaging language that attracts shoppers
20
+ - Emphasizes style and benefits
21
+ - Stays concise (2-3 sentences maximum)
22
+
23
+ Write in an enthusiastic, customer-friendly tone.
24
+
25
+ explanatory:
26
+ system: "You are an expert fashion consultant providing comprehensive product information to customer service representatives."
27
+ user: |
28
+ Analyze this fashion product image to help customer service agents assist shoppers.
29
+
30
+ Provide classification and a detailed, comprehensive description that includes:
31
+ - Complete product features and construction details
32
+ - Material composition and quality indicators (if visible)
33
+ - Styling suggestions and outfit pairing ideas
34
+ - Appropriate occasions and use cases
35
+ - Care considerations if applicable
36
+
37
+ Use 3-5 sentences. Be thorough and informative to help agents answer any customer questions.
src/app.py CHANGED
@@ -18,15 +18,24 @@ AVAILABLE_MODELS = {
18
  "Llama Scout": "accounts/fireworks/models/llama4-scout-instruct-basic",
19
  }
20
 
21
- EXAMPLE_IMAGES_DIR = Path("data/examples")
22
  MAX_CONCURRENT_REQUESTS = 10
23
 
24
  FILE_PATH = Path(__file__).parents[1]
25
  ASSETS_PATH = FILE_PATH / "assets"
26
 
 
 
 
 
 
 
 
27
 
28
  def analyze_single_image(
29
- image_input, model_name: str, api_key: Optional[str] = None
 
 
 
30
  ) -> tuple[str, str, str, str]:
31
  """
32
  Process a single product image and return classification results
@@ -35,6 +44,7 @@ def analyze_single_image(
35
  image_input: PIL Image or file path
36
  model_name: Selected model name
37
  api_key: Optional API key override
 
38
 
39
  Returns:
40
  tuple: (master_category, gender, sub_category, description)
@@ -53,8 +63,17 @@ def analyze_single_image(
53
  if api_key is None:
54
  api_key = os.getenv("FIREWORKS_API_KEY")
55
 
 
 
 
 
 
56
  result = analyze_product_image(
57
- image_url=img_b64, model=model_id, api_key=api_key, provider="Fireworks"
 
 
 
 
58
  )
59
 
60
  # Format results
@@ -75,7 +94,7 @@ def process_batch_dataset(
75
  model_name: str,
76
  api_key: Optional[str] = None,
77
  max_concurrent: int = MAX_CONCURRENT_REQUESTS,
78
- ) -> tuple[pd.DataFrame, str]:
79
  """
80
  Process uploaded CSV dataset with product images
81
 
@@ -218,14 +237,19 @@ def create_demo_interface():
218
  value=list(AVAILABLE_MODELS.keys())[0],
219
  label="Select Model",
220
  )
 
 
 
 
 
221
  api_key_input = gr.Textbox(
222
  label="API Key",
223
  type="password",
224
  )
225
 
226
  with gr.Tabs():
227
- with gr.TabItem("📸 Single Image Analysis"):
228
- gr.Markdown("### Upload a product image for instant classification")
229
 
230
  with gr.Row():
231
  # Left column - Input
@@ -265,7 +289,12 @@ def create_demo_interface():
265
  # Wire up single image analysis
266
  analyze_btn.click(
267
  fn=analyze_single_image,
268
- inputs=[image_input, model_selector, api_key_input],
 
 
 
 
 
269
  outputs=[
270
  master_category_output,
271
  gender_output,
@@ -281,47 +310,6 @@ def create_demo_interface():
281
  outputs=[image_input],
282
  )
283
 
284
- with gr.Row():
285
- # Left - Upload
286
- with gr.Column(scale=1):
287
- dataset_upload = gr.File(
288
- label="Upload Dataset (CSV)", file_types=[".csv"]
289
- )
290
- concurrent_slider = gr.Slider(
291
- minimum=1,
292
- maximum=50,
293
- value=10,
294
- step=1,
295
- label="Concurrent Requests",
296
- info="Higher = faster but may hit rate limits",
297
- )
298
- process_btn = gr.Button(
299
- "⚡ Process Dataset", variant="primary", size="lg"
300
- )
301
-
302
- # Right - Results summary
303
- with gr.Column(scale=1):
304
- summary_output = gr.Textbox(
305
- label="Processing Summary", interactive=False, lines=8
306
- )
307
-
308
- # Results dataframe
309
- results_dataframe = gr.Dataframe(
310
- label="Classification Results", interactive=False, wrap=True
311
- )
312
-
313
- # Wire up batch processing
314
- process_btn.click(
315
- fn=process_batch_dataset,
316
- inputs=[
317
- dataset_upload,
318
- model_selector,
319
- api_key_input,
320
- concurrent_slider,
321
- ],
322
- outputs=[results_dataframe, summary_output],
323
- )
324
-
325
  # Tab 3: Model Evaluation (show uploaded charts)
326
  with gr.TabItem("📈 Model Performance"):
327
  gr.Markdown(
 
18
  "Llama Scout": "accounts/fireworks/models/llama4-scout-instruct-basic",
19
  }
20
 
 
21
  MAX_CONCURRENT_REQUESTS = 10
22
 
23
  FILE_PATH = Path(__file__).parents[1]
24
  ASSETS_PATH = FILE_PATH / "assets"
25
 
26
+ # Prompt style display names
27
+ PROMPT_STYLES = {
28
+ "Data Management": "concise",
29
+ "Website/Sales": "descriptive",
30
+ "Customer Support": "explanatory",
31
+ }
32
+
33
 
34
  def analyze_single_image(
35
+ image_input,
36
+ model_name: str,
37
+ api_key: Optional[str] = None,
38
+ prompt_style_display: Optional[str] = None,
39
  ) -> tuple[str, str, str, str]:
40
  """
41
  Process a single product image and return classification results
 
44
  image_input: PIL Image or file path
45
  model_name: Selected model name
46
  api_key: Optional API key override
47
+ prompt_style_display: Display name for prompt style (e.g., "Data Management")
48
 
49
  Returns:
50
  tuple: (master_category, gender, sub_category, description)
 
63
  if api_key is None:
64
  api_key = os.getenv("FIREWORKS_API_KEY")
65
 
66
+ # Map display name to prompt key
67
+ prompt_style = (
68
+ PROMPT_STYLES.get(prompt_style_display) if prompt_style_display else None
69
+ )
70
+
71
  result = analyze_product_image(
72
+ image_url=img_b64,
73
+ model=model_id,
74
+ api_key=api_key,
75
+ provider="Fireworks",
76
+ prompt_style=prompt_style,
77
  )
78
 
79
  # Format results
 
94
  model_name: str,
95
  api_key: Optional[str] = None,
96
  max_concurrent: int = MAX_CONCURRENT_REQUESTS,
97
+ ) -> tuple[Optional[pd.DataFrame], str]:
98
  """
99
  Process uploaded CSV dataset with product images
100
 
 
237
  value=list(AVAILABLE_MODELS.keys())[0],
238
  label="Select Model",
239
  )
240
+ prompt_selector = gr.Dropdown(
241
+ choices=list(PROMPT_STYLES.keys()),
242
+ value="Website/Sales",
243
+ label="Description Style",
244
+ )
245
  api_key_input = gr.Textbox(
246
  label="API Key",
247
  type="password",
248
  )
249
 
250
  with gr.Tabs():
251
+ with gr.TabItem("📸 Image Analysis 📸 "):
252
+ gr.Markdown("### Upload a product image or select from table below")
253
 
254
  with gr.Row():
255
  # Left column - Input
 
289
  # Wire up single image analysis
290
  analyze_btn.click(
291
  fn=analyze_single_image,
292
+ inputs=[
293
+ image_input,
294
+ model_selector,
295
+ api_key_input,
296
+ prompt_selector,
297
+ ],
298
  outputs=[
299
  master_category_output,
300
  gender_output,
 
310
  outputs=[image_input],
311
  )
312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  # Tab 3: Model Evaluation (show uploaded charts)
314
  with gr.TabItem("📈 Model Performance"):
315
  gr.Markdown(
src/modules/constants.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from pathlib import Path
3
+
4
+ _PATH_TO_CONFIGS = Path(__file__).parents[2] / "configs" / "prompt_library.yaml"
5
+
6
+ with open(_PATH_TO_CONFIGS, "r") as f:
7
+ PROMPT_LIBRARY = yaml.safe_load(f)
src/modules/vlm_inference.py CHANGED
@@ -2,9 +2,11 @@ import os
2
  from openai import OpenAI, AsyncOpenAI
3
  from pydantic import BaseModel, Field
4
  from typing import Optional, Literal
 
5
 
6
  SYSTEM_PROMPT = """
7
- You are a fashion product analyst. Classify products and generate detailed descriptions based on images.
 
8
  """
9
  USER_PROMPT = """
10
  Analyze this fashion product image and provide:
@@ -73,6 +75,7 @@ def analyze_product_image(
73
  model: str = "accounts/fireworks/models/qwen2p5-vl-72b-instruct",
74
  api_key: Optional[str] = None,
75
  provider: str = "Fireworks",
 
76
  ) -> ProductClassification:
77
  """
78
  Analyze a fashion product image using VLM with structured output
@@ -82,6 +85,7 @@ def analyze_product_image(
82
  model: Model to use for inference (default: Qwen2.5 VL 72B)
83
  api_key: Fireworks API key (defaults to FIREWORKS_API_KEY env variable)
84
  provider: Provider to use for inference (default: Fireworks)
 
85
 
86
  Returns:
87
  ProductClassification: Structured classification and description
@@ -98,16 +102,24 @@ def analyze_product_image(
98
  else:
99
  raise ValueError(f"Unknown provider: {provider}")
100
 
 
 
 
 
 
 
 
 
101
  # Call the API with structured output
102
  completion = client.beta.chat.completions.parse(
103
  model=model,
104
  messages=[
105
- {"role": "system", "content": SYSTEM_PROMPT},
106
  {
107
  "role": "user",
108
  "content": [
109
  {"type": "image_url", "image_url": {"url": image_url}},
110
- {"type": "text", "text": USER_PROMPT},
111
  ],
112
  },
113
  ],
@@ -123,6 +135,7 @@ async def analyze_product_image_async(
123
  model: str = "accounts/fireworks/models/qwen2p5-vl-72b-instruct",
124
  api_key: Optional[str] = None,
125
  provider: str = "Fireworks",
 
126
  ) -> ProductClassification:
127
  """
128
  Async version of analyze_product_image for concurrent processing
@@ -132,6 +145,7 @@ async def analyze_product_image_async(
132
  model: Model to use for inference (default: Qwen2.5 VL 72B)
133
  api_key: API key (defaults to provider-specific env variable)
134
  provider: Provider to use for inference (default: Fireworks)
 
135
 
136
  Returns:
137
  ProductClassification: Structured classification and description
@@ -148,16 +162,24 @@ async def analyze_product_image_async(
148
  else:
149
  raise ValueError(f"Unknown provider: {provider}")
150
 
 
 
 
 
 
 
 
 
151
  # Call the API with structured output
152
  completion = await client.beta.chat.completions.parse(
153
  model=model,
154
  messages=[
155
- {"role": "system", "content": SYSTEM_PROMPT},
156
  {
157
  "role": "user",
158
  "content": [
159
  {"type": "image_url", "image_url": {"url": image_url}},
160
- {"type": "text", "text": USER_PROMPT},
161
  ],
162
  },
163
  ],
 
2
  from openai import OpenAI, AsyncOpenAI
3
  from pydantic import BaseModel, Field
4
  from typing import Optional, Literal
5
+ from src.modules.constants import PROMPT_LIBRARY
6
 
7
  SYSTEM_PROMPT = """
8
+ You are an e-commerce fashion catalog assistant.
9
+ Classify products and generate detailed descriptions based on images.
10
  """
11
  USER_PROMPT = """
12
  Analyze this fashion product image and provide:
 
75
  model: str = "accounts/fireworks/models/qwen2p5-vl-72b-instruct",
76
  api_key: Optional[str] = None,
77
  provider: str = "Fireworks",
78
+ prompt_style: Optional[str] = None,
79
  ) -> ProductClassification:
80
  """
81
  Analyze a fashion product image using VLM with structured output
 
85
  model: Model to use for inference (default: Qwen2.5 VL 72B)
86
  api_key: Fireworks API key (defaults to FIREWORKS_API_KEY env variable)
87
  provider: Provider to use for inference (default: Fireworks)
88
+ prompt_style: Prompt style from library (concise, descriptive, explanatory). Defaults to fallback prompts.
89
 
90
  Returns:
91
  ProductClassification: Structured classification and description
 
102
  else:
103
  raise ValueError(f"Unknown provider: {provider}")
104
 
105
+ # Get prompts from library or use defaults
106
+ if prompt_style and prompt_style in PROMPT_LIBRARY:
107
+ system_prompt = PROMPT_LIBRARY[prompt_style]["system"]
108
+ user_prompt = PROMPT_LIBRARY[prompt_style]["user"]
109
+ else:
110
+ system_prompt = SYSTEM_PROMPT
111
+ user_prompt = USER_PROMPT
112
+
113
  # Call the API with structured output
114
  completion = client.beta.chat.completions.parse(
115
  model=model,
116
  messages=[
117
+ {"role": "system", "content": system_prompt},
118
  {
119
  "role": "user",
120
  "content": [
121
  {"type": "image_url", "image_url": {"url": image_url}},
122
+ {"type": "text", "text": user_prompt},
123
  ],
124
  },
125
  ],
 
135
  model: str = "accounts/fireworks/models/qwen2p5-vl-72b-instruct",
136
  api_key: Optional[str] = None,
137
  provider: str = "Fireworks",
138
+ prompt_style: Optional[str] = None,
139
  ) -> ProductClassification:
140
  """
141
  Async version of analyze_product_image for concurrent processing
 
145
  model: Model to use for inference (default: Qwen2.5 VL 72B)
146
  api_key: API key (defaults to provider-specific env variable)
147
  provider: Provider to use for inference (default: Fireworks)
148
+ prompt_style: Prompt style from library (concise, descriptive, explanatory). Defaults to fallback prompts.
149
 
150
  Returns:
151
  ProductClassification: Structured classification and description
 
162
  else:
163
  raise ValueError(f"Unknown provider: {provider}")
164
 
165
+ # Get prompts from library or use defaults
166
+ if prompt_style and prompt_style in PROMPT_LIBRARY:
167
+ system_prompt = PROMPT_LIBRARY[prompt_style]["system"]
168
+ user_prompt = PROMPT_LIBRARY[prompt_style]["user"]
169
+ else:
170
+ system_prompt = SYSTEM_PROMPT
171
+ user_prompt = USER_PROMPT
172
+
173
  # Call the API with structured output
174
  completion = await client.beta.chat.completions.parse(
175
  model=model,
176
  messages=[
177
+ {"role": "system", "content": system_prompt},
178
  {
179
  "role": "user",
180
  "content": [
181
  {"type": "image_url", "image_url": {"url": image_url}},
182
+ {"type": "text", "text": user_prompt},
183
  ],
184
  },
185
  ],