SmartHeal commited on
Commit
7e7d8ff
Β·
verified Β·
1 Parent(s): 862d7cb

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +287 -232
src/ai_processor.py CHANGED
@@ -6,6 +6,7 @@ from PIL import Image
6
  from datetime import datetime
7
  import gradio as gr
8
  import spaces
 
9
 
10
  from huggingface_hub import HfApi, HfFolder
11
  from langchain_community.document_loaders import PyPDFLoader
@@ -23,9 +24,9 @@ if not os.path.exists(UPLOADS_DIR):
23
  logging.info(f"Created uploads directory: {UPLOADS_DIR}")
24
 
25
  HF_TOKEN = os.getenv("HF_TOKEN")
26
- YOLO_MODEL_PATH = "src/best.pt"
27
- SEG_MODEL_PATH = "src/segmentation_model.h5"
28
- GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
29
  DATASET_ID = "SmartHeal/wound-image-uploads"
30
  MAX_NEW_TOKENS = 2048
31
  PIXELS_PER_CM = 38
@@ -136,38 +137,21 @@ def generate_medgemma_report(
136
  patient_info,
137
  visual_results,
138
  guideline_context,
139
- detection_image_path,
140
- segmentation_image_path,
141
  max_new_tokens=None,
142
  ):
143
- """GPU-only function for MedGemma report generation."""
144
- # Import GPU libraries ONLY here
145
- import torch
146
  from transformers import pipeline
147
- from PIL import Image
148
-
149
- default_system_prompt = (
150
- "You are a world-class medical AI assistant specializing in wound care "
151
- "with expertise in wound assessment and treatment. Provide concise, "
152
- "evidence-based medical assessments focusing on: (1) Precise wound "
153
- "classification based on tissue type and appearance, (2) Specific "
154
- "treatment recommendations with exact product names or interventions when "
155
- "appropriate, (3) Objective evaluation of healing progression or deterioration "
156
- "indicators, and (4) Clear follow-up timelines. Avoid general statements and "
157
- "prioritize actionable insights based on the visual analysis measurements and "
158
- "patient context."
159
- )
160
 
161
- # Lazy-load MedGemma pipeline on GPU
162
  if not hasattr(generate_medgemma_report, "_pipe"):
163
  try:
164
  generate_medgemma_report._pipe = pipeline(
165
  "image-text-to-text",
166
  model="google/medgemma-4b-it",
167
- device="cuda",
168
  torch_dtype=torch.bfloat16,
169
- offload_folder="offload",
170
- token=HF_TOKEN,
171
  )
172
  logging.info("βœ… MedGemma pipeline loaded on GPU")
173
  except Exception as e:
