File size: 2,572 Bytes
a9e2b5c
d672951
 
 
8e99010
d672951
1739e75
d672951
 
bafe93c
8e99010
 
 
a9e2b5c
d672951
 
 
1739e75
a9e2b5c
d672951
 
8e99010
 
d672951
 
 
 
1739e75
d672951
 
 
1739e75
 
 
 
8e99010
 
 
 
d672951
 
 
 
 
 
 
 
1739e75
8e99010
 
 
 
 
 
 
 
 
1739e75
8e99010
1739e75
 
d672951
 
 
 
8e99010
 
 
 
d672951
 
8e99010
 
d672951
8e99010
1739e75
a9e2b5c
 
bafe93c
d672951
bafe93c
d672951
 
7b7a932
a9e2b5c
 
d672951
a9e2b5c
7b7a932
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import json
import os
from functools import lru_cache
from pathlib import Path
from typing import Dict, List, Tuple

import torch
import gradio as gr
from transformers import AutoModelForZeroShotImageClassification, AutoProcessor

from utils.cache_manager import cached_inference
from utils.modality_router import detect_modality


BASE_DIR = Path(__file__).resolve().parent
LABEL_DIR = BASE_DIR / "labels"
MODEL_ID = "google/medsiglip-448"


HF_TOKEN = os.getenv("HF_TOKEN")

torch.set_num_threads(1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = AutoModelForZeroShotImageClassification.from_pretrained(
    MODEL_ID,
    token=HF_TOKEN,
    torch_dtype=model_dtype,
).to(device)
model.eval()


LABEL_OVERRIDES = {
    "xray": "chest_labels.json",
    "mri": "brain_labels.json",
}


@lru_cache(maxsize=None)
def load_labels(file_name: str) -> List[str]:
    label_path = LABEL_DIR / file_name
    with label_path.open("r", encoding="utf-8") as handle:
        return json.load(handle)


def get_candidate_labels(image_path: str) -> Tuple[str, ...]:
    modality = detect_modality(image_path)
    candidate_path = LABEL_DIR / f"{modality}_labels.json"
    if not candidate_path.exists():
        override = LABEL_OVERRIDES.get(modality)
        if override:
            candidate_path = LABEL_DIR / override
    if not candidate_path.exists():
        candidate_path = LABEL_DIR / "general_labels.json"

    return tuple(load_labels(candidate_path.name))


def classify_medical_image(image_path: str) -> Dict[str, float]:
    if not image_path:
        return {}

    candidate_labels = get_candidate_labels(image_path)
    scores = cached_inference(image_path, candidate_labels, model, processor)

    if not scores:
        return {}

    results = sorted(zip(candidate_labels, scores), key=lambda x: x[1], reverse=True)
    top_results = results[:5]

    return {label: float(score) for label, score in top_results}


demo = gr.Interface(
    fn=classify_medical_image,
    inputs=gr.Image(type="filepath", label="📤 Upload Medical Image"),
    outputs=gr.Label(num_top_classes=5, label="🧠 Top Predictions"),
    title="🩻 MedSigLIP Smart Medical Classifier",
    description="Zero-shot model with automatic label filtering for different modalities.",
    flagging_mode="never",
)


if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)