Spaces:
Sleeping
Sleeping
| # smartheal_ai_processor.py | |
| # Verbose, instrumented version — preserves public class/function names | |
| # Turn on deep logging: export LOGLEVEL=DEBUG SMARTHEAL_DEBUG=1 | |
| import os | |
| import logging | |
| from datetime import datetime | |
| from typing import Optional, Dict, List, Tuple | |
| # ---- Environment defaults (do NOT globally hint CUDA here) ---- | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| LOGLEVEL = os.getenv("LOGLEVEL", "INFO").upper() | |
| SMARTHEAL_DEBUG = os.getenv("SMARTHEAL_DEBUG", "0") == "1" | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from PIL.ExifTags import TAGS | |
| import spaces | |
| # --- Logging config --- | |
| logging.basicConfig( | |
| level=getattr(logging, LOGLEVEL, logging.INFO), | |
| format="%(asctime)s - %(levelname)s - %(message)s", | |
| ) | |
| def _log_kv(prefix: str, kv: Dict): | |
| logging.debug(prefix + " | " + " | ".join(f"{k}={v}" for k, v in kv.items())) | |
| # ---- Paths / constants ---- | |
| UPLOADS_DIR = "uploads" | |
| os.makedirs(UPLOADS_DIR, exist_ok=True) | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) | |
| YOLO_MODEL_PATH = "src/best.pt" | |
| SEG_MODEL_PATH = "src/segmentation_model_fixed.h5" # optional | |
| GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"] | |
| DATASET_ID = "SmartHeal/wound-image-uploads" | |
| DEFAULT_PX_PER_CM = 38.0 | |
| PX_PER_CM_MIN, PX_PER_CM_MAX = 5.0, 1200.0 | |
| # Segmentation preprocessing knobs | |
| SEG_EXPECTS_RGB = os.getenv("SEG_EXPECTS_RGB", "1") == "1" # most TF models trained on RGB | |
| SEG_NORM = os.getenv("SEG_NORM", "0to1") # "0to1" | "imagenet" | |
| SEG_THRESH = float(os.getenv("SEG_THRESH", "0.5")) | |
| models_cache: Dict[str, object] = {} | |
| knowledge_base_cache: Dict[str, object] = {} | |
| # ---------- Utilities to prevent CUDA in main process ---------- | |
| from contextlib import contextmanager | |
| def _no_cuda_env(): | |
| """ | |
| Mask GPUs so any library imported/constructed in the main process | |
| cannot see CUDA (required for Spaces Stateless GPU). | |
| """ | |
| prev = os.environ.get("CUDA_VISIBLE_DEVICES") | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | |
| try: | |
| yield | |
| finally: | |
| if prev is None: | |
| os.environ.pop("CUDA_VISIBLE_DEVICES", None) | |
| else: | |
| os.environ["CUDA_VISIBLE_DEVICES"] = prev | |
| # ---------- Lazy imports (wrapped where needed) ---------- | |
| def _import_ultralytics(): | |
| # Prevent Ultralytics from probing CUDA on import | |
| with _no_cuda_env(): | |
| from ultralytics import YOLO | |
| return YOLO | |
| def _import_tf_loader(): | |
| import tensorflow as tf | |
| tf.config.set_visible_devices([], "GPU") | |
| from tensorflow.keras.models import load_model | |
| return load_model | |
| def _import_hf_cls(): | |
| from transformers import pipeline | |
| return pipeline | |
| def _import_embeddings(): | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| return HuggingFaceEmbeddings | |
| def _import_langchain_pdf(): | |
| from langchain_community.document_loaders import PyPDFLoader | |
| return PyPDFLoader | |
| def _import_langchain_faiss(): | |
| from langchain_community.vectorstores import FAISS | |
| return FAISS | |
| def _import_hf_hub(): | |
| from huggingface_hub import HfApi, HfFolder | |
| return HfApi, HfFolder | |
| # ---------- SmartHeal prompts (system + user prefix) ---------- | |
| SMARTHEAL_SYSTEM_PROMPT = """\ | |
| You are SmartHeal Clinical Assistant, a wound-care decision-support system. | |
| You analyze wound photographs and brief patient context to produce careful, | |
| specific, guideline-informed recommendations WITHOUT diagnosing. You always: | |
| - Use the measurements calculated by the vision pipeline as ground truth. | |
| - Prefer concise, actionable steps tailored to exudate level, infection risk, and pain. | |
| - Flag uncertainties and red flags that need escalation to a clinician. | |
| - Avoid contraindicated advice; do not infer unseen comorbidities. | |
| - Keep under 300 words and use the requested headings exactly. | |
| - Tone: professional, clear, and conservative; no definitive medical claims. | |
| - Safety: remind the user to seek clinician review for changes or red flags. | |
| """ | |
| SMARTHEAL_USER_PREFIX = """\ | |
| Patient: {patient_info} | |
| Visual findings: type={wound_type}, size={length_cm}x{breadth_cm} cm, area={area_cm2} cm^2, | |
| detection_conf={det_conf:.2f}, calibration={px_per_cm} px/cm. | |
| Guideline context (snippets you can draw principles from; do not quote at length): | |
| {guideline_context} | |
| Write a structured answer with these headings exactly: | |
| 1. Clinical Summary (max 4 bullet points) | |
| 2. Likely Stage/Type (if uncertain, say 'uncertain') | |
| 3. Treatment Plan (specific dressing choices and frequency based on exudate/infection risk) | |
| 4. Red Flags (what to escalate and when) | |
| 5. Follow-up Cadence (days) | |
| 6. Notes (assumptions/uncertainties) | |
| Keep to 220–300 words. Do NOT provide diagnosis. Avoid contraindicated advice. | |
| """ | |
| def _vlm_infer_gpu(messages, model_id: str, max_new_tokens: int, token: Optional[str]): | |
| """ | |
| Runs entirely inside a Spaces GPU worker. It's the ONLY place we allow CUDA init. | |
| Safe for: | |
| - CUDA device selection (no 'Invalid device id') | |
| - BF16/FP16 choice via compute capability | |
| - LLaVA processors with patch_size=None | |
| - Processors WITHOUT a chat template (fallback to plain/LLaVA-style prompt) | |
| """ | |
| import logging | |
| import torch | |
| from typing import Optional, List | |
| from transformers import ( | |
| AutoProcessor, | |
| AutoModelForVision2Seq, | |
| StoppingCriteria, | |
| StoppingCriteriaList, | |
| ) | |
| # -------- Device & dtype (robust) -------- | |
| def _pick_device_and_dtype(): | |
| if not torch.cuda.is_available() or torch.cuda.device_count() == 0: | |
| logging.warning("CUDA not available; using CPU.") | |
| return "cpu", torch.float32 | |
| idx = 0 | |
| try: | |
| torch.cuda.set_device(idx) | |
| except Exception as e: | |
| logging.warning(f"torch.cuda.set_device({idx}) failed: {e}; falling back to CPU.") | |
| return "cpu", torch.float32 | |
| device = f"cuda:{idx}" | |
| try: | |
| props = torch.cuda.get_device_properties(idx) | |
| cc = props.major * 10 + props.minor | |
| dtype = torch.bfloat16 if cc >= 80 else torch.float16 | |
| except Exception as e: | |
| logging.warning(f"Could not query CUDA props: {e}; defaulting to float16.") | |
| dtype = torch.float16 | |
| return device, dtype | |
| device, torch_dtype = _pick_device_and_dtype() | |
| # -------- Load model & processor -------- | |
| model = AutoModelForVision2Seq.from_pretrained( | |
| model_id, | |
| torch_dtype=torch_dtype, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| token=token, | |
| ).to(device) | |
| model.eval() | |
| processor = AutoProcessor.from_pretrained( | |
| model_id, trust_remote_code=True, token=token | |
| ) | |
| # -------- Extract image & text -------- | |
| image_obj = None | |
| text_prompt = "" | |
| for m in messages: | |
| if m.get("role") == "user": | |
| for c in m.get("content", []): | |
| if c.get("type") == "image": | |
| image_obj = c.get("image") | |
| elif c.get("type") == "text": | |
| text_prompt = c.get("text", "") | |
| break | |
| if image_obj is None: | |
| raise ValueError("No image found in messages for VLM inference.") | |
| # -------- Normalize image to PIL -------- | |
| from PIL import Image | |
| import numpy as np | |
| def _to_pil(x): | |
| if isinstance(x, Image.Image): | |
| return x.convert("RGB") | |
| if isinstance(x, str): | |
| return Image.open(x).convert("RGB") | |
| if isinstance(x, np.ndarray): | |
| if x.ndim == 2: | |
| x = np.stack([x]*3, axis=-1) | |
| if x.dtype != np.uint8: | |
| x = x.astype(np.uint8) | |
| return Image.fromarray(x, "RGB") | |
| if hasattr(x, "read"): | |
| return Image.open(x).convert("RGB") | |
| raise TypeError(f"Unsupported image type: {type(x)}") | |
| image_pil = _to_pil(image_obj) | |
| # -------- Ensure patch_size for LLaVA processors -------- | |
| def _ensure_patch_size(proc, mdl): | |
| ps = getattr(proc, "patch_size", None) | |
| if not ps: | |
| candidates = [ | |
| getattr(getattr(mdl, "vision_tower", None), "config", None), | |
| getattr(mdl.config, "vision_config", None), | |
| getattr(proc, "image_processor", None), | |
| getattr(getattr(proc, "image_processor", None), "config", None), | |
| ] | |
| for obj in candidates: | |
| if obj is None: | |
| continue | |
| maybe = getattr(obj, "patch_size", None) | |
| if maybe: | |
| ps = int(maybe); break | |
| if not ps: | |
| ps = 14 # safe default for ViT-L/14-style | |
| try: | |
| setattr(proc, "patch_size", ps) | |
| except Exception: | |
| pass | |
| return ps | |
| _ensure_patch_size(processor, model) | |
| # -------- Build text (chat-template only if it truly exists) -------- | |
| # Some processors expose apply_chat_template but tokenizer has no template → ValueError. Guard it. | |
| tokenizer = getattr(processor, "tokenizer", None) | |
| has_template = bool(getattr(tokenizer, "chat_template", None)) | |
| used_chat_template = False | |
| def _looks_like_llava(): | |
| name = processor.__class__.__name__.lower() | |
| mid = (model_id or "").lower() | |
| return ("llava" in name) or ("llava" in mid) | |
| if hasattr(processor, "apply_chat_template") and has_template: | |
| try: | |
| chat = [{ | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image_pil}, | |
| {"type": "text", "text": text_prompt or "Describe the image."}, | |
| ], | |
| }] | |
| text_for_model = processor.apply_chat_template( | |
| chat, add_generation_prompt=True, tokenize=False | |
| ) | |
| used_chat_template = True | |
| except Exception as e: | |
| logging.info(f"No usable chat template ({e}); falling back to plain prompt.") | |
| text_for_model = ( | |
| f"USER: <image>\n{text_prompt or 'Describe the image.'}\nASSISTANT:" | |
| if _looks_like_llava() else (text_prompt or "Describe the image.") | |
| ) | |
| else: | |
| text_for_model = ( | |
| f"USER: <image>\n{text_prompt or 'Describe the image.'}\nASSISTANT:" | |
| if _looks_like_llava() else (text_prompt or "Describe the image.") | |
| ) | |
| # -------- Tokenize -------- | |
| inputs = processor( | |
| text=[text_for_model], | |
| images=[image_pil], | |
| return_tensors="pt", | |
| padding=True, | |
| ).to(device) | |
| # -------- Stopping criteria -------- | |
| class EosTokenCriteria(StoppingCriteria): | |
| def __init__(self, eos_token_ids: List[int]): | |
| import torch as _t | |
| self.eos = _t.tensor(eos_token_ids, dtype=_t.long) | |
| def __call__(self, input_ids, scores, **kwargs) -> bool: | |
| import torch as _t | |
| last_tok = input_ids[:, -1] | |
| return _t.isin(last_tok, self.eos.to(last_tok.device)).any().item() | |
| eos_ids: List[int] = [] | |
| if tokenizer is not None: | |
| for attr in ("eos_token_id", "eot_token_id"): | |
| v = getattr(tokenizer, attr, None) | |
| if v is None: continue | |
| eos_ids.extend([v] if isinstance(v, int) else list(v)) | |
| if not eos_ids: | |
| cfg = getattr(model, "generation_config", None) | |
| if cfg and getattr(cfg, "eos_token_id", None) is not None: | |
| eos_ids = [cfg.eos_token_id] | |
| else: | |
| eos_ids = [2] | |
| stopping_criteria = StoppingCriteriaList([EosTokenCriteria(eos_ids)]) | |
| if tokenizer is not None and getattr(tokenizer, "pad_token_id", None) is None: | |
| try: tokenizer.pad_token_id = eos_ids[0] | |
| except Exception: pass | |
| # -------- Generate -------- | |
| gen_kwargs = dict( | |
| max_new_tokens=int(max_new_tokens or 256), | |
| do_sample=False, | |
| stopping_criteria=stopping_criteria, | |
| eos_token_id=eos_ids[0] if eos_ids else None, | |
| pad_token_id=getattr(tokenizer, "pad_token_id", None) if tokenizer else None, | |
| ) | |
| with torch.inference_mode(): | |
| out = model.generate(**inputs, **gen_kwargs) | |
| # -------- Decode -------- | |
| seq = out[0] | |
| if "input_ids" in inputs: | |
| cut = inputs["input_ids"].shape[-1] | |
| seq = seq[cut:] | |
| if tokenizer is not None: | |
| text_out = tokenizer.decode(seq, skip_special_tokens=True) | |
| elif hasattr(processor, "batch_decode"): | |
| text_out = processor.batch_decode(seq.unsqueeze(0), skip_special_tokens=True)[0] | |
| else: | |
| text_out = str(seq.tolist()) | |
| return text_out.strip() | |
| def generate_medgemma_report( | |
| patient_info: str, | |
| visual_results: Dict, | |
| guideline_context: str, | |
| image_pil: Image.Image, | |
| max_new_tokens: Optional[int] = None, | |
| ) -> str: | |
| """ | |
| MedGemma replacement using a vision-language model. | |
| Loads & runs ONLY inside a GPU worker to satisfy Stateless GPU constraints. | |
| """ | |
| if os.getenv("SMARTHEAL_ENABLE_VLM", "1") != "1": | |
| return "⚠️ VLM disabled" | |
| model_id = os.getenv("SMARTHEAL_VLM_MODEL", "bczhou/tiny-llava-v1-hf") | |
| max_new_tokens = max_new_tokens or int(os.getenv("SMARTHEAL_VLM_MAX_TOKENS", "600")) | |
| uprompt = SMARTHEAL_USER_PREFIX.format( | |
| patient_info=patient_info, | |
| wound_type=visual_results.get("wound_type", "Unknown"), | |
| length_cm=visual_results.get("length_cm", 0), | |
| breadth_cm=visual_results.get("breadth_cm", 0), | |
| area_cm2=visual_results.get("surface_area_cm2", 0), | |
| det_conf=float(visual_results.get("detection_confidence", 0.0)), | |
| px_per_cm=visual_results.get("px_per_cm", "?"), | |
| guideline_context=(guideline_context or "")[:900], | |
| ) | |
| # The `messages` structure is passed to the verified `_vlm_infer_gpu` function | |
| messages = [ | |
| {"role": "system", "content": [{"type": "text", "text": SMARTHEAL_SYSTEM_PROMPT}]}, | |
| {"role": "user", "content": [ | |
| {"type": "image", "image": image_pil}, | |
| {"type": "text", "text": uprompt}, | |
| ]}, | |
| ] | |
| try: | |
| return _vlm_infer_gpu(messages, model_id, max_new_tokens, HF_TOKEN) | |
| except Exception as e: | |
| logging.error(f"VLM call failed: {e}", exc_info=True) | |
| return f"⚠️ VLM error: {e}" | |
| # ---------- Initialize CPU models ---------- | |
| def load_yolo_model(): | |
| YOLO = _import_ultralytics() | |
| # Construct model with CUDA masked to avoid auto-selecting cuda:0 | |
| with _no_cuda_env(): | |
| model = YOLO(YOLO_MODEL_PATH) | |
| return model | |
| def load_segmentation_model(): | |
| import tensorflow as tf | |
| load_model = _import_tf_loader() | |
| return load_model(SEG_MODEL_PATH, compile=False, custom_objects={'InputLayer': tf.keras.layers.InputLayer}) | |
| def load_classification_pipeline(): | |
| pipe = _import_hf_cls() | |
| return pipe("image-classification", model="Hemg/Wound-classification", token=HF_TOKEN, device="cpu") | |
| def load_embedding_model(): | |
| Emb = _import_embeddings() | |
| return Emb(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"}) | |
| def initialize_cpu_models() -> None: | |
| if HF_TOKEN: | |
| try: | |
| HfApi, HfFolder = _import_hf_hub() | |
| HfFolder.save_token(HF_TOKEN) | |
| logging.info("✅ HF token set") | |
| except Exception as e: | |
| logging.warning(f"HF token save failed: {e}") | |
| if "det" not in models_cache: | |
| try: | |
| models_cache["det"] = load_yolo_model() | |
| logging.info("✅ YOLO loaded (CPU; CUDA masked in main)") | |
| except Exception as e: | |
| logging.error(f"YOLO load failed: {e}") | |
| if "seg" not in models_cache: | |
| try: | |
| if os.path.exists(SEG_MODEL_PATH): | |
| models_cache["seg"] = load_segmentation_model() | |
| m = models_cache["seg"] | |
| ishape = getattr(m, "input_shape", None) | |
| oshape = getattr(m, "output_shape", None) | |
| logging.info(f"✅ Segmentation model loaded (CPU) | input_shape={ishape} output_shape={oshape}") | |
| else: | |
| models_cache["seg"] = None | |
| logging.warning("Segmentation model file missing; skipping.") | |
| except Exception as e: | |
| models_cache["seg"] = None | |
| logging.warning(f"Segmentation unavailable: {e}") | |
| if "cls" not in models_cache: | |
| try: | |
| models_cache["cls"] = load_classification_pipeline() | |
| logging.info("✅ Classifier loaded (CPU)") | |
| except Exception as e: | |
| models_cache["cls"] = None | |
| logging.warning(f"Classifier unavailable: {e}") | |
| if "embedding_model" not in models_cache: | |
| try: | |
| models_cache["embedding_model"] = load_embedding_model() | |
| logging.info("✅ Embeddings loaded (CPU)") | |
| except Exception as e: | |
| models_cache["embedding_model"] = None | |
| logging.warning(f"Embeddings unavailable: {e}") | |
| def setup_knowledge_base() -> None: | |
| if "vector_store" in knowledge_base_cache: | |
| return | |
| docs: List = [] | |
| try: | |
| PyPDFLoader = _import_langchain_pdf() | |
| for pdf in GUIDELINE_PDFS: | |
| if os.path.exists(pdf): | |
| try: | |
| docs.extend(PyPDFLoader(pdf).load()) | |
| logging.info(f"Loaded PDF: {pdf}") | |
| except Exception as e: | |
| logging.warning(f"PDF load failed ({pdf}): {e}") | |
| except Exception as e: | |
| logging.warning(f"LangChain PDF loader unavailable: {e}") | |
| if docs and models_cache.get("embedding_model"): | |
| try: | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| FAISS = _import_langchain_faiss() | |
| chunks = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100).split_documents(docs) | |
| knowledge_base_cache["vector_store"] = FAISS.from_documents(chunks, models_cache["embedding_model"]) | |
| logging.info(f"✅ Knowledge base ready ({len(chunks)} chunks)") | |
| except Exception as e: | |
| knowledge_base_cache["vector_store"] = None | |
| logging.warning(f"KB build failed: {e}") | |
| else: | |
| knowledge_base_cache["vector_store"] = None | |
| logging.warning("KB disabled (no docs or embeddings).") | |
| initialize_cpu_models() | |
| setup_knowledge_base() | |
| # ---------- Calibration helpers ---------- | |
| def _exif_to_dict(pil_img: Image.Image) -> Dict[str, object]: | |
| out = {} | |
| try: | |
| exif = pil_img.getexif() | |
| if not exif: | |
| return out | |
| for k, v in exif.items(): | |
| tag = TAGS.get(k, k) | |
| out[tag] = v | |
| except Exception: | |
| pass | |
| return out | |
| def _to_float(val) -> Optional[float]: | |
| try: | |
| if val is None: | |
| return None | |
| if isinstance(val, tuple) and len(val) == 2: | |
| num, den = float(val[0]), float(val[1]) if float(val[1]) != 0 else 1.0 | |
| return num / den | |
| return float(val) | |
| except Exception: | |
| return None | |
| def _estimate_sensor_width_mm(f_mm: Optional[float], f35: Optional[float]) -> Optional[float]: | |
| if f_mm and f35 and f35 > 0: | |
| return 36.0 * f_mm / f35 | |
| return None | |
| def estimate_px_per_cm_from_exif(pil_img: Image.Image, default_px_per_cm: float = DEFAULT_PX_PER_CM) -> Tuple[float, Dict]: | |
| meta = {"used": "default", "f_mm": None, "f35": None, "sensor_w_mm": None, "distance_m": None} | |
| try: | |
| exif = _exif_to_dict(pil_img) | |
| f_mm = _to_float(exif.get("FocalLength")) | |
| f35 = _to_float(exif.get("FocalLengthIn35mmFilm") or exif.get("FocalLengthIn35mm")) | |
| subj_dist_m = _to_float(exif.get("SubjectDistance")) | |
| sensor_w_mm = _estimate_sensor_width_mm(f_mm, f35) | |
| meta.update({"f_mm": f_mm, "f35": f35, "sensor_w_mm": sensor_w_mm, "distance_m": subj_dist_m}) | |
| if f_mm and sensor_w_mm and subj_dist_m and subj_dist_m > 0: | |
| w_px = pil_img.width | |
| field_w_mm = sensor_w_mm * (subj_dist_m * 1000.0) / f_mm | |
| field_w_cm = field_w_mm / 10.0 | |
| px_per_cm = w_px / max(field_w_cm, 1e-6) | |
| px_per_cm = float(np.clip(px_per_cm, PX_PER_CM_MIN, PX_PER_CM_MAX)) | |
| meta["used"] = "exif" | |
| return px_per_cm, meta | |
| return float(default_px_per_cm), meta | |
| except Exception: | |
| return float(default_px_per_cm), meta | |
| # ---------- Segmentation helpers ---------- | |
| def _imagenet_norm(arr: np.ndarray) -> np.ndarray: | |
| mean = np.array([123.675, 116.28, 103.53], dtype=np.float32) | |
| std = np.array([58.395, 57.12, 57.375], dtype=np.float32) | |
| return (arr.astype(np.float32) - mean) / std | |
| def _preprocess_for_seg(bgr_roi: np.ndarray, target_hw: Tuple[int, int]) -> np.ndarray: | |
| H, W = target_hw | |
| resized = cv2.resize(bgr_roi, (W, H), interpolation=cv2.INTER_LINEAR) | |
| if SEG_EXPECTS_RGB: | |
| resized = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB) | |
| if SEG_NORM.lower() == "imagenet": | |
| x = _imagenet_norm(resized) | |
| else: | |
| x = resized.astype(np.float32) / 255.0 | |
| x = np.expand_dims(x, axis=0) # (1,H,W,3) | |
| return x | |
| def _to_prob(pred: np.ndarray) -> np.ndarray: | |
| p = np.squeeze(pred) | |
| pmin, pmax = float(p.min()), float(p.max()) | |
| if pmax > 1.0 or pmin < 0.0: | |
| p = 1.0 / (1.0 + np.exp(-p)) | |
| return p.astype(np.float32) | |
| # ---- Adaptive threshold + GrabCut grow ---- | |
| def _adaptive_prob_threshold(p: np.ndarray) -> float: | |
| """ | |
| Choose a threshold that avoids tiny blobs while not swallowing skin. | |
| Try Otsu and the 90th percentile, clamp to [0.25, 0.65], pick by area heuristic. | |
| """ | |
| p01 = np.clip(p.astype(np.float32), 0, 1) | |
| p255 = (p01 * 255).astype(np.uint8) | |
| ret_otsu, _ = cv2.threshold(p255, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| thr_otsu = float(np.clip(ret_otsu / 255.0, 0.25, 0.65)) | |
| thr_pctl = float(np.clip(np.percentile(p01, 90), 0.25, 0.65)) | |
| def area_frac(thr: float) -> float: | |
| return float((p01 >= thr).sum()) / float(p01.size) | |
| af_otsu = area_frac(thr_otsu) | |
| af_pctl = area_frac(thr_pctl) | |
| def score(af: float) -> float: | |
| target_low, target_high = 0.03, 0.10 | |
| if af < target_low: return abs(af - target_low) * 3.0 | |
| if af > target_high: return abs(af - target_high) * 1.5 | |
| return 0.0 | |
| return thr_otsu if score(af_otsu) <= score(af_pctl) else thr_pctl | |
| def _grabcut_refine(bgr: np.ndarray, seed01: np.ndarray, iters: int = 3) -> np.ndarray: | |
| """Grow from a confident core into low-contrast margins.""" | |
| h, w = bgr.shape[:2] | |
| gc = np.full((h, w), cv2.GC_PR_BGD, np.uint8) | |
| k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) | |
| seed_dil = cv2.dilate(seed01, k, iterations=1) | |
| gc[seed01.astype(bool)] = cv2.GC_PR_FGD | |
| gc[seed_dil.astype(bool)] = cv2.GC_FGD | |
| gc[0, :], gc[-1, :], gc[:, 0], gc[:, 1] = cv2.GC_BGD, cv2.GC_BGD, cv2.GC_BGD, cv2.GC_BGD | |
| bgdModel = np.zeros((1, 65), np.float64) | |
| fgdModel = np.zeros((1, 65), np.float64) | |
| cv2.grabCut(bgr, gc, None, bgdModel, fgdModel, iters, cv2.GC_INIT_WITH_MASK) | |
| return np.where((gc == cv2.GC_FGD) | (gc == cv2.GC_PR_FGD), 1, 0).astype(np.uint8) | |
| def _fill_holes(mask01: np.ndarray) -> np.ndarray: | |
| h, w = mask01.shape[:2] | |
| ff = np.zeros((h + 2, w + 2), np.uint8) | |
| m = (mask01 * 255).astype(np.uint8).copy() | |
| cv2.floodFill(m, ff, (0, 0), 255) | |
| m_inv = cv2.bitwise_not(m) | |
| out = ((mask01 * 255) | m_inv) // 255 | |
| return out.astype(np.uint8) | |
| def _clean_mask(mask01: np.ndarray) -> np.ndarray: | |
| """Open → Close → Fill holes → Largest component (no dilation).""" | |
| mask01 = (mask01 > 0).astype(np.uint8) | |
| k3 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) | |
| k5 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) | |
| mask01 = cv2.morphologyEx(mask01, cv2.MORPH_OPEN, k3, iterations=1) | |
| mask01 = cv2.morphologyEx(mask01, cv2.MORPH_CLOSE, k5, iterations=1) | |
| mask01 = _fill_holes(mask01) | |
| # Keep largest component only | |
| num, labels, stats, _ = cv2.connectedComponentsWithStats(mask01, 8) | |
| if num > 1: | |
| areas = stats[1:, cv2.CC_STAT_AREA] | |
| if areas.size: | |
| largest_idx = 1 + int(np.argmax(areas)) | |
| mask01 = (labels == largest_idx).astype(np.uint8) | |
| return (mask01 > 0).astype(np.uint8) | |
| # Global last debug dict (per-process) | |
| _last_seg_debug: Dict[str, object] = {} | |
| def segment_wound(image_bgr: np.ndarray, ts: str, out_dir: str) -> Tuple[np.ndarray, Dict[str, object]]: | |
| """ | |
| TF model → adaptive threshold on prob → GrabCut grow → cleanup. | |
| Fallback: KMeans-Lab. | |
| Returns (mask_uint8_0_255, debug_dict) | |
| """ | |
| debug = {"used": None, "reason": None, "positive_fraction": 0.0, | |
| "thr": None, "heatmap_path": None, "roi_seen_by_model": None} | |
| seg_model = models_cache.get("seg", None) | |
| # --- Model path --- | |
| if seg_model is not None: | |
| try: | |
| ishape = getattr(seg_model, "input_shape", None) | |
| if not ishape or len(ishape) < 4: | |
| raise ValueError(f"Bad seg input_shape: {ishape}") | |
| th, tw = int(ishape[1]), int(ishape[2]) | |
| x = _preprocess_for_seg(image_bgr, (th, tw)) | |
| roi_seen_path = None | |
| if SMARTHEAL_DEBUG: | |
| roi_seen_path = os.path.join(out_dir, f"roi_for_seg_{ts}.png") | |
| cv2.imwrite(roi_seen_path, image_bgr) | |
| pred = seg_model.predict(x, verbose=0) | |
| if isinstance(pred, (list, tuple)): pred = pred[0] | |
| p = _to_prob(pred) | |
| p = cv2.resize(p, (image_bgr.shape[1], image_bgr.shape[0]), interpolation=cv2.INTER_LINEAR) | |
| heatmap_path = None | |
| if SMARTHEAL_DEBUG: | |
| hm = (np.clip(p, 0, 1) * 255).astype(np.uint8) | |
| heat = cv2.applyColorMap(hm, cv2.COLORMAP_JET) | |
| heatmap_path = os.path.join(out_dir, f"seg_pred_heatmap_{ts}.png") | |
| cv2.imwrite(heatmap_path, heat) | |
| thr = _adaptive_prob_threshold(p) | |
| core01 = (p >= thr).astype(np.uint8) | |
| core_frac = float(core01.sum()) / float(core01.size) | |
| if core_frac < 0.005: | |
| thr2 = max(thr - 0.10, 0.15) | |
| core01 = (p >= thr2).astype(np.uint8) | |
| thr = thr2 | |
| core_frac = float(core01.sum()) / float(core01.size) | |
| if core01.any(): | |
| gc01 = _grabcut_refine(image_bgr, core01, iters=3) | |
| mask01 = _clean_mask(gc01) | |
| else: | |
| mask01 = np.zeros(core01.shape, np.uint8) | |
| pos_frac = float(mask01.sum()) / float(mask01.size) | |
| logging.info(f"SegModel USED | thr={float(thr):.2f} core_frac={core_frac:.4f} final_frac={pos_frac:.4f}") | |
| debug.update({ | |
| "used": "tf_model", | |
| "reason": "ok", | |
| "positive_fraction": pos_frac, | |
| "thr": float(thr), | |
| "heatmap_path": heatmap_path, | |
| "roi_seen_by_model": roi_seen_path | |
| }) | |
| return (mask01 * 255).astype(np.uint8), debug | |
| except Exception as e: | |
| logging.warning(f"⚠️ Segmentation model failed → fallback. Reason: {e}") | |
| debug.update({"used": "fallback_kmeans", "reason": f"model_failed: {e}"}) | |
| # --- Fallback: KMeans in Lab (reddest cluster as wound) --- | |
| Z = image_bgr.reshape((-1, 3)).astype(np.float32) | |
| criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0) | |
| _, labels, centers = cv2.kmeans(Z, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS) | |
| centers_u8 = centers.astype(np.uint8).reshape(1, 2, 3) | |
| centers_lab = cv2.cvtColor(centers_u8, cv2.COLOR_BGR2LAB)[0] | |
| wound_idx = int(np.argmax(centers_lab[:, 1])) # maximize a* (red) | |
| mask01 = (labels.reshape(image_bgr.shape[:2]) == wound_idx).astype(np.uint8) | |
| mask01 = _clean_mask(mask01) | |
| pos_frac = float(mask01.sum()) / float(mask01.size) | |
| logging.info(f"KMeans USED | final_frac={pos_frac:.4f}") | |
| debug.update({ | |
| "used": "fallback_kmeans", | |
| "reason": debug.get("reason") or "no_model", | |
| "positive_fraction": pos_frac, | |
| "thr": None | |
| }) | |
| return (mask01 * 255).astype(np.uint8), debug | |
| # ---------- Measurement + overlay helpers ---------- | |
| def largest_component_mask(binary01: np.ndarray, min_area_px: int = 50) -> np.ndarray: | |
| num, labels, stats, _ = cv2.connectedComponentsWithStats(binary01.astype(np.uint8), connectivity=8) | |
| if num <= 1: | |
| return binary01.astype(np.uint8) | |
| areas = stats[1:, cv2.CC_STAT_AREA] | |
| if areas.size == 0 or areas.max() < min_area_px: | |
| return binary01.astype(np.uint8) | |
| largest_idx = 1 + int(np.argmax(areas)) | |
| return (labels == largest_idx).astype(np.uint8) | |
| def measure_min_area_rect(mask01: np.ndarray, px_per_cm: float) -> Tuple[float, float, Tuple]: | |
| contours, _ = cv2.findContours(mask01.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| return 0.0, 0.0, (None, None) | |
| cnt = max(contours, key=cv2.contourArea) | |
| rect = cv2.minAreaRect(cnt) | |
| (w_px, h_px) = rect[1] | |
| length_px, breadth_px = (max(w_px, h_px), min(w_px, h_px)) | |
| length_cm = round(length_px / max(px_per_cm, 1e-6), 2) | |
| breadth_cm = round(breadth_px / max(px_per_cm, 1e-6), 2) | |
| box = cv2.boxPoints(rect).astype(int) | |
| return length_cm, breadth_cm, (box, rect[0]) | |
| def area_cm2_from_contour(mask01: np.ndarray, px_per_cm: float) -> Tuple[float, Optional[np.ndarray]]: | |
| """Area from largest polygon (sub-pixel); returns (area_cm2, contour).""" | |
| m = (mask01 > 0).astype(np.uint8) | |
| contours, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| return 0.0, None | |
| cnt = max(contours, key=cv2.contourArea) | |
| poly_area_px2 = float(cv2.contourArea(cnt)) | |
| area_cm2 = round(poly_area_px2 / (max(px_per_cm, 1e-6) ** 2), 2) | |
| return area_cm2, cnt | |
| def clamp_area_with_minrect(cnt: np.ndarray, px_per_cm: float, area_cm2_poly: float) -> float: | |
| rect = cv2.minAreaRect(cnt) | |
| (w_px, h_px) = rect[1] | |
| rect_area_px2 = float(max(w_px, 0.0) * max(h_px, 0.0)) | |
| rect_area_cm2 = rect_area_px2 / (max(px_per_cm, 1e-6) ** 2) | |
| return round(min(area_cm2_poly, rect_area_cm2 * 1.05), 2) | |
| def draw_measurement_overlay( | |
| base_bgr: np.ndarray, | |
| mask01: np.ndarray, | |
| rect_box: np.ndarray, | |
| length_cm: float, | |
| breadth_cm: float, | |
| thickness: int = 2 | |
| ) -> np.ndarray: | |
| """ | |
| 1) Strong red mask overlay + white contour | |
| 2) Min-area rectangle | |
| 3) Double-headed arrows labeled Length/Width | |
| """ | |
| overlay = base_bgr.copy() | |
| # Mask tint | |
| mask255 = (mask01 * 255).astype(np.uint8) | |
| mask3 = cv2.merge([mask255, mask255, mask255]) | |
| red = np.zeros_like(overlay); red[:] = (0, 0, 255) | |
| alpha = 0.55 | |
| tinted = cv2.addWeighted(overlay, 1 - alpha, red, alpha, 0) | |
| overlay = np.where(mask3 > 0, tinted, overlay) | |
| # Contour | |
| cnts, _ = cv2.findContours(mask255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if cnts: | |
| cv2.drawContours(overlay, cnts, -1, (255, 255, 255), 2) | |
| if rect_box is not None: | |
| cv2.polylines(overlay, [rect_box], True, (255, 255, 255), thickness) | |
| pts = rect_box.reshape(-1, 2) | |
| def midpoint(a, b): return (int((a[0] + b[0]) / 2), int((a[1] + b[1]) / 2)) | |
| e = [np.linalg.norm(pts[i] - pts[(i + 1) % 4]) for i in range(4)] | |
| long_edge_idx = int(np.argmax(e)) | |
| mids = [midpoint(pts[i], pts[(i + 1) % 4]) for i in range(4)] | |
| long_pair = (long_edge_idx, (long_edge_idx + 2) % 4) | |
| short_pair = ((long_edge_idx + 1) % 4, (long_edge_idx + 3) % 4) | |
| def draw_double_arrow(img, p1, p2): | |
| cv2.arrowedLine(img, p1, p2, (0, 0, 0), thickness + 2, tipLength=0.05) | |
| cv2.arrowedLine(img, p2, p1, (0, 0, 0), thickness + 2, tipLength=0.05) | |
| cv2.arrowedLine(img, p1, p2, (255, 255, 255), thickness, tipLength=0.05) | |
| cv2.arrowedLine(img, p2, p1, (255, 255, 255), thickness, tipLength=0.05) | |
| def put_label(text, anchor): | |
| org = (anchor[0] + 6, anchor[1] - 6) | |
| cv2.putText(overlay, text, org, cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 4, cv2.LINE_AA) | |
| cv2.putText(overlay, text, org, cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA) | |
| draw_double_arrow(overlay, mids[long_pair[0]], mids[long_pair[1]]) | |
| draw_double_arrow(overlay, mids[short_pair[0]], mids[short_pair[1]]) | |
| put_label(f"Length: {length_cm:.2f} cm", mids[long_pair[0]]) | |
| put_label(f"Width: {breadth_cm:.2f} cm", mids[short_pair[0]]) | |
| return overlay | |
| # ---------- AI PROCESSOR ---------- | |
| class AIProcessor: | |
| def __init__(self): | |
| self.models_cache = models_cache | |
| self.knowledge_base_cache = knowledge_base_cache | |
| self.uploads_dir = UPLOADS_DIR | |
| self.dataset_id = DATASET_ID | |
| self.hf_token = HF_TOKEN | |
| def _ensure_analysis_dir(self) -> str: | |
| out_dir = os.path.join(self.uploads_dir, "analysis") | |
| os.makedirs(out_dir, exist_ok=True) | |
| return out_dir | |
| def perform_visual_analysis(self, image_pil: Image.Image) -> Dict: | |
| """ | |
| YOLO detect → crop ROI → segment_wound(ROI) → clean mask → | |
| minAreaRect measurement (cm) using EXIF px/cm → save outputs. | |
| """ | |
| try: | |
| px_per_cm, exif_meta = estimate_px_per_cm_from_exif(image_pil, DEFAULT_PX_PER_CM) | |
| # Guardrails for calibration to avoid huge area blow-ups | |
| px_per_cm = float(np.clip(px_per_cm, 20.0, 350.0)) | |
| if (exif_meta or {}).get("used") != "exif": | |
| logging.warning(f"Calibration fallback used: px_per_cm={px_per_cm:.2f} (default). Prefer ruler/Aruco for accuracy.") | |
| image_cv = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR) | |
| # --- Detection --- | |
| det_model = self.models_cache.get("det") | |
| if det_model is None: | |
| raise RuntimeError("YOLO model not loaded") | |
| # Force CPU inference and avoid CUDA touch | |
| results = det_model.predict(image_cv, verbose=False, device="cpu") | |
| if (not results) or (not getattr(results[0], "boxes", None)) or (len(results[0].boxes) == 0): | |
| try: | |
| import gradio as gr | |
| raise gr.Error("No wound could be detected.") | |
| except Exception: | |
| raise RuntimeError("No wound could be detected.") | |
| box = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int) | |
| x1, y1, x2, y2 = [int(v) for v in box] | |
| x1, y1 = max(0, x1), max(0, y1) | |
| x2, y2 = min(image_cv.shape[1], x2), min(image_cv.shape[0], y2) | |
| roi = image_cv[y1:y2, x1:x2].copy() | |
| if roi.size == 0: | |
| try: | |
| import gradio as gr | |
| raise gr.Error("Detected ROI is empty.") | |
| except Exception: | |
| raise RuntimeError("Detected ROI is empty.") | |
| out_dir = self._ensure_analysis_dir() | |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| # --- Segmentation (model-first + KMeans fallback) --- | |
| mask_u8_255, seg_debug = segment_wound(roi, ts, out_dir) | |
| mask01 = (mask_u8_255 > 127).astype(np.uint8) | |
| if mask01.any(): | |
| mask01 = _clean_mask(mask01) | |
| logging.debug(f"Mask postproc: px_after={int(mask01.sum())}") | |
| # --- Measurement (accurate & conservative) --- | |
| if mask01.any(): | |
| length_cm, breadth_cm, (box_pts, _) = measure_min_area_rect(mask01, px_per_cm) | |
| area_poly_cm2, largest_cnt = area_cm2_from_contour(mask01, px_per_cm) | |
| if largest_cnt is not None: | |
| surface_area_cm2 = clamp_area_with_minrect(largest_cnt, px_per_cm, area_poly_cm2) | |
| else: | |
| surface_area_cm2 = area_poly_cm2 | |
| anno_roi = draw_measurement_overlay(roi, mask01, box_pts, length_cm, breadth_cm) | |
| segmentation_empty = False | |
| else: | |
| # Fallback if seg failed: use ROI dimensions | |
| h_px = max(0, y2 - y1); w_px = max(0, x2 - x1) | |
| length_cm = round(max(h_px, w_px) / px_per_cm, 2) | |
| breadth_cm = round(min(h_px, w_px) / px_per_cm, 2) | |
| surface_area_cm2 = round((h_px * w_px) / (px_per_cm ** 2), 2) | |
| anno_roi = roi.copy() | |
| cv2.rectangle(anno_roi, (2, 2), (anno_roi.shape[1]-3, anno_roi.shape[0]-3), (0, 0, 255), 3) | |
| cv2.line(anno_roi, (0, 0), (anno_roi.shape[1]-1, anno_roi.shape[0]-1), (0, 0, 255), 2) | |
| cv2.line(anno_roi, (anno_roi.shape[1]-1, 0), (0, anno_roi.shape[0]-1), (0, 0, 255), 2) | |
| box_pts = None | |
| segmentation_empty = True | |
| # --- Save visualizations --- | |
| original_path = os.path.join(out_dir, f"original_{ts}.png") | |
| cv2.imwrite(original_path, image_cv) | |
| det_vis = image_cv.copy() | |
| cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
| detection_path = os.path.join(out_dir, f"detection_{ts}.png") | |
| cv2.imwrite(detection_path, det_vis) | |
| roi_mask_path = os.path.join(out_dir, f"roi_mask_{ts}.png") | |
| cv2.imwrite(roi_mask_path, (mask01 * 255).astype(np.uint8)) | |
| # ROI overlay (mask tint + contour, without arrows) | |
| mask255 = (mask01 * 255).astype(np.uint8) | |
| mask3 = cv2.merge([mask255, mask255, mask255]) | |
| red = np.zeros_like(roi); red[:] = (0, 0, 255) | |
| alpha = 0.55 | |
| tinted = cv2.addWeighted(roi, 1 - alpha, red, alpha, 0) | |
| if mask255.any(): | |
| roi_overlay = np.where(mask3 > 0, tinted, roi) | |
| cnts, _ = cv2.findContours(mask255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| cv2.drawContours(roi_overlay, cnts, -1, (255, 255, 255), 2) | |
| else: | |
| roi_overlay = anno_roi | |
| seg_full = image_cv.copy() | |
| seg_full[y1:y2, x1:x2] = roi_overlay | |
| segmentation_path = os.path.join(out_dir, f"segmentation_{ts}.png") | |
| cv2.imwrite(segmentation_path, seg_full) | |
| segmentation_roi_path = os.path.join(out_dir, f"segmentation_roi_{ts}.png") | |
| cv2.imwrite(segmentation_roi_path, roi_overlay) | |
| # Annotated (mask + arrows + labels) in full-frame | |
| anno_full = image_cv.copy() | |
| anno_full[y1:y2, x1:x2] = anno_roi | |
| annotated_seg_path = os.path.join(out_dir, f"segmentation_annotated_{ts}.png") | |
| cv2.imwrite(annotated_seg_path, anno_full) | |
| # --- Optional classification --- | |
| wound_type = "Unknown" | |
| cls_pipe = self.models_cache.get("cls") | |
| if cls_pipe is not None: | |
| try: | |
| preds = cls_pipe(Image.fromarray(cv2.cvtColor(roi, cv2.COLOR_BGR2RGB))) | |
| if preds: | |
| wound_type = max(preds, key=lambda x: x.get("score", 0)).get("label", "Unknown") | |
| except Exception as e: | |
| logging.warning(f"Classification failed: {e}") | |
| # Log end-of-seg summary | |
| seg_summary = { | |
| "seg_used": seg_debug.get("used"), | |
| "seg_reason": seg_debug.get("reason"), | |
| "positive_fraction": round(float(seg_debug.get("positive_fraction", 0.0)), 6), | |
| "threshold": seg_debug.get("thr"), | |
| "segmentation_empty": segmentation_empty, | |
| "exif_px_per_cm": round(px_per_cm, 3), | |
| } | |
| _log_kv("SEG_SUMMARY", seg_summary) | |
| return { | |
| "wound_type": wound_type, | |
| "length_cm": length_cm, | |
| "breadth_cm": breadth_cm, | |
| "surface_area_cm2": surface_area_cm2, | |
| "px_per_cm": round(px_per_cm, 2), | |
| "calibration_meta": exif_meta, | |
| "detection_confidence": float(results[0].boxes.conf[0].cpu().item()) | |
| if getattr(results[0].boxes, "conf", None) is not None else 0.0, | |
| "detection_image_path": detection_path, | |
| "segmentation_image_path": annotated_seg_path, | |
| "segmentation_annotated_path": annotated_seg_path, | |
| "segmentation_roi_path": segmentation_roi_path, | |
| "roi_mask_path": roi_mask_path, | |
| "segmentation_empty": segmentation_empty, | |
| "segmentation_debug": seg_debug, | |
| "original_image_path": original_path, | |
| } | |
| except Exception as e: | |
| logging.error(f"Visual analysis failed: {e}", exc_info=True) | |
| raise | |
| # ---------- Knowledge base + reporting ---------- | |
| def query_guidelines(self, query: str) -> str: | |
| try: | |
| vs = self.knowledge_base_cache.get("vector_store") | |
| if not vs: | |
| return "Knowledge base is not available." | |
| retriever = vs.as_retriever(search_kwargs={"k": 5}) | |
| # Modern API (avoid get_relevant_documents deprecation) | |
| docs = retriever.invoke(query) | |
| lines: List[str] = [] | |
| for d in docs: | |
| src = (d.metadata or {}).get("source", "N/A") | |
| txt = (d.page_content or "")[:300] | |
| lines.append(f"Source: {src}\nContent: {txt}...") | |
| return "\n\n".join(lines) if lines else "No relevant guideline snippets found." | |
| except Exception as e: | |
| logging.warning(f"Guidelines query failed: {e}") | |
| return f"Guidelines query failed: {str(e)}" | |
| def _generate_fallback_report(self, patient_info: str, visual_results: Dict, guideline_context: str) -> str: | |
| return f"""# 🩺 SmartHeal AI - Comprehensive Wound Analysis Report | |
| ## 📋 Patient Information | |
| {patient_info} | |
| ## 🔍 Visual Analysis Results | |
| - **Wound Type**: {visual_results.get('wound_type', 'Unknown')} | |
| - **Dimensions**: {visual_results.get('length_cm', 0)} cm × {visual_results.get('breadth_cm', 0)} cm | |
| - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cm² | |
| - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.1%} | |
| - **Calibration**: {visual_results.get('px_per_cm','?')} px/cm ({(visual_results.get('calibration_meta') or {}).get('used','default')}) | |
| ## 📊 Analysis Images | |
| - **Original**: {visual_results.get('original_image_path', 'N/A')} | |
| - **Detection**: {visual_results.get('detection_image_path', 'N/A')} | |
| - **Segmentation**: {visual_results.get('segmentation_image_path', 'N/A')} | |
| - **Annotated**: {visual_results.get('segmentation_annotated_path', 'N/A')} | |
| ## 🎯 Clinical Summary | |
| Automated analysis provides quantitative measurements; verify via clinical examination. | |
| ## 💊 Recommendations | |
| - Cleanse wound gently; select dressing per exudate/infection risk | |
| - Debride necrotic tissue if indicated (clinical decision) | |
| - Document with serial photos and measurements | |
| ## 📅 Monitoring | |
| - Daily in week 1, then every 2–3 days (or as indicated) | |
| - Weekly progress review | |
| ## 📚 Guideline Context | |
| {(guideline_context or '')[:800]}{"..." if guideline_context and len(guideline_context) > 800 else ''} | |
| **Disclaimer:** Automated, for decision support only. Verify clinically. | |
| """ | |
| def generate_final_report( | |
| self, | |
| patient_info: str, | |
| visual_results: Dict, | |
| guideline_context: str, | |
| image_pil: Image.Image, | |
| max_new_tokens: Optional[int] = None, | |
| ) -> str: | |
| try: | |
| report = generate_medgemma_report( | |
| patient_info, visual_results, guideline_context, image_pil, max_new_tokens | |
| ) | |
| if report and report.strip() and not report.startswith(("⚠️", "❌")): | |
| return report | |
| logging.warning("VLM unavailable/invalid; using fallback.") | |
| return self._generate_fallback_report(patient_info, visual_results, guideline_context) | |
| except Exception as e: | |
| logging.error(f"Report generation failed: {e}") | |
| return self._generate_fallback_report(patient_info, visual_results, guideline_context) | |
| def save_and_commit_image(self, image_pil: Image.Image) -> str: | |
| try: | |
| os.makedirs(self.uploads_dir, exist_ok=True) | |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"{ts}.png" | |
| path = os.path.join(self.uploads_dir, filename) | |
| image_pil.convert("RGB").save(path) | |
| logging.info(f"✅ Image saved locally: {path}") | |
| if HF_TOKEN and DATASET_ID: | |
| try: | |
| HfApi, HfFolder = _import_hf_hub() | |
| HfFolder.save_token(HF_TOKEN) | |
| api = HfApi() | |
| api.upload_file( | |
| path_or_fileobj=path, | |
| path_in_repo=f"images/{filename}", | |
| repo_id=DATASET_ID, | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| commit_message=f"Upload wound image: {filename}", | |
| ) | |
| logging.info("✅ Image committed to HF dataset") | |
| except Exception as e: | |
| logging.warning(f"HF upload failed: {e}") | |
| return path | |
| except Exception as e: | |
| logging.error(f"Failed to save/commit image: {e}") | |
| return "" | |
| def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: Dict) -> Dict: | |
| try: | |
| saved_path = self.save_and_commit_image(image_pil) | |
| visual_results = self.perform_visual_analysis(image_pil) | |
| pi = questionnaire_data or {} | |
| patient_info = ( | |
| f"Age: {pi.get('age','N/A')}, " | |
| f"Diabetic: {pi.get('diabetic','N/A')}, " | |
| f"Allergies: {pi.get('allergies','N/A')}, " | |
| f"Date of Wound: {pi.get('date_of_injury','N/A')}, " | |
| f"Professional Care: {pi.get('professional_care','N/A')}, " | |
| f"Oozing/Bleeding: {pi.get('oozing_bleeding','N/A')}, " | |
| f"Infection: {pi.get('infection','N/A')}, " | |
| f"Moisture: {pi.get('moisture','N/A')}" | |
| ) | |
| query = ( | |
| f"best practices for managing a {visual_results.get('wound_type','Unknown')} " | |
| f"with moisture '{pi.get('moisture','unknown')}' and infection '{pi.get('infection','unknown')}' " | |
| f"in a diabetic status '{pi.get('diabetic','unknown')}'" | |
| ) | |
| guideline_context = self.query_guidelines(query) | |
| report = self.generate_final_report(patient_info, visual_results, guideline_context, image_pil) | |
| return { | |
| "success": True, | |
| "visual_analysis": visual_results, | |
| "report": report, | |
| "saved_image_path": saved_path, | |
| "guideline_context": (guideline_context or "")[:500] + ( | |
| "..." if guideline_context and len(guideline_context) > 500 else "" | |
| ), | |
| } | |
| except Exception as e: | |
| logging.error(f"Pipeline error: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "visual_analysis": {}, | |
| "report": f"Analysis failed: {str(e)}", | |
| "saved_image_path": None, | |
| "guideline_context": "", | |
| } | |
| def analyze_wound(self, image, questionnaire_data: Dict) -> Dict: | |
| try: | |
| if isinstance(image, str): | |
| if not os.path.exists(image): | |
| raise ValueError(f"Image file not found: {image}") | |
| image_pil = Image.open(image) | |
| elif isinstance(image, Image.Image): | |
| image_pil = image | |
| elif isinstance(image, np.ndarray): | |
| image_pil = Image.fromarray(image) | |
| else: | |
| raise ValueError(f"Unsupported image type: {type(image)}") | |
| return self.full_analysis_pipeline(image_pil, questionnaire_data or {}) | |
| except Exception as e: | |
| logging.error(f"Wound analysis error: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "visual_analysis": {}, | |
| "report": f"Analysis initialization failed: {str(e)}", | |
| "saved_image_path": None, | |
| "guideline_context": "", | |
| } |