Spaces:
Build error
Build error
Upload 27 files
Browse files- .gitattributes +10 -0
- README.md +105 -12
- assets/figures/CreatiDesign_logo.png +3 -0
- assets/figures/Qualitative_results.jpg +3 -0
- assets/figures/Quantitative_results.png +3 -0
- assets/figures/architecture.jpg +3 -0
- assets/figures/dataset.jpg +3 -0
- assets/figures/loop_edit.jpg +3 -0
- assets/figures/motivation.jpg +3 -0
- assets/figures/teaser.jpg +3 -0
- dataloader/__pycache__/creatidesign_dataset_benchmark.cpython-310.pyc +0 -0
- dataloader/arial.ttf +3 -0
- dataloader/creatidesign_dataset_benchmark.py +554 -0
- eval/layout.py +194 -0
- eval/subject.py +233 -0
- eval/text.py +184 -0
- modules/common/__pycache__/lora.cpython-310.pyc +0 -0
- modules/common/lora.py +26 -0
- modules/flux/__pycache__/attention_processor_flux_creatidesign.cpython-310.pyc +3 -0
- modules/flux/__pycache__/transformer_flux_creatidesign.cpython-310.pyc +0 -0
- modules/flux/attention_processor_flux_creatidesign.py +0 -0
- modules/flux/transformer_flux_creatidesign.py +1004 -0
- modules/semantic_layout/__pycache__/layout_encoder.cpython-310.pyc +0 -0
- modules/semantic_layout/layout_encoder.py +139 -0
- pipeline/__pycache__/pipeline_flux_creatidesign.cpython-310.pyc +0 -0
- pipeline/pipeline_flux_creatidesign.py +1068 -0
- requirements.txt +14 -6
- test_creatidesign_benchmark.py +210 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/figures/architecture.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/figures/CreatiDesign_logo.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/figures/dataset.jpg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/figures/loop_edit.jpg filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
assets/figures/motivation.jpg filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
assets/figures/Qualitative_results.jpg filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
assets/figures/Quantitative_results.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
assets/figures/teaser.jpg filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
dataloader/arial.ttf filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
modules/flux/__pycache__/attention_processor_flux_creatidesign.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,12 +1,105 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# <img src='assets/figures/CreatiDesign_logo.png' alt="CreatiDesign Logo" width='24px' /> CreatiDesign
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
<img src='assets/figures/teaser.jpg' width='100%' />
|
| 6 |
+
|
| 7 |
+
<br>
|
| 8 |
+
<a href="https://arxiv.org/pdf/2505.19114"><img src="https://img.shields.io/static/v1?label=Paper&message=2505.19114&color=red&logo=arxiv"></a>
|
| 9 |
+
<a href="https://huizhang0812.github.io/CreatiDesign/"><img src="https://img.shields.io/static/v1?label=Project%20Page&message=Github&color=blue&logo=github-pages"></a>
|
| 10 |
+
<a href="https://huggingface.co/datasets/HuiZhang0812/CreatiDesign_dataset"><img src="https://img.shields.io/badge/🤗_HuggingFace-Dataset-ffbd45.svg" alt="HuggingFace"></a>
|
| 11 |
+
<a href="https://huggingface.co/datasets/HuiZhang0812/CreatiDesign_benchmark"><img src="https://img.shields.io/badge/🤗_HuggingFace-Benchmark-ffbd45.svg" alt="HuggingFace"></a>
|
| 12 |
+
<a href="https://huggingface.co/HuiZhang0812/CreatiDesign"><img src="https://img.shields.io/badge/🤗_HuggingFace-Model-ffbd45.svg" alt="HuggingFace"></a>
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
> <img src='assets/figures/CreatiDesign_logo.png' alt="CreatiDesign Logo" width='15px' /> **CreatiDesign: A Unified Multi-Conditional Diffusion Transformer for Creative Graphic Design**
|
| 17 |
+
> <br>
|
| 18 |
+
> [Hui Zhang](https://huizhang0812.github.io/),
|
| 19 |
+
> [Dexiang Hong](https://scholar.google.com.hk/citations?user=DUNijlcAAAAJ&hl=zh-CN),
|
| 20 |
+
> Maoke Yang,
|
| 21 |
+
> Yutao Cheng,
|
| 22 |
+
> Zhao Zhang,
|
| 23 |
+
> Jie Shao,
|
| 24 |
+
> [Xinglong Wu](https://scholar.google.com/citations?user=LVsp9RQAAAAJ&hl=zh-CN),
|
| 25 |
+
> [Zuxuan Wu](https://zxwu.azurewebsites.net/),
|
| 26 |
+
> and
|
| 27 |
+
> [Yu-Gang Jiang](https://scholar.google.com/citations?user=f3_FP8AAAAAJ)
|
| 28 |
+
> <br>
|
| 29 |
+
> Fudan University & ByteDance Intelligent Creation.
|
| 30 |
+
> <br>
|
| 31 |
+
|
| 32 |
+
## 🎯 Introduction
|
| 33 |
+
CreatiDesign tackles the challenge of automated graphic design generation that requires precise control over multiple heterogeneous elements—primary visual elements (product images), secondary visual elements (decorative objects), and textual elements (slogans, titles). CreatiDesign introduces a unified multi-conditional diffusion transformer that achieves flexible and harmonious integration of diverse design elements with minimal architectural modifications.
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
<img src='assets/figures/motivation.jpg' width='100%' />
|
| 37 |
+
|
| 38 |
+
## ✨ Key Features
|
| 39 |
+
|
| 40 |
+
- **🎨 Multi-Conditional Image Generation**: Unified architecture supporting images, semantic layouts conditions simultaneously
|
| 41 |
+
- **🎯 Precise Element Control**: Multimodal attention mask mechanism prevents condition interference
|
| 42 |
+
- **🗂️ Graphic Design Datasets**: 400K graphic design samples with multi-condition annotations construced by automatic pipeline
|
| 43 |
+
- **📊 Comprehensive Benchmark**: Rigorous evaluation of multi-subject preservation and semantic layout alignment.
|
| 44 |
+
- **✏️ Zero-Shot Editing**: Natural extension to editing tasks without additional training or retraining
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
## Quick Start
|
| 49 |
+
### Setup
|
| 50 |
+
1. **Environment setup**
|
| 51 |
+
```bash
|
| 52 |
+
conda create -n creatidesign python=3.10 -y
|
| 53 |
+
conda activate creatidesign
|
| 54 |
+
conda install pytorch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 pytorch-cuda=12.1 -c pytorch -c nvidia
|
| 55 |
+
```
|
| 56 |
+
2. **Requirements installation**
|
| 57 |
+
```bash
|
| 58 |
+
pip install -r requirements.txt
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
## Dataset and Benchmark
|
| 63 |
+
### CreatiDesign Datasets <a href="https://huggingface.co/datasets/HuiZhang0812/CreatiDesign_dataset"><img src="https://img.shields.io/badge/🤗_HuggingFace-Dataset-ffbd45.svg" alt="HuggingFace"></a>
|
| 64 |
+
Our CreatiDesign dataset contains **400K high-quality graphic design samples** with comprehensive multi-condition annotations, constructed through our fully automated pipeline. The dataset covers diverse design categories including movie posters, product advertisements, brand promotions, and social media content.
|
| 65 |
+
|
| 66 |
+
### CreatiDesign Benchmark <a href="https://huggingface.co/datasets/HuiZhang0812/CreatiDesign_benchmark"><img src="https://img.shields.io/badge/🤗_HuggingFace-Benchmark-ffbd45.svg" alt="HuggingFace"></a>
|
| 67 |
+
Our comprehensive benchmark contains **1,000 carefully curated samples** designed to rigorously evaluate graphic design generation capabilities across multiple dimensions. The benchmark assesses both fine-grained condition adherence and overall visual quality.
|
| 68 |
+
|
| 69 |
+
To evaluate the model's graphic design generation capabilities through our benchmark, follow these steps:
|
| 70 |
+
|
| 71 |
+
Generate images:
|
| 72 |
+
```python
|
| 73 |
+
python test_creatidesign_benchmark.py
|
| 74 |
+
```
|
| 75 |
+
Evaluate multi-subject preservation:
|
| 76 |
+
```python
|
| 77 |
+
python eval/subject.py
|
| 78 |
+
```
|
| 79 |
+
Evaluate semantic layout alignment:
|
| 80 |
+
```python
|
| 81 |
+
python eval/layout.py
|
| 82 |
+
```
|
| 83 |
+
```python
|
| 84 |
+
python eval/text.py
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
## Models
|
| 89 |
+
**Multi-Conditional Graphic Design:**
|
| 90 |
+
| Model | Base model | Description |
|
| 91 |
+
| ------------------------------------------------------------------------------------------------ | -------------- | -------------------------------------------------------------------------------------------------------- |
|
| 92 |
+
| <a href="https://huggingface.co/HuiZhang0812/CreatiDesign"><img src="https://img.shields.io/badge/🤗_HuggingFace-Model-ffbd45.svg" alt="HuggingFace"></a> | FLUX.1-dev | model used in the paper
|
| 93 |
+
|
| 94 |
+
## ✒️ Citation
|
| 95 |
+
|
| 96 |
+
If you find our work useful for your research and applications, please kindly cite using this BibTeX:
|
| 97 |
+
|
| 98 |
+
```latex
|
| 99 |
+
@article{zhang2025creatidesign,
|
| 100 |
+
title={CreatiDesign: A Unified Multi-Conditional Diffusion Transformer for Creative Graphic Design},
|
| 101 |
+
author={Zhang, Hui and Hong, Dexiang and Yang, Maoke and Chen, Yutao and Zhang, Zhao and Shao, Jie and Wu, Xinglong and Wu, Zuxuan and Jiang, Yu-Gang},
|
| 102 |
+
journal={arXiv preprint arXiv:2505.19114},
|
| 103 |
+
year={2025}
|
| 104 |
+
}
|
| 105 |
+
```
|
assets/figures/CreatiDesign_logo.png
ADDED
|
Git LFS Details
|
assets/figures/Qualitative_results.jpg
ADDED
|
Git LFS Details
|
assets/figures/Quantitative_results.png
ADDED
|
Git LFS Details
|
assets/figures/architecture.jpg
ADDED
|
Git LFS Details
|
assets/figures/dataset.jpg
ADDED
|
Git LFS Details
|
assets/figures/loop_edit.jpg
ADDED
|
Git LFS Details
|
assets/figures/motivation.jpg
ADDED
|
Git LFS Details
|
assets/figures/teaser.jpg
ADDED
|
Git LFS Details
|
dataloader/__pycache__/creatidesign_dataset_benchmark.cpython-310.pyc
ADDED
|
Binary file (13.1 kB). View file
|
|
|
dataloader/arial.ttf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:35c0f3559d8db569e36c31095b8a60d441643d95f59139de40e23fada819b833
|
| 3 |
+
size 275572
|
dataloader/creatidesign_dataset_benchmark.py
ADDED
|
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from torch.utils.data import Dataset, DataLoader
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
import random
|
| 9 |
+
from datasets import load_dataset
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
def find_nearest_bucket_size(input_width, input_height, mode="x64", ratio=1):
|
| 12 |
+
buckets = [
|
| 13 |
+
(512, 2048),
|
| 14 |
+
(512, 1984),
|
| 15 |
+
(512, 1920),
|
| 16 |
+
(512, 1856),
|
| 17 |
+
(576, 1792),
|
| 18 |
+
(576, 1728),
|
| 19 |
+
(576, 1664),
|
| 20 |
+
(640, 1600),
|
| 21 |
+
(640, 1536),
|
| 22 |
+
(704, 1472),
|
| 23 |
+
(704, 1408),
|
| 24 |
+
(704, 1344),
|
| 25 |
+
(768, 1344),
|
| 26 |
+
(768, 1280),
|
| 27 |
+
(832, 1216),
|
| 28 |
+
(832, 1152),
|
| 29 |
+
(896, 1152),
|
| 30 |
+
(896, 1088),
|
| 31 |
+
(960, 1088),
|
| 32 |
+
(960, 1024),
|
| 33 |
+
(1024, 1024),
|
| 34 |
+
(1024, 960),
|
| 35 |
+
(1088, 960),
|
| 36 |
+
(1088, 896),
|
| 37 |
+
(1152, 896),
|
| 38 |
+
(1152, 832),
|
| 39 |
+
(1216, 832),
|
| 40 |
+
(1280, 768),
|
| 41 |
+
(1344, 768),
|
| 42 |
+
(1408, 704),
|
| 43 |
+
(1472, 704),
|
| 44 |
+
(1536, 640),
|
| 45 |
+
(1600, 640),
|
| 46 |
+
(1664, 576),
|
| 47 |
+
(1728, 576),
|
| 48 |
+
(1792, 576),
|
| 49 |
+
(1856, 512),
|
| 50 |
+
(1920, 512),
|
| 51 |
+
(1984, 512),
|
| 52 |
+
(2048, 512)
|
| 53 |
+
]
|
| 54 |
+
aspect_ratios = [w / h for (w, h) in buckets]
|
| 55 |
+
|
| 56 |
+
assert mode in ["x64", "x8"]
|
| 57 |
+
if mode == "x64":
|
| 58 |
+
asp = input_width / input_height
|
| 59 |
+
diff = [abs(ar - asp) for ar in aspect_ratios]
|
| 60 |
+
bucket_id = int(np.argmin(diff))
|
| 61 |
+
gen_width, gen_height = buckets[bucket_id]
|
| 62 |
+
elif mode == "x8":
|
| 63 |
+
max_pixels = 1024 * 1024
|
| 64 |
+
ratio = (max_pixels / (input_width * input_height)) ** (0.5)
|
| 65 |
+
gen_width, gen_height = round(input_width * ratio), round(input_height * ratio)
|
| 66 |
+
gen_width = gen_width - gen_width % 8
|
| 67 |
+
gen_height = gen_height - gen_height % 8
|
| 68 |
+
else:
|
| 69 |
+
raise NotImplementedError
|
| 70 |
+
|
| 71 |
+
return (int(gen_width * ratio), int(gen_height * ratio))
|
| 72 |
+
|
| 73 |
+
def adjust_and_normalize_bboxes(bboxes, orig_width, orig_height):
|
| 74 |
+
# Adjust and normalize bbox
|
| 75 |
+
normalized_bboxes = []
|
| 76 |
+
for bbox in bboxes:
|
| 77 |
+
x1, y1, x2, y2 = bbox
|
| 78 |
+
x1_norm = round(x1 / orig_width,2)
|
| 79 |
+
y1_norm = round(y1 / orig_height,2)
|
| 80 |
+
x2_norm = round(x2 / orig_width,2)
|
| 81 |
+
y2_norm = round(y2 / orig_height,2)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
normalized_bboxes.append([x1_norm, y1_norm, x2_norm, y2_norm])
|
| 85 |
+
|
| 86 |
+
return normalized_bboxes
|
| 87 |
+
|
| 88 |
+
def img_transforms(image, height=512, width=512):
|
| 89 |
+
transform = transforms.Compose(
|
| 90 |
+
[
|
| 91 |
+
transforms.Resize(
|
| 92 |
+
(height, width), interpolation=transforms.InterpolationMode.BILINEAR
|
| 93 |
+
),
|
| 94 |
+
transforms.ToTensor(),
|
| 95 |
+
transforms.Normalize([0.5], [0.5]),
|
| 96 |
+
]
|
| 97 |
+
)
|
| 98 |
+
image_transformed = transform(image)
|
| 99 |
+
return image_transformed
|
| 100 |
+
|
| 101 |
+
def mask_transforms(mask, height=512, width=512):
|
| 102 |
+
transform = transforms.Compose(
|
| 103 |
+
[
|
| 104 |
+
transforms.Resize(
|
| 105 |
+
(height, width),
|
| 106 |
+
interpolation=transforms.InterpolationMode.NEAREST
|
| 107 |
+
),
|
| 108 |
+
transforms.ToTensor(),
|
| 109 |
+
]
|
| 110 |
+
)
|
| 111 |
+
mask_transformed = transform(mask)
|
| 112 |
+
return mask_transformed
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class DesignDataset(Dataset):
|
| 116 |
+
|
| 117 |
+
def __init__(
|
| 118 |
+
self,
|
| 119 |
+
dataset_name,
|
| 120 |
+
resolution=512,
|
| 121 |
+
condition_resolution=512,
|
| 122 |
+
condition_resolution_scale_ratio=0.5,
|
| 123 |
+
max_boxes_per_image=10,
|
| 124 |
+
neg_condition_image = 'same',
|
| 125 |
+
background_color = 'gray',
|
| 126 |
+
use_bucket=True,
|
| 127 |
+
box_confidence_th = 0.0
|
| 128 |
+
):
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
print(f"Loading dataset from Hugging Face: {dataset_name}")
|
| 132 |
+
|
| 133 |
+
self.dataset = load_dataset(dataset_name, split="test")
|
| 134 |
+
print(f"Loaded {len(self.dataset)} samples")
|
| 135 |
+
from IPython.core.debugger import set_trace
|
| 136 |
+
set_trace()
|
| 137 |
+
self.max_boxes_per_image = max_boxes_per_image
|
| 138 |
+
self.resolution = resolution
|
| 139 |
+
self.condition_resolution=condition_resolution
|
| 140 |
+
self.neg_condition_image = neg_condition_image
|
| 141 |
+
self.use_bucket = use_bucket
|
| 142 |
+
self.condition_resolution_scale_ratio=condition_resolution_scale_ratio
|
| 143 |
+
self.box_confidence_th = box_confidence_th
|
| 144 |
+
|
| 145 |
+
if background_color == 'white':
|
| 146 |
+
self.background_color = (255, 255, 255)
|
| 147 |
+
elif background_color == 'black':
|
| 148 |
+
self.background_color = (0, 0, 0)
|
| 149 |
+
elif background_color == 'gray':
|
| 150 |
+
self.background_color = (128, 128, 128)
|
| 151 |
+
else:
|
| 152 |
+
raise ValueError("Invalid background color. Use 'white' or 'black'.")
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def __len__(self):
|
| 156 |
+
return len(self.dataset)
|
| 157 |
+
|
| 158 |
+
def __getitem__(self, idx):
|
| 159 |
+
sample = self.dataset[idx]
|
| 160 |
+
image_source = sample['original_image']
|
| 161 |
+
subject_image = sample['condition_gray_background']
|
| 162 |
+
subject_mask = sample['subject_mask']
|
| 163 |
+
json_data = json.loads(sample['metadata'])
|
| 164 |
+
|
| 165 |
+
#img info
|
| 166 |
+
img_info = json_data['img_info']
|
| 167 |
+
img_id = img_info['img_id']
|
| 168 |
+
orig_width, orig_height = int(img_info["img_width"]),int(img_info["img_height"])
|
| 169 |
+
|
| 170 |
+
if self.use_bucket:
|
| 171 |
+
target_width, target_height = find_nearest_bucket_size(orig_width,orig_height)
|
| 172 |
+
condition_width = int(target_width * self.condition_resolution_scale_ratio)
|
| 173 |
+
condition_height = int(target_height * self.condition_resolution_scale_ratio)
|
| 174 |
+
else:
|
| 175 |
+
target_width = target_height = self.resolution
|
| 176 |
+
condition_width = condition_height = self.condition_resolution
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
img_tensor = img_transforms(image_source,height=target_height,width=target_width)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# global caption
|
| 183 |
+
global_caption = json_data['global_caption']
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# object_annotations
|
| 187 |
+
object_annotations = json_data['object_annotations']
|
| 188 |
+
|
| 189 |
+
# object bbox list
|
| 190 |
+
objects_bbox = [item['bbox'] for item in object_annotations]
|
| 191 |
+
|
| 192 |
+
# object bbox caption
|
| 193 |
+
objects_caption = [item['bbox_detail_description'] for item in object_annotations]
|
| 194 |
+
|
| 195 |
+
# object bbox score
|
| 196 |
+
objects_bbox_score = [item['score'][0] for item in object_annotations]
|
| 197 |
+
|
| 198 |
+
# text
|
| 199 |
+
text_list = json_data["text_list"]
|
| 200 |
+
txt_bboxs = [item['bbox'] for item in text_list]
|
| 201 |
+
txt_captions = ["text:"+item['text'] for item in text_list]
|
| 202 |
+
|
| 203 |
+
txt_scores = [1.0 for _ in txt_bboxs]
|
| 204 |
+
# combine bbox 和 description
|
| 205 |
+
objects_bbox.extend(txt_bboxs)
|
| 206 |
+
objects_caption.extend(txt_captions)
|
| 207 |
+
objects_bbox_score.extend(txt_scores)
|
| 208 |
+
|
| 209 |
+
objects_bbox =torch.tensor(adjust_and_normalize_bboxes(objects_bbox,orig_width,orig_height))
|
| 210 |
+
|
| 211 |
+
objects_bbox_score = torch.tensor(objects_bbox_score)
|
| 212 |
+
|
| 213 |
+
boxes_mask = objects_bbox_score > self.box_confidence_th
|
| 214 |
+
objects_bbox_raw = objects_bbox[boxes_mask]
|
| 215 |
+
objects_caption = [object_caption for object_caption, box_mask in zip(objects_caption, boxes_mask) if box_mask]
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
num_boxes = objects_bbox_raw.shape[0]
|
| 219 |
+
objects_boxes_padded = torch.zeros((self.max_boxes_per_image, 4))
|
| 220 |
+
objects_masks_padded = torch.zeros(self.max_boxes_per_image)
|
| 221 |
+
|
| 222 |
+
objects_caption = objects_caption[:self.max_boxes_per_image]
|
| 223 |
+
objects_boxes_padded[:num_boxes] = objects_bbox_raw[:self.max_boxes_per_image]
|
| 224 |
+
objects_masks_padded[:num_boxes] = 1.
|
| 225 |
+
|
| 226 |
+
# objects_masks_maps
|
| 227 |
+
objects_masks_maps_padded = torch.zeros((self.max_boxes_per_image, target_height, target_width))
|
| 228 |
+
for idx in range(num_boxes):
|
| 229 |
+
x1, y1, x2, y2 = objects_boxes_padded[idx]
|
| 230 |
+
|
| 231 |
+
x1_pixel = int(x1 * target_width)
|
| 232 |
+
y1_pixel = int(y1 * target_height)
|
| 233 |
+
x2_pixel = int(x2 * target_width)
|
| 234 |
+
y2_pixel = int(y2 * target_height)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
x1_pixel = max(0, min(x1_pixel, target_width-1))
|
| 238 |
+
y1_pixel = max(0, min(y1_pixel, target_height-1))
|
| 239 |
+
x2_pixel = max(0, min(x2_pixel, target_width-1))
|
| 240 |
+
y2_pixel = max(0, min(y2_pixel, target_height-1))
|
| 241 |
+
|
| 242 |
+
objects_masks_maps_padded[idx, y1_pixel:y2_pixel+1, x1_pixel:x2_pixel+1] = 1.0
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# subject
|
| 247 |
+
original_size_subject_tensor = img_transforms(subject_image,height=target_height,width=target_width)
|
| 248 |
+
subject_tensor = img_transforms(subject_image,height=condition_height,width=condition_width)
|
| 249 |
+
subject_mask_tensor = mask_transforms(subject_mask, height=condition_height,width=condition_width)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
if self.neg_condition_image=='black':
|
| 253 |
+
subject_image_black = Image.new('RGB', (orig_width, orig_height), (0, 0, 0))
|
| 254 |
+
subject_image_neg_tensor = img_transforms(subject_image_black,height=condition_height,width=condition_width)
|
| 255 |
+
elif self.neg_condition_image=='white':
|
| 256 |
+
subject_image_white = Image.new('RGB', (orig_width, orig_height), (255, 255, 255))
|
| 257 |
+
subject_image_neg_tensor = img_transforms(subject_image_white,height=condition_height,width=condition_width)
|
| 258 |
+
elif self.neg_condition_image=='gray':
|
| 259 |
+
subject_image_gray = Image.new('RGB', (orig_width, orig_height), (128, 128, 128))
|
| 260 |
+
subject_image_neg_tensor = img_transforms(subject_image_gray,height=condition_height,width=condition_width)
|
| 261 |
+
elif self.neg_condition_image=='same':
|
| 262 |
+
subject_image_neg_tensor = subject_tensor
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
output = dict(
|
| 266 |
+
id=img_id,
|
| 267 |
+
caption=global_caption,
|
| 268 |
+
objects_boxes=objects_boxes_padded,
|
| 269 |
+
objects_caption=objects_caption,
|
| 270 |
+
objects_masks=objects_masks_padded,
|
| 271 |
+
objects_masks_maps=objects_masks_maps_padded,
|
| 272 |
+
img=img_tensor,
|
| 273 |
+
condition_img_masks_maps = subject_mask_tensor,
|
| 274 |
+
condition_img = subject_tensor,
|
| 275 |
+
original_size_condition_img = original_size_subject_tensor,
|
| 276 |
+
neg_condtion_img = subject_image_neg_tensor,
|
| 277 |
+
img_info = img_info,
|
| 278 |
+
target_width=target_width,
|
| 279 |
+
target_height=target_height,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
return output
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def collate_fn(examples):
|
| 286 |
+
|
| 287 |
+
collated_examples = {}
|
| 288 |
+
|
| 289 |
+
for key in ['id', 'objects_caption', 'caption','img_info','target_width','target_height']:
|
| 290 |
+
collated_examples[key] = [example[key] for example in examples]
|
| 291 |
+
|
| 292 |
+
for key in ['img', 'objects_boxes', 'objects_masks','condition_img','neg_condtion_img','objects_masks_maps','condition_img_masks_maps','original_size_condition_img']:
|
| 293 |
+
collated_examples[key] = torch.stack([example[key] for example in examples]).float()
|
| 294 |
+
|
| 295 |
+
return collated_examples
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
from typing import Dict
|
| 301 |
+
|
| 302 |
+
import numpy as np
|
| 303 |
+
from PIL import Image, ImageDraw, ImageFont, ImageOps
|
| 304 |
+
import random
|
| 305 |
+
def draw_mask(mask, draw, random_color=True):
|
| 306 |
+
"""Draws a mask with a specified color on an image.
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
mask (np.array): Binary mask as a NumPy array.
|
| 310 |
+
draw (ImageDraw.Draw): ImageDraw object to draw on the image.
|
| 311 |
+
random_color (bool): Whether to use a random color for the mask.
|
| 312 |
+
"""
|
| 313 |
+
if random_color:
|
| 314 |
+
color = (
|
| 315 |
+
random.randint(0, 255),
|
| 316 |
+
random.randint(0, 255),
|
| 317 |
+
random.randint(0, 255),
|
| 318 |
+
153,
|
| 319 |
+
)
|
| 320 |
+
else:
|
| 321 |
+
color = (30, 144, 255, 153)
|
| 322 |
+
|
| 323 |
+
nonzero_coords = np.transpose(np.nonzero(mask))
|
| 324 |
+
|
| 325 |
+
for coord in nonzero_coords:
|
| 326 |
+
draw.point(coord[::-1], fill=color)
|
| 327 |
+
|
| 328 |
+
def visualize_bbox(image_pil: Image,
|
| 329 |
+
result: Dict,
|
| 330 |
+
draw_width: float = 6.0,
|
| 331 |
+
return_mask=True) -> Image:
|
| 332 |
+
"""Plot bounding boxes and labels on an image with text wrapping for long descriptions.
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
image_pil (PIL.Image): The input image as a PIL Image object.
|
| 336 |
+
result (Dict[str, Union[torch.Tensor, List[torch.Tensor]]]): The target dictionary containing
|
| 337 |
+
the bounding boxes and labels. The keys are:
|
| 338 |
+
- boxes (List[int]): A list of bounding boxes in shape (N, 4), [x1, y1, x2, y2] format.
|
| 339 |
+
- labels (List[str]): A list of labels for each object
|
| 340 |
+
- masks (List[PIL.Image], optional): A list of masks in the format of PIL.Image
|
| 341 |
+
|
| 342 |
+
Returns:
|
| 343 |
+
PIL.Image: The input image with plotted bounding boxes, labels, and masks.
|
| 344 |
+
"""
|
| 345 |
+
# Get the bounding boxes and labels from the target dictionary
|
| 346 |
+
boxes = result["boxes"]
|
| 347 |
+
categorys = result["labels"]
|
| 348 |
+
masks = result.get("masks", [])
|
| 349 |
+
|
| 350 |
+
color_list = [(255, 162, 76), (177, 214, 144),
|
| 351 |
+
(13, 146, 244), (249, 84, 84), (54, 186, 152),
|
| 352 |
+
(74, 36, 157), (0, 159, 189),
|
| 353 |
+
(80, 118, 135), (188, 90, 148), (119, 205, 255)]
|
| 354 |
+
|
| 355 |
+
# Use smaller font size to allow more text to be displayed
|
| 356 |
+
font_size = 30 # Reduce font size
|
| 357 |
+
font = ImageFont.truetype("dataloader/arial.ttf", font_size)
|
| 358 |
+
|
| 359 |
+
# Get image dimensions
|
| 360 |
+
img_width, img_height = image_pil.size
|
| 361 |
+
|
| 362 |
+
# Find all unique categories and build a cate2color dictionary
|
| 363 |
+
cate2color = {}
|
| 364 |
+
unique_categorys = sorted(set(categorys))
|
| 365 |
+
for idx, cate in enumerate(unique_categorys):
|
| 366 |
+
cate2color[cate] = color_list[idx % len(color_list)]
|
| 367 |
+
|
| 368 |
+
# Create a PIL ImageDraw object to draw on the input image
|
| 369 |
+
if isinstance(image_pil, np.ndarray):
|
| 370 |
+
image_pil = Image.fromarray(image_pil)
|
| 371 |
+
draw = ImageDraw.Draw(image_pil)
|
| 372 |
+
|
| 373 |
+
# Create a new binary mask image with the same size as the input image
|
| 374 |
+
mask = Image.new("L", image_pil.size, 0)
|
| 375 |
+
# Create a PIL ImageDraw object to draw on the mask image
|
| 376 |
+
mask_draw = ImageDraw.Draw(mask)
|
| 377 |
+
|
| 378 |
+
# Draw boxes, labels, and masks for each box and label in the target dictionary
|
| 379 |
+
for box, category in zip(boxes, categorys):
|
| 380 |
+
# Extract the box coordinates
|
| 381 |
+
x0, y0, x1, y1 = box
|
| 382 |
+
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
|
| 383 |
+
box_width = x1 - x0
|
| 384 |
+
box_height = y1 - y0
|
| 385 |
+
color = cate2color.get(category, color_list[0]) # Default color
|
| 386 |
+
|
| 387 |
+
# Draw the box outline on the input image
|
| 388 |
+
draw.rectangle([x0, y0, x1, y1], outline=color, width=int(draw_width))
|
| 389 |
+
|
| 390 |
+
# Allow text box to be maximum 2 times the bounding box width, but not exceed image boundaries
|
| 391 |
+
max_text_width = min(box_width * 2, img_width - x0)
|
| 392 |
+
|
| 393 |
+
# Determine the maximum height for text background area
|
| 394 |
+
max_text_height = min(box_height * 2, 200) # Also allow more text display, but limit height
|
| 395 |
+
|
| 396 |
+
# Handle long text based on bounding box width, split text into lines
|
| 397 |
+
lines = []
|
| 398 |
+
words = category.split()
|
| 399 |
+
current_line = words[0]
|
| 400 |
+
|
| 401 |
+
for word in words[1:]:
|
| 402 |
+
# Try to add the next word
|
| 403 |
+
test_line = current_line + " " + word
|
| 404 |
+
# Use textbbox or textlength to check if width fits the maximum text width
|
| 405 |
+
if hasattr(draw, "textbbox"):
|
| 406 |
+
# Use textbbox method
|
| 407 |
+
bbox = draw.textbbox((0, 0), test_line, font=font)
|
| 408 |
+
w = bbox[2] - bbox[0]
|
| 409 |
+
elif hasattr(draw, "textlength"):
|
| 410 |
+
# Use textlength method
|
| 411 |
+
w = draw.textlength(test_line, font=font)
|
| 412 |
+
else:
|
| 413 |
+
# Fallback - estimate width
|
| 414 |
+
w = len(test_line) * (font_size * 0.6) # Estimate average character width
|
| 415 |
+
|
| 416 |
+
if w <= max_text_width - 20: # Leave some margin
|
| 417 |
+
current_line = test_line
|
| 418 |
+
else:
|
| 419 |
+
lines.append(current_line)
|
| 420 |
+
current_line = word
|
| 421 |
+
|
| 422 |
+
lines.append(current_line) # Add the last line
|
| 423 |
+
|
| 424 |
+
# Limit number of lines to prevent overflow
|
| 425 |
+
max_lines = max_text_height // (font_size + 2) # Line height (font size + spacing)
|
| 426 |
+
if len(lines) > max_lines:
|
| 427 |
+
lines = lines[:max_lines-1]
|
| 428 |
+
lines.append("...") # Add ellipsis
|
| 429 |
+
|
| 430 |
+
# Calculate actual required width for each line
|
| 431 |
+
line_widths = []
|
| 432 |
+
for line in lines:
|
| 433 |
+
if hasattr(draw, "textbbox"):
|
| 434 |
+
bbox = draw.textbbox((0, 0), line, font=font)
|
| 435 |
+
line_width = bbox[2] - bbox[0]
|
| 436 |
+
elif hasattr(draw, "textlength"):
|
| 437 |
+
line_width = draw.textlength(line, font=font)
|
| 438 |
+
else:
|
| 439 |
+
line_width = len(line) * (font_size * 0.6) # Estimate width
|
| 440 |
+
line_widths.append(line_width)
|
| 441 |
+
|
| 442 |
+
# Determine actual required width for text box
|
| 443 |
+
if line_widths:
|
| 444 |
+
needed_text_width = max(line_widths) + 10 # Add small margin
|
| 445 |
+
else:
|
| 446 |
+
needed_text_width = 0
|
| 447 |
+
|
| 448 |
+
# Use bounding box width as minimum, only expand when needed
|
| 449 |
+
text_bg_width = max(box_width, min(needed_text_width, max_text_width))
|
| 450 |
+
|
| 451 |
+
# Ensure it doesn't exceed image boundaries
|
| 452 |
+
text_bg_width = min(text_bg_width, img_width - x0)
|
| 453 |
+
|
| 454 |
+
# Calculate text background height
|
| 455 |
+
text_bg_height = len(lines) * (font_size + 2)
|
| 456 |
+
|
| 457 |
+
# Ensure text background doesn't exceed image bottom
|
| 458 |
+
if y0 + text_bg_height > img_height:
|
| 459 |
+
# If it would exceed bottom, adjust text position to above the bounding box bottom
|
| 460 |
+
text_y0 = max(0, y1 - text_bg_height)
|
| 461 |
+
else:
|
| 462 |
+
text_y0 = y0
|
| 463 |
+
|
| 464 |
+
# Draw text background - note RGBA color handling
|
| 465 |
+
if image_pil.mode == "RGBA":
|
| 466 |
+
# For RGBA mode, we can directly use alpha color
|
| 467 |
+
bg_color = (*color, 180) # Semi-transparent background
|
| 468 |
+
else:
|
| 469 |
+
# For RGB mode, we cannot use alpha
|
| 470 |
+
bg_color = color
|
| 471 |
+
|
| 472 |
+
draw.rectangle([x0, text_y0, x0 + text_bg_width, text_y0 + text_bg_height], fill=bg_color)
|
| 473 |
+
|
| 474 |
+
# Draw text
|
| 475 |
+
for i, line in enumerate(lines):
|
| 476 |
+
y_pos = text_y0 + i * (font_size + 2)
|
| 477 |
+
draw.text((x0 + 5, y_pos), line, fill="white", font=font)
|
| 478 |
+
|
| 479 |
+
# Draw the mask on the input image if masks are provided
|
| 480 |
+
if len(masks) > 0 and return_mask:
|
| 481 |
+
size = image_pil.size
|
| 482 |
+
mask_image = Image.new("RGBA", size, color=(0, 0, 0, 0))
|
| 483 |
+
mask_draw = ImageDraw.Draw(mask_image)
|
| 484 |
+
for mask in masks:
|
| 485 |
+
mask = np.array(mask)[:, :, -1]
|
| 486 |
+
draw_mask(mask, mask_draw)
|
| 487 |
+
|
| 488 |
+
image_pil = Image.alpha_composite(image_pil.convert("RGBA"), mask_image).convert("RGB")
|
| 489 |
+
|
| 490 |
+
return image_pil
|
| 491 |
+
|
| 492 |
+
import torchvision.transforms as T
|
| 493 |
+
from PIL import Image, ImageDraw, ImageFont, ImageChops
|
| 494 |
+
|
| 495 |
+
def tensor_to_pil(img_tensor):
|
| 496 |
+
"""将tensor转换为PIL图像"""
|
| 497 |
+
img_tensor = img_tensor.cpu()
|
| 498 |
+
# 反归一化 ([0.5], [0.5])
|
| 499 |
+
img_tensor = img_tensor * 0.5 + 0.5
|
| 500 |
+
img_tensor = torch.clamp(img_tensor, 0, 1)
|
| 501 |
+
return T.ToPILImage()(img_tensor)
|
| 502 |
+
|
| 503 |
+
def make_image_grid_RGB(images, rows, cols, resize=None):
|
| 504 |
+
"""
|
| 505 |
+
Prepares a single grid of images. Useful for visualization purposes.
|
| 506 |
+
"""
|
| 507 |
+
assert len(images) == rows * cols
|
| 508 |
+
|
| 509 |
+
if resize is not None:
|
| 510 |
+
images = [img.resize((resize, resize)) for img in images]
|
| 511 |
+
|
| 512 |
+
w, h = images[0].size
|
| 513 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
| 514 |
+
|
| 515 |
+
for i, img in enumerate(images):
|
| 516 |
+
grid.paste(img.convert("RGB"), box=(i % cols * w, i // cols * h))
|
| 517 |
+
return grid
|
| 518 |
+
|
| 519 |
+
if __name__ == "__main__":
|
| 520 |
+
resolution = 1024
|
| 521 |
+
condition_resolution = 512
|
| 522 |
+
neg_condition_image = 'same'
|
| 523 |
+
background_color = 'gray'
|
| 524 |
+
use_bucket = True
|
| 525 |
+
condition_resolution_scale_ratio=0.5
|
| 526 |
+
|
| 527 |
+
benchmark_repo = 'HuiZhang0812/CreatiDesign_benchmark' # huggingface repo of benchmark
|
| 528 |
+
|
| 529 |
+
datasets = DesignDataset(dataset_name=benchmark_repo,
|
| 530 |
+
resolution=resolution,
|
| 531 |
+
condition_resolution=condition_resolution,
|
| 532 |
+
neg_condition_image =neg_condition_image,
|
| 533 |
+
background_color=background_color,
|
| 534 |
+
use_bucket=use_bucket,
|
| 535 |
+
condition_resolution_scale_ratio=condition_resolution_scale_ratio
|
| 536 |
+
)
|
| 537 |
+
test_dataloader = DataLoader(datasets, batch_size=1, shuffle=False, num_workers=1,collate_fn=collate_fn)
|
| 538 |
+
|
| 539 |
+
for i, batch in enumerate(tqdm(test_dataloader)):
|
| 540 |
+
prompts = batch["caption"]
|
| 541 |
+
imgs_id = batch['id']
|
| 542 |
+
objects_boxes = batch["objects_boxes"]
|
| 543 |
+
objects_caption = batch['objects_caption']
|
| 544 |
+
objects_masks = batch['objects_masks']
|
| 545 |
+
condition_img = batch['condition_img']
|
| 546 |
+
neg_condtion_img = batch['neg_condtion_img']
|
| 547 |
+
objects_masks_maps= batch['objects_masks_maps']
|
| 548 |
+
subject_masks_maps = batch['condition_img_masks_maps']
|
| 549 |
+
target_width=batch['target_width'][0]
|
| 550 |
+
target_height=batch['target_height'][0]
|
| 551 |
+
|
| 552 |
+
img_info = batch["img_info"][0]
|
| 553 |
+
filename = img_info["img_id"]+'.jpg'
|
| 554 |
+
|
eval/layout.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from transformers import AutoModel, AutoTokenizer
|
| 6 |
+
import torch
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
if __name__ == "__main__":
|
| 9 |
+
model_id ="openbmb/MiniCPM-V-2_6"
|
| 10 |
+
model = AutoModel.from_pretrained(model_id, trust_remote_code=True,
|
| 11 |
+
attn_implementation='sdpa', torch_dtype=torch.bfloat16) # sdpa or flash_attention_2, no eager
|
| 12 |
+
model = model.eval().cuda()
|
| 13 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
| 14 |
+
|
| 15 |
+
# evaluation
|
| 16 |
+
benchmark_repo = 'HuiZhang0812/CreatiDesign_benchmark' # huggingface repo of benchmark
|
| 17 |
+
benchmark = load_dataset(benchmark_repo, split="test")
|
| 18 |
+
gen_root = "outputs/CreatiDesign_benchmark/images"
|
| 19 |
+
print("processing:",gen_root)
|
| 20 |
+
save_json_path = gen_root.replace("images", "minicpm-vqa.json")
|
| 21 |
+
temp_root = gen_root.replace("images", "images-perarea")
|
| 22 |
+
os.makedirs(temp_root, exist_ok=True)
|
| 23 |
+
|
| 24 |
+
skipped_files_log = gen_root.replace("images", "skipped_files.log")
|
| 25 |
+
skipped_files = []
|
| 26 |
+
image_stats = {}
|
| 27 |
+
|
| 28 |
+
for case in tqdm(benchmark):
|
| 29 |
+
json_data = json.loads(case["metadata"])
|
| 30 |
+
case_info = json_data["img_info"]
|
| 31 |
+
case_id = case_info["img_id"]
|
| 32 |
+
file_name = f"{case_id}.jpg"
|
| 33 |
+
generated_img_path = os.path.join(gen_root, file_name)
|
| 34 |
+
global_caption = json_data["global_caption"]
|
| 35 |
+
object_annotations = json_data["object_annotations"]
|
| 36 |
+
detial_region_caption_list = [item["bbox_detail_description"] for item in object_annotations]
|
| 37 |
+
region_caption_list = [item["class_name"] for item in object_annotations]
|
| 38 |
+
region_bboxes_list = [item["bbox"] for item in object_annotations]
|
| 39 |
+
|
| 40 |
+
img = Image.open(generated_img_path).convert("RGB")
|
| 41 |
+
width, height = img.size
|
| 42 |
+
|
| 43 |
+
orignal_img_width = json_data["img_info"]["img_width"]
|
| 44 |
+
orignal_img_height = json_data["img_info"]["img_height"]
|
| 45 |
+
|
| 46 |
+
temp_save_root = os.path.join(temp_root, file_name.split('.')[0])
|
| 47 |
+
os.makedirs(temp_save_root, exist_ok=True)
|
| 48 |
+
|
| 49 |
+
bbox_count = len(region_caption_list)
|
| 50 |
+
|
| 51 |
+
# Initialize scores
|
| 52 |
+
img_score_spatial = 0
|
| 53 |
+
img_score_color = 0
|
| 54 |
+
img_score_texture = 0
|
| 55 |
+
img_score_shape = 0
|
| 56 |
+
for i, (bbox,detial_region_caption,region_caption) in enumerate(zip(region_bboxes_list,detial_region_caption_list,region_caption_list)):
|
| 57 |
+
x1, y1, x2, y2= bbox
|
| 58 |
+
x1 = int(x1 / orignal_img_width*width)
|
| 59 |
+
y1 = int(y1 / orignal_img_height*height)
|
| 60 |
+
x2 = int(x2 / orignal_img_width*width)
|
| 61 |
+
y2 = int(y2 / orignal_img_height*height)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
cropped_img = img.crop((x1, y1, x2, y2))
|
| 65 |
+
|
| 66 |
+
# save crop img
|
| 67 |
+
description = region_caption.replace('/', '')
|
| 68 |
+
detail_description = detial_region_caption.replace('/', '')
|
| 69 |
+
cropped_img_path = os.path.join(temp_save_root, f'{description}.jpg')
|
| 70 |
+
cropped_img.save(cropped_img_path)
|
| 71 |
+
|
| 72 |
+
# spatial
|
| 73 |
+
question = f'Is the subject "{description}" present in the image? Strictly answer with "Yes" or "No", without any irrelevant words.'
|
| 74 |
+
|
| 75 |
+
msgs = [{'role': 'user', 'content': [cropped_img, question]}]
|
| 76 |
+
|
| 77 |
+
res = model.chat(
|
| 78 |
+
image=None,
|
| 79 |
+
msgs=msgs,
|
| 80 |
+
tokenizer=tokenizer,
|
| 81 |
+
seed=42
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
if "Yes" in res or "yes" in res:
|
| 85 |
+
score_spatial = 1.0
|
| 86 |
+
else:
|
| 87 |
+
score_spatial = 0.0
|
| 88 |
+
|
| 89 |
+
score_color, score_texture,score_shape = 0.0, 0.0, 0.0
|
| 90 |
+
# attribute
|
| 91 |
+
if score_spatial==1.0:
|
| 92 |
+
#color
|
| 93 |
+
question_color = f'Is the subject in "{description}" in the image consistent with the color described in the detailed description: "{detail_description}"? Strictly answer with "Yes" or "No", without any irrelevant words. If the color is not mentioned in the detailed description, the answer is "Yes".'
|
| 94 |
+
msgs_color = [{'role': 'user', 'content': [cropped_img, question_color]}]
|
| 95 |
+
|
| 96 |
+
color_attribute = model.chat(
|
| 97 |
+
image=None,
|
| 98 |
+
msgs=msgs_color,
|
| 99 |
+
tokenizer=tokenizer,
|
| 100 |
+
seed=42
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
if "Yes" in color_attribute or "yes" in color_attribute:
|
| 104 |
+
score_color = 1.0
|
| 105 |
+
# texture
|
| 106 |
+
if score_spatial==1.0:
|
| 107 |
+
question_texture = f'Is the subject in "{description}" in the image consistent with the texture described in the detailed description: "{detail_description}"? Strictly answer with "Yes" or "No", without any irrelevant words. If the texture is not mentioned in the detailed description, the answer is "Yes".'
|
| 108 |
+
msgs_texture = [{'role': 'user', 'content': [cropped_img, question_texture]}]
|
| 109 |
+
|
| 110 |
+
texture_attribute = model.chat(
|
| 111 |
+
image=None,
|
| 112 |
+
msgs=msgs_texture,
|
| 113 |
+
tokenizer=tokenizer,
|
| 114 |
+
seed=42
|
| 115 |
+
)
|
| 116 |
+
if "Yes" in texture_attribute or "yes" in texture_attribute:
|
| 117 |
+
score_texture = 1.0
|
| 118 |
+
#shape
|
| 119 |
+
if score_spatial==1.0:
|
| 120 |
+
question_shape = f'Is the subject in "{description}" in the image consistent with the shape described in the detailed description: "{detail_description}"? Strictly answer with "Yes" or "No", without any irrelevant words. If the shape is not mentioned in the detailed description, the answer is "Yes".'
|
| 121 |
+
msgs_shape = [{'role': 'user', 'content': [cropped_img, question_shape]}]
|
| 122 |
+
|
| 123 |
+
shape_attribute = model.chat(
|
| 124 |
+
image=None,
|
| 125 |
+
msgs=msgs_shape,
|
| 126 |
+
tokenizer=tokenizer,
|
| 127 |
+
seed=42
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
if "Yes" in shape_attribute or "yes" in shape_attribute:
|
| 131 |
+
score_shape = 1.0
|
| 132 |
+
|
| 133 |
+
# Update total scores
|
| 134 |
+
img_score_spatial += score_spatial
|
| 135 |
+
img_score_color += score_color
|
| 136 |
+
img_score_texture += score_texture
|
| 137 |
+
img_score_shape += score_shape
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# Store image stats
|
| 141 |
+
image_stats[os.path.basename(file_name)] = {
|
| 142 |
+
"bbox_count": bbox_count,
|
| 143 |
+
"score_spatial": img_score_spatial,
|
| 144 |
+
"score_color": img_score_color,
|
| 145 |
+
"score_texture": img_score_texture,
|
| 146 |
+
"score_shape": img_score_shape,
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
if len(image_stats) % 50 == 0:
|
| 150 |
+
with open(save_json_path, 'w', encoding='utf-8') as json_file:
|
| 151 |
+
json.dump(image_stats, json_file, indent=4)
|
| 152 |
+
|
| 153 |
+
# Save the image_stats dictionary to a JSON file
|
| 154 |
+
with open(save_json_path, 'w', encoding='utf-8') as json_file:
|
| 155 |
+
json.dump(image_stats, json_file, indent=4)
|
| 156 |
+
|
| 157 |
+
print(f"Image statistics saved to {save_json_path}")
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
score_save_path = save_json_path.replace('minicpm-vqa.json', 'minicpm-vqa-score.txt')
|
| 161 |
+
|
| 162 |
+
# Read the JSON file containing image statistics
|
| 163 |
+
with open(save_json_path, "r") as f:
|
| 164 |
+
json_data = json.load(f)
|
| 165 |
+
|
| 166 |
+
total_num = 0
|
| 167 |
+
total_bbox_num = 0
|
| 168 |
+
total_score_spatial = 0
|
| 169 |
+
total_score_color = 0
|
| 170 |
+
total_score_texture = 0
|
| 171 |
+
total_score_shape = 0
|
| 172 |
+
|
| 173 |
+
miss_match =0
|
| 174 |
+
# Iterate over the JSON data
|
| 175 |
+
for key, value in json_data.items():
|
| 176 |
+
|
| 177 |
+
total_num += value["bbox_count"]
|
| 178 |
+
total_score_spatial +=value["score_spatial"]
|
| 179 |
+
total_score_color +=value["score_color"]
|
| 180 |
+
total_score_texture +=value["score_texture"]
|
| 181 |
+
total_score_shape +=value["score_shape"]
|
| 182 |
+
|
| 183 |
+
if value["bbox_count"]!=value["score_spatial"] or value["bbox_count"]!=value["score_color"] or value["bbox_count"]!=value["score_texture"] or value["bbox_count"]!=value["score_shape"]:
|
| 184 |
+
print(key,value["bbox_count"],value["score_spatial"],value["score_color"],value["score_texture"],value["score_shape"])
|
| 185 |
+
miss_match+=1
|
| 186 |
+
|
| 187 |
+
print(miss_match)
|
| 188 |
+
#save total_score_spatial,total_score_color,total_score_texture,total_score_shape
|
| 189 |
+
with open(score_save_path, "w") as f:
|
| 190 |
+
f.write(f"Total number of bbox: {total_num}\n")
|
| 191 |
+
f.write(f"Total score of spatial: {total_score_spatial}; Average score of spatial: {round(total_score_spatial/total_num,4)}\n")
|
| 192 |
+
f.write(f"Total score of color: {total_score_color}; Average score of color: {round(total_score_color/total_num,4)}\n")
|
| 193 |
+
f.write(f"Total score of texture: {total_score_texture}; Average score of texture: {round(total_score_texture/total_num,4)}\n")
|
| 194 |
+
f.write(f"Total score of shape: {total_score_shape}; Average score of shape: {round(total_score_shape/total_num,4)}\n")
|
eval/subject.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys, json, math, argparse, glob
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import List
|
| 4 |
+
import torch
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from transformers import (
|
| 9 |
+
AutoProcessor, CLIPModel,
|
| 10 |
+
AutoImageProcessor, AutoModel
|
| 11 |
+
)
|
| 12 |
+
from datasets import load_dataset
|
| 13 |
+
|
| 14 |
+
def scale_bbox(bbox, ori_size, target_size):
|
| 15 |
+
x_min, y_min, x_max, y_max = bbox
|
| 16 |
+
ori_width, ori_height = ori_size
|
| 17 |
+
target_width, target_height = target_size
|
| 18 |
+
|
| 19 |
+
width_ratio = target_width / ori_width
|
| 20 |
+
height_ratio = target_height / ori_height
|
| 21 |
+
|
| 22 |
+
scaled_x_min = int(x_min * width_ratio)
|
| 23 |
+
scaled_y_min = int(y_min * height_ratio)
|
| 24 |
+
scaled_x_max = int(x_max * width_ratio)
|
| 25 |
+
scaled_y_max = int(y_max * height_ratio)
|
| 26 |
+
|
| 27 |
+
scaled_x_min = max(0, scaled_x_min)
|
| 28 |
+
scaled_y_min = max(0, scaled_y_min)
|
| 29 |
+
scaled_x_max = min(target_width, scaled_x_max)
|
| 30 |
+
scaled_y_max = min(target_height, scaled_y_max)
|
| 31 |
+
|
| 32 |
+
return [scaled_x_min, scaled_y_min, scaled_x_max, scaled_y_max]
|
| 33 |
+
|
| 34 |
+
@torch.no_grad()
|
| 35 |
+
def encode_clip(imgs: List[Image.Image]) -> torch.Tensor:
|
| 36 |
+
features_list = []
|
| 37 |
+
for img in imgs:
|
| 38 |
+
inputs = clip_processor(images=img, return_tensors="pt").to(device)
|
| 39 |
+
image_features = clip_model.get_image_features(**inputs)
|
| 40 |
+
|
| 41 |
+
normalized_features = image_features / image_features.norm(dim=1, keepdim=True)
|
| 42 |
+
features_list.append(normalized_features.squeeze().cpu())
|
| 43 |
+
return torch.stack(features_list)
|
| 44 |
+
|
| 45 |
+
@torch.no_grad()
|
| 46 |
+
def encode_dino(imgs: List[Image.Image]) -> torch.Tensor:
|
| 47 |
+
features_list = []
|
| 48 |
+
for img in imgs:
|
| 49 |
+
inputs = dino_processor(images=img, return_tensors="pt").to(device)
|
| 50 |
+
outputs = dino_model(**inputs)
|
| 51 |
+
image_features = outputs.last_hidden_state.mean(dim=1)
|
| 52 |
+
normalized_features = image_features / image_features.norm(dim=1, keepdim=True)
|
| 53 |
+
features_list.append(normalized_features.squeeze().cpu())
|
| 54 |
+
return torch.stack(features_list)
|
| 55 |
+
|
| 56 |
+
@torch.no_grad()
|
| 57 |
+
def cosine(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
| 58 |
+
return (a @ b.T).squeeze()
|
| 59 |
+
|
| 60 |
+
# ------------- Command line arguments -----------------
|
| 61 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 62 |
+
parser.add_argument("--benchmark_repo", type=str, default="HuiZhang0812/CreatiDesign_benchmark",
|
| 63 |
+
help="Root directory for one thousand cases")
|
| 64 |
+
parser.add_argument("--gen_root", type=str, default="outputs/CreatiDesign_benchmark",
|
| 65 |
+
help="Root directory for generated images (should have images/<case_id>.jpg underneath)")
|
| 66 |
+
parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"])
|
| 67 |
+
parser.add_argument("--outfile", type=str,
|
| 68 |
+
help="Path for result CSV; by default written to gen_root")
|
| 69 |
+
args = parser.parse_args()
|
| 70 |
+
|
| 71 |
+
print("handling:", args.gen_root)
|
| 72 |
+
if args.outfile is None:
|
| 73 |
+
args.outfile = os.path.join(args.gen_root,"scores.csv")
|
| 74 |
+
|
| 75 |
+
# Convert outfile to Path object
|
| 76 |
+
outfile_path = Path(args.outfile)
|
| 77 |
+
|
| 78 |
+
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
| 79 |
+
print(f"[INFO] Using device: {device}")
|
| 80 |
+
|
| 81 |
+
# ------------- Loading models -------------------
|
| 82 |
+
print("[INFO] loading CLIP...")
|
| 83 |
+
clip_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 84 |
+
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
| 85 |
+
clip_model.eval()
|
| 86 |
+
|
| 87 |
+
print("[INFO] loading DINOv2...")
|
| 88 |
+
dino_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
|
| 89 |
+
dino_model = AutoModel.from_pretrained('facebook/dinov2-base').to(device)
|
| 90 |
+
dino_model.eval()
|
| 91 |
+
|
| 92 |
+
benchmark = load_dataset(args.benchmark_repo, split="test")
|
| 93 |
+
|
| 94 |
+
DEBUG = True
|
| 95 |
+
if DEBUG:
|
| 96 |
+
subject_save_roor = os.path.join(args.gen_root,"subject-eval-visual")
|
| 97 |
+
os.makedirs(subject_save_roor,exist_ok=True)
|
| 98 |
+
records = []
|
| 99 |
+
for case in tqdm(benchmark):
|
| 100 |
+
json_data = json.loads(case["metadata"])
|
| 101 |
+
case_info = json_data["img_info"]
|
| 102 |
+
case_id = case_info["img_id"]
|
| 103 |
+
|
| 104 |
+
# ---------- Read reference subjects ----------
|
| 105 |
+
ref_imgs = case['condition_white_variants']
|
| 106 |
+
if len(ref_imgs) == 0:
|
| 107 |
+
print(f"[WARN] {case_id} has no reference subject, skipping")
|
| 108 |
+
continue
|
| 109 |
+
|
| 110 |
+
# ---------- Read generated image ----------
|
| 111 |
+
gen_path = os.path.join(args.gen_root, "images", f"{case_id}.jpg")
|
| 112 |
+
gen_img = Image.open(gen_path).convert("RGB")
|
| 113 |
+
# Get width and height of generated image
|
| 114 |
+
gen_width, gen_height = gen_img.size
|
| 115 |
+
reg_bbox_id = [item["bbox_idx"] for item in sorted(json_data["subject_annotations"], key=lambda x: x["bbox_idx"])]
|
| 116 |
+
ref_bbox = [item["bbox"] for item in sorted(json_data["subject_annotations"], key=lambda x: x["bbox_idx"])]
|
| 117 |
+
ori_width,ori_height = json_data["img_info"]["img_width"],json_data["img_info"]["img_height"]
|
| 118 |
+
# Extract corresponding images from the generated image
|
| 119 |
+
gen_imgs = []
|
| 120 |
+
for bbox in ref_bbox:
|
| 121 |
+
# Scale the bounding box
|
| 122 |
+
scaled_bbox = scale_bbox(
|
| 123 |
+
bbox,
|
| 124 |
+
(ori_width, ori_height),
|
| 125 |
+
(gen_width, gen_height)
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Crop the image area
|
| 129 |
+
x_min, y_min, x_max, y_max = scaled_bbox
|
| 130 |
+
cropped_img = gen_img.crop((x_min, y_min, x_max, y_max))
|
| 131 |
+
gen_imgs.append(cropped_img)
|
| 132 |
+
if DEBUG:
|
| 133 |
+
folder_root = os.path.join(subject_save_roor,case_id)
|
| 134 |
+
os.makedirs(folder_root,exist_ok=True)
|
| 135 |
+
# Save cropped images
|
| 136 |
+
for i, (img, img_id) in enumerate(zip(gen_imgs, reg_bbox_id)):
|
| 137 |
+
img.save(os.path.join(folder_root, f"{img_id}.png"))
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# ---------- Features ----------
|
| 141 |
+
ref_clip = encode_clip(ref_imgs) # (n,dim)
|
| 142 |
+
gen_clip = encode_clip(gen_imgs) # (n,dim)
|
| 143 |
+
|
| 144 |
+
ref_dino = encode_dino(ref_imgs) # (n,dim)
|
| 145 |
+
gen_dino = encode_dino(gen_imgs) # (n,dim)
|
| 146 |
+
|
| 147 |
+
# ---------- Similarity ----------
|
| 148 |
+
clip_sims = torch.nn.functional.cosine_similarity(ref_clip, gen_clip)
|
| 149 |
+
dino_sims = torch.nn.functional.cosine_similarity(ref_dino, gen_dino)
|
| 150 |
+
|
| 151 |
+
clip_i = clip_sims.mean().item()
|
| 152 |
+
dino_avg = dino_sims.mean().item()
|
| 153 |
+
m_dino = dino_sims.prod().item()
|
| 154 |
+
|
| 155 |
+
records.append(dict(
|
| 156 |
+
case_id=case_id,
|
| 157 |
+
num_subject=len(ref_imgs),
|
| 158 |
+
clip_i=clip_i,
|
| 159 |
+
dino=dino_avg,
|
| 160 |
+
m_dino=m_dino
|
| 161 |
+
))
|
| 162 |
+
|
| 163 |
+
# ---------------- Result statistics -----------------
|
| 164 |
+
df = pd.DataFrame(records).sort_values("case_id")
|
| 165 |
+
overall = df[["clip_i","dino","m_dino"]].mean().to_dict()
|
| 166 |
+
|
| 167 |
+
print("\n========== Overall Average ==========")
|
| 168 |
+
for k,v in overall.items():
|
| 169 |
+
print(f"{k:>8}: {v:.6f}")
|
| 170 |
+
print("=====================================\n")
|
| 171 |
+
|
| 172 |
+
# Group by number of subjects
|
| 173 |
+
df_by_subjects = {}
|
| 174 |
+
avg_by_subjects = {}
|
| 175 |
+
|
| 176 |
+
# Create subset for each subject count (1-5)
|
| 177 |
+
for i in range(1, 6):
|
| 178 |
+
# Filter records with subject count = i
|
| 179 |
+
subset = df[df["num_subject"] == i]
|
| 180 |
+
|
| 181 |
+
if len(subset) > 0:
|
| 182 |
+
# Calculate average for this group
|
| 183 |
+
subset_avg = subset[["clip_i", "dino", "m_dino"]].mean().to_dict()
|
| 184 |
+
avg_by_subjects[i] = subset_avg
|
| 185 |
+
|
| 186 |
+
# Create subset with average row
|
| 187 |
+
avg_row = {"case_id": f"average_subject_{i}", "num_subject": i}
|
| 188 |
+
avg_row.update(subset_avg)
|
| 189 |
+
|
| 190 |
+
# Add average row to subset
|
| 191 |
+
subset_with_avg = pd.concat([subset, pd.DataFrame([avg_row])], ignore_index=True)
|
| 192 |
+
df_by_subjects[i] = subset_with_avg
|
| 193 |
+
|
| 194 |
+
# Print average for this group
|
| 195 |
+
print(f"\n=== Subject {i} Average (n={len(subset)}) ===")
|
| 196 |
+
for k, v in subset_avg.items():
|
| 197 |
+
print(f"{k:>8}: {v:.6f}")
|
| 198 |
+
|
| 199 |
+
# Save subset - fixed path handling
|
| 200 |
+
subject_path = outfile_path.parent / f"{outfile_path.stem}_subject{i}_location_prior{outfile_path.suffix}"
|
| 201 |
+
subset_with_avg.to_csv(subject_path, index=False, float_format="%.6f")
|
| 202 |
+
print(f"[INFO] Subject {i} results written to {subject_path}")
|
| 203 |
+
|
| 204 |
+
# Save overall average to CSV - fixed path handling
|
| 205 |
+
overall_df = pd.DataFrame([overall], index=["overall"])
|
| 206 |
+
overall_path = outfile_path.parent / f"{outfile_path.stem}_overall_location_prior{outfile_path.suffix}"
|
| 207 |
+
overall_df.to_csv(overall_path, float_format="%.6f")
|
| 208 |
+
print(f"[INFO] Overall results written to {overall_path}")
|
| 209 |
+
|
| 210 |
+
# Write CSV
|
| 211 |
+
df.to_csv(args.outfile, index=False, float_format="%.6f")
|
| 212 |
+
print(f"[INFO] Written to {args.outfile}")
|
| 213 |
+
|
| 214 |
+
# Create statistics table with averages for all groups
|
| 215 |
+
if avg_by_subjects:
|
| 216 |
+
# Merge averages for each group into one table
|
| 217 |
+
stats_rows = []
|
| 218 |
+
for num_subject, avg_dict in avg_by_subjects.items():
|
| 219 |
+
row = {"num_subject": num_subject}
|
| 220 |
+
row.update(avg_dict)
|
| 221 |
+
stats_rows.append(row)
|
| 222 |
+
|
| 223 |
+
# Add overall average
|
| 224 |
+
overall_row = {"num_subject": "all"}
|
| 225 |
+
overall_row.update(overall)
|
| 226 |
+
stats_rows.append(overall_row)
|
| 227 |
+
|
| 228 |
+
# Create summary statistics table
|
| 229 |
+
stats_df = pd.DataFrame(stats_rows)
|
| 230 |
+
# Fixed path handling
|
| 231 |
+
stats_path = outfile_path.parent / f"{outfile_path.stem}_stats_location_prior{outfile_path.suffix}"
|
| 232 |
+
stats_df.to_csv(stats_path, index=False, float_format="%.6f")
|
| 233 |
+
print(f"[INFO] All group statistics written to {stats_path}")
|
eval/text.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, json, csv, re, cv2, numpy as np, torch
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
from editdistance import eval as edit_distance
|
| 4 |
+
from paddleocr import PaddleOCR
|
| 5 |
+
from datasets import load_dataset
|
| 6 |
+
# -------------------------------------------------------------------
|
| 7 |
+
# Paths
|
| 8 |
+
benchmark_repo = 'HuiZhang0812/CreatiDesign_benchmark' # huggingface repo of benchmark
|
| 9 |
+
benchmark = load_dataset(benchmark_repo, split="test")
|
| 10 |
+
root_gen = "outputs/CreatiDesign_benchmark/images"
|
| 11 |
+
|
| 12 |
+
save_root = root_gen.replace("images", "text_eval") # Output directory
|
| 13 |
+
os.makedirs(save_root, exist_ok=True)
|
| 14 |
+
DEBUG = True
|
| 15 |
+
# -------------------------------------------------------------------
|
| 16 |
+
# 1. OCR initialization (must be det=True)
|
| 17 |
+
ocr = PaddleOCR(det=True, rec=True, cls=False, use_angle_cls=False, lang='en')
|
| 18 |
+
|
| 19 |
+
# -------------------------------------------------------------------
|
| 20 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 21 |
+
|
| 22 |
+
# -------------------------------------------------------------------
|
| 23 |
+
# 3. Utility functions
|
| 24 |
+
|
| 25 |
+
def spatial_match_iou(det_res, gt_box, gt_text_fmt, iou_thr=0.5):
|
| 26 |
+
best_iou = 0.0
|
| 27 |
+
if det_res is None or len(det_res) == 0:
|
| 28 |
+
return best_iou
|
| 29 |
+
|
| 30 |
+
for item in det_res:
|
| 31 |
+
poly = item[0] # Detection box coordinates
|
| 32 |
+
txt_info = item[1] # Text information tuple
|
| 33 |
+
txt = txt_info[0] # Text content
|
| 34 |
+
|
| 35 |
+
if min_ned_substring(normalize_text(txt), gt_text_fmt) <= 0.7: # When calculating spatial, allow some degree of text error
|
| 36 |
+
iou_val = iou(quad2bbox(poly), gt_box)
|
| 37 |
+
best_iou = max(best_iou, iou_val)
|
| 38 |
+
return best_iou
|
| 39 |
+
|
| 40 |
+
# ① New tool: Minimum NED substring
|
| 41 |
+
def min_ned_substring(pred_fmt: str, tgt_fmt: str) -> float:
|
| 42 |
+
"""
|
| 43 |
+
Find a substring in pred_fmt with the same length as tgt_fmt, to minimize normalized edit distance
|
| 44 |
+
Return the minimum value (0 ~ 1)
|
| 45 |
+
"""
|
| 46 |
+
Lp, Lg = len(pred_fmt), len(tgt_fmt)
|
| 47 |
+
if Lg == 0:
|
| 48 |
+
return 0.0
|
| 49 |
+
if Lp < Lg: # If prediction string is shorter than target, calculate directly
|
| 50 |
+
return normalized_edit_distance(pred_fmt, tgt_fmt)
|
| 51 |
+
|
| 52 |
+
best = Lg # Maximum possible distance
|
| 53 |
+
for i in range(Lp - Lg + 1):
|
| 54 |
+
sub = pred_fmt[i:i+Lg]
|
| 55 |
+
d = edit_distance(sub, tgt_fmt)
|
| 56 |
+
if d < best:
|
| 57 |
+
best = d
|
| 58 |
+
if best == 0: # Early exit
|
| 59 |
+
break
|
| 60 |
+
return best / Lg # Normalize
|
| 61 |
+
|
| 62 |
+
def normalize_text(txt: str) -> str:
|
| 63 |
+
txt = txt.lower().replace(" ", "")
|
| 64 |
+
return re.sub(r"[^\w\s]", "", txt)
|
| 65 |
+
|
| 66 |
+
def normalized_edit_distance(pred: str, gt: str) -> float:
|
| 67 |
+
if not gt and not pred:
|
| 68 |
+
return 0.0
|
| 69 |
+
return edit_distance(pred, gt) / max(len(gt), len(pred))
|
| 70 |
+
|
| 71 |
+
def iou(boxA, boxB) -> float:
|
| 72 |
+
xA, yA = max(boxA[0], boxB[0]), max(boxA[1], boxB[1])
|
| 73 |
+
xB, yB = min(boxA[2], boxB[2]), min(boxA[3], boxB[3])
|
| 74 |
+
inter = max(0, xB - xA) * max(0, yB - yA)
|
| 75 |
+
if inter == 0:
|
| 76 |
+
return 0.0
|
| 77 |
+
areaA = (boxA[2]-boxA[0]) * (boxA[3]-boxA[1])
|
| 78 |
+
areaB = (boxB[2]-boxB[0]) * (boxB[3]-boxB[1])
|
| 79 |
+
return inter / (areaA + areaB - inter)
|
| 80 |
+
|
| 81 |
+
def quad2bbox(quad):
|
| 82 |
+
xs = [p[0] for p in quad]; ys = [p[1] for p in quad]
|
| 83 |
+
return [min(xs), min(ys), max(xs), max(ys)]
|
| 84 |
+
|
| 85 |
+
def crop(img, box):
|
| 86 |
+
h, w = img.shape[:2]
|
| 87 |
+
x1,y1,x2,y2 = map(int, box)
|
| 88 |
+
x1, y1 = max(0, x1), max(0, y1)
|
| 89 |
+
x2, y2 = min(w-1, x2), min(h-1, y2)
|
| 90 |
+
if x2 <= x1 or y2 <= y1:
|
| 91 |
+
return np.zeros((1,1,3), np.uint8)
|
| 92 |
+
return img[y1:y2, x1:x2]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# -------------------------------------------------------------------
|
| 96 |
+
# 4. Main loop
|
| 97 |
+
per_img_rows, all_sen_acc, all_ned, all_spatial, text_pairs = [], [], [], [], []
|
| 98 |
+
|
| 99 |
+
for case in tqdm(benchmark):
|
| 100 |
+
json_data = json.loads(case["metadata"])
|
| 101 |
+
case_info = json_data["img_info"]
|
| 102 |
+
case_id = case_info["img_id"]
|
| 103 |
+
|
| 104 |
+
gt_list = json_data["text_list"] # [{'text':..., 'bbox':[x1,y1,x2,y2]}, ...]
|
| 105 |
+
ori_w, ori_h = json_data["img_info"]["img_width"], json_data["img_info"]["img_height"]
|
| 106 |
+
|
| 107 |
+
img_path = os.path.join(root_gen, f"{case_id}.jpg")
|
| 108 |
+
|
| 109 |
+
img = cv2.imread(img_path)
|
| 110 |
+
H, W = img.shape[:2]
|
| 111 |
+
wr, hr = W / ori_w, H / ori_h # GT → Generated image scaling ratio
|
| 112 |
+
|
| 113 |
+
# ---------- 1) Full image OCR ----------
|
| 114 |
+
pred_lines = [] # Save OCR line text
|
| 115 |
+
ocr_res = ocr.ocr(img, cls=False)
|
| 116 |
+
if ocr_res and ocr_res[0]:
|
| 117 |
+
for quad, (txt, conf) in ocr_res[0]:
|
| 118 |
+
pred_lines.append(txt.strip())
|
| 119 |
+
|
| 120 |
+
# Concatenate into full text and normalize
|
| 121 |
+
pred_full_fmt = normalize_text(" ".join(pred_lines))
|
| 122 |
+
|
| 123 |
+
# ==========================================================
|
| 124 |
+
# ③ For each GT sentence, do "substring minimum NED" ---- no longer using IoU
|
| 125 |
+
img_sen_hits, img_neds, img_spatials = [], [], []
|
| 126 |
+
|
| 127 |
+
for t_idx, gt in enumerate(gt_list):
|
| 128 |
+
gt_text_orig = gt["text"].replace("\n", " ").strip()
|
| 129 |
+
gt_text_fmt = normalize_text(gt_text_orig)
|
| 130 |
+
|
| 131 |
+
# ---- Pure text matching ----
|
| 132 |
+
ned = min_ned_substring(pred_full_fmt, gt_text_fmt)
|
| 133 |
+
acc = 1.0 if ned == 0 else 0.0
|
| 134 |
+
img_sen_hits.append(acc)
|
| 135 |
+
img_neds.append(ned)
|
| 136 |
+
|
| 137 |
+
# ---------- Spatial consistency, using IOU ----------
|
| 138 |
+
gt_box = [v*wr if i%2==0 else v*hr for i,v in enumerate(gt["bbox"])]
|
| 139 |
+
det_res = ocr_res[0] if ocr_res else []
|
| 140 |
+
spatial_score = spatial_match_iou(det_res, gt_box, gt_text_fmt)
|
| 141 |
+
img_spatials.append(spatial_score) # Can be used directly or binarized
|
| 142 |
+
crop_box_int = list(map(int, gt_box))
|
| 143 |
+
img_crop = crop(img, crop_box_int)
|
| 144 |
+
if DEBUG:
|
| 145 |
+
# Save cropped image
|
| 146 |
+
img_crop_for_ocr_save_root = os.path.join(save_root, case_id)
|
| 147 |
+
os.makedirs(img_crop_for_ocr_save_root, exist_ok=True)
|
| 148 |
+
safe_text = gt_text_orig.replace('/', '_').replace('\\', '_')
|
| 149 |
+
safe_filename = f"{t_idx}_{safe_text}.jpg"
|
| 150 |
+
cv2.imwrite(os.path.join(img_crop_for_ocr_save_root, safe_filename), img_crop)
|
| 151 |
+
|
| 152 |
+
# --------- Record text pairs ----------
|
| 153 |
+
text_pairs.append({
|
| 154 |
+
"image_id" : case_id,
|
| 155 |
+
"text_id" : t_idx,
|
| 156 |
+
"gt_original" : gt_text_orig,
|
| 157 |
+
"gt_formatted" : gt_text_fmt
|
| 158 |
+
})
|
| 159 |
+
|
| 160 |
+
# ---------- 3) Summarize to image level ----------
|
| 161 |
+
sen_acc = float(np.mean(img_sen_hits))
|
| 162 |
+
ned = float(np.mean(img_neds))
|
| 163 |
+
spatial = float(np.mean(img_spatials))
|
| 164 |
+
|
| 165 |
+
per_img_rows.append([case_id, sen_acc, ned, spatial])
|
| 166 |
+
all_sen_acc.append(sen_acc)
|
| 167 |
+
all_ned.append(ned)
|
| 168 |
+
all_spatial.append(spatial)
|
| 169 |
+
|
| 170 |
+
# -------------------------------------------------------------------
|
| 171 |
+
# 5. Write results
|
| 172 |
+
result_root = root_gen.replace("images","")
|
| 173 |
+
csv_perimg = os.path.join(result_root, "text_results_per_image.csv")
|
| 174 |
+
with open(csv_perimg, "w", newline='', encoding="utf-8") as f:
|
| 175 |
+
w = csv.writer(f); w.writerow(["image_id","sen_acc","ned","score_spatial"]); w.writerows(per_img_rows)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
with open(os.path.join(result_root, "text_overall.txt"), "w", encoding="utf-8") as f:
|
| 179 |
+
f.write(f"Images evaluated : {len(per_img_rows)}\n")
|
| 180 |
+
f.write(f"Global Sen ACC : {np.mean(all_sen_acc):.4f}\n")
|
| 181 |
+
f.write(f"Global NED : {np.mean(all_ned):.4f}\n")
|
| 182 |
+
f.write(f"Global Spatial : {np.mean(all_spatial):.4f}\n")
|
| 183 |
+
|
| 184 |
+
print("✓ Done! Results saved to", result_root)
|
modules/common/__pycache__/lora.cpython-310.pyc
ADDED
|
Binary file (1.17 kB). View file
|
|
|
modules/common/lora.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class LoRALinearLayer(nn.Module):
|
| 5 |
+
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
|
| 6 |
+
super().__init__()
|
| 7 |
+
|
| 8 |
+
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
| 9 |
+
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
|
| 10 |
+
self.network_alpha = network_alpha
|
| 11 |
+
self.rank = rank
|
| 12 |
+
|
| 13 |
+
nn.init.normal_(self.down.weight, std=1 / rank)
|
| 14 |
+
nn.init.zeros_(self.up.weight)
|
| 15 |
+
|
| 16 |
+
def forward(self, hidden_states):
|
| 17 |
+
orig_dtype = hidden_states.dtype
|
| 18 |
+
dtype = self.down.weight.dtype
|
| 19 |
+
|
| 20 |
+
down_hidden_states = self.down(hidden_states.to(dtype))
|
| 21 |
+
up_hidden_states = self.up(down_hidden_states)
|
| 22 |
+
|
| 23 |
+
if self.network_alpha is not None:
|
| 24 |
+
up_hidden_states *= self.network_alpha / self.rank
|
| 25 |
+
|
| 26 |
+
return up_hidden_states.to(orig_dtype)
|
modules/flux/__pycache__/attention_processor_flux_creatidesign.cpython-310.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e534db89ad40a8e61c4c32b8bbeb3084e7d01a83667a66f426dbdfdf93a13936
|
| 3 |
+
size 127465
|
modules/flux/__pycache__/transformer_flux_creatidesign.cpython-310.pyc
ADDED
|
Binary file (25.8 kB). View file
|
|
|
modules/flux/attention_processor_flux_creatidesign.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modules/flux/transformer_flux_creatidesign.py
ADDED
|
@@ -0,0 +1,1004 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
|
| 23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 24 |
+
from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
| 25 |
+
from diffusers.models.attention import FeedForward
|
| 26 |
+
from modules.flux.attention_processor_flux_creatidesign import (
|
| 27 |
+
Attention,
|
| 28 |
+
AttentionProcessor,
|
| 29 |
+
DesignFluxAttnProcessor2_0,
|
| 30 |
+
FluxAttnProcessor2_0_NPU,
|
| 31 |
+
FusedFluxAttnProcessor2_0,
|
| 32 |
+
)
|
| 33 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 34 |
+
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
| 35 |
+
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
| 36 |
+
from diffusers.utils.import_utils import is_torch_npu_available
|
| 37 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 38 |
+
from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
| 39 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 40 |
+
from modules.semantic_layout.layout_encoder import ObjectLayoutEncoder,ObjectLayoutEncoder_noFourier
|
| 41 |
+
from modules.common.lora import LoRALinearLayer
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@maybe_allow_in_graph
|
| 50 |
+
class FluxSingleTransformerBlock(nn.Module):
|
| 51 |
+
r"""
|
| 52 |
+
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
| 53 |
+
|
| 54 |
+
Reference: https://arxiv.org/abs/2403.03206
|
| 55 |
+
|
| 56 |
+
Parameters:
|
| 57 |
+
dim (`int`): The number of channels in the input and output.
|
| 58 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
| 59 |
+
attention_head_dim (`int`): The number of channels in each head.
|
| 60 |
+
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
| 61 |
+
processing of `context` conditions.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0, rank=16,network_alpha=16,lora_weight=1.0,attention_type="design"):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
| 67 |
+
|
| 68 |
+
self.norm = AdaLayerNormZeroSingle(dim)
|
| 69 |
+
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
| 70 |
+
self.act_mlp = nn.GELU(approximate="tanh")
|
| 71 |
+
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if is_torch_npu_available():
|
| 75 |
+
processor = FluxAttnProcessor2_0_NPU()
|
| 76 |
+
else:
|
| 77 |
+
processor = DesignFluxAttnProcessor2_0()
|
| 78 |
+
self.attn = Attention(
|
| 79 |
+
query_dim=dim,
|
| 80 |
+
cross_attention_dim=None,
|
| 81 |
+
dim_head=attention_head_dim,
|
| 82 |
+
heads=num_attention_heads,
|
| 83 |
+
out_dim=dim,
|
| 84 |
+
bias=True,
|
| 85 |
+
processor=processor,
|
| 86 |
+
qk_norm="rms_norm",
|
| 87 |
+
eps=1e-6,
|
| 88 |
+
pre_only=True,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
self.attention_type = attention_type
|
| 92 |
+
self.rank = rank
|
| 93 |
+
self.network_alpha = network_alpha
|
| 94 |
+
self.lora_weight = lora_weight
|
| 95 |
+
if attention_type == "design":
|
| 96 |
+
self.layernorm_subject = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) # layernorm for subject
|
| 97 |
+
self.norm_subject_lora = nn.Sequential(
|
| 98 |
+
nn.SiLU(),
|
| 99 |
+
LoRALinearLayer(dim, dim*3, self.rank, self.network_alpha) # lora for adalinear of subject
|
| 100 |
+
)
|
| 101 |
+
self.layernorm_object_bbox = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) # layernorm for object
|
| 102 |
+
self.norm_object_lora = nn.Sequential(
|
| 103 |
+
nn.SiLU(),
|
| 104 |
+
LoRALinearLayer(dim, dim*3, self.rank, self.network_alpha) # lora for adalinear of object
|
| 105 |
+
)
|
| 106 |
+
def single_block_adaln_lora_forward(self, x, temb, adaln, adaln_lora, layernorm, lora_weight):
|
| 107 |
+
norm_x, x_gate = adaln(x, emb=temb)
|
| 108 |
+
lora_shift_msa, lora_scale_msa, lora_gate_msa = adaln_lora(temb).chunk(3, dim=1)
|
| 109 |
+
norm_x = norm_x + lora_weight * (layernorm(x)* (1 + lora_scale_msa[:, None]) + lora_shift_msa[:, None])
|
| 110 |
+
x_gate = x_gate + lora_weight * lora_gate_msa
|
| 111 |
+
return norm_x, x_gate
|
| 112 |
+
|
| 113 |
+
def forward(
|
| 114 |
+
self,
|
| 115 |
+
hidden_states: torch.Tensor,
|
| 116 |
+
temb: torch.Tensor,
|
| 117 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 118 |
+
subject_hidden_states = None,
|
| 119 |
+
subject_rotary_emb = None,
|
| 120 |
+
object_bbox_hidden_states = None,
|
| 121 |
+
object_rotary_emb = None,
|
| 122 |
+
design_scale = 1.0,
|
| 123 |
+
attention_mask=None,
|
| 124 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 125 |
+
) -> torch.Tensor:
|
| 126 |
+
residual = hidden_states
|
| 127 |
+
|
| 128 |
+
# handle hidden_states
|
| 129 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
| 130 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
| 131 |
+
#creatidesign
|
| 132 |
+
use_subject = True if self.attention_type == "design" and subject_hidden_states is not None and design_scale!=0.0 else False
|
| 133 |
+
use_object = True if self.attention_type == "design" and object_bbox_hidden_states is not None and design_scale!=0.0 else False
|
| 134 |
+
# handle subejct_hidden_states
|
| 135 |
+
if use_subject:
|
| 136 |
+
residual_subject_hidden_states = subject_hidden_states
|
| 137 |
+
norm_subject_hidden_states, subject_gate = self.single_block_adaln_lora_forward(subject_hidden_states, temb, self.norm, self.norm_subject_lora, self.layernorm_subject, self.lora_weight)
|
| 138 |
+
mlp_subject_hidden_states = self.act_mlp(self.proj_mlp(norm_subject_hidden_states))
|
| 139 |
+
if use_object:
|
| 140 |
+
residual_object_bbox_hidden_states = object_bbox_hidden_states
|
| 141 |
+
norm_object_bbox_hidden_states, object_gate = self.single_block_adaln_lora_forward(object_bbox_hidden_states, temb, self.norm, self.norm_object_lora, self.layernorm_object_bbox, self.lora_weight)
|
| 142 |
+
mlp_object_bbox_hidden_states = self.act_mlp(self.proj_mlp(norm_object_bbox_hidden_states))
|
| 143 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 144 |
+
attn_output, subject_attn_output, object_attn_output = self.attn(
|
| 145 |
+
hidden_states=norm_hidden_states,
|
| 146 |
+
image_rotary_emb=image_rotary_emb,
|
| 147 |
+
subject_hidden_states=norm_subject_hidden_states,
|
| 148 |
+
subject_rotary_emb=subject_rotary_emb,
|
| 149 |
+
object_bbox_hidden_states=norm_object_bbox_hidden_states,
|
| 150 |
+
object_rotary_emb=object_rotary_emb,
|
| 151 |
+
attention_mask = attention_mask,
|
| 152 |
+
**joint_attention_kwargs,
|
| 153 |
+
)
|
| 154 |
+
# handle hidden states
|
| 155 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
| 156 |
+
gate = gate.unsqueeze(1)
|
| 157 |
+
hidden_states = gate * self.proj_out(hidden_states)
|
| 158 |
+
hidden_states = residual + hidden_states
|
| 159 |
+
#handle subject_hidden_states
|
| 160 |
+
if use_subject:
|
| 161 |
+
subject_hidden_states = torch.cat([subject_attn_output, mlp_subject_hidden_states], dim=2)
|
| 162 |
+
subject_gate = subject_gate.unsqueeze(1)
|
| 163 |
+
subject_hidden_states = subject_gate * self.proj_out(subject_hidden_states)
|
| 164 |
+
subject_hidden_states = residual_subject_hidden_states + subject_hidden_states
|
| 165 |
+
|
| 166 |
+
#handle object_bbox_hidden_states
|
| 167 |
+
if use_object:
|
| 168 |
+
object_bbox_hidden_states = torch.cat([object_attn_output, mlp_object_bbox_hidden_states], dim=2)
|
| 169 |
+
object_gate = object_gate.unsqueeze(1)
|
| 170 |
+
object_bbox_hidden_states = object_gate * self.proj_out(object_bbox_hidden_states)
|
| 171 |
+
object_bbox_hidden_states = residual_object_bbox_hidden_states + object_bbox_hidden_states
|
| 172 |
+
if hidden_states.dtype == torch.float16:
|
| 173 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
| 174 |
+
|
| 175 |
+
return hidden_states, subject_hidden_states, object_bbox_hidden_states
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@maybe_allow_in_graph
|
| 179 |
+
class FluxTransformerBlock(nn.Module):
|
| 180 |
+
r"""
|
| 181 |
+
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
| 182 |
+
|
| 183 |
+
Reference: https://arxiv.org/abs/2403.03206
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
dim (`int`):
|
| 187 |
+
The embedding dimension of the block.
|
| 188 |
+
num_attention_heads (`int`):
|
| 189 |
+
The number of attention heads to use.
|
| 190 |
+
attention_head_dim (`int`):
|
| 191 |
+
The number of dimensions to use for each attention head.
|
| 192 |
+
qk_norm (`str`, defaults to `"rms_norm"`):
|
| 193 |
+
The normalization to use for the query and key tensors.
|
| 194 |
+
eps (`float`, defaults to `1e-6`):
|
| 195 |
+
The epsilon value to use for the normalization.
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
def __init__(
|
| 199 |
+
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6, rank=16, network_alpha=16, lora_weight=1.0,attention_type="design"
|
| 200 |
+
):
|
| 201 |
+
super().__init__()
|
| 202 |
+
|
| 203 |
+
self.norm1 = AdaLayerNormZero(dim)
|
| 204 |
+
|
| 205 |
+
self.norm1_context = AdaLayerNormZero(dim)
|
| 206 |
+
|
| 207 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
| 208 |
+
processor = DesignFluxAttnProcessor2_0()
|
| 209 |
+
else:
|
| 210 |
+
raise ValueError(
|
| 211 |
+
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
| 212 |
+
)
|
| 213 |
+
self.attn = Attention(
|
| 214 |
+
query_dim=dim,
|
| 215 |
+
cross_attention_dim=None,
|
| 216 |
+
added_kv_proj_dim=dim,
|
| 217 |
+
dim_head=attention_head_dim,
|
| 218 |
+
heads=num_attention_heads,
|
| 219 |
+
out_dim=dim,
|
| 220 |
+
context_pre_only=False,
|
| 221 |
+
bias=True,
|
| 222 |
+
processor=processor,
|
| 223 |
+
qk_norm=qk_norm,
|
| 224 |
+
eps=eps,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 228 |
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 229 |
+
|
| 230 |
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 231 |
+
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
| 232 |
+
|
| 233 |
+
# let chunk size default to None
|
| 234 |
+
self._chunk_size = None
|
| 235 |
+
self._chunk_dim = 0
|
| 236 |
+
|
| 237 |
+
# creatidesign
|
| 238 |
+
self.attention_type = attention_type
|
| 239 |
+
self.rank = rank
|
| 240 |
+
self.network_alpha = network_alpha
|
| 241 |
+
self.lora_weight = lora_weight
|
| 242 |
+
|
| 243 |
+
if self.attention_type == "design":
|
| 244 |
+
# lora for handle subject (img branch)
|
| 245 |
+
self.norm1_subject_lora = nn.Sequential(
|
| 246 |
+
nn.SiLU(),
|
| 247 |
+
LoRALinearLayer(dim, dim*6, self.rank, self.network_alpha) # lora for adalinear
|
| 248 |
+
)
|
| 249 |
+
self.layernorm_subject = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) # norm for subject
|
| 250 |
+
|
| 251 |
+
# lora for handle object (txt branch)
|
| 252 |
+
self.norm1_object_lora = nn.Sequential(
|
| 253 |
+
nn.SiLU(),
|
| 254 |
+
LoRALinearLayer(dim, dim*6, self.rank, self.network_alpha) # lora for adalinear
|
| 255 |
+
)
|
| 256 |
+
self.layernorm_object = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) # norm for object
|
| 257 |
+
|
| 258 |
+
def double_block_adaln_lora_forward(self, x, temb, adaln, adaln_lora, layernorm, lora_weight):
|
| 259 |
+
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = adaln(x, emb=temb)
|
| 260 |
+
lora_shift_msa, lora_scale_msa, lora_gate_msa, lora_shift_mlp, lora_scale_mlp, lora_gate_mlp = adaln_lora(temb).chunk(6, dim=1)
|
| 261 |
+
norm_x = norm_x + lora_weight * (layernorm(x)* (1 + lora_scale_msa[:, None]) + lora_shift_msa[:, None])
|
| 262 |
+
x_gate_msa = x_gate_msa + lora_weight*lora_gate_msa
|
| 263 |
+
x_shift_mlp = x_shift_mlp + lora_weight*lora_shift_mlp
|
| 264 |
+
x_scale_mlp = x_scale_mlp + lora_weight*lora_scale_mlp
|
| 265 |
+
x_gate_mlp = x_gate_mlp + lora_weight*lora_gate_mlp
|
| 266 |
+
return norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp
|
| 267 |
+
def forward(
|
| 268 |
+
self,
|
| 269 |
+
hidden_states: torch.Tensor,
|
| 270 |
+
encoder_hidden_states: torch.Tensor,
|
| 271 |
+
temb: torch.Tensor,
|
| 272 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 273 |
+
subject_hidden_states = None,
|
| 274 |
+
subject_rotary_emb = None,
|
| 275 |
+
object_bbox_hidden_states = None,
|
| 276 |
+
object_rotary_emb = None,
|
| 277 |
+
design_scale = 1.0,
|
| 278 |
+
attention_mask=None,
|
| 279 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 280 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 281 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
| 282 |
+
|
| 283 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
| 284 |
+
encoder_hidden_states, emb=temb
|
| 285 |
+
)
|
| 286 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
use_subject = True if self.attention_type == "design" and subject_hidden_states is not None and design_scale!=0.0 else False
|
| 290 |
+
use_object = True if self.attention_type == "design" and object_bbox_hidden_states is not None and design_scale!=0.0 else False
|
| 291 |
+
if use_subject:
|
| 292 |
+
# subject adalinear
|
| 293 |
+
norm_subject_hidden_states, subject_gate_msa, subject_shift_mlp, subject_scale_mlp, subject_gate_mlp = self.double_block_adaln_lora_forward(
|
| 294 |
+
subject_hidden_states, temb, self.norm1, self.norm1_subject_lora, self.layernorm_subject, self.lora_weight
|
| 295 |
+
)
|
| 296 |
+
if use_object:
|
| 297 |
+
# object adalinear
|
| 298 |
+
norm_object_bbox_hidden_states, object_gate_msa, object_shift_mlp, object_scale_mlp, object_gate_mlp = self.double_block_adaln_lora_forward(
|
| 299 |
+
object_bbox_hidden_states, temb, self.norm1_context, self.norm1_object_lora, self.layernorm_object, self.lora_weight
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
attn_output, context_attn_output, subject_attn_output, object_attn_output = self.attn(
|
| 304 |
+
hidden_states=norm_hidden_states,
|
| 305 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 306 |
+
image_rotary_emb=image_rotary_emb,
|
| 307 |
+
subject_hidden_states=norm_subject_hidden_states if use_subject else None,
|
| 308 |
+
subject_rotary_emb=subject_rotary_emb if use_subject else None,
|
| 309 |
+
object_bbox_hidden_states=norm_object_bbox_hidden_states if use_object else None,
|
| 310 |
+
object_rotary_emb=object_rotary_emb if use_object else None,
|
| 311 |
+
attention_mask = attention_mask,
|
| 312 |
+
**joint_attention_kwargs,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
# Process attention outputs for the `hidden_states`.
|
| 316 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
| 317 |
+
hidden_states = hidden_states + attn_output
|
| 318 |
+
|
| 319 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 320 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 321 |
+
|
| 322 |
+
ff_output = self.ff(norm_hidden_states)
|
| 323 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
| 324 |
+
|
| 325 |
+
hidden_states = hidden_states + ff_output
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
# Process attention outputs for the `encoder_hidden_states`.
|
| 330 |
+
|
| 331 |
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
| 332 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
| 333 |
+
|
| 334 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
| 335 |
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
| 336 |
+
|
| 337 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
| 338 |
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
# process attention outputs for the `subject_hidden_states`.
|
| 342 |
+
if use_subject:
|
| 343 |
+
subject_attn_output = subject_gate_msa.unsqueeze(1) * subject_attn_output
|
| 344 |
+
subject_hidden_states = subject_hidden_states + subject_attn_output
|
| 345 |
+
norm_subject_hidden_states = self.norm2(subject_hidden_states)
|
| 346 |
+
norm_subject_hidden_states = norm_subject_hidden_states * (1 + subject_scale_mlp[:, None]) + subject_shift_mlp[:, None]
|
| 347 |
+
subject_ff_output = self.ff(norm_subject_hidden_states)
|
| 348 |
+
subject_hidden_states = subject_hidden_states + subject_gate_mlp.unsqueeze(1) * subject_ff_output
|
| 349 |
+
|
| 350 |
+
# process attention outputs for the `object_bbox_hidden_states`.
|
| 351 |
+
if use_object:
|
| 352 |
+
object_attn_output = object_gate_msa.unsqueeze(1) * object_attn_output
|
| 353 |
+
object_bbox_hidden_states = object_bbox_hidden_states + object_attn_output
|
| 354 |
+
norm_object_bbox_hidden_states = self.norm2_context(object_bbox_hidden_states)
|
| 355 |
+
norm_object_bbox_hidden_states = norm_object_bbox_hidden_states * (1 + object_scale_mlp[:, None]) + object_shift_mlp[:, None]
|
| 356 |
+
object_ff_output = self.ff_context(norm_object_bbox_hidden_states)
|
| 357 |
+
object_bbox_hidden_states = object_bbox_hidden_states + object_gate_mlp.unsqueeze(1) * object_ff_output
|
| 358 |
+
|
| 359 |
+
if encoder_hidden_states.dtype == torch.float16:
|
| 360 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
| 361 |
+
|
| 362 |
+
return encoder_hidden_states, hidden_states, subject_hidden_states, object_bbox_hidden_states
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class FluxTransformer2DModel(
|
| 366 |
+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
|
| 367 |
+
):
|
| 368 |
+
"""
|
| 369 |
+
The Transformer model introduced in Flux.
|
| 370 |
+
|
| 371 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
| 372 |
+
|
| 373 |
+
Args:
|
| 374 |
+
patch_size (`int`, defaults to `1`):
|
| 375 |
+
Patch size to turn the input data into small patches.
|
| 376 |
+
in_channels (`int`, defaults to `64`):
|
| 377 |
+
The number of channels in the input.
|
| 378 |
+
out_channels (`int`, *optional*, defaults to `None`):
|
| 379 |
+
The number of channels in the output. If not specified, it defaults to `in_channels`.
|
| 380 |
+
num_layers (`int`, defaults to `19`):
|
| 381 |
+
The number of layers of dual stream DiT blocks to use.
|
| 382 |
+
num_single_layers (`int`, defaults to `38`):
|
| 383 |
+
The number of layers of single stream DiT blocks to use.
|
| 384 |
+
attention_head_dim (`int`, defaults to `128`):
|
| 385 |
+
The number of dimensions to use for each attention head.
|
| 386 |
+
num_attention_heads (`int`, defaults to `24`):
|
| 387 |
+
The number of attention heads to use.
|
| 388 |
+
joint_attention_dim (`int`, defaults to `4096`):
|
| 389 |
+
The number of dimensions to use for the joint attention (embedding/channel dimension of
|
| 390 |
+
`encoder_hidden_states`).
|
| 391 |
+
pooled_projection_dim (`int`, defaults to `768`):
|
| 392 |
+
The number of dimensions to use for the pooled projection.
|
| 393 |
+
guidance_embeds (`bool`, defaults to `False`):
|
| 394 |
+
Whether to use guidance embeddings for guidance-distilled variant of the model.
|
| 395 |
+
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
| 396 |
+
The dimensions to use for the rotary positional embeddings.
|
| 397 |
+
"""
|
| 398 |
+
|
| 399 |
+
_supports_gradient_checkpointing = True
|
| 400 |
+
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
| 401 |
+
|
| 402 |
+
@register_to_config
|
| 403 |
+
def __init__(
|
| 404 |
+
self,
|
| 405 |
+
patch_size: int = 1,
|
| 406 |
+
in_channels: int = 64,
|
| 407 |
+
out_channels: Optional[int] = None,
|
| 408 |
+
num_layers: int = 19,
|
| 409 |
+
num_single_layers: int = 38,
|
| 410 |
+
attention_head_dim: int = 128,
|
| 411 |
+
num_attention_heads: int = 24,
|
| 412 |
+
joint_attention_dim: int = 4096,
|
| 413 |
+
pooled_projection_dim: int = 768,
|
| 414 |
+
guidance_embeds: bool = False,
|
| 415 |
+
axes_dims_rope: Tuple[int] = (16, 56, 56),
|
| 416 |
+
attention_type="design",
|
| 417 |
+
max_boxes_token_length=30,
|
| 418 |
+
rank = 16,
|
| 419 |
+
network_alpha = 16,
|
| 420 |
+
lora_weight = 1.0,
|
| 421 |
+
use_attention_mask = True,
|
| 422 |
+
use_objects_masks_maps=True,
|
| 423 |
+
use_subject_masks_maps=True,
|
| 424 |
+
use_layout_encoder=True,
|
| 425 |
+
drop_subject_bg=False,
|
| 426 |
+
gradient_checkpointing=False,
|
| 427 |
+
use_fourier_bbox=True,
|
| 428 |
+
bbox_id_shift=True
|
| 429 |
+
):
|
| 430 |
+
super().__init__()
|
| 431 |
+
# #creatidesign
|
| 432 |
+
self.attention_type = attention_type
|
| 433 |
+
self.max_boxes_token_length = max_boxes_token_length
|
| 434 |
+
self.rank = rank
|
| 435 |
+
self.network_alpha = network_alpha
|
| 436 |
+
self.lora_weight = lora_weight
|
| 437 |
+
self.use_attention_mask = use_attention_mask
|
| 438 |
+
self.use_objects_masks_maps= use_objects_masks_maps
|
| 439 |
+
self.num_attention_heads=num_attention_heads
|
| 440 |
+
self.use_layout_encoder = use_layout_encoder
|
| 441 |
+
self.use_subject_masks_maps = use_subject_masks_maps
|
| 442 |
+
self.drop_subject_bg = drop_subject_bg
|
| 443 |
+
self.gradient_checkpointing = gradient_checkpointing
|
| 444 |
+
self.use_fourier_bbox = use_fourier_bbox
|
| 445 |
+
self.bbox_id_shift = bbox_id_shift
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
self.out_channels = out_channels or in_channels
|
| 449 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 450 |
+
|
| 451 |
+
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
| 452 |
+
|
| 453 |
+
text_time_guidance_cls = (
|
| 454 |
+
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
| 455 |
+
)
|
| 456 |
+
self.time_text_embed = text_time_guidance_cls(
|
| 457 |
+
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
| 461 |
+
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
|
| 462 |
+
|
| 463 |
+
self.transformer_blocks = nn.ModuleList(
|
| 464 |
+
[
|
| 465 |
+
FluxTransformerBlock(
|
| 466 |
+
dim=self.inner_dim,
|
| 467 |
+
num_attention_heads=num_attention_heads,
|
| 468 |
+
attention_head_dim=attention_head_dim,
|
| 469 |
+
attention_type=self.attention_type,
|
| 470 |
+
rank=self.rank,
|
| 471 |
+
network_alpha=self.network_alpha,
|
| 472 |
+
lora_weight=self.lora_weight,
|
| 473 |
+
)
|
| 474 |
+
for _ in range(num_layers)
|
| 475 |
+
]
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
self.single_transformer_blocks = nn.ModuleList(
|
| 479 |
+
[
|
| 480 |
+
FluxSingleTransformerBlock(
|
| 481 |
+
dim=self.inner_dim,
|
| 482 |
+
num_attention_heads=num_attention_heads,
|
| 483 |
+
attention_head_dim=attention_head_dim,
|
| 484 |
+
attention_type=self.attention_type,
|
| 485 |
+
rank=self.rank,
|
| 486 |
+
network_alpha=self.network_alpha,
|
| 487 |
+
lora_weight=self.lora_weight,
|
| 488 |
+
)
|
| 489 |
+
for _ in range(num_single_layers)
|
| 490 |
+
]
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
| 494 |
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
if self.attention_type =="design":
|
| 498 |
+
if self.use_layout_encoder:
|
| 499 |
+
if self.use_fourier_bbox:
|
| 500 |
+
self.object_layout_encoder = ObjectLayoutEncoder(
|
| 501 |
+
positive_len=self.inner_dim, out_dim=self.inner_dim, max_boxes_token_length=self.max_boxes_token_length
|
| 502 |
+
)
|
| 503 |
+
else:
|
| 504 |
+
self.object_layout_encoder = ObjectLayoutEncoder_noFourier(
|
| 505 |
+
in_dim=self.inner_dim, out_dim=self.inner_dim
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
@property
|
| 510 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 511 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 512 |
+
r"""
|
| 513 |
+
Returns:
|
| 514 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 515 |
+
indexed by its weight name.
|
| 516 |
+
"""
|
| 517 |
+
# set recursively
|
| 518 |
+
processors = {}
|
| 519 |
+
|
| 520 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 521 |
+
if hasattr(module, "get_processor"):
|
| 522 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 523 |
+
|
| 524 |
+
for sub_name, child in module.named_children():
|
| 525 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 526 |
+
|
| 527 |
+
return processors
|
| 528 |
+
|
| 529 |
+
for name, module in self.named_children():
|
| 530 |
+
fn_recursive_add_processors(name, module, processors)
|
| 531 |
+
|
| 532 |
+
return processors
|
| 533 |
+
|
| 534 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 535 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 536 |
+
r"""
|
| 537 |
+
Sets the attention processor to use to compute attention.
|
| 538 |
+
|
| 539 |
+
Parameters:
|
| 540 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 541 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 542 |
+
for **all** `Attention` layers.
|
| 543 |
+
|
| 544 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 545 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 546 |
+
|
| 547 |
+
"""
|
| 548 |
+
count = len(self.attn_processors.keys())
|
| 549 |
+
|
| 550 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 551 |
+
raise ValueError(
|
| 552 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 553 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 557 |
+
if hasattr(module, "set_processor"):
|
| 558 |
+
if not isinstance(processor, dict):
|
| 559 |
+
module.set_processor(processor)
|
| 560 |
+
else:
|
| 561 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 562 |
+
|
| 563 |
+
for sub_name, child in module.named_children():
|
| 564 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 565 |
+
|
| 566 |
+
for name, module in self.named_children():
|
| 567 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 568 |
+
|
| 569 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
|
| 570 |
+
def fuse_qkv_projections(self):
|
| 571 |
+
"""
|
| 572 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
| 573 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
| 574 |
+
|
| 575 |
+
<Tip warning={true}>
|
| 576 |
+
|
| 577 |
+
This API is 🧪 experimental.
|
| 578 |
+
|
| 579 |
+
</Tip>
|
| 580 |
+
"""
|
| 581 |
+
self.original_attn_processors = None
|
| 582 |
+
|
| 583 |
+
for _, attn_processor in self.attn_processors.items():
|
| 584 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
| 585 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
| 586 |
+
|
| 587 |
+
self.original_attn_processors = self.attn_processors
|
| 588 |
+
|
| 589 |
+
for module in self.modules():
|
| 590 |
+
if isinstance(module, Attention):
|
| 591 |
+
module.fuse_projections(fuse=True)
|
| 592 |
+
|
| 593 |
+
self.set_attn_processor(FusedFluxAttnProcessor2_0())
|
| 594 |
+
|
| 595 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
| 596 |
+
def unfuse_qkv_projections(self):
|
| 597 |
+
"""Disables the fused QKV projection if enabled.
|
| 598 |
+
|
| 599 |
+
<Tip warning={true}>
|
| 600 |
+
|
| 601 |
+
This API is 🧪 experimental.
|
| 602 |
+
|
| 603 |
+
</Tip>
|
| 604 |
+
|
| 605 |
+
"""
|
| 606 |
+
if self.original_attn_processors is not None:
|
| 607 |
+
self.set_attn_processor(self.original_attn_processors)
|
| 608 |
+
|
| 609 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 610 |
+
if hasattr(module, "gradient_checkpointing"):
|
| 611 |
+
module.gradient_checkpointing = value
|
| 612 |
+
|
| 613 |
+
def forward(
|
| 614 |
+
self,
|
| 615 |
+
hidden_states: torch.Tensor,
|
| 616 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 617 |
+
pooled_projections: torch.Tensor = None,
|
| 618 |
+
timestep: torch.LongTensor = None,
|
| 619 |
+
img_ids: torch.Tensor = None,
|
| 620 |
+
txt_ids: torch.Tensor = None,
|
| 621 |
+
guidance: torch.Tensor = None,
|
| 622 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 623 |
+
controlnet_block_samples=None,
|
| 624 |
+
controlnet_single_block_samples=None,
|
| 625 |
+
return_dict: bool = True,
|
| 626 |
+
controlnet_blocks_repeat: bool = False,
|
| 627 |
+
design_kwargs: dict | None = None,
|
| 628 |
+
design_scale =1.0
|
| 629 |
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
| 630 |
+
"""
|
| 631 |
+
The [`FluxTransformer2DModel`] forward method.
|
| 632 |
+
|
| 633 |
+
Args:
|
| 634 |
+
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
| 635 |
+
Input `hidden_states`.
|
| 636 |
+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
| 637 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
| 638 |
+
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
| 639 |
+
from the embeddings of input conditions.
|
| 640 |
+
timestep ( `torch.LongTensor`):
|
| 641 |
+
Used to indicate denoising step.
|
| 642 |
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
| 643 |
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
| 644 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 645 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 646 |
+
`self.processor` in
|
| 647 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 648 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 649 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
| 650 |
+
tuple.
|
| 651 |
+
|
| 652 |
+
Returns:
|
| 653 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
| 654 |
+
`tuple` where the first element is the sample tensor.
|
| 655 |
+
"""
|
| 656 |
+
if joint_attention_kwargs is not None:
|
| 657 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
| 658 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
| 659 |
+
else:
|
| 660 |
+
lora_scale = 1.0
|
| 661 |
+
|
| 662 |
+
if USE_PEFT_BACKEND:
|
| 663 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 664 |
+
scale_lora_layers(self, lora_scale)
|
| 665 |
+
else:
|
| 666 |
+
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
| 667 |
+
logger.warning(
|
| 668 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 673 |
+
|
| 674 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
| 675 |
+
if guidance is not None:
|
| 676 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 677 |
+
else:
|
| 678 |
+
guidance = None
|
| 679 |
+
|
| 680 |
+
temb = (
|
| 681 |
+
self.time_text_embed(timestep, pooled_projections)
|
| 682 |
+
if guidance is None
|
| 683 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
| 684 |
+
)
|
| 685 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 686 |
+
|
| 687 |
+
if txt_ids.ndim == 3:
|
| 688 |
+
# logger.warning(
|
| 689 |
+
# "Passing `txt_ids` 3d torch.Tensor is deprecated."
|
| 690 |
+
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
|
| 691 |
+
# )
|
| 692 |
+
txt_ids = txt_ids[0]
|
| 693 |
+
if img_ids.ndim == 3:
|
| 694 |
+
# logger.warning(
|
| 695 |
+
# "Passing `img_ids` 3d torch.Tensor is deprecated."
|
| 696 |
+
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
|
| 697 |
+
# )
|
| 698 |
+
img_ids = img_ids[0]
|
| 699 |
+
|
| 700 |
+
attention_mask_batch = None
|
| 701 |
+
# handle design infos
|
| 702 |
+
if self.attention_type=="design" and design_kwargs is not None:
|
| 703 |
+
|
| 704 |
+
# handle objects
|
| 705 |
+
objects_boxes = design_kwargs['object_layout']['objects_boxes'].to(dtype=hidden_states.dtype, device=hidden_states.device) # [B,10,4]
|
| 706 |
+
objects_bbox_text_embeddings = design_kwargs['object_layout']['bbox_text_embeddings'].to(dtype=hidden_states.dtype, device=hidden_states.device) # [B,10,512,4096]
|
| 707 |
+
objects_bbox_masks = design_kwargs['object_layout']['bbox_masks'].to(dtype=hidden_states.dtype, device=hidden_states.device) # [B,10]
|
| 708 |
+
#token Truncation
|
| 709 |
+
objects_bbox_text_embeddings = objects_bbox_text_embeddings[:,:,:self.max_boxes_token_length,:]# [B,10,30,4096]
|
| 710 |
+
|
| 711 |
+
# [B,10,30,4096] -> [B*10,30,4096] -> [B*10,30,3072] -> [B,10,30,3072]
|
| 712 |
+
B, N, S, C = objects_bbox_text_embeddings.shape
|
| 713 |
+
objects_bbox_text_embeddings = objects_bbox_text_embeddings.reshape(-1, S, C) #[B*10,30,4096]
|
| 714 |
+
objects_bbox_text_embeddings = self.context_embedder(objects_bbox_text_embeddings) #[B*10,30,3072]
|
| 715 |
+
objects_bbox_text_embeddings = objects_bbox_text_embeddings.reshape(B, N, S, -1) # [B,10,30,3072]
|
| 716 |
+
|
| 717 |
+
if self.use_layout_encoder:
|
| 718 |
+
if self.use_fourier_bbox:
|
| 719 |
+
object_bbox_hidden_states = self.object_layout_encoder(
|
| 720 |
+
boxes=objects_boxes,
|
| 721 |
+
masks=objects_bbox_masks,
|
| 722 |
+
positive_embeddings=objects_bbox_text_embeddings,
|
| 723 |
+
)# [B,10,30,3072]
|
| 724 |
+
else:
|
| 725 |
+
object_bbox_hidden_states = self.object_layout_encoder(
|
| 726 |
+
positive_embeddings=objects_bbox_text_embeddings,
|
| 727 |
+
)# [B,10,30,3072]
|
| 728 |
+
else:
|
| 729 |
+
object_bbox_hidden_states = objects_bbox_text_embeddings
|
| 730 |
+
|
| 731 |
+
object_bbox_hidden_states = object_bbox_hidden_states.contiguous().view(B, N*S, -1) # [B,300,3072]
|
| 732 |
+
|
| 733 |
+
# bbox_id shift
|
| 734 |
+
if self.bbox_id_shift:
|
| 735 |
+
object_bbox_ids = -1 * torch.ones(object_bbox_hidden_states.shape[0], object_bbox_hidden_states.shape[1], 3).to(device=object_bbox_hidden_states.device, dtype=object_bbox_hidden_states.dtype)
|
| 736 |
+
else:
|
| 737 |
+
object_bbox_ids = torch.zeros(object_bbox_hidden_states.shape[0], object_bbox_hidden_states.shape[1], 3).to(device=object_bbox_hidden_states.device, dtype=object_bbox_hidden_states.dtype)
|
| 738 |
+
if object_bbox_ids.ndim == 3:
|
| 739 |
+
object_bbox_ids = object_bbox_ids[0] #[300,3]
|
| 740 |
+
object_rotary_emb = self.pos_embed(object_bbox_ids)
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
# handle subjects
|
| 745 |
+
subject_hidden_states = design_kwargs['subject_contion']['condition_img']
|
| 746 |
+
subject_hidden_states = self.x_embedder(subject_hidden_states)
|
| 747 |
+
subject_ids = design_kwargs['subject_contion']['condition_img_ids']
|
| 748 |
+
if subject_ids.ndim == 3:
|
| 749 |
+
subject_ids = subject_ids[0]
|
| 750 |
+
subject_rotary_emb = self.pos_embed(subject_ids)
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
if self.use_attention_mask:
|
| 755 |
+
num_objects = N
|
| 756 |
+
tokens_per_object = self.max_boxes_token_length
|
| 757 |
+
total_object_tokens = object_bbox_hidden_states.shape[1]
|
| 758 |
+
assert total_object_tokens == num_objects * tokens_per_object, "Total object tokens do not match expected value"
|
| 759 |
+
encoder_tokens = encoder_hidden_states.shape[1]
|
| 760 |
+
img_tokens = hidden_states.shape[1]
|
| 761 |
+
subject_tokens = subject_hidden_states.shape[1]
|
| 762 |
+
# Total number of tokens
|
| 763 |
+
total_tokens = total_object_tokens + encoder_tokens + img_tokens + subject_tokens
|
| 764 |
+
|
| 765 |
+
attention_mask_batch = torch.zeros((B,total_tokens, total_tokens), dtype=hidden_states.dtype,device=hidden_states.device)
|
| 766 |
+
img_H, img_W = design_kwargs['object_layout']['img_token_h'], design_kwargs['object_layout']['img_token_w']
|
| 767 |
+
objects_masks_maps = design_kwargs['object_layout']['objects_masks_maps'].to(dtype=hidden_states.dtype, device=hidden_states.device) # [B,512,512]
|
| 768 |
+
subject_H,subject_W = design_kwargs['subject_contion']['subject_token_h'], design_kwargs['subject_contion']['subject_token_w']
|
| 769 |
+
subject_masks_maps = design_kwargs['subject_contion']['subject_masks_maps'].to(dtype=hidden_states.dtype, device=hidden_states.device) # [B,512,512]
|
| 770 |
+
for m_idx in range(B):
|
| 771 |
+
# Create the base mask (all False/blocked)
|
| 772 |
+
attention_mask = torch.zeros((total_tokens, total_tokens), dtype=hidden_states.dtype,device=hidden_states.device)
|
| 773 |
+
|
| 774 |
+
# Define token ranges
|
| 775 |
+
o_ranges = [] # Ranges for each object
|
| 776 |
+
start_idx = 0
|
| 777 |
+
for i in range(num_objects):
|
| 778 |
+
end_idx = start_idx + tokens_per_object
|
| 779 |
+
o_ranges.append((start_idx, end_idx))
|
| 780 |
+
start_idx = end_idx
|
| 781 |
+
|
| 782 |
+
encoder_range = (total_object_tokens, total_object_tokens + encoder_tokens)
|
| 783 |
+
img_range = (encoder_range[1], encoder_range[1] + img_tokens)
|
| 784 |
+
subject_range = (img_range[1], img_range[1] + subject_tokens)
|
| 785 |
+
|
| 786 |
+
# Fill in the mask
|
| 787 |
+
|
| 788 |
+
# 1. Object self-attention (diagonal o₁-o₁, o₂-o₂, o₃-o₃)
|
| 789 |
+
for o_start, o_end in o_ranges:
|
| 790 |
+
attention_mask[o_start:o_end, o_start:o_end] = True
|
| 791 |
+
|
| 792 |
+
# 2. Objects to img and img to objetcs
|
| 793 |
+
|
| 794 |
+
if not self.use_objects_masks_maps:
|
| 795 |
+
# all objects can attend to img and img can attend to all objects
|
| 796 |
+
for o_start, o_end in o_ranges:
|
| 797 |
+
attention_mask[o_start:o_end, img_range[0]:img_range[1]] = True
|
| 798 |
+
# img can attend to all
|
| 799 |
+
attention_mask[img_range[0]:img_range[1], :] = True
|
| 800 |
+
else:
|
| 801 |
+
# all objects can only attend to the bbox area (defined by objects_mask) of img
|
| 802 |
+
for idx, (o_start, o_end )in enumerate(o_ranges):
|
| 803 |
+
mask = objects_masks_maps[m_idx][idx]
|
| 804 |
+
mask = torch.nn.functional.interpolate(mask[None, None, :, :], (img_H, img_W), mode='nearest-exact').flatten().unsqueeze(1).repeat(1, tokens_per_object) #shape: [img_tokens,tokens_per_object]
|
| 805 |
+
|
| 806 |
+
# objects to img
|
| 807 |
+
attention_mask[o_start:o_end, img_range[0]:img_range[1]] = mask.transpose(-1, -2)
|
| 808 |
+
|
| 809 |
+
# img to objects
|
| 810 |
+
attention_mask[img_range[0]:img_range[1], o_start:o_end] = mask
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
# img to img
|
| 814 |
+
attention_mask[img_range[0]:img_range[1], img_range[0]:img_range[1]] = True
|
| 815 |
+
|
| 816 |
+
# img to prompt
|
| 817 |
+
attention_mask[img_range[0]:img_range[1], encoder_range[0]:encoder_range[1]] = True
|
| 818 |
+
|
| 819 |
+
# img to subject
|
| 820 |
+
subject_mask = subject_masks_maps[m_idx][0]
|
| 821 |
+
|
| 822 |
+
if not self.use_subject_masks_maps:
|
| 823 |
+
# all img can attend to subject
|
| 824 |
+
attention_mask[img_range[0]:img_range[1], subject_range[0]:subject_range[1]] = True
|
| 825 |
+
else:
|
| 826 |
+
# img can only attend to the bbox area (defined by subject_mask) of subject
|
| 827 |
+
|
| 828 |
+
subject_mask_img = torch.nn.functional.interpolate(subject_mask[None, None, :, :], (img_H, img_W), mode='nearest-exact').flatten().unsqueeze(1).repeat(1, subject_tokens) #shape: [img_tokens,subject_tokens]
|
| 829 |
+
|
| 830 |
+
# img to objects
|
| 831 |
+
attention_mask[img_range[0]:img_range[1], subject_range[0]:subject_range[1]] = subject_mask_img
|
| 832 |
+
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
# 3. prompt to prompt, prompt to img, and prompt to subject
|
| 836 |
+
|
| 837 |
+
# prompt to prompt
|
| 838 |
+
attention_mask[encoder_range[0]:encoder_range[1], encoder_range[0]:encoder_range[1]] = True
|
| 839 |
+
# prompt to img
|
| 840 |
+
attention_mask[encoder_range[0]:encoder_range[1], img_range[0]:img_range[1]] = True
|
| 841 |
+
|
| 842 |
+
# prompt to subject
|
| 843 |
+
if not self.use_subject_masks_maps:
|
| 844 |
+
attention_mask[encoder_range[0]:encoder_range[1], subject_range[0]:subject_range[1]] = True
|
| 845 |
+
else:
|
| 846 |
+
subject_mask_prompt = torch.nn.functional.interpolate(subject_mask[None, None, :, :], (subject_H, subject_W), mode='nearest-exact').flatten().unsqueeze(1).repeat(1, encoder_tokens) #shape: [subject_tokens,encoder_tokens]
|
| 847 |
+
attention_mask[encoder_range[0]:encoder_range[1], subject_range[0]:subject_range[1]] = subject_mask_prompt.transpose(-1, -2)
|
| 848 |
+
|
| 849 |
+
|
| 850 |
+
# 4. subject to prompt, subject to img, subject to subject
|
| 851 |
+
# subject to prompt
|
| 852 |
+
if not self.use_subject_masks_maps:
|
| 853 |
+
attention_mask[subject_range[0]:subject_range[1], encoder_range[0]:encoder_range[1]] = True
|
| 854 |
+
else:
|
| 855 |
+
attention_mask[subject_range[0]:subject_range[1], encoder_range[0]:encoder_range[1]] = False
|
| 856 |
+
|
| 857 |
+
# subject to img
|
| 858 |
+
if not self.use_subject_masks_maps:
|
| 859 |
+
attention_mask[subject_range[0]:subject_range[1], img_range[0]:img_range[1]] = True
|
| 860 |
+
else:
|
| 861 |
+
attention_mask[subject_range[0]:subject_range[1], img_range[0]:img_range[1]] = subject_mask_img.transpose(-1, -2)
|
| 862 |
+
# subject to subject
|
| 863 |
+
if not self.use_subject_masks_maps:
|
| 864 |
+
attention_mask[subject_range[0]:subject_range[1], subject_range[0]:subject_range[1]] = True
|
| 865 |
+
else:
|
| 866 |
+
# blcok non-subject region
|
| 867 |
+
if not self.drop_subject_bg:
|
| 868 |
+
attention_mask[subject_range[0]:subject_range[1], subject_range[0]:subject_range[1]] = True
|
| 869 |
+
else:
|
| 870 |
+
attention_mask[subject_range[0]:subject_range[1], subject_range[0]:subject_range[1]] = subject_mask_img
|
| 871 |
+
|
| 872 |
+
|
| 873 |
+
attention_mask_batch[m_idx] = attention_mask
|
| 874 |
+
|
| 875 |
+
attention_mask_batch = attention_mask_batch.unsqueeze(1).to(dtype=torch.bool, device=hidden_states.device)#[B,2860,2860]->[B,1,2860,2860]
|
| 876 |
+
|
| 877 |
+
|
| 878 |
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
| 879 |
+
image_rotary_emb = self.pos_embed(ids)
|
| 880 |
+
|
| 881 |
+
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
| 882 |
+
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
| 883 |
+
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
|
| 884 |
+
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 888 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 889 |
+
|
| 890 |
+
def create_custom_forward(module, return_dict=None):
|
| 891 |
+
def custom_forward(*inputs):
|
| 892 |
+
if return_dict is not None:
|
| 893 |
+
return module(*inputs, return_dict=return_dict)
|
| 894 |
+
else:
|
| 895 |
+
return module(*inputs)
|
| 896 |
+
|
| 897 |
+
return custom_forward
|
| 898 |
+
|
| 899 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 900 |
+
encoder_hidden_states, hidden_states, subject_hidden_states, object_bbox_hidden_states = torch.utils.checkpoint.checkpoint(
|
| 901 |
+
create_custom_forward(block),
|
| 902 |
+
hidden_states,
|
| 903 |
+
encoder_hidden_states,
|
| 904 |
+
temb,
|
| 905 |
+
image_rotary_emb,
|
| 906 |
+
subject_hidden_states,
|
| 907 |
+
subject_rotary_emb,
|
| 908 |
+
object_bbox_hidden_states,
|
| 909 |
+
object_rotary_emb,
|
| 910 |
+
design_scale,
|
| 911 |
+
attention_mask_batch,
|
| 912 |
+
**ckpt_kwargs,
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
else:
|
| 916 |
+
encoder_hidden_states, hidden_states, subject_hidden_states, object_bbox_hidden_states = block(
|
| 917 |
+
hidden_states=hidden_states,
|
| 918 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 919 |
+
temb=temb,
|
| 920 |
+
image_rotary_emb=image_rotary_emb,
|
| 921 |
+
subject_hidden_states=subject_hidden_states,
|
| 922 |
+
subject_rotary_emb=subject_rotary_emb,
|
| 923 |
+
object_bbox_hidden_states=object_bbox_hidden_states,
|
| 924 |
+
object_rotary_emb=object_rotary_emb,
|
| 925 |
+
design_scale = design_scale,
|
| 926 |
+
attention_mask = attention_mask_batch,
|
| 927 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
# controlnet residual
|
| 931 |
+
if controlnet_block_samples is not None:
|
| 932 |
+
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
| 933 |
+
interval_control = int(np.ceil(interval_control))
|
| 934 |
+
# For Xlabs ControlNet.
|
| 935 |
+
if controlnet_blocks_repeat:
|
| 936 |
+
hidden_states = (
|
| 937 |
+
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
|
| 938 |
+
)
|
| 939 |
+
else:
|
| 940 |
+
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
| 941 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 942 |
+
|
| 943 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
| 944 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 945 |
+
|
| 946 |
+
def create_custom_forward(module, return_dict=None):
|
| 947 |
+
def custom_forward(*inputs):
|
| 948 |
+
if return_dict is not None:
|
| 949 |
+
return module(*inputs, return_dict=return_dict)
|
| 950 |
+
else:
|
| 951 |
+
return module(*inputs)
|
| 952 |
+
|
| 953 |
+
return custom_forward
|
| 954 |
+
|
| 955 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 956 |
+
hidden_states, subject_hidden_states, object_bbox_hidden_states = torch.utils.checkpoint.checkpoint(
|
| 957 |
+
create_custom_forward(block),
|
| 958 |
+
hidden_states,
|
| 959 |
+
temb,
|
| 960 |
+
image_rotary_emb,
|
| 961 |
+
subject_hidden_states,
|
| 962 |
+
subject_rotary_emb,
|
| 963 |
+
object_bbox_hidden_states,
|
| 964 |
+
object_rotary_emb,
|
| 965 |
+
design_scale,
|
| 966 |
+
attention_mask_batch,
|
| 967 |
+
**ckpt_kwargs,
|
| 968 |
+
)
|
| 969 |
+
|
| 970 |
+
else:
|
| 971 |
+
hidden_states, subject_hidden_states, object_bbox_hidden_states = block(
|
| 972 |
+
hidden_states=hidden_states,
|
| 973 |
+
temb=temb,
|
| 974 |
+
image_rotary_emb=image_rotary_emb,
|
| 975 |
+
subject_hidden_states=subject_hidden_states,
|
| 976 |
+
subject_rotary_emb=subject_rotary_emb,
|
| 977 |
+
object_bbox_hidden_states=object_bbox_hidden_states,
|
| 978 |
+
object_rotary_emb=object_rotary_emb,
|
| 979 |
+
design_scale=design_scale,
|
| 980 |
+
attention_mask = attention_mask_batch,
|
| 981 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 982 |
+
)
|
| 983 |
+
|
| 984 |
+
# controlnet residual
|
| 985 |
+
if controlnet_single_block_samples is not None:
|
| 986 |
+
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
| 987 |
+
interval_control = int(np.ceil(interval_control))
|
| 988 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
| 989 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
| 990 |
+
+ controlnet_single_block_samples[index_block // interval_control]
|
| 991 |
+
)
|
| 992 |
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
| 993 |
+
|
| 994 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 995 |
+
output = self.proj_out(hidden_states)
|
| 996 |
+
|
| 997 |
+
if USE_PEFT_BACKEND:
|
| 998 |
+
# remove `lora_scale` from each PEFT layer
|
| 999 |
+
unscale_lora_layers(self, lora_scale)
|
| 1000 |
+
|
| 1001 |
+
if not return_dict:
|
| 1002 |
+
return (output,)
|
| 1003 |
+
|
| 1004 |
+
return Transformer2DModelOutput(sample=output)
|
modules/semantic_layout/__pycache__/layout_encoder.cpython-310.pyc
ADDED
|
Binary file (4.26 kB). View file
|
|
|
modules/semantic_layout/layout_encoder.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
def zero_module(module):
|
| 4 |
+
"""
|
| 5 |
+
Zero out the parameters of a module and return it.
|
| 6 |
+
"""
|
| 7 |
+
for p in module.parameters():
|
| 8 |
+
p.detach().zero_()
|
| 9 |
+
return module
|
| 10 |
+
|
| 11 |
+
def get_fourier_embeds_from_boundingbox(embed_dim, box):
|
| 12 |
+
"""
|
| 13 |
+
Args:
|
| 14 |
+
embed_dim: int
|
| 15 |
+
box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline
|
| 16 |
+
Returns:
|
| 17 |
+
[B x N x embed_dim] tensor of positional embeddings
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
batch_size, num_boxes = box.shape[:2]
|
| 21 |
+
|
| 22 |
+
emb = 100 ** (torch.arange(embed_dim) / embed_dim)
|
| 23 |
+
emb = emb[None, None, None].to(device=box.device, dtype=box.dtype)
|
| 24 |
+
emb = emb * box.unsqueeze(-1)
|
| 25 |
+
|
| 26 |
+
emb = torch.stack((emb.sin(), emb.cos()), dim=-1)
|
| 27 |
+
emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4)
|
| 28 |
+
|
| 29 |
+
return emb
|
| 30 |
+
|
| 31 |
+
class PixArtAlphaTextProjection(nn.Module):
|
| 32 |
+
"""
|
| 33 |
+
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
| 34 |
+
|
| 35 |
+
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
|
| 39 |
+
super().__init__()
|
| 40 |
+
if out_features is None:
|
| 41 |
+
out_features = hidden_size
|
| 42 |
+
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
|
| 43 |
+
if act_fn == "gelu_tanh":
|
| 44 |
+
self.act_1 = nn.GELU(approximate="tanh")
|
| 45 |
+
elif act_fn == "silu":
|
| 46 |
+
self.act_1 = nn.SiLU()
|
| 47 |
+
elif act_fn == "silu_fp32":
|
| 48 |
+
self.act_1 = FP32SiLU()
|
| 49 |
+
else:
|
| 50 |
+
raise ValueError(f"Unknown activation function: {act_fn}")
|
| 51 |
+
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
|
| 52 |
+
|
| 53 |
+
def forward(self, caption):
|
| 54 |
+
hidden_states = self.linear_1(caption)
|
| 55 |
+
hidden_states = self.act_1(hidden_states)
|
| 56 |
+
hidden_states = self.linear_2(hidden_states)
|
| 57 |
+
return hidden_states
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class ObjectLayoutEncoder(nn.Module):
|
| 61 |
+
def __init__(self, positive_len, out_dim, fourier_freqs=8 ,max_boxes_token_length=30):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.positive_len = positive_len
|
| 64 |
+
self.out_dim = out_dim
|
| 65 |
+
|
| 66 |
+
self.fourier_embedder_dim = fourier_freqs
|
| 67 |
+
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy #64
|
| 68 |
+
|
| 69 |
+
if isinstance(out_dim, tuple):
|
| 70 |
+
out_dim = out_dim[0]
|
| 71 |
+
|
| 72 |
+
self.null_positive_feature = torch.nn.Parameter(torch.zeros([max_boxes_token_length, self.positive_len]))
|
| 73 |
+
self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
self.linears = PixArtAlphaTextProjection(in_features=self.positive_len + self.position_dim,hidden_size=out_dim//2,out_features=out_dim, act_fn="silu")
|
| 77 |
+
|
| 78 |
+
def forward(
|
| 79 |
+
self,
|
| 80 |
+
boxes, # [B,10,4]
|
| 81 |
+
masks, # [B,10]
|
| 82 |
+
positive_embeddings, # [B,10,30,3072]
|
| 83 |
+
):
|
| 84 |
+
|
| 85 |
+
B, N, S, C = positive_embeddings.shape # B: batch_size, N: 10, S: 30, C: 3072
|
| 86 |
+
|
| 87 |
+
positive_embeddings = positive_embeddings.reshape(B*N, S, C) # [B*10,30,3072]
|
| 88 |
+
masks = masks.reshape(B*N, 1, 1) # [B*10,1,1]
|
| 89 |
+
|
| 90 |
+
# Process positional encoding
|
| 91 |
+
xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # [B,10,64]
|
| 92 |
+
xyxy_embedding = xyxy_embedding.reshape(B*N, -1) # [B*10,64]
|
| 93 |
+
xyxy_null = self.null_position_feature.view(1, -1) # [1,64]
|
| 94 |
+
|
| 95 |
+
# Expand positional encoding to match sequence dimension
|
| 96 |
+
xyxy_embedding = xyxy_embedding.unsqueeze(1).expand(-1, S, -1) # [B*10,30,64]
|
| 97 |
+
xyxy_null = xyxy_null.unsqueeze(0).expand(B*N, S, -1) # [B*10,30,64]
|
| 98 |
+
|
| 99 |
+
# Apply mask
|
| 100 |
+
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null # [B*10,30,64]
|
| 101 |
+
|
| 102 |
+
# Process feature encoding
|
| 103 |
+
positive_null = self.null_positive_feature.view(1, S, -1).expand(B*N, -1, -1) # [B*10,30,3072]
|
| 104 |
+
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null # [B*10,30,3072]
|
| 105 |
+
|
| 106 |
+
# Concatenate positional encoding and feature encoding
|
| 107 |
+
combined = torch.cat([positive_embeddings, xyxy_embedding], dim=-1) # [B*10,30,3072+64]
|
| 108 |
+
|
| 109 |
+
# Process each box's features independently
|
| 110 |
+
objs = self.linears(combined) # [B*10,30,3072]
|
| 111 |
+
|
| 112 |
+
# Restore original shape
|
| 113 |
+
objs = objs.reshape(B, N, S, -1) # [B,10,30,3072]
|
| 114 |
+
|
| 115 |
+
return objs
|
| 116 |
+
|
| 117 |
+
class ObjectLayoutEncoder_noFourier(nn.Module):
|
| 118 |
+
def __init__(self, in_dim, out_dim):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.in_dim = in_dim
|
| 121 |
+
self.out_dim = out_dim
|
| 122 |
+
|
| 123 |
+
self.linears = PixArtAlphaTextProjection(in_features=self.in_dim,hidden_size=out_dim//2,out_features=out_dim, act_fn="silu")
|
| 124 |
+
|
| 125 |
+
def forward(
|
| 126 |
+
self,
|
| 127 |
+
positive_embeddings, # [B,10,30,3072]
|
| 128 |
+
):
|
| 129 |
+
|
| 130 |
+
B, N, S, C = positive_embeddings.shape # B: batch_size, N: 10, S: 30, C: 3072
|
| 131 |
+
positive_embeddings = positive_embeddings.reshape(B*N, S, C) # [B*10,30,3072]
|
| 132 |
+
|
| 133 |
+
# Process each box's features independently
|
| 134 |
+
objs = self.linears(positive_embeddings) # [B*10,30,3072]
|
| 135 |
+
|
| 136 |
+
# Restore original shape
|
| 137 |
+
objs = objs.reshape(B, N, S, -1) # [B,10,30,3072]
|
| 138 |
+
|
| 139 |
+
return objs
|
pipeline/__pycache__/pipeline_flux_creatidesign.cpython-310.pyc
ADDED
|
Binary file (32.3 kB). View file
|
|
|
pipeline/pipeline_flux_creatidesign.py
ADDED
|
@@ -0,0 +1,1068 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
from transformers import (
|
| 21 |
+
CLIPImageProcessor,
|
| 22 |
+
CLIPTextModel,
|
| 23 |
+
CLIPTokenizer,
|
| 24 |
+
CLIPVisionModelWithProjection,
|
| 25 |
+
T5EncoderModel,
|
| 26 |
+
T5TokenizerFast,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 30 |
+
from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
| 31 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
| 32 |
+
from diffusers.models.transformers import FluxTransformer2DModel
|
| 33 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 34 |
+
from diffusers.utils import (
|
| 35 |
+
USE_PEFT_BACKEND,
|
| 36 |
+
is_torch_xla_available,
|
| 37 |
+
logging,
|
| 38 |
+
replace_example_docstring,
|
| 39 |
+
scale_lora_layers,
|
| 40 |
+
unscale_lora_layers,
|
| 41 |
+
)
|
| 42 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 43 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 44 |
+
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
if is_torch_xla_available():
|
| 48 |
+
import torch_xla.core.xla_model as xm
|
| 49 |
+
|
| 50 |
+
XLA_AVAILABLE = True
|
| 51 |
+
else:
|
| 52 |
+
XLA_AVAILABLE = False
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 56 |
+
|
| 57 |
+
EXAMPLE_DOC_STRING = """
|
| 58 |
+
Examples:
|
| 59 |
+
```py
|
| 60 |
+
>>> import torch
|
| 61 |
+
>>> from diffusers import FluxPipeline
|
| 62 |
+
|
| 63 |
+
>>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
|
| 64 |
+
>>> pipe.to("cuda")
|
| 65 |
+
>>> prompt = "A cat holding a sign that says hello world"
|
| 66 |
+
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
| 67 |
+
>>> # Refer to the pipeline documentation for more details.
|
| 68 |
+
>>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
|
| 69 |
+
>>> image.save("flux.png")
|
| 70 |
+
```
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def calculate_shift(
|
| 75 |
+
image_seq_len,
|
| 76 |
+
base_seq_len: int = 256,
|
| 77 |
+
max_seq_len: int = 4096,
|
| 78 |
+
base_shift: float = 0.5,
|
| 79 |
+
max_shift: float = 1.16,
|
| 80 |
+
):
|
| 81 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 82 |
+
b = base_shift - m * base_seq_len
|
| 83 |
+
mu = image_seq_len * m + b
|
| 84 |
+
return mu
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 88 |
+
def retrieve_timesteps(
|
| 89 |
+
scheduler,
|
| 90 |
+
num_inference_steps: Optional[int] = None,
|
| 91 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 92 |
+
timesteps: Optional[List[int]] = None,
|
| 93 |
+
sigmas: Optional[List[float]] = None,
|
| 94 |
+
**kwargs,
|
| 95 |
+
):
|
| 96 |
+
r"""
|
| 97 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 98 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
scheduler (`SchedulerMixin`):
|
| 102 |
+
The scheduler to get timesteps from.
|
| 103 |
+
num_inference_steps (`int`):
|
| 104 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 105 |
+
must be `None`.
|
| 106 |
+
device (`str` or `torch.device`, *optional*):
|
| 107 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 108 |
+
timesteps (`List[int]`, *optional*):
|
| 109 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 110 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 111 |
+
sigmas (`List[float]`, *optional*):
|
| 112 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 113 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 117 |
+
second element is the number of inference steps.
|
| 118 |
+
"""
|
| 119 |
+
if timesteps is not None and sigmas is not None:
|
| 120 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 121 |
+
if timesteps is not None:
|
| 122 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 123 |
+
if not accepts_timesteps:
|
| 124 |
+
raise ValueError(
|
| 125 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 126 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 127 |
+
)
|
| 128 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 129 |
+
timesteps = scheduler.timesteps
|
| 130 |
+
num_inference_steps = len(timesteps)
|
| 131 |
+
elif sigmas is not None:
|
| 132 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 133 |
+
if not accept_sigmas:
|
| 134 |
+
raise ValueError(
|
| 135 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 136 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 137 |
+
)
|
| 138 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 139 |
+
timesteps = scheduler.timesteps
|
| 140 |
+
num_inference_steps = len(timesteps)
|
| 141 |
+
else:
|
| 142 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 143 |
+
timesteps = scheduler.timesteps
|
| 144 |
+
return timesteps, num_inference_steps
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class FluxPipeline(
|
| 148 |
+
DiffusionPipeline,
|
| 149 |
+
FluxLoraLoaderMixin,
|
| 150 |
+
FromSingleFileMixin,
|
| 151 |
+
TextualInversionLoaderMixin,
|
| 152 |
+
FluxIPAdapterMixin,
|
| 153 |
+
):
|
| 154 |
+
r"""
|
| 155 |
+
The Flux pipeline for text-to-image generation.
|
| 156 |
+
|
| 157 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
transformer ([`FluxTransformer2DModel`]):
|
| 161 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 162 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 163 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 164 |
+
vae ([`AutoencoderKL`]):
|
| 165 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 166 |
+
text_encoder ([`CLIPTextModel`]):
|
| 167 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
| 168 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
| 169 |
+
text_encoder_2 ([`T5EncoderModel`]):
|
| 170 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
| 171 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
| 172 |
+
tokenizer (`CLIPTokenizer`):
|
| 173 |
+
Tokenizer of class
|
| 174 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 175 |
+
tokenizer_2 (`T5TokenizerFast`):
|
| 176 |
+
Second Tokenizer of class
|
| 177 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
|
| 181 |
+
_optional_components = ["image_encoder", "feature_extractor"]
|
| 182 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 183 |
+
|
| 184 |
+
def __init__(
|
| 185 |
+
self,
|
| 186 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 187 |
+
vae: AutoencoderKL,
|
| 188 |
+
text_encoder: CLIPTextModel,
|
| 189 |
+
tokenizer: CLIPTokenizer,
|
| 190 |
+
text_encoder_2: T5EncoderModel,
|
| 191 |
+
tokenizer_2: T5TokenizerFast,
|
| 192 |
+
transformer: FluxTransformer2DModel,
|
| 193 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 194 |
+
feature_extractor: CLIPImageProcessor = None,
|
| 195 |
+
):
|
| 196 |
+
super().__init__()
|
| 197 |
+
|
| 198 |
+
self.register_modules(
|
| 199 |
+
vae=vae,
|
| 200 |
+
text_encoder=text_encoder,
|
| 201 |
+
text_encoder_2=text_encoder_2,
|
| 202 |
+
tokenizer=tokenizer,
|
| 203 |
+
tokenizer_2=tokenizer_2,
|
| 204 |
+
transformer=transformer,
|
| 205 |
+
scheduler=scheduler,
|
| 206 |
+
image_encoder=image_encoder,
|
| 207 |
+
feature_extractor=feature_extractor,
|
| 208 |
+
)
|
| 209 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 210 |
+
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
| 211 |
+
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
| 212 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
| 213 |
+
self.tokenizer_max_length = (
|
| 214 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
| 215 |
+
)
|
| 216 |
+
self.default_sample_size = 128
|
| 217 |
+
|
| 218 |
+
def _get_t5_prompt_embeds(
|
| 219 |
+
self,
|
| 220 |
+
prompt: Union[str, List[str]] = None,
|
| 221 |
+
num_images_per_prompt: int = 1,
|
| 222 |
+
max_sequence_length: int = 512,
|
| 223 |
+
device: Optional[torch.device] = None,
|
| 224 |
+
dtype: Optional[torch.dtype] = None,
|
| 225 |
+
):
|
| 226 |
+
device = device or self._execution_device
|
| 227 |
+
dtype = dtype or self.text_encoder.dtype
|
| 228 |
+
|
| 229 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 230 |
+
batch_size = len(prompt)
|
| 231 |
+
|
| 232 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 233 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
| 234 |
+
|
| 235 |
+
text_inputs = self.tokenizer_2(
|
| 236 |
+
prompt,
|
| 237 |
+
padding="max_length",
|
| 238 |
+
max_length=max_sequence_length,
|
| 239 |
+
truncation=True,
|
| 240 |
+
return_length=False,
|
| 241 |
+
return_overflowing_tokens=False,
|
| 242 |
+
return_tensors="pt",
|
| 243 |
+
)
|
| 244 |
+
text_input_ids = text_inputs.input_ids
|
| 245 |
+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
| 246 |
+
|
| 247 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 248 |
+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
| 249 |
+
logger.warning(
|
| 250 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 251 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
| 255 |
+
|
| 256 |
+
dtype = self.text_encoder_2.dtype
|
| 257 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 258 |
+
|
| 259 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 260 |
+
|
| 261 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
| 262 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 263 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 264 |
+
|
| 265 |
+
return prompt_embeds
|
| 266 |
+
|
| 267 |
+
def _get_clip_prompt_embeds(
|
| 268 |
+
self,
|
| 269 |
+
prompt: Union[str, List[str]],
|
| 270 |
+
num_images_per_prompt: int = 1,
|
| 271 |
+
device: Optional[torch.device] = None,
|
| 272 |
+
):
|
| 273 |
+
device = device or self._execution_device
|
| 274 |
+
|
| 275 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 276 |
+
batch_size = len(prompt)
|
| 277 |
+
|
| 278 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 279 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
| 280 |
+
|
| 281 |
+
text_inputs = self.tokenizer(
|
| 282 |
+
prompt,
|
| 283 |
+
padding="max_length",
|
| 284 |
+
max_length=self.tokenizer_max_length,
|
| 285 |
+
truncation=True,
|
| 286 |
+
return_overflowing_tokens=False,
|
| 287 |
+
return_length=False,
|
| 288 |
+
return_tensors="pt",
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
text_input_ids = text_inputs.input_ids
|
| 292 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 293 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 294 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
| 295 |
+
logger.warning(
|
| 296 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
| 297 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
| 298 |
+
)
|
| 299 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
| 300 |
+
|
| 301 |
+
# Use pooled output of CLIPTextModel
|
| 302 |
+
prompt_embeds = prompt_embeds.pooler_output
|
| 303 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
| 304 |
+
|
| 305 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 306 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
| 307 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 308 |
+
|
| 309 |
+
return prompt_embeds
|
| 310 |
+
|
| 311 |
+
def encode_prompt(
|
| 312 |
+
self,
|
| 313 |
+
prompt: Union[str, List[str]],
|
| 314 |
+
prompt_2: Union[str, List[str]],
|
| 315 |
+
device: Optional[torch.device] = None,
|
| 316 |
+
num_images_per_prompt: int = 1,
|
| 317 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 318 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 319 |
+
max_sequence_length: int = 512,
|
| 320 |
+
lora_scale: Optional[float] = None,
|
| 321 |
+
):
|
| 322 |
+
r"""
|
| 323 |
+
|
| 324 |
+
Args:
|
| 325 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 326 |
+
prompt to be encoded
|
| 327 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 328 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 329 |
+
used in all text-encoders
|
| 330 |
+
device: (`torch.device`):
|
| 331 |
+
torch device
|
| 332 |
+
num_images_per_prompt (`int`):
|
| 333 |
+
number of images that should be generated per prompt
|
| 334 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 335 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 336 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 337 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 338 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 339 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 340 |
+
lora_scale (`float`, *optional*):
|
| 341 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 342 |
+
"""
|
| 343 |
+
device = device or self._execution_device
|
| 344 |
+
|
| 345 |
+
# set lora scale so that monkey patched LoRA
|
| 346 |
+
# function of text encoder can correctly access it
|
| 347 |
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
| 348 |
+
self._lora_scale = lora_scale
|
| 349 |
+
|
| 350 |
+
# dynamically adjust the LoRA scale
|
| 351 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
| 352 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
| 353 |
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
| 354 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
| 355 |
+
|
| 356 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 357 |
+
|
| 358 |
+
if prompt_embeds is None:
|
| 359 |
+
prompt_2 = prompt_2 or prompt
|
| 360 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
| 361 |
+
|
| 362 |
+
# We only use the pooled prompt output from the CLIPTextModel
|
| 363 |
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
| 364 |
+
prompt=prompt,
|
| 365 |
+
device=device,
|
| 366 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 367 |
+
)
|
| 368 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 369 |
+
prompt=prompt_2,
|
| 370 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 371 |
+
max_sequence_length=max_sequence_length,
|
| 372 |
+
device=device,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
if self.text_encoder is not None:
|
| 376 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 377 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 378 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 379 |
+
|
| 380 |
+
if self.text_encoder_2 is not None:
|
| 381 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 382 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 383 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
| 384 |
+
|
| 385 |
+
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
| 386 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
| 387 |
+
|
| 388 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
| 389 |
+
|
| 390 |
+
def encode_image(self, image, device, num_images_per_prompt):
|
| 391 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
| 392 |
+
|
| 393 |
+
if not isinstance(image, torch.Tensor):
|
| 394 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
| 395 |
+
|
| 396 |
+
image = image.to(device=device, dtype=dtype)
|
| 397 |
+
image_embeds = self.image_encoder(image).image_embeds
|
| 398 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
| 399 |
+
return image_embeds
|
| 400 |
+
|
| 401 |
+
def prepare_ip_adapter_image_embeds(
|
| 402 |
+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
|
| 403 |
+
):
|
| 404 |
+
image_embeds = []
|
| 405 |
+
if ip_adapter_image_embeds is None:
|
| 406 |
+
if not isinstance(ip_adapter_image, list):
|
| 407 |
+
ip_adapter_image = [ip_adapter_image]
|
| 408 |
+
|
| 409 |
+
if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
|
| 410 |
+
raise ValueError(
|
| 411 |
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
for single_ip_adapter_image, image_proj_layer in zip(
|
| 415 |
+
ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
|
| 416 |
+
):
|
| 417 |
+
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
|
| 418 |
+
|
| 419 |
+
image_embeds.append(single_image_embeds[None, :])
|
| 420 |
+
else:
|
| 421 |
+
for single_image_embeds in ip_adapter_image_embeds:
|
| 422 |
+
image_embeds.append(single_image_embeds)
|
| 423 |
+
|
| 424 |
+
ip_adapter_image_embeds = []
|
| 425 |
+
for i, single_image_embeds in enumerate(image_embeds):
|
| 426 |
+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
| 427 |
+
single_image_embeds = single_image_embeds.to(device=device)
|
| 428 |
+
ip_adapter_image_embeds.append(single_image_embeds)
|
| 429 |
+
|
| 430 |
+
return ip_adapter_image_embeds
|
| 431 |
+
|
| 432 |
+
def check_inputs(
|
| 433 |
+
self,
|
| 434 |
+
prompt,
|
| 435 |
+
prompt_2,
|
| 436 |
+
height,
|
| 437 |
+
width,
|
| 438 |
+
negative_prompt=None,
|
| 439 |
+
negative_prompt_2=None,
|
| 440 |
+
prompt_embeds=None,
|
| 441 |
+
negative_prompt_embeds=None,
|
| 442 |
+
pooled_prompt_embeds=None,
|
| 443 |
+
negative_pooled_prompt_embeds=None,
|
| 444 |
+
callback_on_step_end_tensor_inputs=None,
|
| 445 |
+
max_sequence_length=None,
|
| 446 |
+
):
|
| 447 |
+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
| 448 |
+
logger.warning(
|
| 449 |
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 453 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 454 |
+
):
|
| 455 |
+
raise ValueError(
|
| 456 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
if prompt is not None and prompt_embeds is not None:
|
| 460 |
+
raise ValueError(
|
| 461 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 462 |
+
" only forward one of the two."
|
| 463 |
+
)
|
| 464 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
| 465 |
+
raise ValueError(
|
| 466 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 467 |
+
" only forward one of the two."
|
| 468 |
+
)
|
| 469 |
+
elif prompt is None and prompt_embeds is None:
|
| 470 |
+
raise ValueError(
|
| 471 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 472 |
+
)
|
| 473 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 474 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 475 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
| 476 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
| 477 |
+
|
| 478 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 479 |
+
raise ValueError(
|
| 480 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 481 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 482 |
+
)
|
| 483 |
+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
| 484 |
+
raise ValueError(
|
| 485 |
+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
| 486 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 490 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 491 |
+
raise ValueError(
|
| 492 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 493 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 494 |
+
f" {negative_prompt_embeds.shape}."
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
| 498 |
+
raise ValueError(
|
| 499 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
| 500 |
+
)
|
| 501 |
+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
| 502 |
+
raise ValueError(
|
| 503 |
+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
| 507 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
| 508 |
+
|
| 509 |
+
@staticmethod
|
| 510 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype,scale_h=1.0,scale_w=1.0):
|
| 511 |
+
latent_image_ids = torch.zeros(height, width, 3)
|
| 512 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]* scale_h
|
| 513 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]* scale_w
|
| 514 |
+
|
| 515 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 516 |
+
|
| 517 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 518 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
| 522 |
+
|
| 523 |
+
@staticmethod
|
| 524 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
| 525 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
| 526 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 527 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
| 528 |
+
|
| 529 |
+
return latents
|
| 530 |
+
|
| 531 |
+
@staticmethod
|
| 532 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
| 533 |
+
batch_size, num_patches, channels = latents.shape
|
| 534 |
+
|
| 535 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 536 |
+
# latent height and width to be divisible by 2.
|
| 537 |
+
height = 2 * (int(height) // (vae_scale_factor * 2))
|
| 538 |
+
width = 2 * (int(width) // (vae_scale_factor * 2))
|
| 539 |
+
|
| 540 |
+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
| 541 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
| 542 |
+
|
| 543 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
| 544 |
+
|
| 545 |
+
return latents
|
| 546 |
+
|
| 547 |
+
def enable_vae_slicing(self):
|
| 548 |
+
r"""
|
| 549 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 550 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 551 |
+
"""
|
| 552 |
+
self.vae.enable_slicing()
|
| 553 |
+
|
| 554 |
+
def disable_vae_slicing(self):
|
| 555 |
+
r"""
|
| 556 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
| 557 |
+
computing decoding in one step.
|
| 558 |
+
"""
|
| 559 |
+
self.vae.disable_slicing()
|
| 560 |
+
|
| 561 |
+
def enable_vae_tiling(self):
|
| 562 |
+
r"""
|
| 563 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 564 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 565 |
+
processing larger images.
|
| 566 |
+
"""
|
| 567 |
+
self.vae.enable_tiling()
|
| 568 |
+
|
| 569 |
+
def disable_vae_tiling(self):
|
| 570 |
+
r"""
|
| 571 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
| 572 |
+
computing decoding in one step.
|
| 573 |
+
"""
|
| 574 |
+
self.vae.disable_tiling()
|
| 575 |
+
|
| 576 |
+
def prepare_latents(
|
| 577 |
+
self,
|
| 578 |
+
batch_size,
|
| 579 |
+
num_channels_latents,
|
| 580 |
+
height,
|
| 581 |
+
width,
|
| 582 |
+
dtype,
|
| 583 |
+
device,
|
| 584 |
+
generator,
|
| 585 |
+
latents=None,
|
| 586 |
+
):
|
| 587 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 588 |
+
# latent height and width to be divisible by 2.
|
| 589 |
+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
| 590 |
+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
| 591 |
+
|
| 592 |
+
shape = (batch_size, num_channels_latents, height, width)
|
| 593 |
+
|
| 594 |
+
if latents is not None:
|
| 595 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
| 596 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
| 597 |
+
|
| 598 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 599 |
+
raise ValueError(
|
| 600 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 601 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 605 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
| 606 |
+
|
| 607 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
| 608 |
+
|
| 609 |
+
return latents, latent_image_ids
|
| 610 |
+
|
| 611 |
+
@property
|
| 612 |
+
def guidance_scale(self):
|
| 613 |
+
return self._guidance_scale
|
| 614 |
+
|
| 615 |
+
@property
|
| 616 |
+
def joint_attention_kwargs(self):
|
| 617 |
+
return self._joint_attention_kwargs
|
| 618 |
+
|
| 619 |
+
@property
|
| 620 |
+
def num_timesteps(self):
|
| 621 |
+
return self._num_timesteps
|
| 622 |
+
|
| 623 |
+
@property
|
| 624 |
+
def interrupt(self):
|
| 625 |
+
return self._interrupt
|
| 626 |
+
|
| 627 |
+
@torch.no_grad()
|
| 628 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 629 |
+
def __call__(
|
| 630 |
+
self,
|
| 631 |
+
prompt: Union[str, List[str]] = None,
|
| 632 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 633 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 634 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 635 |
+
true_cfg_scale: float = 3.0,
|
| 636 |
+
height: Optional[int] = None,
|
| 637 |
+
width: Optional[int] = None,
|
| 638 |
+
num_inference_steps: int = 28,
|
| 639 |
+
sigmas: Optional[List[float]] = None,
|
| 640 |
+
guidance_scale: float = 3.5,
|
| 641 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 642 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 643 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 644 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 645 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 646 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 647 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 648 |
+
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 649 |
+
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 650 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 651 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 652 |
+
output_type: Optional[str] = "pil",
|
| 653 |
+
return_dict: bool = True,
|
| 654 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 655 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 656 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 657 |
+
max_sequence_length: int = 512,
|
| 658 |
+
objects_boxes=None,
|
| 659 |
+
objects_caption=None,
|
| 660 |
+
objects_masks = None,
|
| 661 |
+
objects_masks_maps=None,
|
| 662 |
+
subject_masks_maps=None,
|
| 663 |
+
condition_img = None,
|
| 664 |
+
neg_condtion_img=None,
|
| 665 |
+
max_boxes_per_image = 10,
|
| 666 |
+
position_delta=[0,-64],
|
| 667 |
+
scale_h=1.0,
|
| 668 |
+
scale_w=1.0,
|
| 669 |
+
use_bucket=False
|
| 670 |
+
):
|
| 671 |
+
r"""
|
| 672 |
+
Function invoked when calling the pipeline for generation.
|
| 673 |
+
|
| 674 |
+
Args:
|
| 675 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 676 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 677 |
+
instead.
|
| 678 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 679 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 680 |
+
will be used instead
|
| 681 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 682 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 683 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 684 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 685 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 686 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 687 |
+
expense of slower inference.
|
| 688 |
+
sigmas (`List[float]`, *optional*):
|
| 689 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 690 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 691 |
+
will be used.
|
| 692 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
| 693 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 694 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 695 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 696 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 697 |
+
usually at the expense of lower image quality.
|
| 698 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 699 |
+
The number of images to generate per prompt.
|
| 700 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 701 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 702 |
+
to make generation deterministic.
|
| 703 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 704 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 705 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 706 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 707 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 708 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 709 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 710 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 711 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 712 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 713 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 714 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 715 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 716 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
| 717 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 718 |
+
negative_ip_adapter_image:
|
| 719 |
+
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 720 |
+
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 721 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 722 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
| 723 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 724 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 725 |
+
The output format of the generate image. Choose between
|
| 726 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 727 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 728 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
| 729 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 730 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 731 |
+
`self.processor` in
|
| 732 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 733 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 734 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 735 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 736 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 737 |
+
`callback_on_step_end_tensor_inputs`.
|
| 738 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 739 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 740 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 741 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 742 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
| 743 |
+
|
| 744 |
+
Examples:
|
| 745 |
+
|
| 746 |
+
Returns:
|
| 747 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
| 748 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
| 749 |
+
images.
|
| 750 |
+
"""
|
| 751 |
+
|
| 752 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 753 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 754 |
+
|
| 755 |
+
# 1. Check inputs. Raise error if not correct
|
| 756 |
+
self.check_inputs(
|
| 757 |
+
prompt,
|
| 758 |
+
prompt_2,
|
| 759 |
+
height,
|
| 760 |
+
width,
|
| 761 |
+
negative_prompt=negative_prompt,
|
| 762 |
+
negative_prompt_2=negative_prompt_2,
|
| 763 |
+
prompt_embeds=prompt_embeds,
|
| 764 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 765 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 766 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 767 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 768 |
+
max_sequence_length=max_sequence_length,
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
self._guidance_scale = guidance_scale
|
| 772 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 773 |
+
self._interrupt = False
|
| 774 |
+
|
| 775 |
+
# 2. Define call parameters
|
| 776 |
+
if prompt is not None and isinstance(prompt, str):
|
| 777 |
+
batch_size = 1
|
| 778 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 779 |
+
batch_size = len(prompt)
|
| 780 |
+
else:
|
| 781 |
+
batch_size = prompt_embeds.shape[0]
|
| 782 |
+
|
| 783 |
+
device = self._execution_device
|
| 784 |
+
|
| 785 |
+
lora_scale = (
|
| 786 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 787 |
+
)
|
| 788 |
+
#creatidesign
|
| 789 |
+
negative_prompt = negative_prompt if negative_prompt is not None else [""]*batch_size
|
| 790 |
+
|
| 791 |
+
do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
|
| 792 |
+
(
|
| 793 |
+
prompt_embeds,
|
| 794 |
+
pooled_prompt_embeds,
|
| 795 |
+
text_ids,
|
| 796 |
+
) = self.encode_prompt(
|
| 797 |
+
prompt=prompt,
|
| 798 |
+
prompt_2=prompt_2,
|
| 799 |
+
prompt_embeds=prompt_embeds,
|
| 800 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 801 |
+
device=device,
|
| 802 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 803 |
+
max_sequence_length=max_sequence_length,
|
| 804 |
+
lora_scale=lora_scale,
|
| 805 |
+
)
|
| 806 |
+
if do_true_cfg:
|
| 807 |
+
(
|
| 808 |
+
negative_prompt_embeds,
|
| 809 |
+
negative_pooled_prompt_embeds,
|
| 810 |
+
_,
|
| 811 |
+
) = self.encode_prompt(
|
| 812 |
+
prompt=negative_prompt,
|
| 813 |
+
prompt_2=negative_prompt_2,
|
| 814 |
+
prompt_embeds=negative_prompt_embeds,
|
| 815 |
+
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 816 |
+
device=device,
|
| 817 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 818 |
+
max_sequence_length=max_sequence_length,
|
| 819 |
+
lora_scale=lora_scale,
|
| 820 |
+
)
|
| 821 |
+
|
| 822 |
+
# 4. Prepare latent variables
|
| 823 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 824 |
+
latents, latent_image_ids = self.prepare_latents(
|
| 825 |
+
batch_size * num_images_per_prompt,
|
| 826 |
+
num_channels_latents,
|
| 827 |
+
height,
|
| 828 |
+
width,
|
| 829 |
+
prompt_embeds.dtype,
|
| 830 |
+
device,
|
| 831 |
+
generator,
|
| 832 |
+
latents,
|
| 833 |
+
)
|
| 834 |
+
|
| 835 |
+
# 5. Prepare timesteps
|
| 836 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 837 |
+
image_seq_len = latents.shape[1]
|
| 838 |
+
mu = calculate_shift(
|
| 839 |
+
image_seq_len,
|
| 840 |
+
self.scheduler.config.base_image_seq_len,
|
| 841 |
+
self.scheduler.config.max_image_seq_len,
|
| 842 |
+
self.scheduler.config.base_shift,
|
| 843 |
+
self.scheduler.config.max_shift,
|
| 844 |
+
)
|
| 845 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 846 |
+
self.scheduler,
|
| 847 |
+
num_inference_steps,
|
| 848 |
+
device,
|
| 849 |
+
sigmas=sigmas,
|
| 850 |
+
mu=mu,
|
| 851 |
+
)
|
| 852 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 853 |
+
self._num_timesteps = len(timesteps)
|
| 854 |
+
|
| 855 |
+
# handle guidance
|
| 856 |
+
if self.transformer.config.guidance_embeds:
|
| 857 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 858 |
+
guidance = guidance.expand(latents.shape[0])
|
| 859 |
+
else:
|
| 860 |
+
guidance = None
|
| 861 |
+
|
| 862 |
+
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
|
| 863 |
+
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
|
| 864 |
+
):
|
| 865 |
+
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
| 866 |
+
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
|
| 867 |
+
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
|
| 868 |
+
):
|
| 869 |
+
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
| 870 |
+
|
| 871 |
+
if self.joint_attention_kwargs is None:
|
| 872 |
+
self._joint_attention_kwargs = {}
|
| 873 |
+
|
| 874 |
+
image_embeds = None
|
| 875 |
+
negative_image_embeds = None
|
| 876 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 877 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 878 |
+
ip_adapter_image,
|
| 879 |
+
ip_adapter_image_embeds,
|
| 880 |
+
device,
|
| 881 |
+
batch_size * num_images_per_prompt,
|
| 882 |
+
)
|
| 883 |
+
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
|
| 884 |
+
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 885 |
+
negative_ip_adapter_image,
|
| 886 |
+
negative_ip_adapter_image_embeds,
|
| 887 |
+
device,
|
| 888 |
+
batch_size * num_images_per_prompt,
|
| 889 |
+
)
|
| 890 |
+
|
| 891 |
+
#creatidesign
|
| 892 |
+
objects_boxes = objects_boxes.to(device=device, dtype=latents.dtype).repeat_interleave(batch_size, dim=0)
|
| 893 |
+
objects_masks = objects_masks.to(device=device, dtype=latents.dtype).repeat_interleave(batch_size, dim=0)
|
| 894 |
+
objects_masks_maps = objects_masks_maps.to(device=device, dtype=latents.dtype).repeat_interleave(batch_size, dim=0)
|
| 895 |
+
subject_masks_maps = subject_masks_maps.to(device=device, dtype=latents.dtype).repeat_interleave(batch_size, dim=0)
|
| 896 |
+
N = len(objects_caption[0])
|
| 897 |
+
print("N",N)
|
| 898 |
+
bbox_text_embeddings = torch.zeros(
|
| 899 |
+
max_boxes_per_image,max_sequence_length,4096, device=device, dtype=latents.dtype
|
| 900 |
+
)
|
| 901 |
+
if N>0:
|
| 902 |
+
bbox_text_embeddings_temp,_,_ = self.encode_prompt(prompt=objects_caption[0],prompt_2=None,device=device,
|
| 903 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 904 |
+
max_sequence_length=max_sequence_length,)
|
| 905 |
+
bbox_text_embeddings[:N]=bbox_text_embeddings_temp
|
| 906 |
+
bbox_text_embeddings = bbox_text_embeddings.unsqueeze(0).to(device=device, dtype=latents.dtype).repeat_interleave(batch_size, dim=0) #[b,10,30,4096]
|
| 907 |
+
|
| 908 |
+
# Convert condition images to latent space
|
| 909 |
+
condition_img = condition_img.to(device=device,dtype=self.vae.dtype).repeat_interleave(batch_size, dim=0)
|
| 910 |
+
condition_img_input = self.vae.encode(condition_img).latent_dist.sample()
|
| 911 |
+
condition_img_input = (condition_img_input - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 912 |
+
condition_img_input = condition_img_input.to(dtype=latents.dtype)
|
| 913 |
+
condition_latent_image_ids = self._prepare_latent_image_ids(
|
| 914 |
+
condition_img_input.shape[0],
|
| 915 |
+
condition_img_input.shape[2] // 2,
|
| 916 |
+
condition_img_input.shape[3] // 2,
|
| 917 |
+
device,
|
| 918 |
+
latents.dtype,
|
| 919 |
+
scale_h = scale_h,
|
| 920 |
+
scale_w = scale_w,
|
| 921 |
+
)
|
| 922 |
+
|
| 923 |
+
# shift condition image ids
|
| 924 |
+
|
| 925 |
+
if use_bucket:
|
| 926 |
+
# offset determined by condition image width and scale
|
| 927 |
+
condition_latent_image_ids[:, 1] += 0 # H dimension unchanged
|
| 928 |
+
condition_latent_image_ids[:, 2] += -1*(condition_img_input.shape[3]*scale_w//2)
|
| 929 |
+
else:
|
| 930 |
+
# shift condition image ids
|
| 931 |
+
condition_latent_image_ids[:, 1] += position_delta[0] # H dimension unchanged
|
| 932 |
+
condition_latent_image_ids[:, 2] += position_delta[1] # W dimension shift left
|
| 933 |
+
|
| 934 |
+
packed_clean_condition_input = self._pack_latents(
|
| 935 |
+
condition_img_input,
|
| 936 |
+
batch_size=condition_img_input.shape[0],
|
| 937 |
+
num_channels_latents=condition_img_input.shape[1],
|
| 938 |
+
height=condition_img_input.shape[2],
|
| 939 |
+
width=condition_img_input.shape[3],
|
| 940 |
+
)
|
| 941 |
+
|
| 942 |
+
|
| 943 |
+
design_kwargs = {
|
| 944 |
+
"object_layout": {"objects_boxes": objects_boxes, "bbox_text_embeddings": bbox_text_embeddings, "bbox_masks": objects_masks,"objects_masks_maps":objects_masks_maps,"img_token_h":(int(height) // (self.vae_scale_factor * 2)), "img_token_w":(int(width) // (self.vae_scale_factor * 2))}, #[b,10,4], [B,10,512,4096],[b,10]
|
| 945 |
+
"subject_contion":{"condition_img":packed_clean_condition_input,"subject_masks_maps":subject_masks_maps,"condition_img_ids":condition_latent_image_ids,"subject_token_h":condition_img_input.shape[2]//2, "subject_token_w":condition_img_input.shape[3]//2}, # [B,4,64,64]
|
| 946 |
+
}
|
| 947 |
+
|
| 948 |
+
neg_objects_masks = torch.zeros_like(objects_masks).to(device=device, dtype=latents.dtype)
|
| 949 |
+
|
| 950 |
+
neg_condtion_img = neg_condtion_img.to(device=device,dtype=self.vae.dtype).repeat_interleave(batch_size, dim=0)
|
| 951 |
+
neg_condtion_img_input = self.vae.encode(neg_condtion_img).latent_dist.sample()
|
| 952 |
+
neg_condtion_img_input = (neg_condtion_img_input - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 953 |
+
neg_condtion_img_input = neg_condtion_img_input.to(dtype=latents.dtype)
|
| 954 |
+
neg_condition_latent_image_ids = self._prepare_latent_image_ids(
|
| 955 |
+
neg_condtion_img_input.shape[0],
|
| 956 |
+
neg_condtion_img_input.shape[2] // 2,
|
| 957 |
+
neg_condtion_img_input.shape[3] // 2,
|
| 958 |
+
device,
|
| 959 |
+
latents.dtype,
|
| 960 |
+
scale_h = scale_h,
|
| 961 |
+
scale_w = scale_w
|
| 962 |
+
)
|
| 963 |
+
|
| 964 |
+
if use_bucket:
|
| 965 |
+
# offset determined by condition image width and scale
|
| 966 |
+
neg_condition_latent_image_ids[:, 1] += 0 # H dimension unchanged
|
| 967 |
+
neg_condition_latent_image_ids[:, 2] += -1*(condition_img_input.shape[3]*scale_w//2)
|
| 968 |
+
else:
|
| 969 |
+
# shift negative condition image ids
|
| 970 |
+
neg_condition_latent_image_ids[:, 1] += position_delta[0] # H dimension shift
|
| 971 |
+
neg_condition_latent_image_ids[:, 2] += position_delta[1] # W dimension shift
|
| 972 |
+
|
| 973 |
+
packed_clean_neg_condtion_input = self._pack_latents(
|
| 974 |
+
neg_condtion_img_input,
|
| 975 |
+
batch_size=neg_condtion_img_input.shape[0],
|
| 976 |
+
num_channels_latents=neg_condtion_img_input.shape[1],
|
| 977 |
+
height=neg_condtion_img_input.shape[2],
|
| 978 |
+
width=neg_condtion_img_input.shape[3],
|
| 979 |
+
)
|
| 980 |
+
|
| 981 |
+
neg_subject_masks_maps = subject_masks_maps
|
| 982 |
+
neg_objects_masks_maps = objects_masks_maps
|
| 983 |
+
neg_design_kwargs = {
|
| 984 |
+
"object_layout": {"objects_boxes": objects_boxes, "bbox_text_embeddings": bbox_text_embeddings, "bbox_masks": neg_objects_masks,"objects_masks_maps":neg_objects_masks_maps,"img_token_h":(int(height) // (self.vae_scale_factor * 2)), "img_token_w":(int(width) // (self.vae_scale_factor * 2))}, #[b,10,4], [B,10,512,4096],[b,10]
|
| 985 |
+
"subject_contion":{"condition_img":packed_clean_neg_condtion_input,"subject_masks_maps":neg_subject_masks_maps,"condition_img_ids":neg_condition_latent_image_ids,"subject_token_h":condition_img_input.shape[2]//2, "subject_token_w":condition_img_input.shape[3]//2}, # [B,4,64,64]
|
| 986 |
+
}
|
| 987 |
+
|
| 988 |
+
# 6. Denoising loop
|
| 989 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 990 |
+
for i, t in enumerate(timesteps):
|
| 991 |
+
if self.interrupt:
|
| 992 |
+
continue
|
| 993 |
+
|
| 994 |
+
if image_embeds is not None:
|
| 995 |
+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
| 996 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 997 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 998 |
+
noise_pred = self.transformer(
|
| 999 |
+
hidden_states=latents,
|
| 1000 |
+
timestep=timestep / 1000,
|
| 1001 |
+
guidance=guidance,
|
| 1002 |
+
pooled_projections=pooled_prompt_embeds,
|
| 1003 |
+
encoder_hidden_states=prompt_embeds,
|
| 1004 |
+
txt_ids=text_ids,
|
| 1005 |
+
img_ids=latent_image_ids,
|
| 1006 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1007 |
+
return_dict=False,
|
| 1008 |
+
design_kwargs = design_kwargs,
|
| 1009 |
+
)[0]
|
| 1010 |
+
|
| 1011 |
+
if do_true_cfg:
|
| 1012 |
+
if negative_image_embeds is not None:
|
| 1013 |
+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
| 1014 |
+
neg_noise_pred = self.transformer(
|
| 1015 |
+
hidden_states=latents,
|
| 1016 |
+
timestep=timestep / 1000,
|
| 1017 |
+
guidance=guidance,
|
| 1018 |
+
pooled_projections=negative_pooled_prompt_embeds,
|
| 1019 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 1020 |
+
txt_ids=text_ids,
|
| 1021 |
+
img_ids=latent_image_ids,
|
| 1022 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 1023 |
+
return_dict=False,
|
| 1024 |
+
design_kwargs = neg_design_kwargs,
|
| 1025 |
+
)[0]
|
| 1026 |
+
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
| 1027 |
+
|
| 1028 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 1029 |
+
latents_dtype = latents.dtype
|
| 1030 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 1031 |
+
|
| 1032 |
+
if latents.dtype != latents_dtype:
|
| 1033 |
+
if torch.backends.mps.is_available():
|
| 1034 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 1035 |
+
latents = latents.to(latents_dtype)
|
| 1036 |
+
|
| 1037 |
+
if callback_on_step_end is not None:
|
| 1038 |
+
callback_kwargs = {}
|
| 1039 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 1040 |
+
callback_kwargs[k] = locals()[k]
|
| 1041 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1042 |
+
|
| 1043 |
+
latents = callback_outputs.pop("latents", latents)
|
| 1044 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1045 |
+
|
| 1046 |
+
# call the callback, if provided
|
| 1047 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1048 |
+
progress_bar.update()
|
| 1049 |
+
|
| 1050 |
+
if XLA_AVAILABLE:
|
| 1051 |
+
xm.mark_step()
|
| 1052 |
+
|
| 1053 |
+
if output_type == "latent":
|
| 1054 |
+
image = latents
|
| 1055 |
+
|
| 1056 |
+
else:
|
| 1057 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 1058 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 1059 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 1060 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 1061 |
+
|
| 1062 |
+
# Offload all models
|
| 1063 |
+
self.maybe_free_model_hooks()
|
| 1064 |
+
|
| 1065 |
+
if not return_dict:
|
| 1066 |
+
return (image,)
|
| 1067 |
+
|
| 1068 |
+
return FluxPipelineOutput(images=image)
|
requirements.txt
CHANGED
|
@@ -1,6 +1,14 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
diffusers
|
| 2 |
+
accelerate
|
| 3 |
+
transformers
|
| 4 |
+
sentencepiece
|
| 5 |
+
protobuf
|
| 6 |
+
bitsandbytes
|
| 7 |
+
prodigyopt
|
| 8 |
+
opencv-python
|
| 9 |
+
beautifulsoup4
|
| 10 |
+
xformers==0.0.27.post2
|
| 11 |
+
flash-attn
|
| 12 |
+
gradio
|
| 13 |
+
|
| 14 |
+
|
test_creatidesign_benchmark.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from random import uniform
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import time
|
| 7 |
+
from IPython.core.debugger import set_trace
|
| 8 |
+
from dataloader.creatidesign_dataset_benchmark import DesignDataset,visualize_bbox,collate_fn,tensor_to_pil,make_image_grid_RGB
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from safetensors.torch import save_file, load_file
|
| 12 |
+
from accelerate import load_checkpoint_and_dispatch
|
| 13 |
+
from modules.flux.transformer_flux_creatidesign import FluxTransformer2DModel
|
| 14 |
+
from pipeline.pipeline_flux_creatidesign import FluxPipeline
|
| 15 |
+
import json
|
| 16 |
+
from huggingface_hub import snapshot_download
|
| 17 |
+
from datasets import load_dataset
|
| 18 |
+
|
| 19 |
+
if __name__ == "__main__":
|
| 20 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 21 |
+
weight_dtype = torch.bfloat16
|
| 22 |
+
resolution = 1024
|
| 23 |
+
condition_resolution = 512
|
| 24 |
+
neg_condition_image = 'same'
|
| 25 |
+
background_color = 'gray'
|
| 26 |
+
use_bucket = True
|
| 27 |
+
condition_resolution_scale_ratio=0.5
|
| 28 |
+
|
| 29 |
+
benchmark_repo = 'HuiZhang0812/CreatiDesign_benchmark' # huggingface repo of benchmark
|
| 30 |
+
|
| 31 |
+
datasets = DesignDataset(dataset_name=benchmark_repo,
|
| 32 |
+
resolution=resolution,
|
| 33 |
+
condition_resolution=condition_resolution,
|
| 34 |
+
neg_condition_image =neg_condition_image,
|
| 35 |
+
background_color=background_color,
|
| 36 |
+
use_bucket=use_bucket,
|
| 37 |
+
condition_resolution_scale_ratio=condition_resolution_scale_ratio
|
| 38 |
+
)
|
| 39 |
+
test_dataloader = DataLoader(datasets, batch_size=1, shuffle=False, num_workers=4,collate_fn=collate_fn)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
model_path = "black-forest-labs/FLUX.1-dev"
|
| 43 |
+
|
| 44 |
+
ckpt_repo = "HuiZhang0812/CreatiDesign" # huggingface repo of ckpt
|
| 45 |
+
|
| 46 |
+
ckpt_path = snapshot_download(
|
| 47 |
+
repo_id=ckpt_repo,
|
| 48 |
+
repo_type="model",
|
| 49 |
+
local_dir="./CreatiDesign_checkpoint",
|
| 50 |
+
local_dir_use_symlinks=False
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Load transformer config from checkpoint
|
| 54 |
+
with open(os.path.join(ckpt_path, "transformer", "config.json"), 'r') as f:
|
| 55 |
+
config = json.load(f)
|
| 56 |
+
|
| 57 |
+
transformer = FluxTransformer2DModel(**config)
|
| 58 |
+
transformer = load_checkpoint_and_dispatch(transformer, checkpoint=os.path.join(model_path,"transformer"), device_map=None)
|
| 59 |
+
|
| 60 |
+
# Load lora parameters using safetensors
|
| 61 |
+
state_dict = load_file(os.path.join(ckpt_path, "transformer","model.safetensors"))
|
| 62 |
+
|
| 63 |
+
# Load parameters, allow partial loading
|
| 64 |
+
missing_keys, unexpected_keys = transformer.load_state_dict(state_dict, strict=False)
|
| 65 |
+
|
| 66 |
+
print(f"Loaded parameters: {len(state_dict)}",state_dict.keys())
|
| 67 |
+
print(f"Missing keys: {len(missing_keys)}",missing_keys)
|
| 68 |
+
print(f"Unexpected keys: {len(unexpected_keys)}",unexpected_keys)
|
| 69 |
+
|
| 70 |
+
transformer = transformer.to(dtype=torch.bfloat16)
|
| 71 |
+
|
| 72 |
+
pipe = FluxPipeline.from_pretrained(model_path, transformer=transformer,torch_dtype=torch.bfloat16)
|
| 73 |
+
pipe = pipe.to("cuda")
|
| 74 |
+
|
| 75 |
+
seed=42
|
| 76 |
+
num_samples = 1
|
| 77 |
+
true_cfg_scale=3.5
|
| 78 |
+
guidance_scale=1.0
|
| 79 |
+
if resolution == 512:
|
| 80 |
+
position_delta=[0,-32]
|
| 81 |
+
else:
|
| 82 |
+
position_delta=[0,-64]
|
| 83 |
+
if use_bucket:
|
| 84 |
+
scale_h = 1/condition_resolution_scale_ratio
|
| 85 |
+
scale_w = 1/condition_resolution_scale_ratio
|
| 86 |
+
else:
|
| 87 |
+
scale_h = resolution/condition_resolution
|
| 88 |
+
scale_w = resolution/condition_resolution
|
| 89 |
+
|
| 90 |
+
num_inference_steps = 28
|
| 91 |
+
|
| 92 |
+
# Create save directory based on benchmark directory name
|
| 93 |
+
save_root =os.path.join("outputs",benchmark_repo.split("/")[-1])
|
| 94 |
+
os.makedirs(save_root,exist_ok=True)
|
| 95 |
+
|
| 96 |
+
img_save_root = os.path.join(save_root,"images")
|
| 97 |
+
os.makedirs(img_save_root,exist_ok=True)
|
| 98 |
+
|
| 99 |
+
img_withgt_save_root = os.path.join(save_root,"images_with_gt")
|
| 100 |
+
os.makedirs(img_withgt_save_root,exist_ok=True)
|
| 101 |
+
|
| 102 |
+
total_time = 0
|
| 103 |
+
for i, batch in enumerate(tqdm(test_dataloader)):
|
| 104 |
+
prompts = batch["caption"]
|
| 105 |
+
imgs_id = batch['id']
|
| 106 |
+
objects_boxes = batch["objects_boxes"]
|
| 107 |
+
objects_caption = batch['objects_caption']
|
| 108 |
+
objects_masks = batch['objects_masks']
|
| 109 |
+
condition_img = batch['condition_img']
|
| 110 |
+
neg_condtion_img = batch['neg_condtion_img']
|
| 111 |
+
objects_masks_maps= batch['objects_masks_maps']
|
| 112 |
+
subject_masks_maps = batch['condition_img_masks_maps']
|
| 113 |
+
target_width=batch['target_width'][0]
|
| 114 |
+
target_height=batch['target_height'][0]
|
| 115 |
+
|
| 116 |
+
img_info = batch["img_info"][0]
|
| 117 |
+
filename = img_info["img_id"]+'.jpg'
|
| 118 |
+
start_time = time.time()
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
images = pipe(prompt=prompts*num_samples,
|
| 121 |
+
generator=torch.Generator(device="cuda").manual_seed(seed),
|
| 122 |
+
num_inference_steps = num_inference_steps,
|
| 123 |
+
objects_boxes=objects_boxes,
|
| 124 |
+
objects_caption=objects_caption,
|
| 125 |
+
objects_masks = objects_masks,
|
| 126 |
+
objects_masks_maps=objects_masks_maps,
|
| 127 |
+
condition_img = condition_img,
|
| 128 |
+
subject_masks_maps = subject_masks_maps,
|
| 129 |
+
neg_condtion_img = neg_condtion_img,
|
| 130 |
+
height= target_height,
|
| 131 |
+
width = target_width,
|
| 132 |
+
true_cfg_scale = true_cfg_scale,
|
| 133 |
+
position_delta=position_delta,
|
| 134 |
+
guidance_scale=guidance_scale,
|
| 135 |
+
scale_h = scale_h,
|
| 136 |
+
scale_w = scale_w,
|
| 137 |
+
use_bucket=use_bucket
|
| 138 |
+
)
|
| 139 |
+
images=images.images
|
| 140 |
+
use_time = time.time() - start_time
|
| 141 |
+
total_time +=use_time
|
| 142 |
+
|
| 143 |
+
make_image_grid_RGB(images, rows=1, cols=num_samples).save(os.path.join(img_save_root,filename))
|
| 144 |
+
use_time = time.time() - start_time
|
| 145 |
+
total_time +=use_time
|
| 146 |
+
|
| 147 |
+
# Process original image and bounding boxes
|
| 148 |
+
ori_image = tensor_to_pil(batch['img'][0])
|
| 149 |
+
orig_width, orig_height = ori_image.size
|
| 150 |
+
normalized_boxes = batch['objects_boxes'][0].cpu().numpy()
|
| 151 |
+
denormalized_boxes = []
|
| 152 |
+
for box in normalized_boxes:
|
| 153 |
+
x1, y1, x2, y2 = box
|
| 154 |
+
denorm_box = [
|
| 155 |
+
x1 * orig_width, # x1
|
| 156 |
+
y1 * orig_height, # y1
|
| 157 |
+
x2 * orig_width, # x2
|
| 158 |
+
y2 * orig_height # y2
|
| 159 |
+
]
|
| 160 |
+
denormalized_boxes.append(denorm_box)
|
| 161 |
+
|
| 162 |
+
objects_result = {
|
| 163 |
+
"boxes": denormalized_boxes,
|
| 164 |
+
"labels": batch['objects_caption'][0],
|
| 165 |
+
"masks": []
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
# Only keep boxes and captions where mask is 1
|
| 169 |
+
valid_boxes = []
|
| 170 |
+
valid_labels = []
|
| 171 |
+
for box, label, mask in zip(objects_result['boxes'],
|
| 172 |
+
objects_result['labels'],
|
| 173 |
+
batch['objects_masks'][0]):
|
| 174 |
+
if mask:
|
| 175 |
+
valid_boxes.append(box)
|
| 176 |
+
valid_labels.append(label)
|
| 177 |
+
|
| 178 |
+
objects_result['boxes'] = valid_boxes
|
| 179 |
+
objects_result['labels'] = valid_labels
|
| 180 |
+
|
| 181 |
+
ori_image_with_bbox = visualize_bbox(ori_image ,objects_result)
|
| 182 |
+
|
| 183 |
+
# Concatenate images
|
| 184 |
+
total_width = ori_image.width + ori_image.width+ num_samples*ori_image.width
|
| 185 |
+
max_height = ori_image.height
|
| 186 |
+
|
| 187 |
+
# Create a new blank image to hold the concatenated images
|
| 188 |
+
new_image = Image.new('RGB', (total_width, max_height))
|
| 189 |
+
|
| 190 |
+
new_image.paste(ori_image_with_bbox, (0, 0))
|
| 191 |
+
|
| 192 |
+
# Process condition image
|
| 193 |
+
condition_img = tensor_to_pil(batch['original_size_condition_img'][0])
|
| 194 |
+
subject_canvas_with_bbox = visualize_bbox(condition_img ,objects_result)
|
| 195 |
+
|
| 196 |
+
new_image.paste(subject_canvas_with_bbox, (ori_image.width, 0))
|
| 197 |
+
|
| 198 |
+
# Paste generated images
|
| 199 |
+
for j, image in enumerate(images):
|
| 200 |
+
|
| 201 |
+
save_name=os.path.join(img_withgt_save_root,filename)
|
| 202 |
+
|
| 203 |
+
image_with_bbox = visualize_bbox(image ,objects_result)
|
| 204 |
+
|
| 205 |
+
new_image.paste(image_with_bbox, (ori_image.width*(j+2), 0))
|
| 206 |
+
|
| 207 |
+
new_image.save(save_name)
|
| 208 |
+
|
| 209 |
+
print(f"Total inference time: {total_time:.2f} seconds")
|
| 210 |
+
print(f"Average time per image: {total_time/len(test_dataloader):.2f} seconds")
|