File size: 3,208 Bytes
5134884
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aee1a39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28997be
aee1a39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28997be
aee1a39
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
---
license: mit
language: en
library_name: transformers
tags:
- unet
- film
- computer-vision
- image-segmentation
- medical-imaging
- pytorch
---

# FILMUnet2D

This model is a 2D U-Net with FiLM conditioning for Ultrasound multi-organ segmentation.

## Installation

Make sure you have `transformers` and `torch` installed.

```bash
pip install transformers torch
```

## Usage

You can load the model and processor using the `Auto` classes from `transformers`. Since this repository contains custom code, make sure to pass `trust_remote_code=True`.

```python
import torch
from transformers import AutoModel, AutoImageProcessor
from PIL import Image

# 1. Load model and processor
repo_id = "AImageLab-Zip/US_FiLMUNet" 

processor = AutoImageProcessor.from_pretrained(repo_id, trust_remote_code=True)
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
model.eval()

# 2. Load and preprocess your image
#    The processor handles resizing, letterboxing, and normalization.
image = Image.open("path/to/your/image.png").convert("RGB")
inputs = processor(images=image, return_tensors="pt")

# 3. Prepare conditioning input
#    This should be an integer tensor representing the organ ID.
#    Replace `4` with the appropriate ID for your use case.
organ_id = torch.tensor([4]) 

# 4. Run inference
with torch.no_grad():
    outputs = model(**inputs, organ_id=organ_id)

# 5. Post-process the output to get the final segmentation mask
#    The processor can convert the logits to a binary mask, automatically handling
#    the removal of letterbox padding and resizing to the original image dimensions.
mask = processor.post_process_semantic_segmentation(
    outputs, 
    inputs, 
    threshold=0.7, 
    return_as_pil=True
)[0]

# 6. Save the result
mask.save("output_mask.png")

print("Segmentation mask saved to output_mask.png")
```

### Model Details

- **Architecture:** U-Net with FiLM layers for conditional segmentation.
- **Conditioning:** The model's output is conditioned on an `organ_id` input.
- **Input:** RGB images.
- **Output:** A single-channel segmentation mask.

### Configuration

The model configuration can be accessed via `model.config`. Key parameters include:
- `in_channels`: Number of input channels (default: 3).
- `num_classes`: Number of output classes (default: 1).
- `n_organs`: The number of different organs the model was trained to condition on.
- `depth`: The depth of the U-Net.
- `size`: The base number of filters in the first layer.

### Organ IDs

The `organ_id` passed to the model corresponds to the following mapping:

```python
organ_to_class_dict = {
    "appendix": 0,
    "breast": 1,
    "breast_luminal": 1,
    "cardiac": 2,
    "thyroid": 3,
    "fetal": 4,
    "kidney": 5,
    "liver": 6,
    "testicle": 7,
}
```

### Alternative Versions

This repository contains multiple versions of the model located in subfolders. You can load a specific version by using the `subfolder` parameter.

#### 4-Stage U-Net

This version has a U-Net depth of 4.

```python
from transformers import AutoModel

model_4_stages = AutoModel.from_pretrained(
    "AImageLab-Zip/US_FiLMUNet", 
    subfolder="unet_4_stages",
    trust_remote_code=True
)
```