SmartHeal-Agentic-AI / src /ai_processor.py
SmartHeal's picture
Update src/ai_processor.py
862d7cb verified
raw
history blame
31.1 kB
import os
import logging
import cv2
import numpy as np
from PIL import Image
from datetime import datetime
import gradio as gr
import spaces
from huggingface_hub import HfApi, HfFolder
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
# =============== LOGGING SETUP ===============
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# =============== CONFIGURATION ===============
UPLOADS_DIR = "uploads"
if not os.path.exists(UPLOADS_DIR):
os.makedirs(UPLOADS_DIR)
logging.info(f"Created uploads directory: {UPLOADS_DIR}")
HF_TOKEN = os.getenv("HF_TOKEN")
YOLO_MODEL_PATH = "src/best.pt"
SEG_MODEL_PATH = "src/segmentation_model.h5"
GUIDELINE_PDFS = ["src/eHealth in Wound Care.pdf", "src/IWGDF Guideline.pdf", "src/evaluation.pdf"]
DATASET_ID = "SmartHeal/wound-image-uploads"
MAX_NEW_TOKENS = 2048
PIXELS_PER_CM = 38
# =============== GLOBAL CACHES ===============
models_cache = {}
knowledge_base_cache = {}
# =============== LAZY LOADING FUNCTIONS (CPU-SAFE) ===============
def load_yolo_model(yolo_model_path):
"""Lazy import and load YOLO model to avoid CUDA initialization."""
from ultralytics import YOLO
return YOLO(yolo_model_path)
def load_segmentation_model(seg_model_path):
"""Lazy import and load segmentation model."""
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU') # Force CPU for TensorFlow
from tensorflow.keras.models import load_model
return load_model(seg_model_path, compile=False)
def load_classification_pipeline(hf_token):
"""Lazy import and load classification pipeline (CPU only)."""
from transformers import pipeline
return pipeline(
"image-classification",
model="Hemg/Wound-classification",
token=hf_token,
device="cpu"
)
def load_embedding_model():
"""Load embedding model for knowledge base."""
return HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={"device": "cpu"}
)
# =============== MODEL INITIALIZATION ===============
def initialize_cpu_models():
"""Initialize all CPU-only models once."""
global models_cache
if HF_TOKEN:
HfFolder.save_token(HF_TOKEN)
logging.info("✅ HuggingFace token set")
if "det" not in models_cache:
try:
models_cache["det"] = load_yolo_model(YOLO_MODEL_PATH)
logging.info("✅ YOLO model loaded (CPU only)")
except Exception as e:
logging.error(f"YOLO load failed: {e}")
if "seg" not in models_cache:
try:
models_cache["seg"] = load_segmentation_model(SEG_MODEL_PATH)
logging.info("✅ Segmentation model loaded (CPU)")
except Exception as e:
logging.warning(f"Segmentation model not available: {e}")
if "cls" not in models_cache:
try:
models_cache["cls"] = load_classification_pipeline(HF_TOKEN)
logging.info("✅ Classification pipeline loaded (CPU)")
except Exception as e:
logging.warning(f"Classification pipeline not available: {e}")
if "embedding_model" not in models_cache:
try:
models_cache["embedding_model"] = load_embedding_model()
logging.info("✅ Embedding model loaded (CPU)")
except Exception as e:
logging.warning(f"Embedding model not available: {e}")
def setup_knowledge_base():
"""Load PDF documents and create FAISS vector store."""
global knowledge_base_cache
if "vector_store" in knowledge_base_cache:
return
docs = []
for pdf_path in GUIDELINE_PDFS:
if os.path.exists(pdf_path):
try:
loader = PyPDFLoader(pdf_path)
docs.extend(loader.load())
logging.info(f"Loaded PDF: {pdf_path}")
except Exception as e:
logging.warning(f"Failed to load PDF {pdf_path}: {e}")
if docs and "embedding_model" in models_cache:
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
chunks = splitter.split_documents(docs)
knowledge_base_cache["vector_store"] = FAISS.from_documents(chunks, models_cache["embedding_model"])
logging.info(f"✅ Knowledge base ready with {len(chunks)} chunks")
else:
knowledge_base_cache["vector_store"] = None
logging.warning("Knowledge base unavailable")
# Initialize models on app startup
initialize_cpu_models()
setup_knowledge_base()
# =============== GPU-DECORATED MEDGEMMA FUNCTION ===============
@spaces.GPU(enable_queue=True, duration=120)
def generate_medgemma_report(
patient_info,
visual_results,
guideline_context,
detection_image_path,
segmentation_image_path,
max_new_tokens=None,
):
"""GPU-only function for MedGemma report generation."""
# Import GPU libraries ONLY here
import torch
from transformers import pipeline
from PIL import Image
default_system_prompt = (
"You are a world-class medical AI assistant specializing in wound care "
"with expertise in wound assessment and treatment. Provide concise, "
"evidence-based medical assessments focusing on: (1) Precise wound "
"classification based on tissue type and appearance, (2) Specific "
"treatment recommendations with exact product names or interventions when "
"appropriate, (3) Objective evaluation of healing progression or deterioration "
"indicators, and (4) Clear follow-up timelines. Avoid general statements and "
"prioritize actionable insights based on the visual analysis measurements and "
"patient context."
)
# Lazy-load MedGemma pipeline on GPU
if not hasattr(generate_medgemma_report, "_pipe"):
try:
generate_medgemma_report._pipe = pipeline(
"image-text-to-text",
model="google/medgemma-4b-it",
device="cuda",
torch_dtype=torch.bfloat16,
offload_folder="offload",
token=HF_TOKEN,
)
logging.info("✅ MedGemma pipeline loaded on GPU")
except Exception as e:
logging.warning(f"MedGemma pipeline load failed: {e}")
return None
pipe = generate_medgemma_report._pipe
# Load the original image that was analyzed
original_image = None
if detection_image_path and os.path.exists(detection_image_path.replace('detection_', 'original_')):
original_image = Image.open(detection_image_path.replace('detection_', 'original_'))
elif segmentation_image_path and os.path.exists(segmentation_image_path.replace('segmentation_', 'original_')):
original_image = Image.open(segmentation_image_path.replace('segmentation_', 'original_'))
# Compose messages
msgs = [
{"role": "system", "content": [{"type": "text", "text": default_system_prompt}]},
{"role": "user", "content": []},
]
# Attach images if available
if original_image:
msgs[1]["content"].append({"type": "image", "image": original_image})
else:
# Fallback to detection or segmentation images
for path in (detection_image_path, segmentation_image_path):
if path and os.path.exists(path):
msgs[1]["content"].append({"type": "image", "image": Image.open(path)})
break
# Attach text prompt
prompt = f"""## Patient Information
{patient_info}
## Visual Analysis Results
- Wound Type: {visual_results.get('wound_type','Unknown')}
- Dimensions: {visual_results.get('length_cm', 0)} x {visual_results.get('breadth_cm', 0)} cm
- Surface Area: {visual_results.get('surface_area_cm2', 0)} cm²
- Detection Confidence: {visual_results.get('detection_confidence', 0):.2f}
## Clinical Guidelines Context
{guideline_context[:1500]}...
Please provide a comprehensive wound care assessment and treatment recommendations based on the image and provided information."""
msgs[1]["content"].append({"type": "text", "text": prompt})
try:
out = pipe(text=msgs, max_new_tokens=max_new_tokens or MAX_NEW_TOKENS, do_sample=False)
return out[0]["generated_text"][-1].get("content", "")
except Exception as e:
logging.error(f"Failed to generate MedGemma report: {e}")
return f"❌ An error occurred: {e}"
# =============== AI PROCESSOR CLASS ===============
class AIProcessor:
def __init__(self):
self.models_cache = models_cache
self.knowledge_base_cache = knowledge_base_cache
self.px_per_cm = PIXELS_PER_CM
self.uploads_dir = UPLOADS_DIR
self.dataset_id = DATASET_ID
self.hf_token = HF_TOKEN
def perform_visual_analysis(self, image_pil: Image.Image) -> dict:
"""Performs the full visual analysis pipeline."""
try:
# Convert PIL to OpenCV format
image_cv = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
# YOLO Detection
yolo_model = self.models_cache.get("det")
if yolo_model is None:
raise RuntimeError("YOLO model ('det') not loaded")
results = yolo_model.predict(image_cv, verbose=False, device="cpu")
if not results or not results[0].boxes or len(results[0].boxes) == 0:
raise ValueError("No wound detected in the image")
# Extract bounding box - handle different output formats
boxes_data = results[0].boxes.xyxy.cpu().numpy()
if len(boxes_data.shape) == 1:
# Single detection case
if len(boxes_data) != 4:
raise ValueError(f"Expected 4 coordinates, got {len(boxes_data)}")
x1, y1, x2, y2 = boxes_data.astype(int)
else:
# Multiple detections - take the first one
if boxes_data.shape[1] != 4:
raise ValueError(f"Expected 4 coordinates per box, got {boxes_data.shape[1]}")
x1, y1, x2, y2 = boxes_data[0].astype(int)
# Validate coordinates
if x1 >= x2 or y1 >= y2 or x1 < 0 or y1 < 0:
raise ValueError("Invalid bounding box coordinates")
# Extract wound region
detected_region_cv = image_cv[y1:y2, x1:x2]
if detected_region_cv.size == 0:
raise ValueError("Detected region is empty")
# Save detection visualization
det_vis = image_cv.copy()
cv2.rectangle(det_vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
os.makedirs(f"{self.uploads_dir}/analysis", exist_ok=True)
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
det_path = f"{self.uploads_dir}/analysis/detection_{ts}.png"
cv2.imwrite(det_path, det_vis)
# Save original image for reference
original_path = f"{self.uploads_dir}/analysis/original_{ts}.png"
cv2.imwrite(original_path, image_cv)
# Segmentation Analysis
length = breadth = area = 0
seg_path = None
seg_model = self.models_cache.get("seg")
if seg_model is not None:
try:
# Get input shape from model
input_shape = seg_model.input_shape
if len(input_shape) >= 3:
h, w = input_shape[1:3]
else:
h, w = 256, 256 # Default fallback
# Prepare input for segmentation
resized = cv2.resize(detected_region_cv, (w, h))
normalized_input = np.expand_dims(resized / 255.0, 0)
# Predict mask
mask_pred = seg_model.predict(normalized_input, verbose=0)
# Handle different output formats
if len(mask_pred.shape) == 4:
mask_np = (mask_pred[0, :, :, 0] > 0.5).astype(np.uint8)
elif len(mask_pred.shape) == 3:
mask_np = (mask_pred[0, :, :] > 0.5).astype(np.uint8)
else:
raise ValueError(f"Unexpected segmentation output shape: {mask_pred.shape}")
# Resize mask back to detection region size
mask_resized = cv2.resize(
mask_np * 255,
(detected_region_cv.shape[1], detected_region_cv.shape[0]),
interpolation=cv2.INTER_NEAREST
)
mask_resized = (mask_resized > 127).astype(np.uint8)
# Create segmentation visualization
overlay = detected_region_cv.copy()
overlay[mask_resized == 1] = [0, 0, 255] # Red overlay for wound area
seg_vis = cv2.addWeighted(detected_region_cv, 0.7, overlay, 0.3, 0)
seg_path = f"{self.uploads_dir}/analysis/segmentation_{ts}.png"
cv2.imwrite(seg_path, seg_vis)
# Calculate measurements
contours, _ = cv2.findContours(mask_resized, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours:
# Get the largest contour
largest_contour = max(contours, key=cv2.contourArea)
# Calculate bounding rectangle
bbox = cv2.boundingRect(largest_contour)
if len(bbox) == 4:
x, y, w_box, h_box = bbox
length = round(h_box / self.px_per_cm, 2)
breadth = round(w_box / self.px_per_cm, 2)
area = round(cv2.contourArea(largest_contour) / (self.px_per_cm ** 2), 2)
else:
logging.warning(f"Unexpected bounding rect format: {bbox}")
else:
logging.info("No contours found in segmentation mask")
except Exception as seg_error:
logging.error(f"Segmentation processing error: {seg_error}")
seg_path = None
# Wound Classification
wound_type = "Unknown"
cls_pipeline = self.models_cache.get("cls")
if cls_pipeline is not None:
try:
detected_image_pil = Image.fromarray(cv2.cvtColor(detected_region_cv, cv2.COLOR_BGR2RGB))
predictions = cls_pipeline(detected_image_pil)
if predictions and len(predictions) > 0:
best_pred = max(predictions, key=lambda x: x.get("score", 0))
wound_type = best_pred.get("label", "Unknown")
except Exception as cls_error:
logging.warning(f"Classification failed: {cls_error}")
# Extract confidence score
confidence = 0.0
if results[0].boxes.conf is not None and len(results[0].boxes.conf) > 0:
confidence = float(results[0].boxes.conf[0].cpu().item())
return {
"wound_type": wound_type,
"length_cm": length,
"breadth_cm": breadth,
"surface_area_cm2": area,
"detection_confidence": confidence,
"detection_image_path": det_path,
"segmentation_image_path": seg_path,
"original_image_path": original_path
}
except Exception as e:
logging.error(f"Visual analysis failed: {e}")
raise e
def query_guidelines(self, query: str) -> str:
"""Query the knowledge base for relevant information."""
try:
vector_store = self.knowledge_base_cache.get("vector_store")
if not vector_store:
return "Clinical guidelines unavailable - knowledge base not loaded"
retriever = vector_store.as_retriever(search_kwargs={"k": 10})
docs = retriever.invoke(query)
if not docs:
return "No relevant guidelines found for the query"
context = "\n\n".join([
f"Source: {doc.metadata.get('source', 'Unknown')}, Page: {doc.metadata.get('page', 'N/A')}\n{doc.page_content}"
for doc in docs
])
return context
except Exception as e:
logging.error(f"Guidelines query failed: {e}")
return f"Guidelines query failed: {str(e)}"
def generate_final_report(
self, patient_info: str, visual_results: dict, guideline_context: str,
image_pil: Image.Image, max_new_tokens: int = None
) -> str:
"""Generate final report using MedGemma GPU pipeline."""
try:
det_path = visual_results.get("detection_image_path", "")
seg_path = visual_results.get("segmentation_image_path", "")
report = generate_medgemma_report(
patient_info, visual_results, guideline_context,
det_path, seg_path, max_new_tokens
)
if report and report.strip():
return report
else:
return self._generate_fallback_report(patient_info, visual_results, guideline_context)
except Exception as e:
logging.error(f"MedGemma report generation failed: {e}")
return self._generate_fallback_report(patient_info, visual_results, guideline_context)
def _generate_fallback_report(
self, patient_info: str, visual_results: dict, guideline_context: str
) -> str:
"""Generate fallback report if MedGemma fails."""
report = f"""# Wound Analysis Report
## Patient Information
{patient_info}
## Visual Analysis Results
- **Wound Type**: {visual_results.get('wound_type', 'Unknown')}
- **Dimensions**: {visual_results.get('length_cm', 0)} cm × {visual_results.get('breadth_cm', 0)} cm
- **Surface Area**: {visual_results.get('surface_area_cm2', 0)} cm²
- **Detection Confidence**: {visual_results.get('detection_confidence', 0):.2f}
## Analysis Images
- **Detection Image**: {visual_results.get('detection_image_path', 'N/A')}
- **Segmentation Image**: {visual_results.get('segmentation_image_path', 'N/A')}
## Clinical Guidelines Context
{guideline_context[:1000]}{'...' if len(guideline_context) > 1000 else ''}
## Assessment Summary
Based on the automated visual analysis, the wound has been classified as **{visual_results.get('wound_type', 'Unknown')}** with measurable dimensions. The detection confidence indicates the reliability of the automated assessment.
## Recommendations
1. **Clinical Evaluation**: This automated analysis should be supplemented with professional clinical assessment
2. **Documentation**: Regular monitoring and documentation of wound progression is recommended
3. **Treatment Planning**: Develop appropriate treatment protocol based on wound characteristics and patient factors
4. **Follow-up**: Schedule appropriate follow-up intervals based on wound severity and healing progress
## Important Notes
- This is an automated analysis and should not replace professional medical judgment
- All measurements are estimates based on computer vision algorithms
- Clinical correlation is essential for proper diagnosis and treatment planning
- Consider patient-specific factors not captured in this automated assessment
## Disclaimer
This automated analysis is provided for informational purposes only and does not constitute medical advice. Always consult with qualified healthcare professionals for proper diagnosis and treatment.
"""
return report
def save_and_commit_image(self, image_pil: Image.Image) -> str:
"""Save image locally and optionally commit to HF dataset."""
try:
os.makedirs(self.uploads_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{timestamp}.png"
path = os.path.join(self.uploads_dir, filename)
# Save image
image_pil.convert("RGB").save(path)
logging.info(f"✅ Image saved locally: {path}")
# Upload to HuggingFace dataset if configured
if self.hf_token and self.dataset_id:
try:
api = HfApi()
api.upload_file(
path_or_fileobj=path,
path_in_repo=f"images/{filename}",
repo_id=self.dataset_id,
repo_type="dataset",
token=self.hf_token,
commit_message=f"Upload wound image: {filename}"
)
logging.info("✅ Image committed to HF dataset")
except Exception as e:
logging.warning(f"HF upload failed: {e}")
return path
except Exception as e:
logging.error(f"Failed to save image: {e}")
return ""
def full_analysis_pipeline(self, image_pil: Image.Image, questionnaire_data: dict) -> dict:
"""Run full analysis pipeline."""
try:
# Save image first
saved_path = self.save_and_commit_image(image_pil)
logging.info(f"Image saved: {saved_path}")
# Perform visual analysis
visual_results = self.perform_visual_analysis(image_pil)
logging.info(f"Visual analysis completed: {visual_results}")
# Process questionnaire data
patient_info = ", ".join(f"{k}: {v}" for k, v in questionnaire_data.items() if v)
if not patient_info:
patient_info = "No patient information provided"
# Query guidelines
query = f"wound care treatment for {visual_results.get('wound_type', 'wound')} "
if questionnaire_data.get('diabetic') == 'Yes':
query += "diabetic patient "
if questionnaire_data.get('infection') == 'Yes':
query += "with infection signs "
guideline_context = self.query_guidelines(query)
logging.info("Guidelines queried successfully")
# Generate final report
report = self.generate_final_report(patient_info, visual_results, guideline_context, image_pil)
logging.info("Report generated successfully")
return {
'success': True,
'visual_analysis': visual_results,
'report': report,
'saved_image_path': saved_path,
'guideline_context': guideline_context[:500] + "..." if len(guideline_context) > 500 else guideline_context
}
except Exception as e:
logging.error(f"Pipeline error: {e}")
return {
'success': False,
'error': str(e),
'visual_analysis': {},
'report': f"Analysis failed: {str(e)}",
'saved_image_path': None,
'guideline_context': ""
}
def analyze_wound(self, image, questionnaire_data: dict) -> dict:
"""Main analysis entry point - maintains original function name."""
try:
# Handle different image input formats
if isinstance(image, str):
if os.path.exists(image):
image_pil = Image.open(image)
else:
raise ValueError(f"Image file not found: {image}")
elif isinstance(image, Image.Image):
image_pil = image
elif isinstance(image, np.ndarray):
image_pil = Image.fromarray(image)
else:
raise ValueError(f"Unsupported image type: {type(image)}")
return self.full_analysis_pipeline(image_pil, questionnaire_data)
except Exception as e:
logging.error(f"Wound analysis error: {e}")
return {
'success': False,
'error': str(e),
'visual_analysis': {},
'report': f"Analysis initialization failed: {str(e)}",
'saved_image_path': None,
'guideline_context': ""
}
def _assess_risk_legacy(self, questionnaire_data: dict) -> dict:
"""Legacy risk assessment function - maintains original function name."""
risk_factors = []
risk_score = 0
try:
# Age assessment
age = questionnaire_data.get('patient_age', 0)
if isinstance(age, str):
try:
age = int(age)
except ValueError:
age = 0
if age > 65:
risk_factors.append("Advanced age (>65)")
risk_score += 2
elif age > 50:
risk_factors.append("Older adult (50-65)")
risk_score += 1
# Wound duration assessment
duration = str(questionnaire_data.get('wound_duration', '')).lower()
if any(term in duration for term in ['month', 'months', 'year', 'years']):
risk_factors.append("Chronic wound (>4 weeks)")
risk_score += 3
elif any(term in duration for term in ['week', 'weeks']):
# Try to extract number of weeks
import re
weeks_match = re.search(r'(\d+)\s*week', duration)
if weeks_match and int(weeks_match.group(1)) > 4:
risk_factors.append("Chronic wound (>4 weeks)")
risk_score += 3
# Pain level assessment
pain = questionnaire_data.get('pain_level', 0)
if isinstance(pain, str):
try:
pain = float(pain)
except ValueError:
pain = 0
if pain >= 7:
risk_factors.append("High pain level (≥7/10)")
risk_score += 2
elif pain >= 5:
risk_factors.append("Moderate pain level (5-6/10)")
risk_score += 1
# Medical history assessment
medical_history = str(questionnaire_data.get('medical_history', '')).lower()
diabetic_status = str(questionnaire_data.get('diabetic', '')).lower()
if 'diabetes' in medical_history or 'yes' in diabetic_status:
risk_factors.append("Diabetes mellitus")
risk_score += 3
if any(term in medical_history for term in ['vascular', 'circulation', 'arterial', 'venous']):
risk_factors.append("Vascular disease")
risk_score += 2
if any(term in medical_history for term in ['immune', 'immunocompromised', 'steroid', 'chemotherapy']):
risk_factors.append("Immune system compromise")
risk_score += 2
if any(term in medical_history for term in ['smoking', 'smoker', 'tobacco']):
risk_factors.append("Smoking history")
risk_score += 2
# Infection signs
infection_signs = str(questionnaire_data.get('infection', '')).lower()
if 'yes' in infection_signs:
risk_factors.append("Signs of infection present")
risk_score += 3
# Moisture level
moisture = str(questionnaire_data.get('moisture', '')).lower()
if any(term in moisture for term in ['wet', 'heavy', 'excessive']):
risk_factors.append("Excessive wound exudate")
risk_score += 1
# Determine risk level
if risk_score >= 8:
risk_level = "Very High"
elif risk_score >= 6:
risk_level = "High"
elif risk_score >= 3:
risk_level = "Moderate"
else:
risk_level = "Low"
return {
'risk_score': risk_score,
'risk_level': risk_level,
'risk_factors': risk_factors,
'recommendations': self._get_risk_recommendations(risk_level, risk_factors)
}
except Exception as e:
logging.error(f"Risk assessment error: {e}")
return {
'risk_score': 0,
'risk_level': 'Unknown',
'risk_factors': [],
'recommendations': ["Unable to assess risk due to data processing error"]
}
def _get_risk_recommendations(self, risk_level: str, risk_factors: list) -> list:
"""Generate risk-based recommendations."""
recommendations = []
if risk_level in ["High", "Very High"]:
recommendations.append("Urgent referral to wound care specialist recommended")
recommendations.append("Consider daily wound monitoring")
recommendations.append("Implement aggressive wound care protocol")
elif risk_level == "Moderate":
recommendations.append("Regular wound care follow-up every 2-3 days")
recommendations.append("Monitor for signs of deterioration")
else:
recommendations.append("Standard wound care monitoring")
recommendations.append("Weekly assessment recommended")
# Specific recommendations based on risk factors
if "Diabetes mellitus" in risk_factors:
recommendations.append("Strict glycemic control essential")
recommendations.append("Monitor for diabetic complications")
if "Signs of infection present" in risk_factors:
recommendations.append("Consider antibiotic therapy")
recommendations.append("Increase wound cleaning frequency")
if "Excessive wound exudate" in risk_factors:
recommendations.append("Use high-absorption dressings")
recommendations.append("More frequent dressing changes may be needed")
return recommendations