@@ -176,52 +160,86 @@ def generate_medgemma_report(
176
 
177
  pipe = generate_medgemma_report._pipe
178
 
179
- # Load the original image that was analyzed
180
- original_image = None
181
- if detection_image_path and os.path.exists(detection_image_path.replace('detection_', 'original_')):
182
- original_image = Image.open(detection_image_path.replace('detection_', 'original_'))
183
- elif segmentation_image_path and os.path.exists(segmentation_image_path.replace('segmentation_', 'original_')):
184
- original_image = Image.open(segmentation_image_path.replace('segmentation_', 'original_'))
185
-
186
- # Compose messages
187
- msgs = [
188
- {"role": "system", "content": [{"type": "text", "text": default_system_prompt}]},
189
- {"role": "user", "content": []},
190
- ]
191
-
192
- # Attach images if available
193
- if original_image:
194
- msgs[1]["content"].append({"type": "image", "image": original_image})
195
- else:
196
- # Fallback to detection or segmentation images
197
- for path in (detection_image_path, segmentation_image_path):
198
- if path and os.path.exists(path):
199
- msgs[1]["content"].append({"type": "image", "image": Image.open(path)})
200
- break
201
-
202
- # Attach text prompt
203
- prompt = f"""## Patient Information
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  {patient_info}
 
 
 
 
 
 
 
 
205
 
206
- ## Visual Analysis Results
207
- - Wound Type: {visual_results.get('wound_type','Unknown')}
208
- - Dimensions: {visual_results.get('length_cm', 0)} x {visual_results.get('breadth_cm', 0)} cm
209
- - Surface Area: {visual_results.get('surface_area_cm2', 0)} cmΒ²
210
- - Detection Confidence: {visual_results.get('detection_confidence', 0):.2f}
211
-
212
- ## Clinical Guidelines Context
213
- {guideline_context[:1500]}...
214
-
215
- Please provide a comprehensive wound care assessment and treatment recommendations based on the image and provided information."""
216
-
217
- msgs[1]["content"].append({"type": "text", "text": prompt})
 
 
218
 
219
  try:
220
- out = pipe(text=msgs, max_new_tokens=max_new_tokens or MAX_NEW_TOKENS, do_sample=False)
221
- return out[0]["generated_text"][-1].get("content", "")
 
 
 
 
 
 
222
  except Exception as e:
223
- logging.error(f"Failed to generate MedGemma report: {e}")
224
- return f"❌ An error occurred: {e}"
225
 
226
  # =============== AI PROCESSOR CLASS ===============
227
  class AIProcessor:
@@ -234,48 +252,40 @@ class AIProcessor:
234
  self.hf_token = HF_TOKEN
235
 
236
  def perform_visual_analysis(self, image_pil: Image.Image) -> dict:
237
- """Performs the full visual analysis pipeline."""
238
  try:
239
  # Convert PIL to OpenCV format
240
  image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
241
 
242
- # YOLO Detection
243
- yolo_model = self.models_cache.get("det")
244
- if yolo_model is None:
245
- raise RuntimeError("YOLO model ('det') not loaded")
246
-
247
- results = yolo_model.predict(image_cv, verbose=False, device="cpu")
248
-
249
- if not results or not results[0].boxes or len(results[0].boxes) == 0:
250
- raise ValueError("No wound detected in the image")
251
-
252
- # Extract bounding box - handle different output formats
253
- boxes_data = results[0].boxes.xyxy.cpu().numpy()
254
-
255
- if len(boxes_data.shape) == 1:
256
- # Single detection case
257
- if len(boxes_data) != 4:
258
- raise ValueError(f"Expected 4 coordinates, got {len(boxes_data)}")
259
- x1, y1, x2, y2 = boxes_data.astype(int)
260
- else:
261
- # Multiple detections - take the first one
262
- if boxes_data.shape[1] != 4:
263
- raise ValueError(f"Expected 4 coordinates per box, got {boxes_data.shape[1]}")
264
- x1, y1, x2, y2 = boxes_data[0].astype(int)
265
-
266
- # Validate coordinates
267
- if x1 >= x2 or y1 >= y2 or x1 < 0 or y1 < 0:
268
- raise ValueError("Invalid bounding box coordinates")
269
-
270
- # Extract wound region
271
- detected_region_cv = image_cv[y1:y2, x1:x2]
272
-
273
- if detected_region_cv.size == 0:
274
- raise ValueError("Detected region is empty")
275
 
276
  # Save detection visualization
277
  det_vis = image_cv.copy()
278
- cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
279
  os.makedirs(f"{self.uploads_dir}/analysis", exist_ok=True)
280
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
281
  det_path = f"{self.uploads_dir}/analysis/detection_{ts}.png"
@@ -285,101 +295,27 @@ class AIProcessor:
285
  original_path = f"{self.uploads_dir}/analysis/original_{ts}.png"
286
  cv2.imwrite(original_path, image_cv)
287
 
288
- # Segmentation Analysis
289
- length = breadth = area = 0
290
  seg_path = None
291
-
292
- seg_model = self.models_cache.get("seg")
293
- if seg_model is not None:
294
- try:
295
- # Get input shape from model
296
- input_shape = seg_model.input_shape
297
- if len(input_shape) >= 3:
298
- h, w = input_shape[1:3]
299
- else:
300
- h, w = 256, 256 # Default fallback
301
-
302
- # Prepare input for segmentation
303
- resized = cv2.resize(detected_region_cv, (w, h))
304
- normalized_input = np.expand_dims(resized / 255.0, 0)
305
-
306
- # Predict mask
307
- mask_pred = seg_model.predict(normalized_input, verbose=0)
308
-
309
- # Handle different output formats
310
- if len(mask_pred.shape) == 4:
311
- mask_np = (mask_pred[0, :, :, 0] > 0.5).astype(np.uint8)
312
- elif len(mask_pred.shape) == 3:
313
- mask_np = (mask_pred[0, :, :] > 0.5).astype(np.uint8)
314
- else:
315
- raise ValueError(f"Unexpected segmentation output shape: {mask_pred.shape}")
316
-
317
- # Resize mask back to detection region size
318
- mask_resized = cv2.resize(
319
- mask_np * 255,
320
- (detected_region_cv.shape[1], detected_region_cv.shape[0]),
321
- interpolation=cv2.INTER_NEAREST
322
- )
323
- mask_resized = (mask_resized > 127).astype(np.uint8)
324
-
325
- # Create segmentation visualization
326
- overlay = detected_region_cv.copy()
327
- overlay[mask_resized == 1] = [0, 0, 255] # Red overlay for wound area
328
- seg_vis = cv2.addWeighted(detected_region_cv, 0.7, overlay, 0.3, 0)
329
-
330
- seg_path = f"{self.uploads_dir}/analysis/segmentation_{ts}.png"
331
- cv2.imwrite(seg_path, seg_vis)
332
-
333
- # Calculate measurements
334
- contours, _ = cv2.findContours(mask_resized, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
335
- if contours:
336
- # Get the largest contour
337
- largest_contour = max(contours, key=cv2.contourArea)
338
-
339
- # Calculate bounding rectangle
340
- bbox = cv2.boundingRect(largest_contour)
341
- if len(bbox) == 4:
342
- x, y, w_box, h_box = bbox
343
- length = round(h_box / self.px_per_cm, 2)
344
- breadth = round(w_box / self.px_per_cm, 2)
345
- area = round(cv2.contourArea(largest_contour) / (self.px_per_cm ** 2), 2)
346
- else:
347
- logging.warning(f"Unexpected bounding rect format: {bbox}")
348
- else:
349
- logging.info("No contours found in segmentation mask")
350
-
351
- except Exception as seg_error:
352
- logging.error(f"Segmentation processing error: {seg_error}")
353
- seg_path = None
354
-
355
- # Wound Classification
356
- wound_type = "Unknown"
357
- cls_pipeline = self.models_cache.get("cls")
358
- if cls_pipeline is not None:
359
- try:
360
- detected_image_pil = Image.fromarray(cv2.cvtColor(detected_region_cv, cv2.COLOR_BGR2RGB))
361
- predictions = cls_pipeline(detected_image_pil)
362
- if predictions and len(predictions) > 0:
363
- best_pred = max(predictions, key=lambda x: x.get("score", 0))
364
- wound_type = best_pred.get("label", "Unknown")
365
- except Exception as cls_error:
366
- logging.warning(f"Classification failed: {cls_error}")
367
-
368
- # Extract confidence score
369
- confidence = 0.0
370
- if results[0].boxes.conf is not None and len(results[0].boxes.conf) > 0:
371
- confidence = float(results[0].boxes.conf[0].cpu().item())
372
-
373
- return {
374
- "wound_type": wound_type,
375
- "length_cm": length,
376
- "breadth_cm": breadth,
377
  "surface_area_cm2": area,
378
- "detection_confidence": confidence,
379
  "detection_image_path": det_path,
380
  "segmentation_image_path": seg_path,
381
  "original_image_path": original_path
382
  }
 
383
 
384
  except Exception as e:
385
  logging.error(f"Visual analysis failed: {e}")
@@ -390,20 +326,11 @@ class AIProcessor:
390
  try:
391
  vector_store = self.knowledge_base_cache.get("vector_store")
392
  if not vector_store:
393
- return "Clinical guidelines unavailable - knowledge base not loaded"
394
 
395
  retriever = vector_store.as_retriever(search_kwargs={"k": 10})
396
  docs = retriever.invoke(query)
397
-
398
- if not docs:
399
- return "No relevant guidelines found for the query"
400
-
401
- context = "\n\n".join([
402
- f"Source: {doc.metadata.get('source', 'Unknown')}, Page: {doc.metadata.get('page', 'N/A')}\n{doc.page_content}"
403
- for doc in docs
404
- ])
405
-
406
- return context
407
 
408
  except Exception as e:
409
  logging.error(f"Guidelines query failed: {e}")
@@ -413,19 +340,16 @@ class AIProcessor:
413
  self, patient_info: str, visual_results: dict, guideline_context: str,
414
  image_pil: Image.Image, max_new_tokens: int = None
415
  ) -> str:
416
- """Generate final report using MedGemma GPU pipeline."""
417
  try:
418
- det_path = visual_results.get("detection_image_path", "")
419
- seg_path = visual_results.get("segmentation_image_path", "")
420
-
421
  report = generate_medgemma_report(
422
- patient_info, visual_results, guideline_context,
423
- det_path, seg_path, max_new_tokens
424
  )
425
 
426
- if report and report.strip():
427
  return report
428
  else:
 
429
  return self._generate_fallback_report(patient_info, visual_results, guideline_context)
430
 
431
  except Exception as e:
@@ -437,41 +361,56 @@ class AIProcessor:
437
  ) -> str:
