--- license: apache-2.0 language: en library_name: pytorch tags: - image-classification - medical-imaging - diabetic-retinopathy - pytorch - timm - efficientnet datasets: - aptos2019-blindness-detection widget: - src: gradcam_visualizations/gradcam_sample_003.png example_title: No DR Example - src: gradcam_visualizations/gradcam_sample_007.png example_title: Severe DR Example --- # Diabetic Retinopathy Grading Model (V2) This is a multi-task deep learning model trained to classify the severity of Diabetic Retinopathy (DR) from retinal fundus images. It is based on the **EfficientNet-B3** architecture and was specifically optimized to improve the **Quadratic Weighted Kappa (QWK)** score, a clinically relevant metric for ordinal classification tasks like DR grading. This model is the second iteration (V2) of a project focused on building a diagnostically "smarter" classifier that is more sensitive to severe, vision-threatening stages of the disease. ## Model Details - **Architecture:** `timm/efficientnet_b3` backbone with a custom multi-task head. - **Input Size:** 512x512 pixels. - **Output:** A dictionary containing logits for three tasks: - `severity`: 5 classes (0: No DR, 1: Mild, 2: Moderate, 3: Severe, 4: Proliferative). - `lesions`: 5 classes (multi-label for various lesion types). - `regions`: 5 classes (multi-label for affected anatomical regions). - **Training Objective:** The model was trained focusing only on the `severity` task by setting the loss weights for auxiliary tasks to zero. The auxiliary heads can still produce outputs for interpretability. ## How to Get Started & Use The model can be easily loaded from Hugging Face Hub for inference. ```bash # Install required libraries pip install torch torchvision timm albumentations huggingface-hub numpy pillow opencv-python ``` ```python import torch import torch.nn as nn import torch.nn.functional as F import timm from PIL import Image import numpy as np import albumentations as A from albumentations.pytorch import ToTensorV2 from huggingface_hub import hf_hub_download # Define the model architecture class MultiTaskDRModel(nn.Module): def __init__(self, model_name='efficientnet_b3', num_classes=5, num_lesion_types=5, num_regions=5, pretrained=False): super(MultiTaskDRModel, self).__init__() self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0) self.feature_dim = self.backbone.num_features self.attention = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(self.feature_dim, self.feature_dim // 8), nn.ReLU(inplace=True), nn.Linear(self.feature_dim // 8, self.feature_dim), nn.Sigmoid() ) self.feature_norm = nn.BatchNorm1d(self.feature_dim) self.dropout = nn.Dropout(0.4) self.severity_classifier = nn.Sequential( nn.Linear(self.feature_dim, self.feature_dim // 2), nn.ReLU(inplace=True), nn.Dropout(0.2), nn.Linear(self.feature_dim // 2, num_classes) ) self.lesion_detector = nn.Sequential( nn.Linear(self.feature_dim, self.feature_dim // 4), nn.ReLU(inplace=True), nn.Dropout(0.2), nn.Linear(self.feature_dim // 4, num_lesion_types) ) self.region_predictor = nn.Sequential( nn.Linear(self.feature_dim, self.feature_dim // 4), nn.ReLU(inplace=True), nn.Dropout(0.2), nn.Linear(self.feature_dim // 4, num_regions) ) def forward(self, x): features = self.backbone.forward_features(x) pooled_features = F.adaptive_avg_pool2d(features, 1).flatten(1) attention_weights = self.attention(pooled_features.unsqueeze(-1).unsqueeze(-1)) features = pooled_features * attention_weights features = self.feature_norm(features) features = self.dropout(features) severity_logits = self.severity_classifier(features) lesion_logits = self.lesion_detector(features) region_logits = self.region_predictor(features) return { 'severity': severity_logits, 'lesions': lesion_logits, 'regions': region_logits, 'features': features } # Load the model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = MultiTaskDRModel() # Download and load the checkpoint model_path = hf_hub_download( repo_id="dheeren-tejani/DiabeticRetinpathyClassifier", filename="best_model_v2.pth" ) checkpoint = torch.load(model_path, map_location=device, weights_only=False) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() print("Model loaded successfully!") # Preprocessing function def preprocess_image(image_path): transforms = A.Compose([ A.Resize(512, 512), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2(), ]) image = np.array(Image.open(image_path).convert("RGB")) image_tensor = transforms(image=image)['image'].unsqueeze(0) return image_tensor # Example inference def predict_dr_severity(image_path): image_tensor = preprocess_image(image_path).to(device) with torch.no_grad(): outputs = model(image_tensor) # Get severity prediction severity_probs = torch.softmax(outputs['severity'], dim=1) predicted_class = torch.argmax(severity_probs, dim=1).item() confidence = severity_probs[0, predicted_class].item() severity_labels = { 0: "No DR", 1: "Mild DR", 2: "Moderate DR", 3: "Severe DR", 4: "Proliferative DR" } return { 'predicted_severity': severity_labels[predicted_class], 'confidence': confidence, 'all_probabilities': severity_probs[0].cpu().numpy() } # Example usage # result = predict_dr_severity("path/to/your/fundus_image.jpg") # print(f"Predicted: {result['predicted_severity']} (Confidence: {result['confidence']:.3f})") ``` ## Training Details ### V2 Improvements This model (V2) was specifically designed to address the shortcomings of a baseline model (V1) that struggled with severe-stage DR detection: - **Higher Resolution:** Increased from 224×224 to 512×512 to capture finer pathological details - **Class Balancing:** Implemented WeightedRandomSampler to oversample rare minority classes (Severe and Proliferative DR) - **Focal Loss:** Replaced standard Cross-Entropy with Focal Loss (γ=2.0) to focus on hard-to-classify examples - **Focused Training:** Set auxiliary task weights to zero, dedicating full model capacity to severity classification ### Hyperparameters - **Optimizer:** AdamW - **Learning Rate:** 1e-4 - **Scheduler:** CosineAnnealingWarmRestarts (T_MAX=10) - **Batch Size:** 16 - **Epochs:** 17 (Early stopping) - **Image Size:** 512×512 ## Performance The model was evaluated on a held-out validation set of 735 images: | Metric | Score | |--------|-------| | **Quadratic Weighted Kappa (QWK)** | **0.796** | | Accuracy | 65.0% | | F1-Score (Weighted) | 66.3% | | F1-Score (Macro) | 53.5% | ### Key Achievement The V2 model achieved a **+3.5% improvement in QWK** over the V1 baseline (0.761), indicating it makes "smarter" errors that are more aligned with clinical judgment, despite lower overall accuracy. This trade-off prioritizes clinically relevant performance over naive accuracy. ## Limitations ⚠️ **Important Disclaimers:** - This model was trained on a single public dataset and may not generalize to different clinical settings, camera types, or patient demographics - The dataset may contain inherent demographic biases - **This is NOT a medical device** and should not be used for actual clinical diagnosis - Always consult qualified healthcare professionals for medical decisions ## Citation If you use this model in your research, please cite: ```bibtex @misc{dheerentejani2025dr, author = {Dheeren Tejani}, title = {Diabetic Retinopathy Grading Model V2}, year = {2025}, publisher = {Hugging Face}, journal = {Hugging Face Model Hub}, howpublished = {\url{https://huggingface.co/dheeren-tejani/DiabeticRetinpathyClassifier}}, } ``` ## License This model is released under the Apache 2.0 License.