SmartHeal commited on
Commit
2baa1e7
·
verified ·
1 Parent(s): ef69ec1

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +96 -123
src/ai_processor.py CHANGED
@@ -1,6 +1,10 @@
1
  # smartheal_ai_processor.py
2
- # Fully functional: robust segmentation + safe overlays + conditional GPU wrapper.
3
- # All original class/function names preserved. New helpers are additive.
 
 
 
 
4
 
5
  import os
6
  import time
@@ -8,12 +12,12 @@ import logging
8
  from datetime import datetime
9
  from typing import Optional, Dict, List, Tuple
10
 
11
- # --- quiet tokenizers fork warning (HF) ---
12
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
13
 
14
  import cv2
15
  import numpy as np
16
- from PIL import Image, ImageOps
17
  from PIL.ExifTags import TAGS
18
 
19
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
@@ -26,8 +30,8 @@ YOLO_MODEL_PATH = "src/best.pt"
26
  SEG_MODEL_PATH = "src/segmentation_model.h5" # optional
27
  GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
28
  DATASET_ID = "SmartHeal/wound-image-uploads"
29
- DEFAULT_PX_PER_CM = 38.0 # fallback when we cannot calibrate
30
- PX_PER_CM_MIN, PX_PER_CM_MAX = 5.0, 1200.0 # sanity bounds
31
 
32
  models_cache: Dict[str, object] = {}
33
  knowledge_base_cache: Dict[str, object] = {}
@@ -39,7 +43,7 @@ def _import_ultralytics():
39
 
40
  def _import_tf_loader():
41
  import tensorflow as tf
42
- tf.config.set_visible_devices([], "GPU") # force CPU for TF to avoid CUDA contention
43
  from tensorflow.keras.models import load_model
44
  return load_model
45
 
@@ -63,8 +67,7 @@ def _import_hf_hub():
63
  from huggingface_hub import HfApi, HfFolder
64
  return HfApi, HfFolder
65
 
66
- # ---------- Conditional Spaces GPU function ----------
67
- # Avoid scheduling a GPU worker when CUDA is not available (prevents cudaGetDeviceCount crash)
68
  def _cuda_available() -> bool:
69
  try:
70
  import torch
@@ -81,7 +84,6 @@ def _generate_medgemma_report_core(
81
  ) -> str:
82
  try:
83
  from transformers import pipeline
