import os import logging import cv2 import numpy as np from PIL import Image from datetime import datetime import gradio as gr import spaces from huggingface_hub import HfApi, HfFolder from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS # =============== LOGGING SETUP =============== logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # =============== CONFIGURATION =============== UPLOADS_DIR = "uploads" if not os.path.exists(UPLOADS_DIR): os.makedirs(UPLOADS_DIR) logging.info(f"Created uploads directory: {UPLOADS_DIR}") HF_TOKEN = os.getenv("HF_TOKEN") YOLO_MODEL_PATH = "src/best.pt" SEG_MODEL_PATH = "src/segmentation_model.h5" GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"] DATASET_ID = "SmartHeal/wound-image-uploads" MAX_NEW_TOKENS = 2048 PIXELS_PER_CM = 38 # =============== GLOBAL CACHES =============== models_cache = {} knowledge_base_cache = {} # =============== LAZY LOADING FUNCTIONS (CPU-SAFE) =============== def load_yolo_model(yolo_model_path): """Lazy import and load YOLO model to avoid CUDA initialization.""" from ultralytics import YOLO return YOLO(yolo_model_path) def load_segmentation_model(seg_model_path): """Lazy import and load segmentation model.""" import tensorflow as tf tf.config.set_visible_devices([], 'GPU') # Force CPU for TensorFlow from tensorflow.keras.models import load_model return load_model(seg_model_path, compile=False) def load_classification_pipeline(hf_token): """Lazy import and load classification pipeline (CPU only).""" from transformers import pipeline return pipeline( "image-classification", model="Hemg/Wound-classification", token=hf_token, device="cpu" ) def load_embedding_model(): """Load embedding model for knowledge base.""" return HuggingFaceEmbeddings( model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": "cpu"} ) # =============== MODEL INITIALIZATION =============== def initialize_cpu_models(): """Initialize all CPU-only models once.""" global models_cache if HF_TOKEN: HfFolder.save_token(HF_TOKEN) logging.info("✅ HuggingFace token set") if "det" not in models_cache: try: models_cache["det"] = load_yolo_model(YOLO_MODEL_PATH) logging.info("✅ YOLO model loaded (CPU only)") except Exception as e: logging.error(f"YOLO load failed: {e}") if "seg" not in models_cache: try: models_cache["seg"] = load_segmentation_model(SEG_MODEL_PATH) logging.info("✅ Segmentation model loaded (CPU)") except Exception as e: logging.warning(f"Segmentation model not available: {e}") if "cls" not in models_cache: try: models_cache["cls"] = load_classification_pipeline(HF_TOKEN) logging.info("✅ Classification pipeline loaded (CPU)") except Exception as e: logging.warning(f"Classification pipeline not available: {e}") if "embedding_model" not in models_cache: try: models_cache["embedding_model"] = load_embedding_model() logging.info("✅ Embedding model loaded (CPU)") except Exception as e: logging.warning(f"Embedding model not available: {e}") def setup_knowledge_base(): """Load PDF documents and create FAISS vector store.""" global knowledge_base_cache if "vector_store" in knowledge_base_cache: return docs = [] for pdf_path in GUIDELINE_PDFS: if os.path.exists(pdf_path): try: loader = PyPDFLoader(pdf_path) docs.extend(loader.load()) logging.info(f"Loaded PDF: {pdf_path}") except Exception as e: logging.warning(f"Failed to load PDF {pdf_path}: {e}") if docs and "embedding_model" in models_cache: splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) chunks = splitter.split_documents(docs) knowledge_base_cache["vector_store"] = FAISS.from_documents(chunks, models_cache["embedding_model"]) logging.info(f"✅ Knowledge base ready with {len(chunks)} chunks") else: knowledge_base_cache["vector_store"] = None logging.warning("Knowledge base unavailable") # Initialize models on app startup initialize_cpu_models() setup_knowledge_base() # =============== GPU-DECORATED MEDGEMMA FUNCTION =============== @spaces.GPU(enable_queue=True, duration=120) def generate_medgemma_report( patient_info, visual_results, guideline_context, detection_image_path, segmentation_image_path, max_new_tokens=None, ): """GPU-only function for MedGemma report generation.""" # Import GPU libraries ONLY here import torch from transformers import pipeline from PIL import Image default_system_prompt = ( "You are a world-class medical AI assistant specializing in wound care " "with expertise in wound assessment and treatment. Provide concise, " "evidence-based medical assessments focusing on: (1) Precise wound " "classification based on tissue type and appearance, (2) Specific " "treatment recommendations with exact product names or interventions when " "appropriate, (3) Objective evaluation of healing progression or deterioration " "indicators, and (4) Clear follow-up timelines. Avoid general statements and " "prioritize actionable insights based on the visual analysis measurements and " "patient context." ) # Lazy-load MedGemma pipeline on GPU if not hasattr(generate_medgemma_report, "_pipe"): try: generate_medgemma_report._pipe = pipeline( "image-text-to-text", model="google/medgemma-4b-it", device="cuda", torch_dtype=torch.bfloat16, offload_folder="offload", token=HF_TOKEN, ) logging.info("✅ MedGemma pipeline loaded on GPU") except Exception as e: logging.warning(f"MedGemma pipeline load failed: {e}") return None pipe = generate_medgemma_report._pipe # Load the original image that was analyzed original_image = None if detection_image_path and os.path.exists(detection_image_path.replace('detection_', 'original_')): original_image = Image.open(detection_image_path.replace('detection_', 'original_')) elif segmentation_image_path and os.path.exists(segmentation_image_path.replace('segmentation_', 'original_')): original_image = Image.open(segmentation_image_path.replace('segmentation_', 'original_')) # Compose messages msgs = [ {"role": "system", "content": [{"type": "text", "text": default_system_prompt}]}, {"role": "user", "content": []}, ] # Attach images if available if original_image: msgs[1]["content"].append({"type": "image", "image": original_image}) else: # Fallback to detection or segmentation images for path in (detection_image_path, segmentation_image_path): if path and os.path.exists(path): msgs[1]["content"].append({"type": "image", "image": Image.open(path)}) break # Attach text prompt prompt = f"""## Patient Information {patient_info} ## Visual Analysis Results - Wound Type: {visual_results.get('wound_type','Unknown')} - Dimensions: {visual_results.get('length_cm', 0)} x {visual_results.get('breadth_cm', 0)} cm - Surface Area: {visual_results.get('surface_area_cm2', 0)} cm² - Detection Confidence: {visual_results.get('detection_confidence', 0):.2f} ## Clinical Guidelines Context {guideline_context[:1500]}... Please provide a comprehensive wound care assessment and treatment recommendations based on the image and provided information.""" msgs[1]["content"].append({"type": "text", "text": prompt}) try: out = pipe(text=msgs, max_new_tokens=max_new_tokens or MAX_NEW_TOKENS, do_sample=False) return out[0]["generated_text"][-1].get("content", "") except Exception as e: logging.error(f"Failed to generate MedGemma report: {e}") return f"❌ An error occurred: {e}" # =============== AI PROCESSOR CLASS =============== class AIProcessor: def __init__(self): self.models_cache = models_cache self.knowledge_base_cache = knowledge_base_cache self.px_per_cm = PIXELS_PER_CM self.uploads_dir = UPLOADS_DIR self.dataset_id = DATASET_ID self.hf_token = HF_TOKEN def perform_visual_analysis(self, image_pil: Image.Image) -> dict: """Performs the full visual analysis pipeline.""" try: # Convert PIL to OpenCV format image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) # YOLO Detection yolo_model = self.models_cache.get("det") if yolo_model is None: raise RuntimeError("YOLO model ('det') not loaded") results = yolo_model.predict(image_cv, verbose=False, device="cpu") if not results or not results[0].boxes or len(results[0].boxes) == 0: raise ValueError("No wound detected in the image") # Extract bounding box - handle different output formats boxes_data = results[0].boxes.xyxy.cpu().numpy() if len(boxes_data.shape) == 1: # Single detection case if len(boxes_data) != 4: raise ValueError(f"Expected 4 coordinates, got {len(boxes_data)}") x1, y1, x2, y2 = boxes_data.astype(int) else: # Multiple detections - take the first one if boxes_data.shape[1] != 4: raise ValueError(f"Expected 4 coordinates per box, got {boxes_data.shape[1]}") x1, y1, x2, y2 = boxes_data[0].astype(int) # Validate coordinates if x1 >= x2 or y1 >= y2 or x1 < 0 or y1 < 0: raise ValueError("Invalid bounding box coordinates") # Extract wound region detected_region_cv = image_cv[y1:y2, x1:x2] if detected_region_cv.size == 0: raise ValueError("Detected region is empty") # Save detection visualization det_vis = image_cv.copy() cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0, 255, 0), 2) os.makedirs(f"{self.uploads_dir}/analysis", exist_ok=True) ts = datetime.now().strftime("%Y%m%d_%H%M%S") det_path = f"{self.uploads_dir}/analysis/detection_{ts}.png" cv2.imwrite(det_path, det_vis) # Save original image for reference original_path = f"{self.uploads_dir}/analysis/original_{ts}.png" cv2.imwrite(original_path, image_cv) # Segmentation Analysis length = breadth = area = 0 seg_path = None seg_model = self.models_cache.get("seg") if seg_model is not None: try: # Get input shape from model input_shape = seg_model.input_shape if len(input_shape) >= 3: h, w = input_shape[1:3] else: h, w = 256, 256 # Default fallback # Prepare input for segmentation resized = cv2.resize(detected_region_cv, (w, h)) normalized_input = np.expand_dims(resized / 255.0, 0) # Predict mask mask_pred = seg_model.predict(normalized_input, verbose=0) # Handle different output formats if len(mask_pred.shape) == 4: mask_np = (mask_pred[0, :, :, 0] > 0.5).astype(np.uint8) elif len(mask_pred.shape) == 3: mask_np = (mask_pred[0, :, :] > 0.5).astype(np.uint8) else: raise ValueError(f"Unexpected segmentation output shape: {mask_pred.shape}") # Resize mask back to detection region size mask_resized = cv2.resize( mask_np * 255, (detected_region_cv.shape[1], detected_region_cv.shape[0]), interpolation=cv2.INTER_NEAREST ) mask_resized = (mask_resized > 127).astype(np.uint8) # Create segmentation visualization overlay = detected_region_cv.copy() overlay[mask_resized == 1] = [0, 0, 255] # Red overlay for wound area seg_vis = cv2.addWeighted(detected_region_cv, 0.7, overlay, 0.3, 0) seg_path = f"{self.uploads_dir}/analysis/segmentation_{ts}.png" cv2.imwrite(seg_path, seg_vis) # Calculate measurements contours, _ = cv2.findContours(mask_resized, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if contours: # Get the largest contour largest_contour = max(contours, key=cv2.contourArea) # Calculate bounding rectangle bbox = cv2.boundingRect(largest_contour) if len(bbox) == 4: x, y, w_box, h_box = bbox length = round(h_box / self.px_per_cm, 2) breadth = round(w_box / self.px_per_cm, 2) area = round(cv2.contourArea(largest_contour) / (self.px_per_cm ** 2), 2) else: logging.warning(f"Unexpected bounding rect format: {bbox}") else: logging.info("No contours found in segmentation mask") except Exception as seg_error: logging.error(f"Segmentation processing error: {seg_error}") seg_path = None # Wound Classification wound_type = "Unknown" cls_pipeline = self.models_cache.get("cls") if cls_pipeline is not None: try: detected_image_pil = Image.fromarray(cv2.cvtColor(detected_region_cv, cv2.COLOR_BGR2RGB)) predictions = cls_pipeline(detected_image_pil) if predictions and len(predictions) > 0: best_pred = max(predictions, key=lambda x: x.get("score", 0)) wound_type = best_pred.get("label", "Unknown") except Exception as cls_error: logging.warning(f"Classification failed: {cls_error}") # Extract confidence score confidence = 0.0 if results[0].boxes.conf is not None and len(results[0].boxes.conf) > 0: confidence = float(results[0].boxes.conf[0].cpu().item()) return { "wound_type": wound_type, "length_cm": length, "breadth_cm": breadth, "surface_area_cm2": area, "detection_confidence": confidence, "detection_image_path": det_path, "segmentation_image_path": seg_path, "original_image_path": original_path } except Exception as e: logging.error(f"Visual analysis failed: {e}") raise e def query_guidelines(self, query: str) -> str: """Query the knowledge base for relevant information.""" try: vector_store = self.knowledge_base_cache.get("vector_store") if not vector_store: return "Clinical guidelines unavailable - knowledge base not loaded" retriever = vector_store.as_retriever(search_kwargs={"k": 10}) docs = retriever.invoke(query) if not docs: return "No relevant guidelines found for the query" context = "\n\n".join([ f"Source: {doc.metadata.get('source', 'Unknown')}, Page: {doc.metadata.get('page', 'N/A')}\n{doc.page_content}" for doc in docs ]) return context except Exception as e: logging.error(f"Guidelines query failed: {e}") return f"Guidelines query failed: {str(e)}" def generate_final_report( self, patient_info: str, visual_results: dict, guideline_context: str, image_pil: Image.Image, max_new_tokens: int = None ) -> str: """Generate final report using MedGemma GPU pipeline.""" try: det_path = visual_results.get("detection_image_path", "") seg_path = visual_results.get("segmentation_image_path", "") report = generate_medgemma_report( patient_info, visual_results, guideline_context, det_path, seg_path, max_new_tokens ) if report and report.strip(): return report else: return self._generate_fallback_report(patient_info, visual_results, guideline_context) except Exception as e: logging.error(f"MedGemma report generation failed: {e}") return self._generate_fallback_report(patient_info, visual_results, guideline_context) def _generate_fallback_report( self, patient_info: str, visual_results: dict, guideline_context: str ) -> str: """Generate fallback report if MedGemma fails.""" report = f"""# Wound Analysis Report ## Patient Information {patient_info} ## Visual Analysis Results - **Wound Type**: {visual_results.get('wound_type', 'Unknown')} - **Dimensions**: {visual_results.get('length_cm', 0)} cm × {visual_results.get('breadth_cm', 0)} cm - **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cm² - **Detection Confidence**: {visual_results.get('detection_confidence', 0):.2f} ## Analysis Images - **Detection Image**: {visual_results.get('detection_image_path', 'N/A')} - **Segmentation Image**: {visual_results.get('segmentation_image_path', 'N/A')} ## Clinical Guidelines Context {guideline_context[:1000]}{'...' if len(guideline_context) > 1000 else ''} ## Assessment Summary Based on the automated visual analysis, the wound has been classified as **{visual_results.get('wound_type', 'Unknown')}** with measurable dimensions. The detection confidence indicates the reliability of the automated assessment. ## Recommendations 1. **Clinical Evaluation**: This automated analysis should be supplemented with professional clinical assessment 2. **Documentation**: Regular monitoring and documentation of wound progression is recommended 3. **Treatment Planning**: Develop appropriate treatment protocol based on wound characteristics and patient factors 4. **Follow-up**: Schedule appropriate follow-up intervals based on wound severity and healing progress ## Important Notes - This is an automated analysis and should not replace professional medical judgment - All measurements are estimates based on computer vision algorithms - Clinical correlation is essential for proper diagnosis and treatment planning - Consider patient-specific factors not captured in this automated assessment ## Disclaimer This automated analysis is provided for informational purposes only and does not constitute medical advice. Always consult with qualified healthcare professionals for proper diagnosis and treatment. """ return report def save_and_commit_image(self, image_pil: Image.Image) -> str: """Save image locally and optionally commit to HF dataset.""" try: os.makedirs(self.uploads_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"{timestamp}.png" path = os.path.join(self.uploads_dir, filename) # Save image image_pil.convert("RGB").save(path) logging.info(f"✅ Image saved locally: {path}") # Upload to HuggingFace dataset if configured if self.hf_token and self.dataset_id: try: api = HfApi() api.upload_file( path_or_fileobj=path, path_in_repo=f"images/{filename}", repo_id=self.dataset_id, repo_type="dataset", token=self.hf_token, commit_message=f"Upload wound image: {filename}" ) logging.info("✅ Image committed to HF dataset") except Exception as e: logging.warning(f"HF upload failed: {e}") return path except Exception as e: logging.error(f"Failed to save image: {e}") return "" def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: dict) -> dict: """Run full analysis pipeline.""" try: # Save image first saved_path = self.save_and_commit_image(image_pil) logging.info(f"Image saved: {saved_path}") # Perform visual analysis visual_results = self.perform_visual_analysis(image_pil) logging.info(f"Visual analysis completed: {visual_results}") # Process questionnaire data patient_info = ", ".join(f"{k}: {v}" for k, v in questionnaire_data.items() if v) if not patient_info: patient_info = "No patient information provided" # Query guidelines query = f"wound care treatment for {visual_results.get('wound_type', 'wound')} " if questionnaire_data.get('diabetic') == 'Yes': query += "diabetic patient " if questionnaire_data.get('infection') == 'Yes': query += "with infection signs " guideline_context = self.query_guidelines(query) logging.info("Guidelines queried successfully") # Generate final report report = self.generate_final_report(patient_info, visual_results, guideline_context, image_pil) logging.info("Report generated successfully") return { 'success': True, 'visual_analysis': visual_results, 'report': report, 'saved_image_path': saved_path, 'guideline_context': guideline_context[:500] + "..." if len(guideline_context) > 500 else guideline_context } except Exception as e: logging.error(f"Pipeline error: {e}") return { 'success': False, 'error': str(e), 'visual_analysis': {}, 'report': f"Analysis failed: {str(e)}", 'saved_image_path': None, 'guideline_context': "" } def analyze_wound(self, image, questionnaire_data: dict) -> dict: """Main analysis entry point - maintains original function name.""" try: # Handle different image input formats if isinstance(image, str): if os.path.exists(image): image_pil = Image.open(image) else: raise ValueError(f"Image file not found: {image}") elif isinstance(image, Image.Image): image_pil = image elif isinstance(image, np.ndarray): image_pil = Image.fromarray(image) else: raise ValueError(f"Unsupported image type: {type(image)}") return self.full_analysis_pipeline(image_pil, questionnaire_data) except Exception as e: logging.error(f"Wound analysis error: {e}") return { 'success': False, 'error': str(e), 'visual_analysis': {}, 'report': f"Analysis initialization failed: {str(e)}", 'saved_image_path': None, 'guideline_context': "" } def _assess_risk_legacy(self, questionnaire_data: dict) -> dict: """Legacy risk assessment function - maintains original function name.""" risk_factors = [] risk_score = 0 try: # Age assessment age = questionnaire_data.get('patient_age', 0) if isinstance(age, str): try: age = int(age) except ValueError: age = 0 if age > 65: risk_factors.append("Advanced age (>65)") risk_score += 2 elif age > 50: risk_factors.append("Older adult (50-65)") risk_score += 1 # Wound duration assessment duration = str(questionnaire_data.get('wound_duration', '')).lower() if any(term in duration for term in ['month', 'months', 'year', 'years']): risk_factors.append("Chronic wound (>4 weeks)") risk_score += 3 elif any(term in duration for term in ['week', 'weeks']): # Try to extract number of weeks import re weeks_match = re.search(r'(\d+)\s*week', duration) if weeks_match and int(weeks_match.group(1)) > 4: risk_factors.append("Chronic wound (>4 weeks)") risk_score += 3 # Pain level assessment pain = questionnaire_data.get('pain_level', 0) if isinstance(pain, str): try: pain = float(pain) except ValueError: pain = 0 if pain >= 7: risk_factors.append("High pain level (≥7/10)") risk_score += 2 elif pain >= 5: risk_factors.append("Moderate pain level (5-6/10)") risk_score += 1 # Medical history assessment medical_history = str(questionnaire_data.get('medical_history', '')).lower() diabetic_status = str(questionnaire_data.get('diabetic', '')).lower() if 'diabetes' in medical_history or 'yes' in diabetic_status: risk_factors.append("Diabetes mellitus") risk_score += 3 if any(term in medical_history for term in ['vascular', 'circulation', 'arterial', 'venous']): risk_factors.append("Vascular disease") risk_score += 2 if any(term in medical_history for term in ['immune', 'immunocompromised', 'steroid', 'chemotherapy']): risk_factors.append("Immune system compromise") risk_score += 2 if any(term in medical_history for term in ['smoking', 'smoker', 'tobacco']): risk_factors.append("Smoking history") risk_score += 2 # Infection signs infection_signs = str(questionnaire_data.get('infection', '')).lower() if 'yes' in infection_signs: risk_factors.append("Signs of infection present") risk_score += 3 # Moisture level moisture = str(questionnaire_data.get('moisture', '')).lower() if any(term in moisture for term in ['wet', 'heavy', 'excessive']): risk_factors.append("Excessive wound exudate") risk_score += 1 # Determine risk level if risk_score >= 8: risk_level = "Very High" elif risk_score >= 6: risk_level = "High" elif risk_score >= 3: risk_level = "Moderate" else: risk_level = "Low" return { 'risk_score': risk_score, 'risk_level': risk_level, 'risk_factors': risk_factors, 'recommendations': self._get_risk_recommendations(risk_level, risk_factors) } except Exception as e: logging.error(f"Risk assessment error: {e}") return { 'risk_score': 0, 'risk_level': 'Unknown', 'risk_factors': [], 'recommendations': ["Unable to assess risk due to data processing error"] } def _get_risk_recommendations(self, risk_level: str, risk_factors: list) -> list: """Generate risk-based recommendations.""" recommendations = [] if risk_level in ["High", "Very High"]: recommendations.append("Urgent referral to wound care specialist recommended") recommendations.append("Consider daily wound monitoring") recommendations.append("Implement aggressive wound care protocol") elif risk_level == "Moderate": recommendations.append("Regular wound care follow-up every 2-3 days") recommendations.append("Monitor for signs of deterioration") else: recommendations.append("Standard wound care monitoring") recommendations.append("Weekly assessment recommended") # Specific recommendations based on risk factors if "Diabetes mellitus" in risk_factors: recommendations.append("Strict glycemic control essential") recommendations.append("Monitor for diabetic complications") if "Signs of infection present" in risk_factors: recommendations.append("Consider antibiotic therapy") recommendations.append("Increase wound cleaning frequency") if "Excessive wound exudate" in risk_factors: recommendations.append("Use high-absorption dressings") recommendations.append("More frequent dressing changes may be needed") return recommendations