import json import os from functools import lru_cache from pathlib import Path from typing import Dict, List, Tuple import psutil import torch import gradio as gr from transformers import AutoModelForZeroShotImageClassification, AutoProcessor from utils.cache_manager import cached_inference, configure_cache from utils.modality_router import detect_modality BASE_DIR = Path(__file__).resolve().parent LABEL_DIR = BASE_DIR / "labels" MODEL_ID = "google/medsiglip-448" HF_TOKEN = os.getenv("HF_TOKEN") physical_cores = psutil.cpu_count(logical=False) or psutil.cpu_count() or 1 torch.set_num_threads(min(physical_cores, 4)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") processor = AutoProcessor.from_pretrained( MODEL_ID, token=HF_TOKEN, use_fast=True, ) model = AutoModelForZeroShotImageClassification.from_pretrained( MODEL_ID, token=HF_TOKEN, dtype=torch.float32, ).to(device) model.eval() configure_cache(model, processor) LABEL_OVERRIDES = { "xray": "chest_labels.json", "mri": "brain_labels.json", } @lru_cache(maxsize=None) def load_labels(file_name: str) -> List[str]: label_path = LABEL_DIR / file_name with label_path.open("r", encoding="utf-8") as handle: return json.load(handle) def get_candidate_labels(image_path: str) -> Tuple[str, ...]: modality = detect_modality(image_path) candidate_path = LABEL_DIR / f"{modality}_labels.json" if not candidate_path.exists(): override = LABEL_OVERRIDES.get(modality) if override: candidate_path = LABEL_DIR / override if not candidate_path.exists(): candidate_path = LABEL_DIR / "general_labels.json" return tuple(load_labels(candidate_path.name)) def classify_medical_image(image_path: str) -> Dict[str, float]: if not image_path: return {} candidate_labels = get_candidate_labels(image_path) scores = cached_inference(image_path, candidate_labels) if not scores: return {} results = sorted(zip(candidate_labels, scores), key=lambda x: x[1], reverse=True) top_results = results[:5] return {label: float(score) for label, score in top_results} demo = gr.Interface( fn=classify_medical_image, inputs=gr.Image(type="filepath", label="Upload Medical Image"), outputs=gr.Label(num_top_classes=5, label="Top Predictions"), title="MedSigLIP Smart Medical Classifier", description="Zero-shot model with automatic label filtering for different modalities.", ) if __name__ == "__main__": server_name = os.getenv("SERVER_NAME", "0.0.0.0") port_env = os.getenv("SERVER_PORT") or os.getenv("PORT") or "7860" share_env = os.getenv("GRADIO_SHARE", "false").lower() queue_env = os.getenv("GRADIO_QUEUE", "false").lower() share_enabled = share_env in {"1", "true", "yes"} queue_enabled = queue_env in {"1", "true", "yes"} app_to_launch = demo.queue() if queue_enabled else demo app_to_launch.launch( server_name=server_name, server_port=int(port_env), share=share_enabled, show_api=False, )