maddigit commited on
Commit
ddbdbca
·
verified ·
1 Parent(s): eb2cd25

Upload 27 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/figures/architecture.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/figures/CreatiDesign_logo.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/figures/dataset.jpg filter=lfs diff=lfs merge=lfs -text
39
+ assets/figures/loop_edit.jpg filter=lfs diff=lfs merge=lfs -text
40
+ assets/figures/motivation.jpg filter=lfs diff=lfs merge=lfs -text
41
+ assets/figures/Qualitative_results.jpg filter=lfs diff=lfs merge=lfs -text
42
+ assets/figures/Quantitative_results.png filter=lfs diff=lfs merge=lfs -text
43
+ assets/figures/teaser.jpg filter=lfs diff=lfs merge=lfs -text
44
+ dataloader/arial.ttf filter=lfs diff=lfs merge=lfs -text
45
+ modules/flux/__pycache__/attention_processor_flux_creatidesign.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,105 @@
1
- ---
2
- title: Layout Crazydesign
3
- emoji: 🖼
4
- colorFrom: purple
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.44.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # <img src='assets/figures/CreatiDesign_logo.png' alt="CreatiDesign Logo" width='24px' /> CreatiDesign
3
+
4
+
5
+ <img src='assets/figures/teaser.jpg' width='100%' />
6
+
7
+ <br>
8
+ <a href="https://arxiv.org/pdf/2505.19114"><img src="https://img.shields.io/static/v1?label=Paper&message=2505.19114&color=red&logo=arxiv"></a>
9
+ <a href="https://huizhang0812.github.io/CreatiDesign/"><img src="https://img.shields.io/static/v1?label=Project%20Page&message=Github&color=blue&logo=github-pages"></a>
10
+ <a href="https://huggingface.co/datasets/HuiZhang0812/CreatiDesign_dataset"><img src="https://img.shields.io/badge/🤗_HuggingFace-Dataset-ffbd45.svg" alt="HuggingFace"></a>
11
+ <a href="https://huggingface.co/datasets/HuiZhang0812/CreatiDesign_benchmark"><img src="https://img.shields.io/badge/🤗_HuggingFace-Benchmark-ffbd45.svg" alt="HuggingFace"></a>
12
+ <a href="https://huggingface.co/HuiZhang0812/CreatiDesign"><img src="https://img.shields.io/badge/🤗_HuggingFace-Model-ffbd45.svg" alt="HuggingFace"></a>
13
+
14
+
15
+
16
+ > <img src='assets/figures/CreatiDesign_logo.png' alt="CreatiDesign Logo" width='15px' /> **CreatiDesign: A Unified Multi-Conditional Diffusion Transformer for Creative Graphic Design**
17
+ > <br>
18
+ > [Hui Zhang](https://huizhang0812.github.io/),
19
+ > [Dexiang Hong](https://scholar.google.com.hk/citations?user=DUNijlcAAAAJ&hl=zh-CN),
20
+ > Maoke Yang,
21
+ > Yutao Cheng,
22
+ > Zhao Zhang,
23
+ > Jie Shao,
24
+ > [Xinglong Wu](https://scholar.google.com/citations?user=LVsp9RQAAAAJ&hl=zh-CN),
25
+ > [Zuxuan Wu](https://zxwu.azurewebsites.net/),
26
+ > and
27
+ > [Yu-Gang Jiang](https://scholar.google.com/citations?user=f3_FP8AAAAAJ)
28
+ > <br>
29
+ > Fudan University & ByteDance Intelligent Creation.
30
+ > <br>
31
+
32
+ ## 🎯 Introduction
33
+ CreatiDesign tackles the challenge of automated graphic design generation that requires precise control over multiple heterogeneous elements—primary visual elements (product images), secondary visual elements (decorative objects), and textual elements (slogans, titles). CreatiDesign introduces a unified multi-conditional diffusion transformer that achieves flexible and harmonious integration of diverse design elements with minimal architectural modifications.
34
+
35
+
36
+ <img src='assets/figures/motivation.jpg' width='100%' />
37
+
38
+ ## ✨ Key Features
39
+
40
+ - **🎨 Multi-Conditional Image Generation**: Unified architecture supporting images, semantic layouts conditions simultaneously
41
+ - **🎯 Precise Element Control**: Multimodal attention mask mechanism prevents condition interference
42
+ - **🗂️ Graphic Design Datasets**: 400K graphic design samples with multi-condition annotations construced by automatic pipeline
43
+ - **📊 Comprehensive Benchmark**: Rigorous evaluation of multi-subject preservation and semantic layout alignment.
44
+ - **✏️ Zero-Shot Editing**: Natural extension to editing tasks without additional training or retraining
45
+
46
+
47
+
48
+ ## Quick Start
49
+ ### Setup
50
+ 1. **Environment setup**
51
+ ```bash
52
+ conda create -n creatidesign python=3.10 -y
53
+ conda activate creatidesign
54
+ conda install pytorch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 pytorch-cuda=12.1 -c pytorch -c nvidia
55
+ ```
56
+ 2. **Requirements installation**
57
+ ```bash
58
+ pip install -r requirements.txt
59
+ ```
60
+
61
+
62
+ ## Dataset and Benchmark
63
+ ### CreatiDesign Datasets <a href="https://huggingface.co/datasets/HuiZhang0812/CreatiDesign_dataset"><img src="https://img.shields.io/badge/🤗_HuggingFace-Dataset-ffbd45.svg" alt="HuggingFace"></a>
64
+ Our CreatiDesign dataset contains **400K high-quality graphic design samples** with comprehensive multi-condition annotations, constructed through our fully automated pipeline. The dataset covers diverse design categories including movie posters, product advertisements, brand promotions, and social media content.
65
+
66
+ ### CreatiDesign Benchmark <a href="https://huggingface.co/datasets/HuiZhang0812/CreatiDesign_benchmark"><img src="https://img.shields.io/badge/🤗_HuggingFace-Benchmark-ffbd45.svg" alt="HuggingFace"></a>
67
+ Our comprehensive benchmark contains **1,000 carefully curated samples** designed to rigorously evaluate graphic design generation capabilities across multiple dimensions. The benchmark assesses both fine-grained condition adherence and overall visual quality.
68
+
69
+ To evaluate the model's graphic design generation capabilities through our benchmark, follow these steps:
70
+
71
+ Generate images:
72
+ ```python
73
+ python test_creatidesign_benchmark.py
74
+ ```
75
+ Evaluate multi-subject preservation:
76
+ ```python
77
+ python eval/subject.py
78
+ ```
79
+ Evaluate semantic layout alignment:
80
+ ```python
81
+ python eval/layout.py
82
+ ```
83
+ ```python
84
+ python eval/text.py
85
+ ```
86
+
87
+
88
+ ## Models
89
+ **Multi-Conditional Graphic Design:**
90
+ | Model | Base model | Description |
91
+ | ------------------------------------------------------------------------------------------------ | -------------- | -------------------------------------------------------------------------------------------------------- |
92
+ | <a href="https://huggingface.co/HuiZhang0812/CreatiDesign"><img src="https://img.shields.io/badge/🤗_HuggingFace-Model-ffbd45.svg" alt="HuggingFace"></a> | FLUX.1-dev | model used in the paper
93
+
94
+ ## ✒️ Citation
95
+
96
+ If you find our work useful for your research and applications, please kindly cite using this BibTeX:
97
+
98
+ ```latex
99
+ @article{zhang2025creatidesign,
100
+ title={CreatiDesign: A Unified Multi-Conditional Diffusion Transformer for Creative Graphic Design},
101
+ author={Zhang, Hui and Hong, Dexiang and Yang, Maoke and Chen, Yutao and Zhang, Zhao and Shao, Jie and Wu, Xinglong and Wu, Zuxuan and Jiang, Yu-Gang},
102
+ journal={arXiv preprint arXiv:2505.19114},
103
+ year={2025}
104
+ }
105
+ ```
assets/figures/CreatiDesign_logo.png ADDED

Git LFS Details

  • SHA256: 2ffea0372e673a7381bbf369e6675089f1b0218d744249cb2d921d3267f12e14
  • Pointer size: 131 Bytes
  • Size of remote file: 477 kB
assets/figures/Qualitative_results.jpg ADDED

Git LFS Details

  • SHA256: cc5888193d7aa89cf67dc059323884a13d3bb302fcc61e1138dd349c9a3b1016
  • Pointer size: 132 Bytes
  • Size of remote file: 2.82 MB
assets/figures/Quantitative_results.png ADDED

Git LFS Details

  • SHA256: 62550edea5dc294ea9f57e6647378b7d64f7211518cbe83b7361fdd86ae01f28
  • Pointer size: 131 Bytes
  • Size of remote file: 634 kB
assets/figures/architecture.jpg ADDED

Git LFS Details

  • SHA256: 93acf777f1005b6baaf26627f5c5b1b3c928972d18ed2df3fb8eda96c90145bb
  • Pointer size: 131 Bytes
  • Size of remote file: 872 kB
assets/figures/dataset.jpg ADDED

Git LFS Details

  • SHA256: 8868318d37f5f176fc7845be91bfe2a63eb783d6f17e6c350430cc7843b87619
  • Pointer size: 131 Bytes
  • Size of remote file: 342 kB
assets/figures/loop_edit.jpg ADDED

Git LFS Details

  • SHA256: b080cdbebf31707aa9a5c8877d207bde463842ecbf580c059a5784b670a6da9a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB
assets/figures/motivation.jpg ADDED

Git LFS Details

  • SHA256: 8fa691768656ccadb3d26e9b12d9729c8b9f6ce843e4671af7e02a32335dfef7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.37 MB
assets/figures/teaser.jpg ADDED

Git LFS Details

  • SHA256: 762c885fc465688dc680ea01f564fd041b45938a0c092122e7d3d0331a68a005
  • Pointer size: 132 Bytes
  • Size of remote file: 1.42 MB
dataloader/__pycache__/creatidesign_dataset_benchmark.cpython-310.pyc ADDED
Binary file (13.1 kB). View file
 
dataloader/arial.ttf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:35c0f3559d8db569e36c31095b8a60d441643d95f59139de40e23fada819b833
3
+ size 275572
dataloader/creatidesign_dataset_benchmark.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from PIL import Image
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from torchvision import transforms
6
+ import torch
7
+ import numpy as np
8
+ import random
9
+ from datasets import load_dataset
10
+ from tqdm import tqdm
11
+ def find_nearest_bucket_size(input_width, input_height, mode="x64", ratio=1):
12
+ buckets = [
13
+ (512, 2048),
14
+ (512, 1984),
15
+ (512, 1920),
16
+ (512, 1856),
17
+ (576, 1792),
18
+ (576, 1728),
19
+ (576, 1664),
20
+ (640, 1600),
21
+ (640, 1536),
22
+ (704, 1472),
23
+ (704, 1408),
24
+ (704, 1344),
25
+ (768, 1344),
26
+ (768, 1280),
27
+ (832, 1216),
28
+ (832, 1152),
29
+ (896, 1152),
30
+ (896, 1088),
31
+ (960, 1088),
32
+ (960, 1024),
33
+ (1024, 1024),
34
+ (1024, 960),
35
+ (1088, 960),
36
+ (1088, 896),
37
+ (1152, 896),
38
+ (1152, 832),
39
+ (1216, 832),
40
+ (1280, 768),
41
+ (1344, 768),
42
+ (1408, 704),
43
+ (1472, 704),
44
+ (1536, 640),
45
+ (1600, 640),
46
+ (1664, 576),
47
+ (1728, 576),
48
+ (1792, 576),
49
+ (1856, 512),
50
+ (1920, 512),
51
+ (1984, 512),
52
+ (2048, 512)
53
+ ]
54
+ aspect_ratios = [w / h for (w, h) in buckets]
55
+
56
+ assert mode in ["x64", "x8"]
57
+ if mode == "x64":
58
+ asp = input_width / input_height
59
+ diff = [abs(ar - asp) for ar in aspect_ratios]
60
+ bucket_id = int(np.argmin(diff))
61
+ gen_width, gen_height = buckets[bucket_id]
62
+ elif mode == "x8":
63
+ max_pixels = 1024 * 1024
64
+ ratio = (max_pixels / (input_width * input_height)) ** (0.5)
65
+ gen_width, gen_height = round(input_width * ratio), round(input_height * ratio)
66
+ gen_width = gen_width - gen_width % 8
67
+ gen_height = gen_height - gen_height % 8
68
+ else:
69
+ raise NotImplementedError
70
+
71
+ return (int(gen_width * ratio), int(gen_height * ratio))
72
+
73
+ def adjust_and_normalize_bboxes(bboxes, orig_width, orig_height):
74
+ # Adjust and normalize bbox
75
+ normalized_bboxes = []
76
+ for bbox in bboxes:
77
+ x1, y1, x2, y2 = bbox
78
+ x1_norm = round(x1 / orig_width,2)
79
+ y1_norm = round(y1 / orig_height,2)
80
+ x2_norm = round(x2 / orig_width,2)
81
+ y2_norm = round(y2 / orig_height,2)
82
+
83
+
84
+ normalized_bboxes.append([x1_norm, y1_norm, x2_norm, y2_norm])
85
+
86
+ return normalized_bboxes
87
+
88
+ def img_transforms(image, height=512, width=512):
89
+ transform = transforms.Compose(
90
+ [
91
+ transforms.Resize(
92
+ (height, width), interpolation=transforms.InterpolationMode.BILINEAR
93
+ ),
94
+ transforms.ToTensor(),
95
+ transforms.Normalize([0.5], [0.5]),
96
+ ]
97
+ )
98
+ image_transformed = transform(image)
99
+ return image_transformed
100
+
101
+ def mask_transforms(mask, height=512, width=512):
102
+ transform = transforms.Compose(
103
+ [
104
+ transforms.Resize(
105
+ (height, width),
106
+ interpolation=transforms.InterpolationMode.NEAREST
107
+ ),
108
+ transforms.ToTensor(),
109
+ ]
110
+ )
111
+ mask_transformed = transform(mask)
112
+ return mask_transformed
113
+
114
+
115
+ class DesignDataset(Dataset):
116
+
117
+ def __init__(
118
+ self,
119
+ dataset_name,
120
+ resolution=512,
121
+ condition_resolution=512,
122
+ condition_resolution_scale_ratio=0.5,
123
+ max_boxes_per_image=10,
124
+ neg_condition_image = 'same',
125
+ background_color = 'gray',
126
+ use_bucket=True,
127
+ box_confidence_th = 0.0
128
+ ):
129
+
130
+
131
+ print(f"Loading dataset from Hugging Face: {dataset_name}")
132
+
133
+ self.dataset = load_dataset(dataset_name, split="test")
134
+ print(f"Loaded {len(self.dataset)} samples")
135
+ from IPython.core.debugger import set_trace
136
+ set_trace()
137
+ self.max_boxes_per_image = max_boxes_per_image
138
+ self.resolution = resolution
139
+ self.condition_resolution=condition_resolution
140
+ self.neg_condition_image = neg_condition_image
141
+ self.use_bucket = use_bucket
142
+ self.condition_resolution_scale_ratio=condition_resolution_scale_ratio
143
+ self.box_confidence_th = box_confidence_th
144
+
145
+ if background_color == 'white':
146
+ self.background_color = (255, 255, 255)
147
+ elif background_color == 'black':
148
+ self.background_color = (0, 0, 0)
149
+ elif background_color == 'gray':
150
+ self.background_color = (128, 128, 128)
151
+ else:
152
+ raise ValueError("Invalid background color. Use 'white' or 'black'.")
153
+
154
+
155
+ def __len__(self):
156
+ return len(self.dataset)
157
+
158
+ def __getitem__(self, idx):
159
+ sample = self.dataset[idx]
160
+ image_source = sample['original_image']
161
+ subject_image = sample['condition_gray_background']
162
+ subject_mask = sample['subject_mask']
163
+ json_data = json.loads(sample['metadata'])
164
+
165
+ #img info
166
+ img_info = json_data['img_info']
167
+ img_id = img_info['img_id']
168
+ orig_width, orig_height = int(img_info["img_width"]),int(img_info["img_height"])
169
+
170
+ if self.use_bucket:
171
+ target_width, target_height = find_nearest_bucket_size(orig_width,orig_height)
172
+ condition_width = int(target_width * self.condition_resolution_scale_ratio)
173
+ condition_height = int(target_height * self.condition_resolution_scale_ratio)
174
+ else:
175
+ target_width = target_height = self.resolution
176
+ condition_width = condition_height = self.condition_resolution
177
+
178
+
179
+ img_tensor = img_transforms(image_source,height=target_height,width=target_width)
180
+
181
+
182
+ # global caption
183
+ global_caption = json_data['global_caption']
184
+
185
+
186
+ # object_annotations
187
+ object_annotations = json_data['object_annotations']
188
+
189
+ # object bbox list
190
+ objects_bbox = [item['bbox'] for item in object_annotations]
191
+
192
+ # object bbox caption
193
+ objects_caption = [item['bbox_detail_description'] for item in object_annotations]
194
+
195
+ # object bbox score
196
+ objects_bbox_score = [item['score'][0] for item in object_annotations]
197
+
198
+ # text
199
+ text_list = json_data["text_list"]
200
+ txt_bboxs = [item['bbox'] for item in text_list]
201
+ txt_captions = ["text:"+item['text'] for item in text_list]
202
+
203
+ txt_scores = [1.0 for _ in txt_bboxs]
204
+ # combine bbox 和 description
205
+ objects_bbox.extend(txt_bboxs)
206
+ objects_caption.extend(txt_captions)
207
+ objects_bbox_score.extend(txt_scores)
208
+
209
+ objects_bbox =torch.tensor(adjust_and_normalize_bboxes(objects_bbox,orig_width,orig_height))
210
+
211
+ objects_bbox_score = torch.tensor(objects_bbox_score)
212
+
213
+ boxes_mask = objects_bbox_score > self.box_confidence_th
214
+ objects_bbox_raw = objects_bbox[boxes_mask]
215
+ objects_caption = [object_caption for object_caption, box_mask in zip(objects_caption, boxes_mask) if box_mask]
216
+
217
+
218
+ num_boxes = objects_bbox_raw.shape[0]
219
+ objects_boxes_padded = torch.zeros((self.max_boxes_per_image, 4))
220
+ objects_masks_padded = torch.zeros(self.max_boxes_per_image)
221
+
222
+ objects_caption = objects_caption[:self.max_boxes_per_image]
223
+ objects_boxes_padded[:num_boxes] = objects_bbox_raw[:self.max_boxes_per_image]
224
+ objects_masks_padded[:num_boxes] = 1.
225
+
226
+ # objects_masks_maps
227
+ objects_masks_maps_padded = torch.zeros((self.max_boxes_per_image, target_height, target_width))
228
+ for idx in range(num_boxes):
229
+ x1, y1, x2, y2 = objects_boxes_padded[idx]
230
+
231
+ x1_pixel = int(x1 * target_width)
232
+ y1_pixel = int(y1 * target_height)
233
+ x2_pixel = int(x2 * target_width)
234
+ y2_pixel = int(y2 * target_height)
235
+
236
+
237
+ x1_pixel = max(0, min(x1_pixel, target_width-1))
238
+ y1_pixel = max(0, min(y1_pixel, target_height-1))
239
+ x2_pixel = max(0, min(x2_pixel, target_width-1))
240
+ y2_pixel = max(0, min(y2_pixel, target_height-1))
241
+
242
+ objects_masks_maps_padded[idx, y1_pixel:y2_pixel+1, x1_pixel:x2_pixel+1] = 1.0
243
+
244
+
245
+
246
+ # subject
247
+ original_size_subject_tensor = img_transforms(subject_image,height=target_height,width=target_width)
248
+ subject_tensor = img_transforms(subject_image,height=condition_height,width=condition_width)
249
+ subject_mask_tensor = mask_transforms(subject_mask, height=condition_height,width=condition_width)
250
+
251
+
252
+ if self.neg_condition_image=='black':
253
+ subject_image_black = Image.new('RGB', (orig_width, orig_height), (0, 0, 0))
254
+ subject_image_neg_tensor = img_transforms(subject_image_black,height=condition_height,width=condition_width)
255
+ elif self.neg_condition_image=='white':
256
+ subject_image_white = Image.new('RGB', (orig_width, orig_height), (255, 255, 255))
257
+ subject_image_neg_tensor = img_transforms(subject_image_white,height=condition_height,width=condition_width)
258
+ elif self.neg_condition_image=='gray':
259
+ subject_image_gray = Image.new('RGB', (orig_width, orig_height), (128, 128, 128))
260
+ subject_image_neg_tensor = img_transforms(subject_image_gray,height=condition_height,width=condition_width)
261
+ elif self.neg_condition_image=='same':
262
+ subject_image_neg_tensor = subject_tensor
263
+
264
+
265
+ output = dict(
266
+ id=img_id,
267
+ caption=global_caption,
268
+ objects_boxes=objects_boxes_padded,
269
+ objects_caption=objects_caption,
270
+ objects_masks=objects_masks_padded,
271
+ objects_masks_maps=objects_masks_maps_padded,
272
+ img=img_tensor,
273
+ condition_img_masks_maps = subject_mask_tensor,
274
+ condition_img = subject_tensor,
275
+ original_size_condition_img = original_size_subject_tensor,
276
+ neg_condtion_img = subject_image_neg_tensor,
277
+ img_info = img_info,
278
+ target_width=target_width,
279
+ target_height=target_height,
280
+ )
281
+
282
+ return output
283
+
284
+
285
+ def collate_fn(examples):
286
+
287
+ collated_examples = {}
288
+
289
+ for key in ['id', 'objects_caption', 'caption','img_info','target_width','target_height']:
290
+ collated_examples[key] = [example[key] for example in examples]
291
+
292
+ for key in ['img', 'objects_boxes', 'objects_masks','condition_img','neg_condtion_img','objects_masks_maps','condition_img_masks_maps','original_size_condition_img']:
293
+ collated_examples[key] = torch.stack([example[key] for example in examples]).float()
294
+
295
+ return collated_examples
296
+
297
+
298
+
299
+
300
+ from typing import Dict
301
+
302
+ import numpy as np
303
+ from PIL import Image, ImageDraw, ImageFont, ImageOps
304
+ import random
305
+ def draw_mask(mask, draw, random_color=True):
306
+ """Draws a mask with a specified color on an image.
307
+
308
+ Args:
309
+ mask (np.array): Binary mask as a NumPy array.
310
+ draw (ImageDraw.Draw): ImageDraw object to draw on the image.
311
+ random_color (bool): Whether to use a random color for the mask.
312
+ """
313
+ if random_color:
314
+ color = (
315
+ random.randint(0, 255),
316
+ random.randint(0, 255),
317
+ random.randint(0, 255),
318
+ 153,
319
+ )
320
+ else:
321
+ color = (30, 144, 255, 153)
322
+
323
+ nonzero_coords = np.transpose(np.nonzero(mask))
324
+
325
+ for coord in nonzero_coords:
326
+ draw.point(coord[::-1], fill=color)
327
+
328
+ def visualize_bbox(image_pil: Image,
329
+ result: Dict,
330
+ draw_width: float = 6.0,
331
+ return_mask=True) -> Image:
332
+ """Plot bounding boxes and labels on an image with text wrapping for long descriptions.
333
+
334
+ Args:
335
+ image_pil (PIL.Image): The input image as a PIL Image object.
336
+ result (Dict[str, Union[torch.Tensor, List[torch.Tensor]]]): The target dictionary containing
337
+ the bounding boxes and labels. The keys are:
338
+ - boxes (List[int]): A list of bounding boxes in shape (N, 4), [x1, y1, x2, y2] format.
339
+ - labels (List[str]): A list of labels for each object
340
+ - masks (List[PIL.Image], optional): A list of masks in the format of PIL.Image
341
+
342
+ Returns:
343
+ PIL.Image: The input image with plotted bounding boxes, labels, and masks.
344
+ """
345
+ # Get the bounding boxes and labels from the target dictionary
346
+ boxes = result["boxes"]
347
+ categorys = result["labels"]
348
+ masks = result.get("masks", [])
349
+
350
+ color_list = [(255, 162, 76), (177, 214, 144),
351
+ (13, 146, 244), (249, 84, 84), (54, 186, 152),
352
+ (74, 36, 157), (0, 159, 189),
353
+ (80, 118, 135), (188, 90, 148), (119, 205, 255)]
354
+
355
+ # Use smaller font size to allow more text to be displayed
356
+ font_size = 30 # Reduce font size
357
+ font = ImageFont.truetype("dataloader/arial.ttf", font_size)
358
+
359
+ # Get image dimensions
360
+ img_width, img_height = image_pil.size
361
+
362
+ # Find all unique categories and build a cate2color dictionary
363
+ cate2color = {}
364
+ unique_categorys = sorted(set(categorys))
365
+ for idx, cate in enumerate(unique_categorys):
366
+ cate2color[cate] = color_list[idx % len(color_list)]
367
+
368
+ # Create a PIL ImageDraw object to draw on the input image
369
+ if isinstance(image_pil, np.ndarray):
370
+ image_pil = Image.fromarray(image_pil)
371
+ draw = ImageDraw.Draw(image_pil)
372
+
373
+ # Create a new binary mask image with the same size as the input image
374
+ mask = Image.new("L", image_pil.size, 0)
375
+ # Create a PIL ImageDraw object to draw on the mask image
376
+ mask_draw = ImageDraw.Draw(mask)
377
+
378
+ # Draw boxes, labels, and masks for each box and label in the target dictionary
379
+ for box, category in zip(boxes, categorys):
380
+ # Extract the box coordinates
381
+ x0, y0, x1, y1 = box
382
+ x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
383
+ box_width = x1 - x0
384
+ box_height = y1 - y0
385
+ color = cate2color.get(category, color_list[0]) # Default color
386
+
387
+ # Draw the box outline on the input image
388
+ draw.rectangle([x0, y0, x1, y1], outline=color, width=int(draw_width))
389
+
390
+ # Allow text box to be maximum 2 times the bounding box width, but not exceed image boundaries
391
+ max_text_width = min(box_width * 2, img_width - x0)
392
+
393
+ # Determine the maximum height for text background area
394
+ max_text_height = min(box_height * 2, 200) # Also allow more text display, but limit height
395
+
396
+ # Handle long text based on bounding box width, split text into lines
397
+ lines = []
398
+ words = category.split()
399
+ current_line = words[0]
400
+
401
+ for word in words[1:]:
402
+ # Try to add the next word
403
+ test_line = current_line + " " + word
404
+ # Use textbbox or textlength to check if width fits the maximum text width
405
+ if hasattr(draw, "textbbox"):
406
+ # Use textbbox method
407
+ bbox = draw.textbbox((0, 0), test_line, font=font)
408
+ w = bbox[2] - bbox[0]
409
+ elif hasattr(draw, "textlength"):
410
+ # Use textlength method
411
+ w = draw.textlength(test_line, font=font)
412
+ else:
413
+ # Fallback - estimate width
414
+ w = len(test_line) * (font_size * 0.6) # Estimate average character width
415
+
416
+ if w <= max_text_width - 20: # Leave some margin
417
+ current_line = test_line
418
+ else:
419
+ lines.append(current_line)
420
+ current_line = word
421
+
422
+ lines.append(current_line) # Add the last line
423
+
424
+ # Limit number of lines to prevent overflow
425
+ max_lines = max_text_height // (font_size + 2) # Line height (font size + spacing)
426
+ if len(lines) > max_lines:
427
+ lines = lines[:max_lines-1]
428
+ lines.append("...") # Add ellipsis
429
+
430
+ # Calculate actual required width for each line
431
+ line_widths = []
432
+ for line in lines:
433
+ if hasattr(draw, "textbbox"):
434
+ bbox = draw.textbbox((0, 0), line, font=font)
435
+ line_width = bbox[2] - bbox[0]
436
+ elif hasattr(draw, "textlength"):
437
+ line_width = draw.textlength(line, font=font)
438
+ else:
439
+ line_width = len(line) * (font_size * 0.6) # Estimate width
440
+ line_widths.append(line_width)
441
+
442
+ # Determine actual required width for text box
443
+ if line_widths:
444
+ needed_text_width = max(line_widths) + 10 # Add small margin
445
+ else:
446
+ needed_text_width = 0
447
+
448
+ # Use bounding box width as minimum, only expand when needed
449
+ text_bg_width = max(box_width, min(needed_text_width, max_text_width))
450
+
451
+ # Ensure it doesn't exceed image boundaries
452
+ text_bg_width = min(text_bg_width, img_width - x0)
453
+
454
+ # Calculate text background height
455
+ text_bg_height = len(lines) * (font_size + 2)
456
+
457
+ # Ensure text background doesn't exceed image bottom
458
+ if y0 + text_bg_height > img_height:
459
+ # If it would exceed bottom, adjust text position to above the bounding box bottom
460
+ text_y0 = max(0, y1 - text_bg_height)
461
+ else:
462
+ text_y0 = y0
463
+
464
+ # Draw text background - note RGBA color handling
465
+ if image_pil.mode == "RGBA":
466
+ # For RGBA mode, we can directly use alpha color
467
+ bg_color = (*color, 180) # Semi-transparent background
468
+ else:
469
+ # For RGB mode, we cannot use alpha
470
+ bg_color = color
471
+
472
+ draw.rectangle([x0, text_y0, x0 + text_bg_width, text_y0 + text_bg_height], fill=bg_color)
473
+
474
+ # Draw text
475
+ for i, line in enumerate(lines):
476
+ y_pos = text_y0 + i * (font_size + 2)
477
+ draw.text((x0 + 5, y_pos), line, fill="white", font=font)
478
+
479
+ # Draw the mask on the input image if masks are provided
480
+ if len(masks) > 0 and return_mask:
481
+ size = image_pil.size
482
+ mask_image = Image.new("RGBA", size, color=(0, 0, 0, 0))
483
+ mask_draw = ImageDraw.Draw(mask_image)
484
+ for mask in masks:
485
+ mask = np.array(mask)[:, :, -1]
486
+ draw_mask(mask, mask_draw)
487
+
488
+ image_pil = Image.alpha_composite(image_pil.convert("RGBA"), mask_image).convert("RGB")
489
+
490
+ return image_pil
491
+
492
+ import torchvision.transforms as T
493
+ from PIL import Image, ImageDraw, ImageFont, ImageChops
494
+
495
+ def tensor_to_pil(img_tensor):
496
+ """将tensor转换为PIL图像"""
497
+ img_tensor = img_tensor.cpu()
498
+ # 反归一化 ([0.5], [0.5])
499
+ img_tensor = img_tensor * 0.5 + 0.5
500
+ img_tensor = torch.clamp(img_tensor, 0, 1)
501
+ return T.ToPILImage()(img_tensor)
502
+
503
+ def make_image_grid_RGB(images, rows, cols, resize=None):
504
+ """
505
+ Prepares a single grid of images. Useful for visualization purposes.
506
+ """
507
+ assert len(images) == rows * cols
508
+
509
+ if resize is not None:
510
+ images = [img.resize((resize, resize)) for img in images]
511
+
512
+ w, h = images[0].size
513
+ grid = Image.new("RGB", size=(cols * w, rows * h))
514
+
515
+ for i, img in enumerate(images):
516
+ grid.paste(img.convert("RGB"), box=(i % cols * w, i // cols * h))
517
+ return grid
518
+
519
+ if __name__ == "__main__":
520
+ resolution = 1024
521
+ condition_resolution = 512
522
+ neg_condition_image = 'same'
523
+ background_color = 'gray'
524
+ use_bucket = True
525
+ condition_resolution_scale_ratio=0.5
526
+
527
+ benchmark_repo = 'HuiZhang0812/CreatiDesign_benchmark' # huggingface repo of benchmark
528
+
529
+ datasets = DesignDataset(dataset_name=benchmark_repo,
530
+ resolution=resolution,
531
+ condition_resolution=condition_resolution,
532
+ neg_condition_image =neg_condition_image,
533
+ background_color=background_color,
534
+ use_bucket=use_bucket,
535
+ condition_resolution_scale_ratio=condition_resolution_scale_ratio
536
+ )
537
+ test_dataloader = DataLoader(datasets, batch_size=1, shuffle=False, num_workers=1,collate_fn=collate_fn)
538
+
539
+ for i, batch in enumerate(tqdm(test_dataloader)):
540
+ prompts = batch["caption"]
541
+ imgs_id = batch['id']
542
+ objects_boxes = batch["objects_boxes"]
543
+ objects_caption = batch['objects_caption']
544
+ objects_masks = batch['objects_masks']
545
+ condition_img = batch['condition_img']
546
+ neg_condtion_img = batch['neg_condtion_img']
547
+ objects_masks_maps= batch['objects_masks_maps']
548
+ subject_masks_maps = batch['condition_img_masks_maps']
549
+ target_width=batch['target_width'][0]
550
+ target_height=batch['target_height'][0]
551
+
552
+ img_info = batch["img_info"][0]
553
+ filename = img_info["img_id"]+'.jpg'
554
+
eval/layout.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from PIL import Image
4
+ from tqdm import tqdm
5
+ from transformers import AutoModel, AutoTokenizer
6
+ import torch
7
+ from datasets import load_dataset
8
+ if __name__ == "__main__":
9
+ model_id ="openbmb/MiniCPM-V-2_6"
10
+ model = AutoModel.from_pretrained(model_id, trust_remote_code=True,
11
+ attn_implementation='sdpa', torch_dtype=torch.bfloat16) # sdpa or flash_attention_2, no eager
12
+ model = model.eval().cuda()
13
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
14
+
15
+ # evaluation
16
+ benchmark_repo = 'HuiZhang0812/CreatiDesign_benchmark' # huggingface repo of benchmark
17
+ benchmark = load_dataset(benchmark_repo, split="test")
18
+ gen_root = "outputs/CreatiDesign_benchmark/images"
19
+ print("processing:",gen_root)
20
+ save_json_path = gen_root.replace("images", "minicpm-vqa.json")
21
+ temp_root = gen_root.replace("images", "images-perarea")
22
+ os.makedirs(temp_root, exist_ok=True)
23
+
24
+ skipped_files_log = gen_root.replace("images", "skipped_files.log")
25
+ skipped_files = []
26
+ image_stats = {}
27
+
28
+ for case in tqdm(benchmark):
29
+ json_data = json.loads(case["metadata"])
30
+ case_info = json_data["img_info"]
31
+ case_id = case_info["img_id"]
32
+ file_name = f"{case_id}.jpg"
33
+ generated_img_path = os.path.join(gen_root, file_name)
34
+ global_caption = json_data["global_caption"]
35
+ object_annotations = json_data["object_annotations"]
36
+ detial_region_caption_list = [item["bbox_detail_description"] for item in object_annotations]
37
+ region_caption_list = [item["class_name"] for item in object_annotations]
38
+ region_bboxes_list = [item["bbox"] for item in object_annotations]
39
+
40
+ img = Image.open(generated_img_path).convert("RGB")
41
+ width, height = img.size
42
+
43
+ orignal_img_width = json_data["img_info"]["img_width"]
44
+ orignal_img_height = json_data["img_info"]["img_height"]
45
+
46
+ temp_save_root = os.path.join(temp_root, file_name.split('.')[0])
47
+ os.makedirs(temp_save_root, exist_ok=True)
48
+
49
+ bbox_count = len(region_caption_list)
50
+
51
+ # Initialize scores
52
+ img_score_spatial = 0
53
+ img_score_color = 0
54
+ img_score_texture = 0
55
+ img_score_shape = 0
56
+ for i, (bbox,detial_region_caption,region_caption) in enumerate(zip(region_bboxes_list,detial_region_caption_list,region_caption_list)):
57
+ x1, y1, x2, y2= bbox
58
+ x1 = int(x1 / orignal_img_width*width)
59
+ y1 = int(y1 / orignal_img_height*height)
60
+ x2 = int(x2 / orignal_img_width*width)
61
+ y2 = int(y2 / orignal_img_height*height)
62
+
63
+
64
+ cropped_img = img.crop((x1, y1, x2, y2))
65
+
66
+ # save crop img
67
+ description = region_caption.replace('/', '')
68
+ detail_description = detial_region_caption.replace('/', '')
69
+ cropped_img_path = os.path.join(temp_save_root, f'{description}.jpg')
70
+ cropped_img.save(cropped_img_path)
71
+
72
+ # spatial
73
+ question = f'Is the subject "{description}" present in the image? Strictly answer with "Yes" or "No", without any irrelevant words.'
74
+
75
+ msgs = [{'role': 'user', 'content': [cropped_img, question]}]
76
+
77
+ res = model.chat(
78
+ image=None,
79
+ msgs=msgs,
80
+ tokenizer=tokenizer,
81
+ seed=42
82
+ )
83
+
84
+ if "Yes" in res or "yes" in res:
85
+ score_spatial = 1.0
86
+ else:
87
+ score_spatial = 0.0
88
+
89
+ score_color, score_texture,score_shape = 0.0, 0.0, 0.0
90
+ # attribute
91
+ if score_spatial==1.0:
92
+ #color
93
+ question_color = f'Is the subject in "{description}" in the image consistent with the color described in the detailed description: "{detail_description}"? Strictly answer with "Yes" or "No", without any irrelevant words. If the color is not mentioned in the detailed description, the answer is "Yes".'
94
+ msgs_color = [{'role': 'user', 'content': [cropped_img, question_color]}]
95
+
96
+ color_attribute = model.chat(
97
+ image=None,
98
+ msgs=msgs_color,
99
+ tokenizer=tokenizer,
100
+ seed=42
101
+ )
102
+
103
+ if "Yes" in color_attribute or "yes" in color_attribute:
104
+ score_color = 1.0
105
+ # texture
106
+ if score_spatial==1.0:
107
+ question_texture = f'Is the subject in "{description}" in the image consistent with the texture described in the detailed description: "{detail_description}"? Strictly answer with "Yes" or "No", without any irrelevant words. If the texture is not mentioned in the detailed description, the answer is "Yes".'
108
+ msgs_texture = [{'role': 'user', 'content': [cropped_img, question_texture]}]
109
+
110
+ texture_attribute = model.chat(
111
+ image=None,
112
+ msgs=msgs_texture,
113
+ tokenizer=tokenizer,
114
+ seed=42
115
+ )
116
+ if "Yes" in texture_attribute or "yes" in texture_attribute:
117
+ score_texture = 1.0
118
+ #shape
119
+ if score_spatial==1.0:
120
+ question_shape = f'Is the subject in "{description}" in the image consistent with the shape described in the detailed description: "{detail_description}"? Strictly answer with "Yes" or "No", without any irrelevant words. If the shape is not mentioned in the detailed description, the answer is "Yes".'
121
+ msgs_shape = [{'role': 'user', 'content': [cropped_img, question_shape]}]
122
+
123
+ shape_attribute = model.chat(
124
+ image=None,
125
+ msgs=msgs_shape,
126
+ tokenizer=tokenizer,
127
+ seed=42
128
+ )
129
+
130
+ if "Yes" in shape_attribute or "yes" in shape_attribute:
131
+ score_shape = 1.0
132
+
133
+ # Update total scores
134
+ img_score_spatial += score_spatial
135
+ img_score_color += score_color
136
+ img_score_texture += score_texture
137
+ img_score_shape += score_shape
138
+
139
+
140
+ # Store image stats
141
+ image_stats[os.path.basename(file_name)] = {
142
+ "bbox_count": bbox_count,
143
+ "score_spatial": img_score_spatial,
144
+ "score_color": img_score_color,
145
+ "score_texture": img_score_texture,
146
+ "score_shape": img_score_shape,
147
+ }
148
+
149
+ if len(image_stats) % 50 == 0:
150
+ with open(save_json_path, 'w', encoding='utf-8') as json_file:
151
+ json.dump(image_stats, json_file, indent=4)
152
+
153
+ # Save the image_stats dictionary to a JSON file
154
+ with open(save_json_path, 'w', encoding='utf-8') as json_file:
155
+ json.dump(image_stats, json_file, indent=4)
156
+
157
+ print(f"Image statistics saved to {save_json_path}")
158
+
159
+
160
+ score_save_path = save_json_path.replace('minicpm-vqa.json', 'minicpm-vqa-score.txt')
161
+
162
+ # Read the JSON file containing image statistics
163
+ with open(save_json_path, "r") as f:
164
+ json_data = json.load(f)
165
+
166
+ total_num = 0
167
+ total_bbox_num = 0
168
+ total_score_spatial = 0
169
+ total_score_color = 0
170
+ total_score_texture = 0
171
+ total_score_shape = 0
172
+
173
+ miss_match =0
174
+ # Iterate over the JSON data
175
+ for key, value in json_data.items():
176
+
177
+ total_num += value["bbox_count"]
178
+ total_score_spatial +=value["score_spatial"]
179
+ total_score_color +=value["score_color"]
180
+ total_score_texture +=value["score_texture"]
181
+ total_score_shape +=value["score_shape"]
182
+
183
+ if value["bbox_count"]!=value["score_spatial"] or value["bbox_count"]!=value["score_color"] or value["bbox_count"]!=value["score_texture"] or value["bbox_count"]!=value["score_shape"]:
184
+ print(key,value["bbox_count"],value["score_spatial"],value["score_color"],value["score_texture"],value["score_shape"])
185
+ miss_match+=1
186
+
187
+ print(miss_match)
188
+ #save total_score_spatial,total_score_color,total_score_texture,total_score_shape
189
+ with open(score_save_path, "w") as f:
190
+ f.write(f"Total number of bbox: {total_num}\n")
191
+ f.write(f"Total score of spatial: {total_score_spatial}; Average score of spatial: {round(total_score_spatial/total_num,4)}\n")
192
+ f.write(f"Total score of color: {total_score_color}; Average score of color: {round(total_score_color/total_num,4)}\n")
193
+ f.write(f"Total score of texture: {total_score_texture}; Average score of texture: {round(total_score_texture/total_num,4)}\n")
194
+ f.write(f"Total score of shape: {total_score_shape}; Average score of shape: {round(total_score_shape/total_num,4)}\n")
eval/subject.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, json, math, argparse, glob
2
+ from pathlib import Path
3
+ from typing import List
4
+ import torch
5
+ from PIL import Image
6
+ import pandas as pd
7
+ from tqdm import tqdm
8
+ from transformers import (
9
+ AutoProcessor, CLIPModel,
10
+ AutoImageProcessor, AutoModel
11
+ )
12
+ from datasets import load_dataset
13
+
14
+ def scale_bbox(bbox, ori_size, target_size):
15
+ x_min, y_min, x_max, y_max = bbox
16
+ ori_width, ori_height = ori_size
17
+ target_width, target_height = target_size
18
+
19
+ width_ratio = target_width / ori_width
20
+ height_ratio = target_height / ori_height
21
+
22
+ scaled_x_min = int(x_min * width_ratio)
23
+ scaled_y_min = int(y_min * height_ratio)
24
+ scaled_x_max = int(x_max * width_ratio)
25
+ scaled_y_max = int(y_max * height_ratio)
26
+
27
+ scaled_x_min = max(0, scaled_x_min)
28
+ scaled_y_min = max(0, scaled_y_min)
29
+ scaled_x_max = min(target_width, scaled_x_max)
30
+ scaled_y_max = min(target_height, scaled_y_max)
31
+
32
+ return [scaled_x_min, scaled_y_min, scaled_x_max, scaled_y_max]
33
+
34
+ @torch.no_grad()
35
+ def encode_clip(imgs: List[Image.Image]) -> torch.Tensor:
36
+ features_list = []
37
+ for img in imgs:
38
+ inputs = clip_processor(images=img, return_tensors="pt").to(device)
39
+ image_features = clip_model.get_image_features(**inputs)
40
+
41
+ normalized_features = image_features / image_features.norm(dim=1, keepdim=True)
42
+ features_list.append(normalized_features.squeeze().cpu())
43
+ return torch.stack(features_list)
44
+
45
+ @torch.no_grad()
46
+ def encode_dino(imgs: List[Image.Image]) -> torch.Tensor:
47
+ features_list = []
48
+ for img in imgs:
49
+ inputs = dino_processor(images=img, return_tensors="pt").to(device)
50
+ outputs = dino_model(**inputs)
51
+ image_features = outputs.last_hidden_state.mean(dim=1)
52
+ normalized_features = image_features / image_features.norm(dim=1, keepdim=True)
53
+ features_list.append(normalized_features.squeeze().cpu())
54
+ return torch.stack(features_list)
55
+
56
+ @torch.no_grad()
57
+ def cosine(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
58
+ return (a @ b.T).squeeze()
59
+
60
+ # ------------- Command line arguments -----------------
61
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
62
+ parser.add_argument("--benchmark_repo", type=str, default="HuiZhang0812/CreatiDesign_benchmark",
63
+ help="Root directory for one thousand cases")
64
+ parser.add_argument("--gen_root", type=str, default="outputs/CreatiDesign_benchmark",
65
+ help="Root directory for generated images (should have images/<case_id>.jpg underneath)")
66
+ parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"])
67
+ parser.add_argument("--outfile", type=str,
68
+ help="Path for result CSV; by default written to gen_root")
69
+ args = parser.parse_args()
70
+
71
+ print("handling:", args.gen_root)
72
+ if args.outfile is None:
73
+ args.outfile = os.path.join(args.gen_root,"scores.csv")
74
+
75
+ # Convert outfile to Path object
76
+ outfile_path = Path(args.outfile)
77
+
78
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
79
+ print(f"[INFO] Using device: {device}")
80
+
81
+ # ------------- Loading models -------------------
82
+ print("[INFO] loading CLIP...")
83
+ clip_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
84
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
85
+ clip_model.eval()
86
+
87
+ print("[INFO] loading DINOv2...")
88
+ dino_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
89
+ dino_model = AutoModel.from_pretrained('facebook/dinov2-base').to(device)
90
+ dino_model.eval()
91
+
92
+ benchmark = load_dataset(args.benchmark_repo, split="test")
93
+
94
+ DEBUG = True
95
+ if DEBUG:
96
+ subject_save_roor = os.path.join(args.gen_root,"subject-eval-visual")
97
+ os.makedirs(subject_save_roor,exist_ok=True)
98
+ records = []
99
+ for case in tqdm(benchmark):
100
+ json_data = json.loads(case["metadata"])
101
+ case_info = json_data["img_info"]
102
+ case_id = case_info["img_id"]
103
+
104
+ # ---------- Read reference subjects ----------
105
+ ref_imgs = case['condition_white_variants']
106
+ if len(ref_imgs) == 0:
107
+ print(f"[WARN] {case_id} has no reference subject, skipping")
108
+ continue
109
+
110
+ # ---------- Read generated image ----------
111
+ gen_path = os.path.join(args.gen_root, "images", f"{case_id}.jpg")
112
+ gen_img = Image.open(gen_path).convert("RGB")
113
+ # Get width and height of generated image
114
+ gen_width, gen_height = gen_img.size
115
+ reg_bbox_id = [item["bbox_idx"] for item in sorted(json_data["subject_annotations"], key=lambda x: x["bbox_idx"])]
116
+ ref_bbox = [item["bbox"] for item in sorted(json_data["subject_annotations"], key=lambda x: x["bbox_idx"])]
117
+ ori_width,ori_height = json_data["img_info"]["img_width"],json_data["img_info"]["img_height"]
118
+ # Extract corresponding images from the generated image
119
+ gen_imgs = []
120
+ for bbox in ref_bbox:
121
+ # Scale the bounding box
122
+ scaled_bbox = scale_bbox(
123
+ bbox,
124
+ (ori_width, ori_height),
125
+ (gen_width, gen_height)
126
+ )
127
+
128
+ # Crop the image area
129
+ x_min, y_min, x_max, y_max = scaled_bbox
130
+ cropped_img = gen_img.crop((x_min, y_min, x_max, y_max))
131
+ gen_imgs.append(cropped_img)
132
+ if DEBUG:
133
+ folder_root = os.path.join(subject_save_roor,case_id)
134
+ os.makedirs(folder_root,exist_ok=True)
135
+ # Save cropped images
136
+ for i, (img, img_id) in enumerate(zip(gen_imgs, reg_bbox_id)):
137
+ img.save(os.path.join(folder_root, f"{img_id}.png"))
138
+
139
+
140
+ # ---------- Features ----------
141
+ ref_clip = encode_clip(ref_imgs) # (n,dim)
142
+ gen_clip = encode_clip(gen_imgs) # (n,dim)
143
+
144
+ ref_dino = encode_dino(ref_imgs) # (n,dim)
145
+ gen_dino = encode_dino(gen_imgs) # (n,dim)
146
+
147
+ # ---------- Similarity ----------
148
+ clip_sims = torch.nn.functional.cosine_similarity(ref_clip, gen_clip)
149
+ dino_sims = torch.nn.functional.cosine_similarity(ref_dino, gen_dino)
150
+
151
+ clip_i = clip_sims.mean().item()
152
+ dino_avg = dino_sims.mean().item()
153
+ m_dino = dino_sims.prod().item()
154
+
155
+ records.append(dict(
156
+ case_id=case_id,
157
+ num_subject=len(ref_imgs),
158
+ clip_i=clip_i,
159
+ dino=dino_avg,
160
+ m_dino=m_dino
161
+ ))
162
+
163
+ # ---------------- Result statistics -----------------
164
+ df = pd.DataFrame(records).sort_values("case_id")
165
+ overall = df[["clip_i","dino","m_dino"]].mean().to_dict()
166
+
167
+ print("\n========== Overall Average ==========")
168
+ for k,v in overall.items():
169
+ print(f"{k:>8}: {v:.6f}")
170
+ print("=====================================\n")
171
+
172
+ # Group by number of subjects
173
+ df_by_subjects = {}
174
+ avg_by_subjects = {}
175
+
176
+ # Create subset for each subject count (1-5)
177
+ for i in range(1, 6):
178
+ # Filter records with subject count = i
179
+ subset = df[df["num_subject"] == i]
180
+
181
+ if len(subset) > 0:
182
+ # Calculate average for this group
183
+ subset_avg = subset[["clip_i", "dino", "m_dino"]].mean().to_dict()
184
+ avg_by_subjects[i] = subset_avg
185
+
186
+ # Create subset with average row
187
+ avg_row = {"case_id": f"average_subject_{i}", "num_subject": i}
188
+ avg_row.update(subset_avg)
189
+
190
+ # Add average row to subset
191
+ subset_with_avg = pd.concat([subset, pd.DataFrame([avg_row])], ignore_index=True)
192
+ df_by_subjects[i] = subset_with_avg
193
+
194
+ # Print average for this group
195
+ print(f"\n=== Subject {i} Average (n={len(subset)}) ===")
196
+ for k, v in subset_avg.items():
197
+ print(f"{k:>8}: {v:.6f}")
198
+
199
+ # Save subset - fixed path handling
200
+ subject_path = outfile_path.parent / f"{outfile_path.stem}_subject{i}_location_prior{outfile_path.suffix}"
201
+ subset_with_avg.to_csv(subject_path, index=False, float_format="%.6f")
202
+ print(f"[INFO] Subject {i} results written to {subject_path}")
203
+
204
+ # Save overall average to CSV - fixed path handling
205
+ overall_df = pd.DataFrame([overall], index=["overall"])
206
+ overall_path = outfile_path.parent / f"{outfile_path.stem}_overall_location_prior{outfile_path.suffix}"
207
+ overall_df.to_csv(overall_path, float_format="%.6f")
208
+ print(f"[INFO] Overall results written to {overall_path}")
209
+
210
+ # Write CSV
211
+ df.to_csv(args.outfile, index=False, float_format="%.6f")
212
+ print(f"[INFO] Written to {args.outfile}")
213
+
214
+ # Create statistics table with averages for all groups
215
+ if avg_by_subjects:
216
+ # Merge averages for each group into one table
217
+ stats_rows = []
218
+ for num_subject, avg_dict in avg_by_subjects.items():
219
+ row = {"num_subject": num_subject}
220
+ row.update(avg_dict)
221
+ stats_rows.append(row)
222
+
223
+ # Add overall average
224
+ overall_row = {"num_subject": "all"}
225
+ overall_row.update(overall)
226
+ stats_rows.append(overall_row)
227
+
228
+ # Create summary statistics table
229
+ stats_df = pd.DataFrame(stats_rows)
230
+ # Fixed path handling
231
+ stats_path = outfile_path.parent / f"{outfile_path.stem}_stats_location_prior{outfile_path.suffix}"
232
+ stats_df.to_csv(stats_path, index=False, float_format="%.6f")
233
+ print(f"[INFO] All group statistics written to {stats_path}")
eval/text.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, csv, re, cv2, numpy as np, torch
2
+ from tqdm import tqdm
3
+ from editdistance import eval as edit_distance
4
+ from paddleocr import PaddleOCR
5
+ from datasets import load_dataset
6
+ # -------------------------------------------------------------------
7
+ # Paths
8
+ benchmark_repo = 'HuiZhang0812/CreatiDesign_benchmark' # huggingface repo of benchmark
9
+ benchmark = load_dataset(benchmark_repo, split="test")
10
+ root_gen = "outputs/CreatiDesign_benchmark/images"
11
+
12
+ save_root = root_gen.replace("images", "text_eval") # Output directory
13
+ os.makedirs(save_root, exist_ok=True)
14
+ DEBUG = True
15
+ # -------------------------------------------------------------------
16
+ # 1. OCR initialization (must be det=True)
17
+ ocr = PaddleOCR(det=True, rec=True, cls=False, use_angle_cls=False, lang='en')
18
+
19
+ # -------------------------------------------------------------------
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+
22
+ # -------------------------------------------------------------------
23
+ # 3. Utility functions
24
+
25
+ def spatial_match_iou(det_res, gt_box, gt_text_fmt, iou_thr=0.5):
26
+ best_iou = 0.0
27
+ if det_res is None or len(det_res) == 0:
28
+ return best_iou
29
+
30
+ for item in det_res:
31
+ poly = item[0] # Detection box coordinates
32
+ txt_info = item[1] # Text information tuple
33
+ txt = txt_info[0] # Text content
34
+
35
+ if min_ned_substring(normalize_text(txt), gt_text_fmt) <= 0.7: # When calculating spatial, allow some degree of text error
36
+ iou_val = iou(quad2bbox(poly), gt_box)
37
+ best_iou = max(best_iou, iou_val)
38
+ return best_iou
39
+
40
+ # ① New tool: Minimum NED substring
41
+ def min_ned_substring(pred_fmt: str, tgt_fmt: str) -> float:
42
+ """
43
+ Find a substring in pred_fmt with the same length as tgt_fmt, to minimize normalized edit distance
44
+ Return the minimum value (0 ~ 1)
45
+ """
46
+ Lp, Lg = len(pred_fmt), len(tgt_fmt)
47
+ if Lg == 0:
48
+ return 0.0
49
+ if Lp < Lg: # If prediction string is shorter than target, calculate directly
50
+ return normalized_edit_distance(pred_fmt, tgt_fmt)
51
+
52
+ best = Lg # Maximum possible distance
53
+ for i in range(Lp - Lg + 1):
54
+ sub = pred_fmt[i:i+Lg]
55
+ d = edit_distance(sub, tgt_fmt)
56
+ if d < best:
57
+ best = d
58
+ if best == 0: # Early exit
59
+ break
60
+ return best / Lg # Normalize
61
+
62
+ def normalize_text(txt: str) -> str:
63
+ txt = txt.lower().replace(" ", "")
64
+ return re.sub(r"[^\w\s]", "", txt)
65
+
66
+ def normalized_edit_distance(pred: str, gt: str) -> float:
67
+ if not gt and not pred:
68
+ return 0.0
69
+ return edit_distance(pred, gt) / max(len(gt), len(pred))
70
+
71
+ def iou(boxA, boxB) -> float:
72
+ xA, yA = max(boxA[0], boxB[0]), max(boxA[1], boxB[1])
73
+ xB, yB = min(boxA[2], boxB[2]), min(boxA[3], boxB[3])
74
+ inter = max(0, xB - xA) * max(0, yB - yA)
75
+ if inter == 0:
76
+ return 0.0
77
+ areaA = (boxA[2]-boxA[0]) * (boxA[3]-boxA[1])
78
+ areaB = (boxB[2]-boxB[0]) * (boxB[3]-boxB[1])
79
+ return inter / (areaA + areaB - inter)
80
+
81
+ def quad2bbox(quad):
82
+ xs = [p[0] for p in quad]; ys = [p[1] for p in quad]
83
+ return [min(xs), min(ys), max(xs), max(ys)]
84
+
85
+ def crop(img, box):
86
+ h, w = img.shape[:2]
87
+ x1,y1,x2,y2 = map(int, box)
88
+ x1, y1 = max(0, x1), max(0, y1)
89
+ x2, y2 = min(w-1, x2), min(h-1, y2)
90
+ if x2 <= x1 or y2 <= y1:
91
+ return np.zeros((1,1,3), np.uint8)
92
+ return img[y1:y2, x1:x2]
93
+
94
+
95
+ # -------------------------------------------------------------------
96
+ # 4. Main loop
97
+ per_img_rows, all_sen_acc, all_ned, all_spatial, text_pairs = [], [], [], [], []
98
+
99
+ for case in tqdm(benchmark):
100
+ json_data = json.loads(case["metadata"])
101
+ case_info = json_data["img_info"]
102
+ case_id = case_info["img_id"]
103
+
104
+ gt_list = json_data["text_list"] # [{'text':..., 'bbox':[x1,y1,x2,y2]}, ...]
105
+ ori_w, ori_h = json_data["img_info"]["img_width"], json_data["img_info"]["img_height"]
106
+
107
+ img_path = os.path.join(root_gen, f"{case_id}.jpg")
108
+
109
+ img = cv2.imread(img_path)
110
+ H, W = img.shape[:2]
111
+ wr, hr = W / ori_w, H / ori_h # GT → Generated image scaling ratio
112
+
113
+ # ---------- 1) Full image OCR ----------
114
+ pred_lines = [] # Save OCR line text
115
+ ocr_res = ocr.ocr(img, cls=False)
116
+ if ocr_res and ocr_res[0]:
117
+ for quad, (txt, conf) in ocr_res[0]:
118
+ pred_lines.append(txt.strip())
119
+
120
+ # Concatenate into full text and normalize
121
+ pred_full_fmt = normalize_text(" ".join(pred_lines))
122
+
123
+ # ==========================================================
124
+ # ③ For each GT sentence, do "substring minimum NED" ---- no longer using IoU
125
+ img_sen_hits, img_neds, img_spatials = [], [], []
126
+
127
+ for t_idx, gt in enumerate(gt_list):
128
+ gt_text_orig = gt["text"].replace("\n", " ").strip()
129
+ gt_text_fmt = normalize_text(gt_text_orig)
130
+
131
+ # ---- Pure text matching ----
132
+ ned = min_ned_substring(pred_full_fmt, gt_text_fmt)
133
+ acc = 1.0 if ned == 0 else 0.0
134
+ img_sen_hits.append(acc)
135
+ img_neds.append(ned)
136
+
137
+ # ---------- Spatial consistency, using IOU ----------
138
+ gt_box = [v*wr if i%2==0 else v*hr for i,v in enumerate(gt["bbox"])]
139
+ det_res = ocr_res[0] if ocr_res else []
140
+ spatial_score = spatial_match_iou(det_res, gt_box, gt_text_fmt)
141
+ img_spatials.append(spatial_score) # Can be used directly or binarized
142
+ crop_box_int = list(map(int, gt_box))
143
+ img_crop = crop(img, crop_box_int)
144
+ if DEBUG:
145
+ # Save cropped image
146
+ img_crop_for_ocr_save_root = os.path.join(save_root, case_id)
147
+ os.makedirs(img_crop_for_ocr_save_root, exist_ok=True)
148
+ safe_text = gt_text_orig.replace('/', '_').replace('\\', '_')
149
+ safe_filename = f"{t_idx}_{safe_text}.jpg"
150
+ cv2.imwrite(os.path.join(img_crop_for_ocr_save_root, safe_filename), img_crop)
151
+
152
+ # --------- Record text pairs ----------
153
+ text_pairs.append({
154
+ "image_id" : case_id,
155
+ "text_id" : t_idx,
156
+ "gt_original" : gt_text_orig,
157
+ "gt_formatted" : gt_text_fmt
158
+ })
159
+
160
+ # ---------- 3) Summarize to image level ----------
161
+ sen_acc = float(np.mean(img_sen_hits))
162
+ ned = float(np.mean(img_neds))
163
+ spatial = float(np.mean(img_spatials))
164
+
165
+ per_img_rows.append([case_id, sen_acc, ned, spatial])
166
+ all_sen_acc.append(sen_acc)
167
+ all_ned.append(ned)
168
+ all_spatial.append(spatial)
169
+
170
+ # -------------------------------------------------------------------
171
+ # 5. Write results
172
+ result_root = root_gen.replace("images","")
173
+ csv_perimg = os.path.join(result_root, "text_results_per_image.csv")
174
+ with open(csv_perimg, "w", newline='', encoding="utf-8") as f:
175
+ w = csv.writer(f); w.writerow(["image_id","sen_acc","ned","score_spatial"]); w.writerows(per_img_rows)
176
+
177
+
178
+ with open(os.path.join(result_root, "text_overall.txt"), "w", encoding="utf-8") as f:
179
+ f.write(f"Images evaluated : {len(per_img_rows)}\n")
180
+ f.write(f"Global Sen ACC : {np.mean(all_sen_acc):.4f}\n")
181
+ f.write(f"Global NED : {np.mean(all_ned):.4f}\n")
182
+ f.write(f"Global Spatial : {np.mean(all_spatial):.4f}\n")
183
+
184
+ print("✓ Done! Results saved to", result_root)
modules/common/__pycache__/lora.cpython-310.pyc ADDED
Binary file (1.17 kB). View file
 
modules/common/lora.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class LoRALinearLayer(nn.Module):
5
+ def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
6
+ super().__init__()
7
+
8
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
9
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
10
+ self.network_alpha = network_alpha
11
+ self.rank = rank
12
+
13
+ nn.init.normal_(self.down.weight, std=1 / rank)
14
+ nn.init.zeros_(self.up.weight)
15
+
16
+ def forward(self, hidden_states):
17
+ orig_dtype = hidden_states.dtype
18
+ dtype = self.down.weight.dtype
19
+
20
+ down_hidden_states = self.down(hidden_states.to(dtype))
21
+ up_hidden_states = self.up(down_hidden_states)
22
+
23
+ if self.network_alpha is not None:
24
+ up_hidden_states *= self.network_alpha / self.rank
25
+
26
+ return up_hidden_states.to(orig_dtype)
modules/flux/__pycache__/attention_processor_flux_creatidesign.cpython-310.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e534db89ad40a8e61c4c32b8bbeb3084e7d01a83667a66f426dbdfdf93a13936
3
+ size 127465
modules/flux/__pycache__/transformer_flux_creatidesign.cpython-310.pyc ADDED
Binary file (25.8 kB). View file
 
modules/flux/attention_processor_flux_creatidesign.py ADDED
The diff for this file is too large to render. See raw diff
 
modules/flux/transformer_flux_creatidesign.py ADDED
@@ -0,0 +1,1004 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
25
+ from diffusers.models.attention import FeedForward
26
+ from modules.flux.attention_processor_flux_creatidesign import (
27
+ Attention,
28
+ AttentionProcessor,
29
+ DesignFluxAttnProcessor2_0,
30
+ FluxAttnProcessor2_0_NPU,
31
+ FusedFluxAttnProcessor2_0,
32
+ )
33
+ from diffusers.models.modeling_utils import ModelMixin
34
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
35
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
36
+ from diffusers.utils.import_utils import is_torch_npu_available
37
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
38
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
39
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
40
+ from modules.semantic_layout.layout_encoder import ObjectLayoutEncoder,ObjectLayoutEncoder_noFourier
41
+ from modules.common.lora import LoRALinearLayer
42
+
43
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
44
+
45
+
46
+
47
+
48
+
49
+ @maybe_allow_in_graph
50
+ class FluxSingleTransformerBlock(nn.Module):
51
+ r"""
52
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
53
+
54
+ Reference: https://arxiv.org/abs/2403.03206
55
+
56
+ Parameters:
57
+ dim (`int`): The number of channels in the input and output.
58
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
59
+ attention_head_dim (`int`): The number of channels in each head.
60
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
61
+ processing of `context` conditions.
62
+ """
63
+
64
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0, rank=16,network_alpha=16,lora_weight=1.0,attention_type="design"):
65
+ super().__init__()
66
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
67
+
68
+ self.norm = AdaLayerNormZeroSingle(dim)
69
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
70
+ self.act_mlp = nn.GELU(approximate="tanh")
71
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
72
+
73
+
74
+ if is_torch_npu_available():
75
+ processor = FluxAttnProcessor2_0_NPU()
76
+ else:
77
+ processor = DesignFluxAttnProcessor2_0()
78
+ self.attn = Attention(
79
+ query_dim=dim,
80
+ cross_attention_dim=None,
81
+ dim_head=attention_head_dim,
82
+ heads=num_attention_heads,
83
+ out_dim=dim,
84
+ bias=True,
85
+ processor=processor,
86
+ qk_norm="rms_norm",
87
+ eps=1e-6,
88
+ pre_only=True,
89
+ )
90
+
91
+ self.attention_type = attention_type
92
+ self.rank = rank
93
+ self.network_alpha = network_alpha
94
+ self.lora_weight = lora_weight
95
+ if attention_type == "design":
96
+ self.layernorm_subject = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) # layernorm for subject
97
+ self.norm_subject_lora = nn.Sequential(
98
+ nn.SiLU(),
99
+ LoRALinearLayer(dim, dim*3, self.rank, self.network_alpha) # lora for adalinear of subject
100
+ )
101
+ self.layernorm_object_bbox = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) # layernorm for object
102
+ self.norm_object_lora = nn.Sequential(
103
+ nn.SiLU(),
104
+ LoRALinearLayer(dim, dim*3, self.rank, self.network_alpha) # lora for adalinear of object
105
+ )
106
+ def single_block_adaln_lora_forward(self, x, temb, adaln, adaln_lora, layernorm, lora_weight):
107
+ norm_x, x_gate = adaln(x, emb=temb)
108
+ lora_shift_msa, lora_scale_msa, lora_gate_msa = adaln_lora(temb).chunk(3, dim=1)
109
+ norm_x = norm_x + lora_weight * (layernorm(x)* (1 + lora_scale_msa[:, None]) + lora_shift_msa[:, None])
110
+ x_gate = x_gate + lora_weight * lora_gate_msa
111
+ return norm_x, x_gate
112
+
113
+ def forward(
114
+ self,
115
+ hidden_states: torch.Tensor,
116
+ temb: torch.Tensor,
117
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
118
+ subject_hidden_states = None,
119
+ subject_rotary_emb = None,
120
+ object_bbox_hidden_states = None,
121
+ object_rotary_emb = None,
122
+ design_scale = 1.0,
123
+ attention_mask=None,
124
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
125
+ ) -> torch.Tensor:
126
+ residual = hidden_states
127
+
128
+ # handle hidden_states
129
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
130
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
131
+ #creatidesign
132
+ use_subject = True if self.attention_type == "design" and subject_hidden_states is not None and design_scale!=0.0 else False
133
+ use_object = True if self.attention_type == "design" and object_bbox_hidden_states is not None and design_scale!=0.0 else False
134
+ # handle subejct_hidden_states
135
+ if use_subject:
136
+ residual_subject_hidden_states = subject_hidden_states
137
+ norm_subject_hidden_states, subject_gate = self.single_block_adaln_lora_forward(subject_hidden_states, temb, self.norm, self.norm_subject_lora, self.layernorm_subject, self.lora_weight)
138
+ mlp_subject_hidden_states = self.act_mlp(self.proj_mlp(norm_subject_hidden_states))
139
+ if use_object:
140
+ residual_object_bbox_hidden_states = object_bbox_hidden_states
141
+ norm_object_bbox_hidden_states, object_gate = self.single_block_adaln_lora_forward(object_bbox_hidden_states, temb, self.norm, self.norm_object_lora, self.layernorm_object_bbox, self.lora_weight)
142
+ mlp_object_bbox_hidden_states = self.act_mlp(self.proj_mlp(norm_object_bbox_hidden_states))
143
+ joint_attention_kwargs = joint_attention_kwargs or {}
144
+ attn_output, subject_attn_output, object_attn_output = self.attn(
145
+ hidden_states=norm_hidden_states,
146
+ image_rotary_emb=image_rotary_emb,
147
+ subject_hidden_states=norm_subject_hidden_states,
148
+ subject_rotary_emb=subject_rotary_emb,
149
+ object_bbox_hidden_states=norm_object_bbox_hidden_states,
150
+ object_rotary_emb=object_rotary_emb,
151
+ attention_mask = attention_mask,
152
+ **joint_attention_kwargs,
153
+ )
154
+ # handle hidden states
155
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
156
+ gate = gate.unsqueeze(1)
157
+ hidden_states = gate * self.proj_out(hidden_states)
158
+ hidden_states = residual + hidden_states
159
+ #handle subject_hidden_states
160
+ if use_subject:
161
+ subject_hidden_states = torch.cat([subject_attn_output, mlp_subject_hidden_states], dim=2)
162
+ subject_gate = subject_gate.unsqueeze(1)
163
+ subject_hidden_states = subject_gate * self.proj_out(subject_hidden_states)
164
+ subject_hidden_states = residual_subject_hidden_states + subject_hidden_states
165
+
166
+ #handle object_bbox_hidden_states
167
+ if use_object:
168
+ object_bbox_hidden_states = torch.cat([object_attn_output, mlp_object_bbox_hidden_states], dim=2)
169
+ object_gate = object_gate.unsqueeze(1)
170
+ object_bbox_hidden_states = object_gate * self.proj_out(object_bbox_hidden_states)
171
+ object_bbox_hidden_states = residual_object_bbox_hidden_states + object_bbox_hidden_states
172
+ if hidden_states.dtype == torch.float16:
173
+ hidden_states = hidden_states.clip(-65504, 65504)
174
+
175
+ return hidden_states, subject_hidden_states, object_bbox_hidden_states
176
+
177
+
178
+ @maybe_allow_in_graph
179
+ class FluxTransformerBlock(nn.Module):
180
+ r"""
181
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
182
+
183
+ Reference: https://arxiv.org/abs/2403.03206
184
+
185
+ Args:
186
+ dim (`int`):
187
+ The embedding dimension of the block.
188
+ num_attention_heads (`int`):
189
+ The number of attention heads to use.
190
+ attention_head_dim (`int`):
191
+ The number of dimensions to use for each attention head.
192
+ qk_norm (`str`, defaults to `"rms_norm"`):
193
+ The normalization to use for the query and key tensors.
194
+ eps (`float`, defaults to `1e-6`):
195
+ The epsilon value to use for the normalization.
196
+ """
197
+
198
+ def __init__(
199
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6, rank=16, network_alpha=16, lora_weight=1.0,attention_type="design"
200
+ ):
201
+ super().__init__()
202
+
203
+ self.norm1 = AdaLayerNormZero(dim)
204
+
205
+ self.norm1_context = AdaLayerNormZero(dim)
206
+
207
+ if hasattr(F, "scaled_dot_product_attention"):
208
+ processor = DesignFluxAttnProcessor2_0()
209
+ else:
210
+ raise ValueError(
211
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
212
+ )
213
+ self.attn = Attention(
214
+ query_dim=dim,
215
+ cross_attention_dim=None,
216
+ added_kv_proj_dim=dim,
217
+ dim_head=attention_head_dim,
218
+ heads=num_attention_heads,
219
+ out_dim=dim,
220
+ context_pre_only=False,
221
+ bias=True,
222
+ processor=processor,
223
+ qk_norm=qk_norm,
224
+ eps=eps,
225
+ )
226
+
227
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
228
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
229
+
230
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
231
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
232
+
233
+ # let chunk size default to None
234
+ self._chunk_size = None
235
+ self._chunk_dim = 0
236
+
237
+ # creatidesign
238
+ self.attention_type = attention_type
239
+ self.rank = rank
240
+ self.network_alpha = network_alpha
241
+ self.lora_weight = lora_weight
242
+
243
+ if self.attention_type == "design":
244
+ # lora for handle subject (img branch)
245
+ self.norm1_subject_lora = nn.Sequential(
246
+ nn.SiLU(),
247
+ LoRALinearLayer(dim, dim*6, self.rank, self.network_alpha) # lora for adalinear
248
+ )
249
+ self.layernorm_subject = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) # norm for subject
250
+
251
+ # lora for handle object (txt branch)
252
+ self.norm1_object_lora = nn.Sequential(
253
+ nn.SiLU(),
254
+ LoRALinearLayer(dim, dim*6, self.rank, self.network_alpha) # lora for adalinear
255
+ )
256
+ self.layernorm_object = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) # norm for object
257
+
258
+ def double_block_adaln_lora_forward(self, x, temb, adaln, adaln_lora, layernorm, lora_weight):
259
+ norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = adaln(x, emb=temb)
260
+ lora_shift_msa, lora_scale_msa, lora_gate_msa, lora_shift_mlp, lora_scale_mlp, lora_gate_mlp = adaln_lora(temb).chunk(6, dim=1)
261
+ norm_x = norm_x + lora_weight * (layernorm(x)* (1 + lora_scale_msa[:, None]) + lora_shift_msa[:, None])
262
+ x_gate_msa = x_gate_msa + lora_weight*lora_gate_msa
263
+ x_shift_mlp = x_shift_mlp + lora_weight*lora_shift_mlp
264
+ x_scale_mlp = x_scale_mlp + lora_weight*lora_scale_mlp
265
+ x_gate_mlp = x_gate_mlp + lora_weight*lora_gate_mlp
266
+ return norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp
267
+ def forward(
268
+ self,
269
+ hidden_states: torch.Tensor,
270
+ encoder_hidden_states: torch.Tensor,
271
+ temb: torch.Tensor,
272
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
273
+ subject_hidden_states = None,
274
+ subject_rotary_emb = None,
275
+ object_bbox_hidden_states = None,
276
+ object_rotary_emb = None,
277
+ design_scale = 1.0,
278
+ attention_mask=None,
279
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
280
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
281
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
282
+
283
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
284
+ encoder_hidden_states, emb=temb
285
+ )
286
+ joint_attention_kwargs = joint_attention_kwargs or {}
287
+
288
+
289
+ use_subject = True if self.attention_type == "design" and subject_hidden_states is not None and design_scale!=0.0 else False
290
+ use_object = True if self.attention_type == "design" and object_bbox_hidden_states is not None and design_scale!=0.0 else False
291
+ if use_subject:
292
+ # subject adalinear
293
+ norm_subject_hidden_states, subject_gate_msa, subject_shift_mlp, subject_scale_mlp, subject_gate_mlp = self.double_block_adaln_lora_forward(
294
+ subject_hidden_states, temb, self.norm1, self.norm1_subject_lora, self.layernorm_subject, self.lora_weight
295
+ )
296
+ if use_object:
297
+ # object adalinear
298
+ norm_object_bbox_hidden_states, object_gate_msa, object_shift_mlp, object_scale_mlp, object_gate_mlp = self.double_block_adaln_lora_forward(
299
+ object_bbox_hidden_states, temb, self.norm1_context, self.norm1_object_lora, self.layernorm_object, self.lora_weight
300
+ )
301
+
302
+
303
+ attn_output, context_attn_output, subject_attn_output, object_attn_output = self.attn(
304
+ hidden_states=norm_hidden_states,
305
+ encoder_hidden_states=norm_encoder_hidden_states,
306
+ image_rotary_emb=image_rotary_emb,
307
+ subject_hidden_states=norm_subject_hidden_states if use_subject else None,
308
+ subject_rotary_emb=subject_rotary_emb if use_subject else None,
309
+ object_bbox_hidden_states=norm_object_bbox_hidden_states if use_object else None,
310
+ object_rotary_emb=object_rotary_emb if use_object else None,
311
+ attention_mask = attention_mask,
312
+ **joint_attention_kwargs,
313
+ )
314
+
315
+ # Process attention outputs for the `hidden_states`.
316
+ attn_output = gate_msa.unsqueeze(1) * attn_output
317
+ hidden_states = hidden_states + attn_output
318
+
319
+ norm_hidden_states = self.norm2(hidden_states)
320
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
321
+
322
+ ff_output = self.ff(norm_hidden_states)
323
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
324
+
325
+ hidden_states = hidden_states + ff_output
326
+
327
+
328
+
329
+ # Process attention outputs for the `encoder_hidden_states`.
330
+
331
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
332
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
333
+
334
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
335
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
336
+
337
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
338
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
339
+
340
+
341
+ # process attention outputs for the `subject_hidden_states`.
342
+ if use_subject:
343
+ subject_attn_output = subject_gate_msa.unsqueeze(1) * subject_attn_output
344
+ subject_hidden_states = subject_hidden_states + subject_attn_output
345
+ norm_subject_hidden_states = self.norm2(subject_hidden_states)
346
+ norm_subject_hidden_states = norm_subject_hidden_states * (1 + subject_scale_mlp[:, None]) + subject_shift_mlp[:, None]
347
+ subject_ff_output = self.ff(norm_subject_hidden_states)
348
+ subject_hidden_states = subject_hidden_states + subject_gate_mlp.unsqueeze(1) * subject_ff_output
349
+
350
+ # process attention outputs for the `object_bbox_hidden_states`.
351
+ if use_object:
352
+ object_attn_output = object_gate_msa.unsqueeze(1) * object_attn_output
353
+ object_bbox_hidden_states = object_bbox_hidden_states + object_attn_output
354
+ norm_object_bbox_hidden_states = self.norm2_context(object_bbox_hidden_states)
355
+ norm_object_bbox_hidden_states = norm_object_bbox_hidden_states * (1 + object_scale_mlp[:, None]) + object_shift_mlp[:, None]
356
+ object_ff_output = self.ff_context(norm_object_bbox_hidden_states)
357
+ object_bbox_hidden_states = object_bbox_hidden_states + object_gate_mlp.unsqueeze(1) * object_ff_output
358
+
359
+ if encoder_hidden_states.dtype == torch.float16:
360
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
361
+
362
+ return encoder_hidden_states, hidden_states, subject_hidden_states, object_bbox_hidden_states
363
+
364
+
365
+ class FluxTransformer2DModel(
366
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
367
+ ):
368
+ """
369
+ The Transformer model introduced in Flux.
370
+
371
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
372
+
373
+ Args:
374
+ patch_size (`int`, defaults to `1`):
375
+ Patch size to turn the input data into small patches.
376
+ in_channels (`int`, defaults to `64`):
377
+ The number of channels in the input.
378
+ out_channels (`int`, *optional*, defaults to `None`):
379
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
380
+ num_layers (`int`, defaults to `19`):
381
+ The number of layers of dual stream DiT blocks to use.
382
+ num_single_layers (`int`, defaults to `38`):
383
+ The number of layers of single stream DiT blocks to use.
384
+ attention_head_dim (`int`, defaults to `128`):
385
+ The number of dimensions to use for each attention head.
386
+ num_attention_heads (`int`, defaults to `24`):
387
+ The number of attention heads to use.
388
+ joint_attention_dim (`int`, defaults to `4096`):
389
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
390
+ `encoder_hidden_states`).
391
+ pooled_projection_dim (`int`, defaults to `768`):
392
+ The number of dimensions to use for the pooled projection.
393
+ guidance_embeds (`bool`, defaults to `False`):
394
+ Whether to use guidance embeddings for guidance-distilled variant of the model.
395
+ axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
396
+ The dimensions to use for the rotary positional embeddings.
397
+ """
398
+
399
+ _supports_gradient_checkpointing = True
400
+ _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
401
+
402
+ @register_to_config
403
+ def __init__(
404
+ self,
405
+ patch_size: int = 1,
406
+ in_channels: int = 64,
407
+ out_channels: Optional[int] = None,
408
+ num_layers: int = 19,
409
+ num_single_layers: int = 38,
410
+ attention_head_dim: int = 128,
411
+ num_attention_heads: int = 24,
412
+ joint_attention_dim: int = 4096,
413
+ pooled_projection_dim: int = 768,
414
+ guidance_embeds: bool = False,
415
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
416
+ attention_type="design",
417
+ max_boxes_token_length=30,
418
+ rank = 16,
419
+ network_alpha = 16,
420
+ lora_weight = 1.0,
421
+ use_attention_mask = True,
422
+ use_objects_masks_maps=True,
423
+ use_subject_masks_maps=True,
424
+ use_layout_encoder=True,
425
+ drop_subject_bg=False,
426
+ gradient_checkpointing=False,
427
+ use_fourier_bbox=True,
428
+ bbox_id_shift=True
429
+ ):
430
+ super().__init__()
431
+ # #creatidesign
432
+ self.attention_type = attention_type
433
+ self.max_boxes_token_length = max_boxes_token_length
434
+ self.rank = rank
435
+ self.network_alpha = network_alpha
436
+ self.lora_weight = lora_weight
437
+ self.use_attention_mask = use_attention_mask
438
+ self.use_objects_masks_maps= use_objects_masks_maps
439
+ self.num_attention_heads=num_attention_heads
440
+ self.use_layout_encoder = use_layout_encoder
441
+ self.use_subject_masks_maps = use_subject_masks_maps
442
+ self.drop_subject_bg = drop_subject_bg
443
+ self.gradient_checkpointing = gradient_checkpointing
444
+ self.use_fourier_bbox = use_fourier_bbox
445
+ self.bbox_id_shift = bbox_id_shift
446
+
447
+
448
+ self.out_channels = out_channels or in_channels
449
+ self.inner_dim = num_attention_heads * attention_head_dim
450
+
451
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
452
+
453
+ text_time_guidance_cls = (
454
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
455
+ )
456
+ self.time_text_embed = text_time_guidance_cls(
457
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
458
+ )
459
+
460
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
461
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim)
462
+
463
+ self.transformer_blocks = nn.ModuleList(
464
+ [
465
+ FluxTransformerBlock(
466
+ dim=self.inner_dim,
467
+ num_attention_heads=num_attention_heads,
468
+ attention_head_dim=attention_head_dim,
469
+ attention_type=self.attention_type,
470
+ rank=self.rank,
471
+ network_alpha=self.network_alpha,
472
+ lora_weight=self.lora_weight,
473
+ )
474
+ for _ in range(num_layers)
475
+ ]
476
+ )
477
+
478
+ self.single_transformer_blocks = nn.ModuleList(
479
+ [
480
+ FluxSingleTransformerBlock(
481
+ dim=self.inner_dim,
482
+ num_attention_heads=num_attention_heads,
483
+ attention_head_dim=attention_head_dim,
484
+ attention_type=self.attention_type,
485
+ rank=self.rank,
486
+ network_alpha=self.network_alpha,
487
+ lora_weight=self.lora_weight,
488
+ )
489
+ for _ in range(num_single_layers)
490
+ ]
491
+ )
492
+
493
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
494
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
495
+
496
+
497
+ if self.attention_type =="design":
498
+ if self.use_layout_encoder:
499
+ if self.use_fourier_bbox:
500
+ self.object_layout_encoder = ObjectLayoutEncoder(
501
+ positive_len=self.inner_dim, out_dim=self.inner_dim, max_boxes_token_length=self.max_boxes_token_length
502
+ )
503
+ else:
504
+ self.object_layout_encoder = ObjectLayoutEncoder_noFourier(
505
+ in_dim=self.inner_dim, out_dim=self.inner_dim
506
+ )
507
+
508
+
509
+ @property
510
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
511
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
512
+ r"""
513
+ Returns:
514
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
515
+ indexed by its weight name.
516
+ """
517
+ # set recursively
518
+ processors = {}
519
+
520
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
521
+ if hasattr(module, "get_processor"):
522
+ processors[f"{name}.processor"] = module.get_processor()
523
+
524
+ for sub_name, child in module.named_children():
525
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
526
+
527
+ return processors
528
+
529
+ for name, module in self.named_children():
530
+ fn_recursive_add_processors(name, module, processors)
531
+
532
+ return processors
533
+
534
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
535
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
536
+ r"""
537
+ Sets the attention processor to use to compute attention.
538
+
539
+ Parameters:
540
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
541
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
542
+ for **all** `Attention` layers.
543
+
544
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
545
+ processor. This is strongly recommended when setting trainable attention processors.
546
+
547
+ """
548
+ count = len(self.attn_processors.keys())
549
+
550
+ if isinstance(processor, dict) and len(processor) != count:
551
+ raise ValueError(
552
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
553
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
554
+ )
555
+
556
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
557
+ if hasattr(module, "set_processor"):
558
+ if not isinstance(processor, dict):
559
+ module.set_processor(processor)
560
+ else:
561
+ module.set_processor(processor.pop(f"{name}.processor"))
562
+
563
+ for sub_name, child in module.named_children():
564
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
565
+
566
+ for name, module in self.named_children():
567
+ fn_recursive_attn_processor(name, module, processor)
568
+
569
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
570
+ def fuse_qkv_projections(self):
571
+ """
572
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
573
+ are fused. For cross-attention modules, key and value projection matrices are fused.
574
+
575
+ <Tip warning={true}>
576
+
577
+ This API is 🧪 experimental.
578
+
579
+ </Tip>
580
+ """
581
+ self.original_attn_processors = None
582
+
583
+ for _, attn_processor in self.attn_processors.items():
584
+ if "Added" in str(attn_processor.__class__.__name__):
585
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
586
+
587
+ self.original_attn_processors = self.attn_processors
588
+
589
+ for module in self.modules():
590
+ if isinstance(module, Attention):
591
+ module.fuse_projections(fuse=True)
592
+
593
+ self.set_attn_processor(FusedFluxAttnProcessor2_0())
594
+
595
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
596
+ def unfuse_qkv_projections(self):
597
+ """Disables the fused QKV projection if enabled.
598
+
599
+ <Tip warning={true}>
600
+
601
+ This API is 🧪 experimental.
602
+
603
+ </Tip>
604
+
605
+ """
606
+ if self.original_attn_processors is not None:
607
+ self.set_attn_processor(self.original_attn_processors)
608
+
609
+ def _set_gradient_checkpointing(self, module, value=False):
610
+ if hasattr(module, "gradient_checkpointing"):
611
+ module.gradient_checkpointing = value
612
+
613
+ def forward(
614
+ self,
615
+ hidden_states: torch.Tensor,
616
+ encoder_hidden_states: torch.Tensor = None,
617
+ pooled_projections: torch.Tensor = None,
618
+ timestep: torch.LongTensor = None,
619
+ img_ids: torch.Tensor = None,
620
+ txt_ids: torch.Tensor = None,
621
+ guidance: torch.Tensor = None,
622
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
623
+ controlnet_block_samples=None,
624
+ controlnet_single_block_samples=None,
625
+ return_dict: bool = True,
626
+ controlnet_blocks_repeat: bool = False,
627
+ design_kwargs: dict | None = None,
628
+ design_scale =1.0
629
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
630
+ """
631
+ The [`FluxTransformer2DModel`] forward method.
632
+
633
+ Args:
634
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
635
+ Input `hidden_states`.
636
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
637
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
638
+ pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
639
+ from the embeddings of input conditions.
640
+ timestep ( `torch.LongTensor`):
641
+ Used to indicate denoising step.
642
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
643
+ A list of tensors that if specified are added to the residuals of transformer blocks.
644
+ joint_attention_kwargs (`dict`, *optional*):
645
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
646
+ `self.processor` in
647
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
648
+ return_dict (`bool`, *optional*, defaults to `True`):
649
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
650
+ tuple.
651
+
652
+ Returns:
653
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
654
+ `tuple` where the first element is the sample tensor.
655
+ """
656
+ if joint_attention_kwargs is not None:
657
+ joint_attention_kwargs = joint_attention_kwargs.copy()
658
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
659
+ else:
660
+ lora_scale = 1.0
661
+
662
+ if USE_PEFT_BACKEND:
663
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
664
+ scale_lora_layers(self, lora_scale)
665
+ else:
666
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
667
+ logger.warning(
668
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
669
+ )
670
+
671
+
672
+ hidden_states = self.x_embedder(hidden_states)
673
+
674
+ timestep = timestep.to(hidden_states.dtype) * 1000
675
+ if guidance is not None:
676
+ guidance = guidance.to(hidden_states.dtype) * 1000
677
+ else:
678
+ guidance = None
679
+
680
+ temb = (
681
+ self.time_text_embed(timestep, pooled_projections)
682
+ if guidance is None
683
+ else self.time_text_embed(timestep, guidance, pooled_projections)
684
+ )
685
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
686
+
687
+ if txt_ids.ndim == 3:
688
+ # logger.warning(
689
+ # "Passing `txt_ids` 3d torch.Tensor is deprecated."
690
+ # "Please remove the batch dimension and pass it as a 2d torch Tensor"
691
+ # )
692
+ txt_ids = txt_ids[0]
693
+ if img_ids.ndim == 3:
694
+ # logger.warning(
695
+ # "Passing `img_ids` 3d torch.Tensor is deprecated."
696
+ # "Please remove the batch dimension and pass it as a 2d torch Tensor"
697
+ # )
698
+ img_ids = img_ids[0]
699
+
700
+ attention_mask_batch = None
701
+ # handle design infos
702
+ if self.attention_type=="design" and design_kwargs is not None:
703
+
704
+ # handle objects
705
+ objects_boxes = design_kwargs['object_layout']['objects_boxes'].to(dtype=hidden_states.dtype, device=hidden_states.device) # [B,10,4]
706
+ objects_bbox_text_embeddings = design_kwargs['object_layout']['bbox_text_embeddings'].to(dtype=hidden_states.dtype, device=hidden_states.device) # [B,10,512,4096]
707
+ objects_bbox_masks = design_kwargs['object_layout']['bbox_masks'].to(dtype=hidden_states.dtype, device=hidden_states.device) # [B,10]
708
+ #token Truncation
709
+ objects_bbox_text_embeddings = objects_bbox_text_embeddings[:,:,:self.max_boxes_token_length,:]# [B,10,30,4096]
710
+
711
+ # [B,10,30,4096] -> [B*10,30,4096] -> [B*10,30,3072] -> [B,10,30,3072]
712
+ B, N, S, C = objects_bbox_text_embeddings.shape
713
+ objects_bbox_text_embeddings = objects_bbox_text_embeddings.reshape(-1, S, C) #[B*10,30,4096]
714
+ objects_bbox_text_embeddings = self.context_embedder(objects_bbox_text_embeddings) #[B*10,30,3072]
715
+ objects_bbox_text_embeddings = objects_bbox_text_embeddings.reshape(B, N, S, -1) # [B,10,30,3072]
716
+
717
+ if self.use_layout_encoder:
718
+ if self.use_fourier_bbox:
719
+ object_bbox_hidden_states = self.object_layout_encoder(
720
+ boxes=objects_boxes,
721
+ masks=objects_bbox_masks,
722
+ positive_embeddings=objects_bbox_text_embeddings,
723
+ )# [B,10,30,3072]
724
+ else:
725
+ object_bbox_hidden_states = self.object_layout_encoder(
726
+ positive_embeddings=objects_bbox_text_embeddings,
727
+ )# [B,10,30,3072]
728
+ else:
729
+ object_bbox_hidden_states = objects_bbox_text_embeddings
730
+
731
+ object_bbox_hidden_states = object_bbox_hidden_states.contiguous().view(B, N*S, -1) # [B,300,3072]
732
+
733
+ # bbox_id shift
734
+ if self.bbox_id_shift:
735
+ object_bbox_ids = -1 * torch.ones(object_bbox_hidden_states.shape[0], object_bbox_hidden_states.shape[1], 3).to(device=object_bbox_hidden_states.device, dtype=object_bbox_hidden_states.dtype)
736
+ else:
737
+ object_bbox_ids = torch.zeros(object_bbox_hidden_states.shape[0], object_bbox_hidden_states.shape[1], 3).to(device=object_bbox_hidden_states.device, dtype=object_bbox_hidden_states.dtype)
738
+ if object_bbox_ids.ndim == 3:
739
+ object_bbox_ids = object_bbox_ids[0] #[300,3]
740
+ object_rotary_emb = self.pos_embed(object_bbox_ids)
741
+
742
+
743
+
744
+ # handle subjects
745
+ subject_hidden_states = design_kwargs['subject_contion']['condition_img']
746
+ subject_hidden_states = self.x_embedder(subject_hidden_states)
747
+ subject_ids = design_kwargs['subject_contion']['condition_img_ids']
748
+ if subject_ids.ndim == 3:
749
+ subject_ids = subject_ids[0]
750
+ subject_rotary_emb = self.pos_embed(subject_ids)
751
+
752
+
753
+
754
+ if self.use_attention_mask:
755
+ num_objects = N
756
+ tokens_per_object = self.max_boxes_token_length
757
+ total_object_tokens = object_bbox_hidden_states.shape[1]
758
+ assert total_object_tokens == num_objects * tokens_per_object, "Total object tokens do not match expected value"
759
+ encoder_tokens = encoder_hidden_states.shape[1]
760
+ img_tokens = hidden_states.shape[1]
761
+ subject_tokens = subject_hidden_states.shape[1]
762
+ # Total number of tokens
763
+ total_tokens = total_object_tokens + encoder_tokens + img_tokens + subject_tokens
764
+
765
+ attention_mask_batch = torch.zeros((B,total_tokens, total_tokens), dtype=hidden_states.dtype,device=hidden_states.device)
766
+ img_H, img_W = design_kwargs['object_layout']['img_token_h'], design_kwargs['object_layout']['img_token_w']
767
+ objects_masks_maps = design_kwargs['object_layout']['objects_masks_maps'].to(dtype=hidden_states.dtype, device=hidden_states.device) # [B,512,512]
768
+ subject_H,subject_W = design_kwargs['subject_contion']['subject_token_h'], design_kwargs['subject_contion']['subject_token_w']
769
+ subject_masks_maps = design_kwargs['subject_contion']['subject_masks_maps'].to(dtype=hidden_states.dtype, device=hidden_states.device) # [B,512,512]
770
+ for m_idx in range(B):
771
+ # Create the base mask (all False/blocked)
772
+ attention_mask = torch.zeros((total_tokens, total_tokens), dtype=hidden_states.dtype,device=hidden_states.device)
773
+
774
+ # Define token ranges
775
+ o_ranges = [] # Ranges for each object
776
+ start_idx = 0
777
+ for i in range(num_objects):
778
+ end_idx = start_idx + tokens_per_object
779
+ o_ranges.append((start_idx, end_idx))
780
+ start_idx = end_idx
781
+
782
+ encoder_range = (total_object_tokens, total_object_tokens + encoder_tokens)
783
+ img_range = (encoder_range[1], encoder_range[1] + img_tokens)
784
+ subject_range = (img_range[1], img_range[1] + subject_tokens)
785
+
786
+ # Fill in the mask
787
+
788
+ # 1. Object self-attention (diagonal o₁-o₁, o₂-o₂, o₃-o₃)
789
+ for o_start, o_end in o_ranges:
790
+ attention_mask[o_start:o_end, o_start:o_end] = True
791
+
792
+ # 2. Objects to img and img to objetcs
793
+
794
+ if not self.use_objects_masks_maps:
795
+ # all objects can attend to img and img can attend to all objects
796
+ for o_start, o_end in o_ranges:
797
+ attention_mask[o_start:o_end, img_range[0]:img_range[1]] = True
798
+ # img can attend to all
799
+ attention_mask[img_range[0]:img_range[1], :] = True
800
+ else:
801
+ # all objects can only attend to the bbox area (defined by objects_mask) of img
802
+ for idx, (o_start, o_end )in enumerate(o_ranges):
803
+ mask = objects_masks_maps[m_idx][idx]
804
+ mask = torch.nn.functional.interpolate(mask[None, None, :, :], (img_H, img_W), mode='nearest-exact').flatten().unsqueeze(1).repeat(1, tokens_per_object) #shape: [img_tokens,tokens_per_object]
805
+
806
+ # objects to img
807
+ attention_mask[o_start:o_end, img_range[0]:img_range[1]] = mask.transpose(-1, -2)
808
+
809
+ # img to objects
810
+ attention_mask[img_range[0]:img_range[1], o_start:o_end] = mask
811
+
812
+
813
+ # img to img
814
+ attention_mask[img_range[0]:img_range[1], img_range[0]:img_range[1]] = True
815
+
816
+ # img to prompt
817
+ attention_mask[img_range[0]:img_range[1], encoder_range[0]:encoder_range[1]] = True
818
+
819
+ # img to subject
820
+ subject_mask = subject_masks_maps[m_idx][0]
821
+
822
+ if not self.use_subject_masks_maps:
823
+ # all img can attend to subject
824
+ attention_mask[img_range[0]:img_range[1], subject_range[0]:subject_range[1]] = True
825
+ else:
826
+ # img can only attend to the bbox area (defined by subject_mask) of subject
827
+
828
+ subject_mask_img = torch.nn.functional.interpolate(subject_mask[None, None, :, :], (img_H, img_W), mode='nearest-exact').flatten().unsqueeze(1).repeat(1, subject_tokens) #shape: [img_tokens,subject_tokens]
829
+
830
+ # img to objects
831
+ attention_mask[img_range[0]:img_range[1], subject_range[0]:subject_range[1]] = subject_mask_img
832
+
833
+
834
+
835
+ # 3. prompt to prompt, prompt to img, and prompt to subject
836
+
837
+ # prompt to prompt
838
+ attention_mask[encoder_range[0]:encoder_range[1], encoder_range[0]:encoder_range[1]] = True
839
+ # prompt to img
840
+ attention_mask[encoder_range[0]:encoder_range[1], img_range[0]:img_range[1]] = True
841
+
842
+ # prompt to subject
843
+ if not self.use_subject_masks_maps:
844
+ attention_mask[encoder_range[0]:encoder_range[1], subject_range[0]:subject_range[1]] = True
845
+ else:
846
+ subject_mask_prompt = torch.nn.functional.interpolate(subject_mask[None, None, :, :], (subject_H, subject_W), mode='nearest-exact').flatten().unsqueeze(1).repeat(1, encoder_tokens) #shape: [subject_tokens,encoder_tokens]
847
+ attention_mask[encoder_range[0]:encoder_range[1], subject_range[0]:subject_range[1]] = subject_mask_prompt.transpose(-1, -2)
848
+
849
+
850
+ # 4. subject to prompt, subject to img, subject to subject
851
+ # subject to prompt
852
+ if not self.use_subject_masks_maps:
853
+ attention_mask[subject_range[0]:subject_range[1], encoder_range[0]:encoder_range[1]] = True
854
+ else:
855
+ attention_mask[subject_range[0]:subject_range[1], encoder_range[0]:encoder_range[1]] = False
856
+
857
+ # subject to img
858
+ if not self.use_subject_masks_maps:
859
+ attention_mask[subject_range[0]:subject_range[1], img_range[0]:img_range[1]] = True
860
+ else:
861
+ attention_mask[subject_range[0]:subject_range[1], img_range[0]:img_range[1]] = subject_mask_img.transpose(-1, -2)
862
+ # subject to subject
863
+ if not self.use_subject_masks_maps:
864
+ attention_mask[subject_range[0]:subject_range[1], subject_range[0]:subject_range[1]] = True
865
+ else:
866
+ # blcok non-subject region
867
+ if not self.drop_subject_bg:
868
+ attention_mask[subject_range[0]:subject_range[1], subject_range[0]:subject_range[1]] = True
869
+ else:
870
+ attention_mask[subject_range[0]:subject_range[1], subject_range[0]:subject_range[1]] = subject_mask_img
871
+
872
+
873
+ attention_mask_batch[m_idx] = attention_mask
874
+
875
+ attention_mask_batch = attention_mask_batch.unsqueeze(1).to(dtype=torch.bool, device=hidden_states.device)#[B,2860,2860]->[B,1,2860,2860]
876
+
877
+
878
+ ids = torch.cat((txt_ids, img_ids), dim=0)
879
+ image_rotary_emb = self.pos_embed(ids)
880
+
881
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
882
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
883
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
884
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
885
+
886
+
887
+ for index_block, block in enumerate(self.transformer_blocks):
888
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
889
+
890
+ def create_custom_forward(module, return_dict=None):
891
+ def custom_forward(*inputs):
892
+ if return_dict is not None:
893
+ return module(*inputs, return_dict=return_dict)
894
+ else:
895
+ return module(*inputs)
896
+
897
+ return custom_forward
898
+
899
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
900
+ encoder_hidden_states, hidden_states, subject_hidden_states, object_bbox_hidden_states = torch.utils.checkpoint.checkpoint(
901
+ create_custom_forward(block),
902
+ hidden_states,
903
+ encoder_hidden_states,
904
+ temb,
905
+ image_rotary_emb,
906
+ subject_hidden_states,
907
+ subject_rotary_emb,
908
+ object_bbox_hidden_states,
909
+ object_rotary_emb,
910
+ design_scale,
911
+ attention_mask_batch,
912
+ **ckpt_kwargs,
913
+ )
914
+
915
+ else:
916
+ encoder_hidden_states, hidden_states, subject_hidden_states, object_bbox_hidden_states = block(
917
+ hidden_states=hidden_states,
918
+ encoder_hidden_states=encoder_hidden_states,
919
+ temb=temb,
920
+ image_rotary_emb=image_rotary_emb,
921
+ subject_hidden_states=subject_hidden_states,
922
+ subject_rotary_emb=subject_rotary_emb,
923
+ object_bbox_hidden_states=object_bbox_hidden_states,
924
+ object_rotary_emb=object_rotary_emb,
925
+ design_scale = design_scale,
926
+ attention_mask = attention_mask_batch,
927
+ joint_attention_kwargs=joint_attention_kwargs,
928
+ )
929
+
930
+ # controlnet residual
931
+ if controlnet_block_samples is not None:
932
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
933
+ interval_control = int(np.ceil(interval_control))
934
+ # For Xlabs ControlNet.
935
+ if controlnet_blocks_repeat:
936
+ hidden_states = (
937
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
938
+ )
939
+ else:
940
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
941
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
942
+
943
+ for index_block, block in enumerate(self.single_transformer_blocks):
944
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
945
+
946
+ def create_custom_forward(module, return_dict=None):
947
+ def custom_forward(*inputs):
948
+ if return_dict is not None:
949
+ return module(*inputs, return_dict=return_dict)
950
+ else:
951
+ return module(*inputs)
952
+
953
+ return custom_forward
954
+
955
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
956
+ hidden_states, subject_hidden_states, object_bbox_hidden_states = torch.utils.checkpoint.checkpoint(
957
+ create_custom_forward(block),
958
+ hidden_states,
959
+ temb,
960
+ image_rotary_emb,
961
+ subject_hidden_states,
962
+ subject_rotary_emb,
963
+ object_bbox_hidden_states,
964
+ object_rotary_emb,
965
+ design_scale,
966
+ attention_mask_batch,
967
+ **ckpt_kwargs,
968
+ )
969
+
970
+ else:
971
+ hidden_states, subject_hidden_states, object_bbox_hidden_states = block(
972
+ hidden_states=hidden_states,
973
+ temb=temb,
974
+ image_rotary_emb=image_rotary_emb,
975
+ subject_hidden_states=subject_hidden_states,
976
+ subject_rotary_emb=subject_rotary_emb,
977
+ object_bbox_hidden_states=object_bbox_hidden_states,
978
+ object_rotary_emb=object_rotary_emb,
979
+ design_scale=design_scale,
980
+ attention_mask = attention_mask_batch,
981
+ joint_attention_kwargs=joint_attention_kwargs,
982
+ )
983
+
984
+ # controlnet residual
985
+ if controlnet_single_block_samples is not None:
986
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
987
+ interval_control = int(np.ceil(interval_control))
988
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
989
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
990
+ + controlnet_single_block_samples[index_block // interval_control]
991
+ )
992
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
993
+
994
+ hidden_states = self.norm_out(hidden_states, temb)
995
+ output = self.proj_out(hidden_states)
996
+
997
+ if USE_PEFT_BACKEND:
998
+ # remove `lora_scale` from each PEFT layer
999
+ unscale_lora_layers(self, lora_scale)
1000
+
1001
+ if not return_dict:
1002
+ return (output,)
1003
+
1004
+ return Transformer2DModelOutput(sample=output)
modules/semantic_layout/__pycache__/layout_encoder.cpython-310.pyc ADDED
Binary file (4.26 kB). View file
 
modules/semantic_layout/layout_encoder.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ def zero_module(module):
4
+ """
5
+ Zero out the parameters of a module and return it.
6
+ """
7
+ for p in module.parameters():
8
+ p.detach().zero_()
9
+ return module
10
+
11
+ def get_fourier_embeds_from_boundingbox(embed_dim, box):
12
+ """
13
+ Args:
14
+ embed_dim: int
15
+ box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline
16
+ Returns:
17
+ [B x N x embed_dim] tensor of positional embeddings
18
+ """
19
+
20
+ batch_size, num_boxes = box.shape[:2]
21
+
22
+ emb = 100 ** (torch.arange(embed_dim) / embed_dim)
23
+ emb = emb[None, None, None].to(device=box.device, dtype=box.dtype)
24
+ emb = emb * box.unsqueeze(-1)
25
+
26
+ emb = torch.stack((emb.sin(), emb.cos()), dim=-1)
27
+ emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4)
28
+
29
+ return emb
30
+
31
+ class PixArtAlphaTextProjection(nn.Module):
32
+ """
33
+ Projects caption embeddings. Also handles dropout for classifier-free guidance.
34
+
35
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
36
+ """
37
+
38
+ def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
39
+ super().__init__()
40
+ if out_features is None:
41
+ out_features = hidden_size
42
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
43
+ if act_fn == "gelu_tanh":
44
+ self.act_1 = nn.GELU(approximate="tanh")
45
+ elif act_fn == "silu":
46
+ self.act_1 = nn.SiLU()
47
+ elif act_fn == "silu_fp32":
48
+ self.act_1 = FP32SiLU()
49
+ else:
50
+ raise ValueError(f"Unknown activation function: {act_fn}")
51
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
52
+
53
+ def forward(self, caption):
54
+ hidden_states = self.linear_1(caption)
55
+ hidden_states = self.act_1(hidden_states)
56
+ hidden_states = self.linear_2(hidden_states)
57
+ return hidden_states
58
+
59
+
60
+ class ObjectLayoutEncoder(nn.Module):
61
+ def __init__(self, positive_len, out_dim, fourier_freqs=8 ,max_boxes_token_length=30):
62
+ super().__init__()
63
+ self.positive_len = positive_len
64
+ self.out_dim = out_dim
65
+
66
+ self.fourier_embedder_dim = fourier_freqs
67
+ self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy #64
68
+
69
+ if isinstance(out_dim, tuple):
70
+ out_dim = out_dim[0]
71
+
72
+ self.null_positive_feature = torch.nn.Parameter(torch.zeros([max_boxes_token_length, self.positive_len]))
73
+ self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
74
+
75
+
76
+ self.linears = PixArtAlphaTextProjection(in_features=self.positive_len + self.position_dim,hidden_size=out_dim//2,out_features=out_dim, act_fn="silu")
77
+
78
+ def forward(
79
+ self,
80
+ boxes, # [B,10,4]
81
+ masks, # [B,10]
82
+ positive_embeddings, # [B,10,30,3072]
83
+ ):
84
+
85
+ B, N, S, C = positive_embeddings.shape # B: batch_size, N: 10, S: 30, C: 3072
86
+
87
+ positive_embeddings = positive_embeddings.reshape(B*N, S, C) # [B*10,30,3072]
88
+ masks = masks.reshape(B*N, 1, 1) # [B*10,1,1]
89
+
90
+ # Process positional encoding
91
+ xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # [B,10,64]
92
+ xyxy_embedding = xyxy_embedding.reshape(B*N, -1) # [B*10,64]
93
+ xyxy_null = self.null_position_feature.view(1, -1) # [1,64]
94
+
95
+ # Expand positional encoding to match sequence dimension
96
+ xyxy_embedding = xyxy_embedding.unsqueeze(1).expand(-1, S, -1) # [B*10,30,64]
97
+ xyxy_null = xyxy_null.unsqueeze(0).expand(B*N, S, -1) # [B*10,30,64]
98
+
99
+ # Apply mask
100
+ xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null # [B*10,30,64]
101
+
102
+ # Process feature encoding
103
+ positive_null = self.null_positive_feature.view(1, S, -1).expand(B*N, -1, -1) # [B*10,30,3072]
104
+ positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null # [B*10,30,3072]
105
+
106
+ # Concatenate positional encoding and feature encoding
107
+ combined = torch.cat([positive_embeddings, xyxy_embedding], dim=-1) # [B*10,30,3072+64]
108
+
109
+ # Process each box's features independently
110
+ objs = self.linears(combined) # [B*10,30,3072]
111
+
112
+ # Restore original shape
113
+ objs = objs.reshape(B, N, S, -1) # [B,10,30,3072]
114
+
115
+ return objs
116
+
117
+ class ObjectLayoutEncoder_noFourier(nn.Module):
118
+ def __init__(self, in_dim, out_dim):
119
+ super().__init__()
120
+ self.in_dim = in_dim
121
+ self.out_dim = out_dim
122
+
123
+ self.linears = PixArtAlphaTextProjection(in_features=self.in_dim,hidden_size=out_dim//2,out_features=out_dim, act_fn="silu")
124
+
125
+ def forward(
126
+ self,
127
+ positive_embeddings, # [B,10,30,3072]
128
+ ):
129
+
130
+ B, N, S, C = positive_embeddings.shape # B: batch_size, N: 10, S: 30, C: 3072
131
+ positive_embeddings = positive_embeddings.reshape(B*N, S, C) # [B*10,30,3072]
132
+
133
+ # Process each box's features independently
134
+ objs = self.linears(positive_embeddings) # [B*10,30,3072]
135
+
136
+ # Restore original shape
137
+ objs = objs.reshape(B, N, S, -1) # [B,10,30,3072]
138
+
139
+ return objs
pipeline/__pycache__/pipeline_flux_creatidesign.cpython-310.pyc ADDED
Binary file (32.3 kB). View file
 
pipeline/pipeline_flux_creatidesign.py ADDED
@@ -0,0 +1,1068 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import (
21
+ CLIPImageProcessor,
22
+ CLIPTextModel,
23
+ CLIPTokenizer,
24
+ CLIPVisionModelWithProjection,
25
+ T5EncoderModel,
26
+ T5TokenizerFast,
27
+ )
28
+
29
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
30
+ from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
31
+ from diffusers.models.autoencoders import AutoencoderKL
32
+ from diffusers.models.transformers import FluxTransformer2DModel
33
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
34
+ from diffusers.utils import (
35
+ USE_PEFT_BACKEND,
36
+ is_torch_xla_available,
37
+ logging,
38
+ replace_example_docstring,
39
+ scale_lora_layers,
40
+ unscale_lora_layers,
41
+ )
42
+ from diffusers.utils.torch_utils import randn_tensor
43
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
44
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
45
+
46
+
47
+ if is_torch_xla_available():
48
+ import torch_xla.core.xla_model as xm
49
+
50
+ XLA_AVAILABLE = True
51
+ else:
52
+ XLA_AVAILABLE = False
53
+
54
+
55
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+ EXAMPLE_DOC_STRING = """
58
+ Examples:
59
+ ```py
60
+ >>> import torch
61
+ >>> from diffusers import FluxPipeline
62
+
63
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
64
+ >>> pipe.to("cuda")
65
+ >>> prompt = "A cat holding a sign that says hello world"
66
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
67
+ >>> # Refer to the pipeline documentation for more details.
68
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
69
+ >>> image.save("flux.png")
70
+ ```
71
+ """
72
+
73
+
74
+ def calculate_shift(
75
+ image_seq_len,
76
+ base_seq_len: int = 256,
77
+ max_seq_len: int = 4096,
78
+ base_shift: float = 0.5,
79
+ max_shift: float = 1.16,
80
+ ):
81
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
82
+ b = base_shift - m * base_seq_len
83
+ mu = image_seq_len * m + b
84
+ return mu
85
+
86
+
87
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
88
+ def retrieve_timesteps(
89
+ scheduler,
90
+ num_inference_steps: Optional[int] = None,
91
+ device: Optional[Union[str, torch.device]] = None,
92
+ timesteps: Optional[List[int]] = None,
93
+ sigmas: Optional[List[float]] = None,
94
+ **kwargs,
95
+ ):
96
+ r"""
97
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
98
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
99
+
100
+ Args:
101
+ scheduler (`SchedulerMixin`):
102
+ The scheduler to get timesteps from.
103
+ num_inference_steps (`int`):
104
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
105
+ must be `None`.
106
+ device (`str` or `torch.device`, *optional*):
107
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
108
+ timesteps (`List[int]`, *optional*):
109
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
110
+ `num_inference_steps` and `sigmas` must be `None`.
111
+ sigmas (`List[float]`, *optional*):
112
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
113
+ `num_inference_steps` and `timesteps` must be `None`.
114
+
115
+ Returns:
116
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
117
+ second element is the number of inference steps.
118
+ """
119
+ if timesteps is not None and sigmas is not None:
120
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
121
+ if timesteps is not None:
122
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
123
+ if not accepts_timesteps:
124
+ raise ValueError(
125
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
126
+ f" timestep schedules. Please check whether you are using the correct scheduler."
127
+ )
128
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
129
+ timesteps = scheduler.timesteps
130
+ num_inference_steps = len(timesteps)
131
+ elif sigmas is not None:
132
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
133
+ if not accept_sigmas:
134
+ raise ValueError(
135
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
136
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
137
+ )
138
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
139
+ timesteps = scheduler.timesteps
140
+ num_inference_steps = len(timesteps)
141
+ else:
142
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
143
+ timesteps = scheduler.timesteps
144
+ return timesteps, num_inference_steps
145
+
146
+
147
+ class FluxPipeline(
148
+ DiffusionPipeline,
149
+ FluxLoraLoaderMixin,
150
+ FromSingleFileMixin,
151
+ TextualInversionLoaderMixin,
152
+ FluxIPAdapterMixin,
153
+ ):
154
+ r"""
155
+ The Flux pipeline for text-to-image generation.
156
+
157
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
158
+
159
+ Args:
160
+ transformer ([`FluxTransformer2DModel`]):
161
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
162
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
163
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
164
+ vae ([`AutoencoderKL`]):
165
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
166
+ text_encoder ([`CLIPTextModel`]):
167
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
168
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
169
+ text_encoder_2 ([`T5EncoderModel`]):
170
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
171
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
172
+ tokenizer (`CLIPTokenizer`):
173
+ Tokenizer of class
174
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
175
+ tokenizer_2 (`T5TokenizerFast`):
176
+ Second Tokenizer of class
177
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
178
+ """
179
+
180
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
181
+ _optional_components = ["image_encoder", "feature_extractor"]
182
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
183
+
184
+ def __init__(
185
+ self,
186
+ scheduler: FlowMatchEulerDiscreteScheduler,
187
+ vae: AutoencoderKL,
188
+ text_encoder: CLIPTextModel,
189
+ tokenizer: CLIPTokenizer,
190
+ text_encoder_2: T5EncoderModel,
191
+ tokenizer_2: T5TokenizerFast,
192
+ transformer: FluxTransformer2DModel,
193
+ image_encoder: CLIPVisionModelWithProjection = None,
194
+ feature_extractor: CLIPImageProcessor = None,
195
+ ):
196
+ super().__init__()
197
+
198
+ self.register_modules(
199
+ vae=vae,
200
+ text_encoder=text_encoder,
201
+ text_encoder_2=text_encoder_2,
202
+ tokenizer=tokenizer,
203
+ tokenizer_2=tokenizer_2,
204
+ transformer=transformer,
205
+ scheduler=scheduler,
206
+ image_encoder=image_encoder,
207
+ feature_extractor=feature_extractor,
208
+ )
209
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
210
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
211
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
212
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
213
+ self.tokenizer_max_length = (
214
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
215
+ )
216
+ self.default_sample_size = 128
217
+
218
+ def _get_t5_prompt_embeds(
219
+ self,
220
+ prompt: Union[str, List[str]] = None,
221
+ num_images_per_prompt: int = 1,
222
+ max_sequence_length: int = 512,
223
+ device: Optional[torch.device] = None,
224
+ dtype: Optional[torch.dtype] = None,
225
+ ):
226
+ device = device or self._execution_device
227
+ dtype = dtype or self.text_encoder.dtype
228
+
229
+ prompt = [prompt] if isinstance(prompt, str) else prompt
230
+ batch_size = len(prompt)
231
+
232
+ if isinstance(self, TextualInversionLoaderMixin):
233
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
234
+
235
+ text_inputs = self.tokenizer_2(
236
+ prompt,
237
+ padding="max_length",
238
+ max_length=max_sequence_length,
239
+ truncation=True,
240
+ return_length=False,
241
+ return_overflowing_tokens=False,
242
+ return_tensors="pt",
243
+ )
244
+ text_input_ids = text_inputs.input_ids
245
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
246
+
247
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
248
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
249
+ logger.warning(
250
+ "The following part of your input was truncated because `max_sequence_length` is set to "
251
+ f" {max_sequence_length} tokens: {removed_text}"
252
+ )
253
+
254
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
255
+
256
+ dtype = self.text_encoder_2.dtype
257
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
258
+
259
+ _, seq_len, _ = prompt_embeds.shape
260
+
261
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
262
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
263
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
264
+
265
+ return prompt_embeds
266
+
267
+ def _get_clip_prompt_embeds(
268
+ self,
269
+ prompt: Union[str, List[str]],
270
+ num_images_per_prompt: int = 1,
271
+ device: Optional[torch.device] = None,
272
+ ):
273
+ device = device or self._execution_device
274
+
275
+ prompt = [prompt] if isinstance(prompt, str) else prompt
276
+ batch_size = len(prompt)
277
+
278
+ if isinstance(self, TextualInversionLoaderMixin):
279
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
280
+
281
+ text_inputs = self.tokenizer(
282
+ prompt,
283
+ padding="max_length",
284
+ max_length=self.tokenizer_max_length,
285
+ truncation=True,
286
+ return_overflowing_tokens=False,
287
+ return_length=False,
288
+ return_tensors="pt",
289
+ )
290
+
291
+ text_input_ids = text_inputs.input_ids
292
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
293
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
294
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
295
+ logger.warning(
296
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
297
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
298
+ )
299
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
300
+
301
+ # Use pooled output of CLIPTextModel
302
+ prompt_embeds = prompt_embeds.pooler_output
303
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
304
+
305
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
306
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
307
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
308
+
309
+ return prompt_embeds
310
+
311
+ def encode_prompt(
312
+ self,
313
+ prompt: Union[str, List[str]],
314
+ prompt_2: Union[str, List[str]],
315
+ device: Optional[torch.device] = None,
316
+ num_images_per_prompt: int = 1,
317
+ prompt_embeds: Optional[torch.FloatTensor] = None,
318
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
319
+ max_sequence_length: int = 512,
320
+ lora_scale: Optional[float] = None,
321
+ ):
322
+ r"""
323
+
324
+ Args:
325
+ prompt (`str` or `List[str]`, *optional*):
326
+ prompt to be encoded
327
+ prompt_2 (`str` or `List[str]`, *optional*):
328
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
329
+ used in all text-encoders
330
+ device: (`torch.device`):
331
+ torch device
332
+ num_images_per_prompt (`int`):
333
+ number of images that should be generated per prompt
334
+ prompt_embeds (`torch.FloatTensor`, *optional*):
335
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
336
+ provided, text embeddings will be generated from `prompt` input argument.
337
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
338
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
339
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
340
+ lora_scale (`float`, *optional*):
341
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
342
+ """
343
+ device = device or self._execution_device
344
+
345
+ # set lora scale so that monkey patched LoRA
346
+ # function of text encoder can correctly access it
347
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
348
+ self._lora_scale = lora_scale
349
+
350
+ # dynamically adjust the LoRA scale
351
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
352
+ scale_lora_layers(self.text_encoder, lora_scale)
353
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
354
+ scale_lora_layers(self.text_encoder_2, lora_scale)
355
+
356
+ prompt = [prompt] if isinstance(prompt, str) else prompt
357
+
358
+ if prompt_embeds is None:
359
+ prompt_2 = prompt_2 or prompt
360
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
361
+
362
+ # We only use the pooled prompt output from the CLIPTextModel
363
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
364
+ prompt=prompt,
365
+ device=device,
366
+ num_images_per_prompt=num_images_per_prompt,
367
+ )
368
+ prompt_embeds = self._get_t5_prompt_embeds(
369
+ prompt=prompt_2,
370
+ num_images_per_prompt=num_images_per_prompt,
371
+ max_sequence_length=max_sequence_length,
372
+ device=device,
373
+ )
374
+
375
+ if self.text_encoder is not None:
376
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
377
+ # Retrieve the original scale by scaling back the LoRA layers
378
+ unscale_lora_layers(self.text_encoder, lora_scale)
379
+
380
+ if self.text_encoder_2 is not None:
381
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
382
+ # Retrieve the original scale by scaling back the LoRA layers
383
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
384
+
385
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
386
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
387
+
388
+ return prompt_embeds, pooled_prompt_embeds, text_ids
389
+
390
+ def encode_image(self, image, device, num_images_per_prompt):
391
+ dtype = next(self.image_encoder.parameters()).dtype
392
+
393
+ if not isinstance(image, torch.Tensor):
394
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
395
+
396
+ image = image.to(device=device, dtype=dtype)
397
+ image_embeds = self.image_encoder(image).image_embeds
398
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
399
+ return image_embeds
400
+
401
+ def prepare_ip_adapter_image_embeds(
402
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
403
+ ):
404
+ image_embeds = []
405
+ if ip_adapter_image_embeds is None:
406
+ if not isinstance(ip_adapter_image, list):
407
+ ip_adapter_image = [ip_adapter_image]
408
+
409
+ if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
410
+ raise ValueError(
411
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
412
+ )
413
+
414
+ for single_ip_adapter_image, image_proj_layer in zip(
415
+ ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
416
+ ):
417
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
418
+
419
+ image_embeds.append(single_image_embeds[None, :])
420
+ else:
421
+ for single_image_embeds in ip_adapter_image_embeds:
422
+ image_embeds.append(single_image_embeds)
423
+
424
+ ip_adapter_image_embeds = []
425
+ for i, single_image_embeds in enumerate(image_embeds):
426
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
427
+ single_image_embeds = single_image_embeds.to(device=device)
428
+ ip_adapter_image_embeds.append(single_image_embeds)
429
+
430
+ return ip_adapter_image_embeds
431
+
432
+ def check_inputs(
433
+ self,
434
+ prompt,
435
+ prompt_2,
436
+ height,
437
+ width,
438
+ negative_prompt=None,
439
+ negative_prompt_2=None,
440
+ prompt_embeds=None,
441
+ negative_prompt_embeds=None,
442
+ pooled_prompt_embeds=None,
443
+ negative_pooled_prompt_embeds=None,
444
+ callback_on_step_end_tensor_inputs=None,
445
+ max_sequence_length=None,
446
+ ):
447
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
448
+ logger.warning(
449
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
450
+ )
451
+
452
+ if callback_on_step_end_tensor_inputs is not None and not all(
453
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
454
+ ):
455
+ raise ValueError(
456
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
457
+ )
458
+
459
+ if prompt is not None and prompt_embeds is not None:
460
+ raise ValueError(
461
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
462
+ " only forward one of the two."
463
+ )
464
+ elif prompt_2 is not None and prompt_embeds is not None:
465
+ raise ValueError(
466
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
467
+ " only forward one of the two."
468
+ )
469
+ elif prompt is None and prompt_embeds is None:
470
+ raise ValueError(
471
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
472
+ )
473
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
474
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
475
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
476
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
477
+
478
+ if negative_prompt is not None and negative_prompt_embeds is not None:
479
+ raise ValueError(
480
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
481
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
482
+ )
483
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
484
+ raise ValueError(
485
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
486
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
487
+ )
488
+
489
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
490
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
491
+ raise ValueError(
492
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
493
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
494
+ f" {negative_prompt_embeds.shape}."
495
+ )
496
+
497
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
498
+ raise ValueError(
499
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
500
+ )
501
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
502
+ raise ValueError(
503
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
504
+ )
505
+
506
+ if max_sequence_length is not None and max_sequence_length > 512:
507
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
508
+
509
+ @staticmethod
510
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype,scale_h=1.0,scale_w=1.0):
511
+ latent_image_ids = torch.zeros(height, width, 3)
512
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]* scale_h
513
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]* scale_w
514
+
515
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
516
+
517
+ latent_image_ids = latent_image_ids.reshape(
518
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
519
+ )
520
+
521
+ return latent_image_ids.to(device=device, dtype=dtype)
522
+
523
+ @staticmethod
524
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
525
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
526
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
527
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
528
+
529
+ return latents
530
+
531
+ @staticmethod
532
+ def _unpack_latents(latents, height, width, vae_scale_factor):
533
+ batch_size, num_patches, channels = latents.shape
534
+
535
+ # VAE applies 8x compression on images but we must also account for packing which requires
536
+ # latent height and width to be divisible by 2.
537
+ height = 2 * (int(height) // (vae_scale_factor * 2))
538
+ width = 2 * (int(width) // (vae_scale_factor * 2))
539
+
540
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
541
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
542
+
543
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
544
+
545
+ return latents
546
+
547
+ def enable_vae_slicing(self):
548
+ r"""
549
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
550
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
551
+ """
552
+ self.vae.enable_slicing()
553
+
554
+ def disable_vae_slicing(self):
555
+ r"""
556
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
557
+ computing decoding in one step.
558
+ """
559
+ self.vae.disable_slicing()
560
+
561
+ def enable_vae_tiling(self):
562
+ r"""
563
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
564
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
565
+ processing larger images.
566
+ """
567
+ self.vae.enable_tiling()
568
+
569
+ def disable_vae_tiling(self):
570
+ r"""
571
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
572
+ computing decoding in one step.
573
+ """
574
+ self.vae.disable_tiling()
575
+
576
+ def prepare_latents(
577
+ self,
578
+ batch_size,
579
+ num_channels_latents,
580
+ height,
581
+ width,
582
+ dtype,
583
+ device,
584
+ generator,
585
+ latents=None,
586
+ ):
587
+ # VAE applies 8x compression on images but we must also account for packing which requires
588
+ # latent height and width to be divisible by 2.
589
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
590
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
591
+
592
+ shape = (batch_size, num_channels_latents, height, width)
593
+
594
+ if latents is not None:
595
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
596
+ return latents.to(device=device, dtype=dtype), latent_image_ids
597
+
598
+ if isinstance(generator, list) and len(generator) != batch_size:
599
+ raise ValueError(
600
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
601
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
602
+ )
603
+
604
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
605
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
606
+
607
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
608
+
609
+ return latents, latent_image_ids
610
+
611
+ @property
612
+ def guidance_scale(self):
613
+ return self._guidance_scale
614
+
615
+ @property
616
+ def joint_attention_kwargs(self):
617
+ return self._joint_attention_kwargs
618
+
619
+ @property
620
+ def num_timesteps(self):
621
+ return self._num_timesteps
622
+
623
+ @property
624
+ def interrupt(self):
625
+ return self._interrupt
626
+
627
+ @torch.no_grad()
628
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
629
+ def __call__(
630
+ self,
631
+ prompt: Union[str, List[str]] = None,
632
+ prompt_2: Optional[Union[str, List[str]]] = None,
633
+ negative_prompt: Union[str, List[str]] = None,
634
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
635
+ true_cfg_scale: float = 3.0,
636
+ height: Optional[int] = None,
637
+ width: Optional[int] = None,
638
+ num_inference_steps: int = 28,
639
+ sigmas: Optional[List[float]] = None,
640
+ guidance_scale: float = 3.5,
641
+ num_images_per_prompt: Optional[int] = 1,
642
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
643
+ latents: Optional[torch.FloatTensor] = None,
644
+ prompt_embeds: Optional[torch.FloatTensor] = None,
645
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
646
+ ip_adapter_image: Optional[PipelineImageInput] = None,
647
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
648
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
649
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
650
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
651
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
652
+ output_type: Optional[str] = "pil",
653
+ return_dict: bool = True,
654
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
655
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
656
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
657
+ max_sequence_length: int = 512,
658
+ objects_boxes=None,
659
+ objects_caption=None,
660
+ objects_masks = None,
661
+ objects_masks_maps=None,
662
+ subject_masks_maps=None,
663
+ condition_img = None,
664
+ neg_condtion_img=None,
665
+ max_boxes_per_image = 10,
666
+ position_delta=[0,-64],
667
+ scale_h=1.0,
668
+ scale_w=1.0,
669
+ use_bucket=False
670
+ ):
671
+ r"""
672
+ Function invoked when calling the pipeline for generation.
673
+
674
+ Args:
675
+ prompt (`str` or `List[str]`, *optional*):
676
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
677
+ instead.
678
+ prompt_2 (`str` or `List[str]`, *optional*):
679
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
680
+ will be used instead
681
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
682
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
683
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
684
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
685
+ num_inference_steps (`int`, *optional*, defaults to 50):
686
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
687
+ expense of slower inference.
688
+ sigmas (`List[float]`, *optional*):
689
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
690
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
691
+ will be used.
692
+ guidance_scale (`float`, *optional*, defaults to 7.0):
693
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
694
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
695
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
696
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
697
+ usually at the expense of lower image quality.
698
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
699
+ The number of images to generate per prompt.
700
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
701
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
702
+ to make generation deterministic.
703
+ latents (`torch.FloatTensor`, *optional*):
704
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
705
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
706
+ tensor will ge generated by sampling using the supplied random `generator`.
707
+ prompt_embeds (`torch.FloatTensor`, *optional*):
708
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
709
+ provided, text embeddings will be generated from `prompt` input argument.
710
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
711
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
712
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
713
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
714
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
715
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
716
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
717
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
718
+ negative_ip_adapter_image:
719
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
720
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
721
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
722
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
723
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
724
+ output_type (`str`, *optional*, defaults to `"pil"`):
725
+ The output format of the generate image. Choose between
726
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
727
+ return_dict (`bool`, *optional*, defaults to `True`):
728
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
729
+ joint_attention_kwargs (`dict`, *optional*):
730
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
731
+ `self.processor` in
732
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
733
+ callback_on_step_end (`Callable`, *optional*):
734
+ A function that calls at the end of each denoising steps during the inference. The function is called
735
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
736
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
737
+ `callback_on_step_end_tensor_inputs`.
738
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
739
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
740
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
741
+ `._callback_tensor_inputs` attribute of your pipeline class.
742
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
743
+
744
+ Examples:
745
+
746
+ Returns:
747
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
748
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
749
+ images.
750
+ """
751
+
752
+ height = height or self.default_sample_size * self.vae_scale_factor
753
+ width = width or self.default_sample_size * self.vae_scale_factor
754
+
755
+ # 1. Check inputs. Raise error if not correct
756
+ self.check_inputs(
757
+ prompt,
758
+ prompt_2,
759
+ height,
760
+ width,
761
+ negative_prompt=negative_prompt,
762
+ negative_prompt_2=negative_prompt_2,
763
+ prompt_embeds=prompt_embeds,
764
+ negative_prompt_embeds=negative_prompt_embeds,
765
+ pooled_prompt_embeds=pooled_prompt_embeds,
766
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
767
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
768
+ max_sequence_length=max_sequence_length,
769
+ )
770
+
771
+ self._guidance_scale = guidance_scale
772
+ self._joint_attention_kwargs = joint_attention_kwargs
773
+ self._interrupt = False
774
+
775
+ # 2. Define call parameters
776
+ if prompt is not None and isinstance(prompt, str):
777
+ batch_size = 1
778
+ elif prompt is not None and isinstance(prompt, list):
779
+ batch_size = len(prompt)
780
+ else:
781
+ batch_size = prompt_embeds.shape[0]
782
+
783
+ device = self._execution_device
784
+
785
+ lora_scale = (
786
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
787
+ )
788
+ #creatidesign
789
+ negative_prompt = negative_prompt if negative_prompt is not None else [""]*batch_size
790
+
791
+ do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
792
+ (
793
+ prompt_embeds,
794
+ pooled_prompt_embeds,
795
+ text_ids,
796
+ ) = self.encode_prompt(
797
+ prompt=prompt,
798
+ prompt_2=prompt_2,
799
+ prompt_embeds=prompt_embeds,
800
+ pooled_prompt_embeds=pooled_prompt_embeds,
801
+ device=device,
802
+ num_images_per_prompt=num_images_per_prompt,
803
+ max_sequence_length=max_sequence_length,
804
+ lora_scale=lora_scale,
805
+ )
806
+ if do_true_cfg:
807
+ (
808
+ negative_prompt_embeds,
809
+ negative_pooled_prompt_embeds,
810
+ _,
811
+ ) = self.encode_prompt(
812
+ prompt=negative_prompt,
813
+ prompt_2=negative_prompt_2,
814
+ prompt_embeds=negative_prompt_embeds,
815
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
816
+ device=device,
817
+ num_images_per_prompt=num_images_per_prompt,
818
+ max_sequence_length=max_sequence_length,
819
+ lora_scale=lora_scale,
820
+ )
821
+
822
+ # 4. Prepare latent variables
823
+ num_channels_latents = self.transformer.config.in_channels // 4
824
+ latents, latent_image_ids = self.prepare_latents(
825
+ batch_size * num_images_per_prompt,
826
+ num_channels_latents,
827
+ height,
828
+ width,
829
+ prompt_embeds.dtype,
830
+ device,
831
+ generator,
832
+ latents,
833
+ )
834
+
835
+ # 5. Prepare timesteps
836
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
837
+ image_seq_len = latents.shape[1]
838
+ mu = calculate_shift(
839
+ image_seq_len,
840
+ self.scheduler.config.base_image_seq_len,
841
+ self.scheduler.config.max_image_seq_len,
842
+ self.scheduler.config.base_shift,
843
+ self.scheduler.config.max_shift,
844
+ )
845
+ timesteps, num_inference_steps = retrieve_timesteps(
846
+ self.scheduler,
847
+ num_inference_steps,
848
+ device,
849
+ sigmas=sigmas,
850
+ mu=mu,
851
+ )
852
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
853
+ self._num_timesteps = len(timesteps)
854
+
855
+ # handle guidance
856
+ if self.transformer.config.guidance_embeds:
857
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
858
+ guidance = guidance.expand(latents.shape[0])
859
+ else:
860
+ guidance = None
861
+
862
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
863
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
864
+ ):
865
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
866
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
867
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
868
+ ):
869
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
870
+
871
+ if self.joint_attention_kwargs is None:
872
+ self._joint_attention_kwargs = {}
873
+
874
+ image_embeds = None
875
+ negative_image_embeds = None
876
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
877
+ image_embeds = self.prepare_ip_adapter_image_embeds(
878
+ ip_adapter_image,
879
+ ip_adapter_image_embeds,
880
+ device,
881
+ batch_size * num_images_per_prompt,
882
+ )
883
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
884
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
885
+ negative_ip_adapter_image,
886
+ negative_ip_adapter_image_embeds,
887
+ device,
888
+ batch_size * num_images_per_prompt,
889
+ )
890
+
891
+ #creatidesign
892
+ objects_boxes = objects_boxes.to(device=device, dtype=latents.dtype).repeat_interleave(batch_size, dim=0)
893
+ objects_masks = objects_masks.to(device=device, dtype=latents.dtype).repeat_interleave(batch_size, dim=0)
894
+ objects_masks_maps = objects_masks_maps.to(device=device, dtype=latents.dtype).repeat_interleave(batch_size, dim=0)
895
+ subject_masks_maps = subject_masks_maps.to(device=device, dtype=latents.dtype).repeat_interleave(batch_size, dim=0)
896
+ N = len(objects_caption[0])
897
+ print("N",N)
898
+ bbox_text_embeddings = torch.zeros(
899
+ max_boxes_per_image,max_sequence_length,4096, device=device, dtype=latents.dtype
900
+ )
901
+ if N>0:
902
+ bbox_text_embeddings_temp,_,_ = self.encode_prompt(prompt=objects_caption[0],prompt_2=None,device=device,
903
+ num_images_per_prompt=num_images_per_prompt,
904
+ max_sequence_length=max_sequence_length,)
905
+ bbox_text_embeddings[:N]=bbox_text_embeddings_temp
906
+ bbox_text_embeddings = bbox_text_embeddings.unsqueeze(0).to(device=device, dtype=latents.dtype).repeat_interleave(batch_size, dim=0) #[b,10,30,4096]
907
+
908
+ # Convert condition images to latent space
909
+ condition_img = condition_img.to(device=device,dtype=self.vae.dtype).repeat_interleave(batch_size, dim=0)
910
+ condition_img_input = self.vae.encode(condition_img).latent_dist.sample()
911
+ condition_img_input = (condition_img_input - self.vae.config.shift_factor) * self.vae.config.scaling_factor
912
+ condition_img_input = condition_img_input.to(dtype=latents.dtype)
913
+ condition_latent_image_ids = self._prepare_latent_image_ids(
914
+ condition_img_input.shape[0],
915
+ condition_img_input.shape[2] // 2,
916
+ condition_img_input.shape[3] // 2,
917
+ device,
918
+ latents.dtype,
919
+ scale_h = scale_h,
920
+ scale_w = scale_w,
921
+ )
922
+
923
+ # shift condition image ids
924
+
925
+ if use_bucket:
926
+ # offset determined by condition image width and scale
927
+ condition_latent_image_ids[:, 1] += 0 # H dimension unchanged
928
+ condition_latent_image_ids[:, 2] += -1*(condition_img_input.shape[3]*scale_w//2)
929
+ else:
930
+ # shift condition image ids
931
+ condition_latent_image_ids[:, 1] += position_delta[0] # H dimension unchanged
932
+ condition_latent_image_ids[:, 2] += position_delta[1] # W dimension shift left
933
+
934
+ packed_clean_condition_input = self._pack_latents(
935
+ condition_img_input,
936
+ batch_size=condition_img_input.shape[0],
937
+ num_channels_latents=condition_img_input.shape[1],
938
+ height=condition_img_input.shape[2],
939
+ width=condition_img_input.shape[3],
940
+ )
941
+
942
+
943
+ design_kwargs = {
944
+ "object_layout": {"objects_boxes": objects_boxes, "bbox_text_embeddings": bbox_text_embeddings, "bbox_masks": objects_masks,"objects_masks_maps":objects_masks_maps,"img_token_h":(int(height) // (self.vae_scale_factor * 2)), "img_token_w":(int(width) // (self.vae_scale_factor * 2))}, #[b,10,4], [B,10,512,4096],[b,10]
945
+ "subject_contion":{"condition_img":packed_clean_condition_input,"subject_masks_maps":subject_masks_maps,"condition_img_ids":condition_latent_image_ids,"subject_token_h":condition_img_input.shape[2]//2, "subject_token_w":condition_img_input.shape[3]//2}, # [B,4,64,64]
946
+ }
947
+
948
+ neg_objects_masks = torch.zeros_like(objects_masks).to(device=device, dtype=latents.dtype)
949
+
950
+ neg_condtion_img = neg_condtion_img.to(device=device,dtype=self.vae.dtype).repeat_interleave(batch_size, dim=0)
951
+ neg_condtion_img_input = self.vae.encode(neg_condtion_img).latent_dist.sample()
952
+ neg_condtion_img_input = (neg_condtion_img_input - self.vae.config.shift_factor) * self.vae.config.scaling_factor
953
+ neg_condtion_img_input = neg_condtion_img_input.to(dtype=latents.dtype)
954
+ neg_condition_latent_image_ids = self._prepare_latent_image_ids(
955
+ neg_condtion_img_input.shape[0],
956
+ neg_condtion_img_input.shape[2] // 2,
957
+ neg_condtion_img_input.shape[3] // 2,
958
+ device,
959
+ latents.dtype,
960
+ scale_h = scale_h,
961
+ scale_w = scale_w
962
+ )
963
+
964
+ if use_bucket:
965
+ # offset determined by condition image width and scale
966
+ neg_condition_latent_image_ids[:, 1] += 0 # H dimension unchanged
967
+ neg_condition_latent_image_ids[:, 2] += -1*(condition_img_input.shape[3]*scale_w//2)
968
+ else:
969
+ # shift negative condition image ids
970
+ neg_condition_latent_image_ids[:, 1] += position_delta[0] # H dimension shift
971
+ neg_condition_latent_image_ids[:, 2] += position_delta[1] # W dimension shift
972
+
973
+ packed_clean_neg_condtion_input = self._pack_latents(
974
+ neg_condtion_img_input,
975
+ batch_size=neg_condtion_img_input.shape[0],
976
+ num_channels_latents=neg_condtion_img_input.shape[1],
977
+ height=neg_condtion_img_input.shape[2],
978
+ width=neg_condtion_img_input.shape[3],
979
+ )
980
+
981
+ neg_subject_masks_maps = subject_masks_maps
982
+ neg_objects_masks_maps = objects_masks_maps
983
+ neg_design_kwargs = {
984
+ "object_layout": {"objects_boxes": objects_boxes, "bbox_text_embeddings": bbox_text_embeddings, "bbox_masks": neg_objects_masks,"objects_masks_maps":neg_objects_masks_maps,"img_token_h":(int(height) // (self.vae_scale_factor * 2)), "img_token_w":(int(width) // (self.vae_scale_factor * 2))}, #[b,10,4], [B,10,512,4096],[b,10]
985
+ "subject_contion":{"condition_img":packed_clean_neg_condtion_input,"subject_masks_maps":neg_subject_masks_maps,"condition_img_ids":neg_condition_latent_image_ids,"subject_token_h":condition_img_input.shape[2]//2, "subject_token_w":condition_img_input.shape[3]//2}, # [B,4,64,64]
986
+ }
987
+
988
+ # 6. Denoising loop
989
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
990
+ for i, t in enumerate(timesteps):
991
+ if self.interrupt:
992
+ continue
993
+
994
+ if image_embeds is not None:
995
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
996
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
997
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
998
+ noise_pred = self.transformer(
999
+ hidden_states=latents,
1000
+ timestep=timestep / 1000,
1001
+ guidance=guidance,
1002
+ pooled_projections=pooled_prompt_embeds,
1003
+ encoder_hidden_states=prompt_embeds,
1004
+ txt_ids=text_ids,
1005
+ img_ids=latent_image_ids,
1006
+ joint_attention_kwargs=self.joint_attention_kwargs,
1007
+ return_dict=False,
1008
+ design_kwargs = design_kwargs,
1009
+ )[0]
1010
+
1011
+ if do_true_cfg:
1012
+ if negative_image_embeds is not None:
1013
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
1014
+ neg_noise_pred = self.transformer(
1015
+ hidden_states=latents,
1016
+ timestep=timestep / 1000,
1017
+ guidance=guidance,
1018
+ pooled_projections=negative_pooled_prompt_embeds,
1019
+ encoder_hidden_states=negative_prompt_embeds,
1020
+ txt_ids=text_ids,
1021
+ img_ids=latent_image_ids,
1022
+ joint_attention_kwargs=self.joint_attention_kwargs,
1023
+ return_dict=False,
1024
+ design_kwargs = neg_design_kwargs,
1025
+ )[0]
1026
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
1027
+
1028
+ # compute the previous noisy sample x_t -> x_t-1
1029
+ latents_dtype = latents.dtype
1030
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1031
+
1032
+ if latents.dtype != latents_dtype:
1033
+ if torch.backends.mps.is_available():
1034
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1035
+ latents = latents.to(latents_dtype)
1036
+
1037
+ if callback_on_step_end is not None:
1038
+ callback_kwargs = {}
1039
+ for k in callback_on_step_end_tensor_inputs:
1040
+ callback_kwargs[k] = locals()[k]
1041
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1042
+
1043
+ latents = callback_outputs.pop("latents", latents)
1044
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1045
+
1046
+ # call the callback, if provided
1047
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1048
+ progress_bar.update()
1049
+
1050
+ if XLA_AVAILABLE:
1051
+ xm.mark_step()
1052
+
1053
+ if output_type == "latent":
1054
+ image = latents
1055
+
1056
+ else:
1057
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1058
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1059
+ image = self.vae.decode(latents, return_dict=False)[0]
1060
+ image = self.image_processor.postprocess(image, output_type=output_type)
1061
+
1062
+ # Offload all models
1063
+ self.maybe_free_model_hooks()
1064
+
1065
+ if not return_dict:
1066
+ return (image,)
1067
+
1068
+ return FluxPipelineOutput(images=image)
requirements.txt CHANGED
@@ -1,6 +1,14 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
- transformers
6
- xformers
 
 
 
 
 
 
 
 
 
1
+ diffusers
2
+ accelerate
3
+ transformers
4
+ sentencepiece
5
+ protobuf
6
+ bitsandbytes
7
+ prodigyopt
8
+ opencv-python
9
+ beautifulsoup4
10
+ xformers==0.0.27.post2
11
+ flash-attn
12
+ gradio
13
+
14
+
test_creatidesign_benchmark.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from random import uniform
2
+ import torch
3
+ import os
4
+ from torch.utils.data import DataLoader
5
+ from tqdm import tqdm
6
+ import time
7
+ from IPython.core.debugger import set_trace
8
+ from dataloader.creatidesign_dataset_benchmark import DesignDataset,visualize_bbox,collate_fn,tensor_to_pil,make_image_grid_RGB
9
+ import numpy as np
10
+ from PIL import Image
11
+ from safetensors.torch import save_file, load_file
12
+ from accelerate import load_checkpoint_and_dispatch
13
+ from modules.flux.transformer_flux_creatidesign import FluxTransformer2DModel
14
+ from pipeline.pipeline_flux_creatidesign import FluxPipeline
15
+ import json
16
+ from huggingface_hub import snapshot_download
17
+ from datasets import load_dataset
18
+
19
+ if __name__ == "__main__":
20
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
+ weight_dtype = torch.bfloat16
22
+ resolution = 1024
23
+ condition_resolution = 512
24
+ neg_condition_image = 'same'
25
+ background_color = 'gray'
26
+ use_bucket = True
27
+ condition_resolution_scale_ratio=0.5
28
+
29
+ benchmark_repo = 'HuiZhang0812/CreatiDesign_benchmark' # huggingface repo of benchmark
30
+
31
+ datasets = DesignDataset(dataset_name=benchmark_repo,
32
+ resolution=resolution,
33
+ condition_resolution=condition_resolution,
34
+ neg_condition_image =neg_condition_image,
35
+ background_color=background_color,
36
+ use_bucket=use_bucket,
37
+ condition_resolution_scale_ratio=condition_resolution_scale_ratio
38
+ )
39
+ test_dataloader = DataLoader(datasets, batch_size=1, shuffle=False, num_workers=4,collate_fn=collate_fn)
40
+
41
+
42
+ model_path = "black-forest-labs/FLUX.1-dev"
43
+
44
+ ckpt_repo = "HuiZhang0812/CreatiDesign" # huggingface repo of ckpt
45
+
46
+ ckpt_path = snapshot_download(
47
+ repo_id=ckpt_repo,
48
+ repo_type="model",
49
+ local_dir="./CreatiDesign_checkpoint",
50
+ local_dir_use_symlinks=False
51
+ )
52
+
53
+ # Load transformer config from checkpoint
54
+ with open(os.path.join(ckpt_path, "transformer", "config.json"), 'r') as f:
55
+ config = json.load(f)
56
+
57
+ transformer = FluxTransformer2DModel(**config)
58
+ transformer = load_checkpoint_and_dispatch(transformer, checkpoint=os.path.join(model_path,"transformer"), device_map=None)
59
+
60
+ # Load lora parameters using safetensors
61
+ state_dict = load_file(os.path.join(ckpt_path, "transformer","model.safetensors"))
62
+
63
+ # Load parameters, allow partial loading
64
+ missing_keys, unexpected_keys = transformer.load_state_dict(state_dict, strict=False)
65
+
66
+ print(f"Loaded parameters: {len(state_dict)}",state_dict.keys())
67
+ print(f"Missing keys: {len(missing_keys)}",missing_keys)
68
+ print(f"Unexpected keys: {len(unexpected_keys)}",unexpected_keys)
69
+
70
+ transformer = transformer.to(dtype=torch.bfloat16)
71
+
72
+ pipe = FluxPipeline.from_pretrained(model_path, transformer=transformer,torch_dtype=torch.bfloat16)
73
+ pipe = pipe.to("cuda")
74
+
75
+ seed=42
76
+ num_samples = 1
77
+ true_cfg_scale=3.5
78
+ guidance_scale=1.0
79
+ if resolution == 512:
80
+ position_delta=[0,-32]
81
+ else:
82
+ position_delta=[0,-64]
83
+ if use_bucket:
84
+ scale_h = 1/condition_resolution_scale_ratio
85
+ scale_w = 1/condition_resolution_scale_ratio
86
+ else:
87
+ scale_h = resolution/condition_resolution
88
+ scale_w = resolution/condition_resolution
89
+
90
+ num_inference_steps = 28
91
+
92
+ # Create save directory based on benchmark directory name
93
+ save_root =os.path.join("outputs",benchmark_repo.split("/")[-1])
94
+ os.makedirs(save_root,exist_ok=True)
95
+
96
+ img_save_root = os.path.join(save_root,"images")
97
+ os.makedirs(img_save_root,exist_ok=True)
98
+
99
+ img_withgt_save_root = os.path.join(save_root,"images_with_gt")
100
+ os.makedirs(img_withgt_save_root,exist_ok=True)
101
+
102
+ total_time = 0
103
+ for i, batch in enumerate(tqdm(test_dataloader)):
104
+ prompts = batch["caption"]
105
+ imgs_id = batch['id']
106
+ objects_boxes = batch["objects_boxes"]
107
+ objects_caption = batch['objects_caption']
108
+ objects_masks = batch['objects_masks']
109
+ condition_img = batch['condition_img']
110
+ neg_condtion_img = batch['neg_condtion_img']
111
+ objects_masks_maps= batch['objects_masks_maps']
112
+ subject_masks_maps = batch['condition_img_masks_maps']
113
+ target_width=batch['target_width'][0]
114
+ target_height=batch['target_height'][0]
115
+
116
+ img_info = batch["img_info"][0]
117
+ filename = img_info["img_id"]+'.jpg'
118
+ start_time = time.time()
119
+ with torch.no_grad():
120
+ images = pipe(prompt=prompts*num_samples,
121
+ generator=torch.Generator(device="cuda").manual_seed(seed),
122
+ num_inference_steps = num_inference_steps,
123
+ objects_boxes=objects_boxes,
124
+ objects_caption=objects_caption,
125
+ objects_masks = objects_masks,
126
+ objects_masks_maps=objects_masks_maps,
127
+ condition_img = condition_img,
128
+ subject_masks_maps = subject_masks_maps,
129
+ neg_condtion_img = neg_condtion_img,
130
+ height= target_height,
131
+ width = target_width,
132
+ true_cfg_scale = true_cfg_scale,
133
+ position_delta=position_delta,
134
+ guidance_scale=guidance_scale,
135
+ scale_h = scale_h,
136
+ scale_w = scale_w,
137
+ use_bucket=use_bucket
138
+ )
139
+ images=images.images
140
+ use_time = time.time() - start_time
141
+ total_time +=use_time
142
+
143
+ make_image_grid_RGB(images, rows=1, cols=num_samples).save(os.path.join(img_save_root,filename))
144
+ use_time = time.time() - start_time
145
+ total_time +=use_time
146
+
147
+ # Process original image and bounding boxes
148
+ ori_image = tensor_to_pil(batch['img'][0])
149
+ orig_width, orig_height = ori_image.size
150
+ normalized_boxes = batch['objects_boxes'][0].cpu().numpy()
151
+ denormalized_boxes = []
152
+ for box in normalized_boxes:
153
+ x1, y1, x2, y2 = box
154
+ denorm_box = [
155
+ x1 * orig_width, # x1
156
+ y1 * orig_height, # y1
157
+ x2 * orig_width, # x2
158
+ y2 * orig_height # y2
159
+ ]
160
+ denormalized_boxes.append(denorm_box)
161
+
162
+ objects_result = {
163
+ "boxes": denormalized_boxes,
164
+ "labels": batch['objects_caption'][0],
165
+ "masks": []
166
+ }
167
+
168
+ # Only keep boxes and captions where mask is 1
169
+ valid_boxes = []
170
+ valid_labels = []
171
+ for box, label, mask in zip(objects_result['boxes'],
172
+ objects_result['labels'],
173
+ batch['objects_masks'][0]):
174
+ if mask:
175
+ valid_boxes.append(box)
176
+ valid_labels.append(label)
177
+
178
+ objects_result['boxes'] = valid_boxes
179
+ objects_result['labels'] = valid_labels
180
+
181
+ ori_image_with_bbox = visualize_bbox(ori_image ,objects_result)
182
+
183
+ # Concatenate images
184
+ total_width = ori_image.width + ori_image.width+ num_samples*ori_image.width
185
+ max_height = ori_image.height
186
+
187
+ # Create a new blank image to hold the concatenated images
188
+ new_image = Image.new('RGB', (total_width, max_height))
189
+
190
+ new_image.paste(ori_image_with_bbox, (0, 0))
191
+
192
+ # Process condition image
193
+ condition_img = tensor_to_pil(batch['original_size_condition_img'][0])
194
+ subject_canvas_with_bbox = visualize_bbox(condition_img ,objects_result)
195
+
196
+ new_image.paste(subject_canvas_with_bbox, (ori_image.width, 0))
197
+
198
+ # Paste generated images
199
+ for j, image in enumerate(images):
200
+
201
+ save_name=os.path.join(img_withgt_save_root,filename)
202
+
203
+ image_with_bbox = visualize_bbox(image ,objects_result)
204
+
205
+ new_image.paste(image_with_bbox, (ori_image.width*(j+2), 0))
206
+
207
+ new_image.save(save_name)
208
+
209
+ print(f"Total inference time: {total_time:.2f} seconds")
210
+ print(f"Average time per image: {total_time/len(test_dataloader):.2f} seconds")