Yixuan commited on
Commit
d234621
·
0 Parent(s):

update readme

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +42 -0
  3. .gitignore +3 -0
  4. LICENSE +21 -0
  5. README-copy.md +308 -0
  6. README.md +420 -0
  7. assets/images/logo-text-2.png +3 -0
  8. assets/images/logo-text.png +3 -0
  9. assets/images/logo.png +3 -0
  10. diffsynth/__init__.py +6 -0
  11. diffsynth/configs/__init__.py +0 -0
  12. diffsynth/configs/model_config.py +778 -0
  13. diffsynth/controlnets/__init__.py +2 -0
  14. diffsynth/controlnets/controlnet_unit.py +91 -0
  15. diffsynth/controlnets/processors.py +62 -0
  16. diffsynth/data/__init__.py +1 -0
  17. diffsynth/data/simple_text_image.py +41 -0
  18. diffsynth/data/video.py +148 -0
  19. diffsynth/extensions/ESRGAN/__init__.py +137 -0
  20. diffsynth/extensions/FastBlend/__init__.py +63 -0
  21. diffsynth/extensions/FastBlend/api.py +397 -0
  22. diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
  23. diffsynth/extensions/FastBlend/data.py +146 -0
  24. diffsynth/extensions/FastBlend/patch_match.py +298 -0
  25. diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
  26. diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
  27. diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
  28. diffsynth/extensions/FastBlend/runners/fast.py +141 -0
  29. diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
  30. diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py +1 -0
  31. diffsynth/extensions/ImageQualityMetric/BLIP/blip.py +77 -0
  32. diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py +44 -0
  33. diffsynth/extensions/ImageQualityMetric/BLIP/med.py +947 -0
  34. diffsynth/extensions/ImageQualityMetric/BLIP/vit.py +301 -0
  35. diffsynth/extensions/ImageQualityMetric/__init__.py +148 -0
  36. diffsynth/extensions/ImageQualityMetric/aesthetic.py +148 -0
  37. diffsynth/extensions/ImageQualityMetric/clip.py +97 -0
  38. diffsynth/extensions/ImageQualityMetric/config.py +23 -0
  39. diffsynth/extensions/ImageQualityMetric/hps.py +118 -0
  40. diffsynth/extensions/ImageQualityMetric/imagereward.py +212 -0
  41. diffsynth/extensions/ImageQualityMetric/mps.py +129 -0
  42. diffsynth/extensions/ImageQualityMetric/open_clip/__init__.py +14 -0
  43. diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py +458 -0
  44. diffsynth/extensions/ImageQualityMetric/open_clip/constants.py +2 -0
  45. diffsynth/extensions/ImageQualityMetric/open_clip/factory.py +433 -0
  46. diffsynth/extensions/ImageQualityMetric/open_clip/generation_utils.py +0 -0
  47. diffsynth/extensions/ImageQualityMetric/open_clip/hf_configs.py +45 -0
  48. diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py +176 -0
  49. diffsynth/extensions/ImageQualityMetric/open_clip/loss.py +270 -0
  50. diffsynth/extensions/ImageQualityMetric/open_clip/model.py +461 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/images/logo-text.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/images/logo.png filter=lfs diff=lfs merge=lfs -text
