import json import os from functools import lru_cache from pathlib import Path from typing import Dict, List, Tuple import torch import gradio as gr from transformers import AutoModelForZeroShotImageClassification, AutoProcessor from utils.cache_manager import cached_inference from utils.modality_router import detect_modality BASE_DIR = Path(__file__).resolve().parent LABEL_DIR = BASE_DIR / "labels" MODEL_ID = "google/medsiglip-448" HF_TOKEN = os.getenv("HF_TOKEN") torch.set_num_threads(1) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN) model = AutoModelForZeroShotImageClassification.from_pretrained( MODEL_ID, token=HF_TOKEN, torch_dtype=model_dtype, ).to(device) model.eval() LABEL_OVERRIDES = { "xray": "chest_labels.json", "mri": "brain_labels.json", } @lru_cache(maxsize=None) def load_labels(file_name: str) -> List[str]: label_path = LABEL_DIR / file_name with label_path.open("r", encoding="utf-8") as handle: return json.load(handle) def get_candidate_labels(image_path: str) -> Tuple[str, ...]: modality = detect_modality(image_path) candidate_path = LABEL_DIR / f"{modality}_labels.json" if not candidate_path.exists(): override = LABEL_OVERRIDES.get(modality) if override: candidate_path = LABEL_DIR / override if not candidate_path.exists(): candidate_path = LABEL_DIR / "general_labels.json" return tuple(load_labels(candidate_path.name)) def classify_medical_image(image_path: str) -> Dict[str, float]: if not image_path: return {} candidate_labels = get_candidate_labels(image_path) scores = cached_inference(image_path, candidate_labels, model, processor) if not scores: return {} results = sorted(zip(candidate_labels, scores), key=lambda x: x[1], reverse=True) top_results = results[:5] return {label: float(score) for label, score in top_results} demo = gr.Interface( fn=classify_medical_image, inputs=gr.Image(type="filepath", label="📤 Upload Medical Image"), outputs=gr.Label(num_top_classes=5, label="🧠 Top Predictions"), title="🩻 MedSigLIP Smart Medical Classifier", description="Zero-shot model with automatic label filtering for different modalities.", flagging_mode="never", ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)