438
  """Generate fallback report if MedGemma fails."""
439
 
440
- report = f"""# Wound Analysis Report
441
 
442
- ## Patient Information
443
  {patient_info}
444
 
445
- ## Visual Analysis Results
446
  - **Wound Type**: {visual_results.get('wound_type', 'Unknown')}
447
  - **Dimensions**: {visual_results.get('length_cm', 0)} cm Γ— {visual_results.get('breadth_cm', 0)} cm
448
  - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cmΒ²
449
  - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.2f}
450
 
451
- ## Analysis Images
452
  - **Detection Image**: {visual_results.get('detection_image_path', 'N/A')}
453
  - **Segmentation Image**: {visual_results.get('segmentation_image_path', 'N/A')}
454
 
455
- ## Clinical Guidelines Context
456
  {guideline_context[:1000]}{'...' if len(guideline_context) > 1000 else ''}
457
 
458
- ## Assessment Summary
459
  Based on the automated visual analysis, the wound has been classified as **{visual_results.get('wound_type', 'Unknown')}** with measurable dimensions. The detection confidence indicates the reliability of the automated assessment.
460
 
461
- ## Recommendations
 
 
 
 
 
462
  1. **Clinical Evaluation**: This automated analysis should be supplemented with professional clinical assessment
463
  2. **Documentation**: Regular monitoring and documentation of wound progression is recommended
464
  3. **Treatment Planning**: Develop appropriate treatment protocol based on wound characteristics and patient factors
465
  4. **Follow-up**: Schedule appropriate follow-up intervals based on wound severity and healing progress
466
 
467
- ## Important Notes
468
  - This is an automated analysis and should not replace professional medical judgment
469
  - All measurements are estimates based on computer vision algorithms
470
  - Clinical correlation is essential for proper diagnosis and treatment planning
471
  - Consider patient-specific factors not captured in this automated assessment
472
 
473
- ## Disclaimer
474
- This automated analysis is provided for informational purposes only and does not constitute medical advice. Always consult with qualified healthcare professionals for proper diagnosis and treatment.
 
 
 
 
 
 
 
 
 
 
475
  """