38
+ diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/tokenizer.json filter=lfs diff=lfs merge=lfs -text
39
+ diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt filter=lfs diff=lfs merge=lfs -text
40
+ examples/output_videos/output_moe_framepack_sliding.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ logo-text-2.png filter=lfs diff=lfs merge=lfs -text
42
+ assets/images/logo-text-2.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Ignore all checkpoint files
2
+ *.ckpt
3
+ *.ckpt.*
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Yixuan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README-copy.md ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ReCamMaster: Camera-Controlled Generative Rendering from A Single Video (ICCV'25 Oral, Best Paper Finalist)
2
+
3
+ <div align="center">
4
+ <div align="center" style="margin-top: 0px; margin-bottom: 0px;">
5
+ <img src=https://github.com/user-attachments/assets/81ccf80e-f4b6-4a3d-b47a-e9c2ce14e34f width="30%"/>
6
+ </div>
7
+
8
+ ### [<a href="https://arxiv.org/abs/2503.11647" target="_blank">arXiv</a>] [<a href="https://jianhongbai.github.io/ReCamMaster/" target="_blank">Project Page</a>] [<a href="https://huggingface.co/datasets/KwaiVGI/MultiCamVideo-Dataset" target="_blank">Dataset</a>]
9
+ _**[Jianhong Bai<sup>1*</sup>](https://jianhongbai.github.io/), [Menghan Xia<sup>2†</sup>](https://menghanxia.github.io/), [Xiao Fu<sup>3</sup>](https://fuxiao0719.github.io/), [Xintao Wang<sup>2</sup>](https://xinntao.github.io/), [Lianrui Mu<sup>1</sup>](https://scholar.google.com/citations?user=dCik-2YAAAAJ&hl=en), [Jinwen Cao<sup>2</sup>](https://openreview.net/profile?id=~Jinwen_Cao1), <br>[Zuozhu Liu<sup>1</sup>](https://person.zju.edu.cn/en/lzz), [Haoji Hu<sup>1†</sup>](https://person.zju.edu.cn/en/huhaoji), [Xiang Bai<sup>4</sup>](https://scholar.google.com/citations?user=UeltiQ4AAAAJ&hl=en), [Pengfei Wan<sup>2</sup>](https://scholar.google.com/citations?user=P6MraaYAAAAJ&hl=en), [Di Zhang<sup>2</sup>](https://openreview.net/profile?id=~Di_ZHANG3)**_
10
+ <br>
11
+ (*Work done during an internship at KwaiVGI, Kuaishou Technology †corresponding authors)
12
+
13
+ <sup>1</sup>Zhejiang University, <sup>2</sup>Kuaishou Technology, <sup>3</sup>CUHK, <sup>4</sup>HUST.
14
+
15
+ </div>
16
+
17
+ **Important Note:** This open-source repository is intended to provide a reference implementation. Due to the difference in the underlying T2V model's performance, the open-source version may not achieve the same performance as the model in our paper. If you'd like to use the best version of ReCamMaster, please upload your video to [this link](https://docs.google.com/forms/d/e/1FAIpQLSezOzGPbm8JMXQDq6EINiDf6iXn7rV4ozj6KcbQCSAzE8Vsnw/viewform?usp=dialog). Additionally, we are working on developing an online trial website. Please stay tuned to updates on the [Kling website](https://app.klingai.com/global/).
18
+
19
+ ## 🔥 Updates
20
+ - __[2025.04.15]__: Please feel free to explore our related work, [SynCamMaster](https://github.com/KwaiVGI/SynCamMaster).
21
+ - __[2025.04.09]__: Release the [training and inference code](https://github.com/KwaiVGI/ReCamMaster?tab=readme-ov-file#%EF%B8%8F-code-recammaster--wan21-inference--training), [model checkpoint](https://huggingface.co/KwaiVGI/ReCamMaster-Wan2.1/blob/main/step20000.ckpt).
22
+ - __[2025.03.31]__: Release the [MultiCamVideo Dataset](https://huggingface.co/datasets/KwaiVGI/MultiCamVideo-Dataset).
23
+ - __[2025.03.31]__: We have sent the inference results to the first 1000 trial users.
24
+ - __[2025.03.17]__: Release the [project page](https://jianhongbai.github.io/ReCamMaster/) and the [try out link](https://docs.google.com/forms/d/e/1FAIpQLSezOzGPbm8JMXQDq6EINiDf6iXn7rV4ozj6KcbQCSAzE8Vsnw/viewform?usp=dialog).
25
+
26
+
27
+
28
+
29
+ ## 📖 Introduction
30
+
31
+ **TL;DR:** We propose ReCamMaster to re-capture in-the-wild videos with novel camera trajectories, achieved through our proposed simple-and-effective video conditioning scheme. We also release a multi-camera synchronized video [dataset](https://huggingface.co/datasets/KwaiVGI/MultiCamVideo-Dataset) rendered with Unreal Engine 5. <br>
32
+
33
+ https://github.com/user-attachments/assets/52455e86-1adb-458d-bc37-4540a65a60d4
34
+
35
+ ## 🚀 Trail: Try ReCamMaster with Your Own Videos
36
+
37
+ **Update:** We are actively processing the videos uploaded by users. So far, we have sent the inference results to the email addresses of the first **1500** testers. You should receive an email titled "Inference Results of ReCamMaster" from either [email protected] or [email protected]. Please also check your spam folder, and let us know if you haven't received the email after a long time. If you enjoyed the videos we created, please consider giving us a star 🌟.
38
+
39
+ **You can try out our ReCamMaster by uploading your own video to [this link](https://docs.google.com/forms/d/e/1FAIpQLSezOzGPbm8JMXQDq6EINiDf6iXn7rV4ozj6KcbQCSAzE8Vsnw/viewform?usp=dialog), which will generate a video with camera movements along a new trajectory.** We will send the mp4 file generated by ReCamMaster to your inbox as soon as possible. For camera movement trajectories, we offer 10 basic camera trajectories as follows:
40
+
41
+ | Index | Basic Trajectory |
42
+ |-------------------|-----------------------------|
43
+ | 1 | Pan Right |
44
+ | 2 | Pan Left |
45
+ | 3 | Tilt Up |
46
+ | 4 | Tilt Down |
47
+ | 5 | Zoom In |
48
+ | 6 | Zoom Out |
49
+ | 7 | Translate Up (with rotation) |
50
+ | 8 | Translate Down (with rotation) |
51
+ | 9 | Arc Left (with rotation) |
52
+ | 10 | Arc Right (with rotation) |
53
+
54
+ If you would like to use ReCamMaster as a baseline and need qualitative or quantitative comparisons, please feel free to drop an email to [[email protected]](mailto:[email protected]). We can assist you with batch inference of our model.
55
+
56
+ ## ⚙️ Code: ReCamMaster + Wan2.1 (Inference & Training)
57
+ The model utilized in our paper is an internally developed T2V model, not [Wan2.1](https://github.com/Wan-Video/Wan2.1). Due to company policy restrictions, we are unable to open-source the model used in the paper. Consequently, we migrated ReCamMaster to Wan2.1 to validate the effectiveness of our method. Due to differences in the underlying T2V model, you may not achieve the same results as demonstrated in the demo.
58
+ ### Inference
59
+ Step 1: Set up the environment
60
+
61
+ [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) requires Rust and Cargo to compile extensions. You can install them using the following command:
62
+ ```shell
63
+ curl --proto '=https' --tlsv1.2 -sSf [https://sh.rustup.rs](https://sh.rustup.rs/) | sh
64
+ . "$HOME/.cargo/env"
65
+ ```
66
+
67
+ Install [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio):
68
+ ```shell
69
+ git clone https://github.com/KwaiVGI/ReCamMaster.git
70
+ cd ReCamMaster
71
+ pip install -e .
72
+ ```
73
+
74
+ Step 2: Download the pretrained checkpoints
75
+ 1. Download the pre-trained Wan2.1 models
76
+
77
+ ```shell
78
+ cd ReCamMaster
79
+ python download_wan2.1.py
80
+ ```
81
+ 2. Download the pre-trained ReCamMaster checkpoint
82
+
83
+ Please download from [huggingface](https://huggingface.co/KwaiVGI/ReCamMaster-Wan2.1/blob/main/step20000.ckpt) and place it in ```models/ReCamMaster/checkpoints```.
84
+
85
+ Step 3: Test the example videos
86
+ ```shell
87
+ python inference_recammaster.py --cam_type 1
88
+ ```
89
+
90
+ Step 4: Test your own videos
91
+
92
+ If you want to test your own videos, you need to prepare your test data following the structure of the ```example_test_data``` folder. This includes N mp4 videos, each with at least 81 frames, and a ```metadata.csv``` file that stores their paths and corresponding captions. You can refer to the [Prompt Extension section](https://github.com/Wan-Video/Wan2.1?tab=readme-ov-file#2-using-prompt-extension) in Wan2.1 for guidance on preparing video captions.
93
+
94
+ ```shell
95
+ python inference_recammaster.py --cam_type 1 --dataset_path path/to/your/data
96
+ ```
97
+
98
+ We provide several preset camera types, as shown in the table below. Additionally, you can generate new trajectories for testing.
99
+
100
+ | cam_type | Trajectory |
101
+ |-------------------|-----------------------------|
102
+ | 1 | Pan Right |
103
+ | 2 | Pan Left |
104
+ | 3 | Tilt Up |
105
+ | 4 | Tilt Down |
106
+ | 5 | Zoom In |
107
+ | 6 | Zoom Out |
108
+ | 7 | Translate Up (with rotation) |
109
+ | 8 | Translate Down (with rotation) |
110
+ | 9 | Arc Left (with rotation) |
111
+ | 10 | Arc Right (with rotation) |
112
+
113
+ ### Training
114
+
115
+ Step 1: Set up the environment
116
+
117
+ ```shell
118
+ pip install lightning pandas websockets
119
+ ```
120
+
121
+ Step 2: Prepare the training dataset
122
+
123
+ 1. Download the [MultiCamVideo dataset](https://huggingface.co/datasets/KwaiVGI/MultiCamVideo-Dataset).
124
+
125
+ 2. Extract VAE features
126
+
127
+ ```shell
128
+ CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" python train_recammaster.py --task data_process --dataset_path path/to/the/MultiCamVideo/Dataset --output_path ./models --text_encoder_path "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth" --vae_path "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth" --tiled --num_frames 81 --height 480 --width 832 --dataloader_num_workers 2
129
+ ```
130
+
131
+ 3. Generate Captions for Each Video
132
+
133
+ You can use video caption tools like [LLaVA](https://github.com/haotian-liu/LLaVA) to generate captions for each video and store them in the ```metadata.csv``` file.
134
+
135
+ Step 3: Training
136
+ ```shell
137
+ CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" python train_recammaster.py --task train --dataset_path recam_train_data --output_path ./models/train --dit_path "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors" --steps_per_epoch 8000 --max_epochs 100 --learning_rate 1e-4 --accumulate_grad_batches 1 --use_gradient_checkpointing --dataloader_num_workers 4
138
+ ```
139
+ We do not explore the optimal set of hyper-parameters and train with a batch size of 1 on each GPU. You may achieve better model performance by adjusting hyper-parameters such as the learning rate and increasing the batch size.
140
+
141
+ Step 4: Test the model
142
+
143
+ ```shell
144
+ python inference_recammaster.py --cam_type 1 --ckpt_path path/to/the/checkpoint
145
+ ```
146
+
147
+ ## 📷 Dataset: MultiCamVideo Dataset
148
+ ### 1. Dataset Introduction
149
+
150
+ **TL;DR:** The MultiCamVideo Dataset is a multi-camera synchronized video dataset rendered using Unreal Engine 5. It includes synchronized multi-camera videos and their corresponding camera trajectories. The MultiCamVideo Dataset can be valuable in fields such as camera-controlled video generation, synchronized video production, and 3D/4D reconstruction. If you are looking for synchronized videos captured with stationary cameras, please explore our [SynCamVideo Dataset](https://github.com/KwaiVGI/SynCamMaster).
151
+
152
+ https://github.com/user-attachments/assets/6fa25bcf-1136-43be-8110-b527638874d4
153
+
154
+ The MultiCamVideo Dataset is a multi-camera synchronized video dataset rendered using Unreal Engine 5. It includes synchronized multi-camera videos and their corresponding camera trajectories.
155
+ It consists of 13.6K different dynamic scenes, each captured by 10 cameras, resulting in a total of 136K videos. Each dynamic scene is composed of four elements: {3D environment, character, animation, camera}. Specifically, we use animation to drive the character,
156
+ and position the animated character within the 3D environment. Then, Time-synchronized cameras are set up to move along predefined trajectories to render the multi-camera video data.
157
+ <p align="center">
158
+ <img src="https://github.com/user-attachments/assets/107c9607-e99b-4493-b715-3e194fcb3933" alt="Example Image" width="70%">
159
+ </p>
160
+
161
+ **3D Environment:** We collect 37 high-quality 3D environments assets from [Fab](https://www.fab.com). To minimize the domain gap between rendered data and real-world videos, we primarily select visually realistic 3D scenes, while choosing a few stylized or surreal 3D scenes as a supplement. To ensure data diversity, the selected scenes cover a variety of indoor and outdoor settings, such as city streets, shopping malls, cafes, office rooms, and the countryside.
162
+
163
+ **Character:** We collect 66 different human 3D models as characters from [Fab](https://www.fab.com) and [Mixamo](https://www.mixamo.com).
164
+
165
+ **Animation:** We collect 93 different animations from [Fab](https://www.fab.com) and [Mixamo](https://www.mixamo.com), including common actions such as waving, dancing, and cheering. We use these animations to drive the collected characters and create diverse datasets through various combinations.
166
+
167
+ **Camera:** To ensure camera movements are diverse and closely resemble real-world distributions, we create a wide range of camera trajectories and parameters to cover various situations. To achieve this by designing rules to batch-generate random camera starting positions and movement trajectories:
168
+
169
+ 1. Camera Starting Position.
170
+
171
+ We take the character's position as the center of a hemisphere with a radius of {3m, 5m, 7m, 10m} based on the size of the 3D scene and randomly sample within this range as the camera's starting point, ensuring the closest distance to the character is greater than 0.5m and the pitch angle is within 45 degrees.
172
+
173
+ 2. Camera Trajectories.
174
+
175
+ - **Pan & Tilt**:
176
+ The camera rotation angles are randomly selected within the range, with pan angles ranging from 5 to 45 degrees and tilt angles ranging from 5 to 30 degrees, with directions randomly chosen left/right or up/down.
177
+
178
+ - **Basic Translation**:
179
+ The camera translates along the positive and negative directions of the xyz axes, with movement distances randomly selected within the range of $[\frac{1}{4}, 1] \times \text{distance2character}$.
180
+
181
+ - **Basic Arc Trajectory**:
182
+ The camera moves along an arc, with rotation angles randomly selected within the range of 15 to 75 degrees.
183
+
184
+ - **Random Trajectories**:
185
+ 1-3 points are sampled in space, and the camera moves from the initial position through these points as the movement trajectory, with the total movement distance randomly selected within the range of $[\frac{1}{4}, 1] \times \text{distance2character}$. The polyline is smoothed to make the movement more natural.
186
+
187
+ - **Static Camera**:
188
+ The camera does not translate or rotate during shooting, maintaining a fixed position.
189
+
190
+ 3. Camera Movement Speed.
191
+
192
+ To further enhance the diversity of trajectories, 50% of the training data uses constant-speed camera trajectories, while the other 50% uses variable-speed trajectories generated by nonlinear functions. Consider a camera trajectory with a total of $f$ frames, starting at location $L_{start}$ and ending at position $L_{end}$. The location at the $i$-th frame is given by:
193
+ ```math
194
+ L_i = L_{start} + (L_{end} - L_{start}) \cdot \left( \frac{1 - \exp(-a \cdot i/f)}{1 - \exp(-a)} \right),
195
+ ```
196
+ where $a$ is an adjustable parameter to control the trajectory speed. When $a > 0$, the trajectory starts fast and then slows down; when $a < 0$, the trajectory starts slow and then speeds up. The larger the absolute value of $a$, the more drastic the change.
197
+
198
+ 4. Camera Parameters.
199
+
200
+ We chose four set of camera parameters: {focal=18mm, aperture=10}, {focal=24mm, aperture=5}, {focal=35mm, aperture=2.4} and {focal=50mm, aperture=2.4}.
201
+
202
+ ### 2. Statistics and Configurations
203
+
204
+ Dataset Statistics:
205
+
206
+ | Number of Dynamic Scenes | Camera per Scene | Total Videos |
207
+ |:------------------------:|:----------------:|:------------:|
208
+ | 13,600 | 10 | 136,000 |
209
+
210
+ Video Configurations:
211
+
212
+ | Resolution | Frame Number | FPS |
213
+ |:-----------:|:------------:|:------------------------:|
214
+ | 1280x1280 | 81 | 15 |
215
+
216
+ Note: You can use 'center crop' to adjust the video's aspect ratio to fit your video generation model, such as 16:9, 9:16, 4:3, or 3:4.
217
+
218
+ Camera Configurations:
219
+
220
+ | Focal Length | Aperture | Sensor Height | Sensor Width |
221
+ |:-----------------------:|:------------------:|:-------------:|:------------:|
222
+ | 18mm, 24mm, 35mm, 50mm | 10.0, 5.0, 2.4 | 23.76mm | 23.76mm |
223
+
224
+
225
+
226
+ ### 3. File Structure
227
+ ```
228
+ MultiCamVideo-Dataset
229
+ ├── train
230
+ │ ├── f18_aperture10
231
+ │ │ ├── scene1 # one dynamic scene
232
+ │ │ │ ├── videos
233
+ │ │ │ │ ├── cam01.mp4 # synchronized 81-frame videos at 1280x1280 resolution
234
+ │ │ │ │ ├── cam02.mp4
235
+ │ │ │ │ ├── ...
236
+ │ │ │ │ └── cam10.mp4
237
+ │ │ │ └── cameras
238
+ │ │ │ └── camera_extrinsics.json # 81-frame camera extrinsics of the 10 cameras
239
+ │ │ ├── ...
240
+ │ │ └── scene3400
241
+ │ ├── f24_aperture5
242
+ │ │ ├── scene1
243
+ │ │ ├── ...
244
+ │ │ └── scene3400
245
+ │ ├── f35_aperture2.4
246
+ │ │ ├── scene1
247
+ │ │ ├── ...
248
+ │ │ └── scene3400
249
+ │ └── f50_aperture2.4
250
+ │ ├── scene1
251
+ │ ├── ...
252
+ │ └── scene3400
253
+ └── val
254
+ └── 10basic_trajectories
255
+ ├── videos
256
+ │ ├── cam01.mp4 # example videos corresponding to the validation cameras
257
+ │ ├── cam02.mp4
258
+ │ ├── ...
259
+ │ └── cam10.mp4
260
+ └── cameras
261
+ └── camera_extrinsics.json # 10 different trajectories for validation
262
+ ```
263
+
264
+ ### 3. Useful scripts
265
+ - Data Extraction
266
+ ```bash
267
+ cat MultiCamVideo-Dataset.part* > MultiCamVideo-Dataset.tar.gz
268
+ tar -xzvf MultiCamVideo-Dataset.tar.gz
269
+ ```
270
+ - Camera Visualization
271
+ ```python
272
+ python vis_cam.py
273
+ ```
274
+
275
+ The visualization script is modified from [CameraCtrl](https://github.com/hehao13/CameraCtrl/blob/main/tools/visualize_trajectory.py), thanks for their inspiring work.
276
+
277
+ <p align="center">
278
+ <img src="https://github.com/user-attachments/assets/f9cf342d-2fb3-40ef-a7be-edafb5775004" alt="Example Image" width="40%">
279
+ </p>
280
+
281
+ ## 🤗 Awesome Related Works
282
+ Feel free to explore these outstanding related works, including but not limited to:
283
+
284
+ [GCD](https://gcd.cs.columbia.edu/): GCD synthesizes large-angle novel viewpoints of 4D dynamic scenes from a monocular video.
285
+
286
+ [ReCapture](https://generative-video-camera-controls.github.io/): a method for generating new videos with novel camera trajectories from a single user-provided video.
287
+
288
+ [Trajectory Attention](https://xizaoqu.github.io/trajattn/): Trajectory Attention facilitates various tasks like camera motion control on images and videos, and video editing.
289
+
290
+ [GS-DiT](https://wkbian.github.io/Projects/GS-DiT/): GS-DiT provides 4D video control for a single monocular video.
291
+
292
+ [Diffusion as Shader](https://igl-hkust.github.io/das/): a versatile video generation control model for various tasks.
293
+
294
+ [TrajectoryCrafter](https://trajectorycrafter.github.io/): TrajectoryCrafter achieves high-fidelity novel views generation from casually captured monocular video.
295
+
296
+ [GEN3C](https://research.nvidia.com/labs/toronto-ai/GEN3C/): a generative video model with precise Camera Control and temporal 3D Consistency.
297
+
298
+ ## 🌟 Citation
299
+
300
+ Please leave us a star 🌟 and cite our paper if you find our work helpful.
301
+ ```
302
+ @article{bai2025recammaster,
303
+ title={ReCamMaster: Camera-Controlled Generative Rendering from A Single Video},
304
+ author={Bai, Jianhong and Xia, Menghan and Fu, Xiao and Wang, Xintao and Mu, Lianrui and Cao, Jinwen and Liu, Zuozhu and Hu, Haoji and Bai, Xiang and Wan, Pengfei and others},
305
+ journal={arXiv preprint arXiv:2503.11647},
306
+ year={2025}
307
+ }
308
+ ```
README.md ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - video-generation
5
+ - diffusion
6
+ - world-model
7
+ library_name: diffusion
8
+ model-index:
9
+ - name: Astra
10
+ results: []
11
+ ---
12
+
13
+ # Astra 🌏: General Interactive World Model with Autoregressive Denoising
14
+
15
+ <div align="center">
16
+
17
+ <div style="margin-top: 0; margin-bottom: -20px;">
18
+ <img src="./assets/images/logo-text-2.png" width="50%" />
19
+ </div>
20
+
21
+ <h3 style="margin-top: 0;">
22
+ 📄
23
+ [<a href="https://arxiv.org/abs/2512.08931" target="_blank">arXiv</a>]
24
+ &nbsp;&nbsp;
25
+ 🏠
26
+ [<a href="https://eternalevan.github.io/Astra-project/" target="_blank">Project Page</a>]
27
+ &nbsp;&nbsp;
28
+ 🖥️
29
+ [<a href="https://huggingface.co/EvanEternal/Astra" target="_blank">Github</a>]
30
+ </h3>
31
+
32
+ </div>
33
+
34
+
35
+ <!-- <div align="center">
36
+ <a href="https://hunyuan.tencent.com/video/zh?tabIndex=0" target="_blank"><img src=https://img.shields.io/badge/Official%20Site-333399.svg?logo=homepage height=22px></a>
37
+ <a href=https://huggingface.co/tencent/HunyuanVideo-1.5 target="_blank"><img src=https://img.shields.io/badge/%F0%9F%A4%97%20Models-d96902.svg height=22px></a>
38
+ <a href=https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5 target="_blank"><img src= https://img.shields.io/badge/Page-bb8a2e.svg?logo=github height=22px></a>
39
+ <a href="https://arxiv.org/pdf/2511.18870" target="_blank"><img src=https://img.shields.io/badge/Report-b5212f.svg?logo=arxiv height=22px></a>
40
+ <a href=https://x.com/TencentHunyuan target="_blank"><img src=https://img.shields.io/badge/Hunyuan-black.svg?logo=x height=22px></a>
41
+ <a href="https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/assets/HunyuanVideo_1_5_Prompt_Handbook_EN.md" target="_blank"><img src=https://img.shields.io/badge/📚-PromptHandBook-blue.svg?logo=book height=22px></a> <br/>
42
+ <a href="./ComfyUI/README.md" target="_blank"><img src=https://img.shields.io/badge/ComfyUI-blue.svg?logo=book height=22px></a>
43
+ <a href="https://github.com/ModelTC/LightX2V" target="_blank"><img src=https://img.shields.io/badge/LightX2V-yellow.svg?logo=book height=22px></a>
44
+ <a href="https://tusi.cn/models/933574988890423836" target="_blank"><img src=https://img.shields.io/badge/吐司-purple.svg?logo=book height=22px></a>
45
+ <a href="https://tensor.art/models/933574988890423836" target="_blank"><img src=https://img.shields.io/badge/TensorArt-cyan.svg?logo=book height=22px></a>
46
+
47
+ </div> -->
48
+
49
+ <div align="center">
50
+
51
+ **[Yixuan Zhu<sup>1</sup>](https://eternalevan.github.io/), [Jiaqi Feng<sup>1</sup>](https://github.com/Aurora-edu/), [Wenzhao Zheng<sup>1 †</sup>](https://wzzheng.net), [Yuan Gao<sup>2</sup>](https://openreview.net/profile?id=~Yuan_Gao32), [Xin Tao<sup>2</sup>](https://www.xtao.website), [Pengfei Wan<sup>2</sup>](https://scholar.google.com/citations?user=P6MraaYAAAAJ&hl=en), [Jie Zhou <sup>1</sup>](https://scholar.google.com/citations?user=6a79aPwAAAAJ&hl=en&authuser=1), [Jiwen Lu<sup>1</sup>](https://ivg.au.tsinghua.edu.cn/Jiwen_Lu/)**
52
+ <!-- <br> -->
53
+ († Project leader)
54
+
55
+ <sup>1</sup>Tsinghua University, <sup>2</sup>Kuaishou Technology.
56
+ </div>
57
+
58
+
59
+ ## 📖 Introduction
60
+
61
+ **TL;DR:** Astra is an **interactive world model** that delivers realistic long-horizon video rollouts under a wide range of scenarios and action inputs.
62
+
63
+ **Astra** is an **interactive**, action-driven world model that predicts long-horizon future videos across diverse real-world scenarios. Built on an autoregressive diffusion transformer with temporal causal attention, Astra supports **streaming prediction** while preserving strong temporal coherence. Astra introduces **noise-augmented history memory** to stabilize long rollouts, an **action-aware adapter** for precise control signals, and a **mixture of action experts** to route heterogeneous action modalities. Through these key innovations, Astra delivers consistent, controllable, and high-fidelity video futures for applications such as autonomous driving, robot manipulation, and camera motion.
64
+
65
+ <!-- <div align="center">
66
+ <img src="./assets/images/pipeline.png" alt="Astra Pipeline" width="90%">
67
+ </div> -->
68
+
69
+ ## Gallery
70
+
71
+ ### Astra+Wan2.1
72
+
73
+ <table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
74
+ <tr>
75
+ <td>
76
+ <video src="https://github.com/user-attachments/assets/715a5b66-3966-4923-aa00-02315fb07761"
77
+ style="width:100%; object-fit:cover;"
78
+ controls autoplay loop muted></video>
79
+ </td>
80
+ <td>
81
+ <video src="https://github.com/user-attachments/assets/c7156c4d-d51d-493c-995e-5113c3d49abb"
82
+ style="width:100%; object-fit:cover;"
83
+ controls autoplay loop muted></video>
84
+ </td>
85
+ </tr>
86
+
87
+ <tr>
88
+ <td>
89
+ <video src="https://github.com/user-attachments/assets/d899d704-c706-4e64-a24b-eea174d2173d"
90
+ style="width:100%; height:180px; object-fit:cover;"
91
+ controls autoplay loop muted></video>
92
+ </td>
93
+ <td>
94
+ <video src="https://github.com/user-attachments/assets/c1d8beb2-3102-468a-8019-624d89fba125"
95
+ style="width:100%; height:180px; object-fit:cover;"
96
+ controls autoplay loop muted></video>
97
+ </td>
98
+ </tr>
99
+ </table>
100
+
101
+ <!-- ## 🚀 Trail: Try ReCamMaster with Your Own Videos
102
+
103
+ **Update:** We are actively processing the videos uploaded by users. So far, we have sent the inference results to the email addresses of the first **1180** testers. You should receive an email titled "Inference Results of ReCamMaster" from either [email protected] or [email protected]. Please also check your spam folder, and let us know if you haven't received the email after a long time. If you enjoyed the videos we created, please consider giving us a star 🌟.
104
+
105
+ **You can try out our ReCamMaster by uploading your own video to [this link](https://docs.google.com/forms/d/e/1FAIpQLSezOzGPbm8JMXQDq6EINiDf6iXn7rV4ozj6KcbQCSAzE8Vsnw/viewform?usp=dialog), which will generate a video with camera movements along a new trajectory.** We will send the mp4 file generated by ReCamMaster to your inbox as soon as possible. For camera movement trajectories, we offer 10 basic camera trajectories as follows:
106
+
107
+ | Index | Basic Trajectory |
108
+ |-------------------|-----------------------------|
109
+ | 1 | Pan Right |
110
+ | 2 | Pan Left |
111
+ | 3 | Tilt Up |
112
+ | 4 | Tilt Down |
113
+ | 5 | Zoom In |
114
+ | 6 | Zoom Out |
115
+ | 7 | Translate Up (with rotation) |
116
+ | 8 | Translate Down (with rotation) |
117
+ | 9 | Arc Left (with rotation) |
118
+ | 10 | Arc Right (with rotation) |
119
+
120
+ If you would like to use ReCamMaster as a baseline and need qualitative or quantitative comparisons, please feel free to drop an email to [[email protected]](mailto:[email protected]). We can assist you with batch inference of our model. -->
121
+
122
+ ## 🔥 Updates
123
+ - __[2025.11.17]__: Release the [project page](https://eternalevan.github.io/Astra-project/).
124
+ - __[2025.12.09]__: Release the inference code, model checkpoint.
125
+
126
+ ## 🎯 TODO List
127
+
128
+ - [ ] **Release full inference pipelines** for additional scenarios:
129
+ - [ ] 🚗 Autonomous driving
130
+ - [ ] 🤖 Robotic manipulation
131
+ - [ ] 🛸 Drone navigation / exploration
132
+
133
+
134
+ - [ ] **Open-source training scripts**:
135
+ - [ ] ⬆️ Action-conditioned autoregressive denoising training
136
+ - [ ] 🔄 Multi-scenario joint training pipeline
137
+
138
+ - [ ] **Release dataset preprocessing tools**
139
+
140
+ - [ ] **Provide unified evaluation toolkit**
141
+
142
+ ## ⚙️ Run Astra (Inference)
143
+ Astra is built upon [Wan2.1-1.3B](https://github.com/Wan-Video/Wan2.1), a diffusion-based video generation model. We provide inference scripts to help you quickly generate videos from images and action inputs. Follow the steps below:
144
+
145
+ ### Inference
146
+ Step 1: Set up the environment
147
+
148
+ [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) requires Rust and Cargo to compile extensions. You can install them using the following command:
149
+ ```shell
150
+ curl --proto '=https' --tlsv1.2 -sSf [https://sh.rustup.rs](https://sh.rustup.rs/) | sh
151
+ . "$HOME/.cargo/env"
152
+ ```
153
+
154
+ Install [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio):
155
+ ```shell
156
+ git clone https://github.com/EternalEvan/Astra.git
157
+ cd Astra
158
+ pip install -e .
159
+ ```
160
+
161
+ Step 2: Download the pretrained checkpoints
162
+ 1. Download the pre-trained Wan2.1 models
163
+
164
+ ```shell
165
+ cd script
166
+ python download_wan2.1.py
167
+ ```
168
+ 2. Download the pre-trained Astra checkpoint
169
+
170
+ Please download from [huggingface](https://huggingface.co/EvanEternal/Astra/blob/main/models/Astra/checkpoints/diffusion_pytorch_model.ckpt) and place it in ```models/Astra/checkpoints```.
171
+
172
+ Step 3: Test the example image
173
+ ```shell
174
+ python infer_demo.py \
175
+ --dit_path ../models/Astra/checkpoints/diffusion_pytorch_model.ckpt \
176
+ --wan_model_path ../models/Wan-AI/Wan2.1-T2V-1.3B \
177
+ --condition_image ../examples/condition_images/garden_1.png \
178
+ --cam_type 4 \
179
+ --prompt "A sunlit European street lined with historic buildings and vibrant greenery creates a warm, charming, and inviting atmosphere. The scene shows a picturesque open square paved with red bricks, surrounded by classic narrow townhouses featuring tall windows, gabled roofs, and dark-painted facades. On the right side, a lush arrangement of potted plants and blooming flowers adds rich color and texture to the foreground. A vintage-style streetlamp stands prominently near the center-right, contributing to the timeless character of the street. Mature trees frame the background, their leaves glowing in the warm afternoon sunlight. Bicycles are visible along the edges of the buildings, reinforcing the urban yet leisurely feel. The sky is bright blue with scattered clouds, and soft sun flares enter the frame from the left, enhancing the scene’s inviting, peaceful mood." \
180
+ --output_path ../examples/output_videos/output_moe_framepack_sliding.mp4 \
181
+ ```
182
+
183
+ Step 4: Test your own images
184
+
185
+ To test with your own custom images, you need to prepare the target images and their corresponding text prompts. **We recommend that the size of the input images is close to 832×480 (width × height)**, which is consistent with the resolution of the generated video and can help achieve better video generation effects. For prompts generation, you can refer to the [Prompt Extension section](https://github.com/Wan-Video/Wan2.1?tab=readme-ov-file#2-using-prompt-extension) in Wan2.1 for guidance on crafting the captions.
186
+
187
+ ```shell
188
+ python infer_demo.py \
189
+ --dit_path path/to/your/dit_ckpt \
190
+ --wan_model_path path/to/your/Wan2.1-T2V-1.3B \
191
+ --condition_image path/to/your/image \
192
+ --cam_type your_cam_type \
193
+ --prompt your_prompt \
194
+ --output_path path/to/your/output_video
195
+ ```
196
+
197
+ We provide several preset camera types, as shown in the table below. Additionally, you can generate new trajectories for testing.
198
+
199
+ | cam_type | Trajectory |
200
+ |:-----------:|-----------------------------|
201
+ | 1 | Move Forward (Straight) |
202
+ | 2 | Rotate Left In Place |
203
+ | 3 | Rotate Right In Place |
204
+ | 4 | Move Forward + Rotate Left |
205
+ | 5 | Move Forward + Rotate Right |
206
+ | 6 | S-shaped Trajectory |
207
+ | 7 | Rotate Left → Rotate Right |
208
+
209
+
210
+ ## Future Work 🚀
211
+
212
+ Looking ahead, we plan to further enhance Astra in several directions:
213
+
214
+ - **Training with Wan-2.2:** Upgrade our model using the latest Wan-2.2 framework to release a more powerful version with improved generation quality.
215
+ - **3D Spatial Consistency:** Explore techniques to better preserve 3D consistency across frames for more coherent and realistic video generation.
216
+ - **Long-Term Memory:** Incorporate mechanisms for long-term memory, enabling the model to handle extended temporal dependencies and complex action sequences.
217
+
218
+ These directions aim to push Astra towards more robust and interactive video world modeling.
219
+
220
+ <!-- ### Training
221
+
222
+ Step 1: Set up the environment
223
+
224
+ ```shell
225
+ pip install lightning pandas websockets
226
+ ```
227
+
228
+ Step 2: Prepare the training dataset
229
+
230
+ 1. Download the [MultiCamVideo dataset](https://huggingface.co/datasets/KwaiVGI/MultiCamVideo-Dataset).
231
+
232
+ 2. Extract VAE features
233
+
234
+ ```shell
235
+ CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" python train_recammaster.py --task data_process --dataset_path path/to/the/MultiCamVideo/Dataset --output_path ./models --text_encoder_path "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth" --vae_path "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth" --tiled --num_frames 81 --height 480 --width 832 --dataloader_num_workers 2
236
+ ```
237
+
238
+ 3. Generate Captions for Each Video
239
+
240
+ You can use video caption tools like [LLaVA](https://github.com/haotian-liu/LLaVA) to generate captions for each video and store them in the ```metadata.csv``` file.
241
+
242
+ Step 3: Training
243
+ ```shell
244
+ CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" python train_recammaster.py --task train --dataset_path recam_train_data --output_path ./models/train --dit_path "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors" --steps_per_epoch 8000 --max_epochs 100 --learning_rate 1e-4 --accumulate_grad_batches 1 --use_gradient_checkpointing --dataloader_num_workers 4
245
+ ```
246
+ We do not explore the optimal set of hyper-parameters and train with a batch size of 1 on each GPU. You may achieve better model performance by adjusting hyper-parameters such as the learning rate and increasing the batch size.
247
+
248
+ Step 4: Test the model
249
+
250
+ ```shell
251
+ python inference_recammaster.py --cam_type 1 --ckpt_path path/to/the/checkpoint
252
+ ``` -->
253
+
254
+ <!-- ## 📷 Dataset: MultiCamVideo Dataset
255
+ ### 1. Dataset Introduction
256
+
257
+ **TL;DR:** The MultiCamVideo Dataset is a multi-camera synchronized video dataset rendered using Unreal Engine 5. It includes synchronized multi-camera videos and their corresponding camera trajectories. The MultiCamVideo Dataset can be valuable in fields such as camera-controlled video generation, synchronized video production, and 3D/4D reconstruction.
258
+
259
+ https://github.com/user-attachments/assets/6fa25bcf-1136-43be-8110-b527638874d4
260
+
261
+ The MultiCamVideo Dataset is a multi-camera synchronized video dataset rendered using Unreal Engine 5. It includes synchronized multi-camera videos and their corresponding camera trajectories.
262
+ It consists of 13.6K different dynamic scenes, each captured by 10 cameras, resulting in a total of 136K videos. Each dynamic scene is composed of four elements: {3D environment, character, animation, camera}. Specifically, we use animation to drive the character,
263
+ and position the animated character within the 3D environment. Then, Time-synchronized cameras are set up to move along predefined trajectories to render the multi-camera video data.
264
+ <p align="center">
265
+ <img src="https://github.com/user-attachments/assets/107c9607-e99b-4493-b715-3e194fcb3933" alt="Example Image" width="70%">
266
+ </p>
267
+
268
+ **3D Environment:** We collect 37 high-quality 3D environments assets from [Fab](https://www.fab.com). To minimize the domain gap between rendered data and real-world videos, we primarily select visually realistic 3D scenes, while choosing a few stylized or surreal 3D scenes as a supplement. To ensure data diversity, the selected scenes cover a variety of indoor and outdoor settings, such as city streets, shopping malls, cafes, office rooms, and the countryside.
269
+
270
+ **Character:** We collect 66 different human 3D models as characters from [Fab](https://www.fab.com) and [Mixamo](https://www.mixamo.com).
271
+
272
+ **Animation:** We collect 93 different animations from [Fab](https://www.fab.com) and [Mixamo](https://www.mixamo.com), including common actions such as waving, dancing, and cheering. We use these animations to drive the collected characters and create diverse datasets through various combinations.
273
+
274
+ **Camera:** To ensure camera movements are diverse and closely resemble real-world distributions, we create a wide range of camera trajectories and parameters to cover various situations. To achieve this by designing rules to batch-generate random camera starting positions and movement trajectories:
275
+
276
+ 1. Camera Starting Position.
277
+
278
+ We take the character's position as the center of a hemisphere with a radius of {3m, 5m, 7m, 10m} based on the size of the 3D scene and randomly sample within this range as the camera's starting point, ensuring the closest distance to the character is greater than 0.5m and the pitch angle is within 45 degrees.
279
+
280
+ 2. Camera Trajectories.
281
+
282
+ - **Pan & Tilt**:
283
+ The camera rotation angles are randomly selected within the range, with pan angles ranging from 5 to 45 degrees and tilt angles ranging from 5 to 30 degrees, with directions randomly chosen left/right or up/down.
284
+
285
+ - **Basic Translation**:
286
+ The camera translates along the positive and negative directions of the xyz axes, with movement distances randomly selected within the range of $[\frac{1}{4}, 1] \times \text{distance2character}$.
287
+
288
+ - **Basic Arc Trajectory**:
289
+ The camera moves along an arc, with rotation angles randomly selected within the range of 15 to 75 degrees.
290
+
291
+ - **Random Trajectories**:
292
+ 1-3 points are sampled in space, and the camera moves from the initial position through these points as the movement trajectory, with the total movement distance randomly selected within the range of $[\frac{1}{4}, 1] \times \text{distance2character}$. The polyline is smoothed to make the movement more natural.
293
+
294
+ - **Static Camera**:
295
+ The camera does not translate or rotate during shooting, maintaining a fixed position.
296
+
297
+ 3. Camera Movement Speed.
298
+
299
+ To further enhance the diversity of trajectories, 50% of the training data uses constant-speed camera trajectories, while the other 50% uses variable-speed trajectories generated by nonlinear functions. Consider a camera trajectory with a total of $f$ frames, starting at location $L_{start}$ and ending at position $L_{end}$. The location at the $i$-th frame is given by:
300
+ ```math
301
+ L_i = L_{start} + (L_{end} - L_{start}) \cdot \left( \frac{1 - \exp(-a \cdot i/f)}{1 - \exp(-a)} \right),
302
+ ```
303
+ where $a$ is an adjustable parameter to control the trajectory speed. When $a > 0$, the trajectory starts fast and then slows down; when $a < 0$, the trajectory starts slow and then speeds up. The larger the absolute value of $a$, the more drastic the change.
304
+
305
+ 4. Camera Parameters.
306
+
307
+ We chose four set of camera parameters: {focal=18mm, aperture=10}, {focal=24mm, aperture=5}, {focal=35mm, aperture=2.4} and {focal=50mm, aperture=2.4}.
308
+
309
+ ### 2. Statistics and Configurations
310
+
311
+ Dataset Statistics:
312
+
313
+ | Number of Dynamic Scenes | Camera per Scene | Total Videos |
314
+ |:------------------------:|:----------------:|:------------:|
315
+ | 13,600 | 10 | 136,000 |
316
+
317
+ Video Configurations:
318
+
319
+ | Resolution | Frame Number | FPS |
320
+ |:-----------:|:------------:|:------------------------:|
321
+ | 1280x1280 | 81 | 15 |
322
+
323
+ Note: You can use 'center crop' to adjust the video's aspect ratio to fit your video generation model, such as 16:9, 9:16, 4:3, or 3:4.
324
+
325
+ Camera Configurations:
326
+
327
+ | Focal Length | Aperture | Sensor Height | Sensor Width |
328
+ |:-----------------------:|:------------------:|:-------------:|:------------:|
329
+ | 18mm, 24mm, 35mm, 50mm | 10.0, 5.0, 2.4 | 23.76mm | 23.76mm |
330
+
331
+
332
+
333
+ ### 3. File Structure
334
+ ```
335
+ MultiCamVideo-Dataset
336
+ ├── train
337
+ │ ├── f18_aperture10
338
+ │ │ ├── scene1 # one dynamic scene
339
+ │ │ │ ├── videos
340
+ │ │ │ │ ├── cam01.mp4 # synchronized 81-frame videos at 1280x1280 resolution
341
+ │ │ │ │ ├── cam02.mp4
342
+ │ │ │ │ ├── ...
343
+ │ │ │ │ └── cam10.mp4
344
+ │ │ │ └── cameras
345
+ │ │ │ └── camera_extrinsics.json # 81-frame camera extrinsics of the 10 cameras
346
+ │ │ ├── ...
347
+ │ │ └── scene3400
348
+ │ ├── f24_aperture5
349
+ │ │ ├── scene1
350
+ │ │ ├── ...
351
+ │ │ └── scene3400
352
+ │ ├── f35_aperture2.4
353
+ │ │ ├── scene1
354
+ │ │ ├── ...
355
+ │ │ └── scene3400
356
+ │ └── f50_aperture2.4
357
+ │ ├── scene1
358
+ │ ├── ...
359
+ │ └── scene3400
360
+ └── val
361
+ └── 10basic_trajectories
362
+ ├── videos
363
+ │ ├── cam01.mp4 # example videos corresponding to the validation cameras
364
+ │ ├── cam02.mp4
365
+ │ ├── ...
366
+ │ └── cam10.mp4
367
+ └── cameras
368
+ └── camera_extrinsics.json # 10 different trajectories for validation
369
+ ```
370
+
371
+ ### 3. Useful scripts
372
+ - Data Extraction
373
+ ```bash
374
+ cat MultiCamVideo-Dataset.part* > MultiCamVideo-Dataset.tar.gz
375
+ tar -xzvf MultiCamVideo-Dataset.tar.gz
376
+ ```
377
+ - Camera Visualization
378
+ ```python
379
+ python vis_cam.py
380
+ ```
381
+
382
+ The visualization script is modified from [CameraCtrl](https://github.com/hehao13/CameraCtrl/blob/main/tools/visualize_trajectory.py), thanks for their inspiring work.
383
+
384
+ <p align="center">
385
+ <img src="https://github.com/user-attachments/assets/f9cf342d-2fb3-40ef-a7be-edafb5775004" alt="Example Image" width="40%">
386
+ </p> -->
387
+
388
+ ## 🤗 Awesome Related Works
389
+ Feel free to explore these outstanding related works, including but not limited to:
390
+
391
+ [ReCamMaster](https://github.com/KlingTeam/ReCamMaster): ReCamMaster re-captures in-the-wild videos with novel camera trajectories.
392
+
393
+ [GCD](https://gcd.cs.columbia.edu/): GCD synthesizes large-angle novel viewpoints of 4D dynamic scenes from a monocular video.
394
+
395
+ [ReCapture](https://generative-video-camera-controls.github.io/): a method for generating new videos with novel camera trajectories from a single user-provided video.
396
+
397
+ [Trajectory Attention](https://xizaoqu.github.io/trajattn/): Trajectory Attention facilitates various tasks like camera motion control on images and videos, and video editing.
398
+
399
+ [GS-DiT](https://wkbian.github.io/Projects/GS-DiT/): GS-DiT provides 4D video control for a single monocular video.
400
+
401
+ [Diffusion as Shader](https://igl-hkust.github.io/das/): a versatile video generation control model for various tasks.
402
+
403
+ [TrajectoryCrafter](https://trajectorycrafter.github.io/): TrajectoryCrafter achieves high-fidelity novel views generation from casually captured monocular video.
404
+
405
+ [GEN3C](https://research.nvidia.com/labs/toronto-ai/GEN3C/): a generative video model with precise Camera Control and temporal 3D Consistency.
406
+
407
+ ## 🌟 Citation
408
+
409
+ Please leave us a star 🌟 and cite our paper if you find our work helpful.
410
+ ```
411
+ @misc{zhu2025astrageneralinteractiveworld,
412
+ title={Astra: General Interactive World Model with Autoregressive Denoising},
413
+ author={Yixuan Zhu and Jiaqi Feng and Wenzhao Zheng and Yuan Gao and Xin Tao and Pengfei Wan and Jie Zhou and Jiwen Lu},
414
+ year={2025},
415
+ eprint={2512.08931},
416
+ archivePrefix={arXiv},
417
+ primaryClass={cs.CV},
418
+ url={https://arxiv.org/abs/2512.08931},
419
+ }
420
+ ```
assets/images/logo-text-2.png ADDED

Git LFS Details

  • SHA256: c993f6517368ee0189d3609943645f92d3e5c368674df0df3ae6135a8edb0559
  • Pointer size: 131 Bytes
  • Size of remote file: 195 kB
assets/images/logo-text.png ADDED

Git LFS Details

  • SHA256: 68a4ab9c6f5df9557de493522f97883d7ea93be960e6e66dfa92942595eb69bb
  • Pointer size: 131 Bytes
  • Size of remote file: 145 kB
assets/images/logo.png ADDED

Git LFS Details

  • SHA256: 6097e3e32506f8e1e9705756a39bc4f9a863af8183c2017e1f1ab102bad62253
  • Pointer size: 131 Bytes
  • Size of remote file: 169 kB
diffsynth/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .data import *
2
+ from .models import *
3
+ from .prompters import *
4
+ from .schedulers import *
5
+ from .pipelines import *
6
+ from .controlnets import *
diffsynth/configs/__init__.py ADDED
File without changes
diffsynth/configs/model_config.py ADDED
@@ -0,0 +1,778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import Literal, TypeAlias
2
+
3
+ from ..models.sd_text_encoder import SDTextEncoder
4
+ from ..models.sd_unet import SDUNet
5
+ from ..models.sd_vae_encoder import SDVAEEncoder
6
+ from ..models.sd_vae_decoder import SDVAEDecoder
7
+
8
+ from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
9
+ from ..models.sdxl_unet import SDXLUNet
10
+ from ..models.sdxl_vae_decoder import SDXLVAEDecoder
11
+ from ..models.sdxl_vae_encoder import SDXLVAEEncoder
12
+
13
+ from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
14
+ from ..models.sd3_dit import SD3DiT
15
+ from ..models.sd3_vae_decoder import SD3VAEDecoder
16
+ from ..models.sd3_vae_encoder import SD3VAEEncoder
17
+
18
+ from ..models.sd_controlnet import SDControlNet
19
+ from ..models.sdxl_controlnet import SDXLControlNetUnion
20
+
21
+ from ..models.sd_motion import SDMotionModel
22
+ from ..models.sdxl_motion import SDXLMotionModel
23
+
24
+ from ..models.svd_image_encoder import SVDImageEncoder
25
+ from ..models.svd_unet import SVDUNet
26
+ from ..models.svd_vae_decoder import SVDVAEDecoder
27
+ from ..models.svd_vae_encoder import SVDVAEEncoder
28
+
29
+ from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
30
+ from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
31
+
32
+ from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
33
+ from ..models.hunyuan_dit import HunyuanDiT
34
+
35
+ from ..models.flux_dit import FluxDiT
36
+ from ..models.flux_text_encoder import FluxTextEncoder2
37
+ from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
38
+ from ..models.flux_controlnet import FluxControlNet
39
+ from ..models.flux_ipadapter import FluxIpAdapter
40
+
41
+ from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
42
+ from ..models.cog_dit import CogDiT
43
+
44
+ from ..models.omnigen import OmniGenTransformer
45
+
46
+ from ..models.hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
47
+ from ..models.hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
48
+
49
+ from ..extensions.RIFE import IFNet
50
+ from ..extensions.ESRGAN import RRDBNet
51
+
52
+ from ..models.hunyuan_video_dit import HunyuanVideoDiT
53
+
54
+ from ..models.stepvideo_vae import StepVideoVAE
55
+ from ..models.stepvideo_dit import StepVideoModel
56
+
57
+ from ..models.wan_video_dit import WanModel
58
+ from ..models.wan_video_dit_recam_future import WanModelFuture
59
+ from ..models.wan_video_text_encoder import WanTextEncoder
60
+ from ..models.wan_video_image_encoder import WanImageEncoder
61
+ from ..models.wan_video_vae import WanVideoVAE
62
+
63
+
64
+ model_loader_configs = [
65
+ # These configs are provided for detecting model type automatically.
66
+ # The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
67
+ (None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
68
+ (None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
69
+ (None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
70
+ (None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
71
+ (None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
72
+ (None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
73
+ (None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
74
+ (None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
75
+ (None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
76
+ (None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
77
+ (None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
78
+ (None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
79
+ (None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
80
+ (None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
81
+ (None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
82
+ (None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
83
+ (None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
84
+ (None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
85
+ (None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
86
+ (None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
87
+ (None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
88
+ (None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
89
+ (None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
90
+ (None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
91
+ (None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
92
+ (None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
93
+ (None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
94
+ (None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
95
+ (None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"),
96
+ (None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
97
+ (None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
98
+ (None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
99
+ (None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
100
+ (None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
101
+ (None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
102
+ (None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
103
+ (None, "61cbcbc7ac11f169c5949223efa960d1", ["omnigen_transformer"], [OmniGenTransformer], "diffusers"),
104
+ (None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"),
105
+ (None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
106
+ (None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
107
+ (None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
108
+ (None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
109
+ (None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
110
+ (None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
111
+ (None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
112
+ (None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
113
+ (None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"),
114
+ (None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
115
+ (None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
116
+ (None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"),
117
+ (None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"),
118
+ (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
119
+ (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
120
+ (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
121
+ (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
122
+ (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
123
+ (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
124
+ (None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
125
+ (None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
126
+ ]
127
+ huggingface_model_loader_configs = [
128
+ # These configs are provided for detecting model type automatically.
129
+ # The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
130
+ ("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
131
+ ("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
132
+ ("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
133
+ ("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
134
+ # ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
135
+ ("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
136
+ ("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
137
+ ("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
138
+ ("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
139
+ ("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
140
+ ("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
141
+ ]
142
+ patch_model_loader_configs = [
143
+ # These configs are provided for detecting model type automatically.
144
+ # The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
145
+ ("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
146
+ ]
147
+
148
+ preset_models_on_huggingface = {
149
+ "HunyuanDiT": [
150
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
151
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
152
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
153
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
154
+ ],
155
+ "stable-video-diffusion-img2vid-xt": [
156
+ ("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
157
+ ],
158
+ "ExVideo-SVD-128f-v1": [
159
+ ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
160
+ ],
161
+ # Stable Diffusion
162
+ "StableDiffusion_v15": [
163
+ ("benjamin-paine/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
164
+ ],
165
+ "DreamShaper_8": [
166
+ ("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
167
+ ],
168
+ # Textual Inversion
169
+ "TextualInversion_VeryBadImageNegative_v1.3": [
170
+ ("gemasai/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
171
+ ],
172
+ # Stable Diffusion XL
173
+ "StableDiffusionXL_v1": [
174
+ ("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
175
+ ],
176
+ "BluePencilXL_v200": [
177
+ ("frankjoshua/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
178
+ ],
179
+ "StableDiffusionXL_Turbo": [
180
+ ("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
181
+ ],
182
+ # Stable Diffusion 3
183
+ "StableDiffusion3": [
184
+ ("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
185
+ ],
186
+ "StableDiffusion3_without_T5": [
187
+ ("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
188
+ ],
189
+ # ControlNet
190
+ "ControlNet_v11f1p_sd15_depth": [
191
+ ("lllyasviel/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
192
+ ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
193
+ ],
194
+ "ControlNet_v11p_sd15_softedge": [
195
+ ("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
196
+ ("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators")
197
+ ],
198
+ "ControlNet_v11f1e_sd15_tile": [
199
+ ("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
200
+ ],
201
+ "ControlNet_v11p_sd15_lineart": [
202
+ ("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
203
+ ("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
204
+ ("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators")
205
+ ],
206
+ "ControlNet_union_sdxl_promax": [
207
+ ("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
208
+ ("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
209
+ ],
210
+ # AnimateDiff
211
+ "AnimateDiff_v2": [
212
+ ("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
213
+ ],
214
+ "AnimateDiff_xl_beta": [
215
+ ("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
216
+ ],
217
+
218
+ # Qwen Prompt
219
+ "QwenPrompt": [
220
+ ("Qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
221
+ ("Qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
222
+ ("Qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
223
+ ("Qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
224
+ ("Qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
225
+ ("Qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
226
+ ("Qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
227
+ ("Qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
228
+ ],
229
+ # Beautiful Prompt
230
+ "BeautifulPrompt": [
231
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
232
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
233
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
234
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
235
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
236
+ ("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
237
+ ],
238
+ # Omost prompt
239
+ "OmostPrompt":[
240
+ ("lllyasviel/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
241
+ ("lllyasviel/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
242
+ ("lllyasviel/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
243
+ ("lllyasviel/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
244
+ ("lllyasviel/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
245
+ ("lllyasviel/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
246
+ ("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
247
+ ("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
248
+ ],
249
+ # Translator
250
+ "opus-mt-zh-en": [
251
+ ("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
252
+ ("Helsinki-NLP/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
253
+ ("Helsinki-NLP/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
254
+ ("Helsinki-NLP/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
255
+ ("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
256
+ ("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
257
+ ("Helsinki-NLP/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
258
+ ("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
259
+ ],
260
+ # IP-Adapter
261
+ "IP-Adapter-SD": [
262
+ ("h94/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
263
+ ("h94/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
264
+ ],
265
+ "IP-Adapter-SDXL": [
266
+ ("h94/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
267
+ ("h94/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
268
+ ],
269
+ "SDXL-vae-fp16-fix": [
270
+ ("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
271
+ ],
272
+ # Kolors
273
+ "Kolors": [
274
+ ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
275
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
276
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
277
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
278
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
279
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
280
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
281
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
282
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
283
+ ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
284
+ ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
285
+ ],
286
+ # FLUX
287
+ "FLUX.1-dev": [
288
+ ("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
289
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
290
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
291
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
292
+ ("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
293
+ ("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
294
+ ("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
295
+ ],
296
+ "InstantX/FLUX.1-dev-IP-Adapter": {
297
+ "file_list": [
298
+ ("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
299
+ ("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
300
+ ("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
301
+ ],
302
+ "load_path": [
303
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
304
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
305
+ ],
306
+ },
307
+ # RIFE
308
+ "RIFE": [
309
+ ("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
310
+ ],
311
+ # CogVideo
312
+ "CogVideoX-5B": [
313
+ ("THUDM/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
314
+ ("THUDM/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
315
+ ("THUDM/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
316
+ ("THUDM/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
317
+ ("THUDM/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
318
+ ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
319
+ ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
320
+ ("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
321
+ ("THUDM/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
322
+ ],
323
+ # Stable Diffusion 3.5
324
+ "StableDiffusion3.5-large": [
325
+ ("stabilityai/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
326
+ ("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
327
+ ("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
328
+ ("stabilityai/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
329
+ ],
330
+ }
331
+ preset_models_on_modelscope = {
332
+ # Hunyuan DiT
333
+ "HunyuanDiT": [
334
+ ("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
335
+ ("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
336
+ ("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
337
+ ("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
338
+ ],
339
+ # Stable Video Diffusion
340
+ "stable-video-diffusion-img2vid-xt": [
341
+ ("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
342
+ ],
343
+ # ExVideo
344
+ "ExVideo-SVD-128f-v1": [
345
+ ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
346
+ ],
347
+ "ExVideo-CogVideoX-LoRA-129f-v1": [
348
+ ("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
349
+ ],
350
+ # Stable Diffusion
351
+ "StableDiffusion_v15": [
352
+ ("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
353
+ ],
354
+ "DreamShaper_8": [
355
+ ("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
356
+ ],
357
+ "AingDiffusion_v12": [
358
+ ("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
359
+ ],
360
+ "Flat2DAnimerge_v45Sharp": [
361
+ ("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
362
+ ],
363
+ # Textual Inversion
364
+ "TextualInversion_VeryBadImageNegative_v1.3": [
365
+ ("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
366
+ ],
367
+ # Stable Diffusion XL
368
+ "StableDiffusionXL_v1": [
369
+ ("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
370
+ ],
371
+ "BluePencilXL_v200": [
372
+ ("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
373
+ ],
374
+ "StableDiffusionXL_Turbo": [
375
+ ("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
376
+ ],
377
+ "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
378
+ ("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
379
+ ],
380
+ # Stable Diffusion 3
381
+ "StableDiffusion3": [
382
+ ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
383
+ ],
384
+ "StableDiffusion3_without_T5": [
385
+ ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
386
+ ],
387
+ # ControlNet
388
+ "ControlNet_v11f1p_sd15_depth": [
389
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
390
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
391
+ ],
392
+ "ControlNet_v11p_sd15_softedge": [
393
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
394
+ ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
395
+ ],
396
+ "ControlNet_v11f1e_sd15_tile": [
397
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
398
+ ],
399
+ "ControlNet_v11p_sd15_lineart": [
400
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
401
+ ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
402
+ ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
403
+ ],
404
+ "ControlNet_union_sdxl_promax": [
405
+ ("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
406
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
407
+ ],
408
+ "Annotators:Depth": [
409
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
410
+ ],
411
+ "Annotators:Softedge": [
412
+ ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
413
+ ],
414
+ "Annotators:Lineart": [
415
+ ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
416
+ ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
417
+ ],
418
+ "Annotators:Normal": [
419
+ ("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
420
+ ],
421
+ "Annotators:Openpose": [
422
+ ("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
423
+ ("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
424
+ ("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
425
+ ],
426
+ # AnimateDiff
427
+ "AnimateDiff_v2": [
428
+ ("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
429
+ ],
430
+ "AnimateDiff_xl_beta": [
431
+ ("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
432
+ ],
433
+ # RIFE
434
+ "RIFE": [
435
+ ("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
436
+ ],
437
+ # Qwen Prompt
438
+ "QwenPrompt": {
439
+ "file_list": [
440
+ ("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
441
+ ("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
442
+ ("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
443
+ ("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
444
+ ("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
445
+ ("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
446
+ ("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
447
+ ("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
448
+ ],
449
+ "load_path": [
450
+ "models/QwenPrompt/qwen2-1.5b-instruct",
451
+ ],
452
+ },
453
+ # Beautiful Prompt
454
+ "BeautifulPrompt": {
455
+ "file_list": [
456
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
457
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
458
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
459
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
460
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
461
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
462
+ ],
463
+ "load_path": [
464
+ "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
465
+ ],
466
+ },
467
+ # Omost prompt
468
+ "OmostPrompt": {
469
+ "file_list": [
470
+ ("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
471
+ ("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
472
+ ("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
473
+ ("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
474
+ ("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
475
+ ("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
476
+ ("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
477
+ ("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
478
+ ],
479
+ "load_path": [
480
+ "models/OmostPrompt/omost-llama-3-8b-4bits",
481
+ ],
482
+ },
483
+ # Translator
484
+ "opus-mt-zh-en": {
485
+ "file_list": [
486
+ ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
487
+ ("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
488
+ ("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
489
+ ("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
490
+ ("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
491
+ ("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
492
+ ("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
493
+ ("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
494
+ ],
495
+ "load_path": [
496
+ "models/translator/opus-mt-zh-en",
497
+ ],
498
+ },
499
+ # IP-Adapter
500
+ "IP-Adapter-SD": [
501
+ ("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
502
+ ("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
503
+ ],
504
+ "IP-Adapter-SDXL": [
505
+ ("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
506
+ ("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
507
+ ],
508
+ # Kolors
509
+ "Kolors": {
510
+ "file_list": [
511
+ ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
512
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
513
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
514
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
515
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
516
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
517
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
518
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
519
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
520
+ ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
521
+ ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
522
+ ],
523
+ "load_path": [
524
+ "models/kolors/Kolors/text_encoder",
525
+ "models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
526
+ "models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
527
+ ],
528
+ },
529
+ "SDXL-vae-fp16-fix": [
530
+ ("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
531
+ ],
532
+ # FLUX
533
+ "FLUX.1-dev": {
534
+ "file_list": [
535
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
536
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
537
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
538
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
539
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
540
+ ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
541
+ ("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
542
+ ],
543
+ "load_path": [
544
+ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
545
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
546
+ "models/FLUX/FLUX.1-dev/ae.safetensors",
547
+ "models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
548
+ ],
549
+ },
550
+ "FLUX.1-schnell": {
551
+ "file_list": [
552
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
553
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
554
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
555
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
556
+ ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
557
+ ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
558
+ ("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
559
+ ],
560
+ "load_path": [
561
+ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
562
+ "models/FLUX/FLUX.1-dev/text_encoder_2",
563
+ "models/FLUX/FLUX.1-dev/ae.safetensors",
564
+ "models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
565
+ ],
566
+ },
567
+ "InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
568
+ ("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
569
+ ],
570
+ "jasperai/Flux.1-dev-Controlnet-Depth": [
571
+ ("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
572
+ ],
573
+ "jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
574
+ ("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
575
+ ],
576
+ "jasperai/Flux.1-dev-Controlnet-Upscaler": [
577
+ ("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
578
+ ],
579
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
580
+ ("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
581
+ ],
582
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
583
+ ("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
584
+ ],
585
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
586
+ ("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
587
+ ],
588
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
589
+ ("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
590
+ ],
591
+ "InstantX/FLUX.1-dev-IP-Adapter": {
592
+ "file_list": [
593
+ ("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
594
+ ("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
595
+ ("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
596
+ ],
597
+ "load_path": [
598
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
599
+ "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
600
+ ],
601
+ },
602
+ # ESRGAN
603
+ "ESRGAN_x4": [
604
+ ("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
605
+ ],
606
+ # RIFE
607
+ "RIFE": [
608
+ ("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
609
+ ],
610
+ # Omnigen
611
+ "OmniGen-v1": {
612
+ "file_list": [
613
+ ("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"),
614
+ ("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
615
+ ("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
616
+ ("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
617
+ ("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
618
+ ("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
619
+ ],
620
+ "load_path": [
621
+ "models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
622
+ "models/OmniGen/OmniGen-v1/model.safetensors",
623
+ ]
624
+ },
625
+ # CogVideo
626
+ "CogVideoX-5B": {
627
+ "file_list": [
628
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
629
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
630
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
631
+ ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
632
+ ("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
633
+ ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
634
+ ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
635
+ ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
636
+ ("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
637
+ ],
638
+ "load_path": [
639
+ "models/CogVideo/CogVideoX-5b/text_encoder",
640
+ "models/CogVideo/CogVideoX-5b/transformer",
641
+ "models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
642
+ ],
643
+ },
644
+ # Stable Diffusion 3.5
645
+ "StableDiffusion3.5-large": [
646
+ ("AI-ModelScope/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
647
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
648
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
649
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
650
+ ],
651
+ "StableDiffusion3.5-medium": [
652
+ ("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"),
653
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
654
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
655
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
656
+ ],
657
+ "StableDiffusion3.5-large-turbo": [
658
+ ("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"),
659
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
660
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
661
+ ("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
662
+ ],
663
+ "HunyuanVideo":{
664
+ "file_list": [
665
+ ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
666
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
667
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
668
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
669
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
670
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
671
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
672
+ ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
673
+ ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
674
+ ],
675
+ "load_path": [
676
+ "models/HunyuanVideo/text_encoder/model.safetensors",
677
+ "models/HunyuanVideo/text_encoder_2",
678
+ "models/HunyuanVideo/vae/pytorch_model.pt",
679
+ "models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
680
+ ],
681
+ },
682
+ "HunyuanVideoI2V":{
683
+ "file_list": [
684
+ ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideoI2V/text_encoder"),
685
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00001-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
686
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00002-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
687
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00003-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
688
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00004-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
689
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "config.json", "models/HunyuanVideoI2V/text_encoder_2"),
690
+ ("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model.safetensors.index.json", "models/HunyuanVideoI2V/text_encoder_2"),
691
+ ("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/vae/pytorch_model.pt", "models/HunyuanVideoI2V/vae"),
692
+ ("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideoI2V/transformers")
693
+ ],
694
+ "load_path": [
695
+ "models/HunyuanVideoI2V/text_encoder/model.safetensors",
696
+ "models/HunyuanVideoI2V/text_encoder_2",
697
+ "models/HunyuanVideoI2V/vae/pytorch_model.pt",
698
+ "models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
699
+ ],
700
+ },
701
+ "HunyuanVideo-fp8":{
702
+ "file_list": [
703
+ ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
704
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
705
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
706
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
707
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
708
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
709
+ ("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
710
+ ("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
711
+ ("DiffSynth-Studio/HunyuanVideo-safetensors", "model.fp8.safetensors", "models/HunyuanVideo/transformers")
712
+ ],
713
+ "load_path": [
714
+ "models/HunyuanVideo/text_encoder/model.safetensors",
715
+ "models/HunyuanVideo/text_encoder_2",
716
+ "models/HunyuanVideo/vae/pytorch_model.pt",
717
+ "models/HunyuanVideo/transformers/model.fp8.safetensors"
718
+ ],
719
+ },
720
+ }
721
+ Preset_model_id: TypeAlias = Literal[
722
+ "HunyuanDiT",
723
+ "stable-video-diffusion-img2vid-xt",
724
+ "ExVideo-SVD-128f-v1",
725
+ "ExVideo-CogVideoX-LoRA-129f-v1",
726
+ "StableDiffusion_v15",
727
+ "DreamShaper_8",
728
+ "AingDiffusion_v12",
729
+ "Flat2DAnimerge_v45Sharp",
730
+ "TextualInversion_VeryBadImageNegative_v1.3",
731
+ "StableDiffusionXL_v1",
732
+ "BluePencilXL_v200",
733
+ "StableDiffusionXL_Turbo",
734
+ "ControlNet_v11f1p_sd15_depth",
735
+ "ControlNet_v11p_sd15_softedge",
736
+ "ControlNet_v11f1e_sd15_tile",
737
+ "ControlNet_v11p_sd15_lineart",
738
+ "AnimateDiff_v2",
739
+ "AnimateDiff_xl_beta",
740
+ "RIFE",
741
+ "BeautifulPrompt",
742
+ "opus-mt-zh-en",
743
+ "IP-Adapter-SD",
744
+ "IP-Adapter-SDXL",
745
+ "StableDiffusion3",
746
+ "StableDiffusion3_without_T5",
747
+ "Kolors",
748
+ "SDXL-vae-fp16-fix",
749
+ "ControlNet_union_sdxl_promax",
750
+ "FLUX.1-dev",
751
+ "FLUX.1-schnell",
752
+ "InstantX/FLUX.1-dev-Controlnet-Union-alpha",
753
+ "jasperai/Flux.1-dev-Controlnet-Depth",
754
+ "jasperai/Flux.1-dev-Controlnet-Surface-Normals",
755
+ "jasperai/Flux.1-dev-Controlnet-Upscaler",
756
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
757
+ "alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
758
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
759
+ "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
760
+ "InstantX/FLUX.1-dev-IP-Adapter",
761
+ "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
762
+ "QwenPrompt",
763
+ "OmostPrompt",
764
+ "ESRGAN_x4",
765
+ "RIFE",
766
+ "OmniGen-v1",
767
+ "CogVideoX-5B",
768
+ "Annotators:Depth",
769
+ "Annotators:Softedge",
770
+ "Annotators:Lineart",
771
+ "Annotators:Normal",
772
+ "Annotators:Openpose",
773
+ "StableDiffusion3.5-large",
774
+ "StableDiffusion3.5-medium",
775
+ "HunyuanVideo",
776
+ "HunyuanVideo-fp8",
777
+ "HunyuanVideoI2V",
778
+ ]
diffsynth/controlnets/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager, FluxMultiControlNetManager
2
+ from .processors import Annotator
diffsynth/controlnets/controlnet_unit.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from .processors import Processor_id
4
+
5
+
6
+ class ControlNetConfigUnit:
7
+ def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False):
8
+ self.processor_id = processor_id
9
+ self.model_path = model_path
10
+ self.scale = scale
11
+ self.skip_processor = skip_processor
12
+
13
+
14
+ class ControlNetUnit:
15
+ def __init__(self, processor, model, scale=1.0):
16
+ self.processor = processor
17
+ self.model = model
18
+ self.scale = scale
19
+
20
+
21
+ class MultiControlNetManager:
22
+ def __init__(self, controlnet_units=[]):
23
+ self.processors = [unit.processor for unit in controlnet_units]
24
+ self.models = [unit.model for unit in controlnet_units]
25
+ self.scales = [unit.scale for unit in controlnet_units]
26
+
27
+ def cpu(self):
28
+ for model in self.models:
29
+ model.cpu()
30
+
31
+ def to(self, device):
32
+ for model in self.models:
33
+ model.to(device)
34
+ for processor in self.processors:
35
+ processor.to(device)
36
+
37
+ def process_image(self, image, processor_id=None):
38
+ if processor_id is None:
39
+ processed_image = [processor(image) for processor in self.processors]
40
+ else:
41
+ processed_image = [self.processors[processor_id](image)]
42
+ processed_image = torch.concat([
43
+ torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
44
+ for image_ in processed_image
45
+ ], dim=0)
46
+ return processed_image
47
+
48
+ def __call__(
49
+ self,
50
+ sample, timestep, encoder_hidden_states, conditionings,
51
+ tiled=False, tile_size=64, tile_stride=32, **kwargs
52
+ ):
53
+ res_stack = None
54
+ for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
55
+ res_stack_ = model(
56
+ sample, timestep, encoder_hidden_states, conditioning, **kwargs,
57
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
58
+ processor_id=processor.processor_id
59
+ )
60
+ res_stack_ = [res * scale for res in res_stack_]
61
+ if res_stack is None:
62
+ res_stack = res_stack_
63
+ else:
64
+ res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
65
+ return res_stack
66
+
67
+
68
+ class FluxMultiControlNetManager(MultiControlNetManager):
69
+ def __init__(self, controlnet_units=[]):
70
+ super().__init__(controlnet_units=controlnet_units)
71
+
72
+ def process_image(self, image, processor_id=None):
73
+ if processor_id is None:
74
+ processed_image = [processor(image) for processor in self.processors]
75
+ else:
76
+ processed_image = [self.processors[processor_id](image)]
77
+ return processed_image
78
+
79
+ def __call__(self, conditionings, **kwargs):
80
+ res_stack, single_res_stack = None, None
81
+ for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
82
+ res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs)
83
+ res_stack_ = [res * scale for res in res_stack_]
84
+ single_res_stack_ = [res * scale for res in single_res_stack_]
85
+ if res_stack is None:
86
+ res_stack = res_stack_
87
+ single_res_stack = single_res_stack_
88
+ else:
89
+ res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
90
+ single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
91
+ return res_stack, single_res_stack
diffsynth/controlnets/processors.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import Literal, TypeAlias
2
+
3
+
4
+ Processor_id: TypeAlias = Literal[
5
+ "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
6
+ ]
7
+
8
+ class Annotator:
9
+ def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
10
+ if not skip_processor:
11
+ if processor_id == "canny":
12
+ from controlnet_aux.processor import CannyDetector
13
+ self.processor = CannyDetector()
14
+ elif processor_id == "depth":
15
+ from controlnet_aux.processor import MidasDetector
16
+ self.processor = MidasDetector.from_pretrained(model_path).to(device)
17
+ elif processor_id == "softedge":
18
+ from controlnet_aux.processor import HEDdetector
19
+ self.processor = HEDdetector.from_pretrained(model_path).to(device)
20
+ elif processor_id == "lineart":
21
+ from controlnet_aux.processor import LineartDetector
22
+ self.processor = LineartDetector.from_pretrained(model_path).to(device)
23
+ elif processor_id == "lineart_anime":
24
+ from controlnet_aux.processor import LineartAnimeDetector
25
+ self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
26
+ elif processor_id == "openpose":
27
+ from controlnet_aux.processor import OpenposeDetector
28
+ self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
29
+ elif processor_id == "normal":
30
+ from controlnet_aux.processor import NormalBaeDetector
31
+ self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
32
+ elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
33
+ self.processor = None
34
+ else:
35
+ raise ValueError(f"Unsupported processor_id: {processor_id}")
36
+ else:
37
+ self.processor = None
38
+
39
+ self.processor_id = processor_id
40
+ self.detect_resolution = detect_resolution
41
+
42
+ def to(self,device):
43
+ if hasattr(self.processor,"model") and hasattr(self.processor.model,"to"):
44
+
45
+ self.processor.model.to(device)
46
+
47
+ def __call__(self, image, mask=None):
48
+ width, height = image.size
49
+ if self.processor_id == "openpose":
50
+ kwargs = {
51
+ "include_body": True,
52
+ "include_hand": True,
53
+ "include_face": True
54
+ }
55
+ else:
56
+ kwargs = {}
57
+ if self.processor is not None:
58
+ detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
59
+ image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
60
+ image = image.resize((width, height))
61
+ return image
62
+
diffsynth/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .video import VideoData, save_video, save_frames
diffsynth/data/simple_text_image.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os, torchvision
2
+ from torchvision import transforms
3
+ import pandas as pd
4
+ from PIL import Image
5
+
6
+
7
+
8
+ class TextImageDataset(torch.utils.data.Dataset):
9
+ def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
10
+ self.steps_per_epoch = steps_per_epoch
11
+ metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
12
+ self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
13
+ self.text = metadata["text"].to_list()
14
+ self.height = height
15
+ self.width = width
16
+ self.image_processor = transforms.Compose(
17
+ [
18
+ transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
19
+ transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
20
+ transforms.ToTensor(),
21
+ transforms.Normalize([0.5], [0.5]),
22
+ ]
23
+ )
24
+
25
+
26
+ def __getitem__(self, index):
27
+ data_id = torch.randint(0, len(self.path), (1,))[0]
28
+ data_id = (data_id + index) % len(self.path) # For fixed seed.
29
+ text = self.text[data_id]
30
+ image = Image.open(self.path[data_id]).convert("RGB")
31
+ target_height, target_width = self.height, self.width
32
+ width, height = image.size
33
+ scale = max(target_width / width, target_height / height)
34
+ shape = [round(height*scale),round(width*scale)]
35
+ image = torchvision.transforms.functional.resize(image,shape,interpolation=transforms.InterpolationMode.BILINEAR)
36
+ image = self.image_processor(image)
37
+ return {"text": text, "image": image}
38
+
39
+
40
+ def __len__(self):
41
+ return self.steps_per_epoch
diffsynth/data/video.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio, os
2
+ import numpy as np
3
+ from PIL import Image
4
+ from tqdm import tqdm
5
+
6
+
7
+ class LowMemoryVideo:
8
+ def __init__(self, file_name):
9
+ self.reader = imageio.get_reader(file_name)
10
+
11
+ def __len__(self):
12
+ return self.reader.count_frames()
13
+
14
+ def __getitem__(self, item):
15
+ return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
16
+
17
+ def __del__(self):
18
+ self.reader.close()
19
+
20
+
21
+ def split_file_name(file_name):
22
+ result = []
23
+ number = -1
24
+ for i in file_name:
25
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
26
+ if number == -1:
27
+ number = 0
28
+ number = number*10 + ord(i) - ord("0")
29
+ else:
30
+ if number != -1:
31
+ result.append(number)
32
+ number = -1
33
+ result.append(i)
34
+ if number != -1:
35
+ result.append(number)
36
+ result = tuple(result)
37
+ return result
38
+
39
+
40
+ def search_for_images(folder):
41
+ file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
42
+ file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
43
+ file_list = [i[1] for i in sorted(file_list)]
44
+ file_list = [os.path.join(folder, i) for i in file_list]
45
+ return file_list
46
+
47
+
48
+ class LowMemoryImageFolder:
49
+ def __init__(self, folder, file_list=None):
50
+ if file_list is None:
51
+ self.file_list = search_for_images(folder)
52
+ else:
53
+ self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
54
+
55
+ def __len__(self):
56
+ return len(self.file_list)
57
+
58
+ def __getitem__(self, item):
59
+ return Image.open(self.file_list[item]).convert("RGB")
60
+
61
+ def __del__(self):
62
+ pass
63
+
64
+
65
+ def crop_and_resize(image, height, width):
66
+ image = np.array(image)
67
+ image_height, image_width, _ = image.shape
68
+ if image_height / image_width < height / width:
69
+ croped_width = int(image_height / height * width)
70
+ left = (image_width - croped_width) // 2
71
+ image = image[:, left: left+croped_width]
72
+ image = Image.fromarray(image).resize((width, height))
73
+ else:
74
+ croped_height = int(image_width / width * height)
75
+ left = (image_height - croped_height) // 2
76
+ image = image[left: left+croped_height, :]
77
+ image = Image.fromarray(image).resize((width, height))
78
+ return image
79
+
80
+
81
+ class VideoData:
82
+ def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
83
+ if video_file is not None:
84
+ self.data_type = "video"
85
+ self.data = LowMemoryVideo(video_file, **kwargs)
86
+ elif image_folder is not None:
87
+ self.data_type = "images"
88
+ self.data = LowMemoryImageFolder(image_folder, **kwargs)
89
+ else:
90
+ raise ValueError("Cannot open video or image folder")
91
+ self.length = None
92
+ self.set_shape(height, width)
93
+
94
+ def raw_data(self):
95
+ frames = []
96
+ for i in range(self.__len__()):
97
+ frames.append(self.__getitem__(i))
98
+ return frames
99
+
100
+ def set_length(self, length):
101
+ self.length = length
102
+
103
+ def set_shape(self, height, width):
104
+ self.height = height
105
+ self.width = width
106
+
107
+ def __len__(self):
108
+ if self.length is None:
109
+ return len(self.data)
110
+ else:
111
+ return self.length
112
+
113
+ def shape(self):
114
+ if self.height is not None and self.width is not None:
115
+ return self.height, self.width
116
+ else:
117
+ height, width, _ = self.__getitem__(0).shape
118
+ return height, width
119
+
120
+ def __getitem__(self, item):
121
+ frame = self.data.__getitem__(item)
122
+ width, height = frame.size
123
+ if self.height is not None and self.width is not None:
124
+ if self.height != height or self.width != width:
125
+ frame = crop_and_resize(frame, self.height, self.width)
126
+ return frame
127
+
128
+ def __del__(self):
129
+ pass
130
+
131
+ def save_images(self, folder):
132
+ os.makedirs(folder, exist_ok=True)
133
+ for i in tqdm(range(self.__len__()), desc="Saving images"):
134
+ frame = self.__getitem__(i)
135
+ frame.save(os.path.join(folder, f"{i}.png"))
136
+
137
+
138
+ def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
139
+ writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
140
+ for frame in tqdm(frames, desc="Saving video"):
141
+ frame = np.array(frame)
142
+ writer.append_data(frame)
143
+ writer.close()
144
+
145
+ def save_frames(frames, save_path):
146
+ os.makedirs(save_path, exist_ok=True)
147
+ for i, frame in enumerate(tqdm(frames, desc="Saving images")):
148
+ frame.save(os.path.join(save_path, f"{i}.png"))
diffsynth/extensions/ESRGAN/__init__.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import repeat
3
+ from PIL import Image
4
+ import numpy as np
5
+
6
+
7
+ class ResidualDenseBlock(torch.nn.Module):
8
+
9
+ def __init__(self, num_feat=64, num_grow_ch=32):
10
+ super(ResidualDenseBlock, self).__init__()
11
+ self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
12
+ self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
13
+ self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
14
+ self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
15
+ self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
16
+ self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
17
+
18
+ def forward(self, x):
19
+ x1 = self.lrelu(self.conv1(x))
20
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
21
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
22
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
23
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
24
+ return x5 * 0.2 + x
25
+
26
+
27
+ class RRDB(torch.nn.Module):
28
+
29
+ def __init__(self, num_feat, num_grow_ch=32):
30
+ super(RRDB, self).__init__()
31
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
32
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
33
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
34
+
35
+ def forward(self, x):
36
+ out = self.rdb1(x)
37
+ out = self.rdb2(out)
38
+ out = self.rdb3(out)
39
+ return out * 0.2 + x
40
+
41
+
42
+ class RRDBNet(torch.nn.Module):
43
+
44
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, **kwargs):
45
+ super(RRDBNet, self).__init__()
46
+ self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
47
+ self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
48
+ self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
49
+ # upsample
50
+ self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
51
+ self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
52
+ self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
53
+ self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
54
+ self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
55
+
56
+ def forward(self, x):
57
+ feat = x
58
+ feat = self.conv_first(feat)
59
+ body_feat = self.conv_body(self.body(feat))
60
+ feat = feat + body_feat
61
+ # upsample
62
+ feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
63
+ feat = self.lrelu(self.conv_up1(feat))
64
+ feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
65
+ feat = self.lrelu(self.conv_up2(feat))
66
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
67
+ return out
68
+
69
+ @staticmethod
70
+ def state_dict_converter():
71
+ return RRDBNetStateDictConverter()
72
+
73
+
74
+ class RRDBNetStateDictConverter:
75
+ def __init__(self):
76
+ pass
77
+
78
+ def from_diffusers(self, state_dict):
79
+ return state_dict, {"upcast_to_float32": True}
80
+
81
+ def from_civitai(self, state_dict):
82
+ return state_dict, {"upcast_to_float32": True}
83
+
84
+
85
+ class ESRGAN(torch.nn.Module):
86
+ def __init__(self, model):
87
+ super().__init__()
88
+ self.model = model
89
+
90
+ @staticmethod
91
+ def from_model_manager(model_manager):
92
+ return ESRGAN(model_manager.fetch_model("esrgan"))
93
+
94
+ def process_image(self, image):
95
+ image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
96
+ return image
97
+
98
+ def process_images(self, images):
99
+ images = [self.process_image(image) for image in images]
100
+ images = torch.stack(images)
101
+ return images
102
+
103
+ def decode_images(self, images):
104
+ images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
105
+ images = [Image.fromarray(image) for image in images]
106
+ return images
107
+
108
+ @torch.no_grad()
109
+ def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
110
+ if not isinstance(images, list):
111
+ images = [images]
112
+ is_single_image = True
113
+ else:
114
+ is_single_image = False
115
+
116
+ # Preprocess
117
+ input_tensor = self.process_images(images)
118
+
119
+ # Interpolate
120
+ output_tensor = []
121
+ for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
122
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
123
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
124
+ batch_input_tensor = batch_input_tensor.to(
125
+ device=self.model.conv_first.weight.device,
126
+ dtype=self.model.conv_first.weight.dtype)
127
+ batch_output_tensor = self.model(batch_input_tensor)
128
+ output_tensor.append(batch_output_tensor.cpu())
129
+
130
+ # Output
131
+ output_tensor = torch.concat(output_tensor, dim=0)
132
+
133
+ # To images
134
+ output_images = self.decode_images(output_tensor)
135
+ if is_single_image:
136
+ output_images = output_images[0]
137
+ return output_images
diffsynth/extensions/FastBlend/__init__.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .runners.fast import TableManager, PyramidPatchMatcher
2
+ from PIL import Image
3
+ import numpy as np
4
+ import cupy as cp
5
+
6
+
7
+ class FastBlendSmoother:
8
+ def __init__(self):
9
+ self.batch_size = 8
10
+ self.window_size = 64
11
+ self.ebsynth_config = {
12
+ "minimum_patch_size": 5,
13
+ "threads_per_block": 8,
14
+ "num_iter": 5,
15
+ "gpu_id": 0,
16
+ "guide_weight": 10.0,
17
+ "initialize": "identity",
18
+ "tracking_window_size": 0,
19
+ }
20
+
21
+ @staticmethod
22
+ def from_model_manager(model_manager):
23
+ # TODO: fetch GPU ID from model_manager
24
+ return FastBlendSmoother()
25
+
26
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
27
+ frames_guide = [np.array(frame) for frame in frames_guide]
28
+ frames_style = [np.array(frame) for frame in frames_style]
29
+ table_manager = TableManager()
30
+ patch_match_engine = PyramidPatchMatcher(
31
+ image_height=frames_style[0].shape[0],
32
+ image_width=frames_style[0].shape[1],
33
+ channel=3,
34
+ **ebsynth_config
35
+ )
36
+ # left part
37
+ table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
38
+ table_l = table_manager.remapping_table_to_blending_table(table_l)
39
+ table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
40
+ # right part
41
+ table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
42
+ table_r = table_manager.remapping_table_to_blending_table(table_r)
43
+ table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
44
+ # merge
45
+ frames = []
46
+ for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
47
+ weight_m = -1
48
+ weight = weight_l + weight_m + weight_r
49
+ frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
50
+ frames.append(frame)
51
+ frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
52
+ return frames
53
+
54
+ def __call__(self, rendered_frames, original_frames=None, **kwargs):
55
+ frames = self.run(
56
+ original_frames, rendered_frames,
57
+ self.batch_size, self.window_size, self.ebsynth_config
58
+ )
59
+ mempool = cp.get_default_memory_pool()
60
+ pinned_mempool = cp.get_default_pinned_memory_pool()
61
+ mempool.free_all_blocks()
62
+ pinned_mempool.free_all_blocks()
63
+ return frames
diffsynth/extensions/FastBlend/api.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner
2
+ from .data import VideoData, get_video_fps, save_video, search_for_images
3
+ import os
4
+ import gradio as gr
5
+
6
+
7
+ def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder):
8
+ frames_guide = VideoData(video_guide, video_guide_folder)
9
+ frames_style = VideoData(video_style, video_style_folder)
10
+ message = ""
11
+ if len(frames_guide) < len(frames_style):
12
+ message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n"
13
+ frames_style.set_length(len(frames_guide))
14
+ elif len(frames_guide) > len(frames_style):
15
+ message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n"
16
+ frames_guide.set_length(len(frames_style))
17
+ height_guide, width_guide = frames_guide.shape()
18
+ height_style, width_style = frames_style.shape()
19
+ if height_guide != height_style or width_guide != width_style:
20
+ message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n"
21
+ frames_style.set_shape(height_guide, width_guide)
22
+ return frames_guide, frames_style, message
23
+
24
+
25
+ def smooth_video(
26
+ video_guide,
27
+ video_guide_folder,
28
+ video_style,
29
+ video_style_folder,
30
+ mode,
31
+ window_size,
32
+ batch_size,
33
+ tracking_window_size,
34
+ output_path,
35
+ fps,
36
+ minimum_patch_size,
37
+ num_iter,
38
+ guide_weight,
39
+ initialize,
40
+ progress = None,
41
+ ):
42
+ # input
43
+ frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder)
44
+ if len(message) > 0:
45
+ print(message)
46
+ # output
47
+ if output_path == "":
48
+ if video_style is None:
49
+ output_path = os.path.join(video_style_folder, "output")
50
+ else:
51
+ output_path = os.path.join(os.path.split(video_style)[0], "output")
52
+ os.makedirs(output_path, exist_ok=True)
53
+ print("No valid output_path. Your video will be saved here:", output_path)
54
+ elif not os.path.exists(output_path):
55
+ os.makedirs(output_path, exist_ok=True)
56
+ print("Your video will be saved here:", output_path)
57
+ frames_path = os.path.join(output_path, "frames")
58
+ video_path = os.path.join(output_path, "video.mp4")
59
+ os.makedirs(frames_path, exist_ok=True)
60
+ # process
61
+ if mode == "Fast" or mode == "Balanced":
62
+ tracking_window_size = 0
63
+ ebsynth_config = {
64
+ "minimum_patch_size": minimum_patch_size,
65
+ "threads_per_block": 8,
66
+ "num_iter": num_iter,
67
+ "gpu_id": 0,
68
+ "guide_weight": guide_weight,
69
+ "initialize": initialize,
70
+ "tracking_window_size": tracking_window_size,
71
+ }
72
+ if mode == "Fast":
73
+ FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
74
+ elif mode == "Balanced":
75
+ BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
76
+ elif mode == "Accurate":
77
+ AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
78
+ # output
79
+ try:
80
+ fps = int(fps)
81
+ except:
82
+ fps = get_video_fps(video_style) if video_style is not None else 30
83
+ print("Fps:", fps)
84
+ print("Saving video...")
85
+ video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps)
86
+ print("Success!")
87
+ print("Your frames are here:", frames_path)
88
+ print("Your video is here:", video_path)
89
+ return output_path, fps, video_path
90
+
91
+
92
+ class KeyFrameMatcher:
93
+ def __init__(self):
94
+ pass
95
+
96
+ def extract_number_from_filename(self, file_name):
97
+ result = []
98
+ number = -1
99
+ for i in file_name:
100
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
101
+ if number == -1:
102
+ number = 0
103
+ number = number*10 + ord(i) - ord("0")
104
+ else:
105
+ if number != -1:
106
+ result.append(number)
107
+ number = -1
108
+ if number != -1:
109
+ result.append(number)
110
+ result = tuple(result)
111
+ return result
112
+
113
+ def extract_number_from_filenames(self, file_names):
114
+ numbers = [self.extract_number_from_filename(file_name) for file_name in file_names]
115
+ min_length = min(len(i) for i in numbers)
116
+ for i in range(min_length-1, -1, -1):
117
+ if len(set(number[i] for number in numbers))==len(file_names):
118
+ return [number[i] for number in numbers]
119
+ return list(range(len(file_names)))
120
+
121
+ def match_using_filename(self, file_names_a, file_names_b):
122
+ file_names_b_set = set(file_names_b)
123
+ matched_file_name = []
124
+ for file_name in file_names_a:
125
+ if file_name not in file_names_b_set:
126
+ matched_file_name.append(None)
127
+ else:
128
+ matched_file_name.append(file_name)
129
+ return matched_file_name
130
+
131
+ def match_using_numbers(self, file_names_a, file_names_b):
132
+ numbers_a = self.extract_number_from_filenames(file_names_a)
133
+ numbers_b = self.extract_number_from_filenames(file_names_b)
134
+ numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)}
135
+ matched_file_name = []
136
+ for number in numbers_a:
137
+ if number in numbers_b_dict:
138
+ matched_file_name.append(numbers_b_dict[number])
139
+ else:
140
+ matched_file_name.append(None)
141
+ return matched_file_name
142
+
143
+ def match_filenames(self, file_names_a, file_names_b):
144
+ matched_file_name = self.match_using_filename(file_names_a, file_names_b)
145
+ if sum([i is not None for i in matched_file_name]) > 0:
146
+ return matched_file_name
147
+ matched_file_name = self.match_using_numbers(file_names_a, file_names_b)
148
+ return matched_file_name
149
+
150
+
151
+ def detect_frames(frames_path, keyframes_path):
152
+ if not os.path.exists(frames_path) and not os.path.exists(keyframes_path):
153
+ return "Please input the directory of guide video and rendered frames"
154
+ elif not os.path.exists(frames_path):
155
+ return "Please input the directory of guide video"
156
+ elif not os.path.exists(keyframes_path):
157
+ return "Please input the directory of rendered frames"
158
+ frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
159
+ keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
160
+ if len(frames)==0:
161
+ return f"No images detected in {frames_path}"
162
+ if len(keyframes)==0:
163
+ return f"No images detected in {keyframes_path}"
164
+ matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
165
+ max_filename_length = max([len(i) for i in frames])
166
+ if sum([i is not None for i in matched_keyframes])==0:
167
+ message = ""
168
+ for frame, matched_keyframe in zip(frames, matched_keyframes):
169
+ message += frame + " " * (max_filename_length - len(frame) + 1)
170
+ message += "--> No matched keyframes\n"
171
+ else:
172
+ message = ""
173
+ for frame, matched_keyframe in zip(frames, matched_keyframes):
174
+ message += frame + " " * (max_filename_length - len(frame) + 1)
175
+ if matched_keyframe is None:
176
+ message += "--> [to be rendered]\n"
177
+ else:
178
+ message += f"--> {matched_keyframe}\n"
179
+ return message
180
+
181
+
182
+ def check_input_for_interpolating(frames_path, keyframes_path):
183
+ # search for images
184
+ frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
185
+ keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
186
+ # match frames
187
+ matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
188
+ file_list = [file_name for file_name in matched_keyframes if file_name is not None]
189
+ index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None]
190
+ frames_guide = VideoData(None, frames_path)
191
+ frames_style = VideoData(None, keyframes_path, file_list=file_list)
192
+ # match shape
193
+ message = ""
194
+ height_guide, width_guide = frames_guide.shape()
195
+ height_style, width_style = frames_style.shape()
196
+ if height_guide != height_style or width_guide != width_style:
197
+ message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n"
198
+ frames_style.set_shape(height_guide, width_guide)
199
+ return frames_guide, frames_style, index_style, message
200
+
201
+
202
+ def interpolate_video(
203
+ frames_path,
204
+ keyframes_path,
205
+ output_path,
206
+ fps,
207
+ batch_size,
208
+ tracking_window_size,
209
+ minimum_patch_size,
210
+ num_iter,
211
+ guide_weight,
212
+ initialize,
213
+ progress = None,
214
+ ):
215
+ # input
216
+ frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path)
217
+ if len(message) > 0:
218
+ print(message)
219
+ # output
220
+ if output_path == "":
221
+ output_path = os.path.join(keyframes_path, "output")
222
+ os.makedirs(output_path, exist_ok=True)
223
+ print("No valid output_path. Your video will be saved here:", output_path)
224
+ elif not os.path.exists(output_path):
225
+ os.makedirs(output_path, exist_ok=True)
226
+ print("Your video will be saved here:", output_path)
227
+ output_frames_path = os.path.join(output_path, "frames")
228
+ output_video_path = os.path.join(output_path, "video.mp4")
229
+ os.makedirs(output_frames_path, exist_ok=True)
230
+ # process
231
+ ebsynth_config = {
232
+ "minimum_patch_size": minimum_patch_size,
233
+ "threads_per_block": 8,
234
+ "num_iter": num_iter,
235
+ "gpu_id": 0,
236
+ "guide_weight": guide_weight,
237
+ "initialize": initialize,
238
+ "tracking_window_size": tracking_window_size
239
+ }
240
+ if len(index_style)==1:
241
+ InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
242
+ else:
243
+ InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
244
+ try:
245
+ fps = int(fps)
246
+ except:
247
+ fps = 30
248
+ print("Fps:", fps)
249
+ print("Saving video...")
250
+ video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps)
251
+ print("Success!")
252
+ print("Your frames are here:", output_frames_path)
253
+ print("Your video is here:", video_path)
254
+ return output_path, fps, video_path
255
+
256
+
257
+ def on_ui_tabs():
258
+ with gr.Blocks(analytics_enabled=False) as ui_component:
259
+ with gr.Tab("Blend"):
260
+ gr.Markdown("""
261
+ # Blend
262
+
263
+ Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution.
264
+ """)
265
+ with gr.Row():
266
+ with gr.Column():
267
+ with gr.Tab("Guide video"):
268
+ video_guide = gr.Video(label="Guide video")
269
+ with gr.Tab("Guide video (images format)"):
270
+ video_guide_folder = gr.Textbox(label="Guide video (images format)", value="")
271
+ with gr.Column():
272
+ with gr.Tab("Style video"):
273
+ video_style = gr.Video(label="Style video")
274
+ with gr.Tab("Style video (images format)"):
275
+ video_style_folder = gr.Textbox(label="Style video (images format)", value="")
276
+ with gr.Column():
277
+ output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video")
278
+ fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
279
+ video_output = gr.Video(label="Output video", interactive=False, show_share_button=True)
280
+ btn = gr.Button(value="Blend")
281
+ with gr.Row():
282
+ with gr.Column():
283
+ gr.Markdown("# Settings")
284
+ mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True)
285
+ window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True)
286
+ batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
287
+ tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True)
288
+ gr.Markdown("## Advanced Settings")
289
+ minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True)
290
+ num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
291
+ guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
292
+ initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
293
+ with gr.Column():
294
+ gr.Markdown("""
295
+ # Reference
296
+
297
+ * Output directory: the directory to save the video.
298
+ * Inference mode
299
+
300
+ |Mode|Time|Memory|Quality|Frame by frame output|Description|
301
+ |-|-|-|-|-|-|
302
+ |Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.|
303
+ |Balanced|■■|■|■■|Yes|Blend the frames naively.|
304
+ |Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.|
305
+
306
+ * Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy.
307
+ * Batch size: a larger batch size makes the program faster but requires more VRAM.
308
+ * Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
309
+ * Advanced settings
310
+ * Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5)
311
+ * Number of iterations: the number of iterations of patch matching. (Default: 5)
312
+ * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
313
+ * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
314
+ """)
315
+ btn.click(
316
+ smooth_video,
317
+ inputs=[
318
+ video_guide,
319
+ video_guide_folder,
320
+ video_style,
321
+ video_style_folder,
322
+ mode,
323
+ window_size,
324
+ batch_size,
325
+ tracking_window_size,
326
+ output_path,
327
+ fps,
328
+ minimum_patch_size,
329
+ num_iter,
330
+ guide_weight,
331
+ initialize
332
+ ],
333
+ outputs=[output_path, fps, video_output]
334
+ )
335
+ with gr.Tab("Interpolate"):
336
+ gr.Markdown("""
337
+ # Interpolate
338
+
339
+ Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution.
340
+ """)
341
+ with gr.Row():
342
+ with gr.Column():
343
+ with gr.Row():
344
+ with gr.Column():
345
+ video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="")
346
+ with gr.Column():
347
+ rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="")
348
+ with gr.Row():
349
+ detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False)
350
+ video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
351
+ rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
352
+ with gr.Column():
353
+ output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes")
354
+ fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
355
+ video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True)
356
+ btn_ = gr.Button(value="Interpolate")
357
+ with gr.Row():
358
+ with gr.Column():
359
+ gr.Markdown("# Settings")
360
+ batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
361
+ tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True)
362
+ gr.Markdown("## Advanced Settings")
363
+ minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True)
364
+ num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
365
+ guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
366
+ initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
367
+ with gr.Column():
368
+ gr.Markdown("""
369
+ # Reference
370
+
371
+ * Output directory: the directory to save the video.
372
+ * Batch size: a larger batch size makes the program faster but requires more VRAM.
373
+ * Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
374
+ * Advanced settings
375
+ * Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)**
376
+ * Number of iterations: the number of iterations of patch matching. (Default: 5)
377
+ * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
378
+ * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
379
+ """)
380
+ btn_.click(
381
+ interpolate_video,
382
+ inputs=[
383
+ video_guide_folder_,
384
+ rendered_keyframes_,
385
+ output_path_,
386
+ fps_,
387
+ batch_size_,
388
+ tracking_window_size_,
389
+ minimum_patch_size_,
390
+ num_iter_,
391
+ guide_weight_,
392
+ initialize_,
393
+ ],
394
+ outputs=[output_path_, fps_, video_output_]
395
+ )
396
+
397
+ return [(ui_component, "FastBlend", "FastBlend_ui")]
diffsynth/extensions/FastBlend/cupy_kernels.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cupy as cp
2
+
3
+ remapping_kernel = cp.RawKernel(r'''
4
+ extern "C" __global__
5
+ void remap(
6
+ const int height,
7
+ const int width,
8
+ const int channel,
9
+ const int patch_size,
10
+ const int pad_size,
11
+ const float* source_style,
12
+ const int* nnf,
13
+ float* target_style
14
+ ) {
15
+ const int r = (patch_size - 1) / 2;
16
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
17
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
18
+ if (x >= height or y >= width) return;
19
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
20
+ const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
21
+ const int min_px = x < r ? -x : -r;
22
+ const int max_px = x + r > height - 1 ? height - 1 - x : r;
23
+ const int min_py = y < r ? -y : -r;
24
+ const int max_py = y + r > width - 1 ? width - 1 - y : r;
25
+ int num = 0;
26
+ for (int px = min_px; px <= max_px; px++){
27
+ for (int py = min_py; py <= max_py; py++){
28
+ const int nid = (x + px) * width + y + py;
29
+ const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
30
+ const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
31
+ if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
32
+ const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
33
+ num++;
34
+ for (int c = 0; c < channel; c++){
35
+ target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
36
+ }
37
+ }
38
+ }
39
+ for (int c = 0; c < channel; c++){
40
+ target_style[z + pid * channel + c] /= num;
41
+ }
42
+ }
43
+ ''', 'remap')
44
+
45
+
46
+ patch_error_kernel = cp.RawKernel(r'''
47
+ extern "C" __global__
48
+ void patch_error(
49
+ const int height,
50
+ const int width,
51
+ const int channel,
52
+ const int patch_size,
53
+ const int pad_size,
54
+ const float* source,
55
+ const int* nnf,
56
+ const float* target,
57
+ float* error
58
+ ) {
59
+ const int r = (patch_size - 1) / 2;
60
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
61
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
62
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
63
+ if (x >= height or y >= width) return;
64
+ const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
65
+ const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
66
+ float e = 0;
67
+ for (int px = -r; px <= r; px++){
68
+ for (int py = -r; py <= r; py++){
69
+ const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
70
+ const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
71
+ for (int c = 0; c < channel; c++){
72
+ const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
73
+ e += diff * diff;
74
+ }
75
+ }
76
+ }
77
+ error[blockIdx.z * height * width + x * width + y] = e;
78
+ }
79
+ ''', 'patch_error')
80
+
81
+
82
+ pairwise_patch_error_kernel = cp.RawKernel(r'''
83
+ extern "C" __global__
84
+ void pairwise_patch_error(
85
+ const int height,
86
+ const int width,
87
+ const int channel,
88
+ const int patch_size,
89
+ const int pad_size,
90
+ const float* source_a,
91
+ const int* nnf_a,
92
+ const float* source_b,
93
+ const int* nnf_b,
94
+ float* error
95
+ ) {
96
+ const int r = (patch_size - 1) / 2;
97
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
98
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
99
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
100
+ if (x >= height or y >= width) return;
101
+ const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
102
+ const int x_a = nnf_a[z_nnf + 0];
103
+ const int y_a = nnf_a[z_nnf + 1];
104
+ const int x_b = nnf_b[z_nnf + 0];
105
+ const int y_b = nnf_b[z_nnf + 1];
106
+ float e = 0;
107
+ for (int px = -r; px <= r; px++){
108
+ for (int py = -r; py <= r; py++){
109
+ const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
110
+ const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
111
+ for (int c = 0; c < channel; c++){
112
+ const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
113
+ e += diff * diff;
114
+ }
115
+ }
116
+ }
117
+ error[blockIdx.z * height * width + x * width + y] = e;
118
+ }
119
+ ''', 'pairwise_patch_error')
diffsynth/extensions/FastBlend/data.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio, os
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+
6
+ def read_video(file_name):
7
+ reader = imageio.get_reader(file_name)
8
+ video = []
9
+ for frame in reader:
10
+ frame = np.array(frame)
11
+ video.append(frame)
12
+ reader.close()
13
+ return video
14
+
15
+
16
+ def get_video_fps(file_name):
17
+ reader = imageio.get_reader(file_name)
18
+ fps = reader.get_meta_data()["fps"]
19
+ reader.close()
20
+ return fps
21
+
22
+
23
+ def save_video(frames_path, video_path, num_frames, fps):
24
+ writer = imageio.get_writer(video_path, fps=fps, quality=9)
25
+ for i in range(num_frames):
26
+ frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
27
+ writer.append_data(frame)
28
+ writer.close()
29
+ return video_path
30
+
31
+
32
+ class LowMemoryVideo:
33
+ def __init__(self, file_name):
34
+ self.reader = imageio.get_reader(file_name)
35
+
36
+ def __len__(self):
37
+ return self.reader.count_frames()
38
+
39
+ def __getitem__(self, item):
40
+ return np.array(self.reader.get_data(item))
41
+
42
+ def __del__(self):
43
+ self.reader.close()
44
+
45
+
46
+ def split_file_name(file_name):
47
+ result = []
48
+ number = -1
49
+ for i in file_name:
50
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
51
+ if number == -1:
52
+ number = 0
53
+ number = number*10 + ord(i) - ord("0")
54
+ else:
55
+ if number != -1:
56
+ result.append(number)
57
+ number = -1
58
+ result.append(i)
59
+ if number != -1:
60
+ result.append(number)
61
+ result = tuple(result)
62
+ return result
63
+
64
+
65
+ def search_for_images(folder):
66
+ file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
67
+ file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
68
+ file_list = [i[1] for i in sorted(file_list)]
69
+ file_list = [os.path.join(folder, i) for i in file_list]
70
+ return file_list
71
+
72
+
73
+ def read_images(folder):
74
+ file_list = search_for_images(folder)
75
+ frames = [np.array(Image.open(i)) for i in file_list]
76
+ return frames
77
+
78
+
79
+ class LowMemoryImageFolder:
80
+ def __init__(self, folder, file_list=None):
81
+ if file_list is None:
82
+ self.file_list = search_for_images(folder)
83
+ else:
84
+ self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
85
+
86
+ def __len__(self):
87
+ return len(self.file_list)
88
+
89
+ def __getitem__(self, item):
90
+ return np.array(Image.open(self.file_list[item]))
91
+
92
+ def __del__(self):
93
+ pass
94
+
95
+
96
+ class VideoData:
97
+ def __init__(self, video_file, image_folder, **kwargs):
98
+ if video_file is not None:
99
+ self.data_type = "video"
100
+ self.data = LowMemoryVideo(video_file, **kwargs)
101
+ elif image_folder is not None:
102
+ self.data_type = "images"
103
+ self.data = LowMemoryImageFolder(image_folder, **kwargs)
104
+ else:
105
+ raise ValueError("Cannot open video or image folder")
106
+ self.length = None
107
+ self.height = None
108
+ self.width = None
109
+
110
+ def raw_data(self):
111
+ frames = []
112
+ for i in range(self.__len__()):
113
+ frames.append(self.__getitem__(i))
114
+ return frames
115
+
116
+ def set_length(self, length):
117
+ self.length = length
118
+
119
+ def set_shape(self, height, width):
120
+ self.height = height
121
+ self.width = width
122
+
123
+ def __len__(self):
124
+ if self.length is None:
125
+ return len(self.data)
126
+ else:
127
+ return self.length
128
+
129
+ def shape(self):
130
+ if self.height is not None and self.width is not None:
131
+ return self.height, self.width
132
+ else:
133
+ height, width, _ = self.__getitem__(0).shape
134
+ return height, width
135
+
136
+ def __getitem__(self, item):
137
+ frame = self.data.__getitem__(item)
138
+ height, width, _ = frame.shape
139
+ if self.height is not None and self.width is not None:
140
+ if self.height != height or self.width != width:
141
+ frame = Image.fromarray(frame).resize((self.width, self.height))
142
+ frame = np.array(frame)
143
+ return frame
144
+
145
+ def __del__(self):
146
+ pass
diffsynth/extensions/FastBlend/patch_match.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
2
+ import numpy as np
3
+ import cupy as cp
4
+ import cv2
5
+
6
+
7
+ class PatchMatcher:
8
+ def __init__(
9
+ self, height, width, channel, minimum_patch_size,
10
+ threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
11
+ random_search_steps=3, random_search_range=4,
12
+ use_mean_target_style=False, use_pairwise_patch_error=False,
13
+ tracking_window_size=0
14
+ ):
15
+ self.height = height
16
+ self.width = width
17
+ self.channel = channel
18
+ self.minimum_patch_size = minimum_patch_size
19
+ self.threads_per_block = threads_per_block
20
+ self.num_iter = num_iter
21
+ self.gpu_id = gpu_id
22
+ self.guide_weight = guide_weight
23
+ self.random_search_steps = random_search_steps
24
+ self.random_search_range = random_search_range
25
+ self.use_mean_target_style = use_mean_target_style
26
+ self.use_pairwise_patch_error = use_pairwise_patch_error
27
+ self.tracking_window_size = tracking_window_size
28
+
29
+ self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
30
+ self.pad_size = self.patch_size_list[0] // 2
31
+ self.grid = (
32
+ (height + threads_per_block - 1) // threads_per_block,
33
+ (width + threads_per_block - 1) // threads_per_block
34
+ )
35
+ self.block = (threads_per_block, threads_per_block)
36
+
37
+ def pad_image(self, image):
38
+ return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
39
+
40
+ def unpad_image(self, image):
41
+ return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
42
+
43
+ def apply_nnf_to_image(self, nnf, source):
44
+ batch_size = source.shape[0]
45
+ target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
46
+ remapping_kernel(
47
+ self.grid + (batch_size,),
48
+ self.block,
49
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
50
+ )
51
+ return target
52
+
53
+ def get_patch_error(self, source, nnf, target):
54
+ batch_size = source.shape[0]
55
+ error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
56
+ patch_error_kernel(
57
+ self.grid + (batch_size,),
58
+ self.block,
59
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
60
+ )
61
+ return error
62
+
63
+ def get_pairwise_patch_error(self, source, nnf):
64
+ batch_size = source.shape[0]//2
65
+ error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
66
+ source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
67
+ source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
68
+ pairwise_patch_error_kernel(
69
+ self.grid + (batch_size,),
70
+ self.block,
71
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
72
+ )
73
+ error = error.repeat(2, axis=0)
74
+ return error
75
+
76
+ def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
77
+ error_guide = self.get_patch_error(source_guide, nnf, target_guide)
78
+ if self.use_mean_target_style:
79
+ target_style = self.apply_nnf_to_image(nnf, source_style)
80
+ target_style = target_style.mean(axis=0, keepdims=True)
81
+ target_style = target_style.repeat(source_guide.shape[0], axis=0)
82
+ if self.use_pairwise_patch_error:
83
+ error_style = self.get_pairwise_patch_error(source_style, nnf)
84
+ else:
85
+ error_style = self.get_patch_error(source_style, nnf, target_style)
86
+ error = error_guide * self.guide_weight + error_style
87
+ return error
88
+
89
+ def clamp_bound(self, nnf):
90
+ nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
91
+ nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
92
+ return nnf
93
+
94
+ def random_step(self, nnf, r):
95
+ batch_size = nnf.shape[0]
96
+ step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
97
+ upd_nnf = self.clamp_bound(nnf + step)
98
+ return upd_nnf
99
+
100
+ def neighboor_step(self, nnf, d):
101
+ if d==0:
102
+ upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
103
+ upd_nnf[:, :, :, 0] += 1
104
+ elif d==1:
105
+ upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
106
+ upd_nnf[:, :, :, 1] += 1
107
+ elif d==2:
108
+ upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
109
+ upd_nnf[:, :, :, 0] -= 1
110
+ elif d==3:
111
+ upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
112
+ upd_nnf[:, :, :, 1] -= 1
113
+ upd_nnf = self.clamp_bound(upd_nnf)
114
+ return upd_nnf
115
+
116
+ def shift_nnf(self, nnf, d):
117
+ if d>0:
118
+ d = min(nnf.shape[0], d)
119
+ upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
120
+ else:
121
+ d = max(-nnf.shape[0], d)
122
+ upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
123
+ return upd_nnf
124
+
125
+ def track_step(self, nnf, d):
126
+ if self.use_pairwise_patch_error:
127
+ upd_nnf = cp.zeros_like(nnf)
128
+ upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
129
+ upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
130
+ else:
131
+ upd_nnf = self.shift_nnf(nnf, d)
132
+ return upd_nnf
133
+
134
+ def C(self, n, m):
135
+ # not used
136
+ c = 1
137
+ for i in range(1, n+1):
138
+ c *= i
139
+ for i in range(1, m+1):
140
+ c //= i
141
+ for i in range(1, n-m+1):
142
+ c //= i
143
+ return c
144
+
145
+ def bezier_step(self, nnf, r):
146
+ # not used
147
+ n = r * 2 - 1
148
+ upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
149
+ for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
150
+ if d>0:
151
+ ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
152
+ elif d<0:
153
+ ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
154
+ upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
155
+ upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
156
+ return upd_nnf
157
+
158
+ def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
159
+ upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
160
+ upd_idx = (upd_err < err)
161
+ nnf[upd_idx] = upd_nnf[upd_idx]
162
+ err[upd_idx] = upd_err[upd_idx]
163
+ return nnf, err
164
+
165
+ def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
166
+ for d in cp.random.permutation(4):
167
+ upd_nnf = self.neighboor_step(nnf, d)
168
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
169
+ return nnf, err
170
+
171
+ def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
172
+ for i in range(self.random_search_steps):
173
+ upd_nnf = self.random_step(nnf, self.random_search_range)
174
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
175
+ return nnf, err
176
+
177
+ def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
178
+ for d in range(1, self.tracking_window_size + 1):
179
+ upd_nnf = self.track_step(nnf, d)
180
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
181
+ upd_nnf = self.track_step(nnf, -d)
182
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
183
+ return nnf, err
184
+
185
+ def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
186
+ nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
187
+ nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
188
+ nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
189
+ return nnf, err
190
+
191
+ def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
192
+ with cp.cuda.Device(self.gpu_id):
193
+ source_guide = self.pad_image(source_guide)
194
+ target_guide = self.pad_image(target_guide)
195
+ source_style = self.pad_image(source_style)
196
+ for it in range(self.num_iter):
197
+ self.patch_size = self.patch_size_list[it]
198
+ target_style = self.apply_nnf_to_image(nnf, source_style)
199
+ err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
200
+ nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
201
+ target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
202
+ return nnf, target_style
203
+
204
+
205
+ class PyramidPatchMatcher:
206
+ def __init__(
207
+ self, image_height, image_width, channel, minimum_patch_size,
208
+ threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
209
+ use_mean_target_style=False, use_pairwise_patch_error=False,
210
+ tracking_window_size=0,
211
+ initialize="identity"
212
+ ):
213
+ maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
214
+ self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
215
+ self.pyramid_heights = []
216
+ self.pyramid_widths = []
217
+ self.patch_matchers = []
218
+ self.minimum_patch_size = minimum_patch_size
219
+ self.num_iter = num_iter
220
+ self.gpu_id = gpu_id
221
+ self.initialize = initialize
222
+ for level in range(self.pyramid_level):
223
+ height = image_height//(2**(self.pyramid_level - 1 - level))
224
+ width = image_width//(2**(self.pyramid_level - 1 - level))
225
+ self.pyramid_heights.append(height)
226
+ self.pyramid_widths.append(width)
227
+ self.patch_matchers.append(PatchMatcher(
228
+ height, width, channel, minimum_patch_size=minimum_patch_size,
229
+ threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
230
+ use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
231
+ tracking_window_size=tracking_window_size
232
+ ))
233
+
234
+ def resample_image(self, images, level):
235
+ height, width = self.pyramid_heights[level], self.pyramid_widths[level]
236
+ images = images.get()
237
+ images_resample = []
238
+ for image in images:
239
+ image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
240
+ images_resample.append(image_resample)
241
+ images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
242
+ return images_resample
243
+
244
+ def initialize_nnf(self, batch_size):
245
+ if self.initialize == "random":
246
+ height, width = self.pyramid_heights[0], self.pyramid_widths[0]
247
+ nnf = cp.stack([
248
+ cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
249
+ cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
250
+ ], axis=3)
251
+ elif self.initialize == "identity":
252
+ height, width = self.pyramid_heights[0], self.pyramid_widths[0]
253
+ nnf = cp.stack([
254
+ cp.repeat(cp.arange(height), width).reshape(height, width),
255
+ cp.tile(cp.arange(width), height).reshape(height, width)
256
+ ], axis=2)
257
+ nnf = cp.stack([nnf] * batch_size)
258
+ else:
259
+ raise NotImplementedError()
260
+ return nnf
261
+
262
+ def update_nnf(self, nnf, level):
263
+ # upscale
264
+ nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
265
+ nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
266
+ nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
267
+ # check if scale is 2
268
+ height, width = self.pyramid_heights[level], self.pyramid_widths[level]
269
+ if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
270
+ nnf = nnf.get().astype(np.float32)
271
+ nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
272
+ nnf = cp.array(np.stack(nnf), dtype=cp.int32)
273
+ nnf = self.patch_matchers[level].clamp_bound(nnf)
274
+ return nnf
275
+
276
+ def apply_nnf_to_image(self, nnf, image):
277
+ with cp.cuda.Device(self.gpu_id):
278
+ image = self.patch_matchers[-1].pad_image(image)
279
+ image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
280
+ return image
281
+
282
+ def estimate_nnf(self, source_guide, target_guide, source_style):
283
+ with cp.cuda.Device(self.gpu_id):
284
+ if not isinstance(source_guide, cp.ndarray):
285
+ source_guide = cp.array(source_guide, dtype=cp.float32)
286
+ if not isinstance(target_guide, cp.ndarray):
287
+ target_guide = cp.array(target_guide, dtype=cp.float32)
288
+ if not isinstance(source_style, cp.ndarray):
289
+ source_style = cp.array(source_style, dtype=cp.float32)
290
+ for level in range(self.pyramid_level):
291
+ nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
292
+ source_guide_ = self.resample_image(source_guide, level)
293
+ target_guide_ = self.resample_image(target_guide, level)
294
+ source_style_ = self.resample_image(source_style, level)
295
+ nnf, target_style = self.patch_matchers[level].estimate_nnf(
296
+ source_guide_, target_guide_, source_style_, nnf
297
+ )
298
+ return nnf.get(), target_style.get()
diffsynth/extensions/FastBlend/runners/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .accurate import AccurateModeRunner
2
+ from .fast import FastModeRunner
3
+ from .balanced import BalancedModeRunner
4
+ from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner
diffsynth/extensions/FastBlend/runners/accurate.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class AccurateModeRunner:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
13
+ patch_match_engine = PyramidPatchMatcher(
14
+ image_height=frames_style[0].shape[0],
15
+ image_width=frames_style[0].shape[1],
16
+ channel=3,
17
+ use_mean_target_style=True,
18
+ **ebsynth_config
19
+ )
20
+ # run
21
+ n = len(frames_style)
22
+ for target in tqdm(range(n), desc=desc):
23
+ l, r = max(target - window_size, 0), min(target + window_size + 1, n)
24
+ remapped_frames = []
25
+ for i in range(l, r, batch_size):
26
+ j = min(i + batch_size, r)
27
+ source_guide = np.stack([frames_guide[source] for source in range(i, j)])
28
+ target_guide = np.stack([frames_guide[target]] * (j - i))
29
+ source_style = np.stack([frames_style[source] for source in range(i, j)])
30
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
31
+ remapped_frames.append(target_style)
32
+ frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
33
+ frame = frame.clip(0, 255).astype("uint8")
34
+ if save_path is not None:
35
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
diffsynth/extensions/FastBlend/runners/balanced.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class BalancedModeRunner:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
13
+ patch_match_engine = PyramidPatchMatcher(
14
+ image_height=frames_style[0].shape[0],
15
+ image_width=frames_style[0].shape[1],
16
+ channel=3,
17
+ **ebsynth_config
18
+ )
19
+ # tasks
20
+ n = len(frames_style)
21
+ tasks = []
22
+ for target in range(n):
23
+ for source in range(target - window_size, target + window_size + 1):
24
+ if source >= 0 and source < n and source != target:
25
+ tasks.append((source, target))
26
+ # run
27
+ frames = [(None, 1) for i in range(n)]
28
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
29
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
30
+ source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
31
+ target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
32
+ source_style = np.stack([frames_style[source] for source, target in tasks_batch])
33
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
34
+ for (source, target), result in zip(tasks_batch, target_style):
35
+ frame, weight = frames[target]
36
+ if frame is None:
37
+ frame = frames_style[target]
38
+ frames[target] = (
39
+ frame * (weight / (weight + 1)) + result / (weight + 1),
40
+ weight + 1
41
+ )
42
+ if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
43
+ frame = frame.clip(0, 255).astype("uint8")
44
+ if save_path is not None:
45
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
46
+ frames[target] = (None, 1)
diffsynth/extensions/FastBlend/runners/fast.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import functools, os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class TableManager:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def task_list(self, n):
13
+ tasks = []
14
+ max_level = 1
15
+ while (1<<max_level)<=n:
16
+ max_level += 1
17
+ for i in range(n):
18
+ j = i
19
+ for level in range(max_level):
20
+ if i&(1<<level):
21
+ continue
22
+ j |= 1<<level
23
+ if j>=n:
24
+ break
25
+ meta_data = {
26
+ "source": i,
27
+ "target": j,
28
+ "level": level + 1
29
+ }
30
+ tasks.append(meta_data)
31
+ tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
32
+ return tasks
33
+
34
+ def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
35
+ n = len(frames_guide)
36
+ tasks = self.task_list(n)
37
+ remapping_table = [[(frames_style[i], 1)] for i in range(n)]
38
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
39
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
40
+ source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
41
+ target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
42
+ source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
43
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
44
+ for task, result in zip(tasks_batch, target_style):
45
+ target, level = task["target"], task["level"]
46
+ if len(remapping_table[target])==level:
47
+ remapping_table[target].append((result, 1))
48
+ else:
49
+ frame, weight = remapping_table[target][level]
50
+ remapping_table[target][level] = (
51
+ frame * (weight / (weight + 1)) + result / (weight + 1),
52
+ weight + 1
53
+ )
54
+ return remapping_table
55
+
56
+ def remapping_table_to_blending_table(self, table):
57
+ for i in range(len(table)):
58
+ for j in range(1, len(table[i])):
59
+ frame_1, weight_1 = table[i][j-1]
60
+ frame_2, weight_2 = table[i][j]
61
+ frame = (frame_1 + frame_2) / 2
62
+ weight = weight_1 + weight_2
63
+ table[i][j] = (frame, weight)
64
+ return table
65
+
66
+ def tree_query(self, leftbound, rightbound):
67
+ node_list = []
68
+ node_index = rightbound
69
+ while node_index>=leftbound:
70
+ node_level = 0
71
+ while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
72
+ node_level += 1
73
+ node_list.append((node_index, node_level))
74
+ node_index -= 1<<node_level
75
+ return node_list
76
+
77
+ def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
78
+ n = len(blending_table)
79
+ tasks = []
80
+ frames_result = []
81
+ for target in range(n):
82
+ node_list = self.tree_query(max(target-window_size, 0), target)
83
+ for source, level in node_list:
84
+ if source!=target:
85
+ meta_data = {
86
+ "source": source,
87
+ "target": target,
88
+ "level": level
89
+ }
90
+ tasks.append(meta_data)
91
+ else:
92
+ frames_result.append(blending_table[target][level])
93
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
94
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
95
+ source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
96
+ target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
97
+ source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
98
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
99
+ for task, frame_2 in zip(tasks_batch, target_style):
100
+ source, target, level = task["source"], task["target"], task["level"]
101
+ frame_1, weight_1 = frames_result[target]
102
+ weight_2 = blending_table[source][level][1]
103
+ weight = weight_1 + weight_2
104
+ frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
105
+ frames_result[target] = (frame, weight)
106
+ return frames_result
107
+
108
+
109
+ class FastModeRunner:
110
+ def __init__(self):
111
+ pass
112
+
113
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
114
+ frames_guide = frames_guide.raw_data()
115
+ frames_style = frames_style.raw_data()
116
+ table_manager = TableManager()
117
+ patch_match_engine = PyramidPatchMatcher(
118
+ image_height=frames_style[0].shape[0],
119
+ image_width=frames_style[0].shape[1],
120
+ channel=3,
121
+ **ebsynth_config
122
+ )
123
+ # left part
124
+ table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
125
+ table_l = table_manager.remapping_table_to_blending_table(table_l)
126
+ table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
127
+ # right part
128
+ table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
129
+ table_r = table_manager.remapping_table_to_blending_table(table_r)
130
+ table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
131
+ # merge
132
+ frames = []
133
+ for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
134
+ weight_m = -1
135
+ weight = weight_l + weight_m + weight_r
136
+ frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
137
+ frames.append(frame)
138
+ frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
139
+ if save_path is not None:
140
+ for target, frame in enumerate(frames):
141
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
diffsynth/extensions/FastBlend/runners/interpolation.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class InterpolationModeRunner:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def get_index_dict(self, index_style):
13
+ index_dict = {}
14
+ for i, index in enumerate(index_style):
15
+ index_dict[index] = i
16
+ return index_dict
17
+
18
+ def get_weight(self, l, m, r):
19
+ weight_l, weight_r = abs(m - r), abs(m - l)
20
+ if weight_l + weight_r == 0:
21
+ weight_l, weight_r = 0.5, 0.5
22
+ else:
23
+ weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
24
+ return weight_l, weight_r
25
+
26
+ def get_task_group(self, index_style, n):
27
+ task_group = []
28
+ index_style = sorted(index_style)
29
+ # first frame
30
+ if index_style[0]>0:
31
+ tasks = []
32
+ for m in range(index_style[0]):
33
+ tasks.append((index_style[0], m, index_style[0]))
34
+ task_group.append(tasks)
35
+ # middle frames
36
+ for l, r in zip(index_style[:-1], index_style[1:]):
37
+ tasks = []
38
+ for m in range(l, r):
39
+ tasks.append((l, m, r))
40
+ task_group.append(tasks)
41
+ # last frame
42
+ tasks = []
43
+ for m in range(index_style[-1], n):
44
+ tasks.append((index_style[-1], m, index_style[-1]))
45
+ task_group.append(tasks)
46
+ return task_group
47
+
48
+ def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
49
+ patch_match_engine = PyramidPatchMatcher(
50
+ image_height=frames_style[0].shape[0],
51
+ image_width=frames_style[0].shape[1],
52
+ channel=3,
53
+ use_mean_target_style=False,
54
+ use_pairwise_patch_error=True,
55
+ **ebsynth_config
56
+ )
57
+ # task
58
+ index_dict = self.get_index_dict(index_style)
59
+ task_group = self.get_task_group(index_style, len(frames_guide))
60
+ # run
61
+ for tasks in task_group:
62
+ index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
63
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
64
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
65
+ source_guide, target_guide, source_style = [], [], []
66
+ for l, m, r in tasks_batch:
67
+ # l -> m
68
+ source_guide.append(frames_guide[l])
69
+ target_guide.append(frames_guide[m])
70
+ source_style.append(frames_style[index_dict[l]])
71
+ # r -> m
72
+ source_guide.append(frames_guide[r])
73
+ target_guide.append(frames_guide[m])
74
+ source_style.append(frames_style[index_dict[r]])
75
+ source_guide = np.stack(source_guide)
76
+ target_guide = np.stack(target_guide)
77
+ source_style = np.stack(source_style)
78
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
79
+ if save_path is not None:
80
+ for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
81
+ weight_l, weight_r = self.get_weight(l, m, r)
82
+ frame = frame_l * weight_l + frame_r * weight_r
83
+ frame = frame.clip(0, 255).astype("uint8")
84
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
85
+
86
+
87
+ class InterpolationModeSingleFrameRunner:
88
+ def __init__(self):
89
+ pass
90
+
91
+ def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
92
+ # check input
93
+ tracking_window_size = ebsynth_config["tracking_window_size"]
94
+ if tracking_window_size * 2 >= batch_size:
95
+ raise ValueError("batch_size should be larger than track_window_size * 2")
96
+ frame_style = frames_style[0]
97
+ frame_guide = frames_guide[index_style[0]]
98
+ patch_match_engine = PyramidPatchMatcher(
99
+ image_height=frame_style.shape[0],
100
+ image_width=frame_style.shape[1],
101
+ channel=3,
102
+ **ebsynth_config
103
+ )
104
+ # run
105
+ frame_id, n = 0, len(frames_guide)
106
+ for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
107
+ if i + batch_size > n:
108
+ l, r = max(n - batch_size, 0), n
109
+ else:
110
+ l, r = i, i + batch_size
111
+ source_guide = np.stack([frame_guide] * (r-l))
112
+ target_guide = np.stack([frames_guide[i] for i in range(l, r)])
113
+ source_style = np.stack([frame_style] * (r-l))
114
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
115
+ for i, frame in zip(range(l, r), target_style):
116
+ if i==frame_id:
117
+ frame = frame.clip(0, 255).astype("uint8")
118
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
119
+ frame_id += 1
120
+ if r < n and r-frame_id <= tracking_window_size:
121
+ break
diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .blip_pretrain import *
diffsynth/extensions/ImageQualityMetric/BLIP/blip.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
+ '''
4
+
5
+ import warnings
6
+ warnings.filterwarnings("ignore")
7
+
8
+ import torch
9
+ import os
10
+ from urllib.parse import urlparse
11
+ from timm.models.hub import download_cached_file
12
+ from transformers import BertTokenizer
13
+ from .vit import VisionTransformer, interpolate_pos_embed
14
+
15
+
16
+ def default_bert():
17
+ current_dir = os.path.dirname(os.path.abspath(__file__))
18
+ project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
19
+ model_path = os.path.join(project_root, 'models', 'QualityMetric')
20
+ return os.path.join(model_path, "bert-base-uncased")
21
+
22
+
23
+ def init_tokenizer(bert_model_path):
24
+ tokenizer = BertTokenizer.from_pretrained(bert_model_path)
25
+ tokenizer.add_special_tokens({'bos_token':'[DEC]'})
26
+ tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
27
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
28
+ return tokenizer
29
+
30
+
31
+ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
32
+
33
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
34
+ if vit=='base':
35
+ vision_width = 768
36
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
37
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
38
+ drop_path_rate=0 or drop_path_rate
39
+ )
40
+ elif vit=='large':
41
+ vision_width = 1024
42
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
43
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
44
+ drop_path_rate=0.1 or drop_path_rate
45
+ )
46
+ return visual_encoder, vision_width
47
+
48
+
49
+ def is_url(url_or_filename):
50
+ parsed = urlparse(url_or_filename)
51
+ return parsed.scheme in ("http", "https")
52
+
53
+ def load_checkpoint(model,url_or_filename):
54
+ if is_url(url_or_filename):
55
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
56
+ checkpoint = torch.load(cached_file, map_location='cpu')
57
+ elif os.path.isfile(url_or_filename):
58
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
59
+ else:
60
+ raise RuntimeError('checkpoint url or path is invalid')
61
+
62
+ state_dict = checkpoint['model']
63
+
64
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
65
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
66
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
67
+ model.visual_encoder_m)
68
+ for key in model.state_dict().keys():
69
+ if key in state_dict.keys():
70
+ if state_dict[key].shape!=model.state_dict()[key].shape:
71
+ print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape)
72
+ del state_dict[key]
73
+
74
+ msg = model.load_state_dict(state_dict,strict=False)
75
+ print('load checkpoint from %s'%url_or_filename)
76
+ return model,msg
77
+
diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
+ '''
4
+
5
+ import transformers
6
+ transformers.logging.set_verbosity_error()
7
+
8
+ from torch import nn
9
+ import os
10
+ from .med import BertConfig, BertModel
11
+ from .blip import create_vit, init_tokenizer
12
+
13
+ class BLIP_Pretrain(nn.Module):
14
+ def __init__(self,
15
+ med_config = "med_config.json",
16
+ image_size = 224,
17
+ vit = 'base',
18
+ vit_grad_ckpt = False,
19
+ vit_ckpt_layer = 0,
20
+ embed_dim = 256,
21
+ queue_size = 57600,
22
+ momentum = 0.995,
23
+ bert_model_path = ""
24
+ ):
25
+ """
26
+ Args:
27
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
28
+ image_size (int): input image size
29
+ vit (str): model size of vision transformer
30
+ """
31
+ super().__init__()
32
+
33
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
34
+
35
+ self.tokenizer = init_tokenizer(bert_model_path)
36
+ encoder_config = BertConfig.from_json_file(med_config)
37
+ encoder_config.encoder_width = vision_width
38
+ self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
39
+
40
+ text_width = self.text_encoder.config.hidden_size
41
+
42
+ self.vision_proj = nn.Linear(vision_width, embed_dim)
43
+ self.text_proj = nn.Linear(text_width, embed_dim)
44
+
diffsynth/extensions/ImageQualityMetric/BLIP/med.py ADDED
@@ -0,0 +1,947 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
+ * Based on huggingface code base
4
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
5
+ '''
6
+
7
+ import math
8
+ from typing import Tuple
9
+
10
+ import torch
11
+ from torch import Tensor, device, nn
12
+ import torch.utils.checkpoint
13
+ from torch import nn
14
+ from torch.nn import CrossEntropyLoss
15
+
16
+ from transformers.activations import ACT2FN
17
+ from transformers.file_utils import (
18
+ ModelOutput,
19
+ )
20
+ from transformers.modeling_outputs import (
21
+ BaseModelOutputWithPastAndCrossAttentions,
22
+ BaseModelOutputWithPoolingAndCrossAttentions,
23
+ CausalLMOutputWithCrossAttentions,
24
+ MaskedLMOutput,
25
+ MultipleChoiceModelOutput,
26
+ NextSentencePredictorOutput,
27
+ QuestionAnsweringModelOutput,
28
+ SequenceClassifierOutput,
29
+ TokenClassifierOutput,
30
+ )
31
+ from transformers.modeling_utils import (
32
+ PreTrainedModel,
33
+ apply_chunking_to_forward,
34
+ find_pruneable_heads_and_indices,
35
+ prune_linear_layer,
36
+ )
37
+ from transformers.utils import logging
38
+ from transformers.models.bert.configuration_bert import BertConfig
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+
44
+ class BertEmbeddings(nn.Module):
45
+ """Construct the embeddings from word and position embeddings."""
46
+
47
+ def __init__(self, config):
48
+ super().__init__()
49
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
50
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
51
+
52
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
53
+ # any TensorFlow checkpoint file
54
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
55
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
56
+
57
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
58
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
59
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
60
+
61
+ self.config = config
62
+
63
+ def forward(
64
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
65
+ ):
66
+ if input_ids is not None:
67
+ input_shape = input_ids.size()
68
+ else:
69
+ input_shape = inputs_embeds.size()[:-1]
70
+
71
+ seq_length = input_shape[1]
72
+
73
+ if position_ids is None:
74
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
75
+
76
+ if inputs_embeds is None:
77
+ inputs_embeds = self.word_embeddings(input_ids)
78
+
79
+ embeddings = inputs_embeds
80
+
81
+ if self.position_embedding_type == "absolute":
82
+ position_embeddings = self.position_embeddings(position_ids)
83
+ embeddings += position_embeddings
84
+ embeddings = self.LayerNorm(embeddings)
85
+ embeddings = self.dropout(embeddings)
86
+ return embeddings
87
+
88
+
89
+ class BertSelfAttention(nn.Module):
90
+ def __init__(self, config, is_cross_attention):
91
+ super().__init__()
92
+ self.config = config
93
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
94
+ raise ValueError(
95
+ "The hidden size (%d) is not a multiple of the number of attention "
96
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
97
+ )
98
+
99
+ self.num_attention_heads = config.num_attention_heads
100
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
101
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
102
+
103
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
104
+ if is_cross_attention:
105
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
106
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
107
+ else:
108
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
109
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
110
+
111
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
112
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
113
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
114
+ self.max_position_embeddings = config.max_position_embeddings
115
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
116
+ self.save_attention = False
117
+
118
+ def save_attn_gradients(self, attn_gradients):
119
+ self.attn_gradients = attn_gradients
120
+
121
+ def get_attn_gradients(self):
122
+ return self.attn_gradients
123
+
124
+ def save_attention_map(self, attention_map):
125
+ self.attention_map = attention_map
126
+
127
+ def get_attention_map(self):
128
+ return self.attention_map
129
+
130
+ def transpose_for_scores(self, x):
131
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
132
+ x = x.view(*new_x_shape)
133
+ return x.permute(0, 2, 1, 3)
134
+
135
+ def forward(
136
+ self,
137
+ hidden_states,
138
+ attention_mask=None,
139
+ head_mask=None,
140
+ encoder_hidden_states=None,
141
+ encoder_attention_mask=None,
142
+ past_key_value=None,
143
+ output_attentions=False,
144
+ ):
145
+ mixed_query_layer = self.query(hidden_states)
146
+
147
+ # If this is instantiated as a cross-attention module, the keys
148
+ # and values come from an encoder; the attention mask needs to be
149
+ # such that the encoder's padding tokens are not attended to.
150
+ is_cross_attention = encoder_hidden_states is not None
151
+
152
+ if is_cross_attention:
153
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
154
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
155
+ attention_mask = encoder_attention_mask
156
+ elif past_key_value is not None:
157
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
158
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
159
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
160
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
161
+ else:
162
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
163
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
164
+
165
+ query_layer = self.transpose_for_scores(mixed_query_layer)
166
+
167
+ past_key_value = (key_layer, value_layer)
168
+
169
+ # Take the dot product between "query" and "key" to get the raw attention scores.
170
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
171
+
172
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
173
+ seq_length = hidden_states.size()[1]
174
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
175
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
176
+ distance = position_ids_l - position_ids_r
177
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
178
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
179
+
180
+ if self.position_embedding_type == "relative_key":
181
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
182
+ attention_scores = attention_scores + relative_position_scores
183
+ elif self.position_embedding_type == "relative_key_query":
184
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
185
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
186
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
187
+
188
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
189
+ if attention_mask is not None:
190
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
191
+ attention_scores = attention_scores + attention_mask
192
+
193
+ # Normalize the attention scores to probabilities.
194
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
195
+
196
+ if is_cross_attention and self.save_attention:
197
+ self.save_attention_map(attention_probs)
198
+ attention_probs.register_hook(self.save_attn_gradients)
199
+
200
+ # This is actually dropping out entire tokens to attend to, which might
201
+ # seem a bit unusual, but is taken from the original Transformer paper.
202
+ attention_probs_dropped = self.dropout(attention_probs)
203
+
204
+ # Mask heads if we want to
205
+ if head_mask is not None:
206
+ attention_probs_dropped = attention_probs_dropped * head_mask
207
+
208
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
209
+
210
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
211
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
212
+ context_layer = context_layer.view(*new_context_layer_shape)
213
+
214
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
215
+
216
+ outputs = outputs + (past_key_value,)
217
+ return outputs
218
+
219
+
220
+ class BertSelfOutput(nn.Module):
221
+ def __init__(self, config):
222
+ super().__init__()
223
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
224
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
225
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
226
+
227
+ def forward(self, hidden_states, input_tensor):
228
+ hidden_states = self.dense(hidden_states)
229
+ hidden_states = self.dropout(hidden_states)
230
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
231
+ return hidden_states
232
+
233
+
234
+ class BertAttention(nn.Module):
235
+ def __init__(self, config, is_cross_attention=False):
236
+ super().__init__()
237
+ self.self = BertSelfAttention(config, is_cross_attention)
238
+ self.output = BertSelfOutput(config)
239
+ self.pruned_heads = set()
240
+
241
+ def prune_heads(self, heads):
242
+ if len(heads) == 0:
243
+ return
244
+ heads, index = find_pruneable_heads_and_indices(
245
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
246
+ )
247
+
248
+ # Prune linear layers
249
+ self.self.query = prune_linear_layer(self.self.query, index)
250
+ self.self.key = prune_linear_layer(self.self.key, index)
251
+ self.self.value = prune_linear_layer(self.self.value, index)
252
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
253
+
254
+ # Update hyper params and store pruned heads
255
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
256
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
257
+ self.pruned_heads = self.pruned_heads.union(heads)
258
+
259
+ def forward(
260
+ self,
261
+ hidden_states,
262
+ attention_mask=None,
263
+ head_mask=None,
264
+ encoder_hidden_states=None,
265
+ encoder_attention_mask=None,
266
+ past_key_value=None,
267
+ output_attentions=False,
268
+ ):
269
+ self_outputs = self.self(
270
+ hidden_states,
271
+ attention_mask,
272
+ head_mask,
273
+ encoder_hidden_states,
274
+ encoder_attention_mask,
275
+ past_key_value,
276
+ output_attentions,
277
+ )
278
+ attention_output = self.output(self_outputs[0], hidden_states)
279
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
280
+ return outputs
281
+
282
+
283
+ class BertIntermediate(nn.Module):
284
+ def __init__(self, config):
285
+ super().__init__()
286
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
287
+ if isinstance(config.hidden_act, str):
288
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
289
+ else:
290
+ self.intermediate_act_fn = config.hidden_act
291
+
292
+ def forward(self, hidden_states):
293
+ hidden_states = self.dense(hidden_states)
294
+ hidden_states = self.intermediate_act_fn(hidden_states)
295
+ return hidden_states
296
+
297
+
298
+ class BertOutput(nn.Module):
299
+ def __init__(self, config):
300
+ super().__init__()
301
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
302
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
303
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
304
+
305
+ def forward(self, hidden_states, input_tensor):
306
+ hidden_states = self.dense(hidden_states)
307
+ hidden_states = self.dropout(hidden_states)
308
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
309
+ return hidden_states
310
+
311
+
312
+ class BertLayer(nn.Module):
313
+ def __init__(self, config, layer_num):
314
+ super().__init__()
315
+ self.config = config
316
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
317
+ self.seq_len_dim = 1
318
+ self.attention = BertAttention(config)
319
+ self.layer_num = layer_num
320
+ if self.config.add_cross_attention:
321
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
322
+ self.intermediate = BertIntermediate(config)
323
+ self.output = BertOutput(config)
324
+
325
+ def forward(
326
+ self,
327
+ hidden_states,
328
+ attention_mask=None,
329
+ head_mask=None,
330
+ encoder_hidden_states=None,
331
+ encoder_attention_mask=None,
332
+ past_key_value=None,
333
+ output_attentions=False,
334
+ mode=None,
335
+ ):
336
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
337
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
338
+ self_attention_outputs = self.attention(
339
+ hidden_states,
340
+ attention_mask,
341
+ head_mask,
342
+ output_attentions=output_attentions,
343
+ past_key_value=self_attn_past_key_value,
344
+ )
345
+ attention_output = self_attention_outputs[0]
346
+
347
+ outputs = self_attention_outputs[1:-1]
348
+ present_key_value = self_attention_outputs[-1]
349
+
350
+ if mode=='multimodal':
351
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
352
+
353
+ cross_attention_outputs = self.crossattention(
354
+ attention_output,
355
+ attention_mask,
356
+ head_mask,
357
+ encoder_hidden_states,
358
+ encoder_attention_mask,
359
+ output_attentions=output_attentions,
360
+ )
361
+ attention_output = cross_attention_outputs[0]
362
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
363
+ layer_output = apply_chunking_to_forward(
364
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
365
+ )
366
+ outputs = (layer_output,) + outputs
367
+
368
+ outputs = outputs + (present_key_value,)
369
+
370
+ return outputs
371
+
372
+ def feed_forward_chunk(self, attention_output):
373
+ intermediate_output = self.intermediate(attention_output)
374
+ layer_output = self.output(intermediate_output, attention_output)
375
+ return layer_output
376
+
377
+
378
+ class BertEncoder(nn.Module):
379
+ def __init__(self, config):
380
+ super().__init__()
381
+ self.config = config
382
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
383
+ self.gradient_checkpointing = False
384
+
385
+ def forward(
386
+ self,
387
+ hidden_states,
388
+ attention_mask=None,
389
+ head_mask=None,
390
+ encoder_hidden_states=None,
391
+ encoder_attention_mask=None,
392
+ past_key_values=None,
393
+ use_cache=None,
394
+ output_attentions=False,
395
+ output_hidden_states=False,
396
+ return_dict=True,
397
+ mode='multimodal',
398
+ ):
399
+ all_hidden_states = () if output_hidden_states else None
400
+ all_self_attentions = () if output_attentions else None
401
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
402
+
403
+ next_decoder_cache = () if use_cache else None
404
+
405
+ for i in range(self.config.num_hidden_layers):
406
+ layer_module = self.layer[i]
407
+ if output_hidden_states:
408
+ all_hidden_states = all_hidden_states + (hidden_states,)
409
+
410
+ layer_head_mask = head_mask[i] if head_mask is not None else None
411
+ past_key_value = past_key_values[i] if past_key_values is not None else None
412
+
413
+ if self.gradient_checkpointing and self.training:
414
+
415
+ if use_cache:
416
+ logger.warn(
417
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
418
+ )
419
+ use_cache = False
420
+
421
+ def create_custom_forward(module):
422
+ def custom_forward(*inputs):
423
+ return module(*inputs, past_key_value, output_attentions)
424
+
425
+ return custom_forward
426
+
427
+ layer_outputs = torch.utils.checkpoint.checkpoint(
428
+ create_custom_forward(layer_module),
429
+ hidden_states,
430
+ attention_mask,
431
+ layer_head_mask,
432
+ encoder_hidden_states,
433
+ encoder_attention_mask,
434
+ mode=mode,
435
+ )
436
+ else:
437
+ layer_outputs = layer_module(
438
+ hidden_states,
439
+ attention_mask,
440
+ layer_head_mask,
441
+ encoder_hidden_states,
442
+ encoder_attention_mask,
443
+ past_key_value,
444
+ output_attentions,
445
+ mode=mode,
446
+ )
447
+
448
+ hidden_states = layer_outputs[0]
449
+ if use_cache:
450
+ next_decoder_cache += (layer_outputs[-1],)
451
+ if output_attentions:
452
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
453
+
454
+ if output_hidden_states:
455
+ all_hidden_states = all_hidden_states + (hidden_states,)
456
+
457
+ if not return_dict:
458
+ return tuple(
459
+ v
460
+ for v in [
461
+ hidden_states,
462
+ next_decoder_cache,
463
+ all_hidden_states,
464
+ all_self_attentions,
465
+ all_cross_attentions,
466
+ ]
467
+ if v is not None
468
+ )
469
+ return BaseModelOutputWithPastAndCrossAttentions(
470
+ last_hidden_state=hidden_states,
471
+ past_key_values=next_decoder_cache,
472
+ hidden_states=all_hidden_states,
473
+ attentions=all_self_attentions,
474
+ cross_attentions=all_cross_attentions,
475
+ )
476
+
477
+
478
+ class BertPooler(nn.Module):
479
+ def __init__(self, config):
480
+ super().__init__()
481
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
482
+ self.activation = nn.Tanh()
483
+
484
+ def forward(self, hidden_states):
485
+ # We "pool" the model by simply taking the hidden state corresponding
486
+ # to the first token.
487
+ first_token_tensor = hidden_states[:, 0]
488
+ pooled_output = self.dense(first_token_tensor)
489
+ pooled_output = self.activation(pooled_output)
490
+ return pooled_output
491
+
492
+
493
+ class BertPredictionHeadTransform(nn.Module):
494
+ def __init__(self, config):
495
+ super().__init__()
496
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
497
+ if isinstance(config.hidden_act, str):
498
+ self.transform_act_fn = ACT2FN[config.hidden_act]
499
+ else:
500
+ self.transform_act_fn = config.hidden_act
501
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
502
+
503
+ def forward(self, hidden_states):
504
+ hidden_states = self.dense(hidden_states)
505
+ hidden_states = self.transform_act_fn(hidden_states)
506
+ hidden_states = self.LayerNorm(hidden_states)
507
+ return hidden_states
508
+
509
+
510
+ class BertLMPredictionHead(nn.Module):
511
+ def __init__(self, config):
512
+ super().__init__()
513
+ self.transform = BertPredictionHeadTransform(config)
514
+
515
+ # The output weights are the same as the input embeddings, but there is
516
+ # an output-only bias for each token.
517
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
518
+
519
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
520
+
521
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
522
+ self.decoder.bias = self.bias
523
+
524
+ def forward(self, hidden_states):
525
+ hidden_states = self.transform(hidden_states)
526
+ hidden_states = self.decoder(hidden_states)
527
+ return hidden_states
528
+
529
+
530
+ class BertOnlyMLMHead(nn.Module):
531
+ def __init__(self, config):
532
+ super().__init__()
533
+ self.predictions = BertLMPredictionHead(config)
534
+
535
+ def forward(self, sequence_output):
536
+ prediction_scores = self.predictions(sequence_output)
537
+ return prediction_scores
538
+
539
+
540
+ class BertPreTrainedModel(PreTrainedModel):
541
+ """
542
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
543
+ models.
544
+ """
545
+
546
+ config_class = BertConfig
547
+ base_model_prefix = "bert"
548
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
549
+
550
+ def _init_weights(self, module):
551
+ """ Initialize the weights """
552
+ if isinstance(module, (nn.Linear, nn.Embedding)):
553
+ # Slightly different from the TF version which uses truncated_normal for initialization
554
+ # cf https://github.com/pytorch/pytorch/pull/5617
555
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
556
+ elif isinstance(module, nn.LayerNorm):
557
+ module.bias.data.zero_()
558
+ module.weight.data.fill_(1.0)
559
+ if isinstance(module, nn.Linear) and module.bias is not None:
560
+ module.bias.data.zero_()
561
+
562
+
563
+ class BertModel(BertPreTrainedModel):
564
+ """
565
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
566
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
567
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
568
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
569
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
570
+ input to the forward pass.
571
+ """
572
+
573
+ def __init__(self, config, add_pooling_layer=True):
574
+ super().__init__(config)
575
+ self.config = config
576
+
577
+ self.embeddings = BertEmbeddings(config)
578
+
579
+ self.encoder = BertEncoder(config)
580
+
581
+ self.pooler = BertPooler(config) if add_pooling_layer else None
582
+
583
+ self.init_weights()
584
+
585
+
586
+ def get_input_embeddings(self):
587
+ return self.embeddings.word_embeddings
588
+
589
+ def set_input_embeddings(self, value):
590
+ self.embeddings.word_embeddings = value
591
+
592
+ def _prune_heads(self, heads_to_prune):
593
+ """
594
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
595
+ class PreTrainedModel
596
+ """
597
+ for layer, heads in heads_to_prune.items():
598
+ self.encoder.layer[layer].attention.prune_heads(heads)
599
+
600
+
601
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
602
+ """
603
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
604
+
605
+ Arguments:
606
+ attention_mask (:obj:`torch.Tensor`):
607
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
608
+ input_shape (:obj:`Tuple[int]`):
609
+ The shape of the input to the model.
610
+ device: (:obj:`torch.device`):
611
+ The device of the input to the model.
612
+
613
+ Returns:
614
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
615
+ """
616
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
617
+ # ourselves in which case we just need to make it broadcastable to all heads.
618
+ if attention_mask.dim() == 3:
619
+ extended_attention_mask = attention_mask[:, None, :, :]
620
+ elif attention_mask.dim() == 2:
621
+ # Provided a padding mask of dimensions [batch_size, seq_length]
622
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
623
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
624
+ if is_decoder:
625
+ batch_size, seq_length = input_shape
626
+
627
+ seq_ids = torch.arange(seq_length, device=device)
628
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
629
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
630
+ # causal and attention masks must have same type with pytorch version < 1.3
631
+ causal_mask = causal_mask.to(attention_mask.dtype)
632
+
633
+ if causal_mask.shape[1] < attention_mask.shape[1]:
634
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
635
+ causal_mask = torch.cat(
636
+ [
637
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
638
+ causal_mask,
639
+ ],
640
+ axis=-1,
641
+ )
642
+
643
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
644
+ else:
645
+ extended_attention_mask = attention_mask[:, None, None, :]
646
+ else:
647
+ raise ValueError(
648
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
649
+ input_shape, attention_mask.shape
650
+ )
651
+ )
652
+
653
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
654
+ # masked positions, this operation will create a tensor which is 0.0 for
655
+ # positions we want to attend and -10000.0 for masked positions.
656
+ # Since we are adding it to the raw scores before the softmax, this is
657
+ # effectively the same as removing these entirely.
658
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
659
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
660
+ return extended_attention_mask
661
+
662
+ def forward(
663
+ self,
664
+ input_ids=None,
665
+ attention_mask=None,
666
+ position_ids=None,
667
+ head_mask=None,
668
+ inputs_embeds=None,
669
+ encoder_embeds=None,
670
+ encoder_hidden_states=None,
671
+ encoder_attention_mask=None,
672
+ past_key_values=None,
673
+ use_cache=None,
674
+ output_attentions=None,
675
+ output_hidden_states=None,
676
+ return_dict=None,
677
+ is_decoder=False,
678
+ mode='multimodal',
679
+ ):
680
+ r"""
681
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
682
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
683
+ the model is configured as a decoder.
684
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
685
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
686
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
687
+ - 1 for tokens that are **not masked**,
688
+ - 0 for tokens that are **masked**.
689
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
690
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
691
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
692
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
693
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
694
+ use_cache (:obj:`bool`, `optional`):
695
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
696
+ decoding (see :obj:`past_key_values`).
697
+ """
698
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
699
+ output_hidden_states = (
700
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
701
+ )
702
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
703
+
704
+ if is_decoder:
705
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
706
+ else:
707
+ use_cache = False
708
+
709
+ if input_ids is not None and inputs_embeds is not None:
710
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
711
+ elif input_ids is not None:
712
+ input_shape = input_ids.size()
713
+ batch_size, seq_length = input_shape
714
+ device = input_ids.device
715
+ elif inputs_embeds is not None:
716
+ input_shape = inputs_embeds.size()[:-1]
717
+ batch_size, seq_length = input_shape
718
+ device = inputs_embeds.device
719
+ elif encoder_embeds is not None:
720
+ input_shape = encoder_embeds.size()[:-1]
721
+ batch_size, seq_length = input_shape
722
+ device = encoder_embeds.device
723
+ else:
724
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
725
+
726
+ # past_key_values_length
727
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
728
+
729
+ if attention_mask is None:
730
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
731
+
732
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
733
+ # ourselves in which case we just need to make it broadcastable to all heads.
734
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
735
+ device, is_decoder)
736
+
737
+ # If a 2D or 3D attention mask is provided for the cross-attention
738
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
739
+ if encoder_hidden_states is not None:
740
+ if type(encoder_hidden_states) == list:
741
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
742
+ else:
743
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
744
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
745
+
746
+ if type(encoder_attention_mask) == list:
747
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
748
+ elif encoder_attention_mask is None:
749
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
750
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
751
+ else:
752
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
753
+ else:
754
+ encoder_extended_attention_mask = None
755
+
756
+ # Prepare head mask if needed
757
+ # 1.0 in head_mask indicate we keep the head
758
+ # attention_probs has shape bsz x n_heads x N x N
759
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
760
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
761
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
762
+
763
+ if encoder_embeds is None:
764
+ embedding_output = self.embeddings(
765
+ input_ids=input_ids,
766
+ position_ids=position_ids,
767
+ inputs_embeds=inputs_embeds,
768
+ past_key_values_length=past_key_values_length,
769
+ )
770
+ else:
771
+ embedding_output = encoder_embeds
772
+
773
+ encoder_outputs = self.encoder(
774
+ embedding_output,
775
+ attention_mask=extended_attention_mask,
776
+ head_mask=head_mask,
777
+ encoder_hidden_states=encoder_hidden_states,
778
+ encoder_attention_mask=encoder_extended_attention_mask,
779
+ past_key_values=past_key_values,
780
+ use_cache=use_cache,
781
+ output_attentions=output_attentions,
782
+ output_hidden_states=output_hidden_states,
783
+ return_dict=return_dict,
784
+ mode=mode,
785
+ )
786
+ sequence_output = encoder_outputs[0]
787
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
788
+
789
+ if not return_dict:
790
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
791
+
792
+ return BaseModelOutputWithPoolingAndCrossAttentions(
793
+ last_hidden_state=sequence_output,
794
+ pooler_output=pooled_output,
795
+ past_key_values=encoder_outputs.past_key_values,
796
+ hidden_states=encoder_outputs.hidden_states,
797
+ attentions=encoder_outputs.attentions,
798
+ cross_attentions=encoder_outputs.cross_attentions,
799
+ )
800
+
801
+
802
+
803
+ class BertLMHeadModel(BertPreTrainedModel):
804
+
805
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
806
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
807
+
808
+ def __init__(self, config):
809
+ super().__init__(config)
810
+
811
+ self.bert = BertModel(config, add_pooling_layer=False)
812
+ self.cls = BertOnlyMLMHead(config)
813
+
814
+ self.init_weights()
815
+
816
+ def get_output_embeddings(self):
817
+ return self.cls.predictions.decoder
818
+
819
+ def set_output_embeddings(self, new_embeddings):
820
+ self.cls.predictions.decoder = new_embeddings
821
+
822
+ def forward(
823
+ self,
824
+ input_ids=None,
825
+ attention_mask=None,
826
+ position_ids=None,
827
+ head_mask=None,
828
+ inputs_embeds=None,
829
+ encoder_hidden_states=None,
830
+ encoder_attention_mask=None,
831
+ labels=None,
832
+ past_key_values=None,
833
+ use_cache=None,
834
+ output_attentions=None,
835
+ output_hidden_states=None,
836
+ return_dict=None,
837
+ return_logits=False,
838
+ is_decoder=True,
839
+ reduction='mean',
840
+ mode='multimodal',
841
+ ):
842
+ r"""
843
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
844
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
845
+ the model is configured as a decoder.
846
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
847
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
848
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
849
+ - 1 for tokens that are **not masked**,
850
+ - 0 for tokens that are **masked**.
851
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
852
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
853
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
854
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
855
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
856
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
857
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
858
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
859
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
860
+ use_cache (:obj:`bool`, `optional`):
861
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
862
+ decoding (see :obj:`past_key_values`).
863
+ Returns:
864
+ Example::
865
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
866
+ >>> import torch
867
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
868
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
869
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
870
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
871
+ >>> outputs = model(**inputs)
872
+ >>> prediction_logits = outputs.logits
873
+ """
874
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
875
+ if labels is not None:
876
+ use_cache = False
877
+
878
+ outputs = self.bert(
879
+ input_ids,
880
+ attention_mask=attention_mask,
881
+ position_ids=position_ids,
882
+ head_mask=head_mask,
883
+ inputs_embeds=inputs_embeds,
884
+ encoder_hidden_states=encoder_hidden_states,
885
+ encoder_attention_mask=encoder_attention_mask,
886
+ past_key_values=past_key_values,
887
+ use_cache=use_cache,
888
+ output_attentions=output_attentions,
889
+ output_hidden_states=output_hidden_states,
890
+ return_dict=return_dict,
891
+ is_decoder=is_decoder,
892
+ mode=mode,
893
+ )
894
+
895
+ sequence_output = outputs[0]
896
+ prediction_scores = self.cls(sequence_output)
897
+
898
+ if return_logits:
899
+ return prediction_scores[:, :-1, :].contiguous()
900
+
901
+ lm_loss = None
902
+ if labels is not None:
903
+ # we are doing next-token prediction; shift prediction scores and input ids by one
904
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
905
+ labels = labels[:, 1:].contiguous()
906
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
907
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
908
+ if reduction=='none':
909
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
910
+
911
+ if not return_dict:
912
+ output = (prediction_scores,) + outputs[2:]
913
+ return ((lm_loss,) + output) if lm_loss is not None else output
914
+
915
+ return CausalLMOutputWithCrossAttentions(
916
+ loss=lm_loss,
917
+ logits=prediction_scores,
918
+ past_key_values=outputs.past_key_values,
919
+ hidden_states=outputs.hidden_states,
920
+ attentions=outputs.attentions,
921
+ cross_attentions=outputs.cross_attentions,
922
+ )
923
+
924
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
925
+ input_shape = input_ids.shape
926
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
927
+ if attention_mask is None:
928
+ attention_mask = input_ids.new_ones(input_shape)
929
+
930
+ # cut decoder_input_ids if past is used
931
+ if past is not None:
932
+ input_ids = input_ids[:, -1:]
933
+
934
+ return {
935
+ "input_ids": input_ids,
936
+ "attention_mask": attention_mask,
937
+ "past_key_values": past,
938
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
939
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
940
+ "is_decoder": True,
941
+ }
942
+
943
+ def _reorder_cache(self, past, beam_idx):
944
+ reordered_past = ()
945
+ for layer_past in past:
946
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
947
+ return reordered_past
diffsynth/extensions/ImageQualityMetric/BLIP/vit.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Adapted from BLIP (https://github.com/salesforce/BLIP)
3
+ * Based on timm code base
4
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
5
+ '''
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from functools import partial
11
+
12
+ from timm.models.vision_transformer import _cfg, PatchEmbed
13
+ from timm.models.registry import register_model
14
+ from timm.models.layers import trunc_normal_, DropPath
15
+ from timm.models.helpers import named_apply, adapt_input_conv
16
+
17
+ # from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
18
+
19
+ class Mlp(nn.Module):
20
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
21
+ """
22
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
23
+ super().__init__()
24
+ out_features = out_features or in_features
25
+ hidden_features = hidden_features or in_features
26
+ self.fc1 = nn.Linear(in_features, hidden_features)
27
+ self.act = act_layer()
28
+ self.fc2 = nn.Linear(hidden_features, out_features)
29
+ self.drop = nn.Dropout(drop)
30
+
31
+ def forward(self, x):
32
+ x = self.fc1(x)
33
+ x = self.act(x)
34
+ x = self.drop(x)
35
+ x = self.fc2(x)
36
+ x = self.drop(x)
37
+ return x
38
+
39
+
40
+ class Attention(nn.Module):
41
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
42
+ super().__init__()
43
+ self.num_heads = num_heads
44
+ head_dim = dim // num_heads
45
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
46
+ self.scale = qk_scale or head_dim ** -0.5
47
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
48
+ self.attn_drop = nn.Dropout(attn_drop)
49
+ self.proj = nn.Linear(dim, dim)
50
+ self.proj_drop = nn.Dropout(proj_drop)
51
+ self.attn_gradients = None
52
+ self.attention_map = None
53
+
54
+ def save_attn_gradients(self, attn_gradients):
55
+ self.attn_gradients = attn_gradients
56
+
57
+ def get_attn_gradients(self):
58
+ return self.attn_gradients
59
+
60
+ def save_attention_map(self, attention_map):
61
+ self.attention_map = attention_map
62
+
63
+ def get_attention_map(self):
64
+ return self.attention_map
65
+
66
+ def forward(self, x, register_hook=False):
67
+ B, N, C = x.shape
68
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
69
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
70
+
71
+ attn = (q @ k.transpose(-2, -1)) * self.scale
72
+ attn = attn.softmax(dim=-1)
73
+ attn = self.attn_drop(attn)
74
+
75
+ if register_hook:
76
+ self.save_attention_map(attn)
77
+ attn.register_hook(self.save_attn_gradients)
78
+
79
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
80
+ x = self.proj(x)
81
+ x = self.proj_drop(x)
82
+ return x
83
+
84
+
85
+ class Block(nn.Module):
86
+
87
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
88
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
89
+ super().__init__()
90
+ self.norm1 = norm_layer(dim)
91
+ self.attn = Attention(
92
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
93
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
94
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
95
+ self.norm2 = norm_layer(dim)
96
+ mlp_hidden_dim = int(dim * mlp_ratio)
97
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
98
+
99
+ # if use_grad_checkpointing:
100
+ # self.attn = checkpoint_wrapper(self.attn)
101
+ # self.mlp = checkpoint_wrapper(self.mlp)
102
+
103
+ def forward(self, x, register_hook=False):
104
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
105
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
106
+ return x
107
+
108
+
109
+ class VisionTransformer(nn.Module):
110
+ """ Vision Transformer
111
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
112
+ https://arxiv.org/abs/2010.11929
113
+ """
114
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
115
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
116
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
117
+ use_grad_checkpointing=False, ckpt_layer=0):
118
+ """
119
+ Args:
120
+ img_size (int, tuple): input image size
121
+ patch_size (int, tuple): patch size
122
+ in_chans (int): number of input channels
123
+ num_classes (int): number of classes for classification head
124
+ embed_dim (int): embedding dimension
125
+ depth (int): depth of transformer
126
+ num_heads (int): number of attention heads
127
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
128
+ qkv_bias (bool): enable bias for qkv if True
129
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
130
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
131
+ drop_rate (float): dropout rate
132
+ attn_drop_rate (float): attention dropout rate
133
+ drop_path_rate (float): stochastic depth rate
134
+ norm_layer: (nn.Module): normalization layer
135
+ """
136
+ super().__init__()
137
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
138
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
139
+
140
+ self.patch_embed = PatchEmbed(
141
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
142
+
143
+ num_patches = self.patch_embed.num_patches
144
+
145
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
146
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
147
+ self.pos_drop = nn.Dropout(p=drop_rate)
148
+
149
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
150
+ self.blocks = nn.ModuleList([
151
+ Block(
152
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
153
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
154
+ use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
155
+ )
156
+ for i in range(depth)])
157
+ self.norm = norm_layer(embed_dim)
158
+
159
+ trunc_normal_(self.pos_embed, std=.02)
160
+ trunc_normal_(self.cls_token, std=.02)
161
+ self.apply(self._init_weights)
162
+
163
+ def _init_weights(self, m):
164
+ if isinstance(m, nn.Linear):
165
+ trunc_normal_(m.weight, std=.02)
166
+ if isinstance(m, nn.Linear) and m.bias is not None:
167
+ nn.init.constant_(m.bias, 0)
168
+ elif isinstance(m, nn.LayerNorm):
169
+ nn.init.constant_(m.bias, 0)
170
+ nn.init.constant_(m.weight, 1.0)
171
+
172
+ @torch.jit.ignore
173
+ def no_weight_decay(self):
174
+ return {'pos_embed', 'cls_token'}
175
+
176
+ def forward(self, x, register_blk=-1):
177
+ B = x.shape[0]
178
+ x = self.patch_embed(x)
179
+
180
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
181
+ x = torch.cat((cls_tokens, x), dim=1)
182
+
183
+ x = x + self.pos_embed[:,:x.size(1),:]
184
+ x = self.pos_drop(x)
185
+
186
+ for i,blk in enumerate(self.blocks):
187
+ x = blk(x, register_blk==i)
188
+ x = self.norm(x)
189
+
190
+ return x
191
+
192
+ @torch.jit.ignore()
193
+ def load_pretrained(self, checkpoint_path, prefix=''):
194
+ _load_weights(self, checkpoint_path, prefix)
195
+
196
+
197
+ @torch.no_grad()
198
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
199
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
200
+ """
201
+ import numpy as np
202
+
203
+ def _n2p(w, t=True):
204
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
205
+ w = w.flatten()
206
+ if t:
207
+ if w.ndim == 4:
208
+ w = w.transpose([3, 2, 0, 1])
209
+ elif w.ndim == 3:
210
+ w = w.transpose([2, 0, 1])
211
+ elif w.ndim == 2:
212
+ w = w.transpose([1, 0])
213
+ return torch.from_numpy(w)
214
+
215
+ w = np.load(checkpoint_path)
216
+ if not prefix and 'opt/target/embedding/kernel' in w:
217
+ prefix = 'opt/target/'
218
+
219
+ if hasattr(model.patch_embed, 'backbone'):
220
+ # hybrid
221
+ backbone = model.patch_embed.backbone
222
+ stem_only = not hasattr(backbone, 'stem')
223
+ stem = backbone if stem_only else backbone.stem
224
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
225
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
226
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
227
+ if not stem_only:
228
+ for i, stage in enumerate(backbone.stages):
229
+ for j, block in enumerate(stage.blocks):
230
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
231
+ for r in range(3):
232
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
233
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
234
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
235
+ if block.downsample is not None:
236
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
237
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
238
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
239
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
240
+ else:
241
+ embed_conv_w = adapt_input_conv(
242
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
243
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
244
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
245
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
246
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
247
+ if pos_embed_w.shape != model.pos_embed.shape:
248
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
249
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
250
+ model.pos_embed.copy_(pos_embed_w)
251
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
252
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
253
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
254
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
255
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
256
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
257
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
258
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
259
+ for i, block in enumerate(model.blocks.children()):
260
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
261
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
262
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
263
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
264
+ block.attn.qkv.weight.copy_(torch.cat([
265
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
266
+ block.attn.qkv.bias.copy_(torch.cat([
267
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
268
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
269
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
270
+ for r in range(2):
271
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
272
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
273
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
274
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
275
+
276
+
277
+ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
278
+ # interpolate position embedding
279
+ embedding_size = pos_embed_checkpoint.shape[-1]
280
+ num_patches = visual_encoder.patch_embed.num_patches
281
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
282
+ # height (== width) for the checkpoint position embedding
283
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
284
+ # height (== width) for the new position embedding
285
+ new_size = int(num_patches ** 0.5)
286
+
287
+ if orig_size!=new_size:
288
+ # class_token and dist_token are kept unchanged
289
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
290
+ # only the position tokens are interpolated
291
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
292
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
293
+ pos_tokens = torch.nn.functional.interpolate(
294
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
295
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
296
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
297
+ print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
298
+
299
+ return new_pos_embed
300
+ else:
301
+ return pos_embed_checkpoint
diffsynth/extensions/ImageQualityMetric/__init__.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modelscope import snapshot_download
2
+ from typing_extensions import Literal, TypeAlias
3
+ import os
4
+ from diffsynth.extensions.ImageQualityMetric.aesthetic import AestheticScore
5
+ from diffsynth.extensions.ImageQualityMetric.imagereward import ImageRewardScore
6
+ from diffsynth.extensions.ImageQualityMetric.pickscore import PickScore
7
+ from diffsynth.extensions.ImageQualityMetric.clip import CLIPScore
8
+ from diffsynth.extensions.ImageQualityMetric.hps import HPScore_v2
9
+ from diffsynth.extensions.ImageQualityMetric.mps import MPScore
10
+
11
+
12
+ preference_model_id: TypeAlias = Literal[
13
+ "ImageReward",
14
+ "Aesthetic",
15
+ "PickScore",
16
+ "CLIP",
17
+ "HPSv2",
18
+ "HPSv2.1",
19
+ "MPS",
20
+ ]
21
+ model_dict = {
22
+ "ImageReward": {
23
+ "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
24
+ "allow_file_pattern": [
25
+ "ImageReward/ImageReward.safetensors",
26
+ "ImageReward/med_config.json",
27
+ "bert-base-uncased/config.json",
28
+ "bert-base-uncased/model.safetensors",
29
+ "bert-base-uncased/tokenizer.json",
30
+ "bert-base-uncased/tokenizer_config.json",
31
+ "bert-base-uncased/vocab.txt",
32
+ ],
33
+ "load_path": {
34
+ "imagereward": "ImageReward/ImageReward.safetensors",
35
+ "med_config": "ImageReward/med_config.json",
36
+ "bert_model_path": "bert-base-uncased",
37
+ },
38
+ "model_class": ImageRewardScore
39
+ },
40
+ "Aesthetic": {
41
+ "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
42
+ "allow_file_pattern": [
43
+ "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
44
+ "clip-vit-large-patch14/config.json",
45
+ "clip-vit-large-patch14/merges.txt",
46
+ "clip-vit-large-patch14/model.safetensors",
47
+ "clip-vit-large-patch14/preprocessor_config.json",
48
+ "clip-vit-large-patch14/special_tokens_map.json",
49
+ "clip-vit-large-patch14/tokenizer.json",
50
+ "clip-vit-large-patch14/tokenizer_config.json",
51
+ "clip-vit-large-patch14/vocab.json",
52
+ ],
53
+ "load_path": {
54
+ "aesthetic_predictor": "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
55
+ "clip-large": "clip-vit-large-patch14",
56
+ },
57
+ "model_class": AestheticScore
58
+ },
59
+ "PickScore": {
60
+ "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
61
+ "allow_file_pattern": [
62
+ "PickScore_v1/*",
63
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
64
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
65
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
66
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
67
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
68
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
69
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
70
+ ],
71
+ "load_path": {
72
+ "pickscore": "PickScore_v1",
73
+ "clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
74
+ },
75
+ "model_class": PickScore
76
+ },
77
+ "CLIP": {
78
+ "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
79
+ "allow_file_pattern": [
80
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
81
+ "bpe_simple_vocab_16e6.txt.gz",
82
+ ],
83
+ "load_path": {
84
+ "open_clip": "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
85
+ "open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
86
+ },
87
+ "model_class": CLIPScore
88
+ },
89
+ "HPSv2": {
90
+ "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
91
+ "allow_file_pattern": [
92
+ "HPS_v2/HPS_v2_compressed.safetensors",
93
+ "bpe_simple_vocab_16e6.txt.gz",
94
+ ],
95
+ "load_path": {
96
+ "hpsv2": "HPS_v2/HPS_v2_compressed.safetensors",
97
+ "open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
98
+ },
99
+ "model_class": HPScore_v2,
100
+ "extra_kwargs": {"model_version": "v2"}
101
+ },
102
+ "HPSv2.1": {
103
+ "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
104
+ "allow_file_pattern": [
105
+ "HPS_v2/HPS_v2.1_compressed.safetensors",
106
+ "bpe_simple_vocab_16e6.txt.gz",
107
+ ],
108
+ "load_path": {
109
+ "hpsv2.1": "HPS_v2/HPS_v2.1_compressed.safetensors",
110
+ "open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
111
+ },
112
+ "model_class": HPScore_v2,
113
+ "extra_kwargs": {"model_version": "v21"}
114
+ },
115
+ "MPS": {
116
+ "model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
117
+ "allow_file_pattern": [
118
+ "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
119
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
120
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
121
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
122
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
123
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
124
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
125
+ "CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
126
+ ],
127
+ "load_path": {
128
+ "mps": "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
129
+ "clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
130
+ },
131
+ "model_class": MPScore
132
+ },
133
+ }
134
+
135
+
136
+ def download_preference_model(model_name: preference_model_id, cache_dir="models"):
137
+ metadata = model_dict[model_name]
138
+ snapshot_download(model_id=metadata["model_id"], allow_file_pattern=metadata["allow_file_pattern"], cache_dir=cache_dir)
139
+ load_path = metadata["load_path"]
140
+ load_path = {key: os.path.join(cache_dir, metadata["model_id"], path) for key, path in load_path.items()}
141
+ return load_path
142
+
143
+
144
+ def load_preference_model(model_name: preference_model_id, device = "cuda", path = None):
145
+ model_class = model_dict[model_name]["model_class"]
146
+ extra_kwargs = model_dict[model_name].get("extra_kwargs", {})
147
+ preference_model = model_class(device=device, path=path, **extra_kwargs)
148
+ return preference_model
diffsynth/extensions/ImageQualityMetric/aesthetic.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ from PIL import Image
3
+ import torch
4
+ from transformers import AutoProcessor, AutoModel
5
+ from safetensors.torch import load_file
6
+ import os
7
+ from typing import Union, List
8
+ from .config import MODEL_PATHS
9
+
10
+ class MLP(torch.nn.Module):
11
+ def __init__(self, input_size: int, xcol: str = "emb", ycol: str = "avg_rating"):
12
+ super().__init__()
13
+ self.input_size = input_size
14
+ self.xcol = xcol
15
+ self.ycol = ycol
16
+ self.layers = torch.nn.Sequential(
17
+ torch.nn.Linear(self.input_size, 1024),
18
+ #torch.nn.ReLU(),
19
+ torch.nn.Dropout(0.2),
20
+ torch.nn.Linear(1024, 128),
21
+ #torch.nn.ReLU(),
22
+ torch.nn.Dropout(0.2),
23
+ torch.nn.Linear(128, 64),
24
+ #torch.nn.ReLU(),
25
+ torch.nn.Dropout(0.1),
26
+ torch.nn.Linear(64, 16),
27
+ #torch.nn.ReLU(),
28
+ torch.nn.Linear(16, 1),
29
+ )
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ return self.layers(x)
33
+
34
+ def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
35
+ x = batch[self.xcol]
36
+ y = batch[self.ycol].reshape(-1, 1)
37
+ x_hat = self.layers(x)
38
+ loss = torch.nn.functional.mse_loss(x_hat, y)
39
+ return loss
40
+
41
+ def validation_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
42
+ x = batch[self.xcol]
43
+ y = batch[self.ycol].reshape(-1, 1)
44
+ x_hat = self.layers(x)
45
+ loss = torch.nn.functional.mse_loss(x_hat, y)
46
+ return loss
47
+
48
+ def configure_optimizers(self) -> torch.optim.Optimizer:
49
+ return torch.optim.Adam(self.parameters(), lr=1e-3)
50
+
51
+
52
+ class AestheticScore(torch.nn.Module):
53
+ def __init__(self, device: torch.device, path: str = MODEL_PATHS):
54
+ super().__init__()
55
+ self.device = device
56
+ self.aes_model_path = path.get("aesthetic_predictor")
57
+ # Load the MLP model
58
+ self.model = MLP(768)
59
+ try:
60
+ if self.aes_model_path.endswith(".safetensors"):
61
+ state_dict = load_file(self.aes_model_path)
62
+ else:
63
+ state_dict = torch.load(self.aes_model_path)
64
+ self.model.load_state_dict(state_dict)
65
+ except Exception as e:
66
+ raise ValueError(f"Error loading model weights from {self.aes_model_path}: {e}")
67
+
68
+ self.model.to(device)
69
+ self.model.eval()
70
+
71
+ # Load the CLIP model and processor
72
+ clip_model_name = path.get('clip-large')
73
+ self.model2 = AutoModel.from_pretrained(clip_model_name).eval().to(device)
74
+ self.processor = AutoProcessor.from_pretrained(clip_model_name)
75
+
76
+ def _calculate_score(self, image: torch.Tensor) -> float:
77
+ """Calculate the aesthetic score for a single image.
78
+
79
+ Args:
80
+ image (torch.Tensor): The processed image tensor.
81
+
82
+ Returns:
83
+ float: The aesthetic score.
84
+ """
85
+ with torch.no_grad():
86
+ # Get image embeddings
87
+ image_embs = self.model2.get_image_features(image)
88
+ image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
89
+
90
+ # Compute score
91
+ score = self.model(image_embs).cpu().flatten().item()
92
+
93
+ return score
94
+
95
+ @torch.no_grad()
96
+ def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
97
+ """Score the images based on their aesthetic quality.
98
+
99
+ Args:
100
+ images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
101
+
102
+ Returns:
103
+ List[float]: List of scores for the images.
104
+ """
105
+ try:
106
+ if isinstance(images, (str, Image.Image)):
107
+ # Single image
108
+ if isinstance(images, str):
109
+ pil_image = Image.open(images)
110
+ else:
111
+ pil_image = images
112
+
113
+ # Prepare image inputs
114
+ image_inputs = self.processor(
115
+ images=pil_image,
116
+ padding=True,
117
+ truncation=True,
118
+ max_length=77,
119
+ return_tensors="pt",
120
+ ).to(self.device)
121
+
122
+ return [self._calculate_score(image_inputs["pixel_values"])]
123
+ elif isinstance(images, list):
124
+ # Multiple images
125
+ scores = []
126
+ for one_image in images:
127
+ if isinstance(one_image, str):
128
+ pil_image = Image.open(one_image)
129
+ elif isinstance(one_image, Image.Image):
130
+ pil_image = one_image
131
+ else:
132
+ raise TypeError("The type of parameter images is illegal.")
133
+
134
+ # Prepare image inputs
135
+ image_inputs = self.processor(
136
+ images=pil_image,
137
+ padding=True,
138
+ truncation=True,
139
+ max_length=77,
140
+ return_tensors="pt",
141
+ ).to(self.device)
142
+
143
+ scores.append(self._calculate_score(image_inputs["pixel_values"]))
144
+ return scores
145
+ else:
146
+ raise TypeError("The type of parameter images is illegal.")
147
+ except Exception as e:
148
+ raise RuntimeError(f"Error in scoring images: {e}")
diffsynth/extensions/ImageQualityMetric/clip.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+ from PIL import Image
3
+ import torch
4
+ from .open_clip import create_model_and_transforms, get_tokenizer
5
+ from .config import MODEL_PATHS
6
+
7
+ class CLIPScore(torch.nn.Module):
8
+ def __init__(self, device: torch.device, path: str = MODEL_PATHS):
9
+ super().__init__()
10
+ """Initialize the CLIPScore with a model and tokenizer.
11
+
12
+ Args:
13
+ device (torch.device): The device to load the model on.
14
+ """
15
+ self.device = device
16
+
17
+ # Create model and transforms
18
+ self.model, _, self.preprocess_val = create_model_and_transforms(
19
+ "ViT-H-14",
20
+ # "laion2B-s32B-b79K",
21
+ pretrained=path.get("open_clip"),
22
+ precision="amp",
23
+ device=device,
24
+ jit=False,
25
+ force_quick_gelu=False,
26
+ force_custom_text=False,
27
+ force_patch_dropout=False,
28
+ force_image_size=None,
29
+ pretrained_image=False,
30
+ image_mean=None,
31
+ image_std=None,
32
+ light_augmentation=True,
33
+ aug_cfg={},
34
+ output_dict=True,
35
+ with_score_predictor=False,
36
+ with_region_predictor=False,
37
+ )
38
+
39
+ # Initialize tokenizer
40
+ self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
41
+ self.model = self.model.to(device)
42
+ self.model.eval()
43
+
44
+ def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
45
+ """Calculate the CLIP score for a single image and prompt.
46
+
47
+ Args:
48
+ image (torch.Tensor): The processed image tensor.
49
+ prompt (str): The prompt text.
50
+
51
+ Returns:
52
+ float: The CLIP score.
53
+ """
54
+ with torch.no_grad():
55
+ # Process the prompt
56
+ text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
57
+
58
+ # Calculate the CLIP score
59
+ outputs = self.model(image, text)
60
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
61
+ logits_per_image = image_features @ text_features.T
62
+ clip_score = torch.diagonal(logits_per_image).cpu().numpy()
63
+
64
+ return clip_score[0].item()
65
+
66
+ @torch.no_grad()
67
+ def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
68
+ """Score the images based on the prompt.
69
+
70
+ Args:
71
+ images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
72
+ prompt (str): The prompt text.
73
+
74
+ Returns:
75
+ List[float]: List of CLIP scores for the images.
76
+ """
77
+ if isinstance(images, (str, Image.Image)):
78
+ # Single image
79
+ if isinstance(images, str):
80
+ image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
81
+ else:
82
+ image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
83
+ return [self._calculate_score(image, prompt)]
84
+ elif isinstance(images, list):
85
+ # Multiple images
86
+ scores = []
87
+ for one_images in images:
88
+ if isinstance(one_images, str):
89
+ image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
90
+ elif isinstance(one_images, Image.Image):
91
+ image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
92
+ else:
93
+ raise TypeError("The type of parameter images is illegal.")
94
+ scores.append(self._calculate_score(image, prompt))
95
+ return scores
96
+ else:
97
+ raise TypeError("The type of parameter images is illegal.")
diffsynth/extensions/ImageQualityMetric/config.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ current_dir = os.path.dirname(os.path.abspath(__file__))
4
+ project_root = os.path.abspath(os.path.join(current_dir, '../../../'))
5
+ model_path = os.path.join(project_root, 'models', 'QualityMetric')
6
+
7
+
8
+ def get_model_path(model_name):
9
+ return os.path.join(model_path, model_name)
10
+
11
+
12
+ MODEL_PATHS = {
13
+ "aesthetic_predictor": get_model_path("aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors"),
14
+ "open_clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"),
15
+ "hpsv2": get_model_path("HPS_v2/HPS_v2_compressed.safetensors"),
16
+ "hpsv2.1": get_model_path("HPS_v2/HPS_v2.1_compressed.safetensors"),
17
+ "imagereward": get_model_path("ImageReward/ImageReward.safetensors"),
18
+ "med_config": get_model_path("ImageReward/med_config.json"),
19
+ "clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K"),
20
+ "clip-large": get_model_path("clip-vit-large-patch14"),
21
+ "mps": get_model_path("MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors"),
22
+ "pickscore": get_model_path("PickScore_v1")
23
+ }
diffsynth/extensions/ImageQualityMetric/hps.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+ from PIL import Image
3
+ import torch
4
+ from .open_clip import create_model_and_transforms, get_tokenizer
5
+ from safetensors.torch import load_file
6
+ import os
7
+ from .config import MODEL_PATHS
8
+
9
+ class HPScore_v2(torch.nn.Module):
10
+ def __init__(self, device: torch.device, path: str = MODEL_PATHS, model_version: str = "v2"):
11
+ super().__init__()
12
+ """Initialize the Selector with a model and tokenizer.
13
+
14
+ Args:
15
+ device (torch.device): The device to load the model on.
16
+ model_version (str): The version of the model to load. Supports "v2" or "v21". Default is "v2".
17
+ """
18
+ self.device = device
19
+
20
+ if model_version == "v2":
21
+ safetensors_path = path.get("hpsv2")
22
+ elif model_version == "v21":
23
+ safetensors_path = path.get("hpsv2.1")
24
+ else:
25
+ raise ValueError(f"Unsupported model version: {model_version}. Choose 'v2' or 'v21'.")
26
+
27
+ # Create model and transforms
28
+ model, _, self.preprocess_val = create_model_and_transforms(
29
+ "ViT-H-14",
30
+ # "laion2B-s32B-b79K",
31
+ pretrained=path.get("open_clip"),
32
+ precision="amp",
33
+ device=device,
34
+ jit=False,
35
+ force_quick_gelu=False,
36
+ force_custom_text=False,
37
+ force_patch_dropout=False,
38
+ force_image_size=None,
39
+ pretrained_image=False,
40
+ image_mean=None,
41
+ image_std=None,
42
+ light_augmentation=True,
43
+ aug_cfg={},
44
+ output_dict=True,
45
+ with_score_predictor=False,
46
+ with_region_predictor=False,
47
+ )
48
+
49
+ # Load model weights
50
+ try:
51
+ state_dict = load_file(safetensors_path)
52
+ model.load_state_dict(state_dict)
53
+ except Exception as e:
54
+ raise ValueError(f"Error loading model weights from {safetensors_path}: {e}")
55
+
56
+ # Initialize tokenizer and model
57
+ self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
58
+ model = model.to(device)
59
+ model.eval()
60
+ self.model = model
61
+
62
+ def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
63
+ """Calculate the HPS score for a single image and prompt.
64
+
65
+ Args:
66
+ image (torch.Tensor): The processed image tensor.
67
+ prompt (str): The prompt text.
68
+
69
+ Returns:
70
+ float: The HPS score.
71
+ """
72
+ with torch.no_grad():
73
+ # Process the prompt
74
+ text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
75
+
76
+ # Calculate the HPS score
77
+ outputs = self.model(image, text)
78
+ image_features, text_features = outputs["image_features"], outputs["text_features"]
79
+ logits_per_image = image_features @ text_features.T
80
+ hps_score = torch.diagonal(logits_per_image).cpu().numpy()
81
+
82
+ return hps_score[0].item()
83
+
84
+ @torch.no_grad()
85
+ def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
86
+ """Score the images based on the prompt.
87
+
88
+ Args:
89
+ images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
90
+ prompt (str): The prompt text.
91
+
92
+ Returns:
93
+ List[float]: List of HPS scores for the images.
94
+ """
95
+ try:
96
+ if isinstance(images, (str, Image.Image)):
97
+ # Single image
98
+ if isinstance(images, str):
99
+ image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
100
+ else:
101
+ image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
102
+ return [self._calculate_score(image, prompt)]
103
+ elif isinstance(images, list):
104
+ # Multiple images
105
+ scores = []
106
+ for one_images in images:
107
+ if isinstance(one_images, str):
108
+ image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
109
+ elif isinstance(one_images, Image.Image):
110
+ image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
111
+ else:
112
+ raise TypeError("The type of parameter images is illegal.")
113
+ scores.append(self._calculate_score(image, prompt))
114
+ return scores
115
+ else:
116
+ raise TypeError("The type of parameter images is illegal.")
117
+ except Exception as e:
118
+ raise RuntimeError(f"Error in scoring images: {e}")
diffsynth/extensions/ImageQualityMetric/imagereward.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from PIL import Image
4
+ from typing import List, Union
5
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
6
+ from .BLIP.blip_pretrain import BLIP_Pretrain
7
+ from torchvision.transforms import InterpolationMode
8
+ from safetensors.torch import load_file
9
+ from .config import MODEL_PATHS
10
+ BICUBIC = InterpolationMode.BICUBIC
11
+
12
+ def _convert_image_to_rgb(image):
13
+ return image.convert("RGB")
14
+
15
+ def _transform(n_px):
16
+ return Compose([
17
+ Resize(n_px, interpolation=BICUBIC),
18
+ CenterCrop(n_px),
19
+ _convert_image_to_rgb,
20
+ ToTensor(),
21
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
22
+ ])
23
+
24
+ class MLP(torch.nn.Module):
25
+ def __init__(self, input_size):
26
+ super().__init__()
27
+ self.input_size = input_size
28
+
29
+ self.layers = torch.nn.Sequential(
30
+ torch.nn.Linear(self.input_size, 1024),
31
+ #nn.ReLU(),
32
+ torch.nn.Dropout(0.2),
33
+ torch.nn.Linear(1024, 128),
34
+ #nn.ReLU(),
35
+ torch.nn.Dropout(0.2),
36
+ torch.nn.Linear(128, 64),
37
+ #nn.ReLU(),
38
+ torch.nn.Dropout(0.1),
39
+ torch.nn.Linear(64, 16),
40
+ #nn.ReLU(),
41
+ torch.nn.Linear(16, 1)
42
+ )
43
+
44
+ # initial MLP param
45
+ for name, param in self.layers.named_parameters():
46
+ if 'weight' in name:
47
+ torch.nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1))
48
+ if 'bias' in name:
49
+ torch.nn.init.constant_(param, val=0)
50
+
51
+ def forward(self, input):
52
+ return self.layers(input)
53
+
54
+ class ImageReward(torch.nn.Module):
55
+ def __init__(self, med_config, device='cpu', bert_model_path=""):
56
+ super().__init__()
57
+ self.device = device
58
+
59
+ self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config, bert_model_path=bert_model_path)
60
+ self.preprocess = _transform(224)
61
+ self.mlp = MLP(768)
62
+
63
+ self.mean = 0.16717362830052426
64
+ self.std = 1.0333394966054072
65
+
66
+ def score_grad(self, prompt_ids, prompt_attention_mask, image):
67
+ """Calculate the score with gradient for a single image and prompt.
68
+
69
+ Args:
70
+ prompt_ids (torch.Tensor): Tokenized prompt IDs.
71
+ prompt_attention_mask (torch.Tensor): Attention mask for the prompt.
72
+ image (torch.Tensor): The processed image tensor.
73
+
74
+ Returns:
75
+ torch.Tensor: The reward score.
76
+ """
77
+ image_embeds = self.blip.visual_encoder(image)
78
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
79
+ text_output = self.blip.text_encoder(
80
+ prompt_ids,
81
+ attention_mask=prompt_attention_mask,
82
+ encoder_hidden_states=image_embeds,
83
+ encoder_attention_mask=image_atts,
84
+ return_dict=True,
85
+ )
86
+ txt_features = text_output.last_hidden_state[:, 0, :]
87
+ rewards = self.mlp(txt_features)
88
+ rewards = (rewards - self.mean) / self.std
89
+ return rewards
90
+
91
+ def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
92
+ """Score the images based on the prompt.
93
+
94
+ Args:
95
+ prompt (str): The prompt text.
96
+ images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
97
+
98
+ Returns:
99
+ List[float]: List of scores for the images.
100
+ """
101
+ if isinstance(images, (str, Image.Image)):
102
+ # Single image
103
+ if isinstance(images, str):
104
+ pil_image = Image.open(images)
105
+ else:
106
+ pil_image = images
107
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
108
+ return [self._calculate_score(prompt, image).item()]
109
+ elif isinstance(images, list):
110
+ # Multiple images
111
+ scores = []
112
+ for one_image in images:
113
+ if isinstance(one_image, str):
114
+ pil_image = Image.open(one_image)
115
+ elif isinstance(one_image, Image.Image):
116
+ pil_image = one_image
117
+ else:
118
+ raise TypeError("The type of parameter images is illegal.")
119
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
120
+ scores.append(self._calculate_score(prompt, image).item())
121
+ return scores
122
+ else:
123
+ raise TypeError("The type of parameter images is illegal.")
124
+
125
+ def _calculate_score(self, prompt: str, image: torch.Tensor) -> torch.Tensor:
126
+ """Calculate the score for a single image and prompt.
127
+
128
+ Args:
129
+ prompt (str): The prompt text.
130
+ image (torch.Tensor): The processed image tensor.
131
+
132
+ Returns:
133
+ torch.Tensor: The reward score.
134
+ """
135
+ text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
136
+ image_embeds = self.blip.visual_encoder(image)
137
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
138
+ text_output = self.blip.text_encoder(
139
+ text_input.input_ids,
140
+ attention_mask=text_input.attention_mask,
141
+ encoder_hidden_states=image_embeds,
142
+ encoder_attention_mask=image_atts,
143
+ return_dict=True,
144
+ )
145
+ txt_features = text_output.last_hidden_state[:, 0, :].float()
146
+ rewards = self.mlp(txt_features)
147
+ rewards = (rewards - self.mean) / self.std
148
+ return rewards
149
+
150
+ def inference_rank(self, prompt: str, generations_list: List[Union[str, Image.Image]]) -> tuple:
151
+ """Rank the images based on the prompt.
152
+
153
+ Args:
154
+ prompt (str): The prompt text.
155
+ generations_list (List[Union[str, Image.Image]]): List of image paths or PIL images.
156
+
157
+ Returns:
158
+ tuple: (indices, rewards) where indices are the ranks and rewards are the scores.
159
+ """
160
+ text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
161
+ txt_set = []
162
+ for generation in generations_list:
163
+ if isinstance(generation, str):
164
+ pil_image = Image.open(generation)
165
+ elif isinstance(generation, Image.Image):
166
+ pil_image = generation
167
+ else:
168
+ raise TypeError("The type of parameter generations_list is illegal.")
169
+ image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
170
+ image_embeds = self.blip.visual_encoder(image)
171
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
172
+ text_output = self.blip.text_encoder(
173
+ text_input.input_ids,
174
+ attention_mask=text_input.attention_mask,
175
+ encoder_hidden_states=image_embeds,
176
+ encoder_attention_mask=image_atts,
177
+ return_dict=True,
178
+ )
179
+ txt_set.append(text_output.last_hidden_state[:, 0, :])
180
+ txt_features = torch.cat(txt_set, 0).float()
181
+ rewards = self.mlp(txt_features)
182
+ rewards = (rewards - self.mean) / self.std
183
+ rewards = torch.squeeze(rewards)
184
+ _, rank = torch.sort(rewards, dim=0, descending=True)
185
+ _, indices = torch.sort(rank, dim=0)
186
+ indices = indices + 1
187
+ return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
188
+
189
+
190
+ class ImageRewardScore(torch.nn.Module):
191
+ def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
192
+ super().__init__()
193
+ self.device = device if isinstance(device, torch.device) else torch.device(device)
194
+ model_path = path.get("imagereward")
195
+ med_config = path.get("med_config")
196
+ state_dict = load_file(model_path)
197
+ self.model = ImageReward(device=self.device, med_config=med_config, bert_model_path=path.get("bert_model_path")).to(self.device)
198
+ self.model.load_state_dict(state_dict, strict=False)
199
+ self.model.eval()
200
+
201
+ @torch.no_grad()
202
+ def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
203
+ """Score the images based on the prompt.
204
+
205
+ Args:
206
+ images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
207
+ prompt (str): The prompt text.
208
+
209
+ Returns:
210
+ List[float]: List of scores for the images.
211
+ """
212
+ return self.model.score(images, prompt)
diffsynth/extensions/ImageQualityMetric/mps.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ from io import BytesIO
5
+ from tqdm.auto import tqdm
6
+ from transformers import CLIPFeatureExtractor, CLIPImageProcessor
7
+ from transformers import CLIPConfig
8
+ from dataclasses import dataclass
9
+ from transformers import CLIPModel as HFCLIPModel
10
+ from safetensors.torch import load_file
11
+ from torch import nn, einsum
12
+
13
+ from .trainer.models.base_model import BaseModelConfig
14
+
15
+ from transformers import CLIPConfig
16
+ from transformers import AutoProcessor, AutoModel, AutoTokenizer
17
+ from typing import Any, Optional, Tuple, Union, List
18
+ import torch
19
+
20
+ from .trainer.models.cross_modeling import Cross_model
21
+ from .trainer.models import clip_model
22
+ import torch.nn.functional as F
23
+ import gc
24
+ import json
25
+ from .config import MODEL_PATHS
26
+
27
+ class MPScore(torch.nn.Module):
28
+ def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS, condition: str = 'overall'):
29
+ super().__init__()
30
+ """Initialize the MPSModel with a processor, tokenizer, and model.
31
+
32
+ Args:
33
+ device (Union[str, torch.device]): The device to load the model on.
34
+ """
35
+ self.device = device
36
+ processor_name_or_path = path.get("clip")
37
+ self.image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)
38
+ self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
39
+ self.model = clip_model.CLIPModel(processor_name_or_path, config_file=True)
40
+ state_dict = load_file(path.get("mps"))
41
+ self.model.load_state_dict(state_dict, strict=False)
42
+ self.model.to(device)
43
+ self.condition = condition
44
+
45
+ def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
46
+ """Calculate the reward score for a single image and prompt.
47
+
48
+ Args:
49
+ image (torch.Tensor): The processed image tensor.
50
+ prompt (str): The prompt text.
51
+
52
+ Returns:
53
+ float: The reward score.
54
+ """
55
+ def _tokenize(caption):
56
+ input_ids = self.tokenizer(
57
+ caption,
58
+ max_length=self.tokenizer.model_max_length,
59
+ padding="max_length",
60
+ truncation=True,
61
+ return_tensors="pt"
62
+ ).input_ids
63
+ return input_ids
64
+
65
+ text_input = _tokenize(prompt).to(self.device)
66
+ if self.condition == 'overall':
67
+ condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things'
68
+ elif self.condition == 'aesthetics':
69
+ condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry'
70
+ elif self.condition == 'quality':
71
+ condition_prompt = 'shape, face, hair, hands, limbs, structure, instance, texture'
72
+ elif self.condition == 'semantic':
73
+ condition_prompt = 'quantity, attributes, position, number, location'
74
+ else:
75
+ raise ValueError(
76
+ f"Unsupported condition: {self.condition}. Choose 'overall', 'aesthetics', 'quality', or 'semantic'.")
77
+ condition_batch = _tokenize(condition_prompt).repeat(text_input.shape[0], 1).to(self.device)
78
+
79
+ with torch.no_grad():
80
+ text_f, text_features = self.model.model.get_text_features(text_input)
81
+
82
+ image_f = self.model.model.get_image_features(image.half())
83
+ condition_f, _ = self.model.model.get_text_features(condition_batch)
84
+
85
+ sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
86
+ sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
87
+ sim_text_condition = sim_text_condition / sim_text_condition.max()
88
+ mask = torch.where(sim_text_condition > 0.3, 0, float('-inf'))
89
+ mask = mask.repeat(1, image_f.shape[1], 1)
90
+ image_features = self.model.cross_model(image_f, text_f, mask.half())[:, 0, :]
91
+
92
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
93
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
94
+ image_score = self.model.logit_scale.exp() * text_features @ image_features.T
95
+
96
+ return image_score[0].cpu().numpy().item()
97
+
98
+ @torch.no_grad()
99
+ def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
100
+ """Score the images based on the prompt.
101
+
102
+ Args:
103
+ images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
104
+ prompt (str): The prompt text.
105
+
106
+ Returns:
107
+ List[float]: List of reward scores for the images.
108
+ """
109
+ if isinstance(images, (str, Image.Image)):
110
+ # Single image
111
+ if isinstance(images, str):
112
+ image = self.image_processor(Image.open(images), return_tensors="pt")["pixel_values"].to(self.device)
113
+ else:
114
+ image = self.image_processor(images, return_tensors="pt")["pixel_values"].to(self.device)
115
+ return [self._calculate_score(image, prompt)]
116
+ elif isinstance(images, list):
117
+ # Multiple images
118
+ scores = []
119
+ for one_images in images:
120
+ if isinstance(one_images, str):
121
+ image = self.image_processor(Image.open(one_images), return_tensors="pt")["pixel_values"].to(self.device)
122
+ elif isinstance(one_images, Image.Image):
123
+ image = self.image_processor(one_images, return_tensors="pt")["pixel_values"].to(self.device)
124
+ else:
125
+ raise TypeError("The type of parameter images is illegal.")
126
+ scores.append(self._calculate_score(image, prompt))
127
+ return scores
128
+ else:
129
+ raise TypeError("The type of parameter images is illegal.")
diffsynth/extensions/ImageQualityMetric/open_clip/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .coca_model import CoCa
2
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
3
+ from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
4
+ from .factory import list_models, add_model_config, get_model_config, load_checkpoint
5
+ from .loss import ClipLoss, DistillClipLoss, CoCaLoss
6
+ from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
7
+ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
8
+ from .openai import load_openai_model, list_openai_models
9
+ from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
10
+ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
11
+ from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
12
+ from .tokenizer import SimpleTokenizer
13
+ from .transform import image_transform, AugmentationCfg
14
+ from .utils import freeze_batch_norm_2d
diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ import numpy as np
7
+ from dataclasses import dataclass
8
+
9
+ from .transformer import (
10
+ LayerNormFp32,
11
+ LayerNorm,
12
+ QuickGELU,
13
+ MultimodalTransformer,
14
+ )
15
+ from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
16
+
17
+ try:
18
+ from transformers import (
19
+ BeamSearchScorer,
20
+ LogitsProcessorList,
21
+ TopPLogitsWarper,
22
+ TopKLogitsWarper,
23
+ RepetitionPenaltyLogitsProcessor,
24
+ MinLengthLogitsProcessor,
25
+ MaxLengthCriteria,
26
+ StoppingCriteriaList
27
+ )
28
+
29
+ GENERATION_TYPES = {
30
+ "top_k": TopKLogitsWarper,
31
+ "top_p": TopPLogitsWarper,
32
+ "beam_search": "beam_search"
33
+ }
34
+ _has_transformers = True
35
+ except ImportError as e:
36
+ GENERATION_TYPES = {
37
+ "top_k": None,
38
+ "top_p": None,
39
+ "beam_search": "beam_search"
40
+ }
41
+ _has_transformers = False
42
+
43
+
44
+ @dataclass
45
+ class MultimodalCfg(CLIPTextCfg):
46
+ mlp_ratio: int = 4
47
+ dim_head: int = 64
48
+ heads: int = 8
49
+ n_queries: int = 256
50
+ attn_pooler_heads: int = 8
51
+
52
+
53
+ def _build_text_decoder_tower(
54
+ embed_dim,
55
+ multimodal_cfg,
56
+ quick_gelu: bool = False,
57
+ cast_dtype: Optional[torch.dtype] = None,
58
+ ):
59
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
60
+ act_layer = QuickGELU if quick_gelu else nn.GELU
61
+ norm_layer = (
62
+ LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
63
+ )
64
+
65
+ decoder = MultimodalTransformer(
66
+ context_length=multimodal_cfg.context_length,
67
+ width=multimodal_cfg.width,
68
+ heads=multimodal_cfg.heads,
69
+ layers=multimodal_cfg.layers,
70
+ ls_init_value=multimodal_cfg.ls_init_value,
71
+ output_dim=embed_dim,
72
+ act_layer=act_layer,
73
+ norm_layer=norm_layer,
74
+ )
75
+
76
+ return decoder
77
+
78
+
79
+ class CoCa(nn.Module):
80
+ def __init__(
81
+ self,
82
+ embed_dim,
83
+ multimodal_cfg: MultimodalCfg,
84
+ text_cfg: CLIPTextCfg,
85
+ vision_cfg: CLIPVisionCfg,
86
+ quick_gelu: bool = False,
87
+ cast_dtype: Optional[torch.dtype] = None,
88
+ pad_id: int = 0,
89
+ ):
90
+ super().__init__()
91
+ multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
92
+ text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
93
+ vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
94
+
95
+ self.text = _build_text_tower(
96
+ embed_dim=embed_dim,
97
+ text_cfg=text_cfg,
98
+ quick_gelu=quick_gelu,
99
+ cast_dtype=cast_dtype,
100
+ )
101
+
102
+ vocab_size = (
103
+ text_cfg.vocab_size # for hf models
104
+ if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
105
+ else text_cfg.vocab_size
106
+ )
107
+
108
+ self.visual = _build_vision_tower(
109
+ embed_dim=embed_dim,
110
+ vision_cfg=vision_cfg,
111
+ quick_gelu=quick_gelu,
112
+ cast_dtype=cast_dtype,
113
+ )
114
+
115
+ self.text_decoder = _build_text_decoder_tower(
116
+ vocab_size,
117
+ multimodal_cfg=multimodal_cfg,
118
+ quick_gelu=quick_gelu,
119
+ cast_dtype=cast_dtype,
120
+ )
121
+
122
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
123
+ self.pad_id = pad_id
124
+
125
+ @torch.jit.ignore
126
+ def set_grad_checkpointing(self, enable=True):
127
+ self.visual.set_grad_checkpointing(enable)
128
+ self.text.set_grad_checkpointing(enable)
129
+ self.text_decoder.set_grad_checkpointing(enable)
130
+
131
+ def _encode_image(self, images, normalize=True):
132
+ image_latent, tokens_embs = self.visual(images)
133
+ image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
134
+ return image_latent, tokens_embs
135
+
136
+ def _encode_text(self, text, normalize=True, embed_cls=True):
137
+ text = text[:, :-1] if embed_cls else text # make space for CLS token
138
+ text_latent, token_emb = self.text(text)
139
+ text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
140
+ return text_latent, token_emb
141
+
142
+ def encode_image(self, images, normalize=True):
143
+ image_latent, _ = self._encode_image(images, normalize=normalize)
144
+ return image_latent
145
+
146
+ def encode_text(self, text, normalize=True, embed_cls=True):
147
+ text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
148
+ return text_latent
149
+
150
+ def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
151
+ text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
152
+ if image_latent is None or image_embs is None:
153
+ image_latent, image_embs = self._encode_image(image)
154
+
155
+ # TODO: add assertion to avoid bugs?
156
+ labels = text[:, -token_embs.shape[1]:]
157
+
158
+ logits = self.text_decoder(image_embs, token_embs)
159
+ return {
160
+ "image_features": image_latent,
161
+ "text_features": text_latent,
162
+ "logits": logits,
163
+ "labels": labels,
164
+ "logit_scale": self.logit_scale.exp()
165
+ }
166
+
167
+ def generate(
168
+ self,
169
+ image,
170
+ text=None,
171
+ seq_len=30,
172
+ max_seq_len=77,
173
+ temperature=1.,
174
+ generation_type="beam_search",
175
+ top_p=0.1, # keep tokens in the 1 - top_p quantile
176
+ top_k=1, # keeps the top_k most probable tokens
177
+ pad_token_id=None,
178
+ eos_token_id=None,
179
+ sot_token_id=None,
180
+ num_beams=6,
181
+ num_beam_groups=3,
182
+ min_seq_len=5,
183
+ stopping_criteria=None,
184
+ repetition_penalty=1.0,
185
+ fixed_output_length=False # if True output.shape == (batch_size, seq_len)
186
+ ):
187
+ # taking many ideas and components from HuggingFace GenerationMixin
188
+ # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
189
+ assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
190
+ assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
191
+
192
+ with torch.no_grad():
193
+ sot_token_id = 49406 if sot_token_id is None else sot_token_id
194
+ eos_token_id = 49407 if eos_token_id is None else eos_token_id
195
+ pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
196
+ logit_processor = LogitsProcessorList(
197
+ [
198
+ MinLengthLogitsProcessor(min_seq_len, eos_token_id),
199
+ RepetitionPenaltyLogitsProcessor(repetition_penalty),
200
+ ]
201
+ )
202
+
203
+ if stopping_criteria is None:
204
+ stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
205
+
206
+ stopping_criteria = StoppingCriteriaList(
207
+ stopping_criteria
208
+ )
209
+
210
+ device = image.device
211
+
212
+ if generation_type == "beam_search":
213
+ output = self._generate_beamsearch(
214
+ image_inputs = image,
215
+ pad_token_id=pad_token_id,
216
+ eos_token_id=eos_token_id,
217
+ sot_token_id=sot_token_id,
218
+ num_beams=num_beams,
219
+ num_beam_groups=num_beam_groups,
220
+ min_seq_len=min_seq_len,
221
+ stopping_criteria=stopping_criteria,
222
+ logit_processor=logit_processor,
223
+ )
224
+ if fixed_output_length and output.shape[1] < seq_len:
225
+ return torch.cat(
226
+ (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),
227
+ dim=1
228
+ )
229
+ return output
230
+
231
+ elif generation_type == "top_p":
232
+ logit_warper = GENERATION_TYPES[generation_type](top_p)
233
+ elif generation_type == "top_k":
234
+ logit_warper = GENERATION_TYPES[generation_type](top_k)
235
+ else:
236
+ raise ValueError(
237
+ f"generation_type has to be one of "
238
+ f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
239
+ )
240
+
241
+ image_latent, image_embs = self._encode_image(image)
242
+
243
+ if text is None:
244
+ text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
245
+
246
+ was_training = self.training
247
+ num_dims = len(text.shape)
248
+
249
+ if num_dims == 1:
250
+ text = text[None, :]
251
+
252
+ cur_len = text.shape[1]
253
+ self.eval()
254
+ out = text
255
+
256
+ while True:
257
+ x = out[:, -max_seq_len:]
258
+ cur_len = x.shape[1]
259
+ logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1]
260
+ mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
261
+ sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
262
+
263
+ if mask.all():
264
+ if not fixed_output_length:
265
+ break
266
+ else:
267
+ logits = logits[~mask, :]
268
+ filtered_logits = logit_processor(x[~mask, :], logits)
269
+ filtered_logits = logit_warper(x[~mask, :], filtered_logits)
270
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
271
+
272
+ if (cur_len + 1 == seq_len):
273
+ sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
274
+ else:
275
+ sample[~mask, :] = torch.multinomial(probs, 1)
276
+
277
+ out = torch.cat((out, sample), dim=-1)
278
+
279
+ cur_len += 1
280
+
281
+ if stopping_criteria(out, None):
282
+ break
283
+
284
+ if num_dims == 1:
285
+ out = out.squeeze(0)
286
+
287
+ self.train(was_training)
288
+ return out
289
+
290
+ def _generate_beamsearch(
291
+ self,
292
+ image_inputs,
293
+ pad_token_id=None,
294
+ eos_token_id=None,
295
+ sot_token_id=None,
296
+ num_beams=6,
297
+ num_beam_groups=3,
298
+ min_seq_len=5,
299
+ stopping_criteria=None,
300
+ logit_processor=None,
301
+ logit_warper=None,
302
+ ):
303
+ device = image_inputs.device
304
+ batch_size = image_inputs.shape[0]
305
+ image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
306
+ image_latent, image_embs = self._encode_image(image_inputs)
307
+
308
+ input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
309
+ input_ids = input_ids * sot_token_id
310
+ beam_scorer = BeamSearchScorer(
311
+ batch_size=batch_size,
312
+ num_beams=num_beams,
313
+ device=device,
314
+ num_beam_groups=num_beam_groups,
315
+ )
316
+ # instantiate logits processors
317
+ logits_processor = (
318
+ LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
319
+ if logit_processor is None
320
+ else logit_processor
321
+ )
322
+
323
+ batch_size = len(beam_scorer._beam_hyps)
324
+ num_beams = beam_scorer.num_beams
325
+ num_beam_groups = beam_scorer.num_beam_groups
326
+ num_sub_beams = num_beams // num_beam_groups
327
+ batch_beam_size, cur_len = input_ids.shape
328
+ beam_indices = None
329
+
330
+ if num_beams * batch_size != batch_beam_size:
331
+ raise ValueError(
332
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
333
+ )
334
+
335
+ beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
336
+ # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
337
+ # the same group don't produce same tokens everytime.
338
+ beam_scores[:, ::num_sub_beams] = 0
339
+ beam_scores = beam_scores.view((batch_size * num_beams,))
340
+
341
+ while True:
342
+
343
+ # predicted tokens in cur_len step
344
+ current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
345
+
346
+ # indices which will form the beams in the next time step
347
+ reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
348
+
349
+ # do one decoder step on all beams of all sentences in batch
350
+ model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
351
+ outputs = self(
352
+ model_inputs['images'],
353
+ model_inputs['text'],
354
+ embed_cls=False,
355
+ image_latent=image_latent,
356
+ image_embs=image_embs
357
+ )
358
+
359
+ for beam_group_idx in range(num_beam_groups):
360
+ group_start_idx = beam_group_idx * num_sub_beams
361
+ group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
362
+ group_size = group_end_idx - group_start_idx
363
+
364
+ # indices of beams of current group among all sentences in batch
365
+ batch_group_indices = []
366
+
367
+ for batch_idx in range(batch_size):
368
+ batch_group_indices.extend(
369
+ [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
370
+ )
371
+ group_input_ids = input_ids[batch_group_indices]
372
+
373
+ # select outputs of beams of currentg group only
374
+ next_token_logits = outputs['logits'][batch_group_indices, -1, :]
375
+ vocab_size = next_token_logits.shape[-1]
376
+
377
+ next_token_scores_processed = logits_processor(
378
+ group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
379
+ )
380
+ next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
381
+ next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
382
+
383
+ # reshape for beam search
384
+ next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
385
+
386
+ next_token_scores, next_tokens = torch.topk(
387
+ next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
388
+ )
389
+
390
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
391
+ next_tokens = next_tokens % vocab_size
392
+
393
+ # stateless
394
+ process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
395
+ beam_outputs = beam_scorer.process(
396
+ group_input_ids,
397
+ next_token_scores,
398
+ next_tokens,
399
+ next_indices,
400
+ pad_token_id=pad_token_id,
401
+ eos_token_id=eos_token_id,
402
+ beam_indices=process_beam_indices,
403
+ )
404
+ beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
405
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
406
+ beam_idx = beam_outputs["next_beam_indices"]
407
+
408
+ input_ids[batch_group_indices] = group_input_ids[beam_idx]
409
+ group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
410
+ current_tokens[batch_group_indices] = group_input_ids[:, -1]
411
+
412
+ # (beam_idx // group_size) -> batch_idx
413
+ # (beam_idx % group_size) -> offset of idx inside the group
414
+ reordering_indices[batch_group_indices] = (
415
+ num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
416
+ )
417
+
418
+ input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
419
+
420
+ # increase cur_len
421
+ cur_len = cur_len + 1
422
+ if beam_scorer.is_done or stopping_criteria(input_ids, None):
423
+ break
424
+
425
+ final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
426
+ sequence_outputs = beam_scorer.finalize(
427
+ input_ids,
428
+ beam_scores,
429
+ next_tokens,
430
+ next_indices,
431
+ pad_token_id=pad_token_id,
432
+ eos_token_id=eos_token_id,
433
+ max_length=stopping_criteria.max_length,
434
+ beam_indices=final_beam_indices,
435
+ )
436
+ return sequence_outputs['sequences']
437
+
438
+
439
+ def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
440
+ if past:
441
+ input_ids = input_ids[:, -1].unsqueeze(-1)
442
+
443
+ attention_mask = kwargs.get("attention_mask", None)
444
+ position_ids = kwargs.get("position_ids", None)
445
+
446
+ if attention_mask is not None and position_ids is None:
447
+ # create position_ids on the fly for batch generation
448
+ position_ids = attention_mask.long().cumsum(-1) - 1
449
+ position_ids.masked_fill_(attention_mask == 0, 1)
450
+ else:
451
+ position_ids = None
452
+ return {
453
+ "text": input_ids,
454
+ "images": image_inputs,
455
+ "past_key_values": past,
456
+ "position_ids": position_ids,
457
+ "attention_mask": attention_mask,
458
+ }
diffsynth/extensions/ImageQualityMetric/open_clip/constants.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
diffsynth/extensions/ImageQualityMetric/open_clip/factory.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+ # from turtle import forward
9
+ from typing import Any, Dict, Optional, Tuple, Union
10
+
11
+ import torch
12
+
13
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
14
+ from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
15
+ resize_pos_embed, get_cast_dtype
16
+ from .coca_model import CoCa
17
+ from .loss import ClipLoss, DistillClipLoss, CoCaLoss
18
+ from .openai import load_openai_model
19
+ from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
20
+ from .transform import image_transform, AugmentationCfg
21
+ from .tokenizer import HFTokenizer, SimpleTokenizer
22
+
23
+
24
+ HF_HUB_PREFIX = 'hf-hub:'
25
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
26
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
27
+
28
+
29
+ def _natural_key(string_):
30
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
31
+
32
+
33
+ def _rescan_model_configs():
34
+ global _MODEL_CONFIGS
35
+
36
+ config_ext = ('.json',)
37
+ config_files = []
38
+ for config_path in _MODEL_CONFIG_PATHS:
39
+ if config_path.is_file() and config_path.suffix in config_ext:
40
+ config_files.append(config_path)
41
+ elif config_path.is_dir():
42
+ for ext in config_ext:
43
+ config_files.extend(config_path.glob(f'*{ext}'))
44
+
45
+ for cf in config_files:
46
+ with open(cf, 'r') as f:
47
+ model_cfg = json.load(f)
48
+ if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
49
+ _MODEL_CONFIGS[cf.stem] = model_cfg
50
+
51
+ _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
52
+
53
+
54
+ _rescan_model_configs() # initial populate of model config registry
55
+
56
+
57
+ def list_models():
58
+ """ enumerate available model architectures based on config files """
59
+ return list(_MODEL_CONFIGS.keys())
60
+
61
+
62
+ def add_model_config(path):
63
+ """ add model config path or file and update registry """
64
+ if not isinstance(path, Path):
65
+ path = Path(path)
66
+ _MODEL_CONFIG_PATHS.append(path)
67
+ _rescan_model_configs()
68
+
69
+
70
+ def get_model_config(model_name):
71
+ if model_name in _MODEL_CONFIGS:
72
+ return deepcopy(_MODEL_CONFIGS[model_name])
73
+ else:
74
+ return None
75
+
76
+
77
+ def get_tokenizer(model_name, open_clip_bpe_path=None):
78
+ if model_name.startswith(HF_HUB_PREFIX):
79
+ tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
80
+ else:
81
+ config = get_model_config(model_name)
82
+ tokenizer = HFTokenizer(
83
+ config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else SimpleTokenizer(open_clip_bpe_path)
84
+ return tokenizer
85
+
86
+
87
+ def load_state_dict(checkpoint_path: str, map_location='cpu'):
88
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
89
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
90
+ state_dict = checkpoint['state_dict']
91
+ else:
92
+ state_dict = checkpoint
93
+ if next(iter(state_dict.items()))[0].startswith('module'):
94
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
95
+ return state_dict
96
+
97
+
98
+ def load_checkpoint(model, checkpoint_path, strict=True):
99
+ state_dict = load_state_dict(checkpoint_path)
100
+ # detect old format and make compatible with new format
101
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
102
+ state_dict = convert_to_custom_text_state_dict(state_dict)
103
+ resize_pos_embed(state_dict, model)
104
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
105
+ return incompatible_keys
106
+
107
+
108
+ def create_model(
109
+ model_name: str,
110
+ pretrained: Optional[str] = None,
111
+ precision: str = 'fp32',
112
+ device: Union[str, torch.device] = 'cpu',
113
+ jit: bool = False,
114
+ force_quick_gelu: bool = False,
115
+ force_custom_text: bool = False,
116
+ force_patch_dropout: Optional[float] = None,
117
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
118
+ pretrained_image: bool = False,
119
+ pretrained_hf: bool = True,
120
+ cache_dir: Optional[str] = None,
121
+ output_dict: Optional[bool] = None,
122
+ require_pretrained: bool = False,
123
+ ):
124
+ has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
125
+ if has_hf_hub_prefix:
126
+ model_id = model_name[len(HF_HUB_PREFIX):]
127
+ checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
128
+ config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
129
+
130
+ with open(config_path, 'r', encoding='utf-8') as f:
131
+ config = json.load(f)
132
+ pretrained_cfg = config['preprocess_cfg']
133
+ model_cfg = config['model_cfg']
134
+ else:
135
+ model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
136
+ checkpoint_path = None
137
+ pretrained_cfg = {}
138
+ model_cfg = None
139
+
140
+ if isinstance(device, str):
141
+ device = torch.device(device)
142
+
143
+ if pretrained and pretrained.lower() == 'openai':
144
+ logging.info(f'Loading pretrained {model_name} from OpenAI.')
145
+ model = load_openai_model(
146
+ model_name,
147
+ precision=precision,
148
+ device=device,
149
+ jit=jit,
150
+ cache_dir=cache_dir,
151
+ )
152
+
153
+ # to always output dict even if it is clip
154
+ if output_dict and hasattr(model, "output_dict"):
155
+ model.output_dict = True
156
+ else:
157
+ model_cfg = model_cfg or get_model_config(model_name)
158
+ if model_cfg is not None:
159
+ logging.info(f'Loaded {model_name} model config.')
160
+ else:
161
+ logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
162
+ raise RuntimeError(f'Model config for {model_name} not found.')
163
+
164
+ if force_quick_gelu:
165
+ # override for use of QuickGELU on non-OpenAI transformer models
166
+ model_cfg["quick_gelu"] = True
167
+
168
+ if force_patch_dropout is not None:
169
+ # override the default patch dropout value
170
+ model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
171
+
172
+ if force_image_size is not None:
173
+ # override model config's image size
174
+ model_cfg["vision_cfg"]["image_size"] = force_image_size
175
+
176
+ if pretrained_image:
177
+ if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
178
+ # pretrained weight loading for timm models set via vision_cfg
179
+ model_cfg['vision_cfg']['timm_model_pretrained'] = True
180
+ else:
181
+ assert False, 'pretrained image towers currently only supported for timm models'
182
+
183
+ cast_dtype = get_cast_dtype(precision)
184
+ is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
185
+ custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
186
+
187
+ if custom_text:
188
+ if is_hf_model:
189
+ model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
190
+ if "coca" in model_name:
191
+ model = CoCa(**model_cfg, cast_dtype=cast_dtype)
192
+ else:
193
+ model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
194
+ else:
195
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
196
+
197
+ pretrained_loaded = False
198
+ if pretrained:
199
+ checkpoint_path = ''
200
+ pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
201
+ if pretrained_cfg:
202
+ checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
203
+ elif os.path.exists(pretrained):
204
+ checkpoint_path = pretrained
205
+
206
+ if checkpoint_path:
207
+ logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
208
+ load_checkpoint(model, checkpoint_path)
209
+ else:
210
+ error_str = (
211
+ f'Pretrained weights ({pretrained}) not found for model {model_name}.'
212
+ f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
213
+ logging.warning(error_str)
214
+ raise RuntimeError(error_str)
215
+ pretrained_loaded = True
216
+ elif has_hf_hub_prefix:
217
+ logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
218
+ load_checkpoint(model, checkpoint_path)
219
+ pretrained_loaded = True
220
+
221
+ if require_pretrained and not pretrained_loaded:
222
+ # callers of create_model_from_pretrained always expect pretrained weights
223
+ raise RuntimeError(
224
+ f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
225
+
226
+ model.to(device=device)
227
+ if precision in ("fp16", "bf16"):
228
+ convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
229
+
230
+ # set image / mean metadata from pretrained_cfg if available, or use default
231
+ model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
232
+ model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
233
+
234
+ # to always output dict even if it is clip
235
+ if output_dict and hasattr(model, "output_dict"):
236
+ model.output_dict = True
237
+
238
+ if jit:
239
+ model = torch.jit.script(model)
240
+
241
+ return model
242
+
243
+
244
+ def create_loss(args):
245
+ if args.distill:
246
+ return DistillClipLoss(
247
+ local_loss=args.local_loss,
248
+ gather_with_grad=args.gather_with_grad,
249
+ cache_labels=True,
250
+ rank=args.rank,
251
+ world_size=args.world_size,
252
+ use_horovod=args.horovod,
253
+ )
254
+ elif "coca" in args.model.lower():
255
+ return CoCaLoss(
256
+ caption_loss_weight=args.coca_caption_loss_weight,
257
+ clip_loss_weight=args.coca_contrastive_loss_weight,
258
+ local_loss=args.local_loss,
259
+ gather_with_grad=args.gather_with_grad,
260
+ cache_labels=True,
261
+ rank=args.rank,
262
+ world_size=args.world_size,
263
+ use_horovod=args.horovod,
264
+ )
265
+ return ClipLoss(
266
+ local_loss=args.local_loss,
267
+ gather_with_grad=args.gather_with_grad,
268
+ cache_labels=True,
269
+ rank=args.rank,
270
+ world_size=args.world_size,
271
+ use_horovod=args.horovod,
272
+ )
273
+
274
+ class MLP(torch.nn.Module):
275
+ def __init__(self, input_size):
276
+ super().__init__()
277
+ self.input_size = input_size
278
+ self.layers = torch.nn.Sequential(
279
+ torch.nn.Linear(self.input_size, 1024),
280
+ torch.nn.Dropout(0.2),
281
+ torch.nn.Linear(1024, 128),
282
+ torch.nn.Dropout(0.2),
283
+ torch.nn.Linear(128, 64),
284
+ torch.nn.Dropout(0.1),
285
+ torch.nn.Linear(64, 16),
286
+ torch.nn.Linear(16, 1)
287
+ )
288
+
289
+ def forward(self, x):
290
+ return self.layers(x)
291
+
292
+ # class semantic_head(torch.nn.Module):
293
+ # def __init__(self, input_size):
294
+ # super().__init__()
295
+ # self.input_size = input_size # for ViT-L-14 is 1024
296
+ # self.seg_head = torch.nn.Sequential(
297
+ # torch.nn.Linear(input_size, 128),
298
+ # torch.nn.Dropout(0.2),
299
+ # torch.nn.Linear(128, 64),
300
+ # torch.nn.Dropout(0.1),
301
+ # torch.nn.Linear(64, 16),
302
+ # torch.nn.Linear(16, 1),
303
+ # )
304
+ # self.sigmoid = torch.nn.Sigmoid()
305
+
306
+ # def forward(self, x):
307
+ # return self.sigmoid(self.seg_head(x))
308
+
309
+ def create_model_and_transforms(
310
+ model_name: str,
311
+ pretrained: Optional[str] = None,
312
+ precision: str = 'fp32',
313
+ device: Union[str, torch.device] = 'cpu',
314
+ jit: bool = False,
315
+ force_quick_gelu: bool = False,
316
+ force_custom_text: bool = False,
317
+ force_patch_dropout: Optional[float] = None,
318
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
319
+ pretrained_image: bool = False,
320
+ pretrained_hf: bool = True,
321
+ image_mean: Optional[Tuple[float, ...]] = None,
322
+ image_std: Optional[Tuple[float, ...]] = None,
323
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
324
+ cache_dir: Optional[str] = None,
325
+ light_augmentation = False,
326
+ output_dict: Optional[bool] = None,
327
+ with_score_predictor: bool = False,
328
+ with_region_predictor: bool = False
329
+ ):
330
+ model = create_model(
331
+ model_name,
332
+ pretrained,
333
+ precision=precision,
334
+ device=device,
335
+ jit=jit,
336
+ force_quick_gelu=force_quick_gelu,
337
+ force_custom_text=force_custom_text,
338
+ force_patch_dropout=force_patch_dropout,
339
+ force_image_size=force_image_size,
340
+ pretrained_image=pretrained_image,
341
+ pretrained_hf=pretrained_hf,
342
+ cache_dir=cache_dir,
343
+ output_dict=output_dict,
344
+ )
345
+
346
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
347
+ image_std = image_std or getattr(model.visual, 'image_std', None)
348
+
349
+ if with_score_predictor:
350
+ model.score_predictor = MLP(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype)
351
+
352
+ if with_region_predictor:
353
+ # model.region_predictor = semantic_head(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype)
354
+ model.region_predictor = torch.nn.Linear(model.visual.proj.size(0), 1).to(device=device, dtype=model.visual.proj.dtype)
355
+ # preprocess_train = image_transform_region(
356
+ # model.visual.image_size,
357
+ # is_train=True,
358
+ # mean=image_mean,
359
+ # std=image_std
360
+ # )
361
+ # preprocess_val = image_transform_region(
362
+ # model.visual.image_size,
363
+ # is_train=False,
364
+ # mean=image_mean,
365
+ # std=image_std
366
+ # )
367
+
368
+ if light_augmentation:
369
+ preprocess_val = image_transform(
370
+ model.visual.image_size,
371
+ is_train=False,
372
+ mean=image_mean,
373
+ std=image_std,
374
+ resize_longest_max=True,
375
+ )
376
+ preprocess_train = preprocess_val
377
+ else:
378
+ preprocess_train = image_transform(
379
+ model.visual.image_size,
380
+ is_train=True,
381
+ mean=image_mean,
382
+ std=image_std
383
+ )
384
+ preprocess_val = image_transform(
385
+ model.visual.image_size,
386
+ is_train=False,
387
+ mean=image_mean,
388
+ std=image_std
389
+ )
390
+
391
+ return model, preprocess_train, preprocess_val
392
+
393
+
394
+ def create_model_from_pretrained(
395
+ model_name: str,
396
+ pretrained: Optional[str] = None,
397
+ precision: str = 'fp32',
398
+ device: Union[str, torch.device] = 'cpu',
399
+ jit: bool = False,
400
+ force_quick_gelu: bool = False,
401
+ force_custom_text: bool = False,
402
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
403
+ return_transform: bool = True,
404
+ image_mean: Optional[Tuple[float, ...]] = None,
405
+ image_std: Optional[Tuple[float, ...]] = None,
406
+ cache_dir: Optional[str] = None,
407
+ ):
408
+ model = create_model(
409
+ model_name,
410
+ pretrained,
411
+ precision=precision,
412
+ device=device,
413
+ jit=jit,
414
+ force_quick_gelu=force_quick_gelu,
415
+ force_custom_text=force_custom_text,
416
+ force_image_size=force_image_size,
417
+ cache_dir=cache_dir,
418
+ require_pretrained=True,
419
+ )
420
+
421
+ if not return_transform:
422
+ return model
423
+
424
+ image_mean = image_mean or getattr(model.visual, 'image_mean', None)
425
+ image_std = image_std or getattr(model.visual, 'image_std', None)
426
+ preprocess = image_transform(
427
+ model.visual.image_size,
428
+ is_train=False,
429
+ mean=image_mean,
430
+ std=image_std,
431
+ )
432
+
433
+ return model, preprocess
diffsynth/extensions/ImageQualityMetric/open_clip/generation_utils.py ADDED
File without changes
diffsynth/extensions/ImageQualityMetric/open_clip/hf_configs.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF architecture dict:
2
+ arch_dict = {
3
+ # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4
+ "roberta": {
5
+ "config_names": {
6
+ "context_length": "max_position_embeddings",
7
+ "vocab_size": "vocab_size",
8
+ "width": "hidden_size",
9
+ "heads": "num_attention_heads",
10
+ "layers": "num_hidden_layers",
11
+ "layer_attr": "layer",
12
+ "token_embeddings_attr": "embeddings"
13
+ },
14
+ "pooler": "mean_pooler",
15
+ },
16
+ # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17
+ "xlm-roberta": {
18
+ "config_names": {
19
+ "context_length": "max_position_embeddings",
20
+ "vocab_size": "vocab_size",
21
+ "width": "hidden_size",
22
+ "heads": "num_attention_heads",
23
+ "layers": "num_hidden_layers",
24
+ "layer_attr": "layer",
25
+ "token_embeddings_attr": "embeddings"
26
+ },
27
+ "pooler": "mean_pooler",
28
+ },
29
+ # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30
+ "mt5": {
31
+ "config_names": {
32
+ # unlimited seqlen
33
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35
+ "context_length": "",
36
+ "vocab_size": "vocab_size",
37
+ "width": "d_model",
38
+ "heads": "num_heads",
39
+ "layers": "num_layers",
40
+ "layer_attr": "block",
41
+ "token_embeddings_attr": "embed_tokens"
42
+ },
43
+ "pooler": "mean_pooler",
44
+ },
45
+ }
diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ huggingface model adapter
2
+
3
+ Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4
+ """
5
+
6
+ import re
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch import TensorType
11
+
12
+ try:
13
+ import transformers
14
+ from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
15
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
16
+ BaseModelOutputWithPoolingAndCrossAttentions
17
+ except ImportError as e:
18
+ transformers = None
19
+
20
+
21
+ class BaseModelOutput:
22
+ pass
23
+
24
+
25
+ class PretrainedConfig:
26
+ pass
27
+
28
+ from .hf_configs import arch_dict
29
+
30
+
31
+ # utils
32
+ def _camel2snake(s):
33
+ return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
34
+
35
+
36
+ # TODO: ?last - for gpt-like models
37
+ _POOLERS = {}
38
+
39
+
40
+ def register_pooler(cls):
41
+ """Decorator registering pooler class"""
42
+ _POOLERS[_camel2snake(cls.__name__)] = cls
43
+ return cls
44
+
45
+
46
+ @register_pooler
47
+ class MeanPooler(nn.Module):
48
+ """Mean pooling"""
49
+
50
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
51
+ masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
52
+ return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
53
+
54
+
55
+ @register_pooler
56
+ class MaxPooler(nn.Module):
57
+ """Max pooling"""
58
+
59
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
60
+ masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
61
+ return masked_output.max(1).values
62
+
63
+
64
+ @register_pooler
65
+ class ClsPooler(nn.Module):
66
+ """CLS token pooling"""
67
+
68
+ def __init__(self, use_pooler_output=True):
69
+ super().__init__()
70
+ self.cls_token_position = 0
71
+ self.use_pooler_output = use_pooler_output
72
+
73
+ def forward(self, x: BaseModelOutput, attention_mask: TensorType):
74
+ if (self.use_pooler_output and
75
+ isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
76
+ (x.pooler_output is not None)
77
+ ):
78
+ return x.pooler_output
79
+
80
+ return x.last_hidden_state[:, self.cls_token_position, :]
81
+
82
+
83
+ class HFTextEncoder(nn.Module):
84
+ """HuggingFace model adapter"""
85
+ output_tokens: torch.jit.Final[bool]
86
+
87
+ def __init__(
88
+ self,
89
+ model_name_or_path: str,
90
+ output_dim: int,
91
+ config: PretrainedConfig = None,
92
+ pooler_type: str = None,
93
+ proj: str = None,
94
+ pretrained: bool = True,
95
+ output_tokens: bool = False,
96
+ ):
97
+ super().__init__()
98
+ self.output_tokens = output_tokens
99
+ self.output_dim = output_dim
100
+
101
+ # TODO: find better way to get this information
102
+ uses_transformer_pooler = (pooler_type == "cls_pooler")
103
+
104
+ if transformers is None:
105
+ raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
106
+ if config is None:
107
+ self.config = AutoConfig.from_pretrained(model_name_or_path)
108
+ create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
109
+ AutoModel.from_config, self.config)
110
+ # TODO: do all model configs have this attribute? PretrainedConfig does so yes??
111
+ if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
112
+ self.transformer = create_func(model_args)
113
+ self.transformer = self.transformer.encoder
114
+ else:
115
+ self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
116
+ else:
117
+ self.config = config
118
+ self.transformer = AutoModel.from_config(config)
119
+ if pooler_type is None: # get default arch pooler
120
+ pooler_type = (arch_dict[self.config.model_type]["pooler"])
121
+
122
+ self.pooler = _POOLERS[pooler_type]()
123
+
124
+ d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
125
+ if (d_model == output_dim) and (proj is None): # do we always need a proj?
126
+ self.proj = nn.Identity()
127
+ elif proj == 'linear':
128
+ self.proj = nn.Linear(d_model, output_dim, bias=False)
129
+ elif proj == 'mlp':
130
+ hidden_size = (d_model + output_dim) // 2
131
+ self.proj = nn.Sequential(
132
+ nn.Linear(d_model, hidden_size, bias=False),
133
+ nn.GELU(),
134
+ nn.Linear(hidden_size, output_dim, bias=False),
135
+ )
136
+
137
+ def forward(self, x: TensorType):
138
+ attn_mask = (x != self.config.pad_token_id).long()
139
+ out = self.transformer(input_ids=x, attention_mask=attn_mask)
140
+ pooled_out = self.pooler(out, attn_mask)
141
+ projected = self.proj(pooled_out)
142
+
143
+ seq_len = out.last_hidden_state.shape[1]
144
+ tokens = (
145
+ out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
146
+ if type(self.pooler) == ClsPooler
147
+ else out.last_hidden_state
148
+ )
149
+
150
+ if self.output_tokens:
151
+ return projected, tokens
152
+ return projected
153
+
154
+ def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
155
+ if not unlocked_layers: # full freezing
156
+ for n, p in self.transformer.named_parameters():
157
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
158
+ return
159
+
160
+ encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
161
+ layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
162
+ print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
163
+ embeddings = getattr(
164
+ self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
165
+ modules = [embeddings, *layer_list][:-unlocked_layers]
166
+ # freeze layers
167
+ for module in modules:
168
+ for n, p in module.named_parameters():
169
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
170
+
171
+ @torch.jit.ignore
172
+ def set_grad_checkpointing(self, enable=True):
173
+ self.transformer.gradient_checkpointing_enable()
174
+
175
+ def init_parameters(self):
176
+ pass
diffsynth/extensions/ImageQualityMetric/open_clip/loss.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from torch.nn.utils.rnn import pad_sequence
5
+
6
+ try:
7
+ import torch.distributed.nn
8
+ from torch import distributed as dist
9
+
10
+ has_distributed = True
11
+ except ImportError:
12
+ has_distributed = False
13
+
14
+ try:
15
+ import horovod.torch as hvd
16
+ except ImportError:
17
+ hvd = None
18
+
19
+
20
+ def gather_features(
21
+ image_features,
22
+ text_features,
23
+ local_loss=False,
24
+ gather_with_grad=False,
25
+ rank=0,
26
+ world_size=1,
27
+ use_horovod=False
28
+ ):
29
+ assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
30
+ if use_horovod:
31
+ assert hvd is not None, 'Please install horovod'
32
+ if gather_with_grad:
33
+ all_image_features = hvd.allgather(image_features)
34
+ all_text_features = hvd.allgather(text_features)
35
+ else:
36
+ with torch.no_grad():
37
+ all_image_features = hvd.allgather(image_features)
38
+ all_text_features = hvd.allgather(text_features)
39
+ if not local_loss:
40
+ # ensure grads for local rank when all_* features don't have a gradient
41
+ gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
42
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
43
+ gathered_image_features[rank] = image_features
44
+ gathered_text_features[rank] = text_features
45
+ all_image_features = torch.cat(gathered_image_features, dim=0)
46
+ all_text_features = torch.cat(gathered_text_features, dim=0)
47
+ else:
48
+ # We gather tensors from all gpus
49
+ if gather_with_grad:
50
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
51
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
52
+ else:
53
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
54
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
55
+ dist.all_gather(gathered_image_features, image_features)
56
+ dist.all_gather(gathered_text_features, text_features)
57
+ if not local_loss:
58
+ # ensure grads for local rank when all_* features don't have a gradient
59
+ gathered_image_features[rank] = image_features
60
+ gathered_text_features[rank] = text_features
61
+ all_image_features = torch.cat(gathered_image_features, dim=0)
62
+ all_text_features = torch.cat(gathered_text_features, dim=0)
63
+
64
+ return all_image_features, all_text_features
65
+
66
+
67
+ class ClipLoss(nn.Module):
68
+
69
+ def __init__(
70
+ self,
71
+ local_loss=False,
72
+ gather_with_grad=False,
73
+ cache_labels=False,
74
+ rank=0,
75
+ world_size=1,
76
+ use_horovod=False,
77
+ ):
78
+ super().__init__()
79
+ self.local_loss = local_loss
80
+ self.gather_with_grad = gather_with_grad
81
+ self.cache_labels = cache_labels
82
+ self.rank = rank
83
+ self.world_size = world_size
84
+ self.use_horovod = use_horovod
85
+
86
+ # cache state
87
+ self.prev_num_logits = 0
88
+ self.labels = {}
89
+
90
+ def get_ground_truth(self, device, num_logits) -> torch.Tensor:
91
+ # calculated ground-truth and cache if enabled
92
+ if self.prev_num_logits != num_logits or device not in self.labels:
93
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
94
+ if self.world_size > 1 and self.local_loss:
95
+ labels = labels + num_logits * self.rank
96
+ if self.cache_labels:
97
+ self.labels[device] = labels
98
+ self.prev_num_logits = num_logits
99
+ else:
100
+ labels = self.labels[device]
101
+ return labels
102
+
103
+ def get_logits(self, image_features, text_features, logit_scale):
104
+ if self.world_size > 1:
105
+ all_image_features, all_text_features = gather_features(
106
+ image_features, text_features,
107
+ self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
108
+
109
+ if self.local_loss:
110
+ logits_per_image = logit_scale * image_features @ all_text_features.T
111
+ logits_per_text = logit_scale * text_features @ all_image_features.T
112
+ else:
113
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
114
+ logits_per_text = logits_per_image.T
115
+ else:
116
+ logits_per_image = logit_scale * image_features @ text_features.T
117
+ logits_per_text = logit_scale * text_features @ image_features.T
118
+
119
+ return logits_per_image, logits_per_text
120
+
121
+ def forward(self, image_features, text_features, logit_scale, output_dict=False):
122
+ device = image_features.device
123
+ logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
124
+
125
+ labels = self.get_ground_truth(device, logits_per_image.shape[0])
126
+
127
+ total_loss = (
128
+ F.cross_entropy(logits_per_image, labels) +
129
+ F.cross_entropy(logits_per_text, labels)
130
+ ) / 2
131
+ return total_loss
132
+
133
+ class PreferenceLoss(nn.Module):
134
+
135
+ def forward(self, logits_per_image, num_images, labels):
136
+
137
+ paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
138
+ paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-999)
139
+
140
+ ce_loss = F.cross_entropy(paired_logits, labels)
141
+ return ce_loss
142
+
143
+ class HPSLoss(nn.Module):
144
+
145
+ def forward(self, text_logits, labels):
146
+
147
+ device = text_logits.device
148
+ text_0_logits, text_1_logits = text_logits.chunk(2, dim=-1)
149
+ label_0, label_1 = labels.chunk(2, dim=-1)
150
+
151
+ index = torch.arange(text_0_logits.shape[0], device=device, dtype=torch.long)
152
+ text_0_logits = text_0_logits[index, index]
153
+ text_1_logits = text_1_logits[index, index]
154
+ text_logits = torch.stack([text_0_logits, text_1_logits], dim=-1)
155
+ text_0_labels = torch.zeros(text_logits.shape[0], device=device, dtype=torch.long)
156
+ text_1_labels = text_0_labels + 1
157
+
158
+ text_0_loss = torch.nn.functional.cross_entropy(text_logits, text_0_labels, reduction="none")
159
+ text_1_loss = torch.nn.functional.cross_entropy(text_logits, text_1_labels, reduction="none")
160
+
161
+ text_loss = label_0 * text_0_loss + label_1 * text_1_loss
162
+
163
+ # absolute_example_weight = 1 / num_per_prompt
164
+ # denominator = absolute_example_weight.sum()
165
+ # weight_per_example = absolute_example_weight / denominator
166
+ # text_loss *= weight_per_example
167
+
168
+ text_loss = text_loss.sum()
169
+ return text_loss
170
+
171
+ class RankingLoss(nn.Module):
172
+
173
+ def forward(self, logits_per_image, num_images, labels, margin = 1.0):
174
+ paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
175
+ label_list = [label for label in labels.split(num_images.tolist())]
176
+ # ranked_logits = [torch.index_select(paired_logits_list[i], 0, rank) for i, rank in enumerate(label_list)]
177
+
178
+ paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-1)
179
+ padded_labels = pad_sequence(label_list, batch_first=True, padding_value=10)
180
+
181
+ # regulized_logits = torch.log(torch.sigmoid(paired_logits))
182
+
183
+ diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2)
184
+ # diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2)
185
+ # diff_label = torch.clamp(padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2), min=-1, max=1)
186
+ diff_label = - (padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2))
187
+ mask = torch.triu(torch.ones(diff.shape[1], diff.shape[1]), diagonal=1).bool().detach()
188
+
189
+ loss = torch.clamp(margin - torch.mul(diff[:, ~mask],diff_label[:,~mask]), min=0).mean()
190
+ return loss
191
+
192
+ class CoCaLoss(ClipLoss):
193
+ def __init__(
194
+ self,
195
+ caption_loss_weight,
196
+ clip_loss_weight,
197
+ pad_id=0, # pad_token for open_clip custom tokenizer
198
+ local_loss=False,
199
+ gather_with_grad=False,
200
+ cache_labels=False,
201
+ rank=0,
202
+ world_size=1,
203
+ use_horovod=False,
204
+ ):
205
+ super().__init__(
206
+ local_loss=local_loss,
207
+ gather_with_grad=gather_with_grad,
208
+ cache_labels=cache_labels,
209
+ rank=rank,
210
+ world_size=world_size,
211
+ use_horovod=use_horovod
212
+ )
213
+
214
+ self.clip_loss_weight = clip_loss_weight
215
+ self.caption_loss_weight = caption_loss_weight
216
+ self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
217
+
218
+ def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
219
+ clip_loss = super().forward(image_features, text_features, logit_scale)
220
+ clip_loss = self.clip_loss_weight * clip_loss
221
+
222
+ caption_loss = self.caption_loss(
223
+ logits.permute(0, 2, 1),
224
+ labels,
225
+ )
226
+ caption_loss = caption_loss * self.caption_loss_weight
227
+
228
+ if output_dict:
229
+ return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
230
+
231
+ return clip_loss, caption_loss
232
+
233
+
234
+ class DistillClipLoss(ClipLoss):
235
+
236
+ def dist_loss(self, teacher_logits, student_logits):
237
+ return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
238
+
239
+ def forward(
240
+ self,
241
+ image_features,
242
+ text_features,
243
+ logit_scale,
244
+ dist_image_features,
245
+ dist_text_features,
246
+ dist_logit_scale,
247
+ output_dict=False,
248
+ ):
249
+ logits_per_image, logits_per_text = \
250
+ self.get_logits(image_features, text_features, logit_scale)
251
+
252
+ dist_logits_per_image, dist_logits_per_text = \
253
+ self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
254
+
255
+ labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
256
+
257
+ contrastive_loss = (
258
+ F.cross_entropy(logits_per_image, labels) +
259
+ F.cross_entropy(logits_per_text, labels)
260
+ ) / 2
261
+
262
+ distill_loss = (
263
+ self.dist_loss(dist_logits_per_image, logits_per_image) +
264
+ self.dist_loss(dist_logits_per_text, logits_per_text)
265
+ ) / 2
266
+
267
+ if output_dict:
268
+ return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
269
+
270
+ return contrastive_loss, distill_loss
diffsynth/extensions/ImageQualityMetric/open_clip/model.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP Model
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ from dataclasses import dataclass
6
+ import logging
7
+ import math
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+ from torch.utils.checkpoint import checkpoint
15
+
16
+ from .hf_model import HFTextEncoder
17
+ from .modified_resnet import ModifiedResNet
18
+ from .timm_model import TimmModel
19
+ from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
20
+ from .utils import to_2tuple
21
+
22
+
23
+ @dataclass
24
+ class CLIPVisionCfg:
25
+ layers: Union[Tuple[int, int, int, int], int] = 12
26
+ width: int = 768
27
+ head_width: int = 64
28
+ mlp_ratio: float = 4.0
29
+ patch_size: int = 16
30
+ image_size: Union[Tuple[int, int], int] = 224
31
+ ls_init_value: Optional[float] = None # layer scale initial value
32
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
33
+ input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
34
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
35
+ attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
36
+ n_queries: int = 256 # n_queries for attentional pooler
37
+ attn_pooler_heads: int = 8 # n heads for attentional_pooling
38
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
39
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
40
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
41
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
42
+ timm_proj_bias: bool = False # enable bias final projection
43
+ timm_drop: float = 0. # head dropout
44
+ timm_drop_path: Optional[float] = None # backbone stochastic depth
45
+ output_tokens: bool = False
46
+
47
+
48
+ @dataclass
49
+ class CLIPTextCfg:
50
+ context_length: int = 77
51
+ vocab_size: int = 49408
52
+ width: int = 512
53
+ heads: int = 8
54
+ layers: int = 12
55
+ ls_init_value: Optional[float] = None # layer scale initial value
56
+ hf_model_name: str = None
57
+ hf_tokenizer_name: str = None
58
+ hf_model_pretrained: bool = True
59
+ proj: str = 'mlp'
60
+ pooler_type: str = 'mean_pooler'
61
+ embed_cls: bool = False
62
+ pad_id: int = 0
63
+ output_tokens: bool = False
64
+
65
+
66
+ def get_cast_dtype(precision: str):
67
+ cast_dtype = None
68
+ if precision == 'bf16':
69
+ cast_dtype = torch.bfloat16
70
+ elif precision == 'fp16':
71
+ cast_dtype = torch.float16
72
+ return cast_dtype
73
+
74
+
75
+ def _build_vision_tower(
76
+ embed_dim: int,
77
+ vision_cfg: CLIPVisionCfg,
78
+ quick_gelu: bool = False,
79
+ cast_dtype: Optional[torch.dtype] = None
80
+ ):
81
+ if isinstance(vision_cfg, dict):
82
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
83
+
84
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
85
+ # memory efficient in recent PyTorch releases (>= 1.10).
86
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
87
+ act_layer = QuickGELU if quick_gelu else nn.GELU
88
+
89
+ if vision_cfg.timm_model_name:
90
+ visual = TimmModel(
91
+ vision_cfg.timm_model_name,
92
+ pretrained=vision_cfg.timm_model_pretrained,
93
+ pool=vision_cfg.timm_pool,
94
+ proj=vision_cfg.timm_proj,
95
+ proj_bias=vision_cfg.timm_proj_bias,
96
+ drop=vision_cfg.timm_drop,
97
+ drop_path=vision_cfg.timm_drop_path,
98
+ embed_dim=embed_dim,
99
+ image_size=vision_cfg.image_size,
100
+ )
101
+ act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
102
+ elif isinstance(vision_cfg.layers, (tuple, list)):
103
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
104
+ visual = ModifiedResNet(
105
+ layers=vision_cfg.layers,
106
+ output_dim=embed_dim,
107
+ heads=vision_heads,
108
+ image_size=vision_cfg.image_size,
109
+ width=vision_cfg.width,
110
+ )
111
+ else:
112
+ vision_heads = vision_cfg.width // vision_cfg.head_width
113
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
114
+ visual = VisionTransformer(
115
+ image_size=vision_cfg.image_size,
116
+ patch_size=vision_cfg.patch_size,
117
+ width=vision_cfg.width,
118
+ layers=vision_cfg.layers,
119
+ heads=vision_heads,
120
+ mlp_ratio=vision_cfg.mlp_ratio,
121
+ ls_init_value=vision_cfg.ls_init_value,
122
+ patch_dropout=vision_cfg.patch_dropout,
123
+ input_patchnorm=vision_cfg.input_patchnorm,
124
+ global_average_pool=vision_cfg.global_average_pool,
125
+ attentional_pool=vision_cfg.attentional_pool,
126
+ n_queries=vision_cfg.n_queries,
127
+ attn_pooler_heads=vision_cfg.attn_pooler_heads,
128
+ output_tokens=vision_cfg.output_tokens,
129
+ output_dim=embed_dim,
130
+ act_layer=act_layer,
131
+ norm_layer=norm_layer,
132
+ )
133
+
134
+ return visual
135
+
136
+
137
+ def _build_text_tower(
138
+ embed_dim: int,
139
+ text_cfg: CLIPTextCfg,
140
+ quick_gelu: bool = False,
141
+ cast_dtype: Optional[torch.dtype] = None,
142
+ ):
143
+ if isinstance(text_cfg, dict):
144
+ text_cfg = CLIPTextCfg(**text_cfg)
145
+
146
+ if text_cfg.hf_model_name:
147
+ text = HFTextEncoder(
148
+ text_cfg.hf_model_name,
149
+ output_dim=embed_dim,
150
+ proj=text_cfg.proj,
151
+ pooler_type=text_cfg.pooler_type,
152
+ pretrained=text_cfg.hf_model_pretrained,
153
+ output_tokens=text_cfg.output_tokens,
154
+ )
155
+ else:
156
+ act_layer = QuickGELU if quick_gelu else nn.GELU
157
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
158
+
159
+ text = TextTransformer(
160
+ context_length=text_cfg.context_length,
161
+ vocab_size=text_cfg.vocab_size,
162
+ width=text_cfg.width,
163
+ heads=text_cfg.heads,
164
+ layers=text_cfg.layers,
165
+ ls_init_value=text_cfg.ls_init_value,
166
+ output_dim=embed_dim,
167
+ embed_cls=text_cfg.embed_cls,
168
+ output_tokens=text_cfg.output_tokens,
169
+ pad_id=text_cfg.pad_id,
170
+ act_layer=act_layer,
171
+ norm_layer=norm_layer,
172
+ )
173
+ return text
174
+
175
+
176
+ class CLIP(nn.Module):
177
+ output_dict: torch.jit.Final[bool]
178
+
179
+ def __init__(
180
+ self,
181
+ embed_dim: int,
182
+ vision_cfg: CLIPVisionCfg,
183
+ text_cfg: CLIPTextCfg,
184
+ quick_gelu: bool = False,
185
+ cast_dtype: Optional[torch.dtype] = None,
186
+ output_dict: bool = False,
187
+ ):
188
+ super().__init__()
189
+ self.output_dict = output_dict
190
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
191
+
192
+ text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
193
+ self.transformer = text.transformer
194
+ self.vocab_size = text.vocab_size
195
+ self.token_embedding = text.token_embedding
196
+ self.positional_embedding = text.positional_embedding
197
+ self.ln_final = text.ln_final
198
+ self.text_projection = text.text_projection
199
+ self.register_buffer('attn_mask', text.attn_mask, persistent=False)
200
+
201
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
202
+
203
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
204
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
205
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
206
+
207
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
208
+ locked_layers = []
209
+ locked_layers.append(self.token_embedding)
210
+ self.positional_embedding.requires_grad = False
211
+ if unlocked_layers > 0:
212
+ locked_layers.append(self.transformer.resblocks[:-unlocked_layers])
213
+ else:
214
+ locked_layers.append(self.transformer)
215
+ locked_layers.append(self.ln_final)
216
+ self.text_projection.requires_grad = False
217
+
218
+ # freeze layers
219
+ for module in locked_layers:
220
+ for n, p in module.named_parameters():
221
+ p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
222
+
223
+ @torch.jit.ignore
224
+ def set_grad_checkpointing(self, enable=True):
225
+ self.visual.set_grad_checkpointing(enable)
226
+ self.transformer.grad_checkpointing = enable
227
+
228
+ def encode_image(self, image, normalize: bool = False):
229
+ features = self.visual(image)
230
+ return F.normalize(features, dim=-1) if normalize else features
231
+
232
+ def encode_text(self, text, normalize: bool = False):
233
+ cast_dtype = self.transformer.get_cast_dtype()
234
+
235
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
236
+
237
+ x = x + self.positional_embedding.to(cast_dtype)
238
+ x = x.permute(1, 0, 2) # NLD -> LND
239
+ x = self.transformer(x, attn_mask=self.attn_mask)
240
+ x = x.permute(1, 0, 2) # LND -> NLD
241
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
242
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
243
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
244
+ return F.normalize(x, dim=-1) if normalize else x
245
+
246
+ def forward(self, image, text):
247
+ image_features = self.encode_image(image, normalize=True)
248
+ text_features = self.encode_text(text, normalize=True)
249
+ if self.output_dict:
250
+ return {
251
+ "image_features": image_features,
252
+ "text_features": text_features,
253
+ "logit_scale": self.logit_scale.exp()
254
+ }
255
+ return image_features, text_features, self.logit_scale.exp()
256
+
257
+
258
+ class CustomTextCLIP(nn.Module):
259
+ output_dict: torch.jit.Final[bool]
260
+
261
+ def __init__(
262
+ self,
263
+ embed_dim: int,
264
+ vision_cfg: CLIPVisionCfg,
265
+ text_cfg: CLIPTextCfg,
266
+ quick_gelu: bool = False,
267
+ cast_dtype: Optional[torch.dtype] = None,
268
+ output_dict: bool = False,
269
+ ):
270
+ super().__init__()
271
+ self.output_dict = output_dict
272
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
273
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
274
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
275
+
276
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
277
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
278
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
279
+
280
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
281
+ self.text.lock(unlocked_layers, freeze_layer_norm)
282
+
283
+ @torch.jit.ignore
284
+ def set_grad_checkpointing(self, enable=True):
285
+ self.visual.set_grad_checkpointing(enable)
286
+ self.text.set_grad_checkpointing(enable)
287
+
288
+ def encode_image(self, image, normalize: bool = False):
289
+ features = self.visual(image)
290
+ return F.normalize(features, dim=-1) if normalize else features
291
+
292
+ def encode_text(self, text, normalize: bool = False):
293
+ features = self.text(text)
294
+ return F.normalize(features, dim=-1) if normalize else features
295
+
296
+ def forward(self, image, text):
297
+ image_features = self.encode_image(image, normalize=True)
298
+ text_features = self.encode_text(text, normalize=True)
299
+ if self.output_dict:
300
+ return {
301
+ "image_features": image_features,
302
+ "text_features": text_features,
303
+ "logit_scale": self.logit_scale.exp()
304
+ }
305
+ return image_features, text_features, self.logit_scale.exp()
306
+
307
+
308
+ def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
309
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
310
+
311
+ def _convert_weights(l):
312
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
313
+ l.weight.data = l.weight.data.to(dtype)
314
+ if l.bias is not None:
315
+ l.bias.data = l.bias.data.to(dtype)
316
+
317
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
318
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
319
+ tensor = getattr(l, attr)
320
+ if tensor is not None:
321
+ tensor.data = tensor.data.to(dtype)
322
+
323
+ for name in ["text_projection", "proj"]:
324
+ if hasattr(l, name):
325
+ attr = getattr(l, name)
326
+ if attr is not None:
327
+ attr.data = attr.data.to(dtype)
328
+
329
+ model.apply(_convert_weights)
330
+
331
+
332
+ convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
333
+
334
+
335
+ # used to maintain checkpoint compatibility
336
+ def convert_to_custom_text_state_dict(state_dict: dict):
337
+ if 'text_projection' in state_dict:
338
+ # old format state_dict, move text tower -> .text
339
+ new_state_dict = {}
340
+ for k, v in state_dict.items():
341
+ if any(k.startswith(p) for p in (
342
+ 'text_projection',
343
+ 'positional_embedding',
344
+ 'token_embedding',
345
+ 'transformer',
346
+ 'ln_final',
347
+ )):
348
+ k = 'text.' + k
349
+ new_state_dict[k] = v
350
+ return new_state_dict
351
+ return state_dict
352
+
353
+
354
+ def build_model_from_openai_state_dict(
355
+ state_dict: dict,
356
+ quick_gelu=True,
357
+ cast_dtype=torch.float16,
358
+ ):
359
+ vit = "visual.proj" in state_dict
360
+
361
+ if vit:
362
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
363
+ vision_layers = len(
364
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
365
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
366
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
367
+ image_size = vision_patch_size * grid_size
368
+ else:
369
+ counts: list = [
370
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
371
+ vision_layers = tuple(counts)
372
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
373
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
374
+ vision_patch_size = None
375
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
376
+ image_size = output_width * 32
377
+
378
+ embed_dim = state_dict["text_projection"].shape[1]
379
+ context_length = state_dict["positional_embedding"].shape[0]
380
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
381
+ transformer_width = state_dict["ln_final.weight"].shape[0]
382
+ transformer_heads = transformer_width // 64
383
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
384
+
385
+ vision_cfg = CLIPVisionCfg(
386
+ layers=vision_layers,
387
+ width=vision_width,
388
+ patch_size=vision_patch_size,
389
+ image_size=image_size,
390
+ )
391
+ text_cfg = CLIPTextCfg(
392
+ context_length=context_length,
393
+ vocab_size=vocab_size,
394
+ width=transformer_width,
395
+ heads=transformer_heads,
396
+ layers=transformer_layers,
397
+ )
398
+ model = CLIP(
399
+ embed_dim,
400
+ vision_cfg=vision_cfg,
401
+ text_cfg=text_cfg,
402
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
403
+ cast_dtype=cast_dtype,
404
+ )
405
+
406
+ for key in ["input_resolution", "context_length", "vocab_size"]:
407
+ state_dict.pop(key, None)
408
+
409
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
410
+ model.load_state_dict(state_dict)
411
+ return model.eval()
412
+
413
+
414
+ def trace_model(model, batch_size=256, device=torch.device('cpu')):
415
+ model.eval()
416
+ image_size = model.visual.image_size
417
+ example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
418
+ example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
419
+ model = torch.jit.trace_module(
420
+ model,
421
+ inputs=dict(
422
+ forward=(example_images, example_text),
423
+ encode_text=(example_text,),
424
+ encode_image=(example_images,)
425
+ ))
426
+ model.visual.image_size = image_size
427
+ return model
428
+
429
+
430
+ def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
431
+ # Rescale the grid of position embeddings when loading from state_dict
432
+ old_pos_embed = state_dict.get('visual.positional_embedding', None)
433
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
434
+ return
435
+ grid_size = to_2tuple(model.visual.grid_size)
436
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
437
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
438
+ if new_seq_len == old_pos_embed.shape[0]:
439
+ return
440
+
441
+ if extra_tokens:
442
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
443
+ else:
444
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
445
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
446
+
447
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
448
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
449
+ pos_emb_img = F.interpolate(
450
+ pos_emb_img,
451
+ size=grid_size,
452
+ mode=interpolation,
453
+ antialias=antialias,
454
+ align_corners=False,
455
+ )
456
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
457
+ if pos_emb_tok is not None:
458
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
459
+ else:
460
+ new_pos_embed = pos_emb_img
461
+ state_dict['visual.positional_embedding'] = new_pos_embed