File size: 2,572 Bytes
a9e2b5c d672951 8e99010 d672951 1739e75 d672951 bafe93c 8e99010 a9e2b5c d672951 1739e75 a9e2b5c d672951 8e99010 d672951 1739e75 d672951 1739e75 8e99010 d672951 1739e75 8e99010 1739e75 8e99010 1739e75 d672951 8e99010 d672951 8e99010 d672951 8e99010 1739e75 a9e2b5c bafe93c d672951 bafe93c d672951 7b7a932 a9e2b5c d672951 a9e2b5c 7b7a932 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import json
import os
from functools import lru_cache
from pathlib import Path
from typing import Dict, List, Tuple
import torch
import gradio as gr
from transformers import AutoModelForZeroShotImageClassification, AutoProcessor
from utils.cache_manager import cached_inference
from utils.modality_router import detect_modality
BASE_DIR = Path(__file__).resolve().parent
LABEL_DIR = BASE_DIR / "labels"
MODEL_ID = "google/medsiglip-448"
HF_TOKEN = os.getenv("HF_TOKEN")
torch.set_num_threads(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = AutoModelForZeroShotImageClassification.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
torch_dtype=model_dtype,
).to(device)
model.eval()
LABEL_OVERRIDES = {
"xray": "chest_labels.json",
"mri": "brain_labels.json",
}
@lru_cache(maxsize=None)
def load_labels(file_name: str) -> List[str]:
label_path = LABEL_DIR / file_name
with label_path.open("r", encoding="utf-8") as handle:
return json.load(handle)
def get_candidate_labels(image_path: str) -> Tuple[str, ...]:
modality = detect_modality(image_path)
candidate_path = LABEL_DIR / f"{modality}_labels.json"
if not candidate_path.exists():
override = LABEL_OVERRIDES.get(modality)
if override:
candidate_path = LABEL_DIR / override
if not candidate_path.exists():
candidate_path = LABEL_DIR / "general_labels.json"
return tuple(load_labels(candidate_path.name))
def classify_medical_image(image_path: str) -> Dict[str, float]:
if not image_path:
return {}
candidate_labels = get_candidate_labels(image_path)
scores = cached_inference(image_path, candidate_labels, model, processor)
if not scores:
return {}
results = sorted(zip(candidate_labels, scores), key=lambda x: x[1], reverse=True)
top_results = results[:5]
return {label: float(score) for label, score in top_results}
demo = gr.Interface(
fn=classify_medical_image,
inputs=gr.Image(type="filepath", label="📤 Upload Medical Image"),
outputs=gr.Label(num_top_classes=5, label="🧠 Top Predictions"),
title="🩻 MedSigLIP Smart Medical Classifier",
description="Zero-shot model with automatic label filtering for different modalities.",
flagging_mode="never",
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)
|