476
  return report
477
 
@@ -510,7 +449,7 @@ This automated analysis is provided for informational purposes only and does not
510
  return ""
511
 
512
  def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: dict) -> dict:
513
- """Run full analysis pipeline."""
514
  try:
515
  # Save image first
516
  saved_path = self.save_and_commit_image(image_pil)
@@ -520,18 +459,11 @@ This automated analysis is provided for informational purposes only and does not
520
  visual_results = self.perform_visual_analysis(image_pil)
521
  logging.info(f"Visual analysis completed: {visual_results}")
522
 
523
- # Process questionnaire data
524
- patient_info = ", ".join(f"{k}: {v}" for k, v in questionnaire_data.items() if v)
525
- if not patient_info:
526
- patient_info = "No patient information provided"
527
 
528
- # Query guidelines
529
- query = f"wound care treatment for {visual_results.get('wound_type', 'wound')} "
530
- if questionnaire_data.get('diabetic') == 'Yes':
531
- query += "diabetic patient "
532
- if questionnaire_data.get('infection') == 'Yes':
533
- query += "with infection signs "
534
-
535
  guideline_context = self.query_guidelines(query)
536
  logging.info("Guidelines queried successfully")
537
 
@@ -594,7 +526,7 @@ This automated analysis is provided for informational purposes only and does not
594
 
595
  try:
596
  # Age assessment
597
- age = questionnaire_data.get('patient_age', 0)
598
  if isinstance(age, str):
599
  try:
600
  age = int(age)
@@ -722,4 +654,127 @@ This automated analysis is provided for informational purposes only and does not
722
  recommendations.append("Use high-absorption dressings")
723
  recommendations.append("More frequent dressing changes may be needed")
724
 
725
- return recommendations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from datetime import datetime
7
  import gradio as gr
8
  import spaces
9
+ import torch
10
 
11
  from huggingface_hub import HfApi, HfFolder
12
  from langchain_community.document_loaders import PyPDFLoader
 
24
  logging.info(f"Created uploads directory: {UPLOADS_DIR}")
25
 
26
  HF_TOKEN = os.getenv("HF_TOKEN")
27
+ YOLO_MODEL_PATH = "best.pt"
28
+ SEG_MODEL_PATH = "segmentation_model.h5"
29
+ GUIDELINE_PDFS = ["eHealth in Wound Care.pdf", "IWGDF Guideline.pdf", "evaluation.pdf"]
30
  DATASET_ID = "SmartHeal/wound-image-uploads"
31
  MAX_NEW_TOKENS = 2048
32
  PIXELS_PER_CM = 38
 
137
  patient_info,
138
  visual_results,
139
  guideline_context,
140
+ image_pil,
 
141
  max_new_tokens=None,
142
  ):
143
+ """GPU-only function for MedGemma report generation - EXACTLY like working reference."""
 
 
144
  from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ # Lazy-load MedGemma pipeline on GPU - EXACTLY like working reference
147
  if not hasattr(generate_medgemma_report, "_pipe"):
148
  try:
149
  generate_medgemma_report._pipe = pipeline(
150
  "image-text-to-text",
151
  model="google/medgemma-4b-it",
 
152
  torch_dtype=torch.bfloat16,
153
+ device_map="auto",
154
+ token=HF_TOKEN
155
  )
156
  logging.info("βœ… MedGemma pipeline loaded on GPU")
157
  except Exception as e:
 
160
 
161
  pipe = generate_medgemma_report._pipe
162
 
