|
|
import json |
|
|
import os |
|
|
from functools import lru_cache |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Tuple |
|
|
|
|
|
import torch |
|
|
import gradio as gr |
|
|
from transformers import AutoModelForZeroShotImageClassification, AutoProcessor |
|
|
|
|
|
from utils.cache_manager import cached_inference |
|
|
from utils.modality_router import detect_modality |
|
|
|
|
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent |
|
|
LABEL_DIR = BASE_DIR / "labels" |
|
|
MODEL_ID = "google/medsiglip-448" |
|
|
|
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
torch.set_num_threads(1) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN) |
|
|
model = AutoModelForZeroShotImageClassification.from_pretrained( |
|
|
MODEL_ID, |
|
|
token=HF_TOKEN, |
|
|
torch_dtype=model_dtype, |
|
|
).to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
LABEL_OVERRIDES = { |
|
|
"xray": "chest_labels.json", |
|
|
"mri": "brain_labels.json", |
|
|
} |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=None) |
|
|
def load_labels(file_name: str) -> List[str]: |
|
|
label_path = LABEL_DIR / file_name |
|
|
with label_path.open("r", encoding="utf-8") as handle: |
|
|
return json.load(handle) |
|
|
|
|
|
|
|
|
def get_candidate_labels(image_path: str) -> Tuple[str, ...]: |
|
|
modality = detect_modality(image_path) |
|
|
candidate_path = LABEL_DIR / f"{modality}_labels.json" |
|
|
if not candidate_path.exists(): |
|
|
override = LABEL_OVERRIDES.get(modality) |
|
|
if override: |
|
|
candidate_path = LABEL_DIR / override |
|
|
if not candidate_path.exists(): |
|
|
candidate_path = LABEL_DIR / "general_labels.json" |
|
|
|
|
|
return tuple(load_labels(candidate_path.name)) |
|
|
|
|
|
|
|
|
def classify_medical_image(image_path: str) -> Dict[str, float]: |
|
|
if not image_path: |
|
|
return {} |
|
|
|
|
|
candidate_labels = get_candidate_labels(image_path) |
|
|
scores = cached_inference(image_path, candidate_labels, model, processor) |
|
|
|
|
|
if not scores: |
|
|
return {} |
|
|
|
|
|
results = sorted(zip(candidate_labels, scores), key=lambda x: x[1], reverse=True) |
|
|
top_results = results[:5] |
|
|
|
|
|
return {label: float(score) for label, score in top_results} |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=classify_medical_image, |
|
|
inputs=gr.Image(type="filepath", label="📤 Upload Medical Image"), |
|
|
outputs=gr.Label(num_top_classes=5, label="🧠 Top Predictions"), |
|
|
title="🩻 MedSigLIP Smart Medical Classifier", |
|
|
description="Zero-shot model with automatic label filtering for different modalities.", |
|
|
flagging_mode="never", |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False) |
|
|
|