84
- # Use CPU by default; if CUDA truly available, pipeline can still map automatically
85
  pipe = pipeline(
86
  "image-text-to-text",
87
  model="google/medgemma-4b-it",
@@ -123,8 +125,6 @@ def _generate_medgemma_report_core(
123
  logging.error(f"❌ MedGemma generation error: {e}")
124
  return "⚠️ GPU/LLM worker unavailable"
125
 
126
- # Preserve the SAME public function name.
127
- # Only decorate with @spaces.GPU if CUDA is truly available.
128
  try:
129
  import spaces
130
  if _cuda_available():
@@ -145,7 +145,6 @@ try:
145
  image_pil: Image.Image,
146
  max_new_tokens: Optional[int] = None,
147
  ) -> str:
148
- # no decorator -> no GPU worker init -> no cudaGetDeviceCount crash
149
  return _generate_medgemma_report_core(patient_info, visual_results, guideline_context, image_pil, max_new_tokens)
150
  except Exception:
151
  def generate_medgemma_report(
@@ -289,7 +288,6 @@ def estimate_px_per_cm_from_exif(pil_img: Image.Image, default_px_per_cm: float
289
  f35 = _to_float(exif.get("FocalLengthIn35mmFilm") or exif.get("FocalLengthIn35mm"))
290
  subj_dist_m = _to_float(exif.get("SubjectDistance"))
291
  sensor_w_mm = _estimate_sensor_width_mm(f_mm, f35)
292
-
293
  meta.update({"f_mm": f_mm, "f35": f35, "sensor_w_mm": sensor_w_mm, "distance_m": subj_dist_m})
294
 
295
  if f_mm and sensor_w_mm and subj_dist_m and subj_dist_m > 0:
@@ -304,64 +302,65 @@ def estimate_px_per_cm_from_exif(pil_img: Image.Image, default_px_per_cm: float
304
  except Exception:
305
  return float(default_px_per_cm), meta
306
 
307
- # ---------- Segmentation helpers (additive; names preserved elsewhere) ----------
308
- def _get_seg_hw(seg_model) -> Tuple[int, int]:
309
- shp = getattr(seg_model, "input_shape", None)
310
- if shp and len(shp) >= 4:
311
- return int(shp[1]), int(shp[2])
312
- # try Keras .inputs shape
313
- try:
314
- shp = seg_model.inputs[0].shape
315
- return int(shp[1]), int(shp[2])
316
- except Exception:
317
- pass
318
- raise ValueError(f"Cannot infer (H,W) from segmentation model input shape: {shp}")
319
-
320
- def _to_prob(mask_pred: np.ndarray) -> np.ndarray:
321
- m = np.array(mask_pred)
322
- # squeeze batch/channel dims
323
- while m.ndim > 2:
324
- if m.shape[0] == 1:
325
- m = np.squeeze(m, axis=0)
326
- if m.ndim > 2 and m.shape[-1] == 1:
327
- m = np.squeeze(m, axis=-1)
328
- if m.ndim == 3 and m.shape[-1] > 1:
329
- # pick the most active channel
330
- ch = np.argmax(m.reshape(-1, m.shape[-1]).mean(0))
331
- m = m[..., ch]
332
- if m.ndim <= 2:
333
- break
334
- m = m.astype("float32")
335
- # if looks like logits -> sigmoid
336
- if m.max() > 1.5 or m.min() < -0.5:
337
- m = 1.0 / (1.0 + np.exp(-m))
338
- return np.clip(m, 0.0, 1.0)
339
-
340
- def _adaptive_threshold(prob: np.ndarray, hard: float = 0.5) -> np.ndarray:
341
- if (prob >= hard).sum() > 0:
342
- return (prob >= hard).astype("uint8")
343
- # try Otsu
344
- m8 = (np.clip(prob, 0, 1) * 255).astype("uint8")
345
- try:
346
- # we only need the threshold value _
347
- _, _ = cv2.threshold(m8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
348
- return (m8 >= _).astype("uint8")
349
- except Exception:
350
- p = float(np.percentile(prob, 99.0))
351
- return (prob >= max(0.2, min(0.9, p))).astype("uint8")
352
 
353
- def largest_component_mask(binary: np.ndarray, min_area_px: int = 50) -> np.ndarray:
354
- num, labels, stats, _ = cv2.connectedComponentsWithStats(binary.astype(np.uint8), connectivity=8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  if num <= 1:
356
- return binary.astype(np.uint8)
357
  areas = stats[1:, cv2.CC_STAT_AREA]
358
  if areas.size == 0 or areas.max() < min_area_px:
359
- return binary.astype(np.uint8)
360
  largest_idx = 1 + int(np.argmax(areas))
361
  return (labels == largest_idx).astype(np.uint8)
362
 
363
- def measure_min_area_rect(mask: np.ndarray, px_per_cm: float) -> Tuple[float, float, Tuple]:
364
- contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
365
  if not contours:
366
  return 0.0, 0.0, (None, None)
367
  cnt = max(contours, key=cv2.contourArea)
@@ -373,8 +372,8 @@ def measure_min_area_rect(mask: np.ndarray, px_per_cm: float) -> Tuple[float, fl
373
  box = cv2.boxPoints(rect).astype(int)
374
  return length_cm, breadth_cm, (box, rect[0])
375
 
376
- def count_area_cm2(mask: np.ndarray, px_per_cm: float) -> float:
377
- px_count = float(mask.astype(bool).sum())
378
  return round(px_count / (max(px_per_cm, 1e-6) ** 2), 2)
379
 
380
  def draw_measurement_overlay(
@@ -386,13 +385,11 @@ def draw_measurement_overlay(
386
  thickness: int = 2
387
  ) -> np.ndarray:
388
  overlay = base_bgr.copy()
389
- # safe blend: blend once, then gate with mask (no mask kwarg!)
390
- colored = np.zeros_like(base_bgr); colored[:] = (0, 0, 255)
391
- blended = cv2.addWeighted(overlay, 1.0, colored, 0.3, 0)
392
  m3 = np.dstack([mask01 * 255] * 3).astype("uint8")
393
- blended_masked = cv2.bitwise_and(blended, m3)
394
- bg = cv2.bitwise_and(overlay, cv2.bitwise_not(m3))
395
- overlay = cv2.add(bg, blended_masked)
396
 
397
  if rect_box is not None:
398
  cv2.polylines(overlay, [rect_box], True, (255, 255, 255), thickness)
@@ -410,15 +407,14 @@ def draw_measurement_overlay(
410
  cv2.arrowedLine(img, p1, p2, (255, 255, 255), thickness, tipLength=0.05)
411
  cv2.arrowedLine(img, p2, p1, (255, 255, 255), thickness, tipLength=0.05)
412
 
413
- draw_arrow(overlay, mids[long_pair[0]], mids[long_pair[1]])
414
- draw_arrow(overlay, mids[short_pair[0]], mids[short_pair[1]])
415
-
416
  def put_label(text, org):
417
  cv2.putText(overlay, text, (org[0] + 4, org[1] - 4),
418
  cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 4, cv2.LINE_AA)
419
  cv2.putText(overlay, text, (org[0] + 4, org[1] - 4),
420
  cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
421
 
 
 
422
  put_label(f"{length_cm:.2f} cm", mids[long_pair[0]])
423
  put_label(f"{breadth_cm:.2f} cm", mids[short_pair[0]])
424
  return overlay
@@ -439,24 +435,20 @@ class AIProcessor:
439
 
440
  def perform_visual_analysis(self, image_pil: Image.Image) -> Dict:
441
  """
442
- Detect → crop ROI → (optional) segment cleanup → largest component →
443
- oriented minAreaRect in cm (EXIF-calibrated) → save original/detect/seg/annotated.
444
  """
445
  try:
446
- # --- Auto calibration from EXIF ---
447
  px_per_cm, exif_meta = estimate_px_per_cm_from_exif(image_pil, DEFAULT_PX_PER_CM)
448
-
449
- # Convert PIL to OpenCV BGR
450
  image_cv = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR)
451
 
452
- # --- Detection (YOLO) ---
453
  det_model = self.models_cache.get("det")
454
  if det_model is None:
455
  raise RuntimeError("YOLO model not loaded")
456
-
457
  results = det_model.predict(image_cv, verbose=False, device="cpu")
458
  if not results or not getattr(results[0], "boxes", None) or len(results[0].boxes) == 0:
459
- import gradio as gr # local import to keep class name intact if gradio missing
460
  raise gr.Error("No wound could be detected.")
461
 
462
  box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
@@ -468,36 +460,21 @@ class AIProcessor:
468
  import gradio as gr
469
  raise gr.Error("Detected ROI is empty.")
470
 
471
- # --- Segmentation (robust) ---
472
- seg_model = self.models_cache.get("seg")
473
- mask_roi_01 = None
474
- if seg_model is not None:
475
- try:
476
- H, W = _get_seg_hw(seg_model) # robust (H,W)
477
- resized = cv2.resize(roi, (W, H)) # cv2.resize expects (W,H)
478
- pred = seg_model.predict(np.expand_dims(resized / 255.0, 0), verbose=0)
479
- prob = _to_prob(pred) # (H,W) in [0,1]
480
- binmask = _adaptive_threshold(prob, hard=0.5)
481
- # gentle cleanup + largest component
482
- binmask = cv2.morphologyEx(binmask, cv2.MORPH_OPEN, np.ones((3,3), np.uint8), iterations=1)
483
- binmask = cv2.morphologyEx(binmask, cv2.MORPH_CLOSE, np.ones((3,3), np.uint8), iterations=1)
484
- binmask = largest_component_mask(binmask, min_area_px=30)
485
- # back to ROI size {0,1}
486
- mask_roi_01 = cv2.resize(binmask, (roi.shape[1], roi.shape[0]), interpolation=cv2.INTER_NEAREST).astype(np.uint8)
487
- logging.info(f"seg prob stats: min={prob.min():.4f}, max={prob.max():.4f}, mean={prob.mean():.4f}; on={(mask_roi_01==1).sum()}")
488
- except Exception as e:
489
- logging.warning(f"Segmentation failed: {e}")
490
- mask_roi_01 = None
491
- else:
492
- logging.info("Skipping segmentation (no model).")
493
 
494
  # --- Measurement ---
495
- if mask_roi_01 is not None and mask_roi_01.any():
496
- length_cm, breadth_cm, (box_pts, _) = measure_min_area_rect(mask_roi_01, px_per_cm)
497
- surface_area_cm2 = count_area_cm2(mask_roi_01, px_per_cm)
498
- anno_roi = draw_measurement_overlay(roi, mask_roi_01, box_pts, length_cm, breadth_cm)
499
  else:
500
- # fallback to detection-box cm
501
  h_px = max(0, y2 - y1); w_px = max(0, x2 - x1)
502
  length_cm = round(h_px / px_per_cm, 2)
503
  breadth_cm = round(w_px / px_per_cm, 2)
@@ -518,18 +495,14 @@ class AIProcessor:
518
 
519
  segmentation_path = None
520
  annotated_seg_path = None
521
- if mask_roi_01 is not None and mask_roi_01.any():
522
- # safe masked blend (no mask kwarg to addWeighted)
523
  seg_full = image_cv.copy()
524
- roi_overlay = roi.copy()
525
- red = np.zeros_like(roi_overlay); red[:] = (0, 0, 255)
526
- blended = cv2.addWeighted(roi_overlay, 1.0, red, 0.3, 0)
527
- mask_u8 = (mask_roi_01.astype(np.uint8) * 255)
528
- mask3 = cv2.merge([mask_u8, mask_u8, mask_u8])
529
- blended_masked = cv2.bitwise_and(blended, mask3)
530
- roi_bg = cv2.bitwise_and(roi_overlay, cv2.bitwise_not(mask3))
531
- roi_overlay = cv2.add(roi_bg, blended_masked)
532
-
533
  seg_full[y1:y2, x1:x2] = roi_overlay
534
  segmentation_path = os.path.join(out_dir, f"segmentation_{ts}.png")
535
  cv2.imwrite(segmentation_path, seg_full)
@@ -568,7 +541,7 @@ class AIProcessor:
568
  logging.error(f"Visual analysis failed: {e}", exc_info=True)
569
  raise
570
 
571
- # ---------- Knowledge base and reporting stay unchanged ----------
572
  def query_guidelines(self, query: str) -> str:
573
  try:
574
  vs = self.knowledge_base_cache.get("vector_store")
 
1
  # smartheal_ai_processor.py
2
+ # Preserves ALL original class/function names.
3
+ # Changes:
4
+ # - Adds segment_wound(image) with your logic (+ KMeans fallback)
5
+ # - perform_visual_analysis() now calls segment_wound() for mask
6
+ # - Safe overlay (no mask kwarg in addWeighted)
7
+ # - Conditional @spaces.GPU to avoid cudaGetDeviceCount crash
8
 
9
  import os
10
  import time
 
12
  from datetime import datetime
13
  from typing import Optional, Dict, List, Tuple
14
 
15
+ # Quiet HF tokenizers fork warning
16
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
17
 
18
  import cv2
19
  import numpy as np
20
+ from PIL import Image
21
  from PIL.ExifTags import TAGS
22
 
23
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
30
  SEG_MODEL_PATH = "src/segmentation_model.h5" # optional
31
  GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
32
  DATASET_ID = "SmartHeal/wound-image-uploads"
33
+ DEFAULT_PX_PER_CM = 38.0
34
+ PX_PER_CM_MIN, PX_PER_CM_MAX = 5.0, 1200.0
35
 
36
  models_cache: Dict[str, object] = {}
37
  knowledge_base_cache: Dict[str, object] = {}
 
43
 
44
  def _import_tf_loader():
45
  import tensorflow as tf
46
+ tf.config.set_visible_devices([], "GPU") # force TF CPU
47
  from tensorflow.keras.models import load_model
48
  return load_model
49
 
 
67
  from huggingface_hub import HfApi, HfFolder
68
  return HfApi, HfFolder
69
 
70
+ # ---------- Conditional Spaces GPU wrapper ----------
 
71
  def _cuda_available() -> bool:
72
  try:
73
  import torch
 
84
  ) -> str:
85
  try:
86
  from transformers import pipeline
 
87
  pipe = pipeline(
88
  "image-text-to-text",
89
  model="google/medgemma-4b-it",
 
125
  logging.error(f"❌ MedGemma generation error: {e}")
126
  return "⚠️ GPU/LLM worker unavailable"
127
 
 
 
128
  try:
129
  import spaces
130
  if _cuda_available():
 
145
  image_pil: Image.Image,
146
  max_new_tokens: Optional[int] = None,
147
  ) -> str:
 
148
  return _generate_medgemma_report_core(patient_info, visual_results, guideline_context, image_pil, max_new_tokens)
149
  except Exception:
150
  def generate_medgemma_report(
 
288
  f35 = _to_float(exif.get("FocalLengthIn35mmFilm") or exif.get("FocalLengthIn35mm"))
289
  subj_dist_m = _to_float(exif.get("SubjectDistance"))
290
  sensor_w_mm = _estimate_sensor_width_mm(f_mm, f35)
 
291
  meta.update({"f_mm": f_mm, "f35": f35, "sensor_w_mm": sensor_w_mm, "distance_m": subj_dist_m})
292
 
293
  if f_mm and sensor_w_mm and subj_dist_m and subj_dist_m > 0:
 
302
  except Exception:
303
  return float(default_px_per_cm), meta
304
 
305
+ # ---------- Your requested segmentation logic ----------
306
+ def segment_wound(image: np.ndarray) -> np.ndarray:
307
+ """
308
+ Segments wound from a preprocessed ROI image, with a fallback to KMeans if the model fails.
309
+ Returns a mask in 0..255 (uint8), same HxW as input image.
310
+ """
311
+ segmentation_model = models_cache.get("seg", None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
+ if segmentation_model is not None:
314
+ try:
315
+ input_size = getattr(segmentation_model, "input_shape", None)
316
+ if input_size is None or len(input_size) < 3:
317
+ raise ValueError(f"Bad seg input_shape: {input_size}")
318
+ H, W = int(input_size[1]), int(input_size[2]) # (None,H,W,C)
319
+
320
+ resized = cv2.resize(image, (W, H)) # cv2 takes (W,H)
321
+ norm = np.expand_dims(resized / 255.0, axis=0) # (1,H,W,3)
322
+ prediction = segmentation_model.predict(norm, verbose=0)
323
+
324
+ # Handle models with multiple outputs
325
+ if isinstance(prediction, list):
326
+ prediction = prediction[0]
327
+ # squeeze batch dim if present
328
+ prediction = prediction[0] if prediction.ndim >= 3 else prediction
329
+
330
+ # prediction can be (H,W,1) or (H,W)
331
+ pred2d = prediction.squeeze()
332
+ mask_prob = cv2.resize(pred2d, (image.shape[1], image.shape[0])) # back to ROI size
333
+ mask = (mask_prob >= 0.5).astype(np.uint8) * 255
334
+ if mask.max() == 0:
335
+ logging.info("Seg model returned empty mask at 0.5 — keeping as-is (KMeans fallback will handle if needed).")
336
+ return mask.astype(np.uint8)
337
+ except Exception as e:
338
+ logging.warning(f"⚠️ Segmentation model prediction failed: {e}. Falling back to KMeans.")
339
+
340
+ # --- Fallback: color clustering (KMeans, k=2) ---
341
+ Z = image.reshape((-1, 3)).astype(np.float32)
342
+ criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
343
+ _K = 2
344
+ _, labels, centers = cv2.kmeans(Z, _K, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
345
+ centers = centers.astype(np.uint8).reshape(1, _K, 3)
346
+ centers_lab = cv2.cvtColor(centers, cv2.COLOR_BGR2LAB)[0]
347
+ wound_idx = int(np.argmax(centers_lab[:, 1])) # reddest cluster (a* channel)
348
+ mask = (labels.reshape(image.shape[:2]) == wound_idx).astype(np.uint8) * 255
349
+ return mask.astype(np.uint8)
350
+
351
+ # ---------- Measurement + overlay helpers ----------
352
+ def largest_component_mask(binary01: np.ndarray, min_area_px: int = 50) -> np.ndarray:
353
+ num, labels, stats, _ = cv2.connectedComponentsWithStats(binary01.astype(np.uint8), connectivity=8)
354
  if num <= 1:
355
+ return binary01.astype(np.uint8)
356
  areas = stats[1:, cv2.CC_STAT_AREA]
357
  if areas.size == 0 or areas.max() < min_area_px:
358
+ return binary01.astype(np.uint8)
359
  largest_idx = 1 + int(np.argmax(areas))
360
  return (labels == largest_idx).astype(np.uint8)
361
 
362
+ def measure_min_area_rect(mask01: np.ndarray, px_per_cm: float) -> Tuple[float, float, Tuple]:
363
+ contours, _ = cv2.findContours(mask01.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
364
  if not contours:
365
  return 0.0, 0.0, (None, None)
366
  cnt = max(contours, key=cv2.contourArea)
 
372
  box = cv2.boxPoints(rect).astype(int)
373
  return length_cm, breadth_cm, (box, rect[0])
374
 
375
+ def count_area_cm2(mask01: np.ndarray, px_per_cm: float) -> float:
376
+ px_count = float(mask01.astype(bool).sum())
377
  return round(px_count / (max(px_per_cm, 1e-6) ** 2), 2)
378
 
379
  def draw_measurement_overlay(
 
385
  thickness: int = 2
386
  ) -> np.ndarray:
387
  overlay = base_bgr.copy()
388
+ red = np.zeros_like(overlay); red[:] = (0, 0, 255)
389
+ blended = cv2.addWeighted(overlay, 1.0, red, 0.3, 0)
 
390
  m3 = np.dstack([mask01 * 255] * 3).astype("uint8")
391
+ overlay = cv2.add(cv2.bitwise_and(overlay, cv2.bitwise_not(m3)),
392
+ cv2.bitwise_and(blended, m3))
 
393
 
394
  if rect_box is not None:
395
  cv2.polylines(overlay, [rect_box], True, (255, 255, 255), thickness)
 
407
  cv2.arrowedLine(img, p1, p2, (255, 255, 255), thickness, tipLength=0.05)
408
  cv2.arrowedLine(img, p2, p1, (255, 255, 255), thickness, tipLength=0.05)
409
 
 
 
 
410
  def put_label(text, org):
411
  cv2.putText(overlay, text, (org[0] + 4, org[1] - 4),
412
  cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 4, cv2.LINE_AA)
413
  cv2.putText(overlay, text, (org[0] + 4, org[1] - 4),
414
  cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
415
 
416
+ draw_arrow(overlay, mids[long_pair[0]], mids[long_pair[1]])
417
+ draw_arrow(overlay, mids[short_pair[0]], mids[short_pair[1]])
418
  put_label(f"{length_cm:.2f} cm", mids[long_pair[0]])
419
  put_label(f"{breadth_cm:.2f} cm", mids[short_pair[0]])
420
  return overlay
 
435
 
436
  def perform_visual_analysis(self, image_pil: Image.Image) -> Dict:
437
  """
438
+ YOLO detect → crop ROI → segment_wound(ROI) → largest component →
439
+ minAreaRect measurement (cm) using EXIF px/cm → save outputs.
440
  """
441
  try:
 
442
  px_per_cm, exif_meta = estimate_px_per_cm_from_exif(image_pil, DEFAULT_PX_PER_CM)
 
 
443
  image_cv = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR)
444
 
445
+ # --- Detection ---
446
  det_model = self.models_cache.get("det")
447
  if det_model is None:
448
  raise RuntimeError("YOLO model not loaded")
 
449
  results = det_model.predict(image_cv, verbose=False, device="cpu")
450
  if not results or not getattr(results[0], "boxes", None) or len(results[0].boxes) == 0:
451
+ import gradio as gr
452
  raise gr.Error("No wound could be detected.")
453
 
454
  box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
 
460
  import gradio as gr
461
  raise gr.Error("Detected ROI is empty.")
462
 
463
+ # --- Segmentation (your logic + fallback) ---
464
+ mask_u8_255 = segment_wound(roi) # 0..255
465
+ # Clean up & keep largest component (in 0/1)
466
+ mask01 = (mask_u8_255 > 127).astype(np.uint8)
467
+ mask01 = cv2.morphologyEx(mask01, cv2.MORPH_OPEN, np.ones((3,3), np.uint8), iterations=1)
468
+ mask01 = cv2.morphologyEx(mask01, cv2.MORPH_CLOSE, np.ones((3,3), np.uint8), iterations=1)
469
+ mask01 = largest_component_mask(mask01, min_area_px=30)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
 
471
  # --- Measurement ---
472
+ if mask01.any():
473
+ length_cm, breadth_cm, (box_pts, _) = measure_min_area_rect(mask01, px_per_cm)
474
+ surface_area_cm2 = count_area_cm2(mask01, px_per_cm)
475
+ anno_roi = draw_measurement_overlay(roi, mask01, box_pts, length_cm, breadth_cm)
476
  else:
477
+ # fallback to detection box
478
  h_px = max(0, y2 - y1); w_px = max(0, x2 - x1)
479
  length_cm = round(h_px / px_per_cm, 2)
480
  breadth_cm = round(w_px / px_per_cm, 2)
 
495
 
496
  segmentation_path = None
497
  annotated_seg_path = None
498
+ if mask01.any():
 
499
  seg_full = image_cv.copy()
500
+ # safe masked blend (no mask kwarg)
501
+ red = np.zeros_like(roi); red[:] = (0, 0, 255)
502
+ blended = cv2.addWeighted(roi, 1.0, red, 0.3, 0)
503
+ m3 = np.dstack([mask01 * 255] * 3).astype("uint8")
504
+ roi_overlay = cv2.add(cv2.bitwise_and(roi, cv2.bitwise_not(m3)),
505
+ cv2.bitwise_and(blended, m3))
 
 
 
506
  seg_full[y1:y2, x1:x2] = roi_overlay
507
  segmentation_path = os.path.join(out_dir, f"segmentation_{ts}.png")
508
  cv2.imwrite(segmentation_path, seg_full)
 
541
  logging.error(f"Visual analysis failed: {e}", exc_info=True)
542
  raise
543
 
544
+ # ---------- Knowledge base + reporting (unchanged names) ----------
545
  def query_guidelines(self, query: str) -> str:
546
  try:
547
  vs = self.knowledge_base_cache.get("vector_store")