163
+ # Use the EXACT prompt format from the working reference
164
+ prompt = f"""
165
+ 🩺 You are SmartHeal-AI Agent, a world-class wound care AI specialist trained in clinical wound assessment and guideline-based treatment planning.
166
+ Your task is to process the following structured inputs (patient data, wound measurements, clinical guidelines, and image) and perform **clinical reasoning and decision-making** to generate a complete wound care report.
167
+ ---
168
+ πŸ” **YOUR PROCESS β€” FOLLOW STRICTLY:**
169
+ ### Step 1: Clinical Reasoning (Chain-of-Thought)
170
+ Use the provided information to think step-by-step about:
171
+ - Patient's risk factors (e.g. diabetes, age, healing limitations)
172
+ - Wound characteristics (size, tissue appearance, moisture, infection signs)
173
+ - Visual clues from the image (location, granulation, maceration, inflammation, surrounding skin)
174
+ - Clinical guidelines provided β€” selectively choose the ones most relevant to this case
175
+ Do NOT list all guidelines verbatim. Use judgment: apply them where relevant. Explain why or why not.
176
+ Also assess whether this wound appears:
177
+ - Acute vs chronic
178
+ - Surgical vs traumatic
179
+ - Inflammatory vs proliferative healing phase
180
+ ---
181
+ ### Step 2: Structured Clinical Report
182
+ Generate the following report sections using markdown and medical terminology:
183
+ #### **1. Clinical Summary**
184
+ - Describe wound appearance and tissue types (e.g., slough, necrotic, granulating, epithelializing)
185
+ - Include size, wound bed condition, peri-wound skin, and signs of infection or biofilm
186
+ - Mention inferred location (e.g., heel, forefoot) if image allows
187
+ - Summarize patient's systemic risk profile
188
+ #### **2. Medicinal & Dressing Recommendations**
189
+ Based on your analysis:
190
+ - Recommend specific **wound care dressings** (e.g., hydrocolloid, alginate, foam, antimicrobial silver, etc.) suitable to wound moisture level and infection risk
191
+ - Propose **topical or systemic agents** ONLY if relevant β€” include name classes (e.g., antiseptic: povidone iodine, antibiotic ointments, enzymatic debriders)
192
+ - Mention **techniques** (e.g., sharp debridement, NPWT, moisture balance, pressure offloading, dressing frequency)
193
+ - Avoid repeating guidelines β€” **apply them**
194
+ #### **3. Key Risk Factors**
195
+ Explain how the patient's condition (e.g., diabetic, poor circulation, advanced age, poor hygiene) may affect wound healing
196
+ #### **4. Prognosis & Monitoring Advice**
197
+ - Mention how often wound should be reassessed
198
+ - Indicate signs to monitor for deterioration or improvement
199
+ - Include when escalation to specialist is necessary
200
+ #### **5. Disclaimer**
201
+ This is an AI-generated summary based on available data. It is not a substitute for clinical evaluation by a wound care professional.
202
+ **Note:** Every dressing change is a chance for wound reassessment. Always perform a thorough wound evaluation at each dressing change.
203
+ ---
204
+ 🧾 **INPUT DATA**
205
+ **Patient Info:**
206
  {patient_info}
207
+ **Wound Details:**
208
+ - Type: {visual_results['wound_type']}
209
+ - Size: {visual_results['length_cm']} Γ— {visual_results['breadth_cm']} cm
210
+ - Area: {visual_results['surface_area_cm2']} cmΒ²
211
+ **Clinical Guideline Evidence:**
212
+ {guideline_context}
213
+ You may now begin your analysis and generate the two-part report.
214
+ """
215
 
216
+ # Use EXACT message format from working reference
217
+ messages = [
218
+ {
219
+ "role": "system",
220
+ "content": [{"type": "text", "text": "You are a world-class medical AI assistant. Follow the user's instructions precisely to perform a two-step analysis and generate a structured report."}],
221
+ },
222
+ {
223
+ "role": "user",
224
+ "content": [
225
+ {"type": "image", "image": image_pil},
226
+ {"type": "text", "text": prompt},
227
+ ]
228
+ }
229
+ ]
230
 
231
  try:
232
+ output = pipe(
233
+ text=messages,
234
+ max_new_tokens=max_new_tokens or MAX_NEW_TOKENS,
235
+ do_sample=False,
236
+ )
237
+ result = output[0]["generated_text"][-1].get("content", "").strip()
238
+ return result if result else "⚠️ No content generated. Try reducing max tokens or input size."
239
+
240
  except Exception as e:
241
+ logging.error(f"Failed to generate MedGemma report: {e}", exc_info=True)
242
+ return f"❌ An error occurred while generating the report: {e}"
243
 
244
  # =============== AI PROCESSOR CLASS ===============
245
  class AIProcessor:
 
252
  self.hf_token = HF_TOKEN
253
 
254
  def perform_visual_analysis(self, image_pil: Image.Image) -> dict:
255
+ """Performs the full visual analysis pipeline - EXACTLY like working reference."""
256
  try:
257
  # Convert PIL to OpenCV format
258
  image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
259
 
260
+ # YOLO Detection - EXACTLY like working reference
261
+ results = self.models_cache["det"].predict(image_cv, verbose=False, device="cpu")
262
+ if not results or not results[0].boxes:
263
+ raise ValueError("No wound could be detected.")
264
+
265
+ box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
266
+ detected_region_cv = image_cv[box[1]:box[3], box[0]:box[2]]
267
+
268
+ # Segmentation - EXACTLY like working reference
269
+ input_size = self.models_cache["seg"].input_shape[1:3]
270
+ resized = cv2.resize(detected_region_cv, (input_size[1], input_size[0]))
271
+ mask_pred = self.models_cache["seg"].predict(np.expand_dims(resized / 255.0, 0), verbose=0)[0]
272
+ mask_np = (mask_pred[:, :, 0] > 0.5).astype(np.uint8)
273
+
274
+ # Calculate measurements - EXACTLY like working reference
275
+ contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
276
+ length, breadth, area = (0, 0, 0)
277
+ if contours:
278
+ cnt = max(contours, key=cv2.contourArea)
279
+ x, y, w, h = cv2.boundingRect(cnt)
280
+ length, breadth, area = round(h / self.px_per_cm, 2), round(w / self.px_per_cm, 2), round(cv2.contourArea(cnt) / (self.px_per_cm ** 2), 2)
281
+
282
+ # Classification - EXACTLY like working reference
283
+ detected_image_pil = Image.fromarray(cv2.cvtColor(detected_region_cv, cv2.COLOR_BGR2RGB))
284
+ wound_type = max(self.models_cache["cls"](detected_image_pil), key=lambda x: x["score"])["label"]
 
 
 
 
 
 
 
 
285
 
