xray / app.py
fokan's picture
Upload 3 files
8e99010 verified
raw
history blame
2.57 kB
import json
import os
from functools import lru_cache
from pathlib import Path
from typing import Dict, List, Tuple
import torch
import gradio as gr
from transformers import AutoModelForZeroShotImageClassification, AutoProcessor
from utils.cache_manager import cached_inference
from utils.modality_router import detect_modality
BASE_DIR = Path(__file__).resolve().parent
LABEL_DIR = BASE_DIR / "labels"
MODEL_ID = "google/medsiglip-448"
HF_TOKEN = os.getenv("HF_TOKEN")
torch.set_num_threads(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = AutoModelForZeroShotImageClassification.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
torch_dtype=model_dtype,
).to(device)
model.eval()
LABEL_OVERRIDES = {
"xray": "chest_labels.json",
"mri": "brain_labels.json",
}
@lru_cache(maxsize=None)
def load_labels(file_name: str) -> List[str]:
label_path = LABEL_DIR / file_name
with label_path.open("r", encoding="utf-8") as handle:
return json.load(handle)
def get_candidate_labels(image_path: str) -> Tuple[str, ...]:
modality = detect_modality(image_path)
candidate_path = LABEL_DIR / f"{modality}_labels.json"
if not candidate_path.exists():
override = LABEL_OVERRIDES.get(modality)
if override:
candidate_path = LABEL_DIR / override
if not candidate_path.exists():
candidate_path = LABEL_DIR / "general_labels.json"
return tuple(load_labels(candidate_path.name))
def classify_medical_image(image_path: str) -> Dict[str, float]:
if not image_path:
return {}
candidate_labels = get_candidate_labels(image_path)
scores = cached_inference(image_path, candidate_labels, model, processor)
if not scores:
return {}
results = sorted(zip(candidate_labels, scores), key=lambda x: x[1], reverse=True)
top_results = results[:5]
return {label: float(score) for label, score in top_results}
demo = gr.Interface(
fn=classify_medical_image,
inputs=gr.Image(type="filepath", label="📤 Upload Medical Image"),
outputs=gr.Label(num_top_classes=5, label="🧠 Top Predictions"),
title="🩻 MedSigLIP Smart Medical Classifier",
description="Zero-shot model with automatic label filtering for different modalities.",
flagging_mode="never",
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)