MedGemma 1.5 - Multiple Myeloma Risk Assessment (Module 2)
𧬠Model Overview
This repository contains a PEFT/LoRA adapter fine-tuned on MedGemma 1.5 4B-IT. It is specifically designed to extract relevant lab values (e.g., CRAB panels, Beta-2 Microglobulin) as well as the behavioural markers of the patient and calculate concise risk profiles for Multiple Myeloma patients based on raw clinical text. This adapter was finetuned on clinically validated data, extracted from MIMIC-IV and MIMIC-IV-notes dataset. This adapter was developed as part of a broader agentic AI orchestration system. It acts as Module 2, operating alongside downstream vision and progression modules to feed structured clinical data into a RAG-enabled Gradio dashboard.
π Associated Code Repository
The complete source code including data preparation, adapter training, evaluation workflows, and the full agent orchestration pipeline implementing the proposed Mixture of Adapters is publicly available on GitHub: here
Base Model Dependency
This is an adapter model. It requires the base weights from Google's MedGemma 1.5 4B-IT.
β οΈ License and Terms of Use
- LoRA Adapter Weights: The adapter weights and associated code in this repository are open-sourced under the Apache 2.0 license.
- Base Model: To use this adapter, you must agree to the Google Health AI Developer Foundations Terms of Use to access the underlying MedGemma 1.5 weights.
- Clinical Disclaimer: This model is for educational and research purposes only. It is not a medical device, is not intended for clinical use, and should not be used to diagnose, treat, or offer medical advice for any disease or condition.
π» How to Use
Because this model uses the MedGemma architecture, it is recommended to load the model in 4-bit NF4 quantization and utilize a dummy image tensor to stabilize the cross-attention vision layers during text-only generation.
from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig
from peft import PeftModel
from PIL import Image
import torch
# 1. Load Base Model in 4-bit NF4
model_id = "google/medgemma-1.5-4b-it"
processor = AutoProcessor.from_pretrained(model_id)
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.float16
)
base_model = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map="auto",
quantization_config=quant_config
)
# 2. Load this LoRA Adapter
model = PeftModel.from_pretrained(base_model, "shrish/medgemma-1.5-mm-risk-module2")
# 3. Format Prompt and Dummy Image
prompt = "Assess the Multiple Myeloma risk profile for this <AGE>-year-old <GENDER> patient on this clinical data: [INSERT DATA]"
messages = [{"role": "user", "content": prompt}]
formatted_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
dummy_image = Image.new('RGB', (224, 224), color='black')
inputs = processor(
text=formatted_prompt,
images=dummy_image,
return_tensors="pt",
padding=True
).to(model.device)
inputs.pop("token_type_ids", None)
# 4. Generate
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=400, do_sample=False)
input_length = inputs["input_ids"].shape[1]
print(processor.decode(outputs[0, input_length:], skip_special_tokens=True))
- Downloads last month
- 22
Model tree for shrishSVaidya/medgemma-1.5-mm-risk-module2
Base model
google/medgemma-1.5-4b-it