286
  # Save detection visualization
287
  det_vis = image_cv.copy()
288
+ cv2.rectangle(det_vis, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
289
  os.makedirs(f"{self.uploads_dir}/analysis", exist_ok=True)
290
  ts = datetime.now().strftime("%Y%m%d_%H%M%S")
291
  det_path = f"{self.uploads_dir}/analysis/detection_{ts}.png"
 
295
  original_path = f"{self.uploads_dir}/analysis/original_{ts}.png"
296
  cv2.imwrite(original_path, image_cv)
297
 
298
+ # Save segmentation visualization if available
 
299
  seg_path = None
300
+ if contours:
301
+ mask_resized = cv2.resize(mask_np * 255, (detected_region_cv.shape[1], detected_region_cv.shape[0]), interpolation=cv2.INTER_NEAREST)
302
+ overlay = detected_region_cv.copy()
303
+ overlay[mask_resized > 127] = [0, 0, 255] # Red overlay for wound area
304
+ seg_vis = cv2.addWeighted(detected_region_cv, 0.7, overlay, 0.3, 0)
305
+ seg_path = f"{self.uploads_dir}/analysis/segmentation_{ts}.png"
306
+ cv2.imwrite(seg_path, seg_vis)
307
+
308
+ visual_results = {
309
+ "wound_type": wound_type,
310
+ "length_cm": length,
311
+ "breadth_cm": breadth,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  "surface_area_cm2": area,
313
+ "detection_confidence": float(results[0].boxes.conf[0].cpu().item()) if results[0].boxes.conf is not None else 0.0,
314
  "detection_image_path": det_path,
315
  "segmentation_image_path": seg_path,
316
  "original_image_path": original_path
317
  }
318
+ return visual_results
319
 
320
  except Exception as e:
321
  logging.error(f"Visual analysis failed: {e}")
 
326
  try:
327
  vector_store = self.knowledge_base_cache.get("vector_store")
328
  if not vector_store:
329
+ return "Knowledge base is not available."
330
 
331
  retriever = vector_store.as_retriever(search_kwargs={"k": 10})
332
  docs = retriever.invoke(query)
333
+ return "\n\n".join([f"Source: {doc.metadata.get('source', 'N/A')}, Page: {doc.metadata.get('page', 'N/A')}\nContent: {doc.page_content}" for doc in docs])
 
 
 
 
 
 
 
 
 
334
 
335
  except Exception as e:
336
  logging.error(f"Guidelines query failed: {e}")
 
340
  self, patient_info: str, visual_results: dict, guideline_context: str,
341
  image_pil: Image.Image, max_new_tokens: int = None
342
  ) -> str:
343
+ """Generate final report using MedGemma GPU pipeline - EXACTLY like working reference."""
344
  try:
 
 
 
345
  report = generate_medgemma_report(
346
+ patient_info, visual_results, guideline_context, image_pil, max_new_tokens
 
347
  )
348
 
349
+ if report and report.strip() and not report.startswith("❌") and not report.startswith("⚠️"):
350
  return report
351
  else:
352
+ logging.warning("MedGemma returned empty or error response, using fallback")
353
  return self._generate_fallback_report(patient_info, visual_results, guideline_context)
354
 
355
  except Exception as e:
 
361
  ) -> str:
362
  """Generate fallback report if MedGemma fails."""
363
 
364
+ report = f"""# 🩺 SmartHeal AI - Wound Analysis Report
365
 
366
+ ## πŸ“‹ Patient Information
367
  {patient_info}
368
 
369
+ ## πŸ” Visual Analysis Results
370
  - **Wound Type**: {visual_results.get('wound_type', 'Unknown')}
371
  - **Dimensions**: {visual_results.get('length_cm', 0)} cm Γ— {visual_results.get('breadth_cm', 0)} cm
372
  - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cmΒ²
373
  - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.2f}
374
 
375
+ ## πŸ“Š Analysis Images
376
  - **Detection Image**: {visual_results.get('detection_image_path', 'N/A')}
377
  - **Segmentation Image**: {visual_results.get('segmentation_image_path', 'N/A')}
378
 
379
+ ## πŸ“š Clinical Guidelines Context
380
  {guideline_context[:1000]}{'...' if len(guideline_context) > 1000 else ''}
381
 
382
+ ## 🎯 Assessment Summary
383
  Based on the automated visual analysis, the wound has been classified as **{visual_results.get('wound_type', 'Unknown')}** with measurable dimensions. The detection confidence indicates the reliability of the automated assessment.
384
 
385
+ ### Clinical Observations
386
+ - **Wound Classification**: {visual_results.get('wound_type', 'Unspecified')}
387
+ - **Approximate Size**: {visual_results.get('length_cm', 0)} Γ— {visual_results.get('breadth_cm', 0)} cm
388
+ - **Calculated Area**: {visual_results.get('surface_area_cm2', 0)} cmΒ²
389
+
390
+ ## πŸ’Š General Recommendations
391
  1. **Clinical Evaluation**: This automated analysis should be supplemented with professional clinical assessment
392
  2. **Documentation**: Regular monitoring and documentation of wound progression is recommended
393
  3. **Treatment Planning**: Develop appropriate treatment protocol based on wound characteristics and patient factors
394
  4. **Follow-up**: Schedule appropriate follow-up intervals based on wound severity and healing progress
395
 
396
+ ## ⚠️ Important Clinical Notes
397
  - This is an automated analysis and should not replace professional medical judgment
398
  - All measurements are estimates based on computer vision algorithms
399
  - Clinical correlation is essential for proper diagnosis and treatment planning
400
  - Consider patient-specific factors not captured in this automated assessment
401
 
402
+ ## πŸ₯ Next Steps
403
+ 1. **Professional Assessment**: Consult with a qualified wound care specialist
404
+ 2. **Comprehensive Evaluation**: Consider patient's overall health status and comorbidities
405
+ 3. **Treatment Protocol**: Develop individualized care plan based on clinical findings
406
+ 4. **Monitoring Plan**: Establish regular assessment schedule
407
+
408
+ ## βš–οΈ Disclaimer
409
+ This automated analysis is provided for informational purposes only and does not constitute medical advice. Always consult with qualified healthcare professionals for proper diagnosis and treatment. This AI-generated report should be used as a supplementary tool alongside professional clinical assessment.
410
+
411
+ ---
412
+ *Generated by SmartHeal AI - Advanced Wound Care Analysis System*
413
+ *Report Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}*
414
  """
415
  return report
416
 
 
449
  return ""
450
 
451
  def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: dict) -> dict:
452
+ """Run full analysis pipeline - EXACTLY like working reference."""
453
  try:
454
  # Save image first
455
  saved_path = self.save_and_commit_image(image_pil)
 
459
  visual_results = self.perform_visual_analysis(image_pil)
460
  logging.info(f"Visual analysis completed: {visual_results}")
461
 
462
+ # Process questionnaire data - EXACTLY like working reference
463
+ patient_info = f"Age: {questionnaire_data.get('age', 'N/A')}, Diabetic: {questionnaire_data.get('diabetic', 'N/A')}, Allergies: {questionnaire_data.get('allergies', 'N/A')}, Date of Wound Sustained: {questionnaire_data.get('date_of_injury', 'N/A')}, Professional Care: {questionnaire_data.get('professional_care', 'N/A')}, Oozing/Bleeding: {questionnaire_data.get('oozing_bleeding', 'N/A')}, Infection: {questionnaire_data.get('infection', 'N/A')}, Moisture: {questionnaire_data.get('moisture', 'N/A')}"
 
 
464
 
465
+ # Query guidelines - EXACTLY like working reference
466
+ query = f"best practices for managing a {visual_results['wound_type']} with moisture level '{questionnaire_data.get('moisture', 'unknown')}' and signs of infection '{questionnaire_data.get('infection', 'unknown')}' in a patient who is diabetic '{questionnaire_data.get('diabetic', 'unknown')}'"
 
 
 
 
 
467
  guideline_context = self.query_guidelines(query)
468
  logging.info("Guidelines queried successfully")
469
 
 
526
 
527
  try:
528
  # Age assessment
529
+ age = questionnaire_data.get('age', 0)
530
  if isinstance(age, str):
531
  try:
532
  age = int(age)
 
654
  recommendations.append("Use high-absorption dressings")
655
  recommendations.append("More frequent dressing changes may be needed")
656
 
657
+ return recommendations
658
+
659
+
660
+ # =============== STANDALONE SAVE AND COMMIT FUNCTION ===============
661
+ def save_and_commit_image(image_to_save):
662
+ """Saves an image locally and commits it to the separate HF Dataset repository - EXACTLY like working reference."""
663
+ if not image_to_save:
664
+ return
665
+
666
+ timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
667
+ filename = f"{timestamp}.png"
668
+ local_save_path = os.path.join(UPLOADS_DIR, filename)
669
+
670
+ image_to_save.convert("RGB").save(local_save_path)
671
+ logging.info(f"βœ… Image saved to temporary local storage: {local_save_path}")
672
+
673
+ if DATASET_ID and HF_TOKEN:
674
+ try:
675
+ api = HfApi()
676
+ repo_path = f"images/{filename}"
677
+
678
+ logging.info(f"Attempting to commit {local_save_path} to DATASET {DATASET_ID}...")
679
+
680
+ api.upload_file(
681
+ path_or_fileobj=local_save_path,
682
+ path_in_repo=repo_path,
683
+ repo_id=DATASET_ID,
684
+ repo_type="dataset",
685
+ commit_message=f"Upload wound image: {filename}"
686
+ )
687
+ logging.info(f"βœ… Image successfully committed to dataset.")
688
+ except Exception as e:
689
+ logging.error(f"❌ FAILED TO COMMIT IMAGE TO DATASET: {e}")
690
+ else:
691
+ logging.warning("DATASET_ID or HF_TOKEN not set. Skipping file commit.")
692
+
693
+ # =============== MAIN ANALYSIS FUNCTION (with @spaces.GPU) - EXACTLY LIKE WORKING REFERENCE ===============
694
+ @spaces.GPU(enable_queue=True, duration=120)
695
+ def analyze(image, age, diabetic, allergies, date_of_injury, professional_care, oozing_bleeding, infection, moisture):
696
+ """Main analysis function with GPU decorator - EXACTLY like working reference."""
697
+ try:
698
+ yield None, None, "⏳ Initializing... Loading AI models..."
699
+
700
+ # Load all models - using global cache
701
+ if "medgemma_pipe" not in models_cache:
702
+ from transformers import pipeline
703
+ models_cache["medgemma_pipe"] = pipeline(
704
+ "image-text-to-text",
705
+ model="google/medgemma-4b-it",
706
+ torch_dtype=torch.bfloat16,
707
+ device_map="auto",
708
+ token=HF_TOKEN
709
+ )
710
+ logging.info("βœ… All models loaded.")
711
+
712
+ yield None, None, "⏳ Setting up knowledge base from guidelines..."
713
+
714
+ # Save image
715
+ save_and_commit_image(image)
716
+
717
+ # Create processor instance
718
+ processor = AIProcessor()
719
+
720
+ yield None, None, "⏳ Performing visual analysis..."
721
+
722
+ # Perform visual analysis - EXACTLY like working reference
723
+ image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
724
+ results = models_cache["det"].predict(image_cv, verbose=False, device="cpu")
725
+ if not results or not results[0].boxes:
726
+ raise ValueError("No wound could be detected.")
727
+
728
+ box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
729
+ detected_region_cv = image_cv[box[1]:box[3], box[0]:box[2]]
730
+
731
+ input_size = models_cache["seg"].input_shape[1:3]
732
+ resized = cv2.resize(detected_region_cv, (input_size[1], input_size[0]))
733
+ mask_pred = models_cache["seg"].predict(np.expand_dims(resized / 255.0, 0), verbose=0)[0]
734
+ mask_np = (mask_pred[:, :, 0] > 0.5).astype(np.uint8)
735
+
736
+ contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
737
+ length, breadth, area = (0, 0, 0)
738
+ if contours:
739
+ cnt = max(contours, key=cv2.contourArea)
740
+ x, y, w, h = cv2.boundingRect(cnt)
741
+ length, breadth, area = round(h / PIXELS_PER_CM, 2), round(w / PIXELS_PER_CM, 2), round(cv2.contourArea(cnt) / (PIXELS_PER_CM ** 2), 2)
742
+
743
+ detected_image_pil = Image.fromarray(cv2.cvtColor(detected_region_cv, cv2.COLOR_BGR2RGB))
744
+ wound_type = max(models_cache["cls"](detected_image_pil), key=lambda x: x["score"])["label"]
745
+
746
+ visual_results = {"wound_type": wound_type, "length_cm": length, "breadth_cm": breadth, "surface_area_cm2": area}
747
+
748
+ # Create visualization images
749
+ segmented_mask = Image.fromarray(cv2.resize(mask_np * 255, (detected_region_cv.shape[1], detected_region_cv.shape[0]), interpolation=cv2.INTER_NEAREST))
750
+
751
+ yield detected_image_pil, segmented_mask, f"βœ… Visual analysis complete. Detected: {visual_results['wound_type']}. Querying guidelines..."
752
+
753
+ # Query guidelines
754
+ patient_info = f"Age: {age}, Diabetic: {diabetic}, Allergies: {allergies}, Date of Wound Sustained: {date_of_injury}, Professional Care: {professional_care}, Oozing/Bleeding: {oozing_bleeding}, Infection: {infection}, Moisture: {moisture}"
755
+ query = f"best practices for managing a {visual_results['wound_type']} with moisture level '{moisture}' and signs of infection '{infection}' in a patient who is diabetic '{diabetic}'"
756
+ guideline_context = processor.query_guidelines(query)
757
+
758
+ yield detected_image_pil, segmented_mask, "βœ… Guidelines queried. Generating final report..."
759
+
760
+ # Generate final report using MedGemma
761
+ final_report = generate_medgemma_report(
762
+ patient_info,
763
+ visual_results,
764
+ guideline_context,
765
+ image_pil=image
766
+ )
767
+
768
+ visual_summary = f"""## πŸ“Š Programmatic Visual Analysis
769
+ | Metric | Result |
770
+ | :--- | :--- |
771
+ | **Detected Wound Type** | {visual_results['wound_type']} |
772
+ | **Estimated Dimensions** | {visual_results['length_cm']}cm x {visual_results['breadth_cm']}cm (Area: {visual_results['surface_area_cm2']}cmΒ²) |
773
+ ---
774
+ """
775
+ final_output_text = visual_summary + "## 🩺 MedHeal-AI Clinical Assessment\n" + final_report
776
+ yield detected_image_pil, segmented_mask, final_output_text
777
+
778
+ except Exception as e:
779
+ logging.error(f"An error occurred during analysis: {e}", exc_info=True)
780
+ yield None, None, f"❌ **An error occurred:** {e}"