Yixuan
commited on
Commit
·
d234621
0
Parent(s):
update readme
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- .gitattributes +42 -0
- .gitignore +3 -0
- LICENSE +21 -0
- README-copy.md +308 -0
- README.md +420 -0
- assets/images/logo-text-2.png +3 -0
- assets/images/logo-text.png +3 -0
- assets/images/logo.png +3 -0
- diffsynth/__init__.py +6 -0
- diffsynth/configs/__init__.py +0 -0
- diffsynth/configs/model_config.py +778 -0
- diffsynth/controlnets/__init__.py +2 -0
- diffsynth/controlnets/controlnet_unit.py +91 -0
- diffsynth/controlnets/processors.py +62 -0
- diffsynth/data/__init__.py +1 -0
- diffsynth/data/simple_text_image.py +41 -0
- diffsynth/data/video.py +148 -0
- diffsynth/extensions/ESRGAN/__init__.py +137 -0
- diffsynth/extensions/FastBlend/__init__.py +63 -0
- diffsynth/extensions/FastBlend/api.py +397 -0
- diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
- diffsynth/extensions/FastBlend/data.py +146 -0
- diffsynth/extensions/FastBlend/patch_match.py +298 -0
- diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
- diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
- diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
- diffsynth/extensions/FastBlend/runners/fast.py +141 -0
- diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
- diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py +1 -0
- diffsynth/extensions/ImageQualityMetric/BLIP/blip.py +77 -0
- diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py +44 -0
- diffsynth/extensions/ImageQualityMetric/BLIP/med.py +947 -0
- diffsynth/extensions/ImageQualityMetric/BLIP/vit.py +301 -0
- diffsynth/extensions/ImageQualityMetric/__init__.py +148 -0
- diffsynth/extensions/ImageQualityMetric/aesthetic.py +148 -0
- diffsynth/extensions/ImageQualityMetric/clip.py +97 -0
- diffsynth/extensions/ImageQualityMetric/config.py +23 -0
- diffsynth/extensions/ImageQualityMetric/hps.py +118 -0
- diffsynth/extensions/ImageQualityMetric/imagereward.py +212 -0
- diffsynth/extensions/ImageQualityMetric/mps.py +129 -0
- diffsynth/extensions/ImageQualityMetric/open_clip/__init__.py +14 -0
- diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py +458 -0
- diffsynth/extensions/ImageQualityMetric/open_clip/constants.py +2 -0
- diffsynth/extensions/ImageQualityMetric/open_clip/factory.py +433 -0
- diffsynth/extensions/ImageQualityMetric/open_clip/generation_utils.py +0 -0
- diffsynth/extensions/ImageQualityMetric/open_clip/hf_configs.py +45 -0
- diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py +176 -0
- diffsynth/extensions/ImageQualityMetric/open_clip/loss.py +270 -0
- diffsynth/extensions/ImageQualityMetric/open_clip/model.py +461 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
.gitattributes
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/images/logo-text.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/images/logo.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
examples/output_videos/output_moe_framepack_sliding.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
logo-text-2.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
assets/images/logo-text-2.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ignore all checkpoint files
|
| 2 |
+
*.ckpt
|
| 3 |
+
*.ckpt.*
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Yixuan
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README-copy.md
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ReCamMaster: Camera-Controlled Generative Rendering from A Single Video (ICCV'25 Oral, Best Paper Finalist)
|
| 2 |
+
|
| 3 |
+
<div align="center">
|
| 4 |
+
<div align="center" style="margin-top: 0px; margin-bottom: 0px;">
|
| 5 |
+
<img src=https://github.com/user-attachments/assets/81ccf80e-f4b6-4a3d-b47a-e9c2ce14e34f width="30%"/>
|
| 6 |
+
</div>
|
| 7 |
+
|
| 8 |
+
### [<a href="https://arxiv.org/abs/2503.11647" target="_blank">arXiv</a>] [<a href="https://jianhongbai.github.io/ReCamMaster/" target="_blank">Project Page</a>] [<a href="https://huggingface.co/datasets/KwaiVGI/MultiCamVideo-Dataset" target="_blank">Dataset</a>]
|
| 9 |
+
_**[Jianhong Bai<sup>1*</sup>](https://jianhongbai.github.io/), [Menghan Xia<sup>2†</sup>](https://menghanxia.github.io/), [Xiao Fu<sup>3</sup>](https://fuxiao0719.github.io/), [Xintao Wang<sup>2</sup>](https://xinntao.github.io/), [Lianrui Mu<sup>1</sup>](https://scholar.google.com/citations?user=dCik-2YAAAAJ&hl=en), [Jinwen Cao<sup>2</sup>](https://openreview.net/profile?id=~Jinwen_Cao1), <br>[Zuozhu Liu<sup>1</sup>](https://person.zju.edu.cn/en/lzz), [Haoji Hu<sup>1†</sup>](https://person.zju.edu.cn/en/huhaoji), [Xiang Bai<sup>4</sup>](https://scholar.google.com/citations?user=UeltiQ4AAAAJ&hl=en), [Pengfei Wan<sup>2</sup>](https://scholar.google.com/citations?user=P6MraaYAAAAJ&hl=en), [Di Zhang<sup>2</sup>](https://openreview.net/profile?id=~Di_ZHANG3)**_
|
| 10 |
+
<br>
|
| 11 |
+
(*Work done during an internship at KwaiVGI, Kuaishou Technology †corresponding authors)
|
| 12 |
+
|
| 13 |
+
<sup>1</sup>Zhejiang University, <sup>2</sup>Kuaishou Technology, <sup>3</sup>CUHK, <sup>4</sup>HUST.
|
| 14 |
+
|
| 15 |
+
</div>
|
| 16 |
+
|
| 17 |
+
**Important Note:** This open-source repository is intended to provide a reference implementation. Due to the difference in the underlying T2V model's performance, the open-source version may not achieve the same performance as the model in our paper. If you'd like to use the best version of ReCamMaster, please upload your video to [this link](https://docs.google.com/forms/d/e/1FAIpQLSezOzGPbm8JMXQDq6EINiDf6iXn7rV4ozj6KcbQCSAzE8Vsnw/viewform?usp=dialog). Additionally, we are working on developing an online trial website. Please stay tuned to updates on the [Kling website](https://app.klingai.com/global/).
|
| 18 |
+
|
| 19 |
+
## 🔥 Updates
|
| 20 |
+
- __[2025.04.15]__: Please feel free to explore our related work, [SynCamMaster](https://github.com/KwaiVGI/SynCamMaster).
|
| 21 |
+
- __[2025.04.09]__: Release the [training and inference code](https://github.com/KwaiVGI/ReCamMaster?tab=readme-ov-file#%EF%B8%8F-code-recammaster--wan21-inference--training), [model checkpoint](https://huggingface.co/KwaiVGI/ReCamMaster-Wan2.1/blob/main/step20000.ckpt).
|
| 22 |
+
- __[2025.03.31]__: Release the [MultiCamVideo Dataset](https://huggingface.co/datasets/KwaiVGI/MultiCamVideo-Dataset).
|
| 23 |
+
- __[2025.03.31]__: We have sent the inference results to the first 1000 trial users.
|
| 24 |
+
- __[2025.03.17]__: Release the [project page](https://jianhongbai.github.io/ReCamMaster/) and the [try out link](https://docs.google.com/forms/d/e/1FAIpQLSezOzGPbm8JMXQDq6EINiDf6iXn7rV4ozj6KcbQCSAzE8Vsnw/viewform?usp=dialog).
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
## 📖 Introduction
|
| 30 |
+
|
| 31 |
+
**TL;DR:** We propose ReCamMaster to re-capture in-the-wild videos with novel camera trajectories, achieved through our proposed simple-and-effective video conditioning scheme. We also release a multi-camera synchronized video [dataset](https://huggingface.co/datasets/KwaiVGI/MultiCamVideo-Dataset) rendered with Unreal Engine 5. <br>
|
| 32 |
+
|
| 33 |
+
https://github.com/user-attachments/assets/52455e86-1adb-458d-bc37-4540a65a60d4
|
| 34 |
+
|
| 35 |
+
## 🚀 Trail: Try ReCamMaster with Your Own Videos
|
| 36 |
+
|
| 37 |
+
**Update:** We are actively processing the videos uploaded by users. So far, we have sent the inference results to the email addresses of the first **1500** testers. You should receive an email titled "Inference Results of ReCamMaster" from either [email protected] or [email protected]. Please also check your spam folder, and let us know if you haven't received the email after a long time. If you enjoyed the videos we created, please consider giving us a star 🌟.
|
| 38 |
+
|
| 39 |
+
**You can try out our ReCamMaster by uploading your own video to [this link](https://docs.google.com/forms/d/e/1FAIpQLSezOzGPbm8JMXQDq6EINiDf6iXn7rV4ozj6KcbQCSAzE8Vsnw/viewform?usp=dialog), which will generate a video with camera movements along a new trajectory.** We will send the mp4 file generated by ReCamMaster to your inbox as soon as possible. For camera movement trajectories, we offer 10 basic camera trajectories as follows:
|
| 40 |
+
|
| 41 |
+
| Index | Basic Trajectory |
|
| 42 |
+
|-------------------|-----------------------------|
|
| 43 |
+
| 1 | Pan Right |
|
| 44 |
+
| 2 | Pan Left |
|
| 45 |
+
| 3 | Tilt Up |
|
| 46 |
+
| 4 | Tilt Down |
|
| 47 |
+
| 5 | Zoom In |
|
| 48 |
+
| 6 | Zoom Out |
|
| 49 |
+
| 7 | Translate Up (with rotation) |
|
| 50 |
+
| 8 | Translate Down (with rotation) |
|
| 51 |
+
| 9 | Arc Left (with rotation) |
|
| 52 |
+
| 10 | Arc Right (with rotation) |
|
| 53 |
+
|
| 54 |
+
If you would like to use ReCamMaster as a baseline and need qualitative or quantitative comparisons, please feel free to drop an email to [[email protected]](mailto:[email protected]). We can assist you with batch inference of our model.
|
| 55 |
+
|
| 56 |
+
## ⚙️ Code: ReCamMaster + Wan2.1 (Inference & Training)
|
| 57 |
+
The model utilized in our paper is an internally developed T2V model, not [Wan2.1](https://github.com/Wan-Video/Wan2.1). Due to company policy restrictions, we are unable to open-source the model used in the paper. Consequently, we migrated ReCamMaster to Wan2.1 to validate the effectiveness of our method. Due to differences in the underlying T2V model, you may not achieve the same results as demonstrated in the demo.
|
| 58 |
+
### Inference
|
| 59 |
+
Step 1: Set up the environment
|
| 60 |
+
|
| 61 |
+
[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) requires Rust and Cargo to compile extensions. You can install them using the following command:
|
| 62 |
+
```shell
|
| 63 |
+
curl --proto '=https' --tlsv1.2 -sSf [https://sh.rustup.rs](https://sh.rustup.rs/) | sh
|
| 64 |
+
. "$HOME/.cargo/env"
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
Install [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio):
|
| 68 |
+
```shell
|
| 69 |
+
git clone https://github.com/KwaiVGI/ReCamMaster.git
|
| 70 |
+
cd ReCamMaster
|
| 71 |
+
pip install -e .
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
Step 2: Download the pretrained checkpoints
|
| 75 |
+
1. Download the pre-trained Wan2.1 models
|
| 76 |
+
|
| 77 |
+
```shell
|
| 78 |
+
cd ReCamMaster
|
| 79 |
+
python download_wan2.1.py
|
| 80 |
+
```
|
| 81 |
+
2. Download the pre-trained ReCamMaster checkpoint
|
| 82 |
+
|
| 83 |
+
Please download from [huggingface](https://huggingface.co/KwaiVGI/ReCamMaster-Wan2.1/blob/main/step20000.ckpt) and place it in ```models/ReCamMaster/checkpoints```.
|
| 84 |
+
|
| 85 |
+
Step 3: Test the example videos
|
| 86 |
+
```shell
|
| 87 |
+
python inference_recammaster.py --cam_type 1
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
Step 4: Test your own videos
|
| 91 |
+
|
| 92 |
+
If you want to test your own videos, you need to prepare your test data following the structure of the ```example_test_data``` folder. This includes N mp4 videos, each with at least 81 frames, and a ```metadata.csv``` file that stores their paths and corresponding captions. You can refer to the [Prompt Extension section](https://github.com/Wan-Video/Wan2.1?tab=readme-ov-file#2-using-prompt-extension) in Wan2.1 for guidance on preparing video captions.
|
| 93 |
+
|
| 94 |
+
```shell
|
| 95 |
+
python inference_recammaster.py --cam_type 1 --dataset_path path/to/your/data
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
We provide several preset camera types, as shown in the table below. Additionally, you can generate new trajectories for testing.
|
| 99 |
+
|
| 100 |
+
| cam_type | Trajectory |
|
| 101 |
+
|-------------------|-----------------------------|
|
| 102 |
+
| 1 | Pan Right |
|
| 103 |
+
| 2 | Pan Left |
|
| 104 |
+
| 3 | Tilt Up |
|
| 105 |
+
| 4 | Tilt Down |
|
| 106 |
+
| 5 | Zoom In |
|
| 107 |
+
| 6 | Zoom Out |
|
| 108 |
+
| 7 | Translate Up (with rotation) |
|
| 109 |
+
| 8 | Translate Down (with rotation) |
|
| 110 |
+
| 9 | Arc Left (with rotation) |
|
| 111 |
+
| 10 | Arc Right (with rotation) |
|
| 112 |
+
|
| 113 |
+
### Training
|
| 114 |
+
|
| 115 |
+
Step 1: Set up the environment
|
| 116 |
+
|
| 117 |
+
```shell
|
| 118 |
+
pip install lightning pandas websockets
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
Step 2: Prepare the training dataset
|
| 122 |
+
|
| 123 |
+
1. Download the [MultiCamVideo dataset](https://huggingface.co/datasets/KwaiVGI/MultiCamVideo-Dataset).
|
| 124 |
+
|
| 125 |
+
2. Extract VAE features
|
| 126 |
+
|
| 127 |
+
```shell
|
| 128 |
+
CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" python train_recammaster.py --task data_process --dataset_path path/to/the/MultiCamVideo/Dataset --output_path ./models --text_encoder_path "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth" --vae_path "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth" --tiled --num_frames 81 --height 480 --width 832 --dataloader_num_workers 2
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
3. Generate Captions for Each Video
|
| 132 |
+
|
| 133 |
+
You can use video caption tools like [LLaVA](https://github.com/haotian-liu/LLaVA) to generate captions for each video and store them in the ```metadata.csv``` file.
|
| 134 |
+
|
| 135 |
+
Step 3: Training
|
| 136 |
+
```shell
|
| 137 |
+
CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" python train_recammaster.py --task train --dataset_path recam_train_data --output_path ./models/train --dit_path "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors" --steps_per_epoch 8000 --max_epochs 100 --learning_rate 1e-4 --accumulate_grad_batches 1 --use_gradient_checkpointing --dataloader_num_workers 4
|
| 138 |
+
```
|
| 139 |
+
We do not explore the optimal set of hyper-parameters and train with a batch size of 1 on each GPU. You may achieve better model performance by adjusting hyper-parameters such as the learning rate and increasing the batch size.
|
| 140 |
+
|
| 141 |
+
Step 4: Test the model
|
| 142 |
+
|
| 143 |
+
```shell
|
| 144 |
+
python inference_recammaster.py --cam_type 1 --ckpt_path path/to/the/checkpoint
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
## 📷 Dataset: MultiCamVideo Dataset
|
| 148 |
+
### 1. Dataset Introduction
|
| 149 |
+
|
| 150 |
+
**TL;DR:** The MultiCamVideo Dataset is a multi-camera synchronized video dataset rendered using Unreal Engine 5. It includes synchronized multi-camera videos and their corresponding camera trajectories. The MultiCamVideo Dataset can be valuable in fields such as camera-controlled video generation, synchronized video production, and 3D/4D reconstruction. If you are looking for synchronized videos captured with stationary cameras, please explore our [SynCamVideo Dataset](https://github.com/KwaiVGI/SynCamMaster).
|
| 151 |
+
|
| 152 |
+
https://github.com/user-attachments/assets/6fa25bcf-1136-43be-8110-b527638874d4
|
| 153 |
+
|
| 154 |
+
The MultiCamVideo Dataset is a multi-camera synchronized video dataset rendered using Unreal Engine 5. It includes synchronized multi-camera videos and their corresponding camera trajectories.
|
| 155 |
+
It consists of 13.6K different dynamic scenes, each captured by 10 cameras, resulting in a total of 136K videos. Each dynamic scene is composed of four elements: {3D environment, character, animation, camera}. Specifically, we use animation to drive the character,
|
| 156 |
+
and position the animated character within the 3D environment. Then, Time-synchronized cameras are set up to move along predefined trajectories to render the multi-camera video data.
|
| 157 |
+
<p align="center">
|
| 158 |
+
<img src="https://github.com/user-attachments/assets/107c9607-e99b-4493-b715-3e194fcb3933" alt="Example Image" width="70%">
|
| 159 |
+
</p>
|
| 160 |
+
|
| 161 |
+
**3D Environment:** We collect 37 high-quality 3D environments assets from [Fab](https://www.fab.com). To minimize the domain gap between rendered data and real-world videos, we primarily select visually realistic 3D scenes, while choosing a few stylized or surreal 3D scenes as a supplement. To ensure data diversity, the selected scenes cover a variety of indoor and outdoor settings, such as city streets, shopping malls, cafes, office rooms, and the countryside.
|
| 162 |
+
|
| 163 |
+
**Character:** We collect 66 different human 3D models as characters from [Fab](https://www.fab.com) and [Mixamo](https://www.mixamo.com).
|
| 164 |
+
|
| 165 |
+
**Animation:** We collect 93 different animations from [Fab](https://www.fab.com) and [Mixamo](https://www.mixamo.com), including common actions such as waving, dancing, and cheering. We use these animations to drive the collected characters and create diverse datasets through various combinations.
|
| 166 |
+
|
| 167 |
+
**Camera:** To ensure camera movements are diverse and closely resemble real-world distributions, we create a wide range of camera trajectories and parameters to cover various situations. To achieve this by designing rules to batch-generate random camera starting positions and movement trajectories:
|
| 168 |
+
|
| 169 |
+
1. Camera Starting Position.
|
| 170 |
+
|
| 171 |
+
We take the character's position as the center of a hemisphere with a radius of {3m, 5m, 7m, 10m} based on the size of the 3D scene and randomly sample within this range as the camera's starting point, ensuring the closest distance to the character is greater than 0.5m and the pitch angle is within 45 degrees.
|
| 172 |
+
|
| 173 |
+
2. Camera Trajectories.
|
| 174 |
+
|
| 175 |
+
- **Pan & Tilt**:
|
| 176 |
+
The camera rotation angles are randomly selected within the range, with pan angles ranging from 5 to 45 degrees and tilt angles ranging from 5 to 30 degrees, with directions randomly chosen left/right or up/down.
|
| 177 |
+
|
| 178 |
+
- **Basic Translation**:
|
| 179 |
+
The camera translates along the positive and negative directions of the xyz axes, with movement distances randomly selected within the range of $[\frac{1}{4}, 1] \times \text{distance2character}$.
|
| 180 |
+
|
| 181 |
+
- **Basic Arc Trajectory**:
|
| 182 |
+
The camera moves along an arc, with rotation angles randomly selected within the range of 15 to 75 degrees.
|
| 183 |
+
|
| 184 |
+
- **Random Trajectories**:
|
| 185 |
+
1-3 points are sampled in space, and the camera moves from the initial position through these points as the movement trajectory, with the total movement distance randomly selected within the range of $[\frac{1}{4}, 1] \times \text{distance2character}$. The polyline is smoothed to make the movement more natural.
|
| 186 |
+
|
| 187 |
+
- **Static Camera**:
|
| 188 |
+
The camera does not translate or rotate during shooting, maintaining a fixed position.
|
| 189 |
+
|
| 190 |
+
3. Camera Movement Speed.
|
| 191 |
+
|
| 192 |
+
To further enhance the diversity of trajectories, 50% of the training data uses constant-speed camera trajectories, while the other 50% uses variable-speed trajectories generated by nonlinear functions. Consider a camera trajectory with a total of $f$ frames, starting at location $L_{start}$ and ending at position $L_{end}$. The location at the $i$-th frame is given by:
|
| 193 |
+
```math
|
| 194 |
+
L_i = L_{start} + (L_{end} - L_{start}) \cdot \left( \frac{1 - \exp(-a \cdot i/f)}{1 - \exp(-a)} \right),
|
| 195 |
+
```
|
| 196 |
+
where $a$ is an adjustable parameter to control the trajectory speed. When $a > 0$, the trajectory starts fast and then slows down; when $a < 0$, the trajectory starts slow and then speeds up. The larger the absolute value of $a$, the more drastic the change.
|
| 197 |
+
|
| 198 |
+
4. Camera Parameters.
|
| 199 |
+
|
| 200 |
+
We chose four set of camera parameters: {focal=18mm, aperture=10}, {focal=24mm, aperture=5}, {focal=35mm, aperture=2.4} and {focal=50mm, aperture=2.4}.
|
| 201 |
+
|
| 202 |
+
### 2. Statistics and Configurations
|
| 203 |
+
|
| 204 |
+
Dataset Statistics:
|
| 205 |
+
|
| 206 |
+
| Number of Dynamic Scenes | Camera per Scene | Total Videos |
|
| 207 |
+
|:------------------------:|:----------------:|:------------:|
|
| 208 |
+
| 13,600 | 10 | 136,000 |
|
| 209 |
+
|
| 210 |
+
Video Configurations:
|
| 211 |
+
|
| 212 |
+
| Resolution | Frame Number | FPS |
|
| 213 |
+
|:-----------:|:------------:|:------------------------:|
|
| 214 |
+
| 1280x1280 | 81 | 15 |
|
| 215 |
+
|
| 216 |
+
Note: You can use 'center crop' to adjust the video's aspect ratio to fit your video generation model, such as 16:9, 9:16, 4:3, or 3:4.
|
| 217 |
+
|
| 218 |
+
Camera Configurations:
|
| 219 |
+
|
| 220 |
+
| Focal Length | Aperture | Sensor Height | Sensor Width |
|
| 221 |
+
|:-----------------------:|:------------------:|:-------------:|:------------:|
|
| 222 |
+
| 18mm, 24mm, 35mm, 50mm | 10.0, 5.0, 2.4 | 23.76mm | 23.76mm |
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
### 3. File Structure
|
| 227 |
+
```
|
| 228 |
+
MultiCamVideo-Dataset
|
| 229 |
+
├── train
|
| 230 |
+
│ ├── f18_aperture10
|
| 231 |
+
│ │ ├── scene1 # one dynamic scene
|
| 232 |
+
│ │ │ ├── videos
|
| 233 |
+
│ │ │ │ ├── cam01.mp4 # synchronized 81-frame videos at 1280x1280 resolution
|
| 234 |
+
│ │ │ │ ├── cam02.mp4
|
| 235 |
+
│ │ │ │ ├── ...
|
| 236 |
+
│ │ │ │ └── cam10.mp4
|
| 237 |
+
│ │ │ └── cameras
|
| 238 |
+
│ │ │ └── camera_extrinsics.json # 81-frame camera extrinsics of the 10 cameras
|
| 239 |
+
│ │ ├── ...
|
| 240 |
+
│ │ └── scene3400
|
| 241 |
+
│ ├── f24_aperture5
|
| 242 |
+
│ │ ├── scene1
|
| 243 |
+
│ │ ├── ...
|
| 244 |
+
│ │ └── scene3400
|
| 245 |
+
│ ├── f35_aperture2.4
|
| 246 |
+
│ │ ├── scene1
|
| 247 |
+
│ │ ├── ...
|
| 248 |
+
│ │ └── scene3400
|
| 249 |
+
│ └── f50_aperture2.4
|
| 250 |
+
│ ├── scene1
|
| 251 |
+
│ ├── ...
|
| 252 |
+
│ └── scene3400
|
| 253 |
+
└── val
|
| 254 |
+
└── 10basic_trajectories
|
| 255 |
+
├── videos
|
| 256 |
+
│ ├── cam01.mp4 # example videos corresponding to the validation cameras
|
| 257 |
+
│ ├── cam02.mp4
|
| 258 |
+
│ ├── ...
|
| 259 |
+
│ └── cam10.mp4
|
| 260 |
+
└── cameras
|
| 261 |
+
└── camera_extrinsics.json # 10 different trajectories for validation
|
| 262 |
+
```
|
| 263 |
+
|
| 264 |
+
### 3. Useful scripts
|
| 265 |
+
- Data Extraction
|
| 266 |
+
```bash
|
| 267 |
+
cat MultiCamVideo-Dataset.part* > MultiCamVideo-Dataset.tar.gz
|
| 268 |
+
tar -xzvf MultiCamVideo-Dataset.tar.gz
|
| 269 |
+
```
|
| 270 |
+
- Camera Visualization
|
| 271 |
+
```python
|
| 272 |
+
python vis_cam.py
|
| 273 |
+
```
|
| 274 |
+
|
| 275 |
+
The visualization script is modified from [CameraCtrl](https://github.com/hehao13/CameraCtrl/blob/main/tools/visualize_trajectory.py), thanks for their inspiring work.
|
| 276 |
+
|
| 277 |
+
<p align="center">
|
| 278 |
+
<img src="https://github.com/user-attachments/assets/f9cf342d-2fb3-40ef-a7be-edafb5775004" alt="Example Image" width="40%">
|
| 279 |
+
</p>
|
| 280 |
+
|
| 281 |
+
## 🤗 Awesome Related Works
|
| 282 |
+
Feel free to explore these outstanding related works, including but not limited to:
|
| 283 |
+
|
| 284 |
+
[GCD](https://gcd.cs.columbia.edu/): GCD synthesizes large-angle novel viewpoints of 4D dynamic scenes from a monocular video.
|
| 285 |
+
|
| 286 |
+
[ReCapture](https://generative-video-camera-controls.github.io/): a method for generating new videos with novel camera trajectories from a single user-provided video.
|
| 287 |
+
|
| 288 |
+
[Trajectory Attention](https://xizaoqu.github.io/trajattn/): Trajectory Attention facilitates various tasks like camera motion control on images and videos, and video editing.
|
| 289 |
+
|
| 290 |
+
[GS-DiT](https://wkbian.github.io/Projects/GS-DiT/): GS-DiT provides 4D video control for a single monocular video.
|
| 291 |
+
|
| 292 |
+
[Diffusion as Shader](https://igl-hkust.github.io/das/): a versatile video generation control model for various tasks.
|
| 293 |
+
|
| 294 |
+
[TrajectoryCrafter](https://trajectorycrafter.github.io/): TrajectoryCrafter achieves high-fidelity novel views generation from casually captured monocular video.
|
| 295 |
+
|
| 296 |
+
[GEN3C](https://research.nvidia.com/labs/toronto-ai/GEN3C/): a generative video model with precise Camera Control and temporal 3D Consistency.
|
| 297 |
+
|
| 298 |
+
## 🌟 Citation
|
| 299 |
+
|
| 300 |
+
Please leave us a star 🌟 and cite our paper if you find our work helpful.
|
| 301 |
+
```
|
| 302 |
+
@article{bai2025recammaster,
|
| 303 |
+
title={ReCamMaster: Camera-Controlled Generative Rendering from A Single Video},
|
| 304 |
+
author={Bai, Jianhong and Xia, Menghan and Fu, Xiao and Wang, Xintao and Mu, Lianrui and Cao, Jinwen and Liu, Zuozhu and Hu, Haoji and Bai, Xiang and Wan, Pengfei and others},
|
| 305 |
+
journal={arXiv preprint arXiv:2503.11647},
|
| 306 |
+
year={2025}
|
| 307 |
+
}
|
| 308 |
+
```
|
README.md
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- video-generation
|
| 5 |
+
- diffusion
|
| 6 |
+
- world-model
|
| 7 |
+
library_name: diffusion
|
| 8 |
+
model-index:
|
| 9 |
+
- name: Astra
|
| 10 |
+
results: []
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# Astra 🌏: General Interactive World Model with Autoregressive Denoising
|
| 14 |
+
|
| 15 |
+
<div align="center">
|
| 16 |
+
|
| 17 |
+
<div style="margin-top: 0; margin-bottom: -20px;">
|
| 18 |
+
<img src="./assets/images/logo-text-2.png" width="50%" />
|
| 19 |
+
</div>
|
| 20 |
+
|
| 21 |
+
<h3 style="margin-top: 0;">
|
| 22 |
+
📄
|
| 23 |
+
[<a href="https://arxiv.org/abs/2512.08931" target="_blank">arXiv</a>]
|
| 24 |
+
|
| 25 |
+
🏠
|
| 26 |
+
[<a href="https://eternalevan.github.io/Astra-project/" target="_blank">Project Page</a>]
|
| 27 |
+
|
| 28 |
+
🖥️
|
| 29 |
+
[<a href="https://huggingface.co/EvanEternal/Astra" target="_blank">Github</a>]
|
| 30 |
+
</h3>
|
| 31 |
+
|
| 32 |
+
</div>
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
<!-- <div align="center">
|
| 36 |
+
<a href="https://hunyuan.tencent.com/video/zh?tabIndex=0" target="_blank"><img src=https://img.shields.io/badge/Official%20Site-333399.svg?logo=homepage height=22px></a>
|
| 37 |
+
<a href=https://huggingface.co/tencent/HunyuanVideo-1.5 target="_blank"><img src=https://img.shields.io/badge/%F0%9F%A4%97%20Models-d96902.svg height=22px></a>
|
| 38 |
+
<a href=https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5 target="_blank"><img src= https://img.shields.io/badge/Page-bb8a2e.svg?logo=github height=22px></a>
|
| 39 |
+
<a href="https://arxiv.org/pdf/2511.18870" target="_blank"><img src=https://img.shields.io/badge/Report-b5212f.svg?logo=arxiv height=22px></a>
|
| 40 |
+
<a href=https://x.com/TencentHunyuan target="_blank"><img src=https://img.shields.io/badge/Hunyuan-black.svg?logo=x height=22px></a>
|
| 41 |
+
<a href="https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/assets/HunyuanVideo_1_5_Prompt_Handbook_EN.md" target="_blank"><img src=https://img.shields.io/badge/📚-PromptHandBook-blue.svg?logo=book height=22px></a> <br/>
|
| 42 |
+
<a href="./ComfyUI/README.md" target="_blank"><img src=https://img.shields.io/badge/ComfyUI-blue.svg?logo=book height=22px></a>
|
| 43 |
+
<a href="https://github.com/ModelTC/LightX2V" target="_blank"><img src=https://img.shields.io/badge/LightX2V-yellow.svg?logo=book height=22px></a>
|
| 44 |
+
<a href="https://tusi.cn/models/933574988890423836" target="_blank"><img src=https://img.shields.io/badge/吐司-purple.svg?logo=book height=22px></a>
|
| 45 |
+
<a href="https://tensor.art/models/933574988890423836" target="_blank"><img src=https://img.shields.io/badge/TensorArt-cyan.svg?logo=book height=22px></a>
|
| 46 |
+
|
| 47 |
+
</div> -->
|
| 48 |
+
|
| 49 |
+
<div align="center">
|
| 50 |
+
|
| 51 |
+
**[Yixuan Zhu<sup>1</sup>](https://eternalevan.github.io/), [Jiaqi Feng<sup>1</sup>](https://github.com/Aurora-edu/), [Wenzhao Zheng<sup>1 †</sup>](https://wzzheng.net), [Yuan Gao<sup>2</sup>](https://openreview.net/profile?id=~Yuan_Gao32), [Xin Tao<sup>2</sup>](https://www.xtao.website), [Pengfei Wan<sup>2</sup>](https://scholar.google.com/citations?user=P6MraaYAAAAJ&hl=en), [Jie Zhou <sup>1</sup>](https://scholar.google.com/citations?user=6a79aPwAAAAJ&hl=en&authuser=1), [Jiwen Lu<sup>1</sup>](https://ivg.au.tsinghua.edu.cn/Jiwen_Lu/)**
|
| 52 |
+
<!-- <br> -->
|
| 53 |
+
(† Project leader)
|
| 54 |
+
|
| 55 |
+
<sup>1</sup>Tsinghua University, <sup>2</sup>Kuaishou Technology.
|
| 56 |
+
</div>
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
## 📖 Introduction
|
| 60 |
+
|
| 61 |
+
**TL;DR:** Astra is an **interactive world model** that delivers realistic long-horizon video rollouts under a wide range of scenarios and action inputs.
|
| 62 |
+
|
| 63 |
+
**Astra** is an **interactive**, action-driven world model that predicts long-horizon future videos across diverse real-world scenarios. Built on an autoregressive diffusion transformer with temporal causal attention, Astra supports **streaming prediction** while preserving strong temporal coherence. Astra introduces **noise-augmented history memory** to stabilize long rollouts, an **action-aware adapter** for precise control signals, and a **mixture of action experts** to route heterogeneous action modalities. Through these key innovations, Astra delivers consistent, controllable, and high-fidelity video futures for applications such as autonomous driving, robot manipulation, and camera motion.
|
| 64 |
+
|
| 65 |
+
<!-- <div align="center">
|
| 66 |
+
<img src="./assets/images/pipeline.png" alt="Astra Pipeline" width="90%">
|
| 67 |
+
</div> -->
|
| 68 |
+
|
| 69 |
+
## Gallery
|
| 70 |
+
|
| 71 |
+
### Astra+Wan2.1
|
| 72 |
+
|
| 73 |
+
<table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
|
| 74 |
+
<tr>
|
| 75 |
+
<td>
|
| 76 |
+
<video src="https://github.com/user-attachments/assets/715a5b66-3966-4923-aa00-02315fb07761"
|
| 77 |
+
style="width:100%; object-fit:cover;"
|
| 78 |
+
controls autoplay loop muted></video>
|
| 79 |
+
</td>
|
| 80 |
+
<td>
|
| 81 |
+
<video src="https://github.com/user-attachments/assets/c7156c4d-d51d-493c-995e-5113c3d49abb"
|
| 82 |
+
style="width:100%; object-fit:cover;"
|
| 83 |
+
controls autoplay loop muted></video>
|
| 84 |
+
</td>
|
| 85 |
+
</tr>
|
| 86 |
+
|
| 87 |
+
<tr>
|
| 88 |
+
<td>
|
| 89 |
+
<video src="https://github.com/user-attachments/assets/d899d704-c706-4e64-a24b-eea174d2173d"
|
| 90 |
+
style="width:100%; height:180px; object-fit:cover;"
|
| 91 |
+
controls autoplay loop muted></video>
|
| 92 |
+
</td>
|
| 93 |
+
<td>
|
| 94 |
+
<video src="https://github.com/user-attachments/assets/c1d8beb2-3102-468a-8019-624d89fba125"
|
| 95 |
+
style="width:100%; height:180px; object-fit:cover;"
|
| 96 |
+
controls autoplay loop muted></video>
|
| 97 |
+
</td>
|
| 98 |
+
</tr>
|
| 99 |
+
</table>
|
| 100 |
+
|
| 101 |
+
<!-- ## 🚀 Trail: Try ReCamMaster with Your Own Videos
|
| 102 |
+
|
| 103 |
+
**Update:** We are actively processing the videos uploaded by users. So far, we have sent the inference results to the email addresses of the first **1180** testers. You should receive an email titled "Inference Results of ReCamMaster" from either [email protected] or [email protected]. Please also check your spam folder, and let us know if you haven't received the email after a long time. If you enjoyed the videos we created, please consider giving us a star 🌟.
|
| 104 |
+
|
| 105 |
+
**You can try out our ReCamMaster by uploading your own video to [this link](https://docs.google.com/forms/d/e/1FAIpQLSezOzGPbm8JMXQDq6EINiDf6iXn7rV4ozj6KcbQCSAzE8Vsnw/viewform?usp=dialog), which will generate a video with camera movements along a new trajectory.** We will send the mp4 file generated by ReCamMaster to your inbox as soon as possible. For camera movement trajectories, we offer 10 basic camera trajectories as follows:
|
| 106 |
+
|
| 107 |
+
| Index | Basic Trajectory |
|
| 108 |
+
|-------------------|-----------------------------|
|
| 109 |
+
| 1 | Pan Right |
|
| 110 |
+
| 2 | Pan Left |
|
| 111 |
+
| 3 | Tilt Up |
|
| 112 |
+
| 4 | Tilt Down |
|
| 113 |
+
| 5 | Zoom In |
|
| 114 |
+
| 6 | Zoom Out |
|
| 115 |
+
| 7 | Translate Up (with rotation) |
|
| 116 |
+
| 8 | Translate Down (with rotation) |
|
| 117 |
+
| 9 | Arc Left (with rotation) |
|
| 118 |
+
| 10 | Arc Right (with rotation) |
|
| 119 |
+
|
| 120 |
+
If you would like to use ReCamMaster as a baseline and need qualitative or quantitative comparisons, please feel free to drop an email to [[email protected]](mailto:[email protected]). We can assist you with batch inference of our model. -->
|
| 121 |
+
|
| 122 |
+
## 🔥 Updates
|
| 123 |
+
- __[2025.11.17]__: Release the [project page](https://eternalevan.github.io/Astra-project/).
|
| 124 |
+
- __[2025.12.09]__: Release the inference code, model checkpoint.
|
| 125 |
+
|
| 126 |
+
## 🎯 TODO List
|
| 127 |
+
|
| 128 |
+
- [ ] **Release full inference pipelines** for additional scenarios:
|
| 129 |
+
- [ ] 🚗 Autonomous driving
|
| 130 |
+
- [ ] 🤖 Robotic manipulation
|
| 131 |
+
- [ ] 🛸 Drone navigation / exploration
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
- [ ] **Open-source training scripts**:
|
| 135 |
+
- [ ] ⬆️ Action-conditioned autoregressive denoising training
|
| 136 |
+
- [ ] 🔄 Multi-scenario joint training pipeline
|
| 137 |
+
|
| 138 |
+
- [ ] **Release dataset preprocessing tools**
|
| 139 |
+
|
| 140 |
+
- [ ] **Provide unified evaluation toolkit**
|
| 141 |
+
|
| 142 |
+
## ⚙️ Run Astra (Inference)
|
| 143 |
+
Astra is built upon [Wan2.1-1.3B](https://github.com/Wan-Video/Wan2.1), a diffusion-based video generation model. We provide inference scripts to help you quickly generate videos from images and action inputs. Follow the steps below:
|
| 144 |
+
|
| 145 |
+
### Inference
|
| 146 |
+
Step 1: Set up the environment
|
| 147 |
+
|
| 148 |
+
[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) requires Rust and Cargo to compile extensions. You can install them using the following command:
|
| 149 |
+
```shell
|
| 150 |
+
curl --proto '=https' --tlsv1.2 -sSf [https://sh.rustup.rs](https://sh.rustup.rs/) | sh
|
| 151 |
+
. "$HOME/.cargo/env"
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
Install [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio):
|
| 155 |
+
```shell
|
| 156 |
+
git clone https://github.com/EternalEvan/Astra.git
|
| 157 |
+
cd Astra
|
| 158 |
+
pip install -e .
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
Step 2: Download the pretrained checkpoints
|
| 162 |
+
1. Download the pre-trained Wan2.1 models
|
| 163 |
+
|
| 164 |
+
```shell
|
| 165 |
+
cd script
|
| 166 |
+
python download_wan2.1.py
|
| 167 |
+
```
|
| 168 |
+
2. Download the pre-trained Astra checkpoint
|
| 169 |
+
|
| 170 |
+
Please download from [huggingface](https://huggingface.co/EvanEternal/Astra/blob/main/models/Astra/checkpoints/diffusion_pytorch_model.ckpt) and place it in ```models/Astra/checkpoints```.
|
| 171 |
+
|
| 172 |
+
Step 3: Test the example image
|
| 173 |
+
```shell
|
| 174 |
+
python infer_demo.py \
|
| 175 |
+
--dit_path ../models/Astra/checkpoints/diffusion_pytorch_model.ckpt \
|
| 176 |
+
--wan_model_path ../models/Wan-AI/Wan2.1-T2V-1.3B \
|
| 177 |
+
--condition_image ../examples/condition_images/garden_1.png \
|
| 178 |
+
--cam_type 4 \
|
| 179 |
+
--prompt "A sunlit European street lined with historic buildings and vibrant greenery creates a warm, charming, and inviting atmosphere. The scene shows a picturesque open square paved with red bricks, surrounded by classic narrow townhouses featuring tall windows, gabled roofs, and dark-painted facades. On the right side, a lush arrangement of potted plants and blooming flowers adds rich color and texture to the foreground. A vintage-style streetlamp stands prominently near the center-right, contributing to the timeless character of the street. Mature trees frame the background, their leaves glowing in the warm afternoon sunlight. Bicycles are visible along the edges of the buildings, reinforcing the urban yet leisurely feel. The sky is bright blue with scattered clouds, and soft sun flares enter the frame from the left, enhancing the scene’s inviting, peaceful mood." \
|
| 180 |
+
--output_path ../examples/output_videos/output_moe_framepack_sliding.mp4 \
|
| 181 |
+
```
|
| 182 |
+
|
| 183 |
+
Step 4: Test your own images
|
| 184 |
+
|
| 185 |
+
To test with your own custom images, you need to prepare the target images and their corresponding text prompts. **We recommend that the size of the input images is close to 832×480 (width × height)**, which is consistent with the resolution of the generated video and can help achieve better video generation effects. For prompts generation, you can refer to the [Prompt Extension section](https://github.com/Wan-Video/Wan2.1?tab=readme-ov-file#2-using-prompt-extension) in Wan2.1 for guidance on crafting the captions.
|
| 186 |
+
|
| 187 |
+
```shell
|
| 188 |
+
python infer_demo.py \
|
| 189 |
+
--dit_path path/to/your/dit_ckpt \
|
| 190 |
+
--wan_model_path path/to/your/Wan2.1-T2V-1.3B \
|
| 191 |
+
--condition_image path/to/your/image \
|
| 192 |
+
--cam_type your_cam_type \
|
| 193 |
+
--prompt your_prompt \
|
| 194 |
+
--output_path path/to/your/output_video
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
We provide several preset camera types, as shown in the table below. Additionally, you can generate new trajectories for testing.
|
| 198 |
+
|
| 199 |
+
| cam_type | Trajectory |
|
| 200 |
+
|:-----------:|-----------------------------|
|
| 201 |
+
| 1 | Move Forward (Straight) |
|
| 202 |
+
| 2 | Rotate Left In Place |
|
| 203 |
+
| 3 | Rotate Right In Place |
|
| 204 |
+
| 4 | Move Forward + Rotate Left |
|
| 205 |
+
| 5 | Move Forward + Rotate Right |
|
| 206 |
+
| 6 | S-shaped Trajectory |
|
| 207 |
+
| 7 | Rotate Left → Rotate Right |
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
## Future Work 🚀
|
| 211 |
+
|
| 212 |
+
Looking ahead, we plan to further enhance Astra in several directions:
|
| 213 |
+
|
| 214 |
+
- **Training with Wan-2.2:** Upgrade our model using the latest Wan-2.2 framework to release a more powerful version with improved generation quality.
|
| 215 |
+
- **3D Spatial Consistency:** Explore techniques to better preserve 3D consistency across frames for more coherent and realistic video generation.
|
| 216 |
+
- **Long-Term Memory:** Incorporate mechanisms for long-term memory, enabling the model to handle extended temporal dependencies and complex action sequences.
|
| 217 |
+
|
| 218 |
+
These directions aim to push Astra towards more robust and interactive video world modeling.
|
| 219 |
+
|
| 220 |
+
<!-- ### Training
|
| 221 |
+
|
| 222 |
+
Step 1: Set up the environment
|
| 223 |
+
|
| 224 |
+
```shell
|
| 225 |
+
pip install lightning pandas websockets
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
Step 2: Prepare the training dataset
|
| 229 |
+
|
| 230 |
+
1. Download the [MultiCamVideo dataset](https://huggingface.co/datasets/KwaiVGI/MultiCamVideo-Dataset).
|
| 231 |
+
|
| 232 |
+
2. Extract VAE features
|
| 233 |
+
|
| 234 |
+
```shell
|
| 235 |
+
CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" python train_recammaster.py --task data_process --dataset_path path/to/the/MultiCamVideo/Dataset --output_path ./models --text_encoder_path "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth" --vae_path "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth" --tiled --num_frames 81 --height 480 --width 832 --dataloader_num_workers 2
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
3. Generate Captions for Each Video
|
| 239 |
+
|
| 240 |
+
You can use video caption tools like [LLaVA](https://github.com/haotian-liu/LLaVA) to generate captions for each video and store them in the ```metadata.csv``` file.
|
| 241 |
+
|
| 242 |
+
Step 3: Training
|
| 243 |
+
```shell
|
| 244 |
+
CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" python train_recammaster.py --task train --dataset_path recam_train_data --output_path ./models/train --dit_path "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors" --steps_per_epoch 8000 --max_epochs 100 --learning_rate 1e-4 --accumulate_grad_batches 1 --use_gradient_checkpointing --dataloader_num_workers 4
|
| 245 |
+
```
|
| 246 |
+
We do not explore the optimal set of hyper-parameters and train with a batch size of 1 on each GPU. You may achieve better model performance by adjusting hyper-parameters such as the learning rate and increasing the batch size.
|
| 247 |
+
|
| 248 |
+
Step 4: Test the model
|
| 249 |
+
|
| 250 |
+
```shell
|
| 251 |
+
python inference_recammaster.py --cam_type 1 --ckpt_path path/to/the/checkpoint
|
| 252 |
+
``` -->
|
| 253 |
+
|
| 254 |
+
<!-- ## 📷 Dataset: MultiCamVideo Dataset
|
| 255 |
+
### 1. Dataset Introduction
|
| 256 |
+
|
| 257 |
+
**TL;DR:** The MultiCamVideo Dataset is a multi-camera synchronized video dataset rendered using Unreal Engine 5. It includes synchronized multi-camera videos and their corresponding camera trajectories. The MultiCamVideo Dataset can be valuable in fields such as camera-controlled video generation, synchronized video production, and 3D/4D reconstruction.
|
| 258 |
+
|
| 259 |
+
https://github.com/user-attachments/assets/6fa25bcf-1136-43be-8110-b527638874d4
|
| 260 |
+
|
| 261 |
+
The MultiCamVideo Dataset is a multi-camera synchronized video dataset rendered using Unreal Engine 5. It includes synchronized multi-camera videos and their corresponding camera trajectories.
|
| 262 |
+
It consists of 13.6K different dynamic scenes, each captured by 10 cameras, resulting in a total of 136K videos. Each dynamic scene is composed of four elements: {3D environment, character, animation, camera}. Specifically, we use animation to drive the character,
|
| 263 |
+
and position the animated character within the 3D environment. Then, Time-synchronized cameras are set up to move along predefined trajectories to render the multi-camera video data.
|
| 264 |
+
<p align="center">
|
| 265 |
+
<img src="https://github.com/user-attachments/assets/107c9607-e99b-4493-b715-3e194fcb3933" alt="Example Image" width="70%">
|
| 266 |
+
</p>
|
| 267 |
+
|
| 268 |
+
**3D Environment:** We collect 37 high-quality 3D environments assets from [Fab](https://www.fab.com). To minimize the domain gap between rendered data and real-world videos, we primarily select visually realistic 3D scenes, while choosing a few stylized or surreal 3D scenes as a supplement. To ensure data diversity, the selected scenes cover a variety of indoor and outdoor settings, such as city streets, shopping malls, cafes, office rooms, and the countryside.
|
| 269 |
+
|
| 270 |
+
**Character:** We collect 66 different human 3D models as characters from [Fab](https://www.fab.com) and [Mixamo](https://www.mixamo.com).
|
| 271 |
+
|
| 272 |
+
**Animation:** We collect 93 different animations from [Fab](https://www.fab.com) and [Mixamo](https://www.mixamo.com), including common actions such as waving, dancing, and cheering. We use these animations to drive the collected characters and create diverse datasets through various combinations.
|
| 273 |
+
|
| 274 |
+
**Camera:** To ensure camera movements are diverse and closely resemble real-world distributions, we create a wide range of camera trajectories and parameters to cover various situations. To achieve this by designing rules to batch-generate random camera starting positions and movement trajectories:
|
| 275 |
+
|
| 276 |
+
1. Camera Starting Position.
|
| 277 |
+
|
| 278 |
+
We take the character's position as the center of a hemisphere with a radius of {3m, 5m, 7m, 10m} based on the size of the 3D scene and randomly sample within this range as the camera's starting point, ensuring the closest distance to the character is greater than 0.5m and the pitch angle is within 45 degrees.
|
| 279 |
+
|
| 280 |
+
2. Camera Trajectories.
|
| 281 |
+
|
| 282 |
+
- **Pan & Tilt**:
|
| 283 |
+
The camera rotation angles are randomly selected within the range, with pan angles ranging from 5 to 45 degrees and tilt angles ranging from 5 to 30 degrees, with directions randomly chosen left/right or up/down.
|
| 284 |
+
|
| 285 |
+
- **Basic Translation**:
|
| 286 |
+
The camera translates along the positive and negative directions of the xyz axes, with movement distances randomly selected within the range of $[\frac{1}{4}, 1] \times \text{distance2character}$.
|
| 287 |
+
|
| 288 |
+
- **Basic Arc Trajectory**:
|
| 289 |
+
The camera moves along an arc, with rotation angles randomly selected within the range of 15 to 75 degrees.
|
| 290 |
+
|
| 291 |
+
- **Random Trajectories**:
|
| 292 |
+
1-3 points are sampled in space, and the camera moves from the initial position through these points as the movement trajectory, with the total movement distance randomly selected within the range of $[\frac{1}{4}, 1] \times \text{distance2character}$. The polyline is smoothed to make the movement more natural.
|
| 293 |
+
|
| 294 |
+
- **Static Camera**:
|
| 295 |
+
The camera does not translate or rotate during shooting, maintaining a fixed position.
|
| 296 |
+
|
| 297 |
+
3. Camera Movement Speed.
|
| 298 |
+
|
| 299 |
+
To further enhance the diversity of trajectories, 50% of the training data uses constant-speed camera trajectories, while the other 50% uses variable-speed trajectories generated by nonlinear functions. Consider a camera trajectory with a total of $f$ frames, starting at location $L_{start}$ and ending at position $L_{end}$. The location at the $i$-th frame is given by:
|
| 300 |
+
```math
|
| 301 |
+
L_i = L_{start} + (L_{end} - L_{start}) \cdot \left( \frac{1 - \exp(-a \cdot i/f)}{1 - \exp(-a)} \right),
|
| 302 |
+
```
|
| 303 |
+
where $a$ is an adjustable parameter to control the trajectory speed. When $a > 0$, the trajectory starts fast and then slows down; when $a < 0$, the trajectory starts slow and then speeds up. The larger the absolute value of $a$, the more drastic the change.
|
| 304 |
+
|
| 305 |
+
4. Camera Parameters.
|
| 306 |
+
|
| 307 |
+
We chose four set of camera parameters: {focal=18mm, aperture=10}, {focal=24mm, aperture=5}, {focal=35mm, aperture=2.4} and {focal=50mm, aperture=2.4}.
|
| 308 |
+
|
| 309 |
+
### 2. Statistics and Configurations
|
| 310 |
+
|
| 311 |
+
Dataset Statistics:
|
| 312 |
+
|
| 313 |
+
| Number of Dynamic Scenes | Camera per Scene | Total Videos |
|
| 314 |
+
|:------------------------:|:----------------:|:------------:|
|
| 315 |
+
| 13,600 | 10 | 136,000 |
|
| 316 |
+
|
| 317 |
+
Video Configurations:
|
| 318 |
+
|
| 319 |
+
| Resolution | Frame Number | FPS |
|
| 320 |
+
|:-----------:|:------------:|:------------------------:|
|
| 321 |
+
| 1280x1280 | 81 | 15 |
|
| 322 |
+
|
| 323 |
+
Note: You can use 'center crop' to adjust the video's aspect ratio to fit your video generation model, such as 16:9, 9:16, 4:3, or 3:4.
|
| 324 |
+
|
| 325 |
+
Camera Configurations:
|
| 326 |
+
|
| 327 |
+
| Focal Length | Aperture | Sensor Height | Sensor Width |
|
| 328 |
+
|:-----------------------:|:------------------:|:-------------:|:------------:|
|
| 329 |
+
| 18mm, 24mm, 35mm, 50mm | 10.0, 5.0, 2.4 | 23.76mm | 23.76mm |
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
### 3. File Structure
|
| 334 |
+
```
|
| 335 |
+
MultiCamVideo-Dataset
|
| 336 |
+
├── train
|
| 337 |
+
│ ├── f18_aperture10
|
| 338 |
+
│ │ ├── scene1 # one dynamic scene
|
| 339 |
+
│ │ │ ├── videos
|
| 340 |
+
│ │ │ │ ├── cam01.mp4 # synchronized 81-frame videos at 1280x1280 resolution
|
| 341 |
+
│ │ │ │ ├── cam02.mp4
|
| 342 |
+
│ │ │ │ ├── ...
|
| 343 |
+
│ │ │ │ └── cam10.mp4
|
| 344 |
+
│ │ │ └── cameras
|
| 345 |
+
│ │ │ └── camera_extrinsics.json # 81-frame camera extrinsics of the 10 cameras
|
| 346 |
+
│ │ ├── ...
|
| 347 |
+
│ │ └── scene3400
|
| 348 |
+
│ ├── f24_aperture5
|
| 349 |
+
│ │ ├── scene1
|
| 350 |
+
│ │ ├── ...
|
| 351 |
+
│ │ └── scene3400
|
| 352 |
+
│ ├── f35_aperture2.4
|
| 353 |
+
│ │ ├── scene1
|
| 354 |
+
│ │ ├── ...
|
| 355 |
+
│ │ └── scene3400
|
| 356 |
+
│ └── f50_aperture2.4
|
| 357 |
+
│ ├── scene1
|
| 358 |
+
│ ├── ...
|
| 359 |
+
│ └── scene3400
|
| 360 |
+
└── val
|
| 361 |
+
└── 10basic_trajectories
|
| 362 |
+
├── videos
|
| 363 |
+
│ ├── cam01.mp4 # example videos corresponding to the validation cameras
|
| 364 |
+
│ ├── cam02.mp4
|
| 365 |
+
│ ├── ...
|
| 366 |
+
│ └── cam10.mp4
|
| 367 |
+
└── cameras
|
| 368 |
+
└── camera_extrinsics.json # 10 different trajectories for validation
|
| 369 |
+
```
|
| 370 |
+
|
| 371 |
+
### 3. Useful scripts
|
| 372 |
+
- Data Extraction
|
| 373 |
+
```bash
|
| 374 |
+
cat MultiCamVideo-Dataset.part* > MultiCamVideo-Dataset.tar.gz
|
| 375 |
+
tar -xzvf MultiCamVideo-Dataset.tar.gz
|
| 376 |
+
```
|
| 377 |
+
- Camera Visualization
|
| 378 |
+
```python
|
| 379 |
+
python vis_cam.py
|
| 380 |
+
```
|
| 381 |
+
|
| 382 |
+
The visualization script is modified from [CameraCtrl](https://github.com/hehao13/CameraCtrl/blob/main/tools/visualize_trajectory.py), thanks for their inspiring work.
|
| 383 |
+
|
| 384 |
+
<p align="center">
|
| 385 |
+
<img src="https://github.com/user-attachments/assets/f9cf342d-2fb3-40ef-a7be-edafb5775004" alt="Example Image" width="40%">
|
| 386 |
+
</p> -->
|
| 387 |
+
|
| 388 |
+
## 🤗 Awesome Related Works
|
| 389 |
+
Feel free to explore these outstanding related works, including but not limited to:
|
| 390 |
+
|
| 391 |
+
[ReCamMaster](https://github.com/KlingTeam/ReCamMaster): ReCamMaster re-captures in-the-wild videos with novel camera trajectories.
|
| 392 |
+
|
| 393 |
+
[GCD](https://gcd.cs.columbia.edu/): GCD synthesizes large-angle novel viewpoints of 4D dynamic scenes from a monocular video.
|
| 394 |
+
|
| 395 |
+
[ReCapture](https://generative-video-camera-controls.github.io/): a method for generating new videos with novel camera trajectories from a single user-provided video.
|
| 396 |
+
|
| 397 |
+
[Trajectory Attention](https://xizaoqu.github.io/trajattn/): Trajectory Attention facilitates various tasks like camera motion control on images and videos, and video editing.
|
| 398 |
+
|
| 399 |
+
[GS-DiT](https://wkbian.github.io/Projects/GS-DiT/): GS-DiT provides 4D video control for a single monocular video.
|
| 400 |
+
|
| 401 |
+
[Diffusion as Shader](https://igl-hkust.github.io/das/): a versatile video generation control model for various tasks.
|
| 402 |
+
|
| 403 |
+
[TrajectoryCrafter](https://trajectorycrafter.github.io/): TrajectoryCrafter achieves high-fidelity novel views generation from casually captured monocular video.
|
| 404 |
+
|
| 405 |
+
[GEN3C](https://research.nvidia.com/labs/toronto-ai/GEN3C/): a generative video model with precise Camera Control and temporal 3D Consistency.
|
| 406 |
+
|
| 407 |
+
## 🌟 Citation
|
| 408 |
+
|
| 409 |
+
Please leave us a star 🌟 and cite our paper if you find our work helpful.
|
| 410 |
+
```
|
| 411 |
+
@misc{zhu2025astrageneralinteractiveworld,
|
| 412 |
+
title={Astra: General Interactive World Model with Autoregressive Denoising},
|
| 413 |
+
author={Yixuan Zhu and Jiaqi Feng and Wenzhao Zheng and Yuan Gao and Xin Tao and Pengfei Wan and Jie Zhou and Jiwen Lu},
|
| 414 |
+
year={2025},
|
| 415 |
+
eprint={2512.08931},
|
| 416 |
+
archivePrefix={arXiv},
|
| 417 |
+
primaryClass={cs.CV},
|
| 418 |
+
url={https://arxiv.org/abs/2512.08931},
|
| 419 |
+
}
|
| 420 |
+
```
|
assets/images/logo-text-2.png
ADDED
|
Git LFS Details
|
assets/images/logo-text.png
ADDED
|
Git LFS Details
|
assets/images/logo.png
ADDED
|
Git LFS Details
|
diffsynth/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .data import *
|
| 2 |
+
from .models import *
|
| 3 |
+
from .prompters import *
|
| 4 |
+
from .schedulers import *
|
| 5 |
+
from .pipelines import *
|
| 6 |
+
from .controlnets import *
|
diffsynth/configs/__init__.py
ADDED
|
File without changes
|
diffsynth/configs/model_config.py
ADDED
|
@@ -0,0 +1,778 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing_extensions import Literal, TypeAlias
|
| 2 |
+
|
| 3 |
+
from ..models.sd_text_encoder import SDTextEncoder
|
| 4 |
+
from ..models.sd_unet import SDUNet
|
| 5 |
+
from ..models.sd_vae_encoder import SDVAEEncoder
|
| 6 |
+
from ..models.sd_vae_decoder import SDVAEDecoder
|
| 7 |
+
|
| 8 |
+
from ..models.sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
| 9 |
+
from ..models.sdxl_unet import SDXLUNet
|
| 10 |
+
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
|
| 11 |
+
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
|
| 12 |
+
|
| 13 |
+
from ..models.sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
| 14 |
+
from ..models.sd3_dit import SD3DiT
|
| 15 |
+
from ..models.sd3_vae_decoder import SD3VAEDecoder
|
| 16 |
+
from ..models.sd3_vae_encoder import SD3VAEEncoder
|
| 17 |
+
|
| 18 |
+
from ..models.sd_controlnet import SDControlNet
|
| 19 |
+
from ..models.sdxl_controlnet import SDXLControlNetUnion
|
| 20 |
+
|
| 21 |
+
from ..models.sd_motion import SDMotionModel
|
| 22 |
+
from ..models.sdxl_motion import SDXLMotionModel
|
| 23 |
+
|
| 24 |
+
from ..models.svd_image_encoder import SVDImageEncoder
|
| 25 |
+
from ..models.svd_unet import SVDUNet
|
| 26 |
+
from ..models.svd_vae_decoder import SVDVAEDecoder
|
| 27 |
+
from ..models.svd_vae_encoder import SVDVAEEncoder
|
| 28 |
+
|
| 29 |
+
from ..models.sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
| 30 |
+
from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
| 31 |
+
|
| 32 |
+
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
| 33 |
+
from ..models.hunyuan_dit import HunyuanDiT
|
| 34 |
+
|
| 35 |
+
from ..models.flux_dit import FluxDiT
|
| 36 |
+
from ..models.flux_text_encoder import FluxTextEncoder2
|
| 37 |
+
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
| 38 |
+
from ..models.flux_controlnet import FluxControlNet
|
| 39 |
+
from ..models.flux_ipadapter import FluxIpAdapter
|
| 40 |
+
|
| 41 |
+
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
|
| 42 |
+
from ..models.cog_dit import CogDiT
|
| 43 |
+
|
| 44 |
+
from ..models.omnigen import OmniGenTransformer
|
| 45 |
+
|
| 46 |
+
from ..models.hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
|
| 47 |
+
from ..models.hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
|
| 48 |
+
|
| 49 |
+
from ..extensions.RIFE import IFNet
|
| 50 |
+
from ..extensions.ESRGAN import RRDBNet
|
| 51 |
+
|
| 52 |
+
from ..models.hunyuan_video_dit import HunyuanVideoDiT
|
| 53 |
+
|
| 54 |
+
from ..models.stepvideo_vae import StepVideoVAE
|
| 55 |
+
from ..models.stepvideo_dit import StepVideoModel
|
| 56 |
+
|
| 57 |
+
from ..models.wan_video_dit import WanModel
|
| 58 |
+
from ..models.wan_video_dit_recam_future import WanModelFuture
|
| 59 |
+
from ..models.wan_video_text_encoder import WanTextEncoder
|
| 60 |
+
from ..models.wan_video_image_encoder import WanImageEncoder
|
| 61 |
+
from ..models.wan_video_vae import WanVideoVAE
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
model_loader_configs = [
|
| 65 |
+
# These configs are provided for detecting model type automatically.
|
| 66 |
+
# The format is (state_dict_keys_hash, state_dict_keys_hash_with_shape, model_names, model_classes, model_resource)
|
| 67 |
+
(None, "091b0e30e77c76626b3ba62acdf95343", ["sd_controlnet"], [SDControlNet], "civitai"),
|
| 68 |
+
(None, "4a6c8306a27d916dea81263c8c88f450", ["hunyuan_dit_clip_text_encoder"], [HunyuanDiTCLIPTextEncoder], "civitai"),
|
| 69 |
+
(None, "f4aec400fe394297961218c768004521", ["hunyuan_dit"], [HunyuanDiT], "civitai"),
|
| 70 |
+
(None, "9e6e58043a5a2e332803ed42f6ee7181", ["hunyuan_dit_t5_text_encoder"], [HunyuanDiTT5TextEncoder], "civitai"),
|
| 71 |
+
(None, "13115dd45a6e1c39860f91ab073b8a78", ["sdxl_vae_encoder", "sdxl_vae_decoder"], [SDXLVAEEncoder, SDXLVAEDecoder], "diffusers"),
|
| 72 |
+
(None, "d78aa6797382a6d455362358a3295ea9", ["sd_ipadapter_clip_image_encoder"], [IpAdapterCLIPImageEmbedder], "diffusers"),
|
| 73 |
+
(None, "e291636cc15e803186b47404262ef812", ["sd_ipadapter"], [SDIpAdapter], "civitai"),
|
| 74 |
+
(None, "399c81f2f8de8d1843d0127a00f3c224", ["sdxl_ipadapter_clip_image_encoder"], [IpAdapterXLCLIPImageEmbedder], "diffusers"),
|
| 75 |
+
(None, "a64eac9aa0db4b9602213bc0131281c7", ["sdxl_ipadapter"], [SDXLIpAdapter], "civitai"),
|
| 76 |
+
(None, "52817e4fdd89df154f02749ca6f692ac", ["sdxl_unet"], [SDXLUNet], "diffusers"),
|
| 77 |
+
(None, "03343c606f16d834d6411d0902b53636", ["sd_text_encoder", "sd_unet", "sd_vae_decoder", "sd_vae_encoder"], [SDTextEncoder, SDUNet, SDVAEDecoder, SDVAEEncoder], "civitai"),
|
| 78 |
+
(None, "d4ba77a7ece070679b4a987f58f201e9", ["sd_text_encoder"], [SDTextEncoder], "civitai"),
|
| 79 |
+
(None, "d0c89e55c5a57cf3981def0cb1c9e65a", ["sd_vae_decoder", "sd_vae_encoder"], [SDVAEDecoder, SDVAEEncoder], "civitai"),
|
| 80 |
+
(None, "3926bf373b39a67eeafd7901478a47a7", ["sd_unet"], [SDUNet], "civitai"),
|
| 81 |
+
(None, "1e0c39ec176b9007c05f76d52b554a4d", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
| 82 |
+
(None, "d9e0290829ba8d98e28e1a2b1407db4a", ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3, SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
| 83 |
+
(None, "5072d0b24e406b49507abe861cf97691", ["sd3_text_encoder_3"], [SD3TextEncoder3], "civitai"),
|
| 84 |
+
(None, "4cf64a799d04260df438c6f33c9a047e", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"),
|
| 85 |
+
(None, "d9b008a867c498ab12ad24042eff8e3f", ["sdxl_text_encoder", "sdxl_text_encoder_2", "sdxl_unet", "sdxl_vae_decoder", "sdxl_vae_encoder"], [SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder], "civitai"), # SDXL-Turbo
|
| 86 |
+
(None, "025bb7452e531a3853d951d77c63f032", ["sdxl_text_encoder", "sdxl_text_encoder_2"], [SDXLTextEncoder, SDXLTextEncoder2], "civitai"),
|
| 87 |
+
(None, "298997b403a4245c04102c9f36aac348", ["sdxl_unet"], [SDXLUNet], "civitai"),
|
| 88 |
+
(None, "2a07abce74b4bdc696b76254ab474da6", ["svd_image_encoder", "svd_unet", "svd_vae_decoder", "svd_vae_encoder"], [SVDImageEncoder, SVDUNet, SVDVAEDecoder, SVDVAEEncoder], "civitai"),
|
| 89 |
+
(None, "c96a285a6888465f87de22a984d049fb", ["sd_motion_modules"], [SDMotionModel], "civitai"),
|
| 90 |
+
(None, "72907b92caed19bdb2adb89aa4063fe2", ["sdxl_motion_modules"], [SDXLMotionModel], "civitai"),
|
| 91 |
+
(None, "31d2d9614fba60511fc9bf2604aa01f7", ["sdxl_controlnet"], [SDXLControlNetUnion], "diffusers"),
|
| 92 |
+
(None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
|
| 93 |
+
(None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
|
| 94 |
+
(None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
|
| 95 |
+
(None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"),
|
| 96 |
+
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
|
| 97 |
+
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
|
| 98 |
+
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
| 99 |
+
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
|
| 100 |
+
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
| 101 |
+
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
| 102 |
+
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
| 103 |
+
(None, "61cbcbc7ac11f169c5949223efa960d1", ["omnigen_transformer"], [OmniGenTransformer], "diffusers"),
|
| 104 |
+
(None, "78d18b9101345ff695f312e7e62538c0", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
| 105 |
+
(None, "b001c89139b5f053c715fe772362dd2a", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
| 106 |
+
(None, "52357cb26250681367488a8954c271e8", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
| 107 |
+
(None, "0cfd1740758423a2a854d67c136d1e8c", ["flux_controlnet"], [FluxControlNet], "diffusers"),
|
| 108 |
+
(None, "4daaa66cc656a8fe369908693dad0a35", ["flux_ipadapter"], [FluxIpAdapter], "diffusers"),
|
| 109 |
+
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
| 110 |
+
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
|
| 111 |
+
(None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
| 112 |
+
(None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
|
| 113 |
+
(None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"),
|
| 114 |
+
(None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
|
| 115 |
+
(None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
|
| 116 |
+
(None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"),
|
| 117 |
+
(None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"),
|
| 118 |
+
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
|
| 119 |
+
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
| 120 |
+
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
| 121 |
+
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
| 122 |
+
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
| 123 |
+
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
| 124 |
+
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
| 125 |
+
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
|
| 126 |
+
]
|
| 127 |
+
huggingface_model_loader_configs = [
|
| 128 |
+
# These configs are provided for detecting model type automatically.
|
| 129 |
+
# The format is (architecture_in_huggingface_config, huggingface_lib, model_name, redirected_architecture)
|
| 130 |
+
("ChatGLMModel", "diffsynth.models.kolors_text_encoder", "kolors_text_encoder", None),
|
| 131 |
+
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
|
| 132 |
+
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
|
| 133 |
+
("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
|
| 134 |
+
# ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
|
| 135 |
+
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
|
| 136 |
+
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
| 137 |
+
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
|
| 138 |
+
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
|
| 139 |
+
("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
|
| 140 |
+
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
|
| 141 |
+
]
|
| 142 |
+
patch_model_loader_configs = [
|
| 143 |
+
# These configs are provided for detecting model type automatically.
|
| 144 |
+
# The format is (state_dict_keys_hash_with_shape, model_name, model_class, extra_kwargs)
|
| 145 |
+
("9a4ab6869ac9b7d6e31f9854e397c867", ["svd_unet"], [SVDUNet], {"add_positional_conv": 128}),
|
| 146 |
+
]
|
| 147 |
+
|
| 148 |
+
preset_models_on_huggingface = {
|
| 149 |
+
"HunyuanDiT": [
|
| 150 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
| 151 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
| 152 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
| 153 |
+
("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
| 154 |
+
],
|
| 155 |
+
"stable-video-diffusion-img2vid-xt": [
|
| 156 |
+
("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
| 157 |
+
],
|
| 158 |
+
"ExVideo-SVD-128f-v1": [
|
| 159 |
+
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
| 160 |
+
],
|
| 161 |
+
# Stable Diffusion
|
| 162 |
+
"StableDiffusion_v15": [
|
| 163 |
+
("benjamin-paine/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
| 164 |
+
],
|
| 165 |
+
"DreamShaper_8": [
|
| 166 |
+
("Yntec/Dreamshaper8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
| 167 |
+
],
|
| 168 |
+
# Textual Inversion
|
| 169 |
+
"TextualInversion_VeryBadImageNegative_v1.3": [
|
| 170 |
+
("gemasai/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
| 171 |
+
],
|
| 172 |
+
# Stable Diffusion XL
|
| 173 |
+
"StableDiffusionXL_v1": [
|
| 174 |
+
("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
| 175 |
+
],
|
| 176 |
+
"BluePencilXL_v200": [
|
| 177 |
+
("frankjoshua/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
| 178 |
+
],
|
| 179 |
+
"StableDiffusionXL_Turbo": [
|
| 180 |
+
("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
| 181 |
+
],
|
| 182 |
+
# Stable Diffusion 3
|
| 183 |
+
"StableDiffusion3": [
|
| 184 |
+
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
| 185 |
+
],
|
| 186 |
+
"StableDiffusion3_without_T5": [
|
| 187 |
+
("stabilityai/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
| 188 |
+
],
|
| 189 |
+
# ControlNet
|
| 190 |
+
"ControlNet_v11f1p_sd15_depth": [
|
| 191 |
+
("lllyasviel/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
| 192 |
+
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
| 193 |
+
],
|
| 194 |
+
"ControlNet_v11p_sd15_softedge": [
|
| 195 |
+
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
| 196 |
+
("lllyasviel/Annotators", "ControlNetHED.pth", "models/Annotators")
|
| 197 |
+
],
|
| 198 |
+
"ControlNet_v11f1e_sd15_tile": [
|
| 199 |
+
("lllyasviel/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
| 200 |
+
],
|
| 201 |
+
"ControlNet_v11p_sd15_lineart": [
|
| 202 |
+
("lllyasviel/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
| 203 |
+
("lllyasviel/Annotators", "sk_model.pth", "models/Annotators"),
|
| 204 |
+
("lllyasviel/Annotators", "sk_model2.pth", "models/Annotators")
|
| 205 |
+
],
|
| 206 |
+
"ControlNet_union_sdxl_promax": [
|
| 207 |
+
("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
| 208 |
+
("lllyasviel/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
| 209 |
+
],
|
| 210 |
+
# AnimateDiff
|
| 211 |
+
"AnimateDiff_v2": [
|
| 212 |
+
("guoyww/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
| 213 |
+
],
|
| 214 |
+
"AnimateDiff_xl_beta": [
|
| 215 |
+
("guoyww/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
| 216 |
+
],
|
| 217 |
+
|
| 218 |
+
# Qwen Prompt
|
| 219 |
+
"QwenPrompt": [
|
| 220 |
+
("Qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 221 |
+
("Qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 222 |
+
("Qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 223 |
+
("Qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 224 |
+
("Qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 225 |
+
("Qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 226 |
+
("Qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 227 |
+
("Qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 228 |
+
],
|
| 229 |
+
# Beautiful Prompt
|
| 230 |
+
"BeautifulPrompt": [
|
| 231 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 232 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 233 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 234 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 235 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 236 |
+
("alibaba-pai/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 237 |
+
],
|
| 238 |
+
# Omost prompt
|
| 239 |
+
"OmostPrompt":[
|
| 240 |
+
("lllyasviel/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 241 |
+
("lllyasviel/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 242 |
+
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 243 |
+
("lllyasviel/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 244 |
+
("lllyasviel/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 245 |
+
("lllyasviel/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 246 |
+
("lllyasviel/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 247 |
+
("lllyasviel/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 248 |
+
],
|
| 249 |
+
# Translator
|
| 250 |
+
"opus-mt-zh-en": [
|
| 251 |
+
("Helsinki-NLP/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
| 252 |
+
("Helsinki-NLP/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
| 253 |
+
("Helsinki-NLP/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
| 254 |
+
("Helsinki-NLP/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
| 255 |
+
("Helsinki-NLP/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
| 256 |
+
("Helsinki-NLP/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
| 257 |
+
("Helsinki-NLP/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
| 258 |
+
("Helsinki-NLP/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
| 259 |
+
],
|
| 260 |
+
# IP-Adapter
|
| 261 |
+
"IP-Adapter-SD": [
|
| 262 |
+
("h94/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
| 263 |
+
("h94/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
| 264 |
+
],
|
| 265 |
+
"IP-Adapter-SDXL": [
|
| 266 |
+
("h94/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
| 267 |
+
("h94/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
| 268 |
+
],
|
| 269 |
+
"SDXL-vae-fp16-fix": [
|
| 270 |
+
("madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
| 271 |
+
],
|
| 272 |
+
# Kolors
|
| 273 |
+
"Kolors": [
|
| 274 |
+
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
| 275 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
| 276 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 277 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 278 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 279 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 280 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 281 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 282 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 283 |
+
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
| 284 |
+
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
| 285 |
+
],
|
| 286 |
+
# FLUX
|
| 287 |
+
"FLUX.1-dev": [
|
| 288 |
+
("black-forest-labs/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
| 289 |
+
("black-forest-labs/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 290 |
+
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 291 |
+
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 292 |
+
("black-forest-labs/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 293 |
+
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
| 294 |
+
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
| 295 |
+
],
|
| 296 |
+
"InstantX/FLUX.1-dev-IP-Adapter": {
|
| 297 |
+
"file_list": [
|
| 298 |
+
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
|
| 299 |
+
("google/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
| 300 |
+
("google/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
| 301 |
+
],
|
| 302 |
+
"load_path": [
|
| 303 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
| 304 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
| 305 |
+
],
|
| 306 |
+
},
|
| 307 |
+
# RIFE
|
| 308 |
+
"RIFE": [
|
| 309 |
+
("AlexWortega/RIFE", "flownet.pkl", "models/RIFE"),
|
| 310 |
+
],
|
| 311 |
+
# CogVideo
|
| 312 |
+
"CogVideoX-5B": [
|
| 313 |
+
("THUDM/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
| 314 |
+
("THUDM/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
| 315 |
+
("THUDM/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
| 316 |
+
("THUDM/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
| 317 |
+
("THUDM/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
| 318 |
+
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
| 319 |
+
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
| 320 |
+
("THUDM/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
| 321 |
+
("THUDM/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
| 322 |
+
],
|
| 323 |
+
# Stable Diffusion 3.5
|
| 324 |
+
"StableDiffusion3.5-large": [
|
| 325 |
+
("stabilityai/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
|
| 326 |
+
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 327 |
+
("stabilityai/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 328 |
+
("stabilityai/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 329 |
+
],
|
| 330 |
+
}
|
| 331 |
+
preset_models_on_modelscope = {
|
| 332 |
+
# Hunyuan DiT
|
| 333 |
+
"HunyuanDiT": [
|
| 334 |
+
("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
|
| 335 |
+
("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
|
| 336 |
+
("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
|
| 337 |
+
("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
|
| 338 |
+
],
|
| 339 |
+
# Stable Video Diffusion
|
| 340 |
+
"stable-video-diffusion-img2vid-xt": [
|
| 341 |
+
("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
|
| 342 |
+
],
|
| 343 |
+
# ExVideo
|
| 344 |
+
"ExVideo-SVD-128f-v1": [
|
| 345 |
+
("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
|
| 346 |
+
],
|
| 347 |
+
"ExVideo-CogVideoX-LoRA-129f-v1": [
|
| 348 |
+
("ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1", "ExVideo-CogVideoX-LoRA-129f-v1.safetensors", "models/lora"),
|
| 349 |
+
],
|
| 350 |
+
# Stable Diffusion
|
| 351 |
+
"StableDiffusion_v15": [
|
| 352 |
+
("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
|
| 353 |
+
],
|
| 354 |
+
"DreamShaper_8": [
|
| 355 |
+
("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
|
| 356 |
+
],
|
| 357 |
+
"AingDiffusion_v12": [
|
| 358 |
+
("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
|
| 359 |
+
],
|
| 360 |
+
"Flat2DAnimerge_v45Sharp": [
|
| 361 |
+
("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
|
| 362 |
+
],
|
| 363 |
+
# Textual Inversion
|
| 364 |
+
"TextualInversion_VeryBadImageNegative_v1.3": [
|
| 365 |
+
("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
|
| 366 |
+
],
|
| 367 |
+
# Stable Diffusion XL
|
| 368 |
+
"StableDiffusionXL_v1": [
|
| 369 |
+
("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
|
| 370 |
+
],
|
| 371 |
+
"BluePencilXL_v200": [
|
| 372 |
+
("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
|
| 373 |
+
],
|
| 374 |
+
"StableDiffusionXL_Turbo": [
|
| 375 |
+
("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
|
| 376 |
+
],
|
| 377 |
+
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0": [
|
| 378 |
+
("sd_lora/zyd232_ChineseInkStyle_SDXL_v1_0", "zyd232_ChineseInkStyle_SDXL_v1_0.safetensors", "models/lora"),
|
| 379 |
+
],
|
| 380 |
+
# Stable Diffusion 3
|
| 381 |
+
"StableDiffusion3": [
|
| 382 |
+
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
|
| 383 |
+
],
|
| 384 |
+
"StableDiffusion3_without_T5": [
|
| 385 |
+
("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
|
| 386 |
+
],
|
| 387 |
+
# ControlNet
|
| 388 |
+
"ControlNet_v11f1p_sd15_depth": [
|
| 389 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
|
| 390 |
+
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
| 391 |
+
],
|
| 392 |
+
"ControlNet_v11p_sd15_softedge": [
|
| 393 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
|
| 394 |
+
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
|
| 395 |
+
],
|
| 396 |
+
"ControlNet_v11f1e_sd15_tile": [
|
| 397 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
|
| 398 |
+
],
|
| 399 |
+
"ControlNet_v11p_sd15_lineart": [
|
| 400 |
+
("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
|
| 401 |
+
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
| 402 |
+
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
|
| 403 |
+
],
|
| 404 |
+
"ControlNet_union_sdxl_promax": [
|
| 405 |
+
("AI-ModelScope/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", "models/ControlNet/controlnet_union"),
|
| 406 |
+
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
|
| 407 |
+
],
|
| 408 |
+
"Annotators:Depth": [
|
| 409 |
+
("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators"),
|
| 410 |
+
],
|
| 411 |
+
"Annotators:Softedge": [
|
| 412 |
+
("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators"),
|
| 413 |
+
],
|
| 414 |
+
"Annotators:Lineart": [
|
| 415 |
+
("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
|
| 416 |
+
("sd_lora/Annotators", "sk_model2.pth", "models/Annotators"),
|
| 417 |
+
],
|
| 418 |
+
"Annotators:Normal": [
|
| 419 |
+
("sd_lora/Annotators", "scannet.pt", "models/Annotators"),
|
| 420 |
+
],
|
| 421 |
+
"Annotators:Openpose": [
|
| 422 |
+
("sd_lora/Annotators", "body_pose_model.pth", "models/Annotators"),
|
| 423 |
+
("sd_lora/Annotators", "facenet.pth", "models/Annotators"),
|
| 424 |
+
("sd_lora/Annotators", "hand_pose_model.pth", "models/Annotators"),
|
| 425 |
+
],
|
| 426 |
+
# AnimateDiff
|
| 427 |
+
"AnimateDiff_v2": [
|
| 428 |
+
("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
|
| 429 |
+
],
|
| 430 |
+
"AnimateDiff_xl_beta": [
|
| 431 |
+
("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
|
| 432 |
+
],
|
| 433 |
+
# RIFE
|
| 434 |
+
"RIFE": [
|
| 435 |
+
("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
|
| 436 |
+
],
|
| 437 |
+
# Qwen Prompt
|
| 438 |
+
"QwenPrompt": {
|
| 439 |
+
"file_list": [
|
| 440 |
+
("qwen/Qwen2-1.5B-Instruct", "config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 441 |
+
("qwen/Qwen2-1.5B-Instruct", "generation_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 442 |
+
("qwen/Qwen2-1.5B-Instruct", "model.safetensors", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 443 |
+
("qwen/Qwen2-1.5B-Instruct", "special_tokens_map.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 444 |
+
("qwen/Qwen2-1.5B-Instruct", "tokenizer.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 445 |
+
("qwen/Qwen2-1.5B-Instruct", "tokenizer_config.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 446 |
+
("qwen/Qwen2-1.5B-Instruct", "merges.txt", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 447 |
+
("qwen/Qwen2-1.5B-Instruct", "vocab.json", "models/QwenPrompt/qwen2-1.5b-instruct"),
|
| 448 |
+
],
|
| 449 |
+
"load_path": [
|
| 450 |
+
"models/QwenPrompt/qwen2-1.5b-instruct",
|
| 451 |
+
],
|
| 452 |
+
},
|
| 453 |
+
# Beautiful Prompt
|
| 454 |
+
"BeautifulPrompt": {
|
| 455 |
+
"file_list": [
|
| 456 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 457 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 458 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 459 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 460 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 461 |
+
("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
|
| 462 |
+
],
|
| 463 |
+
"load_path": [
|
| 464 |
+
"models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd",
|
| 465 |
+
],
|
| 466 |
+
},
|
| 467 |
+
# Omost prompt
|
| 468 |
+
"OmostPrompt": {
|
| 469 |
+
"file_list": [
|
| 470 |
+
("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 471 |
+
("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 472 |
+
("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 473 |
+
("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 474 |
+
("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 475 |
+
("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 476 |
+
("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 477 |
+
("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"),
|
| 478 |
+
],
|
| 479 |
+
"load_path": [
|
| 480 |
+
"models/OmostPrompt/omost-llama-3-8b-4bits",
|
| 481 |
+
],
|
| 482 |
+
},
|
| 483 |
+
# Translator
|
| 484 |
+
"opus-mt-zh-en": {
|
| 485 |
+
"file_list": [
|
| 486 |
+
("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
|
| 487 |
+
("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
|
| 488 |
+
("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
|
| 489 |
+
("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
|
| 490 |
+
("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
|
| 491 |
+
("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
|
| 492 |
+
("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
|
| 493 |
+
("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
|
| 494 |
+
],
|
| 495 |
+
"load_path": [
|
| 496 |
+
"models/translator/opus-mt-zh-en",
|
| 497 |
+
],
|
| 498 |
+
},
|
| 499 |
+
# IP-Adapter
|
| 500 |
+
"IP-Adapter-SD": [
|
| 501 |
+
("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
|
| 502 |
+
("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
|
| 503 |
+
],
|
| 504 |
+
"IP-Adapter-SDXL": [
|
| 505 |
+
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
| 506 |
+
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
| 507 |
+
],
|
| 508 |
+
# Kolors
|
| 509 |
+
"Kolors": {
|
| 510 |
+
"file_list": [
|
| 511 |
+
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
| 512 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
| 513 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 514 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 515 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 516 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 517 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 518 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 519 |
+
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
| 520 |
+
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
| 521 |
+
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
| 522 |
+
],
|
| 523 |
+
"load_path": [
|
| 524 |
+
"models/kolors/Kolors/text_encoder",
|
| 525 |
+
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
|
| 526 |
+
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors",
|
| 527 |
+
],
|
| 528 |
+
},
|
| 529 |
+
"SDXL-vae-fp16-fix": [
|
| 530 |
+
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
|
| 531 |
+
],
|
| 532 |
+
# FLUX
|
| 533 |
+
"FLUX.1-dev": {
|
| 534 |
+
"file_list": [
|
| 535 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
| 536 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 537 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 538 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 539 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 540 |
+
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
| 541 |
+
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
|
| 542 |
+
],
|
| 543 |
+
"load_path": [
|
| 544 |
+
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
| 545 |
+
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
| 546 |
+
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
| 547 |
+
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
|
| 548 |
+
],
|
| 549 |
+
},
|
| 550 |
+
"FLUX.1-schnell": {
|
| 551 |
+
"file_list": [
|
| 552 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder"),
|
| 553 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/config.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 554 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00001-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 555 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model-00002-of-00002.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 556 |
+
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
|
| 557 |
+
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
|
| 558 |
+
("AI-ModelScope/FLUX.1-schnell", "flux1-schnell.safetensors", "models/FLUX/FLUX.1-schnell"),
|
| 559 |
+
],
|
| 560 |
+
"load_path": [
|
| 561 |
+
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
| 562 |
+
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
| 563 |
+
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
| 564 |
+
"models/FLUX/FLUX.1-schnell/flux1-schnell.safetensors"
|
| 565 |
+
],
|
| 566 |
+
},
|
| 567 |
+
"InstantX/FLUX.1-dev-Controlnet-Union-alpha": [
|
| 568 |
+
("InstantX/FLUX.1-dev-Controlnet-Union-alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/InstantX/FLUX.1-dev-Controlnet-Union-alpha"),
|
| 569 |
+
],
|
| 570 |
+
"jasperai/Flux.1-dev-Controlnet-Depth": [
|
| 571 |
+
("jasperai/Flux.1-dev-Controlnet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Depth"),
|
| 572 |
+
],
|
| 573 |
+
"jasperai/Flux.1-dev-Controlnet-Surface-Normals": [
|
| 574 |
+
("jasperai/Flux.1-dev-Controlnet-Surface-Normals", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Surface-Normals"),
|
| 575 |
+
],
|
| 576 |
+
"jasperai/Flux.1-dev-Controlnet-Upscaler": [
|
| 577 |
+
("jasperai/Flux.1-dev-Controlnet-Upscaler", "diffusion_pytorch_model.safetensors", "models/ControlNet/jasperai/Flux.1-dev-Controlnet-Upscaler"),
|
| 578 |
+
],
|
| 579 |
+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha": [
|
| 580 |
+
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha"),
|
| 581 |
+
],
|
| 582 |
+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta": [
|
| 583 |
+
("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", "diffusion_pytorch_model.safetensors", "models/ControlNet/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta"),
|
| 584 |
+
],
|
| 585 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth": [
|
| 586 |
+
("Shakker-Labs/FLUX.1-dev-ControlNet-Depth", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Depth"),
|
| 587 |
+
],
|
| 588 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro": [
|
| 589 |
+
("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", "models/ControlNet/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"),
|
| 590 |
+
],
|
| 591 |
+
"InstantX/FLUX.1-dev-IP-Adapter": {
|
| 592 |
+
"file_list": [
|
| 593 |
+
("InstantX/FLUX.1-dev-IP-Adapter", "ip-adapter.bin", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter"),
|
| 594 |
+
("AI-ModelScope/siglip-so400m-patch14-384", "model.safetensors", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
| 595 |
+
("AI-ModelScope/siglip-so400m-patch14-384", "config.json", "models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder"),
|
| 596 |
+
],
|
| 597 |
+
"load_path": [
|
| 598 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/ip-adapter.bin",
|
| 599 |
+
"models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder",
|
| 600 |
+
],
|
| 601 |
+
},
|
| 602 |
+
# ESRGAN
|
| 603 |
+
"ESRGAN_x4": [
|
| 604 |
+
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
|
| 605 |
+
],
|
| 606 |
+
# RIFE
|
| 607 |
+
"RIFE": [
|
| 608 |
+
("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
|
| 609 |
+
],
|
| 610 |
+
# Omnigen
|
| 611 |
+
"OmniGen-v1": {
|
| 612 |
+
"file_list": [
|
| 613 |
+
("BAAI/OmniGen-v1", "vae/diffusion_pytorch_model.safetensors", "models/OmniGen/OmniGen-v1/vae"),
|
| 614 |
+
("BAAI/OmniGen-v1", "model.safetensors", "models/OmniGen/OmniGen-v1"),
|
| 615 |
+
("BAAI/OmniGen-v1", "config.json", "models/OmniGen/OmniGen-v1"),
|
| 616 |
+
("BAAI/OmniGen-v1", "special_tokens_map.json", "models/OmniGen/OmniGen-v1"),
|
| 617 |
+
("BAAI/OmniGen-v1", "tokenizer_config.json", "models/OmniGen/OmniGen-v1"),
|
| 618 |
+
("BAAI/OmniGen-v1", "tokenizer.json", "models/OmniGen/OmniGen-v1"),
|
| 619 |
+
],
|
| 620 |
+
"load_path": [
|
| 621 |
+
"models/OmniGen/OmniGen-v1/vae/diffusion_pytorch_model.safetensors",
|
| 622 |
+
"models/OmniGen/OmniGen-v1/model.safetensors",
|
| 623 |
+
]
|
| 624 |
+
},
|
| 625 |
+
# CogVideo
|
| 626 |
+
"CogVideoX-5B": {
|
| 627 |
+
"file_list": [
|
| 628 |
+
("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
| 629 |
+
("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
| 630 |
+
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
| 631 |
+
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
|
| 632 |
+
("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
| 633 |
+
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
|
| 634 |
+
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
| 635 |
+
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
|
| 636 |
+
("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
|
| 637 |
+
],
|
| 638 |
+
"load_path": [
|
| 639 |
+
"models/CogVideo/CogVideoX-5b/text_encoder",
|
| 640 |
+
"models/CogVideo/CogVideoX-5b/transformer",
|
| 641 |
+
"models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
|
| 642 |
+
],
|
| 643 |
+
},
|
| 644 |
+
# Stable Diffusion 3.5
|
| 645 |
+
"StableDiffusion3.5-large": [
|
| 646 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "sd3.5_large.safetensors", "models/stable_diffusion_3"),
|
| 647 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 648 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 649 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 650 |
+
],
|
| 651 |
+
"StableDiffusion3.5-medium": [
|
| 652 |
+
("AI-ModelScope/stable-diffusion-3.5-medium", "sd3.5_medium.safetensors", "models/stable_diffusion_3"),
|
| 653 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 654 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 655 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 656 |
+
],
|
| 657 |
+
"StableDiffusion3.5-large-turbo": [
|
| 658 |
+
("AI-ModelScope/stable-diffusion-3.5-large-turbo", "sd3.5_large_turbo.safetensors", "models/stable_diffusion_3"),
|
| 659 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_l.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 660 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 661 |
+
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
| 662 |
+
],
|
| 663 |
+
"HunyuanVideo":{
|
| 664 |
+
"file_list": [
|
| 665 |
+
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
| 666 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
| 667 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
| 668 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
| 669 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
| 670 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
| 671 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
| 672 |
+
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
| 673 |
+
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
|
| 674 |
+
],
|
| 675 |
+
"load_path": [
|
| 676 |
+
"models/HunyuanVideo/text_encoder/model.safetensors",
|
| 677 |
+
"models/HunyuanVideo/text_encoder_2",
|
| 678 |
+
"models/HunyuanVideo/vae/pytorch_model.pt",
|
| 679 |
+
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
| 680 |
+
],
|
| 681 |
+
},
|
| 682 |
+
"HunyuanVideoI2V":{
|
| 683 |
+
"file_list": [
|
| 684 |
+
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideoI2V/text_encoder"),
|
| 685 |
+
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00001-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
| 686 |
+
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00002-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
| 687 |
+
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00003-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
| 688 |
+
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00004-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
|
| 689 |
+
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "config.json", "models/HunyuanVideoI2V/text_encoder_2"),
|
| 690 |
+
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model.safetensors.index.json", "models/HunyuanVideoI2V/text_encoder_2"),
|
| 691 |
+
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/vae/pytorch_model.pt", "models/HunyuanVideoI2V/vae"),
|
| 692 |
+
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideoI2V/transformers")
|
| 693 |
+
],
|
| 694 |
+
"load_path": [
|
| 695 |
+
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
|
| 696 |
+
"models/HunyuanVideoI2V/text_encoder_2",
|
| 697 |
+
"models/HunyuanVideoI2V/vae/pytorch_model.pt",
|
| 698 |
+
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
|
| 699 |
+
],
|
| 700 |
+
},
|
| 701 |
+
"HunyuanVideo-fp8":{
|
| 702 |
+
"file_list": [
|
| 703 |
+
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
| 704 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
| 705 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
| 706 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
| 707 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
| 708 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
| 709 |
+
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
| 710 |
+
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
| 711 |
+
("DiffSynth-Studio/HunyuanVideo-safetensors", "model.fp8.safetensors", "models/HunyuanVideo/transformers")
|
| 712 |
+
],
|
| 713 |
+
"load_path": [
|
| 714 |
+
"models/HunyuanVideo/text_encoder/model.safetensors",
|
| 715 |
+
"models/HunyuanVideo/text_encoder_2",
|
| 716 |
+
"models/HunyuanVideo/vae/pytorch_model.pt",
|
| 717 |
+
"models/HunyuanVideo/transformers/model.fp8.safetensors"
|
| 718 |
+
],
|
| 719 |
+
},
|
| 720 |
+
}
|
| 721 |
+
Preset_model_id: TypeAlias = Literal[
|
| 722 |
+
"HunyuanDiT",
|
| 723 |
+
"stable-video-diffusion-img2vid-xt",
|
| 724 |
+
"ExVideo-SVD-128f-v1",
|
| 725 |
+
"ExVideo-CogVideoX-LoRA-129f-v1",
|
| 726 |
+
"StableDiffusion_v15",
|
| 727 |
+
"DreamShaper_8",
|
| 728 |
+
"AingDiffusion_v12",
|
| 729 |
+
"Flat2DAnimerge_v45Sharp",
|
| 730 |
+
"TextualInversion_VeryBadImageNegative_v1.3",
|
| 731 |
+
"StableDiffusionXL_v1",
|
| 732 |
+
"BluePencilXL_v200",
|
| 733 |
+
"StableDiffusionXL_Turbo",
|
| 734 |
+
"ControlNet_v11f1p_sd15_depth",
|
| 735 |
+
"ControlNet_v11p_sd15_softedge",
|
| 736 |
+
"ControlNet_v11f1e_sd15_tile",
|
| 737 |
+
"ControlNet_v11p_sd15_lineart",
|
| 738 |
+
"AnimateDiff_v2",
|
| 739 |
+
"AnimateDiff_xl_beta",
|
| 740 |
+
"RIFE",
|
| 741 |
+
"BeautifulPrompt",
|
| 742 |
+
"opus-mt-zh-en",
|
| 743 |
+
"IP-Adapter-SD",
|
| 744 |
+
"IP-Adapter-SDXL",
|
| 745 |
+
"StableDiffusion3",
|
| 746 |
+
"StableDiffusion3_without_T5",
|
| 747 |
+
"Kolors",
|
| 748 |
+
"SDXL-vae-fp16-fix",
|
| 749 |
+
"ControlNet_union_sdxl_promax",
|
| 750 |
+
"FLUX.1-dev",
|
| 751 |
+
"FLUX.1-schnell",
|
| 752 |
+
"InstantX/FLUX.1-dev-Controlnet-Union-alpha",
|
| 753 |
+
"jasperai/Flux.1-dev-Controlnet-Depth",
|
| 754 |
+
"jasperai/Flux.1-dev-Controlnet-Surface-Normals",
|
| 755 |
+
"jasperai/Flux.1-dev-Controlnet-Upscaler",
|
| 756 |
+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
|
| 757 |
+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
|
| 758 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Depth",
|
| 759 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
| 760 |
+
"InstantX/FLUX.1-dev-IP-Adapter",
|
| 761 |
+
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
|
| 762 |
+
"QwenPrompt",
|
| 763 |
+
"OmostPrompt",
|
| 764 |
+
"ESRGAN_x4",
|
| 765 |
+
"RIFE",
|
| 766 |
+
"OmniGen-v1",
|
| 767 |
+
"CogVideoX-5B",
|
| 768 |
+
"Annotators:Depth",
|
| 769 |
+
"Annotators:Softedge",
|
| 770 |
+
"Annotators:Lineart",
|
| 771 |
+
"Annotators:Normal",
|
| 772 |
+
"Annotators:Openpose",
|
| 773 |
+
"StableDiffusion3.5-large",
|
| 774 |
+
"StableDiffusion3.5-medium",
|
| 775 |
+
"HunyuanVideo",
|
| 776 |
+
"HunyuanVideo-fp8",
|
| 777 |
+
"HunyuanVideoI2V",
|
| 778 |
+
]
|
diffsynth/controlnets/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager, FluxMultiControlNetManager
|
| 2 |
+
from .processors import Annotator
|
diffsynth/controlnets/controlnet_unit.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from .processors import Processor_id
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ControlNetConfigUnit:
|
| 7 |
+
def __init__(self, processor_id: Processor_id, model_path, scale=1.0, skip_processor=False):
|
| 8 |
+
self.processor_id = processor_id
|
| 9 |
+
self.model_path = model_path
|
| 10 |
+
self.scale = scale
|
| 11 |
+
self.skip_processor = skip_processor
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ControlNetUnit:
|
| 15 |
+
def __init__(self, processor, model, scale=1.0):
|
| 16 |
+
self.processor = processor
|
| 17 |
+
self.model = model
|
| 18 |
+
self.scale = scale
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MultiControlNetManager:
|
| 22 |
+
def __init__(self, controlnet_units=[]):
|
| 23 |
+
self.processors = [unit.processor for unit in controlnet_units]
|
| 24 |
+
self.models = [unit.model for unit in controlnet_units]
|
| 25 |
+
self.scales = [unit.scale for unit in controlnet_units]
|
| 26 |
+
|
| 27 |
+
def cpu(self):
|
| 28 |
+
for model in self.models:
|
| 29 |
+
model.cpu()
|
| 30 |
+
|
| 31 |
+
def to(self, device):
|
| 32 |
+
for model in self.models:
|
| 33 |
+
model.to(device)
|
| 34 |
+
for processor in self.processors:
|
| 35 |
+
processor.to(device)
|
| 36 |
+
|
| 37 |
+
def process_image(self, image, processor_id=None):
|
| 38 |
+
if processor_id is None:
|
| 39 |
+
processed_image = [processor(image) for processor in self.processors]
|
| 40 |
+
else:
|
| 41 |
+
processed_image = [self.processors[processor_id](image)]
|
| 42 |
+
processed_image = torch.concat([
|
| 43 |
+
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
|
| 44 |
+
for image_ in processed_image
|
| 45 |
+
], dim=0)
|
| 46 |
+
return processed_image
|
| 47 |
+
|
| 48 |
+
def __call__(
|
| 49 |
+
self,
|
| 50 |
+
sample, timestep, encoder_hidden_states, conditionings,
|
| 51 |
+
tiled=False, tile_size=64, tile_stride=32, **kwargs
|
| 52 |
+
):
|
| 53 |
+
res_stack = None
|
| 54 |
+
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
|
| 55 |
+
res_stack_ = model(
|
| 56 |
+
sample, timestep, encoder_hidden_states, conditioning, **kwargs,
|
| 57 |
+
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
| 58 |
+
processor_id=processor.processor_id
|
| 59 |
+
)
|
| 60 |
+
res_stack_ = [res * scale for res in res_stack_]
|
| 61 |
+
if res_stack is None:
|
| 62 |
+
res_stack = res_stack_
|
| 63 |
+
else:
|
| 64 |
+
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
| 65 |
+
return res_stack
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class FluxMultiControlNetManager(MultiControlNetManager):
|
| 69 |
+
def __init__(self, controlnet_units=[]):
|
| 70 |
+
super().__init__(controlnet_units=controlnet_units)
|
| 71 |
+
|
| 72 |
+
def process_image(self, image, processor_id=None):
|
| 73 |
+
if processor_id is None:
|
| 74 |
+
processed_image = [processor(image) for processor in self.processors]
|
| 75 |
+
else:
|
| 76 |
+
processed_image = [self.processors[processor_id](image)]
|
| 77 |
+
return processed_image
|
| 78 |
+
|
| 79 |
+
def __call__(self, conditionings, **kwargs):
|
| 80 |
+
res_stack, single_res_stack = None, None
|
| 81 |
+
for processor, conditioning, model, scale in zip(self.processors, conditionings, self.models, self.scales):
|
| 82 |
+
res_stack_, single_res_stack_ = model(controlnet_conditioning=conditioning, processor_id=processor.processor_id, **kwargs)
|
| 83 |
+
res_stack_ = [res * scale for res in res_stack_]
|
| 84 |
+
single_res_stack_ = [res * scale for res in single_res_stack_]
|
| 85 |
+
if res_stack is None:
|
| 86 |
+
res_stack = res_stack_
|
| 87 |
+
single_res_stack = single_res_stack_
|
| 88 |
+
else:
|
| 89 |
+
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
| 90 |
+
single_res_stack = [i + j for i, j in zip(single_res_stack, single_res_stack_)]
|
| 91 |
+
return res_stack, single_res_stack
|
diffsynth/controlnets/processors.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing_extensions import Literal, TypeAlias
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
Processor_id: TypeAlias = Literal[
|
| 5 |
+
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
|
| 6 |
+
]
|
| 7 |
+
|
| 8 |
+
class Annotator:
|
| 9 |
+
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
|
| 10 |
+
if not skip_processor:
|
| 11 |
+
if processor_id == "canny":
|
| 12 |
+
from controlnet_aux.processor import CannyDetector
|
| 13 |
+
self.processor = CannyDetector()
|
| 14 |
+
elif processor_id == "depth":
|
| 15 |
+
from controlnet_aux.processor import MidasDetector
|
| 16 |
+
self.processor = MidasDetector.from_pretrained(model_path).to(device)
|
| 17 |
+
elif processor_id == "softedge":
|
| 18 |
+
from controlnet_aux.processor import HEDdetector
|
| 19 |
+
self.processor = HEDdetector.from_pretrained(model_path).to(device)
|
| 20 |
+
elif processor_id == "lineart":
|
| 21 |
+
from controlnet_aux.processor import LineartDetector
|
| 22 |
+
self.processor = LineartDetector.from_pretrained(model_path).to(device)
|
| 23 |
+
elif processor_id == "lineart_anime":
|
| 24 |
+
from controlnet_aux.processor import LineartAnimeDetector
|
| 25 |
+
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
|
| 26 |
+
elif processor_id == "openpose":
|
| 27 |
+
from controlnet_aux.processor import OpenposeDetector
|
| 28 |
+
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
|
| 29 |
+
elif processor_id == "normal":
|
| 30 |
+
from controlnet_aux.processor import NormalBaeDetector
|
| 31 |
+
self.processor = NormalBaeDetector.from_pretrained(model_path).to(device)
|
| 32 |
+
elif processor_id == "tile" or processor_id == "none" or processor_id == "inpaint":
|
| 33 |
+
self.processor = None
|
| 34 |
+
else:
|
| 35 |
+
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
| 36 |
+
else:
|
| 37 |
+
self.processor = None
|
| 38 |
+
|
| 39 |
+
self.processor_id = processor_id
|
| 40 |
+
self.detect_resolution = detect_resolution
|
| 41 |
+
|
| 42 |
+
def to(self,device):
|
| 43 |
+
if hasattr(self.processor,"model") and hasattr(self.processor.model,"to"):
|
| 44 |
+
|
| 45 |
+
self.processor.model.to(device)
|
| 46 |
+
|
| 47 |
+
def __call__(self, image, mask=None):
|
| 48 |
+
width, height = image.size
|
| 49 |
+
if self.processor_id == "openpose":
|
| 50 |
+
kwargs = {
|
| 51 |
+
"include_body": True,
|
| 52 |
+
"include_hand": True,
|
| 53 |
+
"include_face": True
|
| 54 |
+
}
|
| 55 |
+
else:
|
| 56 |
+
kwargs = {}
|
| 57 |
+
if self.processor is not None:
|
| 58 |
+
detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
|
| 59 |
+
image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
|
| 60 |
+
image = image.resize((width, height))
|
| 61 |
+
return image
|
| 62 |
+
|
diffsynth/data/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .video import VideoData, save_video, save_frames
|
diffsynth/data/simple_text_image.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, os, torchvision
|
| 2 |
+
from torchvision import transforms
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TextImageDataset(torch.utils.data.Dataset):
|
| 9 |
+
def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
|
| 10 |
+
self.steps_per_epoch = steps_per_epoch
|
| 11 |
+
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
|
| 12 |
+
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
|
| 13 |
+
self.text = metadata["text"].to_list()
|
| 14 |
+
self.height = height
|
| 15 |
+
self.width = width
|
| 16 |
+
self.image_processor = transforms.Compose(
|
| 17 |
+
[
|
| 18 |
+
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
|
| 19 |
+
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
|
| 20 |
+
transforms.ToTensor(),
|
| 21 |
+
transforms.Normalize([0.5], [0.5]),
|
| 22 |
+
]
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def __getitem__(self, index):
|
| 27 |
+
data_id = torch.randint(0, len(self.path), (1,))[0]
|
| 28 |
+
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
| 29 |
+
text = self.text[data_id]
|
| 30 |
+
image = Image.open(self.path[data_id]).convert("RGB")
|
| 31 |
+
target_height, target_width = self.height, self.width
|
| 32 |
+
width, height = image.size
|
| 33 |
+
scale = max(target_width / width, target_height / height)
|
| 34 |
+
shape = [round(height*scale),round(width*scale)]
|
| 35 |
+
image = torchvision.transforms.functional.resize(image,shape,interpolation=transforms.InterpolationMode.BILINEAR)
|
| 36 |
+
image = self.image_processor(image)
|
| 37 |
+
return {"text": text, "image": image}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def __len__(self):
|
| 41 |
+
return self.steps_per_epoch
|
diffsynth/data/video.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import imageio, os
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LowMemoryVideo:
|
| 8 |
+
def __init__(self, file_name):
|
| 9 |
+
self.reader = imageio.get_reader(file_name)
|
| 10 |
+
|
| 11 |
+
def __len__(self):
|
| 12 |
+
return self.reader.count_frames()
|
| 13 |
+
|
| 14 |
+
def __getitem__(self, item):
|
| 15 |
+
return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
|
| 16 |
+
|
| 17 |
+
def __del__(self):
|
| 18 |
+
self.reader.close()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def split_file_name(file_name):
|
| 22 |
+
result = []
|
| 23 |
+
number = -1
|
| 24 |
+
for i in file_name:
|
| 25 |
+
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
| 26 |
+
if number == -1:
|
| 27 |
+
number = 0
|
| 28 |
+
number = number*10 + ord(i) - ord("0")
|
| 29 |
+
else:
|
| 30 |
+
if number != -1:
|
| 31 |
+
result.append(number)
|
| 32 |
+
number = -1
|
| 33 |
+
result.append(i)
|
| 34 |
+
if number != -1:
|
| 35 |
+
result.append(number)
|
| 36 |
+
result = tuple(result)
|
| 37 |
+
return result
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def search_for_images(folder):
|
| 41 |
+
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
|
| 42 |
+
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
|
| 43 |
+
file_list = [i[1] for i in sorted(file_list)]
|
| 44 |
+
file_list = [os.path.join(folder, i) for i in file_list]
|
| 45 |
+
return file_list
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class LowMemoryImageFolder:
|
| 49 |
+
def __init__(self, folder, file_list=None):
|
| 50 |
+
if file_list is None:
|
| 51 |
+
self.file_list = search_for_images(folder)
|
| 52 |
+
else:
|
| 53 |
+
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
|
| 54 |
+
|
| 55 |
+
def __len__(self):
|
| 56 |
+
return len(self.file_list)
|
| 57 |
+
|
| 58 |
+
def __getitem__(self, item):
|
| 59 |
+
return Image.open(self.file_list[item]).convert("RGB")
|
| 60 |
+
|
| 61 |
+
def __del__(self):
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def crop_and_resize(image, height, width):
|
| 66 |
+
image = np.array(image)
|
| 67 |
+
image_height, image_width, _ = image.shape
|
| 68 |
+
if image_height / image_width < height / width:
|
| 69 |
+
croped_width = int(image_height / height * width)
|
| 70 |
+
left = (image_width - croped_width) // 2
|
| 71 |
+
image = image[:, left: left+croped_width]
|
| 72 |
+
image = Image.fromarray(image).resize((width, height))
|
| 73 |
+
else:
|
| 74 |
+
croped_height = int(image_width / width * height)
|
| 75 |
+
left = (image_height - croped_height) // 2
|
| 76 |
+
image = image[left: left+croped_height, :]
|
| 77 |
+
image = Image.fromarray(image).resize((width, height))
|
| 78 |
+
return image
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class VideoData:
|
| 82 |
+
def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
|
| 83 |
+
if video_file is not None:
|
| 84 |
+
self.data_type = "video"
|
| 85 |
+
self.data = LowMemoryVideo(video_file, **kwargs)
|
| 86 |
+
elif image_folder is not None:
|
| 87 |
+
self.data_type = "images"
|
| 88 |
+
self.data = LowMemoryImageFolder(image_folder, **kwargs)
|
| 89 |
+
else:
|
| 90 |
+
raise ValueError("Cannot open video or image folder")
|
| 91 |
+
self.length = None
|
| 92 |
+
self.set_shape(height, width)
|
| 93 |
+
|
| 94 |
+
def raw_data(self):
|
| 95 |
+
frames = []
|
| 96 |
+
for i in range(self.__len__()):
|
| 97 |
+
frames.append(self.__getitem__(i))
|
| 98 |
+
return frames
|
| 99 |
+
|
| 100 |
+
def set_length(self, length):
|
| 101 |
+
self.length = length
|
| 102 |
+
|
| 103 |
+
def set_shape(self, height, width):
|
| 104 |
+
self.height = height
|
| 105 |
+
self.width = width
|
| 106 |
+
|
| 107 |
+
def __len__(self):
|
| 108 |
+
if self.length is None:
|
| 109 |
+
return len(self.data)
|
| 110 |
+
else:
|
| 111 |
+
return self.length
|
| 112 |
+
|
| 113 |
+
def shape(self):
|
| 114 |
+
if self.height is not None and self.width is not None:
|
| 115 |
+
return self.height, self.width
|
| 116 |
+
else:
|
| 117 |
+
height, width, _ = self.__getitem__(0).shape
|
| 118 |
+
return height, width
|
| 119 |
+
|
| 120 |
+
def __getitem__(self, item):
|
| 121 |
+
frame = self.data.__getitem__(item)
|
| 122 |
+
width, height = frame.size
|
| 123 |
+
if self.height is not None and self.width is not None:
|
| 124 |
+
if self.height != height or self.width != width:
|
| 125 |
+
frame = crop_and_resize(frame, self.height, self.width)
|
| 126 |
+
return frame
|
| 127 |
+
|
| 128 |
+
def __del__(self):
|
| 129 |
+
pass
|
| 130 |
+
|
| 131 |
+
def save_images(self, folder):
|
| 132 |
+
os.makedirs(folder, exist_ok=True)
|
| 133 |
+
for i in tqdm(range(self.__len__()), desc="Saving images"):
|
| 134 |
+
frame = self.__getitem__(i)
|
| 135 |
+
frame.save(os.path.join(folder, f"{i}.png"))
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
|
| 139 |
+
writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
|
| 140 |
+
for frame in tqdm(frames, desc="Saving video"):
|
| 141 |
+
frame = np.array(frame)
|
| 142 |
+
writer.append_data(frame)
|
| 143 |
+
writer.close()
|
| 144 |
+
|
| 145 |
+
def save_frames(frames, save_path):
|
| 146 |
+
os.makedirs(save_path, exist_ok=True)
|
| 147 |
+
for i, frame in enumerate(tqdm(frames, desc="Saving images")):
|
| 148 |
+
frame.save(os.path.join(save_path, f"{i}.png"))
|
diffsynth/extensions/ESRGAN/__init__.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from einops import repeat
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ResidualDenseBlock(torch.nn.Module):
|
| 8 |
+
|
| 9 |
+
def __init__(self, num_feat=64, num_grow_ch=32):
|
| 10 |
+
super(ResidualDenseBlock, self).__init__()
|
| 11 |
+
self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
| 12 |
+
self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
| 13 |
+
self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
| 14 |
+
self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
| 15 |
+
self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
| 16 |
+
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
x1 = self.lrelu(self.conv1(x))
|
| 20 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
| 21 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
| 22 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
| 23 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
| 24 |
+
return x5 * 0.2 + x
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class RRDB(torch.nn.Module):
|
| 28 |
+
|
| 29 |
+
def __init__(self, num_feat, num_grow_ch=32):
|
| 30 |
+
super(RRDB, self).__init__()
|
| 31 |
+
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
| 32 |
+
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
| 33 |
+
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
out = self.rdb1(x)
|
| 37 |
+
out = self.rdb2(out)
|
| 38 |
+
out = self.rdb3(out)
|
| 39 |
+
return out * 0.2 + x
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class RRDBNet(torch.nn.Module):
|
| 43 |
+
|
| 44 |
+
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, **kwargs):
|
| 45 |
+
super(RRDBNet, self).__init__()
|
| 46 |
+
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
| 47 |
+
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
|
| 48 |
+
self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
| 49 |
+
# upsample
|
| 50 |
+
self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
| 51 |
+
self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
| 52 |
+
self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
| 53 |
+
self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
| 54 |
+
self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
feat = x
|
| 58 |
+
feat = self.conv_first(feat)
|
| 59 |
+
body_feat = self.conv_body(self.body(feat))
|
| 60 |
+
feat = feat + body_feat
|
| 61 |
+
# upsample
|
| 62 |
+
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
|
| 63 |
+
feat = self.lrelu(self.conv_up1(feat))
|
| 64 |
+
feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
|
| 65 |
+
feat = self.lrelu(self.conv_up2(feat))
|
| 66 |
+
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
| 67 |
+
return out
|
| 68 |
+
|
| 69 |
+
@staticmethod
|
| 70 |
+
def state_dict_converter():
|
| 71 |
+
return RRDBNetStateDictConverter()
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class RRDBNetStateDictConverter:
|
| 75 |
+
def __init__(self):
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
def from_diffusers(self, state_dict):
|
| 79 |
+
return state_dict, {"upcast_to_float32": True}
|
| 80 |
+
|
| 81 |
+
def from_civitai(self, state_dict):
|
| 82 |
+
return state_dict, {"upcast_to_float32": True}
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class ESRGAN(torch.nn.Module):
|
| 86 |
+
def __init__(self, model):
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.model = model
|
| 89 |
+
|
| 90 |
+
@staticmethod
|
| 91 |
+
def from_model_manager(model_manager):
|
| 92 |
+
return ESRGAN(model_manager.fetch_model("esrgan"))
|
| 93 |
+
|
| 94 |
+
def process_image(self, image):
|
| 95 |
+
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
|
| 96 |
+
return image
|
| 97 |
+
|
| 98 |
+
def process_images(self, images):
|
| 99 |
+
images = [self.process_image(image) for image in images]
|
| 100 |
+
images = torch.stack(images)
|
| 101 |
+
return images
|
| 102 |
+
|
| 103 |
+
def decode_images(self, images):
|
| 104 |
+
images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
|
| 105 |
+
images = [Image.fromarray(image) for image in images]
|
| 106 |
+
return images
|
| 107 |
+
|
| 108 |
+
@torch.no_grad()
|
| 109 |
+
def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
|
| 110 |
+
if not isinstance(images, list):
|
| 111 |
+
images = [images]
|
| 112 |
+
is_single_image = True
|
| 113 |
+
else:
|
| 114 |
+
is_single_image = False
|
| 115 |
+
|
| 116 |
+
# Preprocess
|
| 117 |
+
input_tensor = self.process_images(images)
|
| 118 |
+
|
| 119 |
+
# Interpolate
|
| 120 |
+
output_tensor = []
|
| 121 |
+
for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
|
| 122 |
+
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
| 123 |
+
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
| 124 |
+
batch_input_tensor = batch_input_tensor.to(
|
| 125 |
+
device=self.model.conv_first.weight.device,
|
| 126 |
+
dtype=self.model.conv_first.weight.dtype)
|
| 127 |
+
batch_output_tensor = self.model(batch_input_tensor)
|
| 128 |
+
output_tensor.append(batch_output_tensor.cpu())
|
| 129 |
+
|
| 130 |
+
# Output
|
| 131 |
+
output_tensor = torch.concat(output_tensor, dim=0)
|
| 132 |
+
|
| 133 |
+
# To images
|
| 134 |
+
output_images = self.decode_images(output_tensor)
|
| 135 |
+
if is_single_image:
|
| 136 |
+
output_images = output_images[0]
|
| 137 |
+
return output_images
|
diffsynth/extensions/FastBlend/__init__.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .runners.fast import TableManager, PyramidPatchMatcher
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import numpy as np
|
| 4 |
+
import cupy as cp
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class FastBlendSmoother:
|
| 8 |
+
def __init__(self):
|
| 9 |
+
self.batch_size = 8
|
| 10 |
+
self.window_size = 64
|
| 11 |
+
self.ebsynth_config = {
|
| 12 |
+
"minimum_patch_size": 5,
|
| 13 |
+
"threads_per_block": 8,
|
| 14 |
+
"num_iter": 5,
|
| 15 |
+
"gpu_id": 0,
|
| 16 |
+
"guide_weight": 10.0,
|
| 17 |
+
"initialize": "identity",
|
| 18 |
+
"tracking_window_size": 0,
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
@staticmethod
|
| 22 |
+
def from_model_manager(model_manager):
|
| 23 |
+
# TODO: fetch GPU ID from model_manager
|
| 24 |
+
return FastBlendSmoother()
|
| 25 |
+
|
| 26 |
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
|
| 27 |
+
frames_guide = [np.array(frame) for frame in frames_guide]
|
| 28 |
+
frames_style = [np.array(frame) for frame in frames_style]
|
| 29 |
+
table_manager = TableManager()
|
| 30 |
+
patch_match_engine = PyramidPatchMatcher(
|
| 31 |
+
image_height=frames_style[0].shape[0],
|
| 32 |
+
image_width=frames_style[0].shape[1],
|
| 33 |
+
channel=3,
|
| 34 |
+
**ebsynth_config
|
| 35 |
+
)
|
| 36 |
+
# left part
|
| 37 |
+
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
|
| 38 |
+
table_l = table_manager.remapping_table_to_blending_table(table_l)
|
| 39 |
+
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
|
| 40 |
+
# right part
|
| 41 |
+
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
|
| 42 |
+
table_r = table_manager.remapping_table_to_blending_table(table_r)
|
| 43 |
+
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
|
| 44 |
+
# merge
|
| 45 |
+
frames = []
|
| 46 |
+
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
|
| 47 |
+
weight_m = -1
|
| 48 |
+
weight = weight_l + weight_m + weight_r
|
| 49 |
+
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
|
| 50 |
+
frames.append(frame)
|
| 51 |
+
frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
|
| 52 |
+
return frames
|
| 53 |
+
|
| 54 |
+
def __call__(self, rendered_frames, original_frames=None, **kwargs):
|
| 55 |
+
frames = self.run(
|
| 56 |
+
original_frames, rendered_frames,
|
| 57 |
+
self.batch_size, self.window_size, self.ebsynth_config
|
| 58 |
+
)
|
| 59 |
+
mempool = cp.get_default_memory_pool()
|
| 60 |
+
pinned_mempool = cp.get_default_pinned_memory_pool()
|
| 61 |
+
mempool.free_all_blocks()
|
| 62 |
+
pinned_mempool.free_all_blocks()
|
| 63 |
+
return frames
|
diffsynth/extensions/FastBlend/api.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner
|
| 2 |
+
from .data import VideoData, get_video_fps, save_video, search_for_images
|
| 3 |
+
import os
|
| 4 |
+
import gradio as gr
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder):
|
| 8 |
+
frames_guide = VideoData(video_guide, video_guide_folder)
|
| 9 |
+
frames_style = VideoData(video_style, video_style_folder)
|
| 10 |
+
message = ""
|
| 11 |
+
if len(frames_guide) < len(frames_style):
|
| 12 |
+
message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n"
|
| 13 |
+
frames_style.set_length(len(frames_guide))
|
| 14 |
+
elif len(frames_guide) > len(frames_style):
|
| 15 |
+
message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n"
|
| 16 |
+
frames_guide.set_length(len(frames_style))
|
| 17 |
+
height_guide, width_guide = frames_guide.shape()
|
| 18 |
+
height_style, width_style = frames_style.shape()
|
| 19 |
+
if height_guide != height_style or width_guide != width_style:
|
| 20 |
+
message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n"
|
| 21 |
+
frames_style.set_shape(height_guide, width_guide)
|
| 22 |
+
return frames_guide, frames_style, message
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def smooth_video(
|
| 26 |
+
video_guide,
|
| 27 |
+
video_guide_folder,
|
| 28 |
+
video_style,
|
| 29 |
+
video_style_folder,
|
| 30 |
+
mode,
|
| 31 |
+
window_size,
|
| 32 |
+
batch_size,
|
| 33 |
+
tracking_window_size,
|
| 34 |
+
output_path,
|
| 35 |
+
fps,
|
| 36 |
+
minimum_patch_size,
|
| 37 |
+
num_iter,
|
| 38 |
+
guide_weight,
|
| 39 |
+
initialize,
|
| 40 |
+
progress = None,
|
| 41 |
+
):
|
| 42 |
+
# input
|
| 43 |
+
frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder)
|
| 44 |
+
if len(message) > 0:
|
| 45 |
+
print(message)
|
| 46 |
+
# output
|
| 47 |
+
if output_path == "":
|
| 48 |
+
if video_style is None:
|
| 49 |
+
output_path = os.path.join(video_style_folder, "output")
|
| 50 |
+
else:
|
| 51 |
+
output_path = os.path.join(os.path.split(video_style)[0], "output")
|
| 52 |
+
os.makedirs(output_path, exist_ok=True)
|
| 53 |
+
print("No valid output_path. Your video will be saved here:", output_path)
|
| 54 |
+
elif not os.path.exists(output_path):
|
| 55 |
+
os.makedirs(output_path, exist_ok=True)
|
| 56 |
+
print("Your video will be saved here:", output_path)
|
| 57 |
+
frames_path = os.path.join(output_path, "frames")
|
| 58 |
+
video_path = os.path.join(output_path, "video.mp4")
|
| 59 |
+
os.makedirs(frames_path, exist_ok=True)
|
| 60 |
+
# process
|
| 61 |
+
if mode == "Fast" or mode == "Balanced":
|
| 62 |
+
tracking_window_size = 0
|
| 63 |
+
ebsynth_config = {
|
| 64 |
+
"minimum_patch_size": minimum_patch_size,
|
| 65 |
+
"threads_per_block": 8,
|
| 66 |
+
"num_iter": num_iter,
|
| 67 |
+
"gpu_id": 0,
|
| 68 |
+
"guide_weight": guide_weight,
|
| 69 |
+
"initialize": initialize,
|
| 70 |
+
"tracking_window_size": tracking_window_size,
|
| 71 |
+
}
|
| 72 |
+
if mode == "Fast":
|
| 73 |
+
FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
| 74 |
+
elif mode == "Balanced":
|
| 75 |
+
BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
| 76 |
+
elif mode == "Accurate":
|
| 77 |
+
AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
| 78 |
+
# output
|
| 79 |
+
try:
|
| 80 |
+
fps = int(fps)
|
| 81 |
+
except:
|
| 82 |
+
fps = get_video_fps(video_style) if video_style is not None else 30
|
| 83 |
+
print("Fps:", fps)
|
| 84 |
+
print("Saving video...")
|
| 85 |
+
video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps)
|
| 86 |
+
print("Success!")
|
| 87 |
+
print("Your frames are here:", frames_path)
|
| 88 |
+
print("Your video is here:", video_path)
|
| 89 |
+
return output_path, fps, video_path
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class KeyFrameMatcher:
|
| 93 |
+
def __init__(self):
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
def extract_number_from_filename(self, file_name):
|
| 97 |
+
result = []
|
| 98 |
+
number = -1
|
| 99 |
+
for i in file_name:
|
| 100 |
+
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
| 101 |
+
if number == -1:
|
| 102 |
+
number = 0
|
| 103 |
+
number = number*10 + ord(i) - ord("0")
|
| 104 |
+
else:
|
| 105 |
+
if number != -1:
|
| 106 |
+
result.append(number)
|
| 107 |
+
number = -1
|
| 108 |
+
if number != -1:
|
| 109 |
+
result.append(number)
|
| 110 |
+
result = tuple(result)
|
| 111 |
+
return result
|
| 112 |
+
|
| 113 |
+
def extract_number_from_filenames(self, file_names):
|
| 114 |
+
numbers = [self.extract_number_from_filename(file_name) for file_name in file_names]
|
| 115 |
+
min_length = min(len(i) for i in numbers)
|
| 116 |
+
for i in range(min_length-1, -1, -1):
|
| 117 |
+
if len(set(number[i] for number in numbers))==len(file_names):
|
| 118 |
+
return [number[i] for number in numbers]
|
| 119 |
+
return list(range(len(file_names)))
|
| 120 |
+
|
| 121 |
+
def match_using_filename(self, file_names_a, file_names_b):
|
| 122 |
+
file_names_b_set = set(file_names_b)
|
| 123 |
+
matched_file_name = []
|
| 124 |
+
for file_name in file_names_a:
|
| 125 |
+
if file_name not in file_names_b_set:
|
| 126 |
+
matched_file_name.append(None)
|
| 127 |
+
else:
|
| 128 |
+
matched_file_name.append(file_name)
|
| 129 |
+
return matched_file_name
|
| 130 |
+
|
| 131 |
+
def match_using_numbers(self, file_names_a, file_names_b):
|
| 132 |
+
numbers_a = self.extract_number_from_filenames(file_names_a)
|
| 133 |
+
numbers_b = self.extract_number_from_filenames(file_names_b)
|
| 134 |
+
numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)}
|
| 135 |
+
matched_file_name = []
|
| 136 |
+
for number in numbers_a:
|
| 137 |
+
if number in numbers_b_dict:
|
| 138 |
+
matched_file_name.append(numbers_b_dict[number])
|
| 139 |
+
else:
|
| 140 |
+
matched_file_name.append(None)
|
| 141 |
+
return matched_file_name
|
| 142 |
+
|
| 143 |
+
def match_filenames(self, file_names_a, file_names_b):
|
| 144 |
+
matched_file_name = self.match_using_filename(file_names_a, file_names_b)
|
| 145 |
+
if sum([i is not None for i in matched_file_name]) > 0:
|
| 146 |
+
return matched_file_name
|
| 147 |
+
matched_file_name = self.match_using_numbers(file_names_a, file_names_b)
|
| 148 |
+
return matched_file_name
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def detect_frames(frames_path, keyframes_path):
|
| 152 |
+
if not os.path.exists(frames_path) and not os.path.exists(keyframes_path):
|
| 153 |
+
return "Please input the directory of guide video and rendered frames"
|
| 154 |
+
elif not os.path.exists(frames_path):
|
| 155 |
+
return "Please input the directory of guide video"
|
| 156 |
+
elif not os.path.exists(keyframes_path):
|
| 157 |
+
return "Please input the directory of rendered frames"
|
| 158 |
+
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
|
| 159 |
+
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
|
| 160 |
+
if len(frames)==0:
|
| 161 |
+
return f"No images detected in {frames_path}"
|
| 162 |
+
if len(keyframes)==0:
|
| 163 |
+
return f"No images detected in {keyframes_path}"
|
| 164 |
+
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
|
| 165 |
+
max_filename_length = max([len(i) for i in frames])
|
| 166 |
+
if sum([i is not None for i in matched_keyframes])==0:
|
| 167 |
+
message = ""
|
| 168 |
+
for frame, matched_keyframe in zip(frames, matched_keyframes):
|
| 169 |
+
message += frame + " " * (max_filename_length - len(frame) + 1)
|
| 170 |
+
message += "--> No matched keyframes\n"
|
| 171 |
+
else:
|
| 172 |
+
message = ""
|
| 173 |
+
for frame, matched_keyframe in zip(frames, matched_keyframes):
|
| 174 |
+
message += frame + " " * (max_filename_length - len(frame) + 1)
|
| 175 |
+
if matched_keyframe is None:
|
| 176 |
+
message += "--> [to be rendered]\n"
|
| 177 |
+
else:
|
| 178 |
+
message += f"--> {matched_keyframe}\n"
|
| 179 |
+
return message
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def check_input_for_interpolating(frames_path, keyframes_path):
|
| 183 |
+
# search for images
|
| 184 |
+
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
|
| 185 |
+
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
|
| 186 |
+
# match frames
|
| 187 |
+
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
|
| 188 |
+
file_list = [file_name for file_name in matched_keyframes if file_name is not None]
|
| 189 |
+
index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None]
|
| 190 |
+
frames_guide = VideoData(None, frames_path)
|
| 191 |
+
frames_style = VideoData(None, keyframes_path, file_list=file_list)
|
| 192 |
+
# match shape
|
| 193 |
+
message = ""
|
| 194 |
+
height_guide, width_guide = frames_guide.shape()
|
| 195 |
+
height_style, width_style = frames_style.shape()
|
| 196 |
+
if height_guide != height_style or width_guide != width_style:
|
| 197 |
+
message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n"
|
| 198 |
+
frames_style.set_shape(height_guide, width_guide)
|
| 199 |
+
return frames_guide, frames_style, index_style, message
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def interpolate_video(
|
| 203 |
+
frames_path,
|
| 204 |
+
keyframes_path,
|
| 205 |
+
output_path,
|
| 206 |
+
fps,
|
| 207 |
+
batch_size,
|
| 208 |
+
tracking_window_size,
|
| 209 |
+
minimum_patch_size,
|
| 210 |
+
num_iter,
|
| 211 |
+
guide_weight,
|
| 212 |
+
initialize,
|
| 213 |
+
progress = None,
|
| 214 |
+
):
|
| 215 |
+
# input
|
| 216 |
+
frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path)
|
| 217 |
+
if len(message) > 0:
|
| 218 |
+
print(message)
|
| 219 |
+
# output
|
| 220 |
+
if output_path == "":
|
| 221 |
+
output_path = os.path.join(keyframes_path, "output")
|
| 222 |
+
os.makedirs(output_path, exist_ok=True)
|
| 223 |
+
print("No valid output_path. Your video will be saved here:", output_path)
|
| 224 |
+
elif not os.path.exists(output_path):
|
| 225 |
+
os.makedirs(output_path, exist_ok=True)
|
| 226 |
+
print("Your video will be saved here:", output_path)
|
| 227 |
+
output_frames_path = os.path.join(output_path, "frames")
|
| 228 |
+
output_video_path = os.path.join(output_path, "video.mp4")
|
| 229 |
+
os.makedirs(output_frames_path, exist_ok=True)
|
| 230 |
+
# process
|
| 231 |
+
ebsynth_config = {
|
| 232 |
+
"minimum_patch_size": minimum_patch_size,
|
| 233 |
+
"threads_per_block": 8,
|
| 234 |
+
"num_iter": num_iter,
|
| 235 |
+
"gpu_id": 0,
|
| 236 |
+
"guide_weight": guide_weight,
|
| 237 |
+
"initialize": initialize,
|
| 238 |
+
"tracking_window_size": tracking_window_size
|
| 239 |
+
}
|
| 240 |
+
if len(index_style)==1:
|
| 241 |
+
InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
|
| 242 |
+
else:
|
| 243 |
+
InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
|
| 244 |
+
try:
|
| 245 |
+
fps = int(fps)
|
| 246 |
+
except:
|
| 247 |
+
fps = 30
|
| 248 |
+
print("Fps:", fps)
|
| 249 |
+
print("Saving video...")
|
| 250 |
+
video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps)
|
| 251 |
+
print("Success!")
|
| 252 |
+
print("Your frames are here:", output_frames_path)
|
| 253 |
+
print("Your video is here:", video_path)
|
| 254 |
+
return output_path, fps, video_path
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def on_ui_tabs():
|
| 258 |
+
with gr.Blocks(analytics_enabled=False) as ui_component:
|
| 259 |
+
with gr.Tab("Blend"):
|
| 260 |
+
gr.Markdown("""
|
| 261 |
+
# Blend
|
| 262 |
+
|
| 263 |
+
Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution.
|
| 264 |
+
""")
|
| 265 |
+
with gr.Row():
|
| 266 |
+
with gr.Column():
|
| 267 |
+
with gr.Tab("Guide video"):
|
| 268 |
+
video_guide = gr.Video(label="Guide video")
|
| 269 |
+
with gr.Tab("Guide video (images format)"):
|
| 270 |
+
video_guide_folder = gr.Textbox(label="Guide video (images format)", value="")
|
| 271 |
+
with gr.Column():
|
| 272 |
+
with gr.Tab("Style video"):
|
| 273 |
+
video_style = gr.Video(label="Style video")
|
| 274 |
+
with gr.Tab("Style video (images format)"):
|
| 275 |
+
video_style_folder = gr.Textbox(label="Style video (images format)", value="")
|
| 276 |
+
with gr.Column():
|
| 277 |
+
output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video")
|
| 278 |
+
fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
|
| 279 |
+
video_output = gr.Video(label="Output video", interactive=False, show_share_button=True)
|
| 280 |
+
btn = gr.Button(value="Blend")
|
| 281 |
+
with gr.Row():
|
| 282 |
+
with gr.Column():
|
| 283 |
+
gr.Markdown("# Settings")
|
| 284 |
+
mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True)
|
| 285 |
+
window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True)
|
| 286 |
+
batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
|
| 287 |
+
tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True)
|
| 288 |
+
gr.Markdown("## Advanced Settings")
|
| 289 |
+
minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True)
|
| 290 |
+
num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
|
| 291 |
+
guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
|
| 292 |
+
initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
|
| 293 |
+
with gr.Column():
|
| 294 |
+
gr.Markdown("""
|
| 295 |
+
# Reference
|
| 296 |
+
|
| 297 |
+
* Output directory: the directory to save the video.
|
| 298 |
+
* Inference mode
|
| 299 |
+
|
| 300 |
+
|Mode|Time|Memory|Quality|Frame by frame output|Description|
|
| 301 |
+
|-|-|-|-|-|-|
|
| 302 |
+
|Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.|
|
| 303 |
+
|Balanced|■■|■|■■|Yes|Blend the frames naively.|
|
| 304 |
+
|Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.|
|
| 305 |
+
|
| 306 |
+
* Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy.
|
| 307 |
+
* Batch size: a larger batch size makes the program faster but requires more VRAM.
|
| 308 |
+
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
|
| 309 |
+
* Advanced settings
|
| 310 |
+
* Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5)
|
| 311 |
+
* Number of iterations: the number of iterations of patch matching. (Default: 5)
|
| 312 |
+
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
|
| 313 |
+
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
|
| 314 |
+
""")
|
| 315 |
+
btn.click(
|
| 316 |
+
smooth_video,
|
| 317 |
+
inputs=[
|
| 318 |
+
video_guide,
|
| 319 |
+
video_guide_folder,
|
| 320 |
+
video_style,
|
| 321 |
+
video_style_folder,
|
| 322 |
+
mode,
|
| 323 |
+
window_size,
|
| 324 |
+
batch_size,
|
| 325 |
+
tracking_window_size,
|
| 326 |
+
output_path,
|
| 327 |
+
fps,
|
| 328 |
+
minimum_patch_size,
|
| 329 |
+
num_iter,
|
| 330 |
+
guide_weight,
|
| 331 |
+
initialize
|
| 332 |
+
],
|
| 333 |
+
outputs=[output_path, fps, video_output]
|
| 334 |
+
)
|
| 335 |
+
with gr.Tab("Interpolate"):
|
| 336 |
+
gr.Markdown("""
|
| 337 |
+
# Interpolate
|
| 338 |
+
|
| 339 |
+
Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution.
|
| 340 |
+
""")
|
| 341 |
+
with gr.Row():
|
| 342 |
+
with gr.Column():
|
| 343 |
+
with gr.Row():
|
| 344 |
+
with gr.Column():
|
| 345 |
+
video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="")
|
| 346 |
+
with gr.Column():
|
| 347 |
+
rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="")
|
| 348 |
+
with gr.Row():
|
| 349 |
+
detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False)
|
| 350 |
+
video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
|
| 351 |
+
rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
|
| 352 |
+
with gr.Column():
|
| 353 |
+
output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes")
|
| 354 |
+
fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
|
| 355 |
+
video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True)
|
| 356 |
+
btn_ = gr.Button(value="Interpolate")
|
| 357 |
+
with gr.Row():
|
| 358 |
+
with gr.Column():
|
| 359 |
+
gr.Markdown("# Settings")
|
| 360 |
+
batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
|
| 361 |
+
tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True)
|
| 362 |
+
gr.Markdown("## Advanced Settings")
|
| 363 |
+
minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True)
|
| 364 |
+
num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
|
| 365 |
+
guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
|
| 366 |
+
initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
|
| 367 |
+
with gr.Column():
|
| 368 |
+
gr.Markdown("""
|
| 369 |
+
# Reference
|
| 370 |
+
|
| 371 |
+
* Output directory: the directory to save the video.
|
| 372 |
+
* Batch size: a larger batch size makes the program faster but requires more VRAM.
|
| 373 |
+
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
|
| 374 |
+
* Advanced settings
|
| 375 |
+
* Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)**
|
| 376 |
+
* Number of iterations: the number of iterations of patch matching. (Default: 5)
|
| 377 |
+
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
|
| 378 |
+
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
|
| 379 |
+
""")
|
| 380 |
+
btn_.click(
|
| 381 |
+
interpolate_video,
|
| 382 |
+
inputs=[
|
| 383 |
+
video_guide_folder_,
|
| 384 |
+
rendered_keyframes_,
|
| 385 |
+
output_path_,
|
| 386 |
+
fps_,
|
| 387 |
+
batch_size_,
|
| 388 |
+
tracking_window_size_,
|
| 389 |
+
minimum_patch_size_,
|
| 390 |
+
num_iter_,
|
| 391 |
+
guide_weight_,
|
| 392 |
+
initialize_,
|
| 393 |
+
],
|
| 394 |
+
outputs=[output_path_, fps_, video_output_]
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
return [(ui_component, "FastBlend", "FastBlend_ui")]
|
diffsynth/extensions/FastBlend/cupy_kernels.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cupy as cp
|
| 2 |
+
|
| 3 |
+
remapping_kernel = cp.RawKernel(r'''
|
| 4 |
+
extern "C" __global__
|
| 5 |
+
void remap(
|
| 6 |
+
const int height,
|
| 7 |
+
const int width,
|
| 8 |
+
const int channel,
|
| 9 |
+
const int patch_size,
|
| 10 |
+
const int pad_size,
|
| 11 |
+
const float* source_style,
|
| 12 |
+
const int* nnf,
|
| 13 |
+
float* target_style
|
| 14 |
+
) {
|
| 15 |
+
const int r = (patch_size - 1) / 2;
|
| 16 |
+
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
| 17 |
+
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
| 18 |
+
if (x >= height or y >= width) return;
|
| 19 |
+
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
| 20 |
+
const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
|
| 21 |
+
const int min_px = x < r ? -x : -r;
|
| 22 |
+
const int max_px = x + r > height - 1 ? height - 1 - x : r;
|
| 23 |
+
const int min_py = y < r ? -y : -r;
|
| 24 |
+
const int max_py = y + r > width - 1 ? width - 1 - y : r;
|
| 25 |
+
int num = 0;
|
| 26 |
+
for (int px = min_px; px <= max_px; px++){
|
| 27 |
+
for (int py = min_py; py <= max_py; py++){
|
| 28 |
+
const int nid = (x + px) * width + y + py;
|
| 29 |
+
const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
|
| 30 |
+
const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
|
| 31 |
+
if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
|
| 32 |
+
const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
|
| 33 |
+
num++;
|
| 34 |
+
for (int c = 0; c < channel; c++){
|
| 35 |
+
target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
|
| 36 |
+
}
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
for (int c = 0; c < channel; c++){
|
| 40 |
+
target_style[z + pid * channel + c] /= num;
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
''', 'remap')
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
patch_error_kernel = cp.RawKernel(r'''
|
| 47 |
+
extern "C" __global__
|
| 48 |
+
void patch_error(
|
| 49 |
+
const int height,
|
| 50 |
+
const int width,
|
| 51 |
+
const int channel,
|
| 52 |
+
const int patch_size,
|
| 53 |
+
const int pad_size,
|
| 54 |
+
const float* source,
|
| 55 |
+
const int* nnf,
|
| 56 |
+
const float* target,
|
| 57 |
+
float* error
|
| 58 |
+
) {
|
| 59 |
+
const int r = (patch_size - 1) / 2;
|
| 60 |
+
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
| 61 |
+
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
| 62 |
+
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
| 63 |
+
if (x >= height or y >= width) return;
|
| 64 |
+
const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
|
| 65 |
+
const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
|
| 66 |
+
float e = 0;
|
| 67 |
+
for (int px = -r; px <= r; px++){
|
| 68 |
+
for (int py = -r; py <= r; py++){
|
| 69 |
+
const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
|
| 70 |
+
const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
|
| 71 |
+
for (int c = 0; c < channel; c++){
|
| 72 |
+
const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
|
| 73 |
+
e += diff * diff;
|
| 74 |
+
}
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
error[blockIdx.z * height * width + x * width + y] = e;
|
| 78 |
+
}
|
| 79 |
+
''', 'patch_error')
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
pairwise_patch_error_kernel = cp.RawKernel(r'''
|
| 83 |
+
extern "C" __global__
|
| 84 |
+
void pairwise_patch_error(
|
| 85 |
+
const int height,
|
| 86 |
+
const int width,
|
| 87 |
+
const int channel,
|
| 88 |
+
const int patch_size,
|
| 89 |
+
const int pad_size,
|
| 90 |
+
const float* source_a,
|
| 91 |
+
const int* nnf_a,
|
| 92 |
+
const float* source_b,
|
| 93 |
+
const int* nnf_b,
|
| 94 |
+
float* error
|
| 95 |
+
) {
|
| 96 |
+
const int r = (patch_size - 1) / 2;
|
| 97 |
+
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
| 98 |
+
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
| 99 |
+
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
| 100 |
+
if (x >= height or y >= width) return;
|
| 101 |
+
const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
|
| 102 |
+
const int x_a = nnf_a[z_nnf + 0];
|
| 103 |
+
const int y_a = nnf_a[z_nnf + 1];
|
| 104 |
+
const int x_b = nnf_b[z_nnf + 0];
|
| 105 |
+
const int y_b = nnf_b[z_nnf + 1];
|
| 106 |
+
float e = 0;
|
| 107 |
+
for (int px = -r; px <= r; px++){
|
| 108 |
+
for (int py = -r; py <= r; py++){
|
| 109 |
+
const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
|
| 110 |
+
const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
|
| 111 |
+
for (int c = 0; c < channel; c++){
|
| 112 |
+
const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
|
| 113 |
+
e += diff * diff;
|
| 114 |
+
}
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
error[blockIdx.z * height * width + x * width + y] = e;
|
| 118 |
+
}
|
| 119 |
+
''', 'pairwise_patch_error')
|
diffsynth/extensions/FastBlend/data.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import imageio, os
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def read_video(file_name):
|
| 7 |
+
reader = imageio.get_reader(file_name)
|
| 8 |
+
video = []
|
| 9 |
+
for frame in reader:
|
| 10 |
+
frame = np.array(frame)
|
| 11 |
+
video.append(frame)
|
| 12 |
+
reader.close()
|
| 13 |
+
return video
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_video_fps(file_name):
|
| 17 |
+
reader = imageio.get_reader(file_name)
|
| 18 |
+
fps = reader.get_meta_data()["fps"]
|
| 19 |
+
reader.close()
|
| 20 |
+
return fps
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def save_video(frames_path, video_path, num_frames, fps):
|
| 24 |
+
writer = imageio.get_writer(video_path, fps=fps, quality=9)
|
| 25 |
+
for i in range(num_frames):
|
| 26 |
+
frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
|
| 27 |
+
writer.append_data(frame)
|
| 28 |
+
writer.close()
|
| 29 |
+
return video_path
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class LowMemoryVideo:
|
| 33 |
+
def __init__(self, file_name):
|
| 34 |
+
self.reader = imageio.get_reader(file_name)
|
| 35 |
+
|
| 36 |
+
def __len__(self):
|
| 37 |
+
return self.reader.count_frames()
|
| 38 |
+
|
| 39 |
+
def __getitem__(self, item):
|
| 40 |
+
return np.array(self.reader.get_data(item))
|
| 41 |
+
|
| 42 |
+
def __del__(self):
|
| 43 |
+
self.reader.close()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def split_file_name(file_name):
|
| 47 |
+
result = []
|
| 48 |
+
number = -1
|
| 49 |
+
for i in file_name:
|
| 50 |
+
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
| 51 |
+
if number == -1:
|
| 52 |
+
number = 0
|
| 53 |
+
number = number*10 + ord(i) - ord("0")
|
| 54 |
+
else:
|
| 55 |
+
if number != -1:
|
| 56 |
+
result.append(number)
|
| 57 |
+
number = -1
|
| 58 |
+
result.append(i)
|
| 59 |
+
if number != -1:
|
| 60 |
+
result.append(number)
|
| 61 |
+
result = tuple(result)
|
| 62 |
+
return result
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def search_for_images(folder):
|
| 66 |
+
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
|
| 67 |
+
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
|
| 68 |
+
file_list = [i[1] for i in sorted(file_list)]
|
| 69 |
+
file_list = [os.path.join(folder, i) for i in file_list]
|
| 70 |
+
return file_list
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def read_images(folder):
|
| 74 |
+
file_list = search_for_images(folder)
|
| 75 |
+
frames = [np.array(Image.open(i)) for i in file_list]
|
| 76 |
+
return frames
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class LowMemoryImageFolder:
|
| 80 |
+
def __init__(self, folder, file_list=None):
|
| 81 |
+
if file_list is None:
|
| 82 |
+
self.file_list = search_for_images(folder)
|
| 83 |
+
else:
|
| 84 |
+
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
|
| 85 |
+
|
| 86 |
+
def __len__(self):
|
| 87 |
+
return len(self.file_list)
|
| 88 |
+
|
| 89 |
+
def __getitem__(self, item):
|
| 90 |
+
return np.array(Image.open(self.file_list[item]))
|
| 91 |
+
|
| 92 |
+
def __del__(self):
|
| 93 |
+
pass
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class VideoData:
|
| 97 |
+
def __init__(self, video_file, image_folder, **kwargs):
|
| 98 |
+
if video_file is not None:
|
| 99 |
+
self.data_type = "video"
|
| 100 |
+
self.data = LowMemoryVideo(video_file, **kwargs)
|
| 101 |
+
elif image_folder is not None:
|
| 102 |
+
self.data_type = "images"
|
| 103 |
+
self.data = LowMemoryImageFolder(image_folder, **kwargs)
|
| 104 |
+
else:
|
| 105 |
+
raise ValueError("Cannot open video or image folder")
|
| 106 |
+
self.length = None
|
| 107 |
+
self.height = None
|
| 108 |
+
self.width = None
|
| 109 |
+
|
| 110 |
+
def raw_data(self):
|
| 111 |
+
frames = []
|
| 112 |
+
for i in range(self.__len__()):
|
| 113 |
+
frames.append(self.__getitem__(i))
|
| 114 |
+
return frames
|
| 115 |
+
|
| 116 |
+
def set_length(self, length):
|
| 117 |
+
self.length = length
|
| 118 |
+
|
| 119 |
+
def set_shape(self, height, width):
|
| 120 |
+
self.height = height
|
| 121 |
+
self.width = width
|
| 122 |
+
|
| 123 |
+
def __len__(self):
|
| 124 |
+
if self.length is None:
|
| 125 |
+
return len(self.data)
|
| 126 |
+
else:
|
| 127 |
+
return self.length
|
| 128 |
+
|
| 129 |
+
def shape(self):
|
| 130 |
+
if self.height is not None and self.width is not None:
|
| 131 |
+
return self.height, self.width
|
| 132 |
+
else:
|
| 133 |
+
height, width, _ = self.__getitem__(0).shape
|
| 134 |
+
return height, width
|
| 135 |
+
|
| 136 |
+
def __getitem__(self, item):
|
| 137 |
+
frame = self.data.__getitem__(item)
|
| 138 |
+
height, width, _ = frame.shape
|
| 139 |
+
if self.height is not None and self.width is not None:
|
| 140 |
+
if self.height != height or self.width != width:
|
| 141 |
+
frame = Image.fromarray(frame).resize((self.width, self.height))
|
| 142 |
+
frame = np.array(frame)
|
| 143 |
+
return frame
|
| 144 |
+
|
| 145 |
+
def __del__(self):
|
| 146 |
+
pass
|
diffsynth/extensions/FastBlend/patch_match.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cupy as cp
|
| 4 |
+
import cv2
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class PatchMatcher:
|
| 8 |
+
def __init__(
|
| 9 |
+
self, height, width, channel, minimum_patch_size,
|
| 10 |
+
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
|
| 11 |
+
random_search_steps=3, random_search_range=4,
|
| 12 |
+
use_mean_target_style=False, use_pairwise_patch_error=False,
|
| 13 |
+
tracking_window_size=0
|
| 14 |
+
):
|
| 15 |
+
self.height = height
|
| 16 |
+
self.width = width
|
| 17 |
+
self.channel = channel
|
| 18 |
+
self.minimum_patch_size = minimum_patch_size
|
| 19 |
+
self.threads_per_block = threads_per_block
|
| 20 |
+
self.num_iter = num_iter
|
| 21 |
+
self.gpu_id = gpu_id
|
| 22 |
+
self.guide_weight = guide_weight
|
| 23 |
+
self.random_search_steps = random_search_steps
|
| 24 |
+
self.random_search_range = random_search_range
|
| 25 |
+
self.use_mean_target_style = use_mean_target_style
|
| 26 |
+
self.use_pairwise_patch_error = use_pairwise_patch_error
|
| 27 |
+
self.tracking_window_size = tracking_window_size
|
| 28 |
+
|
| 29 |
+
self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
|
| 30 |
+
self.pad_size = self.patch_size_list[0] // 2
|
| 31 |
+
self.grid = (
|
| 32 |
+
(height + threads_per_block - 1) // threads_per_block,
|
| 33 |
+
(width + threads_per_block - 1) // threads_per_block
|
| 34 |
+
)
|
| 35 |
+
self.block = (threads_per_block, threads_per_block)
|
| 36 |
+
|
| 37 |
+
def pad_image(self, image):
|
| 38 |
+
return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
|
| 39 |
+
|
| 40 |
+
def unpad_image(self, image):
|
| 41 |
+
return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
|
| 42 |
+
|
| 43 |
+
def apply_nnf_to_image(self, nnf, source):
|
| 44 |
+
batch_size = source.shape[0]
|
| 45 |
+
target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
|
| 46 |
+
remapping_kernel(
|
| 47 |
+
self.grid + (batch_size,),
|
| 48 |
+
self.block,
|
| 49 |
+
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
|
| 50 |
+
)
|
| 51 |
+
return target
|
| 52 |
+
|
| 53 |
+
def get_patch_error(self, source, nnf, target):
|
| 54 |
+
batch_size = source.shape[0]
|
| 55 |
+
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
|
| 56 |
+
patch_error_kernel(
|
| 57 |
+
self.grid + (batch_size,),
|
| 58 |
+
self.block,
|
| 59 |
+
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
|
| 60 |
+
)
|
| 61 |
+
return error
|
| 62 |
+
|
| 63 |
+
def get_pairwise_patch_error(self, source, nnf):
|
| 64 |
+
batch_size = source.shape[0]//2
|
| 65 |
+
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
|
| 66 |
+
source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
|
| 67 |
+
source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
|
| 68 |
+
pairwise_patch_error_kernel(
|
| 69 |
+
self.grid + (batch_size,),
|
| 70 |
+
self.block,
|
| 71 |
+
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
|
| 72 |
+
)
|
| 73 |
+
error = error.repeat(2, axis=0)
|
| 74 |
+
return error
|
| 75 |
+
|
| 76 |
+
def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
|
| 77 |
+
error_guide = self.get_patch_error(source_guide, nnf, target_guide)
|
| 78 |
+
if self.use_mean_target_style:
|
| 79 |
+
target_style = self.apply_nnf_to_image(nnf, source_style)
|
| 80 |
+
target_style = target_style.mean(axis=0, keepdims=True)
|
| 81 |
+
target_style = target_style.repeat(source_guide.shape[0], axis=0)
|
| 82 |
+
if self.use_pairwise_patch_error:
|
| 83 |
+
error_style = self.get_pairwise_patch_error(source_style, nnf)
|
| 84 |
+
else:
|
| 85 |
+
error_style = self.get_patch_error(source_style, nnf, target_style)
|
| 86 |
+
error = error_guide * self.guide_weight + error_style
|
| 87 |
+
return error
|
| 88 |
+
|
| 89 |
+
def clamp_bound(self, nnf):
|
| 90 |
+
nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
|
| 91 |
+
nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
|
| 92 |
+
return nnf
|
| 93 |
+
|
| 94 |
+
def random_step(self, nnf, r):
|
| 95 |
+
batch_size = nnf.shape[0]
|
| 96 |
+
step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
|
| 97 |
+
upd_nnf = self.clamp_bound(nnf + step)
|
| 98 |
+
return upd_nnf
|
| 99 |
+
|
| 100 |
+
def neighboor_step(self, nnf, d):
|
| 101 |
+
if d==0:
|
| 102 |
+
upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
|
| 103 |
+
upd_nnf[:, :, :, 0] += 1
|
| 104 |
+
elif d==1:
|
| 105 |
+
upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
|
| 106 |
+
upd_nnf[:, :, :, 1] += 1
|
| 107 |
+
elif d==2:
|
| 108 |
+
upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
|
| 109 |
+
upd_nnf[:, :, :, 0] -= 1
|
| 110 |
+
elif d==3:
|
| 111 |
+
upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
|
| 112 |
+
upd_nnf[:, :, :, 1] -= 1
|
| 113 |
+
upd_nnf = self.clamp_bound(upd_nnf)
|
| 114 |
+
return upd_nnf
|
| 115 |
+
|
| 116 |
+
def shift_nnf(self, nnf, d):
|
| 117 |
+
if d>0:
|
| 118 |
+
d = min(nnf.shape[0], d)
|
| 119 |
+
upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
|
| 120 |
+
else:
|
| 121 |
+
d = max(-nnf.shape[0], d)
|
| 122 |
+
upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
|
| 123 |
+
return upd_nnf
|
| 124 |
+
|
| 125 |
+
def track_step(self, nnf, d):
|
| 126 |
+
if self.use_pairwise_patch_error:
|
| 127 |
+
upd_nnf = cp.zeros_like(nnf)
|
| 128 |
+
upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
|
| 129 |
+
upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
|
| 130 |
+
else:
|
| 131 |
+
upd_nnf = self.shift_nnf(nnf, d)
|
| 132 |
+
return upd_nnf
|
| 133 |
+
|
| 134 |
+
def C(self, n, m):
|
| 135 |
+
# not used
|
| 136 |
+
c = 1
|
| 137 |
+
for i in range(1, n+1):
|
| 138 |
+
c *= i
|
| 139 |
+
for i in range(1, m+1):
|
| 140 |
+
c //= i
|
| 141 |
+
for i in range(1, n-m+1):
|
| 142 |
+
c //= i
|
| 143 |
+
return c
|
| 144 |
+
|
| 145 |
+
def bezier_step(self, nnf, r):
|
| 146 |
+
# not used
|
| 147 |
+
n = r * 2 - 1
|
| 148 |
+
upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
|
| 149 |
+
for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
|
| 150 |
+
if d>0:
|
| 151 |
+
ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
|
| 152 |
+
elif d<0:
|
| 153 |
+
ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
|
| 154 |
+
upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
|
| 155 |
+
upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
|
| 156 |
+
return upd_nnf
|
| 157 |
+
|
| 158 |
+
def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
|
| 159 |
+
upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
|
| 160 |
+
upd_idx = (upd_err < err)
|
| 161 |
+
nnf[upd_idx] = upd_nnf[upd_idx]
|
| 162 |
+
err[upd_idx] = upd_err[upd_idx]
|
| 163 |
+
return nnf, err
|
| 164 |
+
|
| 165 |
+
def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
| 166 |
+
for d in cp.random.permutation(4):
|
| 167 |
+
upd_nnf = self.neighboor_step(nnf, d)
|
| 168 |
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
| 169 |
+
return nnf, err
|
| 170 |
+
|
| 171 |
+
def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
| 172 |
+
for i in range(self.random_search_steps):
|
| 173 |
+
upd_nnf = self.random_step(nnf, self.random_search_range)
|
| 174 |
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
| 175 |
+
return nnf, err
|
| 176 |
+
|
| 177 |
+
def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
| 178 |
+
for d in range(1, self.tracking_window_size + 1):
|
| 179 |
+
upd_nnf = self.track_step(nnf, d)
|
| 180 |
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
| 181 |
+
upd_nnf = self.track_step(nnf, -d)
|
| 182 |
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
| 183 |
+
return nnf, err
|
| 184 |
+
|
| 185 |
+
def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
| 186 |
+
nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
|
| 187 |
+
nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
|
| 188 |
+
nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
|
| 189 |
+
return nnf, err
|
| 190 |
+
|
| 191 |
+
def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
|
| 192 |
+
with cp.cuda.Device(self.gpu_id):
|
| 193 |
+
source_guide = self.pad_image(source_guide)
|
| 194 |
+
target_guide = self.pad_image(target_guide)
|
| 195 |
+
source_style = self.pad_image(source_style)
|
| 196 |
+
for it in range(self.num_iter):
|
| 197 |
+
self.patch_size = self.patch_size_list[it]
|
| 198 |
+
target_style = self.apply_nnf_to_image(nnf, source_style)
|
| 199 |
+
err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
|
| 200 |
+
nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
|
| 201 |
+
target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
|
| 202 |
+
return nnf, target_style
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class PyramidPatchMatcher:
|
| 206 |
+
def __init__(
|
| 207 |
+
self, image_height, image_width, channel, minimum_patch_size,
|
| 208 |
+
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
|
| 209 |
+
use_mean_target_style=False, use_pairwise_patch_error=False,
|
| 210 |
+
tracking_window_size=0,
|
| 211 |
+
initialize="identity"
|
| 212 |
+
):
|
| 213 |
+
maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
|
| 214 |
+
self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
|
| 215 |
+
self.pyramid_heights = []
|
| 216 |
+
self.pyramid_widths = []
|
| 217 |
+
self.patch_matchers = []
|
| 218 |
+
self.minimum_patch_size = minimum_patch_size
|
| 219 |
+
self.num_iter = num_iter
|
| 220 |
+
self.gpu_id = gpu_id
|
| 221 |
+
self.initialize = initialize
|
| 222 |
+
for level in range(self.pyramid_level):
|
| 223 |
+
height = image_height//(2**(self.pyramid_level - 1 - level))
|
| 224 |
+
width = image_width//(2**(self.pyramid_level - 1 - level))
|
| 225 |
+
self.pyramid_heights.append(height)
|
| 226 |
+
self.pyramid_widths.append(width)
|
| 227 |
+
self.patch_matchers.append(PatchMatcher(
|
| 228 |
+
height, width, channel, minimum_patch_size=minimum_patch_size,
|
| 229 |
+
threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
|
| 230 |
+
use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
|
| 231 |
+
tracking_window_size=tracking_window_size
|
| 232 |
+
))
|
| 233 |
+
|
| 234 |
+
def resample_image(self, images, level):
|
| 235 |
+
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
| 236 |
+
images = images.get()
|
| 237 |
+
images_resample = []
|
| 238 |
+
for image in images:
|
| 239 |
+
image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
|
| 240 |
+
images_resample.append(image_resample)
|
| 241 |
+
images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
|
| 242 |
+
return images_resample
|
| 243 |
+
|
| 244 |
+
def initialize_nnf(self, batch_size):
|
| 245 |
+
if self.initialize == "random":
|
| 246 |
+
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
|
| 247 |
+
nnf = cp.stack([
|
| 248 |
+
cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
|
| 249 |
+
cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
|
| 250 |
+
], axis=3)
|
| 251 |
+
elif self.initialize == "identity":
|
| 252 |
+
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
|
| 253 |
+
nnf = cp.stack([
|
| 254 |
+
cp.repeat(cp.arange(height), width).reshape(height, width),
|
| 255 |
+
cp.tile(cp.arange(width), height).reshape(height, width)
|
| 256 |
+
], axis=2)
|
| 257 |
+
nnf = cp.stack([nnf] * batch_size)
|
| 258 |
+
else:
|
| 259 |
+
raise NotImplementedError()
|
| 260 |
+
return nnf
|
| 261 |
+
|
| 262 |
+
def update_nnf(self, nnf, level):
|
| 263 |
+
# upscale
|
| 264 |
+
nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
|
| 265 |
+
nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
|
| 266 |
+
nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
|
| 267 |
+
# check if scale is 2
|
| 268 |
+
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
| 269 |
+
if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
|
| 270 |
+
nnf = nnf.get().astype(np.float32)
|
| 271 |
+
nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
|
| 272 |
+
nnf = cp.array(np.stack(nnf), dtype=cp.int32)
|
| 273 |
+
nnf = self.patch_matchers[level].clamp_bound(nnf)
|
| 274 |
+
return nnf
|
| 275 |
+
|
| 276 |
+
def apply_nnf_to_image(self, nnf, image):
|
| 277 |
+
with cp.cuda.Device(self.gpu_id):
|
| 278 |
+
image = self.patch_matchers[-1].pad_image(image)
|
| 279 |
+
image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
|
| 280 |
+
return image
|
| 281 |
+
|
| 282 |
+
def estimate_nnf(self, source_guide, target_guide, source_style):
|
| 283 |
+
with cp.cuda.Device(self.gpu_id):
|
| 284 |
+
if not isinstance(source_guide, cp.ndarray):
|
| 285 |
+
source_guide = cp.array(source_guide, dtype=cp.float32)
|
| 286 |
+
if not isinstance(target_guide, cp.ndarray):
|
| 287 |
+
target_guide = cp.array(target_guide, dtype=cp.float32)
|
| 288 |
+
if not isinstance(source_style, cp.ndarray):
|
| 289 |
+
source_style = cp.array(source_style, dtype=cp.float32)
|
| 290 |
+
for level in range(self.pyramid_level):
|
| 291 |
+
nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
|
| 292 |
+
source_guide_ = self.resample_image(source_guide, level)
|
| 293 |
+
target_guide_ = self.resample_image(target_guide, level)
|
| 294 |
+
source_style_ = self.resample_image(source_style, level)
|
| 295 |
+
nnf, target_style = self.patch_matchers[level].estimate_nnf(
|
| 296 |
+
source_guide_, target_guide_, source_style_, nnf
|
| 297 |
+
)
|
| 298 |
+
return nnf.get(), target_style.get()
|
diffsynth/extensions/FastBlend/runners/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .accurate import AccurateModeRunner
|
| 2 |
+
from .fast import FastModeRunner
|
| 3 |
+
from .balanced import BalancedModeRunner
|
| 4 |
+
from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner
|
diffsynth/extensions/FastBlend/runners/accurate.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..patch_match import PyramidPatchMatcher
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AccurateModeRunner:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
|
| 13 |
+
patch_match_engine = PyramidPatchMatcher(
|
| 14 |
+
image_height=frames_style[0].shape[0],
|
| 15 |
+
image_width=frames_style[0].shape[1],
|
| 16 |
+
channel=3,
|
| 17 |
+
use_mean_target_style=True,
|
| 18 |
+
**ebsynth_config
|
| 19 |
+
)
|
| 20 |
+
# run
|
| 21 |
+
n = len(frames_style)
|
| 22 |
+
for target in tqdm(range(n), desc=desc):
|
| 23 |
+
l, r = max(target - window_size, 0), min(target + window_size + 1, n)
|
| 24 |
+
remapped_frames = []
|
| 25 |
+
for i in range(l, r, batch_size):
|
| 26 |
+
j = min(i + batch_size, r)
|
| 27 |
+
source_guide = np.stack([frames_guide[source] for source in range(i, j)])
|
| 28 |
+
target_guide = np.stack([frames_guide[target]] * (j - i))
|
| 29 |
+
source_style = np.stack([frames_style[source] for source in range(i, j)])
|
| 30 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
| 31 |
+
remapped_frames.append(target_style)
|
| 32 |
+
frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
|
| 33 |
+
frame = frame.clip(0, 255).astype("uint8")
|
| 34 |
+
if save_path is not None:
|
| 35 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
diffsynth/extensions/FastBlend/runners/balanced.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..patch_match import PyramidPatchMatcher
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BalancedModeRunner:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
|
| 13 |
+
patch_match_engine = PyramidPatchMatcher(
|
| 14 |
+
image_height=frames_style[0].shape[0],
|
| 15 |
+
image_width=frames_style[0].shape[1],
|
| 16 |
+
channel=3,
|
| 17 |
+
**ebsynth_config
|
| 18 |
+
)
|
| 19 |
+
# tasks
|
| 20 |
+
n = len(frames_style)
|
| 21 |
+
tasks = []
|
| 22 |
+
for target in range(n):
|
| 23 |
+
for source in range(target - window_size, target + window_size + 1):
|
| 24 |
+
if source >= 0 and source < n and source != target:
|
| 25 |
+
tasks.append((source, target))
|
| 26 |
+
# run
|
| 27 |
+
frames = [(None, 1) for i in range(n)]
|
| 28 |
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
| 29 |
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
| 30 |
+
source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
|
| 31 |
+
target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
|
| 32 |
+
source_style = np.stack([frames_style[source] for source, target in tasks_batch])
|
| 33 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
| 34 |
+
for (source, target), result in zip(tasks_batch, target_style):
|
| 35 |
+
frame, weight = frames[target]
|
| 36 |
+
if frame is None:
|
| 37 |
+
frame = frames_style[target]
|
| 38 |
+
frames[target] = (
|
| 39 |
+
frame * (weight / (weight + 1)) + result / (weight + 1),
|
| 40 |
+
weight + 1
|
| 41 |
+
)
|
| 42 |
+
if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
|
| 43 |
+
frame = frame.clip(0, 255).astype("uint8")
|
| 44 |
+
if save_path is not None:
|
| 45 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
| 46 |
+
frames[target] = (None, 1)
|
diffsynth/extensions/FastBlend/runners/fast.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..patch_match import PyramidPatchMatcher
|
| 2 |
+
import functools, os
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TableManager:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
def task_list(self, n):
|
| 13 |
+
tasks = []
|
| 14 |
+
max_level = 1
|
| 15 |
+
while (1<<max_level)<=n:
|
| 16 |
+
max_level += 1
|
| 17 |
+
for i in range(n):
|
| 18 |
+
j = i
|
| 19 |
+
for level in range(max_level):
|
| 20 |
+
if i&(1<<level):
|
| 21 |
+
continue
|
| 22 |
+
j |= 1<<level
|
| 23 |
+
if j>=n:
|
| 24 |
+
break
|
| 25 |
+
meta_data = {
|
| 26 |
+
"source": i,
|
| 27 |
+
"target": j,
|
| 28 |
+
"level": level + 1
|
| 29 |
+
}
|
| 30 |
+
tasks.append(meta_data)
|
| 31 |
+
tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
|
| 32 |
+
return tasks
|
| 33 |
+
|
| 34 |
+
def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
|
| 35 |
+
n = len(frames_guide)
|
| 36 |
+
tasks = self.task_list(n)
|
| 37 |
+
remapping_table = [[(frames_style[i], 1)] for i in range(n)]
|
| 38 |
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
| 39 |
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
| 40 |
+
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
|
| 41 |
+
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
|
| 42 |
+
source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
|
| 43 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
| 44 |
+
for task, result in zip(tasks_batch, target_style):
|
| 45 |
+
target, level = task["target"], task["level"]
|
| 46 |
+
if len(remapping_table[target])==level:
|
| 47 |
+
remapping_table[target].append((result, 1))
|
| 48 |
+
else:
|
| 49 |
+
frame, weight = remapping_table[target][level]
|
| 50 |
+
remapping_table[target][level] = (
|
| 51 |
+
frame * (weight / (weight + 1)) + result / (weight + 1),
|
| 52 |
+
weight + 1
|
| 53 |
+
)
|
| 54 |
+
return remapping_table
|
| 55 |
+
|
| 56 |
+
def remapping_table_to_blending_table(self, table):
|
| 57 |
+
for i in range(len(table)):
|
| 58 |
+
for j in range(1, len(table[i])):
|
| 59 |
+
frame_1, weight_1 = table[i][j-1]
|
| 60 |
+
frame_2, weight_2 = table[i][j]
|
| 61 |
+
frame = (frame_1 + frame_2) / 2
|
| 62 |
+
weight = weight_1 + weight_2
|
| 63 |
+
table[i][j] = (frame, weight)
|
| 64 |
+
return table
|
| 65 |
+
|
| 66 |
+
def tree_query(self, leftbound, rightbound):
|
| 67 |
+
node_list = []
|
| 68 |
+
node_index = rightbound
|
| 69 |
+
while node_index>=leftbound:
|
| 70 |
+
node_level = 0
|
| 71 |
+
while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
|
| 72 |
+
node_level += 1
|
| 73 |
+
node_list.append((node_index, node_level))
|
| 74 |
+
node_index -= 1<<node_level
|
| 75 |
+
return node_list
|
| 76 |
+
|
| 77 |
+
def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
|
| 78 |
+
n = len(blending_table)
|
| 79 |
+
tasks = []
|
| 80 |
+
frames_result = []
|
| 81 |
+
for target in range(n):
|
| 82 |
+
node_list = self.tree_query(max(target-window_size, 0), target)
|
| 83 |
+
for source, level in node_list:
|
| 84 |
+
if source!=target:
|
| 85 |
+
meta_data = {
|
| 86 |
+
"source": source,
|
| 87 |
+
"target": target,
|
| 88 |
+
"level": level
|
| 89 |
+
}
|
| 90 |
+
tasks.append(meta_data)
|
| 91 |
+
else:
|
| 92 |
+
frames_result.append(blending_table[target][level])
|
| 93 |
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
| 94 |
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
| 95 |
+
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
|
| 96 |
+
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
|
| 97 |
+
source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
|
| 98 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
| 99 |
+
for task, frame_2 in zip(tasks_batch, target_style):
|
| 100 |
+
source, target, level = task["source"], task["target"], task["level"]
|
| 101 |
+
frame_1, weight_1 = frames_result[target]
|
| 102 |
+
weight_2 = blending_table[source][level][1]
|
| 103 |
+
weight = weight_1 + weight_2
|
| 104 |
+
frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
|
| 105 |
+
frames_result[target] = (frame, weight)
|
| 106 |
+
return frames_result
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class FastModeRunner:
|
| 110 |
+
def __init__(self):
|
| 111 |
+
pass
|
| 112 |
+
|
| 113 |
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
|
| 114 |
+
frames_guide = frames_guide.raw_data()
|
| 115 |
+
frames_style = frames_style.raw_data()
|
| 116 |
+
table_manager = TableManager()
|
| 117 |
+
patch_match_engine = PyramidPatchMatcher(
|
| 118 |
+
image_height=frames_style[0].shape[0],
|
| 119 |
+
image_width=frames_style[0].shape[1],
|
| 120 |
+
channel=3,
|
| 121 |
+
**ebsynth_config
|
| 122 |
+
)
|
| 123 |
+
# left part
|
| 124 |
+
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
|
| 125 |
+
table_l = table_manager.remapping_table_to_blending_table(table_l)
|
| 126 |
+
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
|
| 127 |
+
# right part
|
| 128 |
+
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
|
| 129 |
+
table_r = table_manager.remapping_table_to_blending_table(table_r)
|
| 130 |
+
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
|
| 131 |
+
# merge
|
| 132 |
+
frames = []
|
| 133 |
+
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
|
| 134 |
+
weight_m = -1
|
| 135 |
+
weight = weight_l + weight_m + weight_r
|
| 136 |
+
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
|
| 137 |
+
frames.append(frame)
|
| 138 |
+
frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
|
| 139 |
+
if save_path is not None:
|
| 140 |
+
for target, frame in enumerate(frames):
|
| 141 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
diffsynth/extensions/FastBlend/runners/interpolation.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..patch_match import PyramidPatchMatcher
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class InterpolationModeRunner:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
def get_index_dict(self, index_style):
|
| 13 |
+
index_dict = {}
|
| 14 |
+
for i, index in enumerate(index_style):
|
| 15 |
+
index_dict[index] = i
|
| 16 |
+
return index_dict
|
| 17 |
+
|
| 18 |
+
def get_weight(self, l, m, r):
|
| 19 |
+
weight_l, weight_r = abs(m - r), abs(m - l)
|
| 20 |
+
if weight_l + weight_r == 0:
|
| 21 |
+
weight_l, weight_r = 0.5, 0.5
|
| 22 |
+
else:
|
| 23 |
+
weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
|
| 24 |
+
return weight_l, weight_r
|
| 25 |
+
|
| 26 |
+
def get_task_group(self, index_style, n):
|
| 27 |
+
task_group = []
|
| 28 |
+
index_style = sorted(index_style)
|
| 29 |
+
# first frame
|
| 30 |
+
if index_style[0]>0:
|
| 31 |
+
tasks = []
|
| 32 |
+
for m in range(index_style[0]):
|
| 33 |
+
tasks.append((index_style[0], m, index_style[0]))
|
| 34 |
+
task_group.append(tasks)
|
| 35 |
+
# middle frames
|
| 36 |
+
for l, r in zip(index_style[:-1], index_style[1:]):
|
| 37 |
+
tasks = []
|
| 38 |
+
for m in range(l, r):
|
| 39 |
+
tasks.append((l, m, r))
|
| 40 |
+
task_group.append(tasks)
|
| 41 |
+
# last frame
|
| 42 |
+
tasks = []
|
| 43 |
+
for m in range(index_style[-1], n):
|
| 44 |
+
tasks.append((index_style[-1], m, index_style[-1]))
|
| 45 |
+
task_group.append(tasks)
|
| 46 |
+
return task_group
|
| 47 |
+
|
| 48 |
+
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
|
| 49 |
+
patch_match_engine = PyramidPatchMatcher(
|
| 50 |
+
image_height=frames_style[0].shape[0],
|
| 51 |
+
image_width=frames_style[0].shape[1],
|
| 52 |
+
channel=3,
|
| 53 |
+
use_mean_target_style=False,
|
| 54 |
+
use_pairwise_patch_error=True,
|
| 55 |
+
**ebsynth_config
|
| 56 |
+
)
|
| 57 |
+
# task
|
| 58 |
+
index_dict = self.get_index_dict(index_style)
|
| 59 |
+
task_group = self.get_task_group(index_style, len(frames_guide))
|
| 60 |
+
# run
|
| 61 |
+
for tasks in task_group:
|
| 62 |
+
index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
|
| 63 |
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
|
| 64 |
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
| 65 |
+
source_guide, target_guide, source_style = [], [], []
|
| 66 |
+
for l, m, r in tasks_batch:
|
| 67 |
+
# l -> m
|
| 68 |
+
source_guide.append(frames_guide[l])
|
| 69 |
+
target_guide.append(frames_guide[m])
|
| 70 |
+
source_style.append(frames_style[index_dict[l]])
|
| 71 |
+
# r -> m
|
| 72 |
+
source_guide.append(frames_guide[r])
|
| 73 |
+
target_guide.append(frames_guide[m])
|
| 74 |
+
source_style.append(frames_style[index_dict[r]])
|
| 75 |
+
source_guide = np.stack(source_guide)
|
| 76 |
+
target_guide = np.stack(target_guide)
|
| 77 |
+
source_style = np.stack(source_style)
|
| 78 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
| 79 |
+
if save_path is not None:
|
| 80 |
+
for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
|
| 81 |
+
weight_l, weight_r = self.get_weight(l, m, r)
|
| 82 |
+
frame = frame_l * weight_l + frame_r * weight_r
|
| 83 |
+
frame = frame.clip(0, 255).astype("uint8")
|
| 84 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class InterpolationModeSingleFrameRunner:
|
| 88 |
+
def __init__(self):
|
| 89 |
+
pass
|
| 90 |
+
|
| 91 |
+
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
|
| 92 |
+
# check input
|
| 93 |
+
tracking_window_size = ebsynth_config["tracking_window_size"]
|
| 94 |
+
if tracking_window_size * 2 >= batch_size:
|
| 95 |
+
raise ValueError("batch_size should be larger than track_window_size * 2")
|
| 96 |
+
frame_style = frames_style[0]
|
| 97 |
+
frame_guide = frames_guide[index_style[0]]
|
| 98 |
+
patch_match_engine = PyramidPatchMatcher(
|
| 99 |
+
image_height=frame_style.shape[0],
|
| 100 |
+
image_width=frame_style.shape[1],
|
| 101 |
+
channel=3,
|
| 102 |
+
**ebsynth_config
|
| 103 |
+
)
|
| 104 |
+
# run
|
| 105 |
+
frame_id, n = 0, len(frames_guide)
|
| 106 |
+
for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
|
| 107 |
+
if i + batch_size > n:
|
| 108 |
+
l, r = max(n - batch_size, 0), n
|
| 109 |
+
else:
|
| 110 |
+
l, r = i, i + batch_size
|
| 111 |
+
source_guide = np.stack([frame_guide] * (r-l))
|
| 112 |
+
target_guide = np.stack([frames_guide[i] for i in range(l, r)])
|
| 113 |
+
source_style = np.stack([frame_style] * (r-l))
|
| 114 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
| 115 |
+
for i, frame in zip(range(l, r), target_style):
|
| 116 |
+
if i==frame_id:
|
| 117 |
+
frame = frame.clip(0, 255).astype("uint8")
|
| 118 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
|
| 119 |
+
frame_id += 1
|
| 120 |
+
if r < n and r-frame_id <= tracking_window_size:
|
| 121 |
+
break
|
diffsynth/extensions/ImageQualityMetric/BLIP/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .blip_pretrain import *
|
diffsynth/extensions/ImageQualityMetric/BLIP/blip.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
| 3 |
+
'''
|
| 4 |
+
|
| 5 |
+
import warnings
|
| 6 |
+
warnings.filterwarnings("ignore")
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import os
|
| 10 |
+
from urllib.parse import urlparse
|
| 11 |
+
from timm.models.hub import download_cached_file
|
| 12 |
+
from transformers import BertTokenizer
|
| 13 |
+
from .vit import VisionTransformer, interpolate_pos_embed
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def default_bert():
|
| 17 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 18 |
+
project_root = os.path.abspath(os.path.join(current_dir, '../../../../'))
|
| 19 |
+
model_path = os.path.join(project_root, 'models', 'QualityMetric')
|
| 20 |
+
return os.path.join(model_path, "bert-base-uncased")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def init_tokenizer(bert_model_path):
|
| 24 |
+
tokenizer = BertTokenizer.from_pretrained(bert_model_path)
|
| 25 |
+
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
| 26 |
+
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
| 27 |
+
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
| 28 |
+
return tokenizer
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
|
| 32 |
+
|
| 33 |
+
assert vit in ['base', 'large'], "vit parameter must be base or large"
|
| 34 |
+
if vit=='base':
|
| 35 |
+
vision_width = 768
|
| 36 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
|
| 37 |
+
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
| 38 |
+
drop_path_rate=0 or drop_path_rate
|
| 39 |
+
)
|
| 40 |
+
elif vit=='large':
|
| 41 |
+
vision_width = 1024
|
| 42 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
|
| 43 |
+
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
| 44 |
+
drop_path_rate=0.1 or drop_path_rate
|
| 45 |
+
)
|
| 46 |
+
return visual_encoder, vision_width
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def is_url(url_or_filename):
|
| 50 |
+
parsed = urlparse(url_or_filename)
|
| 51 |
+
return parsed.scheme in ("http", "https")
|
| 52 |
+
|
| 53 |
+
def load_checkpoint(model,url_or_filename):
|
| 54 |
+
if is_url(url_or_filename):
|
| 55 |
+
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
| 56 |
+
checkpoint = torch.load(cached_file, map_location='cpu')
|
| 57 |
+
elif os.path.isfile(url_or_filename):
|
| 58 |
+
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
| 59 |
+
else:
|
| 60 |
+
raise RuntimeError('checkpoint url or path is invalid')
|
| 61 |
+
|
| 62 |
+
state_dict = checkpoint['model']
|
| 63 |
+
|
| 64 |
+
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
| 65 |
+
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
|
| 66 |
+
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
|
| 67 |
+
model.visual_encoder_m)
|
| 68 |
+
for key in model.state_dict().keys():
|
| 69 |
+
if key in state_dict.keys():
|
| 70 |
+
if state_dict[key].shape!=model.state_dict()[key].shape:
|
| 71 |
+
print(key, ": ", state_dict[key].shape, ', ', model.state_dict()[key].shape)
|
| 72 |
+
del state_dict[key]
|
| 73 |
+
|
| 74 |
+
msg = model.load_state_dict(state_dict,strict=False)
|
| 75 |
+
print('load checkpoint from %s'%url_or_filename)
|
| 76 |
+
return model,msg
|
| 77 |
+
|
diffsynth/extensions/ImageQualityMetric/BLIP/blip_pretrain.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
| 3 |
+
'''
|
| 4 |
+
|
| 5 |
+
import transformers
|
| 6 |
+
transformers.logging.set_verbosity_error()
|
| 7 |
+
|
| 8 |
+
from torch import nn
|
| 9 |
+
import os
|
| 10 |
+
from .med import BertConfig, BertModel
|
| 11 |
+
from .blip import create_vit, init_tokenizer
|
| 12 |
+
|
| 13 |
+
class BLIP_Pretrain(nn.Module):
|
| 14 |
+
def __init__(self,
|
| 15 |
+
med_config = "med_config.json",
|
| 16 |
+
image_size = 224,
|
| 17 |
+
vit = 'base',
|
| 18 |
+
vit_grad_ckpt = False,
|
| 19 |
+
vit_ckpt_layer = 0,
|
| 20 |
+
embed_dim = 256,
|
| 21 |
+
queue_size = 57600,
|
| 22 |
+
momentum = 0.995,
|
| 23 |
+
bert_model_path = ""
|
| 24 |
+
):
|
| 25 |
+
"""
|
| 26 |
+
Args:
|
| 27 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 28 |
+
image_size (int): input image size
|
| 29 |
+
vit (str): model size of vision transformer
|
| 30 |
+
"""
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0)
|
| 34 |
+
|
| 35 |
+
self.tokenizer = init_tokenizer(bert_model_path)
|
| 36 |
+
encoder_config = BertConfig.from_json_file(med_config)
|
| 37 |
+
encoder_config.encoder_width = vision_width
|
| 38 |
+
self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
|
| 39 |
+
|
| 40 |
+
text_width = self.text_encoder.config.hidden_size
|
| 41 |
+
|
| 42 |
+
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
| 43 |
+
self.text_proj = nn.Linear(text_width, embed_dim)
|
| 44 |
+
|
diffsynth/extensions/ImageQualityMetric/BLIP/med.py
ADDED
|
@@ -0,0 +1,947 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
| 3 |
+
* Based on huggingface code base
|
| 4 |
+
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
| 5 |
+
'''
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import Tensor, device, nn
|
| 12 |
+
import torch.utils.checkpoint
|
| 13 |
+
from torch import nn
|
| 14 |
+
from torch.nn import CrossEntropyLoss
|
| 15 |
+
|
| 16 |
+
from transformers.activations import ACT2FN
|
| 17 |
+
from transformers.file_utils import (
|
| 18 |
+
ModelOutput,
|
| 19 |
+
)
|
| 20 |
+
from transformers.modeling_outputs import (
|
| 21 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 22 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 23 |
+
CausalLMOutputWithCrossAttentions,
|
| 24 |
+
MaskedLMOutput,
|
| 25 |
+
MultipleChoiceModelOutput,
|
| 26 |
+
NextSentencePredictorOutput,
|
| 27 |
+
QuestionAnsweringModelOutput,
|
| 28 |
+
SequenceClassifierOutput,
|
| 29 |
+
TokenClassifierOutput,
|
| 30 |
+
)
|
| 31 |
+
from transformers.modeling_utils import (
|
| 32 |
+
PreTrainedModel,
|
| 33 |
+
apply_chunking_to_forward,
|
| 34 |
+
find_pruneable_heads_and_indices,
|
| 35 |
+
prune_linear_layer,
|
| 36 |
+
)
|
| 37 |
+
from transformers.utils import logging
|
| 38 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class BertEmbeddings(nn.Module):
|
| 45 |
+
"""Construct the embeddings from word and position embeddings."""
|
| 46 |
+
|
| 47 |
+
def __init__(self, config):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 50 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 51 |
+
|
| 52 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 53 |
+
# any TensorFlow checkpoint file
|
| 54 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 55 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 56 |
+
|
| 57 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 58 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
| 59 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 60 |
+
|
| 61 |
+
self.config = config
|
| 62 |
+
|
| 63 |
+
def forward(
|
| 64 |
+
self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
| 65 |
+
):
|
| 66 |
+
if input_ids is not None:
|
| 67 |
+
input_shape = input_ids.size()
|
| 68 |
+
else:
|
| 69 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 70 |
+
|
| 71 |
+
seq_length = input_shape[1]
|
| 72 |
+
|
| 73 |
+
if position_ids is None:
|
| 74 |
+
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
| 75 |
+
|
| 76 |
+
if inputs_embeds is None:
|
| 77 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 78 |
+
|
| 79 |
+
embeddings = inputs_embeds
|
| 80 |
+
|
| 81 |
+
if self.position_embedding_type == "absolute":
|
| 82 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 83 |
+
embeddings += position_embeddings
|
| 84 |
+
embeddings = self.LayerNorm(embeddings)
|
| 85 |
+
embeddings = self.dropout(embeddings)
|
| 86 |
+
return embeddings
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class BertSelfAttention(nn.Module):
|
| 90 |
+
def __init__(self, config, is_cross_attention):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.config = config
|
| 93 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 94 |
+
raise ValueError(
|
| 95 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
| 96 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
self.num_attention_heads = config.num_attention_heads
|
| 100 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 101 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 102 |
+
|
| 103 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 104 |
+
if is_cross_attention:
|
| 105 |
+
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
| 106 |
+
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
| 107 |
+
else:
|
| 108 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 109 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 110 |
+
|
| 111 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 112 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 113 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 114 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 115 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
| 116 |
+
self.save_attention = False
|
| 117 |
+
|
| 118 |
+
def save_attn_gradients(self, attn_gradients):
|
| 119 |
+
self.attn_gradients = attn_gradients
|
| 120 |
+
|
| 121 |
+
def get_attn_gradients(self):
|
| 122 |
+
return self.attn_gradients
|
| 123 |
+
|
| 124 |
+
def save_attention_map(self, attention_map):
|
| 125 |
+
self.attention_map = attention_map
|
| 126 |
+
|
| 127 |
+
def get_attention_map(self):
|
| 128 |
+
return self.attention_map
|
| 129 |
+
|
| 130 |
+
def transpose_for_scores(self, x):
|
| 131 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 132 |
+
x = x.view(*new_x_shape)
|
| 133 |
+
return x.permute(0, 2, 1, 3)
|
| 134 |
+
|
| 135 |
+
def forward(
|
| 136 |
+
self,
|
| 137 |
+
hidden_states,
|
| 138 |
+
attention_mask=None,
|
| 139 |
+
head_mask=None,
|
| 140 |
+
encoder_hidden_states=None,
|
| 141 |
+
encoder_attention_mask=None,
|
| 142 |
+
past_key_value=None,
|
| 143 |
+
output_attentions=False,
|
| 144 |
+
):
|
| 145 |
+
mixed_query_layer = self.query(hidden_states)
|
| 146 |
+
|
| 147 |
+
# If this is instantiated as a cross-attention module, the keys
|
| 148 |
+
# and values come from an encoder; the attention mask needs to be
|
| 149 |
+
# such that the encoder's padding tokens are not attended to.
|
| 150 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 151 |
+
|
| 152 |
+
if is_cross_attention:
|
| 153 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
| 154 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
| 155 |
+
attention_mask = encoder_attention_mask
|
| 156 |
+
elif past_key_value is not None:
|
| 157 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 158 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 159 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
| 160 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
| 161 |
+
else:
|
| 162 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 163 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 164 |
+
|
| 165 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 166 |
+
|
| 167 |
+
past_key_value = (key_layer, value_layer)
|
| 168 |
+
|
| 169 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 170 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 171 |
+
|
| 172 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 173 |
+
seq_length = hidden_states.size()[1]
|
| 174 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
| 175 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
| 176 |
+
distance = position_ids_l - position_ids_r
|
| 177 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
| 178 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
| 179 |
+
|
| 180 |
+
if self.position_embedding_type == "relative_key":
|
| 181 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 182 |
+
attention_scores = attention_scores + relative_position_scores
|
| 183 |
+
elif self.position_embedding_type == "relative_key_query":
|
| 184 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 185 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
| 186 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
| 187 |
+
|
| 188 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 189 |
+
if attention_mask is not None:
|
| 190 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
| 191 |
+
attention_scores = attention_scores + attention_mask
|
| 192 |
+
|
| 193 |
+
# Normalize the attention scores to probabilities.
|
| 194 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
| 195 |
+
|
| 196 |
+
if is_cross_attention and self.save_attention:
|
| 197 |
+
self.save_attention_map(attention_probs)
|
| 198 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
| 199 |
+
|
| 200 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 201 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 202 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
| 203 |
+
|
| 204 |
+
# Mask heads if we want to
|
| 205 |
+
if head_mask is not None:
|
| 206 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
| 207 |
+
|
| 208 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
| 209 |
+
|
| 210 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 211 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 212 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 213 |
+
|
| 214 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 215 |
+
|
| 216 |
+
outputs = outputs + (past_key_value,)
|
| 217 |
+
return outputs
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class BertSelfOutput(nn.Module):
|
| 221 |
+
def __init__(self, config):
|
| 222 |
+
super().__init__()
|
| 223 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 224 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 225 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 226 |
+
|
| 227 |
+
def forward(self, hidden_states, input_tensor):
|
| 228 |
+
hidden_states = self.dense(hidden_states)
|
| 229 |
+
hidden_states = self.dropout(hidden_states)
|
| 230 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 231 |
+
return hidden_states
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class BertAttention(nn.Module):
|
| 235 |
+
def __init__(self, config, is_cross_attention=False):
|
| 236 |
+
super().__init__()
|
| 237 |
+
self.self = BertSelfAttention(config, is_cross_attention)
|
| 238 |
+
self.output = BertSelfOutput(config)
|
| 239 |
+
self.pruned_heads = set()
|
| 240 |
+
|
| 241 |
+
def prune_heads(self, heads):
|
| 242 |
+
if len(heads) == 0:
|
| 243 |
+
return
|
| 244 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 245 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Prune linear layers
|
| 249 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
| 250 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
| 251 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
| 252 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 253 |
+
|
| 254 |
+
# Update hyper params and store pruned heads
|
| 255 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 256 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
| 257 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 258 |
+
|
| 259 |
+
def forward(
|
| 260 |
+
self,
|
| 261 |
+
hidden_states,
|
| 262 |
+
attention_mask=None,
|
| 263 |
+
head_mask=None,
|
| 264 |
+
encoder_hidden_states=None,
|
| 265 |
+
encoder_attention_mask=None,
|
| 266 |
+
past_key_value=None,
|
| 267 |
+
output_attentions=False,
|
| 268 |
+
):
|
| 269 |
+
self_outputs = self.self(
|
| 270 |
+
hidden_states,
|
| 271 |
+
attention_mask,
|
| 272 |
+
head_mask,
|
| 273 |
+
encoder_hidden_states,
|
| 274 |
+
encoder_attention_mask,
|
| 275 |
+
past_key_value,
|
| 276 |
+
output_attentions,
|
| 277 |
+
)
|
| 278 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 279 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 280 |
+
return outputs
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class BertIntermediate(nn.Module):
|
| 284 |
+
def __init__(self, config):
|
| 285 |
+
super().__init__()
|
| 286 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 287 |
+
if isinstance(config.hidden_act, str):
|
| 288 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 289 |
+
else:
|
| 290 |
+
self.intermediate_act_fn = config.hidden_act
|
| 291 |
+
|
| 292 |
+
def forward(self, hidden_states):
|
| 293 |
+
hidden_states = self.dense(hidden_states)
|
| 294 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 295 |
+
return hidden_states
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class BertOutput(nn.Module):
|
| 299 |
+
def __init__(self, config):
|
| 300 |
+
super().__init__()
|
| 301 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 302 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 303 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 304 |
+
|
| 305 |
+
def forward(self, hidden_states, input_tensor):
|
| 306 |
+
hidden_states = self.dense(hidden_states)
|
| 307 |
+
hidden_states = self.dropout(hidden_states)
|
| 308 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 309 |
+
return hidden_states
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class BertLayer(nn.Module):
|
| 313 |
+
def __init__(self, config, layer_num):
|
| 314 |
+
super().__init__()
|
| 315 |
+
self.config = config
|
| 316 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 317 |
+
self.seq_len_dim = 1
|
| 318 |
+
self.attention = BertAttention(config)
|
| 319 |
+
self.layer_num = layer_num
|
| 320 |
+
if self.config.add_cross_attention:
|
| 321 |
+
self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
|
| 322 |
+
self.intermediate = BertIntermediate(config)
|
| 323 |
+
self.output = BertOutput(config)
|
| 324 |
+
|
| 325 |
+
def forward(
|
| 326 |
+
self,
|
| 327 |
+
hidden_states,
|
| 328 |
+
attention_mask=None,
|
| 329 |
+
head_mask=None,
|
| 330 |
+
encoder_hidden_states=None,
|
| 331 |
+
encoder_attention_mask=None,
|
| 332 |
+
past_key_value=None,
|
| 333 |
+
output_attentions=False,
|
| 334 |
+
mode=None,
|
| 335 |
+
):
|
| 336 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
| 337 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
| 338 |
+
self_attention_outputs = self.attention(
|
| 339 |
+
hidden_states,
|
| 340 |
+
attention_mask,
|
| 341 |
+
head_mask,
|
| 342 |
+
output_attentions=output_attentions,
|
| 343 |
+
past_key_value=self_attn_past_key_value,
|
| 344 |
+
)
|
| 345 |
+
attention_output = self_attention_outputs[0]
|
| 346 |
+
|
| 347 |
+
outputs = self_attention_outputs[1:-1]
|
| 348 |
+
present_key_value = self_attention_outputs[-1]
|
| 349 |
+
|
| 350 |
+
if mode=='multimodal':
|
| 351 |
+
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
|
| 352 |
+
|
| 353 |
+
cross_attention_outputs = self.crossattention(
|
| 354 |
+
attention_output,
|
| 355 |
+
attention_mask,
|
| 356 |
+
head_mask,
|
| 357 |
+
encoder_hidden_states,
|
| 358 |
+
encoder_attention_mask,
|
| 359 |
+
output_attentions=output_attentions,
|
| 360 |
+
)
|
| 361 |
+
attention_output = cross_attention_outputs[0]
|
| 362 |
+
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
| 363 |
+
layer_output = apply_chunking_to_forward(
|
| 364 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
| 365 |
+
)
|
| 366 |
+
outputs = (layer_output,) + outputs
|
| 367 |
+
|
| 368 |
+
outputs = outputs + (present_key_value,)
|
| 369 |
+
|
| 370 |
+
return outputs
|
| 371 |
+
|
| 372 |
+
def feed_forward_chunk(self, attention_output):
|
| 373 |
+
intermediate_output = self.intermediate(attention_output)
|
| 374 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 375 |
+
return layer_output
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class BertEncoder(nn.Module):
|
| 379 |
+
def __init__(self, config):
|
| 380 |
+
super().__init__()
|
| 381 |
+
self.config = config
|
| 382 |
+
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
|
| 383 |
+
self.gradient_checkpointing = False
|
| 384 |
+
|
| 385 |
+
def forward(
|
| 386 |
+
self,
|
| 387 |
+
hidden_states,
|
| 388 |
+
attention_mask=None,
|
| 389 |
+
head_mask=None,
|
| 390 |
+
encoder_hidden_states=None,
|
| 391 |
+
encoder_attention_mask=None,
|
| 392 |
+
past_key_values=None,
|
| 393 |
+
use_cache=None,
|
| 394 |
+
output_attentions=False,
|
| 395 |
+
output_hidden_states=False,
|
| 396 |
+
return_dict=True,
|
| 397 |
+
mode='multimodal',
|
| 398 |
+
):
|
| 399 |
+
all_hidden_states = () if output_hidden_states else None
|
| 400 |
+
all_self_attentions = () if output_attentions else None
|
| 401 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
| 402 |
+
|
| 403 |
+
next_decoder_cache = () if use_cache else None
|
| 404 |
+
|
| 405 |
+
for i in range(self.config.num_hidden_layers):
|
| 406 |
+
layer_module = self.layer[i]
|
| 407 |
+
if output_hidden_states:
|
| 408 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 409 |
+
|
| 410 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 411 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
| 412 |
+
|
| 413 |
+
if self.gradient_checkpointing and self.training:
|
| 414 |
+
|
| 415 |
+
if use_cache:
|
| 416 |
+
logger.warn(
|
| 417 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 418 |
+
)
|
| 419 |
+
use_cache = False
|
| 420 |
+
|
| 421 |
+
def create_custom_forward(module):
|
| 422 |
+
def custom_forward(*inputs):
|
| 423 |
+
return module(*inputs, past_key_value, output_attentions)
|
| 424 |
+
|
| 425 |
+
return custom_forward
|
| 426 |
+
|
| 427 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 428 |
+
create_custom_forward(layer_module),
|
| 429 |
+
hidden_states,
|
| 430 |
+
attention_mask,
|
| 431 |
+
layer_head_mask,
|
| 432 |
+
encoder_hidden_states,
|
| 433 |
+
encoder_attention_mask,
|
| 434 |
+
mode=mode,
|
| 435 |
+
)
|
| 436 |
+
else:
|
| 437 |
+
layer_outputs = layer_module(
|
| 438 |
+
hidden_states,
|
| 439 |
+
attention_mask,
|
| 440 |
+
layer_head_mask,
|
| 441 |
+
encoder_hidden_states,
|
| 442 |
+
encoder_attention_mask,
|
| 443 |
+
past_key_value,
|
| 444 |
+
output_attentions,
|
| 445 |
+
mode=mode,
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
hidden_states = layer_outputs[0]
|
| 449 |
+
if use_cache:
|
| 450 |
+
next_decoder_cache += (layer_outputs[-1],)
|
| 451 |
+
if output_attentions:
|
| 452 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 453 |
+
|
| 454 |
+
if output_hidden_states:
|
| 455 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 456 |
+
|
| 457 |
+
if not return_dict:
|
| 458 |
+
return tuple(
|
| 459 |
+
v
|
| 460 |
+
for v in [
|
| 461 |
+
hidden_states,
|
| 462 |
+
next_decoder_cache,
|
| 463 |
+
all_hidden_states,
|
| 464 |
+
all_self_attentions,
|
| 465 |
+
all_cross_attentions,
|
| 466 |
+
]
|
| 467 |
+
if v is not None
|
| 468 |
+
)
|
| 469 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 470 |
+
last_hidden_state=hidden_states,
|
| 471 |
+
past_key_values=next_decoder_cache,
|
| 472 |
+
hidden_states=all_hidden_states,
|
| 473 |
+
attentions=all_self_attentions,
|
| 474 |
+
cross_attentions=all_cross_attentions,
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
class BertPooler(nn.Module):
|
| 479 |
+
def __init__(self, config):
|
| 480 |
+
super().__init__()
|
| 481 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 482 |
+
self.activation = nn.Tanh()
|
| 483 |
+
|
| 484 |
+
def forward(self, hidden_states):
|
| 485 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 486 |
+
# to the first token.
|
| 487 |
+
first_token_tensor = hidden_states[:, 0]
|
| 488 |
+
pooled_output = self.dense(first_token_tensor)
|
| 489 |
+
pooled_output = self.activation(pooled_output)
|
| 490 |
+
return pooled_output
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class BertPredictionHeadTransform(nn.Module):
|
| 494 |
+
def __init__(self, config):
|
| 495 |
+
super().__init__()
|
| 496 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 497 |
+
if isinstance(config.hidden_act, str):
|
| 498 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 499 |
+
else:
|
| 500 |
+
self.transform_act_fn = config.hidden_act
|
| 501 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 502 |
+
|
| 503 |
+
def forward(self, hidden_states):
|
| 504 |
+
hidden_states = self.dense(hidden_states)
|
| 505 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 506 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 507 |
+
return hidden_states
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
class BertLMPredictionHead(nn.Module):
|
| 511 |
+
def __init__(self, config):
|
| 512 |
+
super().__init__()
|
| 513 |
+
self.transform = BertPredictionHeadTransform(config)
|
| 514 |
+
|
| 515 |
+
# The output weights are the same as the input embeddings, but there is
|
| 516 |
+
# an output-only bias for each token.
|
| 517 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 518 |
+
|
| 519 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 520 |
+
|
| 521 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
| 522 |
+
self.decoder.bias = self.bias
|
| 523 |
+
|
| 524 |
+
def forward(self, hidden_states):
|
| 525 |
+
hidden_states = self.transform(hidden_states)
|
| 526 |
+
hidden_states = self.decoder(hidden_states)
|
| 527 |
+
return hidden_states
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
class BertOnlyMLMHead(nn.Module):
|
| 531 |
+
def __init__(self, config):
|
| 532 |
+
super().__init__()
|
| 533 |
+
self.predictions = BertLMPredictionHead(config)
|
| 534 |
+
|
| 535 |
+
def forward(self, sequence_output):
|
| 536 |
+
prediction_scores = self.predictions(sequence_output)
|
| 537 |
+
return prediction_scores
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
class BertPreTrainedModel(PreTrainedModel):
|
| 541 |
+
"""
|
| 542 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 543 |
+
models.
|
| 544 |
+
"""
|
| 545 |
+
|
| 546 |
+
config_class = BertConfig
|
| 547 |
+
base_model_prefix = "bert"
|
| 548 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
| 549 |
+
|
| 550 |
+
def _init_weights(self, module):
|
| 551 |
+
""" Initialize the weights """
|
| 552 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 553 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 554 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 555 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 556 |
+
elif isinstance(module, nn.LayerNorm):
|
| 557 |
+
module.bias.data.zero_()
|
| 558 |
+
module.weight.data.fill_(1.0)
|
| 559 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 560 |
+
module.bias.data.zero_()
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
class BertModel(BertPreTrainedModel):
|
| 564 |
+
"""
|
| 565 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
| 566 |
+
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
| 567 |
+
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
| 568 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
| 569 |
+
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
| 570 |
+
input to the forward pass.
|
| 571 |
+
"""
|
| 572 |
+
|
| 573 |
+
def __init__(self, config, add_pooling_layer=True):
|
| 574 |
+
super().__init__(config)
|
| 575 |
+
self.config = config
|
| 576 |
+
|
| 577 |
+
self.embeddings = BertEmbeddings(config)
|
| 578 |
+
|
| 579 |
+
self.encoder = BertEncoder(config)
|
| 580 |
+
|
| 581 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
| 582 |
+
|
| 583 |
+
self.init_weights()
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def get_input_embeddings(self):
|
| 587 |
+
return self.embeddings.word_embeddings
|
| 588 |
+
|
| 589 |
+
def set_input_embeddings(self, value):
|
| 590 |
+
self.embeddings.word_embeddings = value
|
| 591 |
+
|
| 592 |
+
def _prune_heads(self, heads_to_prune):
|
| 593 |
+
"""
|
| 594 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 595 |
+
class PreTrainedModel
|
| 596 |
+
"""
|
| 597 |
+
for layer, heads in heads_to_prune.items():
|
| 598 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
|
| 602 |
+
"""
|
| 603 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
| 604 |
+
|
| 605 |
+
Arguments:
|
| 606 |
+
attention_mask (:obj:`torch.Tensor`):
|
| 607 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
| 608 |
+
input_shape (:obj:`Tuple[int]`):
|
| 609 |
+
The shape of the input to the model.
|
| 610 |
+
device: (:obj:`torch.device`):
|
| 611 |
+
The device of the input to the model.
|
| 612 |
+
|
| 613 |
+
Returns:
|
| 614 |
+
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
| 615 |
+
"""
|
| 616 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 617 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 618 |
+
if attention_mask.dim() == 3:
|
| 619 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
| 620 |
+
elif attention_mask.dim() == 2:
|
| 621 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
| 622 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
| 623 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 624 |
+
if is_decoder:
|
| 625 |
+
batch_size, seq_length = input_shape
|
| 626 |
+
|
| 627 |
+
seq_ids = torch.arange(seq_length, device=device)
|
| 628 |
+
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
| 629 |
+
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
| 630 |
+
# causal and attention masks must have same type with pytorch version < 1.3
|
| 631 |
+
causal_mask = causal_mask.to(attention_mask.dtype)
|
| 632 |
+
|
| 633 |
+
if causal_mask.shape[1] < attention_mask.shape[1]:
|
| 634 |
+
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
| 635 |
+
causal_mask = torch.cat(
|
| 636 |
+
[
|
| 637 |
+
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
|
| 638 |
+
causal_mask,
|
| 639 |
+
],
|
| 640 |
+
axis=-1,
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
| 644 |
+
else:
|
| 645 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
| 646 |
+
else:
|
| 647 |
+
raise ValueError(
|
| 648 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
| 649 |
+
input_shape, attention_mask.shape
|
| 650 |
+
)
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 654 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 655 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 656 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 657 |
+
# effectively the same as removing these entirely.
|
| 658 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
| 659 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
| 660 |
+
return extended_attention_mask
|
| 661 |
+
|
| 662 |
+
def forward(
|
| 663 |
+
self,
|
| 664 |
+
input_ids=None,
|
| 665 |
+
attention_mask=None,
|
| 666 |
+
position_ids=None,
|
| 667 |
+
head_mask=None,
|
| 668 |
+
inputs_embeds=None,
|
| 669 |
+
encoder_embeds=None,
|
| 670 |
+
encoder_hidden_states=None,
|
| 671 |
+
encoder_attention_mask=None,
|
| 672 |
+
past_key_values=None,
|
| 673 |
+
use_cache=None,
|
| 674 |
+
output_attentions=None,
|
| 675 |
+
output_hidden_states=None,
|
| 676 |
+
return_dict=None,
|
| 677 |
+
is_decoder=False,
|
| 678 |
+
mode='multimodal',
|
| 679 |
+
):
|
| 680 |
+
r"""
|
| 681 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
| 682 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 683 |
+
the model is configured as a decoder.
|
| 684 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 685 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 686 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
| 687 |
+
- 1 for tokens that are **not masked**,
|
| 688 |
+
- 0 for tokens that are **masked**.
|
| 689 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 690 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
| 691 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
| 692 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
| 693 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
| 694 |
+
use_cache (:obj:`bool`, `optional`):
|
| 695 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
| 696 |
+
decoding (see :obj:`past_key_values`).
|
| 697 |
+
"""
|
| 698 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 699 |
+
output_hidden_states = (
|
| 700 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 701 |
+
)
|
| 702 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 703 |
+
|
| 704 |
+
if is_decoder:
|
| 705 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 706 |
+
else:
|
| 707 |
+
use_cache = False
|
| 708 |
+
|
| 709 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 710 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 711 |
+
elif input_ids is not None:
|
| 712 |
+
input_shape = input_ids.size()
|
| 713 |
+
batch_size, seq_length = input_shape
|
| 714 |
+
device = input_ids.device
|
| 715 |
+
elif inputs_embeds is not None:
|
| 716 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 717 |
+
batch_size, seq_length = input_shape
|
| 718 |
+
device = inputs_embeds.device
|
| 719 |
+
elif encoder_embeds is not None:
|
| 720 |
+
input_shape = encoder_embeds.size()[:-1]
|
| 721 |
+
batch_size, seq_length = input_shape
|
| 722 |
+
device = encoder_embeds.device
|
| 723 |
+
else:
|
| 724 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
|
| 725 |
+
|
| 726 |
+
# past_key_values_length
|
| 727 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 728 |
+
|
| 729 |
+
if attention_mask is None:
|
| 730 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
| 731 |
+
|
| 732 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 733 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 734 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
|
| 735 |
+
device, is_decoder)
|
| 736 |
+
|
| 737 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 738 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 739 |
+
if encoder_hidden_states is not None:
|
| 740 |
+
if type(encoder_hidden_states) == list:
|
| 741 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
| 742 |
+
else:
|
| 743 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
| 744 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 745 |
+
|
| 746 |
+
if type(encoder_attention_mask) == list:
|
| 747 |
+
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
| 748 |
+
elif encoder_attention_mask is None:
|
| 749 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 750 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 751 |
+
else:
|
| 752 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 753 |
+
else:
|
| 754 |
+
encoder_extended_attention_mask = None
|
| 755 |
+
|
| 756 |
+
# Prepare head mask if needed
|
| 757 |
+
# 1.0 in head_mask indicate we keep the head
|
| 758 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 759 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 760 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 761 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 762 |
+
|
| 763 |
+
if encoder_embeds is None:
|
| 764 |
+
embedding_output = self.embeddings(
|
| 765 |
+
input_ids=input_ids,
|
| 766 |
+
position_ids=position_ids,
|
| 767 |
+
inputs_embeds=inputs_embeds,
|
| 768 |
+
past_key_values_length=past_key_values_length,
|
| 769 |
+
)
|
| 770 |
+
else:
|
| 771 |
+
embedding_output = encoder_embeds
|
| 772 |
+
|
| 773 |
+
encoder_outputs = self.encoder(
|
| 774 |
+
embedding_output,
|
| 775 |
+
attention_mask=extended_attention_mask,
|
| 776 |
+
head_mask=head_mask,
|
| 777 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 778 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
| 779 |
+
past_key_values=past_key_values,
|
| 780 |
+
use_cache=use_cache,
|
| 781 |
+
output_attentions=output_attentions,
|
| 782 |
+
output_hidden_states=output_hidden_states,
|
| 783 |
+
return_dict=return_dict,
|
| 784 |
+
mode=mode,
|
| 785 |
+
)
|
| 786 |
+
sequence_output = encoder_outputs[0]
|
| 787 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 788 |
+
|
| 789 |
+
if not return_dict:
|
| 790 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 791 |
+
|
| 792 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 793 |
+
last_hidden_state=sequence_output,
|
| 794 |
+
pooler_output=pooled_output,
|
| 795 |
+
past_key_values=encoder_outputs.past_key_values,
|
| 796 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 797 |
+
attentions=encoder_outputs.attentions,
|
| 798 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
|
| 803 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
| 804 |
+
|
| 805 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
| 806 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
| 807 |
+
|
| 808 |
+
def __init__(self, config):
|
| 809 |
+
super().__init__(config)
|
| 810 |
+
|
| 811 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
| 812 |
+
self.cls = BertOnlyMLMHead(config)
|
| 813 |
+
|
| 814 |
+
self.init_weights()
|
| 815 |
+
|
| 816 |
+
def get_output_embeddings(self):
|
| 817 |
+
return self.cls.predictions.decoder
|
| 818 |
+
|
| 819 |
+
def set_output_embeddings(self, new_embeddings):
|
| 820 |
+
self.cls.predictions.decoder = new_embeddings
|
| 821 |
+
|
| 822 |
+
def forward(
|
| 823 |
+
self,
|
| 824 |
+
input_ids=None,
|
| 825 |
+
attention_mask=None,
|
| 826 |
+
position_ids=None,
|
| 827 |
+
head_mask=None,
|
| 828 |
+
inputs_embeds=None,
|
| 829 |
+
encoder_hidden_states=None,
|
| 830 |
+
encoder_attention_mask=None,
|
| 831 |
+
labels=None,
|
| 832 |
+
past_key_values=None,
|
| 833 |
+
use_cache=None,
|
| 834 |
+
output_attentions=None,
|
| 835 |
+
output_hidden_states=None,
|
| 836 |
+
return_dict=None,
|
| 837 |
+
return_logits=False,
|
| 838 |
+
is_decoder=True,
|
| 839 |
+
reduction='mean',
|
| 840 |
+
mode='multimodal',
|
| 841 |
+
):
|
| 842 |
+
r"""
|
| 843 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
| 844 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 845 |
+
the model is configured as a decoder.
|
| 846 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 847 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 848 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
| 849 |
+
- 1 for tokens that are **not masked**,
|
| 850 |
+
- 0 for tokens that are **masked**.
|
| 851 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 852 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
| 853 |
+
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
| 854 |
+
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
| 855 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
| 856 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
| 857 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
| 858 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
| 859 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
| 860 |
+
use_cache (:obj:`bool`, `optional`):
|
| 861 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
| 862 |
+
decoding (see :obj:`past_key_values`).
|
| 863 |
+
Returns:
|
| 864 |
+
Example::
|
| 865 |
+
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
| 866 |
+
>>> import torch
|
| 867 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
| 868 |
+
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
| 869 |
+
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
| 870 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 871 |
+
>>> outputs = model(**inputs)
|
| 872 |
+
>>> prediction_logits = outputs.logits
|
| 873 |
+
"""
|
| 874 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 875 |
+
if labels is not None:
|
| 876 |
+
use_cache = False
|
| 877 |
+
|
| 878 |
+
outputs = self.bert(
|
| 879 |
+
input_ids,
|
| 880 |
+
attention_mask=attention_mask,
|
| 881 |
+
position_ids=position_ids,
|
| 882 |
+
head_mask=head_mask,
|
| 883 |
+
inputs_embeds=inputs_embeds,
|
| 884 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 885 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 886 |
+
past_key_values=past_key_values,
|
| 887 |
+
use_cache=use_cache,
|
| 888 |
+
output_attentions=output_attentions,
|
| 889 |
+
output_hidden_states=output_hidden_states,
|
| 890 |
+
return_dict=return_dict,
|
| 891 |
+
is_decoder=is_decoder,
|
| 892 |
+
mode=mode,
|
| 893 |
+
)
|
| 894 |
+
|
| 895 |
+
sequence_output = outputs[0]
|
| 896 |
+
prediction_scores = self.cls(sequence_output)
|
| 897 |
+
|
| 898 |
+
if return_logits:
|
| 899 |
+
return prediction_scores[:, :-1, :].contiguous()
|
| 900 |
+
|
| 901 |
+
lm_loss = None
|
| 902 |
+
if labels is not None:
|
| 903 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
| 904 |
+
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
| 905 |
+
labels = labels[:, 1:].contiguous()
|
| 906 |
+
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
| 907 |
+
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 908 |
+
if reduction=='none':
|
| 909 |
+
lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
|
| 910 |
+
|
| 911 |
+
if not return_dict:
|
| 912 |
+
output = (prediction_scores,) + outputs[2:]
|
| 913 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
| 914 |
+
|
| 915 |
+
return CausalLMOutputWithCrossAttentions(
|
| 916 |
+
loss=lm_loss,
|
| 917 |
+
logits=prediction_scores,
|
| 918 |
+
past_key_values=outputs.past_key_values,
|
| 919 |
+
hidden_states=outputs.hidden_states,
|
| 920 |
+
attentions=outputs.attentions,
|
| 921 |
+
cross_attentions=outputs.cross_attentions,
|
| 922 |
+
)
|
| 923 |
+
|
| 924 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
| 925 |
+
input_shape = input_ids.shape
|
| 926 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
| 927 |
+
if attention_mask is None:
|
| 928 |
+
attention_mask = input_ids.new_ones(input_shape)
|
| 929 |
+
|
| 930 |
+
# cut decoder_input_ids if past is used
|
| 931 |
+
if past is not None:
|
| 932 |
+
input_ids = input_ids[:, -1:]
|
| 933 |
+
|
| 934 |
+
return {
|
| 935 |
+
"input_ids": input_ids,
|
| 936 |
+
"attention_mask": attention_mask,
|
| 937 |
+
"past_key_values": past,
|
| 938 |
+
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
| 939 |
+
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
| 940 |
+
"is_decoder": True,
|
| 941 |
+
}
|
| 942 |
+
|
| 943 |
+
def _reorder_cache(self, past, beam_idx):
|
| 944 |
+
reordered_past = ()
|
| 945 |
+
for layer_past in past:
|
| 946 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
| 947 |
+
return reordered_past
|
diffsynth/extensions/ImageQualityMetric/BLIP/vit.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
* Adapted from BLIP (https://github.com/salesforce/BLIP)
|
| 3 |
+
* Based on timm code base
|
| 4 |
+
* https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 5 |
+
'''
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from functools import partial
|
| 11 |
+
|
| 12 |
+
from timm.models.vision_transformer import _cfg, PatchEmbed
|
| 13 |
+
from timm.models.registry import register_model
|
| 14 |
+
from timm.models.layers import trunc_normal_, DropPath
|
| 15 |
+
from timm.models.helpers import named_apply, adapt_input_conv
|
| 16 |
+
|
| 17 |
+
# from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
| 18 |
+
|
| 19 |
+
class Mlp(nn.Module):
|
| 20 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
| 21 |
+
"""
|
| 22 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 23 |
+
super().__init__()
|
| 24 |
+
out_features = out_features or in_features
|
| 25 |
+
hidden_features = hidden_features or in_features
|
| 26 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 27 |
+
self.act = act_layer()
|
| 28 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 29 |
+
self.drop = nn.Dropout(drop)
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
x = self.fc1(x)
|
| 33 |
+
x = self.act(x)
|
| 34 |
+
x = self.drop(x)
|
| 35 |
+
x = self.fc2(x)
|
| 36 |
+
x = self.drop(x)
|
| 37 |
+
return x
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class Attention(nn.Module):
|
| 41 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.num_heads = num_heads
|
| 44 |
+
head_dim = dim // num_heads
|
| 45 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
| 46 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 47 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 48 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 49 |
+
self.proj = nn.Linear(dim, dim)
|
| 50 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 51 |
+
self.attn_gradients = None
|
| 52 |
+
self.attention_map = None
|
| 53 |
+
|
| 54 |
+
def save_attn_gradients(self, attn_gradients):
|
| 55 |
+
self.attn_gradients = attn_gradients
|
| 56 |
+
|
| 57 |
+
def get_attn_gradients(self):
|
| 58 |
+
return self.attn_gradients
|
| 59 |
+
|
| 60 |
+
def save_attention_map(self, attention_map):
|
| 61 |
+
self.attention_map = attention_map
|
| 62 |
+
|
| 63 |
+
def get_attention_map(self):
|
| 64 |
+
return self.attention_map
|
| 65 |
+
|
| 66 |
+
def forward(self, x, register_hook=False):
|
| 67 |
+
B, N, C = x.shape
|
| 68 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 69 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 70 |
+
|
| 71 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 72 |
+
attn = attn.softmax(dim=-1)
|
| 73 |
+
attn = self.attn_drop(attn)
|
| 74 |
+
|
| 75 |
+
if register_hook:
|
| 76 |
+
self.save_attention_map(attn)
|
| 77 |
+
attn.register_hook(self.save_attn_gradients)
|
| 78 |
+
|
| 79 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 80 |
+
x = self.proj(x)
|
| 81 |
+
x = self.proj_drop(x)
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class Block(nn.Module):
|
| 86 |
+
|
| 87 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 88 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.norm1 = norm_layer(dim)
|
| 91 |
+
self.attn = Attention(
|
| 92 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 93 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 94 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 95 |
+
self.norm2 = norm_layer(dim)
|
| 96 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 97 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 98 |
+
|
| 99 |
+
# if use_grad_checkpointing:
|
| 100 |
+
# self.attn = checkpoint_wrapper(self.attn)
|
| 101 |
+
# self.mlp = checkpoint_wrapper(self.mlp)
|
| 102 |
+
|
| 103 |
+
def forward(self, x, register_hook=False):
|
| 104 |
+
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
|
| 105 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 106 |
+
return x
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class VisionTransformer(nn.Module):
|
| 110 |
+
""" Vision Transformer
|
| 111 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
| 112 |
+
https://arxiv.org/abs/2010.11929
|
| 113 |
+
"""
|
| 114 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
| 115 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
| 116 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
|
| 117 |
+
use_grad_checkpointing=False, ckpt_layer=0):
|
| 118 |
+
"""
|
| 119 |
+
Args:
|
| 120 |
+
img_size (int, tuple): input image size
|
| 121 |
+
patch_size (int, tuple): patch size
|
| 122 |
+
in_chans (int): number of input channels
|
| 123 |
+
num_classes (int): number of classes for classification head
|
| 124 |
+
embed_dim (int): embedding dimension
|
| 125 |
+
depth (int): depth of transformer
|
| 126 |
+
num_heads (int): number of attention heads
|
| 127 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 128 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 129 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
| 130 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
| 131 |
+
drop_rate (float): dropout rate
|
| 132 |
+
attn_drop_rate (float): attention dropout rate
|
| 133 |
+
drop_path_rate (float): stochastic depth rate
|
| 134 |
+
norm_layer: (nn.Module): normalization layer
|
| 135 |
+
"""
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 138 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
| 139 |
+
|
| 140 |
+
self.patch_embed = PatchEmbed(
|
| 141 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 142 |
+
|
| 143 |
+
num_patches = self.patch_embed.num_patches
|
| 144 |
+
|
| 145 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 146 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
| 147 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 148 |
+
|
| 149 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 150 |
+
self.blocks = nn.ModuleList([
|
| 151 |
+
Block(
|
| 152 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 153 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
| 154 |
+
use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
|
| 155 |
+
)
|
| 156 |
+
for i in range(depth)])
|
| 157 |
+
self.norm = norm_layer(embed_dim)
|
| 158 |
+
|
| 159 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 160 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 161 |
+
self.apply(self._init_weights)
|
| 162 |
+
|
| 163 |
+
def _init_weights(self, m):
|
| 164 |
+
if isinstance(m, nn.Linear):
|
| 165 |
+
trunc_normal_(m.weight, std=.02)
|
| 166 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 167 |
+
nn.init.constant_(m.bias, 0)
|
| 168 |
+
elif isinstance(m, nn.LayerNorm):
|
| 169 |
+
nn.init.constant_(m.bias, 0)
|
| 170 |
+
nn.init.constant_(m.weight, 1.0)
|
| 171 |
+
|
| 172 |
+
@torch.jit.ignore
|
| 173 |
+
def no_weight_decay(self):
|
| 174 |
+
return {'pos_embed', 'cls_token'}
|
| 175 |
+
|
| 176 |
+
def forward(self, x, register_blk=-1):
|
| 177 |
+
B = x.shape[0]
|
| 178 |
+
x = self.patch_embed(x)
|
| 179 |
+
|
| 180 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
| 181 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 182 |
+
|
| 183 |
+
x = x + self.pos_embed[:,:x.size(1),:]
|
| 184 |
+
x = self.pos_drop(x)
|
| 185 |
+
|
| 186 |
+
for i,blk in enumerate(self.blocks):
|
| 187 |
+
x = blk(x, register_blk==i)
|
| 188 |
+
x = self.norm(x)
|
| 189 |
+
|
| 190 |
+
return x
|
| 191 |
+
|
| 192 |
+
@torch.jit.ignore()
|
| 193 |
+
def load_pretrained(self, checkpoint_path, prefix=''):
|
| 194 |
+
_load_weights(self, checkpoint_path, prefix)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
@torch.no_grad()
|
| 198 |
+
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
|
| 199 |
+
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
|
| 200 |
+
"""
|
| 201 |
+
import numpy as np
|
| 202 |
+
|
| 203 |
+
def _n2p(w, t=True):
|
| 204 |
+
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
| 205 |
+
w = w.flatten()
|
| 206 |
+
if t:
|
| 207 |
+
if w.ndim == 4:
|
| 208 |
+
w = w.transpose([3, 2, 0, 1])
|
| 209 |
+
elif w.ndim == 3:
|
| 210 |
+
w = w.transpose([2, 0, 1])
|
| 211 |
+
elif w.ndim == 2:
|
| 212 |
+
w = w.transpose([1, 0])
|
| 213 |
+
return torch.from_numpy(w)
|
| 214 |
+
|
| 215 |
+
w = np.load(checkpoint_path)
|
| 216 |
+
if not prefix and 'opt/target/embedding/kernel' in w:
|
| 217 |
+
prefix = 'opt/target/'
|
| 218 |
+
|
| 219 |
+
if hasattr(model.patch_embed, 'backbone'):
|
| 220 |
+
# hybrid
|
| 221 |
+
backbone = model.patch_embed.backbone
|
| 222 |
+
stem_only = not hasattr(backbone, 'stem')
|
| 223 |
+
stem = backbone if stem_only else backbone.stem
|
| 224 |
+
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
|
| 225 |
+
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
|
| 226 |
+
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
|
| 227 |
+
if not stem_only:
|
| 228 |
+
for i, stage in enumerate(backbone.stages):
|
| 229 |
+
for j, block in enumerate(stage.blocks):
|
| 230 |
+
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
|
| 231 |
+
for r in range(3):
|
| 232 |
+
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
|
| 233 |
+
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
|
| 234 |
+
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
|
| 235 |
+
if block.downsample is not None:
|
| 236 |
+
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
|
| 237 |
+
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
|
| 238 |
+
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
|
| 239 |
+
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
|
| 240 |
+
else:
|
| 241 |
+
embed_conv_w = adapt_input_conv(
|
| 242 |
+
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
| 243 |
+
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
| 244 |
+
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
| 245 |
+
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
| 246 |
+
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
| 247 |
+
if pos_embed_w.shape != model.pos_embed.shape:
|
| 248 |
+
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
| 249 |
+
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
| 250 |
+
model.pos_embed.copy_(pos_embed_w)
|
| 251 |
+
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
| 252 |
+
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
| 253 |
+
# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
| 254 |
+
# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
| 255 |
+
# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
| 256 |
+
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
| 257 |
+
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
| 258 |
+
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
| 259 |
+
for i, block in enumerate(model.blocks.children()):
|
| 260 |
+
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
| 261 |
+
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
| 262 |
+
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
| 263 |
+
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
| 264 |
+
block.attn.qkv.weight.copy_(torch.cat([
|
| 265 |
+
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
| 266 |
+
block.attn.qkv.bias.copy_(torch.cat([
|
| 267 |
+
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
| 268 |
+
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
| 269 |
+
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
| 270 |
+
for r in range(2):
|
| 271 |
+
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
|
| 272 |
+
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
|
| 273 |
+
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
|
| 274 |
+
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
|
| 278 |
+
# interpolate position embedding
|
| 279 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
| 280 |
+
num_patches = visual_encoder.patch_embed.num_patches
|
| 281 |
+
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
|
| 282 |
+
# height (== width) for the checkpoint position embedding
|
| 283 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
| 284 |
+
# height (== width) for the new position embedding
|
| 285 |
+
new_size = int(num_patches ** 0.5)
|
| 286 |
+
|
| 287 |
+
if orig_size!=new_size:
|
| 288 |
+
# class_token and dist_token are kept unchanged
|
| 289 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
| 290 |
+
# only the position tokens are interpolated
|
| 291 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
| 292 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
| 293 |
+
pos_tokens = torch.nn.functional.interpolate(
|
| 294 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
| 295 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
| 296 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
| 297 |
+
print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
|
| 298 |
+
|
| 299 |
+
return new_pos_embed
|
| 300 |
+
else:
|
| 301 |
+
return pos_embed_checkpoint
|
diffsynth/extensions/ImageQualityMetric/__init__.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from modelscope import snapshot_download
|
| 2 |
+
from typing_extensions import Literal, TypeAlias
|
| 3 |
+
import os
|
| 4 |
+
from diffsynth.extensions.ImageQualityMetric.aesthetic import AestheticScore
|
| 5 |
+
from diffsynth.extensions.ImageQualityMetric.imagereward import ImageRewardScore
|
| 6 |
+
from diffsynth.extensions.ImageQualityMetric.pickscore import PickScore
|
| 7 |
+
from diffsynth.extensions.ImageQualityMetric.clip import CLIPScore
|
| 8 |
+
from diffsynth.extensions.ImageQualityMetric.hps import HPScore_v2
|
| 9 |
+
from diffsynth.extensions.ImageQualityMetric.mps import MPScore
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
preference_model_id: TypeAlias = Literal[
|
| 13 |
+
"ImageReward",
|
| 14 |
+
"Aesthetic",
|
| 15 |
+
"PickScore",
|
| 16 |
+
"CLIP",
|
| 17 |
+
"HPSv2",
|
| 18 |
+
"HPSv2.1",
|
| 19 |
+
"MPS",
|
| 20 |
+
]
|
| 21 |
+
model_dict = {
|
| 22 |
+
"ImageReward": {
|
| 23 |
+
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
| 24 |
+
"allow_file_pattern": [
|
| 25 |
+
"ImageReward/ImageReward.safetensors",
|
| 26 |
+
"ImageReward/med_config.json",
|
| 27 |
+
"bert-base-uncased/config.json",
|
| 28 |
+
"bert-base-uncased/model.safetensors",
|
| 29 |
+
"bert-base-uncased/tokenizer.json",
|
| 30 |
+
"bert-base-uncased/tokenizer_config.json",
|
| 31 |
+
"bert-base-uncased/vocab.txt",
|
| 32 |
+
],
|
| 33 |
+
"load_path": {
|
| 34 |
+
"imagereward": "ImageReward/ImageReward.safetensors",
|
| 35 |
+
"med_config": "ImageReward/med_config.json",
|
| 36 |
+
"bert_model_path": "bert-base-uncased",
|
| 37 |
+
},
|
| 38 |
+
"model_class": ImageRewardScore
|
| 39 |
+
},
|
| 40 |
+
"Aesthetic": {
|
| 41 |
+
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
| 42 |
+
"allow_file_pattern": [
|
| 43 |
+
"aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
|
| 44 |
+
"clip-vit-large-patch14/config.json",
|
| 45 |
+
"clip-vit-large-patch14/merges.txt",
|
| 46 |
+
"clip-vit-large-patch14/model.safetensors",
|
| 47 |
+
"clip-vit-large-patch14/preprocessor_config.json",
|
| 48 |
+
"clip-vit-large-patch14/special_tokens_map.json",
|
| 49 |
+
"clip-vit-large-patch14/tokenizer.json",
|
| 50 |
+
"clip-vit-large-patch14/tokenizer_config.json",
|
| 51 |
+
"clip-vit-large-patch14/vocab.json",
|
| 52 |
+
],
|
| 53 |
+
"load_path": {
|
| 54 |
+
"aesthetic_predictor": "aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors",
|
| 55 |
+
"clip-large": "clip-vit-large-patch14",
|
| 56 |
+
},
|
| 57 |
+
"model_class": AestheticScore
|
| 58 |
+
},
|
| 59 |
+
"PickScore": {
|
| 60 |
+
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
| 61 |
+
"allow_file_pattern": [
|
| 62 |
+
"PickScore_v1/*",
|
| 63 |
+
"CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
|
| 64 |
+
"CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
|
| 65 |
+
"CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
|
| 66 |
+
"CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
|
| 67 |
+
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
|
| 68 |
+
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
|
| 69 |
+
"CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
|
| 70 |
+
],
|
| 71 |
+
"load_path": {
|
| 72 |
+
"pickscore": "PickScore_v1",
|
| 73 |
+
"clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
|
| 74 |
+
},
|
| 75 |
+
"model_class": PickScore
|
| 76 |
+
},
|
| 77 |
+
"CLIP": {
|
| 78 |
+
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
| 79 |
+
"allow_file_pattern": [
|
| 80 |
+
"CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
|
| 81 |
+
"bpe_simple_vocab_16e6.txt.gz",
|
| 82 |
+
],
|
| 83 |
+
"load_path": {
|
| 84 |
+
"open_clip": "CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
|
| 85 |
+
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
|
| 86 |
+
},
|
| 87 |
+
"model_class": CLIPScore
|
| 88 |
+
},
|
| 89 |
+
"HPSv2": {
|
| 90 |
+
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
| 91 |
+
"allow_file_pattern": [
|
| 92 |
+
"HPS_v2/HPS_v2_compressed.safetensors",
|
| 93 |
+
"bpe_simple_vocab_16e6.txt.gz",
|
| 94 |
+
],
|
| 95 |
+
"load_path": {
|
| 96 |
+
"hpsv2": "HPS_v2/HPS_v2_compressed.safetensors",
|
| 97 |
+
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
|
| 98 |
+
},
|
| 99 |
+
"model_class": HPScore_v2,
|
| 100 |
+
"extra_kwargs": {"model_version": "v2"}
|
| 101 |
+
},
|
| 102 |
+
"HPSv2.1": {
|
| 103 |
+
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
| 104 |
+
"allow_file_pattern": [
|
| 105 |
+
"HPS_v2/HPS_v2.1_compressed.safetensors",
|
| 106 |
+
"bpe_simple_vocab_16e6.txt.gz",
|
| 107 |
+
],
|
| 108 |
+
"load_path": {
|
| 109 |
+
"hpsv2.1": "HPS_v2/HPS_v2.1_compressed.safetensors",
|
| 110 |
+
"open_clip_bpe": "bpe_simple_vocab_16e6.txt.gz",
|
| 111 |
+
},
|
| 112 |
+
"model_class": HPScore_v2,
|
| 113 |
+
"extra_kwargs": {"model_version": "v21"}
|
| 114 |
+
},
|
| 115 |
+
"MPS": {
|
| 116 |
+
"model_id": "DiffSynth-Studio/QualityMetric_reward_pretrained",
|
| 117 |
+
"allow_file_pattern": [
|
| 118 |
+
"MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
|
| 119 |
+
"CLIP-ViT-H-14-laion2B-s32B-b79K/config.json",
|
| 120 |
+
"CLIP-ViT-H-14-laion2B-s32B-b79K/merges.txt",
|
| 121 |
+
"CLIP-ViT-H-14-laion2B-s32B-b79K/preprocessor_config.json",
|
| 122 |
+
"CLIP-ViT-H-14-laion2B-s32B-b79K/special_tokens_map.json",
|
| 123 |
+
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer.json",
|
| 124 |
+
"CLIP-ViT-H-14-laion2B-s32B-b79K/tokenizer_config.json",
|
| 125 |
+
"CLIP-ViT-H-14-laion2B-s32B-b79K/vocab.json",
|
| 126 |
+
],
|
| 127 |
+
"load_path": {
|
| 128 |
+
"mps": "MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors",
|
| 129 |
+
"clip": "CLIP-ViT-H-14-laion2B-s32B-b79K",
|
| 130 |
+
},
|
| 131 |
+
"model_class": MPScore
|
| 132 |
+
},
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def download_preference_model(model_name: preference_model_id, cache_dir="models"):
|
| 137 |
+
metadata = model_dict[model_name]
|
| 138 |
+
snapshot_download(model_id=metadata["model_id"], allow_file_pattern=metadata["allow_file_pattern"], cache_dir=cache_dir)
|
| 139 |
+
load_path = metadata["load_path"]
|
| 140 |
+
load_path = {key: os.path.join(cache_dir, metadata["model_id"], path) for key, path in load_path.items()}
|
| 141 |
+
return load_path
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def load_preference_model(model_name: preference_model_id, device = "cuda", path = None):
|
| 145 |
+
model_class = model_dict[model_name]["model_class"]
|
| 146 |
+
extra_kwargs = model_dict[model_name].get("extra_kwargs", {})
|
| 147 |
+
preference_model = model_class(device=device, path=path, **extra_kwargs)
|
| 148 |
+
return preference_model
|
diffsynth/extensions/ImageQualityMetric/aesthetic.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import AutoProcessor, AutoModel
|
| 5 |
+
from safetensors.torch import load_file
|
| 6 |
+
import os
|
| 7 |
+
from typing import Union, List
|
| 8 |
+
from .config import MODEL_PATHS
|
| 9 |
+
|
| 10 |
+
class MLP(torch.nn.Module):
|
| 11 |
+
def __init__(self, input_size: int, xcol: str = "emb", ycol: str = "avg_rating"):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.input_size = input_size
|
| 14 |
+
self.xcol = xcol
|
| 15 |
+
self.ycol = ycol
|
| 16 |
+
self.layers = torch.nn.Sequential(
|
| 17 |
+
torch.nn.Linear(self.input_size, 1024),
|
| 18 |
+
#torch.nn.ReLU(),
|
| 19 |
+
torch.nn.Dropout(0.2),
|
| 20 |
+
torch.nn.Linear(1024, 128),
|
| 21 |
+
#torch.nn.ReLU(),
|
| 22 |
+
torch.nn.Dropout(0.2),
|
| 23 |
+
torch.nn.Linear(128, 64),
|
| 24 |
+
#torch.nn.ReLU(),
|
| 25 |
+
torch.nn.Dropout(0.1),
|
| 26 |
+
torch.nn.Linear(64, 16),
|
| 27 |
+
#torch.nn.ReLU(),
|
| 28 |
+
torch.nn.Linear(16, 1),
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
return self.layers(x)
|
| 33 |
+
|
| 34 |
+
def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
|
| 35 |
+
x = batch[self.xcol]
|
| 36 |
+
y = batch[self.ycol].reshape(-1, 1)
|
| 37 |
+
x_hat = self.layers(x)
|
| 38 |
+
loss = torch.nn.functional.mse_loss(x_hat, y)
|
| 39 |
+
return loss
|
| 40 |
+
|
| 41 |
+
def validation_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
|
| 42 |
+
x = batch[self.xcol]
|
| 43 |
+
y = batch[self.ycol].reshape(-1, 1)
|
| 44 |
+
x_hat = self.layers(x)
|
| 45 |
+
loss = torch.nn.functional.mse_loss(x_hat, y)
|
| 46 |
+
return loss
|
| 47 |
+
|
| 48 |
+
def configure_optimizers(self) -> torch.optim.Optimizer:
|
| 49 |
+
return torch.optim.Adam(self.parameters(), lr=1e-3)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class AestheticScore(torch.nn.Module):
|
| 53 |
+
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.device = device
|
| 56 |
+
self.aes_model_path = path.get("aesthetic_predictor")
|
| 57 |
+
# Load the MLP model
|
| 58 |
+
self.model = MLP(768)
|
| 59 |
+
try:
|
| 60 |
+
if self.aes_model_path.endswith(".safetensors"):
|
| 61 |
+
state_dict = load_file(self.aes_model_path)
|
| 62 |
+
else:
|
| 63 |
+
state_dict = torch.load(self.aes_model_path)
|
| 64 |
+
self.model.load_state_dict(state_dict)
|
| 65 |
+
except Exception as e:
|
| 66 |
+
raise ValueError(f"Error loading model weights from {self.aes_model_path}: {e}")
|
| 67 |
+
|
| 68 |
+
self.model.to(device)
|
| 69 |
+
self.model.eval()
|
| 70 |
+
|
| 71 |
+
# Load the CLIP model and processor
|
| 72 |
+
clip_model_name = path.get('clip-large')
|
| 73 |
+
self.model2 = AutoModel.from_pretrained(clip_model_name).eval().to(device)
|
| 74 |
+
self.processor = AutoProcessor.from_pretrained(clip_model_name)
|
| 75 |
+
|
| 76 |
+
def _calculate_score(self, image: torch.Tensor) -> float:
|
| 77 |
+
"""Calculate the aesthetic score for a single image.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
image (torch.Tensor): The processed image tensor.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
float: The aesthetic score.
|
| 84 |
+
"""
|
| 85 |
+
with torch.no_grad():
|
| 86 |
+
# Get image embeddings
|
| 87 |
+
image_embs = self.model2.get_image_features(image)
|
| 88 |
+
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
|
| 89 |
+
|
| 90 |
+
# Compute score
|
| 91 |
+
score = self.model(image_embs).cpu().flatten().item()
|
| 92 |
+
|
| 93 |
+
return score
|
| 94 |
+
|
| 95 |
+
@torch.no_grad()
|
| 96 |
+
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
|
| 97 |
+
"""Score the images based on their aesthetic quality.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
List[float]: List of scores for the images.
|
| 104 |
+
"""
|
| 105 |
+
try:
|
| 106 |
+
if isinstance(images, (str, Image.Image)):
|
| 107 |
+
# Single image
|
| 108 |
+
if isinstance(images, str):
|
| 109 |
+
pil_image = Image.open(images)
|
| 110 |
+
else:
|
| 111 |
+
pil_image = images
|
| 112 |
+
|
| 113 |
+
# Prepare image inputs
|
| 114 |
+
image_inputs = self.processor(
|
| 115 |
+
images=pil_image,
|
| 116 |
+
padding=True,
|
| 117 |
+
truncation=True,
|
| 118 |
+
max_length=77,
|
| 119 |
+
return_tensors="pt",
|
| 120 |
+
).to(self.device)
|
| 121 |
+
|
| 122 |
+
return [self._calculate_score(image_inputs["pixel_values"])]
|
| 123 |
+
elif isinstance(images, list):
|
| 124 |
+
# Multiple images
|
| 125 |
+
scores = []
|
| 126 |
+
for one_image in images:
|
| 127 |
+
if isinstance(one_image, str):
|
| 128 |
+
pil_image = Image.open(one_image)
|
| 129 |
+
elif isinstance(one_image, Image.Image):
|
| 130 |
+
pil_image = one_image
|
| 131 |
+
else:
|
| 132 |
+
raise TypeError("The type of parameter images is illegal.")
|
| 133 |
+
|
| 134 |
+
# Prepare image inputs
|
| 135 |
+
image_inputs = self.processor(
|
| 136 |
+
images=pil_image,
|
| 137 |
+
padding=True,
|
| 138 |
+
truncation=True,
|
| 139 |
+
max_length=77,
|
| 140 |
+
return_tensors="pt",
|
| 141 |
+
).to(self.device)
|
| 142 |
+
|
| 143 |
+
scores.append(self._calculate_score(image_inputs["pixel_values"]))
|
| 144 |
+
return scores
|
| 145 |
+
else:
|
| 146 |
+
raise TypeError("The type of parameter images is illegal.")
|
| 147 |
+
except Exception as e:
|
| 148 |
+
raise RuntimeError(f"Error in scoring images: {e}")
|
diffsynth/extensions/ImageQualityMetric/clip.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Union
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import torch
|
| 4 |
+
from .open_clip import create_model_and_transforms, get_tokenizer
|
| 5 |
+
from .config import MODEL_PATHS
|
| 6 |
+
|
| 7 |
+
class CLIPScore(torch.nn.Module):
|
| 8 |
+
def __init__(self, device: torch.device, path: str = MODEL_PATHS):
|
| 9 |
+
super().__init__()
|
| 10 |
+
"""Initialize the CLIPScore with a model and tokenizer.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
device (torch.device): The device to load the model on.
|
| 14 |
+
"""
|
| 15 |
+
self.device = device
|
| 16 |
+
|
| 17 |
+
# Create model and transforms
|
| 18 |
+
self.model, _, self.preprocess_val = create_model_and_transforms(
|
| 19 |
+
"ViT-H-14",
|
| 20 |
+
# "laion2B-s32B-b79K",
|
| 21 |
+
pretrained=path.get("open_clip"),
|
| 22 |
+
precision="amp",
|
| 23 |
+
device=device,
|
| 24 |
+
jit=False,
|
| 25 |
+
force_quick_gelu=False,
|
| 26 |
+
force_custom_text=False,
|
| 27 |
+
force_patch_dropout=False,
|
| 28 |
+
force_image_size=None,
|
| 29 |
+
pretrained_image=False,
|
| 30 |
+
image_mean=None,
|
| 31 |
+
image_std=None,
|
| 32 |
+
light_augmentation=True,
|
| 33 |
+
aug_cfg={},
|
| 34 |
+
output_dict=True,
|
| 35 |
+
with_score_predictor=False,
|
| 36 |
+
with_region_predictor=False,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# Initialize tokenizer
|
| 40 |
+
self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
|
| 41 |
+
self.model = self.model.to(device)
|
| 42 |
+
self.model.eval()
|
| 43 |
+
|
| 44 |
+
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
|
| 45 |
+
"""Calculate the CLIP score for a single image and prompt.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
image (torch.Tensor): The processed image tensor.
|
| 49 |
+
prompt (str): The prompt text.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
float: The CLIP score.
|
| 53 |
+
"""
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
# Process the prompt
|
| 56 |
+
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
|
| 57 |
+
|
| 58 |
+
# Calculate the CLIP score
|
| 59 |
+
outputs = self.model(image, text)
|
| 60 |
+
image_features, text_features = outputs["image_features"], outputs["text_features"]
|
| 61 |
+
logits_per_image = image_features @ text_features.T
|
| 62 |
+
clip_score = torch.diagonal(logits_per_image).cpu().numpy()
|
| 63 |
+
|
| 64 |
+
return clip_score[0].item()
|
| 65 |
+
|
| 66 |
+
@torch.no_grad()
|
| 67 |
+
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
| 68 |
+
"""Score the images based on the prompt.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
| 72 |
+
prompt (str): The prompt text.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
List[float]: List of CLIP scores for the images.
|
| 76 |
+
"""
|
| 77 |
+
if isinstance(images, (str, Image.Image)):
|
| 78 |
+
# Single image
|
| 79 |
+
if isinstance(images, str):
|
| 80 |
+
image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
| 81 |
+
else:
|
| 82 |
+
image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
| 83 |
+
return [self._calculate_score(image, prompt)]
|
| 84 |
+
elif isinstance(images, list):
|
| 85 |
+
# Multiple images
|
| 86 |
+
scores = []
|
| 87 |
+
for one_images in images:
|
| 88 |
+
if isinstance(one_images, str):
|
| 89 |
+
image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
| 90 |
+
elif isinstance(one_images, Image.Image):
|
| 91 |
+
image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
| 92 |
+
else:
|
| 93 |
+
raise TypeError("The type of parameter images is illegal.")
|
| 94 |
+
scores.append(self._calculate_score(image, prompt))
|
| 95 |
+
return scores
|
| 96 |
+
else:
|
| 97 |
+
raise TypeError("The type of parameter images is illegal.")
|
diffsynth/extensions/ImageQualityMetric/config.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 4 |
+
project_root = os.path.abspath(os.path.join(current_dir, '../../../'))
|
| 5 |
+
model_path = os.path.join(project_root, 'models', 'QualityMetric')
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_model_path(model_name):
|
| 9 |
+
return os.path.join(model_path, model_name)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
MODEL_PATHS = {
|
| 13 |
+
"aesthetic_predictor": get_model_path("aesthetic-predictor/sac+logos+ava1-l14-linearMSE.safetensors"),
|
| 14 |
+
"open_clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"),
|
| 15 |
+
"hpsv2": get_model_path("HPS_v2/HPS_v2_compressed.safetensors"),
|
| 16 |
+
"hpsv2.1": get_model_path("HPS_v2/HPS_v2.1_compressed.safetensors"),
|
| 17 |
+
"imagereward": get_model_path("ImageReward/ImageReward.safetensors"),
|
| 18 |
+
"med_config": get_model_path("ImageReward/med_config.json"),
|
| 19 |
+
"clip": get_model_path("CLIP-ViT-H-14-laion2B-s32B-b79K"),
|
| 20 |
+
"clip-large": get_model_path("clip-vit-large-patch14"),
|
| 21 |
+
"mps": get_model_path("MPS_overall_checkpoint/MPS_overall_checkpoint_diffsynth.safetensors"),
|
| 22 |
+
"pickscore": get_model_path("PickScore_v1")
|
| 23 |
+
}
|
diffsynth/extensions/ImageQualityMetric/hps.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Union
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import torch
|
| 4 |
+
from .open_clip import create_model_and_transforms, get_tokenizer
|
| 5 |
+
from safetensors.torch import load_file
|
| 6 |
+
import os
|
| 7 |
+
from .config import MODEL_PATHS
|
| 8 |
+
|
| 9 |
+
class HPScore_v2(torch.nn.Module):
|
| 10 |
+
def __init__(self, device: torch.device, path: str = MODEL_PATHS, model_version: str = "v2"):
|
| 11 |
+
super().__init__()
|
| 12 |
+
"""Initialize the Selector with a model and tokenizer.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
device (torch.device): The device to load the model on.
|
| 16 |
+
model_version (str): The version of the model to load. Supports "v2" or "v21". Default is "v2".
|
| 17 |
+
"""
|
| 18 |
+
self.device = device
|
| 19 |
+
|
| 20 |
+
if model_version == "v2":
|
| 21 |
+
safetensors_path = path.get("hpsv2")
|
| 22 |
+
elif model_version == "v21":
|
| 23 |
+
safetensors_path = path.get("hpsv2.1")
|
| 24 |
+
else:
|
| 25 |
+
raise ValueError(f"Unsupported model version: {model_version}. Choose 'v2' or 'v21'.")
|
| 26 |
+
|
| 27 |
+
# Create model and transforms
|
| 28 |
+
model, _, self.preprocess_val = create_model_and_transforms(
|
| 29 |
+
"ViT-H-14",
|
| 30 |
+
# "laion2B-s32B-b79K",
|
| 31 |
+
pretrained=path.get("open_clip"),
|
| 32 |
+
precision="amp",
|
| 33 |
+
device=device,
|
| 34 |
+
jit=False,
|
| 35 |
+
force_quick_gelu=False,
|
| 36 |
+
force_custom_text=False,
|
| 37 |
+
force_patch_dropout=False,
|
| 38 |
+
force_image_size=None,
|
| 39 |
+
pretrained_image=False,
|
| 40 |
+
image_mean=None,
|
| 41 |
+
image_std=None,
|
| 42 |
+
light_augmentation=True,
|
| 43 |
+
aug_cfg={},
|
| 44 |
+
output_dict=True,
|
| 45 |
+
with_score_predictor=False,
|
| 46 |
+
with_region_predictor=False,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Load model weights
|
| 50 |
+
try:
|
| 51 |
+
state_dict = load_file(safetensors_path)
|
| 52 |
+
model.load_state_dict(state_dict)
|
| 53 |
+
except Exception as e:
|
| 54 |
+
raise ValueError(f"Error loading model weights from {safetensors_path}: {e}")
|
| 55 |
+
|
| 56 |
+
# Initialize tokenizer and model
|
| 57 |
+
self.tokenizer = get_tokenizer("ViT-H-14", path["open_clip_bpe"])
|
| 58 |
+
model = model.to(device)
|
| 59 |
+
model.eval()
|
| 60 |
+
self.model = model
|
| 61 |
+
|
| 62 |
+
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
|
| 63 |
+
"""Calculate the HPS score for a single image and prompt.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
image (torch.Tensor): The processed image tensor.
|
| 67 |
+
prompt (str): The prompt text.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
float: The HPS score.
|
| 71 |
+
"""
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
# Process the prompt
|
| 74 |
+
text = self.tokenizer([prompt]).to(device=self.device, non_blocking=True)
|
| 75 |
+
|
| 76 |
+
# Calculate the HPS score
|
| 77 |
+
outputs = self.model(image, text)
|
| 78 |
+
image_features, text_features = outputs["image_features"], outputs["text_features"]
|
| 79 |
+
logits_per_image = image_features @ text_features.T
|
| 80 |
+
hps_score = torch.diagonal(logits_per_image).cpu().numpy()
|
| 81 |
+
|
| 82 |
+
return hps_score[0].item()
|
| 83 |
+
|
| 84 |
+
@torch.no_grad()
|
| 85 |
+
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
| 86 |
+
"""Score the images based on the prompt.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
| 90 |
+
prompt (str): The prompt text.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
List[float]: List of HPS scores for the images.
|
| 94 |
+
"""
|
| 95 |
+
try:
|
| 96 |
+
if isinstance(images, (str, Image.Image)):
|
| 97 |
+
# Single image
|
| 98 |
+
if isinstance(images, str):
|
| 99 |
+
image = self.preprocess_val(Image.open(images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
| 100 |
+
else:
|
| 101 |
+
image = self.preprocess_val(images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
| 102 |
+
return [self._calculate_score(image, prompt)]
|
| 103 |
+
elif isinstance(images, list):
|
| 104 |
+
# Multiple images
|
| 105 |
+
scores = []
|
| 106 |
+
for one_images in images:
|
| 107 |
+
if isinstance(one_images, str):
|
| 108 |
+
image = self.preprocess_val(Image.open(one_images)).unsqueeze(0).to(device=self.device, non_blocking=True)
|
| 109 |
+
elif isinstance(one_images, Image.Image):
|
| 110 |
+
image = self.preprocess_val(one_images).unsqueeze(0).to(device=self.device, non_blocking=True)
|
| 111 |
+
else:
|
| 112 |
+
raise TypeError("The type of parameter images is illegal.")
|
| 113 |
+
scores.append(self._calculate_score(image, prompt))
|
| 114 |
+
return scores
|
| 115 |
+
else:
|
| 116 |
+
raise TypeError("The type of parameter images is illegal.")
|
| 117 |
+
except Exception as e:
|
| 118 |
+
raise RuntimeError(f"Error in scoring images: {e}")
|
diffsynth/extensions/ImageQualityMetric/imagereward.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from typing import List, Union
|
| 5 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
| 6 |
+
from .BLIP.blip_pretrain import BLIP_Pretrain
|
| 7 |
+
from torchvision.transforms import InterpolationMode
|
| 8 |
+
from safetensors.torch import load_file
|
| 9 |
+
from .config import MODEL_PATHS
|
| 10 |
+
BICUBIC = InterpolationMode.BICUBIC
|
| 11 |
+
|
| 12 |
+
def _convert_image_to_rgb(image):
|
| 13 |
+
return image.convert("RGB")
|
| 14 |
+
|
| 15 |
+
def _transform(n_px):
|
| 16 |
+
return Compose([
|
| 17 |
+
Resize(n_px, interpolation=BICUBIC),
|
| 18 |
+
CenterCrop(n_px),
|
| 19 |
+
_convert_image_to_rgb,
|
| 20 |
+
ToTensor(),
|
| 21 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
| 22 |
+
])
|
| 23 |
+
|
| 24 |
+
class MLP(torch.nn.Module):
|
| 25 |
+
def __init__(self, input_size):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.input_size = input_size
|
| 28 |
+
|
| 29 |
+
self.layers = torch.nn.Sequential(
|
| 30 |
+
torch.nn.Linear(self.input_size, 1024),
|
| 31 |
+
#nn.ReLU(),
|
| 32 |
+
torch.nn.Dropout(0.2),
|
| 33 |
+
torch.nn.Linear(1024, 128),
|
| 34 |
+
#nn.ReLU(),
|
| 35 |
+
torch.nn.Dropout(0.2),
|
| 36 |
+
torch.nn.Linear(128, 64),
|
| 37 |
+
#nn.ReLU(),
|
| 38 |
+
torch.nn.Dropout(0.1),
|
| 39 |
+
torch.nn.Linear(64, 16),
|
| 40 |
+
#nn.ReLU(),
|
| 41 |
+
torch.nn.Linear(16, 1)
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# initial MLP param
|
| 45 |
+
for name, param in self.layers.named_parameters():
|
| 46 |
+
if 'weight' in name:
|
| 47 |
+
torch.nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1))
|
| 48 |
+
if 'bias' in name:
|
| 49 |
+
torch.nn.init.constant_(param, val=0)
|
| 50 |
+
|
| 51 |
+
def forward(self, input):
|
| 52 |
+
return self.layers(input)
|
| 53 |
+
|
| 54 |
+
class ImageReward(torch.nn.Module):
|
| 55 |
+
def __init__(self, med_config, device='cpu', bert_model_path=""):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.device = device
|
| 58 |
+
|
| 59 |
+
self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config, bert_model_path=bert_model_path)
|
| 60 |
+
self.preprocess = _transform(224)
|
| 61 |
+
self.mlp = MLP(768)
|
| 62 |
+
|
| 63 |
+
self.mean = 0.16717362830052426
|
| 64 |
+
self.std = 1.0333394966054072
|
| 65 |
+
|
| 66 |
+
def score_grad(self, prompt_ids, prompt_attention_mask, image):
|
| 67 |
+
"""Calculate the score with gradient for a single image and prompt.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
prompt_ids (torch.Tensor): Tokenized prompt IDs.
|
| 71 |
+
prompt_attention_mask (torch.Tensor): Attention mask for the prompt.
|
| 72 |
+
image (torch.Tensor): The processed image tensor.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
torch.Tensor: The reward score.
|
| 76 |
+
"""
|
| 77 |
+
image_embeds = self.blip.visual_encoder(image)
|
| 78 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
| 79 |
+
text_output = self.blip.text_encoder(
|
| 80 |
+
prompt_ids,
|
| 81 |
+
attention_mask=prompt_attention_mask,
|
| 82 |
+
encoder_hidden_states=image_embeds,
|
| 83 |
+
encoder_attention_mask=image_atts,
|
| 84 |
+
return_dict=True,
|
| 85 |
+
)
|
| 86 |
+
txt_features = text_output.last_hidden_state[:, 0, :]
|
| 87 |
+
rewards = self.mlp(txt_features)
|
| 88 |
+
rewards = (rewards - self.mean) / self.std
|
| 89 |
+
return rewards
|
| 90 |
+
|
| 91 |
+
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]:
|
| 92 |
+
"""Score the images based on the prompt.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
prompt (str): The prompt text.
|
| 96 |
+
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
List[float]: List of scores for the images.
|
| 100 |
+
"""
|
| 101 |
+
if isinstance(images, (str, Image.Image)):
|
| 102 |
+
# Single image
|
| 103 |
+
if isinstance(images, str):
|
| 104 |
+
pil_image = Image.open(images)
|
| 105 |
+
else:
|
| 106 |
+
pil_image = images
|
| 107 |
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
| 108 |
+
return [self._calculate_score(prompt, image).item()]
|
| 109 |
+
elif isinstance(images, list):
|
| 110 |
+
# Multiple images
|
| 111 |
+
scores = []
|
| 112 |
+
for one_image in images:
|
| 113 |
+
if isinstance(one_image, str):
|
| 114 |
+
pil_image = Image.open(one_image)
|
| 115 |
+
elif isinstance(one_image, Image.Image):
|
| 116 |
+
pil_image = one_image
|
| 117 |
+
else:
|
| 118 |
+
raise TypeError("The type of parameter images is illegal.")
|
| 119 |
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
| 120 |
+
scores.append(self._calculate_score(prompt, image).item())
|
| 121 |
+
return scores
|
| 122 |
+
else:
|
| 123 |
+
raise TypeError("The type of parameter images is illegal.")
|
| 124 |
+
|
| 125 |
+
def _calculate_score(self, prompt: str, image: torch.Tensor) -> torch.Tensor:
|
| 126 |
+
"""Calculate the score for a single image and prompt.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
prompt (str): The prompt text.
|
| 130 |
+
image (torch.Tensor): The processed image tensor.
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
torch.Tensor: The reward score.
|
| 134 |
+
"""
|
| 135 |
+
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
|
| 136 |
+
image_embeds = self.blip.visual_encoder(image)
|
| 137 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
| 138 |
+
text_output = self.blip.text_encoder(
|
| 139 |
+
text_input.input_ids,
|
| 140 |
+
attention_mask=text_input.attention_mask,
|
| 141 |
+
encoder_hidden_states=image_embeds,
|
| 142 |
+
encoder_attention_mask=image_atts,
|
| 143 |
+
return_dict=True,
|
| 144 |
+
)
|
| 145 |
+
txt_features = text_output.last_hidden_state[:, 0, :].float()
|
| 146 |
+
rewards = self.mlp(txt_features)
|
| 147 |
+
rewards = (rewards - self.mean) / self.std
|
| 148 |
+
return rewards
|
| 149 |
+
|
| 150 |
+
def inference_rank(self, prompt: str, generations_list: List[Union[str, Image.Image]]) -> tuple:
|
| 151 |
+
"""Rank the images based on the prompt.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
prompt (str): The prompt text.
|
| 155 |
+
generations_list (List[Union[str, Image.Image]]): List of image paths or PIL images.
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
tuple: (indices, rewards) where indices are the ranks and rewards are the scores.
|
| 159 |
+
"""
|
| 160 |
+
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device)
|
| 161 |
+
txt_set = []
|
| 162 |
+
for generation in generations_list:
|
| 163 |
+
if isinstance(generation, str):
|
| 164 |
+
pil_image = Image.open(generation)
|
| 165 |
+
elif isinstance(generation, Image.Image):
|
| 166 |
+
pil_image = generation
|
| 167 |
+
else:
|
| 168 |
+
raise TypeError("The type of parameter generations_list is illegal.")
|
| 169 |
+
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
| 170 |
+
image_embeds = self.blip.visual_encoder(image)
|
| 171 |
+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.device)
|
| 172 |
+
text_output = self.blip.text_encoder(
|
| 173 |
+
text_input.input_ids,
|
| 174 |
+
attention_mask=text_input.attention_mask,
|
| 175 |
+
encoder_hidden_states=image_embeds,
|
| 176 |
+
encoder_attention_mask=image_atts,
|
| 177 |
+
return_dict=True,
|
| 178 |
+
)
|
| 179 |
+
txt_set.append(text_output.last_hidden_state[:, 0, :])
|
| 180 |
+
txt_features = torch.cat(txt_set, 0).float()
|
| 181 |
+
rewards = self.mlp(txt_features)
|
| 182 |
+
rewards = (rewards - self.mean) / self.std
|
| 183 |
+
rewards = torch.squeeze(rewards)
|
| 184 |
+
_, rank = torch.sort(rewards, dim=0, descending=True)
|
| 185 |
+
_, indices = torch.sort(rank, dim=0)
|
| 186 |
+
indices = indices + 1
|
| 187 |
+
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist()
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class ImageRewardScore(torch.nn.Module):
|
| 191 |
+
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS):
|
| 192 |
+
super().__init__()
|
| 193 |
+
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
| 194 |
+
model_path = path.get("imagereward")
|
| 195 |
+
med_config = path.get("med_config")
|
| 196 |
+
state_dict = load_file(model_path)
|
| 197 |
+
self.model = ImageReward(device=self.device, med_config=med_config, bert_model_path=path.get("bert_model_path")).to(self.device)
|
| 198 |
+
self.model.load_state_dict(state_dict, strict=False)
|
| 199 |
+
self.model.eval()
|
| 200 |
+
|
| 201 |
+
@torch.no_grad()
|
| 202 |
+
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
| 203 |
+
"""Score the images based on the prompt.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
| 207 |
+
prompt (str): The prompt text.
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
List[float]: List of scores for the images.
|
| 211 |
+
"""
|
| 212 |
+
return self.model.score(images, prompt)
|
diffsynth/extensions/ImageQualityMetric/mps.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
from tqdm.auto import tqdm
|
| 6 |
+
from transformers import CLIPFeatureExtractor, CLIPImageProcessor
|
| 7 |
+
from transformers import CLIPConfig
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from transformers import CLIPModel as HFCLIPModel
|
| 10 |
+
from safetensors.torch import load_file
|
| 11 |
+
from torch import nn, einsum
|
| 12 |
+
|
| 13 |
+
from .trainer.models.base_model import BaseModelConfig
|
| 14 |
+
|
| 15 |
+
from transformers import CLIPConfig
|
| 16 |
+
from transformers import AutoProcessor, AutoModel, AutoTokenizer
|
| 17 |
+
from typing import Any, Optional, Tuple, Union, List
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from .trainer.models.cross_modeling import Cross_model
|
| 21 |
+
from .trainer.models import clip_model
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
import gc
|
| 24 |
+
import json
|
| 25 |
+
from .config import MODEL_PATHS
|
| 26 |
+
|
| 27 |
+
class MPScore(torch.nn.Module):
|
| 28 |
+
def __init__(self, device: Union[str, torch.device], path: str = MODEL_PATHS, condition: str = 'overall'):
|
| 29 |
+
super().__init__()
|
| 30 |
+
"""Initialize the MPSModel with a processor, tokenizer, and model.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
device (Union[str, torch.device]): The device to load the model on.
|
| 34 |
+
"""
|
| 35 |
+
self.device = device
|
| 36 |
+
processor_name_or_path = path.get("clip")
|
| 37 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(processor_name_or_path)
|
| 38 |
+
self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
|
| 39 |
+
self.model = clip_model.CLIPModel(processor_name_or_path, config_file=True)
|
| 40 |
+
state_dict = load_file(path.get("mps"))
|
| 41 |
+
self.model.load_state_dict(state_dict, strict=False)
|
| 42 |
+
self.model.to(device)
|
| 43 |
+
self.condition = condition
|
| 44 |
+
|
| 45 |
+
def _calculate_score(self, image: torch.Tensor, prompt: str) -> float:
|
| 46 |
+
"""Calculate the reward score for a single image and prompt.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
image (torch.Tensor): The processed image tensor.
|
| 50 |
+
prompt (str): The prompt text.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
float: The reward score.
|
| 54 |
+
"""
|
| 55 |
+
def _tokenize(caption):
|
| 56 |
+
input_ids = self.tokenizer(
|
| 57 |
+
caption,
|
| 58 |
+
max_length=self.tokenizer.model_max_length,
|
| 59 |
+
padding="max_length",
|
| 60 |
+
truncation=True,
|
| 61 |
+
return_tensors="pt"
|
| 62 |
+
).input_ids
|
| 63 |
+
return input_ids
|
| 64 |
+
|
| 65 |
+
text_input = _tokenize(prompt).to(self.device)
|
| 66 |
+
if self.condition == 'overall':
|
| 67 |
+
condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things'
|
| 68 |
+
elif self.condition == 'aesthetics':
|
| 69 |
+
condition_prompt = 'light, color, clarity, tone, style, ambiance, artistry'
|
| 70 |
+
elif self.condition == 'quality':
|
| 71 |
+
condition_prompt = 'shape, face, hair, hands, limbs, structure, instance, texture'
|
| 72 |
+
elif self.condition == 'semantic':
|
| 73 |
+
condition_prompt = 'quantity, attributes, position, number, location'
|
| 74 |
+
else:
|
| 75 |
+
raise ValueError(
|
| 76 |
+
f"Unsupported condition: {self.condition}. Choose 'overall', 'aesthetics', 'quality', or 'semantic'.")
|
| 77 |
+
condition_batch = _tokenize(condition_prompt).repeat(text_input.shape[0], 1).to(self.device)
|
| 78 |
+
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
text_f, text_features = self.model.model.get_text_features(text_input)
|
| 81 |
+
|
| 82 |
+
image_f = self.model.model.get_image_features(image.half())
|
| 83 |
+
condition_f, _ = self.model.model.get_text_features(condition_batch)
|
| 84 |
+
|
| 85 |
+
sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
|
| 86 |
+
sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
|
| 87 |
+
sim_text_condition = sim_text_condition / sim_text_condition.max()
|
| 88 |
+
mask = torch.where(sim_text_condition > 0.3, 0, float('-inf'))
|
| 89 |
+
mask = mask.repeat(1, image_f.shape[1], 1)
|
| 90 |
+
image_features = self.model.cross_model(image_f, text_f, mask.half())[:, 0, :]
|
| 91 |
+
|
| 92 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
| 93 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 94 |
+
image_score = self.model.logit_scale.exp() * text_features @ image_features.T
|
| 95 |
+
|
| 96 |
+
return image_score[0].cpu().numpy().item()
|
| 97 |
+
|
| 98 |
+
@torch.no_grad()
|
| 99 |
+
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str) -> List[float]:
|
| 100 |
+
"""Score the images based on the prompt.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s).
|
| 104 |
+
prompt (str): The prompt text.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
List[float]: List of reward scores for the images.
|
| 108 |
+
"""
|
| 109 |
+
if isinstance(images, (str, Image.Image)):
|
| 110 |
+
# Single image
|
| 111 |
+
if isinstance(images, str):
|
| 112 |
+
image = self.image_processor(Image.open(images), return_tensors="pt")["pixel_values"].to(self.device)
|
| 113 |
+
else:
|
| 114 |
+
image = self.image_processor(images, return_tensors="pt")["pixel_values"].to(self.device)
|
| 115 |
+
return [self._calculate_score(image, prompt)]
|
| 116 |
+
elif isinstance(images, list):
|
| 117 |
+
# Multiple images
|
| 118 |
+
scores = []
|
| 119 |
+
for one_images in images:
|
| 120 |
+
if isinstance(one_images, str):
|
| 121 |
+
image = self.image_processor(Image.open(one_images), return_tensors="pt")["pixel_values"].to(self.device)
|
| 122 |
+
elif isinstance(one_images, Image.Image):
|
| 123 |
+
image = self.image_processor(one_images, return_tensors="pt")["pixel_values"].to(self.device)
|
| 124 |
+
else:
|
| 125 |
+
raise TypeError("The type of parameter images is illegal.")
|
| 126 |
+
scores.append(self._calculate_score(image, prompt))
|
| 127 |
+
return scores
|
| 128 |
+
else:
|
| 129 |
+
raise TypeError("The type of parameter images is illegal.")
|
diffsynth/extensions/ImageQualityMetric/open_clip/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .coca_model import CoCa
|
| 2 |
+
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
| 3 |
+
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
|
| 4 |
+
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
|
| 5 |
+
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
|
| 6 |
+
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
|
| 7 |
+
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
|
| 8 |
+
from .openai import load_openai_model, list_openai_models
|
| 9 |
+
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
|
| 10 |
+
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
|
| 11 |
+
from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
|
| 12 |
+
from .tokenizer import SimpleTokenizer
|
| 13 |
+
from .transform import image_transform, AugmentationCfg
|
| 14 |
+
from .utils import freeze_batch_norm_2d
|
diffsynth/extensions/ImageQualityMetric/open_clip/coca_model.py
ADDED
|
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
import numpy as np
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
|
| 9 |
+
from .transformer import (
|
| 10 |
+
LayerNormFp32,
|
| 11 |
+
LayerNorm,
|
| 12 |
+
QuickGELU,
|
| 13 |
+
MultimodalTransformer,
|
| 14 |
+
)
|
| 15 |
+
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from transformers import (
|
| 19 |
+
BeamSearchScorer,
|
| 20 |
+
LogitsProcessorList,
|
| 21 |
+
TopPLogitsWarper,
|
| 22 |
+
TopKLogitsWarper,
|
| 23 |
+
RepetitionPenaltyLogitsProcessor,
|
| 24 |
+
MinLengthLogitsProcessor,
|
| 25 |
+
MaxLengthCriteria,
|
| 26 |
+
StoppingCriteriaList
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
GENERATION_TYPES = {
|
| 30 |
+
"top_k": TopKLogitsWarper,
|
| 31 |
+
"top_p": TopPLogitsWarper,
|
| 32 |
+
"beam_search": "beam_search"
|
| 33 |
+
}
|
| 34 |
+
_has_transformers = True
|
| 35 |
+
except ImportError as e:
|
| 36 |
+
GENERATION_TYPES = {
|
| 37 |
+
"top_k": None,
|
| 38 |
+
"top_p": None,
|
| 39 |
+
"beam_search": "beam_search"
|
| 40 |
+
}
|
| 41 |
+
_has_transformers = False
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclass
|
| 45 |
+
class MultimodalCfg(CLIPTextCfg):
|
| 46 |
+
mlp_ratio: int = 4
|
| 47 |
+
dim_head: int = 64
|
| 48 |
+
heads: int = 8
|
| 49 |
+
n_queries: int = 256
|
| 50 |
+
attn_pooler_heads: int = 8
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _build_text_decoder_tower(
|
| 54 |
+
embed_dim,
|
| 55 |
+
multimodal_cfg,
|
| 56 |
+
quick_gelu: bool = False,
|
| 57 |
+
cast_dtype: Optional[torch.dtype] = None,
|
| 58 |
+
):
|
| 59 |
+
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
| 60 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
| 61 |
+
norm_layer = (
|
| 62 |
+
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
decoder = MultimodalTransformer(
|
| 66 |
+
context_length=multimodal_cfg.context_length,
|
| 67 |
+
width=multimodal_cfg.width,
|
| 68 |
+
heads=multimodal_cfg.heads,
|
| 69 |
+
layers=multimodal_cfg.layers,
|
| 70 |
+
ls_init_value=multimodal_cfg.ls_init_value,
|
| 71 |
+
output_dim=embed_dim,
|
| 72 |
+
act_layer=act_layer,
|
| 73 |
+
norm_layer=norm_layer,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
return decoder
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class CoCa(nn.Module):
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
embed_dim,
|
| 83 |
+
multimodal_cfg: MultimodalCfg,
|
| 84 |
+
text_cfg: CLIPTextCfg,
|
| 85 |
+
vision_cfg: CLIPVisionCfg,
|
| 86 |
+
quick_gelu: bool = False,
|
| 87 |
+
cast_dtype: Optional[torch.dtype] = None,
|
| 88 |
+
pad_id: int = 0,
|
| 89 |
+
):
|
| 90 |
+
super().__init__()
|
| 91 |
+
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
|
| 92 |
+
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
|
| 93 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
|
| 94 |
+
|
| 95 |
+
self.text = _build_text_tower(
|
| 96 |
+
embed_dim=embed_dim,
|
| 97 |
+
text_cfg=text_cfg,
|
| 98 |
+
quick_gelu=quick_gelu,
|
| 99 |
+
cast_dtype=cast_dtype,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
vocab_size = (
|
| 103 |
+
text_cfg.vocab_size # for hf models
|
| 104 |
+
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
|
| 105 |
+
else text_cfg.vocab_size
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
self.visual = _build_vision_tower(
|
| 109 |
+
embed_dim=embed_dim,
|
| 110 |
+
vision_cfg=vision_cfg,
|
| 111 |
+
quick_gelu=quick_gelu,
|
| 112 |
+
cast_dtype=cast_dtype,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
self.text_decoder = _build_text_decoder_tower(
|
| 116 |
+
vocab_size,
|
| 117 |
+
multimodal_cfg=multimodal_cfg,
|
| 118 |
+
quick_gelu=quick_gelu,
|
| 119 |
+
cast_dtype=cast_dtype,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 123 |
+
self.pad_id = pad_id
|
| 124 |
+
|
| 125 |
+
@torch.jit.ignore
|
| 126 |
+
def set_grad_checkpointing(self, enable=True):
|
| 127 |
+
self.visual.set_grad_checkpointing(enable)
|
| 128 |
+
self.text.set_grad_checkpointing(enable)
|
| 129 |
+
self.text_decoder.set_grad_checkpointing(enable)
|
| 130 |
+
|
| 131 |
+
def _encode_image(self, images, normalize=True):
|
| 132 |
+
image_latent, tokens_embs = self.visual(images)
|
| 133 |
+
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
|
| 134 |
+
return image_latent, tokens_embs
|
| 135 |
+
|
| 136 |
+
def _encode_text(self, text, normalize=True, embed_cls=True):
|
| 137 |
+
text = text[:, :-1] if embed_cls else text # make space for CLS token
|
| 138 |
+
text_latent, token_emb = self.text(text)
|
| 139 |
+
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
|
| 140 |
+
return text_latent, token_emb
|
| 141 |
+
|
| 142 |
+
def encode_image(self, images, normalize=True):
|
| 143 |
+
image_latent, _ = self._encode_image(images, normalize=normalize)
|
| 144 |
+
return image_latent
|
| 145 |
+
|
| 146 |
+
def encode_text(self, text, normalize=True, embed_cls=True):
|
| 147 |
+
text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
|
| 148 |
+
return text_latent
|
| 149 |
+
|
| 150 |
+
def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
|
| 151 |
+
text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
|
| 152 |
+
if image_latent is None or image_embs is None:
|
| 153 |
+
image_latent, image_embs = self._encode_image(image)
|
| 154 |
+
|
| 155 |
+
# TODO: add assertion to avoid bugs?
|
| 156 |
+
labels = text[:, -token_embs.shape[1]:]
|
| 157 |
+
|
| 158 |
+
logits = self.text_decoder(image_embs, token_embs)
|
| 159 |
+
return {
|
| 160 |
+
"image_features": image_latent,
|
| 161 |
+
"text_features": text_latent,
|
| 162 |
+
"logits": logits,
|
| 163 |
+
"labels": labels,
|
| 164 |
+
"logit_scale": self.logit_scale.exp()
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
def generate(
|
| 168 |
+
self,
|
| 169 |
+
image,
|
| 170 |
+
text=None,
|
| 171 |
+
seq_len=30,
|
| 172 |
+
max_seq_len=77,
|
| 173 |
+
temperature=1.,
|
| 174 |
+
generation_type="beam_search",
|
| 175 |
+
top_p=0.1, # keep tokens in the 1 - top_p quantile
|
| 176 |
+
top_k=1, # keeps the top_k most probable tokens
|
| 177 |
+
pad_token_id=None,
|
| 178 |
+
eos_token_id=None,
|
| 179 |
+
sot_token_id=None,
|
| 180 |
+
num_beams=6,
|
| 181 |
+
num_beam_groups=3,
|
| 182 |
+
min_seq_len=5,
|
| 183 |
+
stopping_criteria=None,
|
| 184 |
+
repetition_penalty=1.0,
|
| 185 |
+
fixed_output_length=False # if True output.shape == (batch_size, seq_len)
|
| 186 |
+
):
|
| 187 |
+
# taking many ideas and components from HuggingFace GenerationMixin
|
| 188 |
+
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
|
| 189 |
+
assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
|
| 190 |
+
assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
|
| 191 |
+
|
| 192 |
+
with torch.no_grad():
|
| 193 |
+
sot_token_id = 49406 if sot_token_id is None else sot_token_id
|
| 194 |
+
eos_token_id = 49407 if eos_token_id is None else eos_token_id
|
| 195 |
+
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
|
| 196 |
+
logit_processor = LogitsProcessorList(
|
| 197 |
+
[
|
| 198 |
+
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
|
| 199 |
+
RepetitionPenaltyLogitsProcessor(repetition_penalty),
|
| 200 |
+
]
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
if stopping_criteria is None:
|
| 204 |
+
stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
|
| 205 |
+
|
| 206 |
+
stopping_criteria = StoppingCriteriaList(
|
| 207 |
+
stopping_criteria
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
device = image.device
|
| 211 |
+
|
| 212 |
+
if generation_type == "beam_search":
|
| 213 |
+
output = self._generate_beamsearch(
|
| 214 |
+
image_inputs = image,
|
| 215 |
+
pad_token_id=pad_token_id,
|
| 216 |
+
eos_token_id=eos_token_id,
|
| 217 |
+
sot_token_id=sot_token_id,
|
| 218 |
+
num_beams=num_beams,
|
| 219 |
+
num_beam_groups=num_beam_groups,
|
| 220 |
+
min_seq_len=min_seq_len,
|
| 221 |
+
stopping_criteria=stopping_criteria,
|
| 222 |
+
logit_processor=logit_processor,
|
| 223 |
+
)
|
| 224 |
+
if fixed_output_length and output.shape[1] < seq_len:
|
| 225 |
+
return torch.cat(
|
| 226 |
+
(output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),
|
| 227 |
+
dim=1
|
| 228 |
+
)
|
| 229 |
+
return output
|
| 230 |
+
|
| 231 |
+
elif generation_type == "top_p":
|
| 232 |
+
logit_warper = GENERATION_TYPES[generation_type](top_p)
|
| 233 |
+
elif generation_type == "top_k":
|
| 234 |
+
logit_warper = GENERATION_TYPES[generation_type](top_k)
|
| 235 |
+
else:
|
| 236 |
+
raise ValueError(
|
| 237 |
+
f"generation_type has to be one of "
|
| 238 |
+
f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
image_latent, image_embs = self._encode_image(image)
|
| 242 |
+
|
| 243 |
+
if text is None:
|
| 244 |
+
text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
|
| 245 |
+
|
| 246 |
+
was_training = self.training
|
| 247 |
+
num_dims = len(text.shape)
|
| 248 |
+
|
| 249 |
+
if num_dims == 1:
|
| 250 |
+
text = text[None, :]
|
| 251 |
+
|
| 252 |
+
cur_len = text.shape[1]
|
| 253 |
+
self.eval()
|
| 254 |
+
out = text
|
| 255 |
+
|
| 256 |
+
while True:
|
| 257 |
+
x = out[:, -max_seq_len:]
|
| 258 |
+
cur_len = x.shape[1]
|
| 259 |
+
logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1]
|
| 260 |
+
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
|
| 261 |
+
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
|
| 262 |
+
|
| 263 |
+
if mask.all():
|
| 264 |
+
if not fixed_output_length:
|
| 265 |
+
break
|
| 266 |
+
else:
|
| 267 |
+
logits = logits[~mask, :]
|
| 268 |
+
filtered_logits = logit_processor(x[~mask, :], logits)
|
| 269 |
+
filtered_logits = logit_warper(x[~mask, :], filtered_logits)
|
| 270 |
+
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
| 271 |
+
|
| 272 |
+
if (cur_len + 1 == seq_len):
|
| 273 |
+
sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
|
| 274 |
+
else:
|
| 275 |
+
sample[~mask, :] = torch.multinomial(probs, 1)
|
| 276 |
+
|
| 277 |
+
out = torch.cat((out, sample), dim=-1)
|
| 278 |
+
|
| 279 |
+
cur_len += 1
|
| 280 |
+
|
| 281 |
+
if stopping_criteria(out, None):
|
| 282 |
+
break
|
| 283 |
+
|
| 284 |
+
if num_dims == 1:
|
| 285 |
+
out = out.squeeze(0)
|
| 286 |
+
|
| 287 |
+
self.train(was_training)
|
| 288 |
+
return out
|
| 289 |
+
|
| 290 |
+
def _generate_beamsearch(
|
| 291 |
+
self,
|
| 292 |
+
image_inputs,
|
| 293 |
+
pad_token_id=None,
|
| 294 |
+
eos_token_id=None,
|
| 295 |
+
sot_token_id=None,
|
| 296 |
+
num_beams=6,
|
| 297 |
+
num_beam_groups=3,
|
| 298 |
+
min_seq_len=5,
|
| 299 |
+
stopping_criteria=None,
|
| 300 |
+
logit_processor=None,
|
| 301 |
+
logit_warper=None,
|
| 302 |
+
):
|
| 303 |
+
device = image_inputs.device
|
| 304 |
+
batch_size = image_inputs.shape[0]
|
| 305 |
+
image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
|
| 306 |
+
image_latent, image_embs = self._encode_image(image_inputs)
|
| 307 |
+
|
| 308 |
+
input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
|
| 309 |
+
input_ids = input_ids * sot_token_id
|
| 310 |
+
beam_scorer = BeamSearchScorer(
|
| 311 |
+
batch_size=batch_size,
|
| 312 |
+
num_beams=num_beams,
|
| 313 |
+
device=device,
|
| 314 |
+
num_beam_groups=num_beam_groups,
|
| 315 |
+
)
|
| 316 |
+
# instantiate logits processors
|
| 317 |
+
logits_processor = (
|
| 318 |
+
LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
|
| 319 |
+
if logit_processor is None
|
| 320 |
+
else logit_processor
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
batch_size = len(beam_scorer._beam_hyps)
|
| 324 |
+
num_beams = beam_scorer.num_beams
|
| 325 |
+
num_beam_groups = beam_scorer.num_beam_groups
|
| 326 |
+
num_sub_beams = num_beams // num_beam_groups
|
| 327 |
+
batch_beam_size, cur_len = input_ids.shape
|
| 328 |
+
beam_indices = None
|
| 329 |
+
|
| 330 |
+
if num_beams * batch_size != batch_beam_size:
|
| 331 |
+
raise ValueError(
|
| 332 |
+
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
| 336 |
+
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
|
| 337 |
+
# the same group don't produce same tokens everytime.
|
| 338 |
+
beam_scores[:, ::num_sub_beams] = 0
|
| 339 |
+
beam_scores = beam_scores.view((batch_size * num_beams,))
|
| 340 |
+
|
| 341 |
+
while True:
|
| 342 |
+
|
| 343 |
+
# predicted tokens in cur_len step
|
| 344 |
+
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
|
| 345 |
+
|
| 346 |
+
# indices which will form the beams in the next time step
|
| 347 |
+
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
|
| 348 |
+
|
| 349 |
+
# do one decoder step on all beams of all sentences in batch
|
| 350 |
+
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
|
| 351 |
+
outputs = self(
|
| 352 |
+
model_inputs['images'],
|
| 353 |
+
model_inputs['text'],
|
| 354 |
+
embed_cls=False,
|
| 355 |
+
image_latent=image_latent,
|
| 356 |
+
image_embs=image_embs
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
for beam_group_idx in range(num_beam_groups):
|
| 360 |
+
group_start_idx = beam_group_idx * num_sub_beams
|
| 361 |
+
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
| 362 |
+
group_size = group_end_idx - group_start_idx
|
| 363 |
+
|
| 364 |
+
# indices of beams of current group among all sentences in batch
|
| 365 |
+
batch_group_indices = []
|
| 366 |
+
|
| 367 |
+
for batch_idx in range(batch_size):
|
| 368 |
+
batch_group_indices.extend(
|
| 369 |
+
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
|
| 370 |
+
)
|
| 371 |
+
group_input_ids = input_ids[batch_group_indices]
|
| 372 |
+
|
| 373 |
+
# select outputs of beams of currentg group only
|
| 374 |
+
next_token_logits = outputs['logits'][batch_group_indices, -1, :]
|
| 375 |
+
vocab_size = next_token_logits.shape[-1]
|
| 376 |
+
|
| 377 |
+
next_token_scores_processed = logits_processor(
|
| 378 |
+
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
|
| 379 |
+
)
|
| 380 |
+
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
|
| 381 |
+
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
|
| 382 |
+
|
| 383 |
+
# reshape for beam search
|
| 384 |
+
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
|
| 385 |
+
|
| 386 |
+
next_token_scores, next_tokens = torch.topk(
|
| 387 |
+
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
| 391 |
+
next_tokens = next_tokens % vocab_size
|
| 392 |
+
|
| 393 |
+
# stateless
|
| 394 |
+
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
| 395 |
+
beam_outputs = beam_scorer.process(
|
| 396 |
+
group_input_ids,
|
| 397 |
+
next_token_scores,
|
| 398 |
+
next_tokens,
|
| 399 |
+
next_indices,
|
| 400 |
+
pad_token_id=pad_token_id,
|
| 401 |
+
eos_token_id=eos_token_id,
|
| 402 |
+
beam_indices=process_beam_indices,
|
| 403 |
+
)
|
| 404 |
+
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
| 405 |
+
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
| 406 |
+
beam_idx = beam_outputs["next_beam_indices"]
|
| 407 |
+
|
| 408 |
+
input_ids[batch_group_indices] = group_input_ids[beam_idx]
|
| 409 |
+
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
| 410 |
+
current_tokens[batch_group_indices] = group_input_ids[:, -1]
|
| 411 |
+
|
| 412 |
+
# (beam_idx // group_size) -> batch_idx
|
| 413 |
+
# (beam_idx % group_size) -> offset of idx inside the group
|
| 414 |
+
reordering_indices[batch_group_indices] = (
|
| 415 |
+
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
|
| 419 |
+
|
| 420 |
+
# increase cur_len
|
| 421 |
+
cur_len = cur_len + 1
|
| 422 |
+
if beam_scorer.is_done or stopping_criteria(input_ids, None):
|
| 423 |
+
break
|
| 424 |
+
|
| 425 |
+
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
| 426 |
+
sequence_outputs = beam_scorer.finalize(
|
| 427 |
+
input_ids,
|
| 428 |
+
beam_scores,
|
| 429 |
+
next_tokens,
|
| 430 |
+
next_indices,
|
| 431 |
+
pad_token_id=pad_token_id,
|
| 432 |
+
eos_token_id=eos_token_id,
|
| 433 |
+
max_length=stopping_criteria.max_length,
|
| 434 |
+
beam_indices=final_beam_indices,
|
| 435 |
+
)
|
| 436 |
+
return sequence_outputs['sequences']
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
|
| 440 |
+
if past:
|
| 441 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 442 |
+
|
| 443 |
+
attention_mask = kwargs.get("attention_mask", None)
|
| 444 |
+
position_ids = kwargs.get("position_ids", None)
|
| 445 |
+
|
| 446 |
+
if attention_mask is not None and position_ids is None:
|
| 447 |
+
# create position_ids on the fly for batch generation
|
| 448 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 449 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 450 |
+
else:
|
| 451 |
+
position_ids = None
|
| 452 |
+
return {
|
| 453 |
+
"text": input_ids,
|
| 454 |
+
"images": image_inputs,
|
| 455 |
+
"past_key_values": past,
|
| 456 |
+
"position_ids": position_ids,
|
| 457 |
+
"attention_mask": attention_mask,
|
| 458 |
+
}
|
diffsynth/extensions/ImageQualityMetric/open_clip/constants.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
| 2 |
+
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
diffsynth/extensions/ImageQualityMetric/open_clip/factory.py
ADDED
|
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import pathlib
|
| 5 |
+
import re
|
| 6 |
+
from copy import deepcopy
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
# from turtle import forward
|
| 9 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
| 14 |
+
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
|
| 15 |
+
resize_pos_embed, get_cast_dtype
|
| 16 |
+
from .coca_model import CoCa
|
| 17 |
+
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
|
| 18 |
+
from .openai import load_openai_model
|
| 19 |
+
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
|
| 20 |
+
from .transform import image_transform, AugmentationCfg
|
| 21 |
+
from .tokenizer import HFTokenizer, SimpleTokenizer
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
HF_HUB_PREFIX = 'hf-hub:'
|
| 25 |
+
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
| 26 |
+
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _natural_key(string_):
|
| 30 |
+
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _rescan_model_configs():
|
| 34 |
+
global _MODEL_CONFIGS
|
| 35 |
+
|
| 36 |
+
config_ext = ('.json',)
|
| 37 |
+
config_files = []
|
| 38 |
+
for config_path in _MODEL_CONFIG_PATHS:
|
| 39 |
+
if config_path.is_file() and config_path.suffix in config_ext:
|
| 40 |
+
config_files.append(config_path)
|
| 41 |
+
elif config_path.is_dir():
|
| 42 |
+
for ext in config_ext:
|
| 43 |
+
config_files.extend(config_path.glob(f'*{ext}'))
|
| 44 |
+
|
| 45 |
+
for cf in config_files:
|
| 46 |
+
with open(cf, 'r') as f:
|
| 47 |
+
model_cfg = json.load(f)
|
| 48 |
+
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
|
| 49 |
+
_MODEL_CONFIGS[cf.stem] = model_cfg
|
| 50 |
+
|
| 51 |
+
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
_rescan_model_configs() # initial populate of model config registry
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def list_models():
|
| 58 |
+
""" enumerate available model architectures based on config files """
|
| 59 |
+
return list(_MODEL_CONFIGS.keys())
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def add_model_config(path):
|
| 63 |
+
""" add model config path or file and update registry """
|
| 64 |
+
if not isinstance(path, Path):
|
| 65 |
+
path = Path(path)
|
| 66 |
+
_MODEL_CONFIG_PATHS.append(path)
|
| 67 |
+
_rescan_model_configs()
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_model_config(model_name):
|
| 71 |
+
if model_name in _MODEL_CONFIGS:
|
| 72 |
+
return deepcopy(_MODEL_CONFIGS[model_name])
|
| 73 |
+
else:
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_tokenizer(model_name, open_clip_bpe_path=None):
|
| 78 |
+
if model_name.startswith(HF_HUB_PREFIX):
|
| 79 |
+
tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
|
| 80 |
+
else:
|
| 81 |
+
config = get_model_config(model_name)
|
| 82 |
+
tokenizer = HFTokenizer(
|
| 83 |
+
config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else SimpleTokenizer(open_clip_bpe_path)
|
| 84 |
+
return tokenizer
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def load_state_dict(checkpoint_path: str, map_location='cpu'):
|
| 88 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
| 89 |
+
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
| 90 |
+
state_dict = checkpoint['state_dict']
|
| 91 |
+
else:
|
| 92 |
+
state_dict = checkpoint
|
| 93 |
+
if next(iter(state_dict.items()))[0].startswith('module'):
|
| 94 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
| 95 |
+
return state_dict
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def load_checkpoint(model, checkpoint_path, strict=True):
|
| 99 |
+
state_dict = load_state_dict(checkpoint_path)
|
| 100 |
+
# detect old format and make compatible with new format
|
| 101 |
+
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
|
| 102 |
+
state_dict = convert_to_custom_text_state_dict(state_dict)
|
| 103 |
+
resize_pos_embed(state_dict, model)
|
| 104 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
| 105 |
+
return incompatible_keys
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def create_model(
|
| 109 |
+
model_name: str,
|
| 110 |
+
pretrained: Optional[str] = None,
|
| 111 |
+
precision: str = 'fp32',
|
| 112 |
+
device: Union[str, torch.device] = 'cpu',
|
| 113 |
+
jit: bool = False,
|
| 114 |
+
force_quick_gelu: bool = False,
|
| 115 |
+
force_custom_text: bool = False,
|
| 116 |
+
force_patch_dropout: Optional[float] = None,
|
| 117 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
| 118 |
+
pretrained_image: bool = False,
|
| 119 |
+
pretrained_hf: bool = True,
|
| 120 |
+
cache_dir: Optional[str] = None,
|
| 121 |
+
output_dict: Optional[bool] = None,
|
| 122 |
+
require_pretrained: bool = False,
|
| 123 |
+
):
|
| 124 |
+
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
|
| 125 |
+
if has_hf_hub_prefix:
|
| 126 |
+
model_id = model_name[len(HF_HUB_PREFIX):]
|
| 127 |
+
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
|
| 128 |
+
config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
|
| 129 |
+
|
| 130 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
| 131 |
+
config = json.load(f)
|
| 132 |
+
pretrained_cfg = config['preprocess_cfg']
|
| 133 |
+
model_cfg = config['model_cfg']
|
| 134 |
+
else:
|
| 135 |
+
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
|
| 136 |
+
checkpoint_path = None
|
| 137 |
+
pretrained_cfg = {}
|
| 138 |
+
model_cfg = None
|
| 139 |
+
|
| 140 |
+
if isinstance(device, str):
|
| 141 |
+
device = torch.device(device)
|
| 142 |
+
|
| 143 |
+
if pretrained and pretrained.lower() == 'openai':
|
| 144 |
+
logging.info(f'Loading pretrained {model_name} from OpenAI.')
|
| 145 |
+
model = load_openai_model(
|
| 146 |
+
model_name,
|
| 147 |
+
precision=precision,
|
| 148 |
+
device=device,
|
| 149 |
+
jit=jit,
|
| 150 |
+
cache_dir=cache_dir,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# to always output dict even if it is clip
|
| 154 |
+
if output_dict and hasattr(model, "output_dict"):
|
| 155 |
+
model.output_dict = True
|
| 156 |
+
else:
|
| 157 |
+
model_cfg = model_cfg or get_model_config(model_name)
|
| 158 |
+
if model_cfg is not None:
|
| 159 |
+
logging.info(f'Loaded {model_name} model config.')
|
| 160 |
+
else:
|
| 161 |
+
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
|
| 162 |
+
raise RuntimeError(f'Model config for {model_name} not found.')
|
| 163 |
+
|
| 164 |
+
if force_quick_gelu:
|
| 165 |
+
# override for use of QuickGELU on non-OpenAI transformer models
|
| 166 |
+
model_cfg["quick_gelu"] = True
|
| 167 |
+
|
| 168 |
+
if force_patch_dropout is not None:
|
| 169 |
+
# override the default patch dropout value
|
| 170 |
+
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
|
| 171 |
+
|
| 172 |
+
if force_image_size is not None:
|
| 173 |
+
# override model config's image size
|
| 174 |
+
model_cfg["vision_cfg"]["image_size"] = force_image_size
|
| 175 |
+
|
| 176 |
+
if pretrained_image:
|
| 177 |
+
if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
|
| 178 |
+
# pretrained weight loading for timm models set via vision_cfg
|
| 179 |
+
model_cfg['vision_cfg']['timm_model_pretrained'] = True
|
| 180 |
+
else:
|
| 181 |
+
assert False, 'pretrained image towers currently only supported for timm models'
|
| 182 |
+
|
| 183 |
+
cast_dtype = get_cast_dtype(precision)
|
| 184 |
+
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
|
| 185 |
+
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
|
| 186 |
+
|
| 187 |
+
if custom_text:
|
| 188 |
+
if is_hf_model:
|
| 189 |
+
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
|
| 190 |
+
if "coca" in model_name:
|
| 191 |
+
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
|
| 192 |
+
else:
|
| 193 |
+
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
|
| 194 |
+
else:
|
| 195 |
+
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
|
| 196 |
+
|
| 197 |
+
pretrained_loaded = False
|
| 198 |
+
if pretrained:
|
| 199 |
+
checkpoint_path = ''
|
| 200 |
+
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
|
| 201 |
+
if pretrained_cfg:
|
| 202 |
+
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
|
| 203 |
+
elif os.path.exists(pretrained):
|
| 204 |
+
checkpoint_path = pretrained
|
| 205 |
+
|
| 206 |
+
if checkpoint_path:
|
| 207 |
+
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
| 208 |
+
load_checkpoint(model, checkpoint_path)
|
| 209 |
+
else:
|
| 210 |
+
error_str = (
|
| 211 |
+
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
|
| 212 |
+
f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
|
| 213 |
+
logging.warning(error_str)
|
| 214 |
+
raise RuntimeError(error_str)
|
| 215 |
+
pretrained_loaded = True
|
| 216 |
+
elif has_hf_hub_prefix:
|
| 217 |
+
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
| 218 |
+
load_checkpoint(model, checkpoint_path)
|
| 219 |
+
pretrained_loaded = True
|
| 220 |
+
|
| 221 |
+
if require_pretrained and not pretrained_loaded:
|
| 222 |
+
# callers of create_model_from_pretrained always expect pretrained weights
|
| 223 |
+
raise RuntimeError(
|
| 224 |
+
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
|
| 225 |
+
|
| 226 |
+
model.to(device=device)
|
| 227 |
+
if precision in ("fp16", "bf16"):
|
| 228 |
+
convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
|
| 229 |
+
|
| 230 |
+
# set image / mean metadata from pretrained_cfg if available, or use default
|
| 231 |
+
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
|
| 232 |
+
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
|
| 233 |
+
|
| 234 |
+
# to always output dict even if it is clip
|
| 235 |
+
if output_dict and hasattr(model, "output_dict"):
|
| 236 |
+
model.output_dict = True
|
| 237 |
+
|
| 238 |
+
if jit:
|
| 239 |
+
model = torch.jit.script(model)
|
| 240 |
+
|
| 241 |
+
return model
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def create_loss(args):
|
| 245 |
+
if args.distill:
|
| 246 |
+
return DistillClipLoss(
|
| 247 |
+
local_loss=args.local_loss,
|
| 248 |
+
gather_with_grad=args.gather_with_grad,
|
| 249 |
+
cache_labels=True,
|
| 250 |
+
rank=args.rank,
|
| 251 |
+
world_size=args.world_size,
|
| 252 |
+
use_horovod=args.horovod,
|
| 253 |
+
)
|
| 254 |
+
elif "coca" in args.model.lower():
|
| 255 |
+
return CoCaLoss(
|
| 256 |
+
caption_loss_weight=args.coca_caption_loss_weight,
|
| 257 |
+
clip_loss_weight=args.coca_contrastive_loss_weight,
|
| 258 |
+
local_loss=args.local_loss,
|
| 259 |
+
gather_with_grad=args.gather_with_grad,
|
| 260 |
+
cache_labels=True,
|
| 261 |
+
rank=args.rank,
|
| 262 |
+
world_size=args.world_size,
|
| 263 |
+
use_horovod=args.horovod,
|
| 264 |
+
)
|
| 265 |
+
return ClipLoss(
|
| 266 |
+
local_loss=args.local_loss,
|
| 267 |
+
gather_with_grad=args.gather_with_grad,
|
| 268 |
+
cache_labels=True,
|
| 269 |
+
rank=args.rank,
|
| 270 |
+
world_size=args.world_size,
|
| 271 |
+
use_horovod=args.horovod,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
class MLP(torch.nn.Module):
|
| 275 |
+
def __init__(self, input_size):
|
| 276 |
+
super().__init__()
|
| 277 |
+
self.input_size = input_size
|
| 278 |
+
self.layers = torch.nn.Sequential(
|
| 279 |
+
torch.nn.Linear(self.input_size, 1024),
|
| 280 |
+
torch.nn.Dropout(0.2),
|
| 281 |
+
torch.nn.Linear(1024, 128),
|
| 282 |
+
torch.nn.Dropout(0.2),
|
| 283 |
+
torch.nn.Linear(128, 64),
|
| 284 |
+
torch.nn.Dropout(0.1),
|
| 285 |
+
torch.nn.Linear(64, 16),
|
| 286 |
+
torch.nn.Linear(16, 1)
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
def forward(self, x):
|
| 290 |
+
return self.layers(x)
|
| 291 |
+
|
| 292 |
+
# class semantic_head(torch.nn.Module):
|
| 293 |
+
# def __init__(self, input_size):
|
| 294 |
+
# super().__init__()
|
| 295 |
+
# self.input_size = input_size # for ViT-L-14 is 1024
|
| 296 |
+
# self.seg_head = torch.nn.Sequential(
|
| 297 |
+
# torch.nn.Linear(input_size, 128),
|
| 298 |
+
# torch.nn.Dropout(0.2),
|
| 299 |
+
# torch.nn.Linear(128, 64),
|
| 300 |
+
# torch.nn.Dropout(0.1),
|
| 301 |
+
# torch.nn.Linear(64, 16),
|
| 302 |
+
# torch.nn.Linear(16, 1),
|
| 303 |
+
# )
|
| 304 |
+
# self.sigmoid = torch.nn.Sigmoid()
|
| 305 |
+
|
| 306 |
+
# def forward(self, x):
|
| 307 |
+
# return self.sigmoid(self.seg_head(x))
|
| 308 |
+
|
| 309 |
+
def create_model_and_transforms(
|
| 310 |
+
model_name: str,
|
| 311 |
+
pretrained: Optional[str] = None,
|
| 312 |
+
precision: str = 'fp32',
|
| 313 |
+
device: Union[str, torch.device] = 'cpu',
|
| 314 |
+
jit: bool = False,
|
| 315 |
+
force_quick_gelu: bool = False,
|
| 316 |
+
force_custom_text: bool = False,
|
| 317 |
+
force_patch_dropout: Optional[float] = None,
|
| 318 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
| 319 |
+
pretrained_image: bool = False,
|
| 320 |
+
pretrained_hf: bool = True,
|
| 321 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
| 322 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
| 323 |
+
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
| 324 |
+
cache_dir: Optional[str] = None,
|
| 325 |
+
light_augmentation = False,
|
| 326 |
+
output_dict: Optional[bool] = None,
|
| 327 |
+
with_score_predictor: bool = False,
|
| 328 |
+
with_region_predictor: bool = False
|
| 329 |
+
):
|
| 330 |
+
model = create_model(
|
| 331 |
+
model_name,
|
| 332 |
+
pretrained,
|
| 333 |
+
precision=precision,
|
| 334 |
+
device=device,
|
| 335 |
+
jit=jit,
|
| 336 |
+
force_quick_gelu=force_quick_gelu,
|
| 337 |
+
force_custom_text=force_custom_text,
|
| 338 |
+
force_patch_dropout=force_patch_dropout,
|
| 339 |
+
force_image_size=force_image_size,
|
| 340 |
+
pretrained_image=pretrained_image,
|
| 341 |
+
pretrained_hf=pretrained_hf,
|
| 342 |
+
cache_dir=cache_dir,
|
| 343 |
+
output_dict=output_dict,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
| 347 |
+
image_std = image_std or getattr(model.visual, 'image_std', None)
|
| 348 |
+
|
| 349 |
+
if with_score_predictor:
|
| 350 |
+
model.score_predictor = MLP(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype)
|
| 351 |
+
|
| 352 |
+
if with_region_predictor:
|
| 353 |
+
# model.region_predictor = semantic_head(model.visual.proj.size(1)).to(device=device, dtype=model.visual.proj.dtype)
|
| 354 |
+
model.region_predictor = torch.nn.Linear(model.visual.proj.size(0), 1).to(device=device, dtype=model.visual.proj.dtype)
|
| 355 |
+
# preprocess_train = image_transform_region(
|
| 356 |
+
# model.visual.image_size,
|
| 357 |
+
# is_train=True,
|
| 358 |
+
# mean=image_mean,
|
| 359 |
+
# std=image_std
|
| 360 |
+
# )
|
| 361 |
+
# preprocess_val = image_transform_region(
|
| 362 |
+
# model.visual.image_size,
|
| 363 |
+
# is_train=False,
|
| 364 |
+
# mean=image_mean,
|
| 365 |
+
# std=image_std
|
| 366 |
+
# )
|
| 367 |
+
|
| 368 |
+
if light_augmentation:
|
| 369 |
+
preprocess_val = image_transform(
|
| 370 |
+
model.visual.image_size,
|
| 371 |
+
is_train=False,
|
| 372 |
+
mean=image_mean,
|
| 373 |
+
std=image_std,
|
| 374 |
+
resize_longest_max=True,
|
| 375 |
+
)
|
| 376 |
+
preprocess_train = preprocess_val
|
| 377 |
+
else:
|
| 378 |
+
preprocess_train = image_transform(
|
| 379 |
+
model.visual.image_size,
|
| 380 |
+
is_train=True,
|
| 381 |
+
mean=image_mean,
|
| 382 |
+
std=image_std
|
| 383 |
+
)
|
| 384 |
+
preprocess_val = image_transform(
|
| 385 |
+
model.visual.image_size,
|
| 386 |
+
is_train=False,
|
| 387 |
+
mean=image_mean,
|
| 388 |
+
std=image_std
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
return model, preprocess_train, preprocess_val
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def create_model_from_pretrained(
|
| 395 |
+
model_name: str,
|
| 396 |
+
pretrained: Optional[str] = None,
|
| 397 |
+
precision: str = 'fp32',
|
| 398 |
+
device: Union[str, torch.device] = 'cpu',
|
| 399 |
+
jit: bool = False,
|
| 400 |
+
force_quick_gelu: bool = False,
|
| 401 |
+
force_custom_text: bool = False,
|
| 402 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
| 403 |
+
return_transform: bool = True,
|
| 404 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
| 405 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
| 406 |
+
cache_dir: Optional[str] = None,
|
| 407 |
+
):
|
| 408 |
+
model = create_model(
|
| 409 |
+
model_name,
|
| 410 |
+
pretrained,
|
| 411 |
+
precision=precision,
|
| 412 |
+
device=device,
|
| 413 |
+
jit=jit,
|
| 414 |
+
force_quick_gelu=force_quick_gelu,
|
| 415 |
+
force_custom_text=force_custom_text,
|
| 416 |
+
force_image_size=force_image_size,
|
| 417 |
+
cache_dir=cache_dir,
|
| 418 |
+
require_pretrained=True,
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
if not return_transform:
|
| 422 |
+
return model
|
| 423 |
+
|
| 424 |
+
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
| 425 |
+
image_std = image_std or getattr(model.visual, 'image_std', None)
|
| 426 |
+
preprocess = image_transform(
|
| 427 |
+
model.visual.image_size,
|
| 428 |
+
is_train=False,
|
| 429 |
+
mean=image_mean,
|
| 430 |
+
std=image_std,
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
return model, preprocess
|
diffsynth/extensions/ImageQualityMetric/open_clip/generation_utils.py
ADDED
|
File without changes
|
diffsynth/extensions/ImageQualityMetric/open_clip/hf_configs.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# HF architecture dict:
|
| 2 |
+
arch_dict = {
|
| 3 |
+
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
|
| 4 |
+
"roberta": {
|
| 5 |
+
"config_names": {
|
| 6 |
+
"context_length": "max_position_embeddings",
|
| 7 |
+
"vocab_size": "vocab_size",
|
| 8 |
+
"width": "hidden_size",
|
| 9 |
+
"heads": "num_attention_heads",
|
| 10 |
+
"layers": "num_hidden_layers",
|
| 11 |
+
"layer_attr": "layer",
|
| 12 |
+
"token_embeddings_attr": "embeddings"
|
| 13 |
+
},
|
| 14 |
+
"pooler": "mean_pooler",
|
| 15 |
+
},
|
| 16 |
+
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
|
| 17 |
+
"xlm-roberta": {
|
| 18 |
+
"config_names": {
|
| 19 |
+
"context_length": "max_position_embeddings",
|
| 20 |
+
"vocab_size": "vocab_size",
|
| 21 |
+
"width": "hidden_size",
|
| 22 |
+
"heads": "num_attention_heads",
|
| 23 |
+
"layers": "num_hidden_layers",
|
| 24 |
+
"layer_attr": "layer",
|
| 25 |
+
"token_embeddings_attr": "embeddings"
|
| 26 |
+
},
|
| 27 |
+
"pooler": "mean_pooler",
|
| 28 |
+
},
|
| 29 |
+
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
|
| 30 |
+
"mt5": {
|
| 31 |
+
"config_names": {
|
| 32 |
+
# unlimited seqlen
|
| 33 |
+
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
|
| 34 |
+
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
|
| 35 |
+
"context_length": "",
|
| 36 |
+
"vocab_size": "vocab_size",
|
| 37 |
+
"width": "d_model",
|
| 38 |
+
"heads": "num_heads",
|
| 39 |
+
"layers": "num_layers",
|
| 40 |
+
"layer_attr": "block",
|
| 41 |
+
"token_embeddings_attr": "embed_tokens"
|
| 42 |
+
},
|
| 43 |
+
"pooler": "mean_pooler",
|
| 44 |
+
},
|
| 45 |
+
}
|
diffsynth/extensions/ImageQualityMetric/open_clip/hf_model.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" huggingface model adapter
|
| 2 |
+
|
| 3 |
+
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from torch import TensorType
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
import transformers
|
| 14 |
+
from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
|
| 15 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
|
| 16 |
+
BaseModelOutputWithPoolingAndCrossAttentions
|
| 17 |
+
except ImportError as e:
|
| 18 |
+
transformers = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class BaseModelOutput:
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PretrainedConfig:
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
from .hf_configs import arch_dict
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# utils
|
| 32 |
+
def _camel2snake(s):
|
| 33 |
+
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# TODO: ?last - for gpt-like models
|
| 37 |
+
_POOLERS = {}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def register_pooler(cls):
|
| 41 |
+
"""Decorator registering pooler class"""
|
| 42 |
+
_POOLERS[_camel2snake(cls.__name__)] = cls
|
| 43 |
+
return cls
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@register_pooler
|
| 47 |
+
class MeanPooler(nn.Module):
|
| 48 |
+
"""Mean pooling"""
|
| 49 |
+
|
| 50 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
| 51 |
+
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
|
| 52 |
+
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@register_pooler
|
| 56 |
+
class MaxPooler(nn.Module):
|
| 57 |
+
"""Max pooling"""
|
| 58 |
+
|
| 59 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
| 60 |
+
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
|
| 61 |
+
return masked_output.max(1).values
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@register_pooler
|
| 65 |
+
class ClsPooler(nn.Module):
|
| 66 |
+
"""CLS token pooling"""
|
| 67 |
+
|
| 68 |
+
def __init__(self, use_pooler_output=True):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.cls_token_position = 0
|
| 71 |
+
self.use_pooler_output = use_pooler_output
|
| 72 |
+
|
| 73 |
+
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
|
| 74 |
+
if (self.use_pooler_output and
|
| 75 |
+
isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
|
| 76 |
+
(x.pooler_output is not None)
|
| 77 |
+
):
|
| 78 |
+
return x.pooler_output
|
| 79 |
+
|
| 80 |
+
return x.last_hidden_state[:, self.cls_token_position, :]
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class HFTextEncoder(nn.Module):
|
| 84 |
+
"""HuggingFace model adapter"""
|
| 85 |
+
output_tokens: torch.jit.Final[bool]
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
model_name_or_path: str,
|
| 90 |
+
output_dim: int,
|
| 91 |
+
config: PretrainedConfig = None,
|
| 92 |
+
pooler_type: str = None,
|
| 93 |
+
proj: str = None,
|
| 94 |
+
pretrained: bool = True,
|
| 95 |
+
output_tokens: bool = False,
|
| 96 |
+
):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.output_tokens = output_tokens
|
| 99 |
+
self.output_dim = output_dim
|
| 100 |
+
|
| 101 |
+
# TODO: find better way to get this information
|
| 102 |
+
uses_transformer_pooler = (pooler_type == "cls_pooler")
|
| 103 |
+
|
| 104 |
+
if transformers is None:
|
| 105 |
+
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
|
| 106 |
+
if config is None:
|
| 107 |
+
self.config = AutoConfig.from_pretrained(model_name_or_path)
|
| 108 |
+
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
|
| 109 |
+
AutoModel.from_config, self.config)
|
| 110 |
+
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
|
| 111 |
+
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
|
| 112 |
+
self.transformer = create_func(model_args)
|
| 113 |
+
self.transformer = self.transformer.encoder
|
| 114 |
+
else:
|
| 115 |
+
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
|
| 116 |
+
else:
|
| 117 |
+
self.config = config
|
| 118 |
+
self.transformer = AutoModel.from_config(config)
|
| 119 |
+
if pooler_type is None: # get default arch pooler
|
| 120 |
+
pooler_type = (arch_dict[self.config.model_type]["pooler"])
|
| 121 |
+
|
| 122 |
+
self.pooler = _POOLERS[pooler_type]()
|
| 123 |
+
|
| 124 |
+
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
|
| 125 |
+
if (d_model == output_dim) and (proj is None): # do we always need a proj?
|
| 126 |
+
self.proj = nn.Identity()
|
| 127 |
+
elif proj == 'linear':
|
| 128 |
+
self.proj = nn.Linear(d_model, output_dim, bias=False)
|
| 129 |
+
elif proj == 'mlp':
|
| 130 |
+
hidden_size = (d_model + output_dim) // 2
|
| 131 |
+
self.proj = nn.Sequential(
|
| 132 |
+
nn.Linear(d_model, hidden_size, bias=False),
|
| 133 |
+
nn.GELU(),
|
| 134 |
+
nn.Linear(hidden_size, output_dim, bias=False),
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
def forward(self, x: TensorType):
|
| 138 |
+
attn_mask = (x != self.config.pad_token_id).long()
|
| 139 |
+
out = self.transformer(input_ids=x, attention_mask=attn_mask)
|
| 140 |
+
pooled_out = self.pooler(out, attn_mask)
|
| 141 |
+
projected = self.proj(pooled_out)
|
| 142 |
+
|
| 143 |
+
seq_len = out.last_hidden_state.shape[1]
|
| 144 |
+
tokens = (
|
| 145 |
+
out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
|
| 146 |
+
if type(self.pooler) == ClsPooler
|
| 147 |
+
else out.last_hidden_state
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
if self.output_tokens:
|
| 151 |
+
return projected, tokens
|
| 152 |
+
return projected
|
| 153 |
+
|
| 154 |
+
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
| 155 |
+
if not unlocked_layers: # full freezing
|
| 156 |
+
for n, p in self.transformer.named_parameters():
|
| 157 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
| 158 |
+
return
|
| 159 |
+
|
| 160 |
+
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
|
| 161 |
+
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
|
| 162 |
+
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
|
| 163 |
+
embeddings = getattr(
|
| 164 |
+
self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
|
| 165 |
+
modules = [embeddings, *layer_list][:-unlocked_layers]
|
| 166 |
+
# freeze layers
|
| 167 |
+
for module in modules:
|
| 168 |
+
for n, p in module.named_parameters():
|
| 169 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
| 170 |
+
|
| 171 |
+
@torch.jit.ignore
|
| 172 |
+
def set_grad_checkpointing(self, enable=True):
|
| 173 |
+
self.transformer.gradient_checkpointing_enable()
|
| 174 |
+
|
| 175 |
+
def init_parameters(self):
|
| 176 |
+
pass
|
diffsynth/extensions/ImageQualityMetric/open_clip/loss.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
import torch.distributed.nn
|
| 8 |
+
from torch import distributed as dist
|
| 9 |
+
|
| 10 |
+
has_distributed = True
|
| 11 |
+
except ImportError:
|
| 12 |
+
has_distributed = False
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
import horovod.torch as hvd
|
| 16 |
+
except ImportError:
|
| 17 |
+
hvd = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def gather_features(
|
| 21 |
+
image_features,
|
| 22 |
+
text_features,
|
| 23 |
+
local_loss=False,
|
| 24 |
+
gather_with_grad=False,
|
| 25 |
+
rank=0,
|
| 26 |
+
world_size=1,
|
| 27 |
+
use_horovod=False
|
| 28 |
+
):
|
| 29 |
+
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
|
| 30 |
+
if use_horovod:
|
| 31 |
+
assert hvd is not None, 'Please install horovod'
|
| 32 |
+
if gather_with_grad:
|
| 33 |
+
all_image_features = hvd.allgather(image_features)
|
| 34 |
+
all_text_features = hvd.allgather(text_features)
|
| 35 |
+
else:
|
| 36 |
+
with torch.no_grad():
|
| 37 |
+
all_image_features = hvd.allgather(image_features)
|
| 38 |
+
all_text_features = hvd.allgather(text_features)
|
| 39 |
+
if not local_loss:
|
| 40 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
| 41 |
+
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
|
| 42 |
+
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
|
| 43 |
+
gathered_image_features[rank] = image_features
|
| 44 |
+
gathered_text_features[rank] = text_features
|
| 45 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
| 46 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
| 47 |
+
else:
|
| 48 |
+
# We gather tensors from all gpus
|
| 49 |
+
if gather_with_grad:
|
| 50 |
+
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
| 51 |
+
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
| 52 |
+
else:
|
| 53 |
+
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
| 54 |
+
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
| 55 |
+
dist.all_gather(gathered_image_features, image_features)
|
| 56 |
+
dist.all_gather(gathered_text_features, text_features)
|
| 57 |
+
if not local_loss:
|
| 58 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
| 59 |
+
gathered_image_features[rank] = image_features
|
| 60 |
+
gathered_text_features[rank] = text_features
|
| 61 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
| 62 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
| 63 |
+
|
| 64 |
+
return all_image_features, all_text_features
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class ClipLoss(nn.Module):
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
local_loss=False,
|
| 72 |
+
gather_with_grad=False,
|
| 73 |
+
cache_labels=False,
|
| 74 |
+
rank=0,
|
| 75 |
+
world_size=1,
|
| 76 |
+
use_horovod=False,
|
| 77 |
+
):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.local_loss = local_loss
|
| 80 |
+
self.gather_with_grad = gather_with_grad
|
| 81 |
+
self.cache_labels = cache_labels
|
| 82 |
+
self.rank = rank
|
| 83 |
+
self.world_size = world_size
|
| 84 |
+
self.use_horovod = use_horovod
|
| 85 |
+
|
| 86 |
+
# cache state
|
| 87 |
+
self.prev_num_logits = 0
|
| 88 |
+
self.labels = {}
|
| 89 |
+
|
| 90 |
+
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
|
| 91 |
+
# calculated ground-truth and cache if enabled
|
| 92 |
+
if self.prev_num_logits != num_logits or device not in self.labels:
|
| 93 |
+
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
| 94 |
+
if self.world_size > 1 and self.local_loss:
|
| 95 |
+
labels = labels + num_logits * self.rank
|
| 96 |
+
if self.cache_labels:
|
| 97 |
+
self.labels[device] = labels
|
| 98 |
+
self.prev_num_logits = num_logits
|
| 99 |
+
else:
|
| 100 |
+
labels = self.labels[device]
|
| 101 |
+
return labels
|
| 102 |
+
|
| 103 |
+
def get_logits(self, image_features, text_features, logit_scale):
|
| 104 |
+
if self.world_size > 1:
|
| 105 |
+
all_image_features, all_text_features = gather_features(
|
| 106 |
+
image_features, text_features,
|
| 107 |
+
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
|
| 108 |
+
|
| 109 |
+
if self.local_loss:
|
| 110 |
+
logits_per_image = logit_scale * image_features @ all_text_features.T
|
| 111 |
+
logits_per_text = logit_scale * text_features @ all_image_features.T
|
| 112 |
+
else:
|
| 113 |
+
logits_per_image = logit_scale * all_image_features @ all_text_features.T
|
| 114 |
+
logits_per_text = logits_per_image.T
|
| 115 |
+
else:
|
| 116 |
+
logits_per_image = logit_scale * image_features @ text_features.T
|
| 117 |
+
logits_per_text = logit_scale * text_features @ image_features.T
|
| 118 |
+
|
| 119 |
+
return logits_per_image, logits_per_text
|
| 120 |
+
|
| 121 |
+
def forward(self, image_features, text_features, logit_scale, output_dict=False):
|
| 122 |
+
device = image_features.device
|
| 123 |
+
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
|
| 124 |
+
|
| 125 |
+
labels = self.get_ground_truth(device, logits_per_image.shape[0])
|
| 126 |
+
|
| 127 |
+
total_loss = (
|
| 128 |
+
F.cross_entropy(logits_per_image, labels) +
|
| 129 |
+
F.cross_entropy(logits_per_text, labels)
|
| 130 |
+
) / 2
|
| 131 |
+
return total_loss
|
| 132 |
+
|
| 133 |
+
class PreferenceLoss(nn.Module):
|
| 134 |
+
|
| 135 |
+
def forward(self, logits_per_image, num_images, labels):
|
| 136 |
+
|
| 137 |
+
paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
|
| 138 |
+
paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-999)
|
| 139 |
+
|
| 140 |
+
ce_loss = F.cross_entropy(paired_logits, labels)
|
| 141 |
+
return ce_loss
|
| 142 |
+
|
| 143 |
+
class HPSLoss(nn.Module):
|
| 144 |
+
|
| 145 |
+
def forward(self, text_logits, labels):
|
| 146 |
+
|
| 147 |
+
device = text_logits.device
|
| 148 |
+
text_0_logits, text_1_logits = text_logits.chunk(2, dim=-1)
|
| 149 |
+
label_0, label_1 = labels.chunk(2, dim=-1)
|
| 150 |
+
|
| 151 |
+
index = torch.arange(text_0_logits.shape[0], device=device, dtype=torch.long)
|
| 152 |
+
text_0_logits = text_0_logits[index, index]
|
| 153 |
+
text_1_logits = text_1_logits[index, index]
|
| 154 |
+
text_logits = torch.stack([text_0_logits, text_1_logits], dim=-1)
|
| 155 |
+
text_0_labels = torch.zeros(text_logits.shape[0], device=device, dtype=torch.long)
|
| 156 |
+
text_1_labels = text_0_labels + 1
|
| 157 |
+
|
| 158 |
+
text_0_loss = torch.nn.functional.cross_entropy(text_logits, text_0_labels, reduction="none")
|
| 159 |
+
text_1_loss = torch.nn.functional.cross_entropy(text_logits, text_1_labels, reduction="none")
|
| 160 |
+
|
| 161 |
+
text_loss = label_0 * text_0_loss + label_1 * text_1_loss
|
| 162 |
+
|
| 163 |
+
# absolute_example_weight = 1 / num_per_prompt
|
| 164 |
+
# denominator = absolute_example_weight.sum()
|
| 165 |
+
# weight_per_example = absolute_example_weight / denominator
|
| 166 |
+
# text_loss *= weight_per_example
|
| 167 |
+
|
| 168 |
+
text_loss = text_loss.sum()
|
| 169 |
+
return text_loss
|
| 170 |
+
|
| 171 |
+
class RankingLoss(nn.Module):
|
| 172 |
+
|
| 173 |
+
def forward(self, logits_per_image, num_images, labels, margin = 1.0):
|
| 174 |
+
paired_logits_list = [logit[:,i] for i, logit in enumerate(logits_per_image.split(num_images.tolist()))]
|
| 175 |
+
label_list = [label for label in labels.split(num_images.tolist())]
|
| 176 |
+
# ranked_logits = [torch.index_select(paired_logits_list[i], 0, rank) for i, rank in enumerate(label_list)]
|
| 177 |
+
|
| 178 |
+
paired_logits = pad_sequence(paired_logits_list, batch_first=True, padding_value=-1)
|
| 179 |
+
padded_labels = pad_sequence(label_list, batch_first=True, padding_value=10)
|
| 180 |
+
|
| 181 |
+
# regulized_logits = torch.log(torch.sigmoid(paired_logits))
|
| 182 |
+
|
| 183 |
+
diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2)
|
| 184 |
+
# diff = paired_logits.unsqueeze(1) - paired_logits.unsqueeze(2)
|
| 185 |
+
# diff_label = torch.clamp(padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2), min=-1, max=1)
|
| 186 |
+
diff_label = - (padded_labels.unsqueeze(1) - padded_labels.unsqueeze(2))
|
| 187 |
+
mask = torch.triu(torch.ones(diff.shape[1], diff.shape[1]), diagonal=1).bool().detach()
|
| 188 |
+
|
| 189 |
+
loss = torch.clamp(margin - torch.mul(diff[:, ~mask],diff_label[:,~mask]), min=0).mean()
|
| 190 |
+
return loss
|
| 191 |
+
|
| 192 |
+
class CoCaLoss(ClipLoss):
|
| 193 |
+
def __init__(
|
| 194 |
+
self,
|
| 195 |
+
caption_loss_weight,
|
| 196 |
+
clip_loss_weight,
|
| 197 |
+
pad_id=0, # pad_token for open_clip custom tokenizer
|
| 198 |
+
local_loss=False,
|
| 199 |
+
gather_with_grad=False,
|
| 200 |
+
cache_labels=False,
|
| 201 |
+
rank=0,
|
| 202 |
+
world_size=1,
|
| 203 |
+
use_horovod=False,
|
| 204 |
+
):
|
| 205 |
+
super().__init__(
|
| 206 |
+
local_loss=local_loss,
|
| 207 |
+
gather_with_grad=gather_with_grad,
|
| 208 |
+
cache_labels=cache_labels,
|
| 209 |
+
rank=rank,
|
| 210 |
+
world_size=world_size,
|
| 211 |
+
use_horovod=use_horovod
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
self.clip_loss_weight = clip_loss_weight
|
| 215 |
+
self.caption_loss_weight = caption_loss_weight
|
| 216 |
+
self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
|
| 217 |
+
|
| 218 |
+
def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
|
| 219 |
+
clip_loss = super().forward(image_features, text_features, logit_scale)
|
| 220 |
+
clip_loss = self.clip_loss_weight * clip_loss
|
| 221 |
+
|
| 222 |
+
caption_loss = self.caption_loss(
|
| 223 |
+
logits.permute(0, 2, 1),
|
| 224 |
+
labels,
|
| 225 |
+
)
|
| 226 |
+
caption_loss = caption_loss * self.caption_loss_weight
|
| 227 |
+
|
| 228 |
+
if output_dict:
|
| 229 |
+
return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
|
| 230 |
+
|
| 231 |
+
return clip_loss, caption_loss
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class DistillClipLoss(ClipLoss):
|
| 235 |
+
|
| 236 |
+
def dist_loss(self, teacher_logits, student_logits):
|
| 237 |
+
return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
|
| 238 |
+
|
| 239 |
+
def forward(
|
| 240 |
+
self,
|
| 241 |
+
image_features,
|
| 242 |
+
text_features,
|
| 243 |
+
logit_scale,
|
| 244 |
+
dist_image_features,
|
| 245 |
+
dist_text_features,
|
| 246 |
+
dist_logit_scale,
|
| 247 |
+
output_dict=False,
|
| 248 |
+
):
|
| 249 |
+
logits_per_image, logits_per_text = \
|
| 250 |
+
self.get_logits(image_features, text_features, logit_scale)
|
| 251 |
+
|
| 252 |
+
dist_logits_per_image, dist_logits_per_text = \
|
| 253 |
+
self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
|
| 254 |
+
|
| 255 |
+
labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
|
| 256 |
+
|
| 257 |
+
contrastive_loss = (
|
| 258 |
+
F.cross_entropy(logits_per_image, labels) +
|
| 259 |
+
F.cross_entropy(logits_per_text, labels)
|
| 260 |
+
) / 2
|
| 261 |
+
|
| 262 |
+
distill_loss = (
|
| 263 |
+
self.dist_loss(dist_logits_per_image, logits_per_image) +
|
| 264 |
+
self.dist_loss(dist_logits_per_text, logits_per_text)
|
| 265 |
+
) / 2
|
| 266 |
+
|
| 267 |
+
if output_dict:
|
| 268 |
+
return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
|
| 269 |
+
|
| 270 |
+
return contrastive_loss, distill_loss
|
diffsynth/extensions/ImageQualityMetric/open_clip/model.py
ADDED
|
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" CLIP Model
|
| 2 |
+
|
| 3 |
+
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
| 4 |
+
"""
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
import logging
|
| 7 |
+
import math
|
| 8 |
+
from typing import Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch import nn
|
| 14 |
+
from torch.utils.checkpoint import checkpoint
|
| 15 |
+
|
| 16 |
+
from .hf_model import HFTextEncoder
|
| 17 |
+
from .modified_resnet import ModifiedResNet
|
| 18 |
+
from .timm_model import TimmModel
|
| 19 |
+
from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
|
| 20 |
+
from .utils import to_2tuple
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class CLIPVisionCfg:
|
| 25 |
+
layers: Union[Tuple[int, int, int, int], int] = 12
|
| 26 |
+
width: int = 768
|
| 27 |
+
head_width: int = 64
|
| 28 |
+
mlp_ratio: float = 4.0
|
| 29 |
+
patch_size: int = 16
|
| 30 |
+
image_size: Union[Tuple[int, int], int] = 224
|
| 31 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
| 32 |
+
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
| 33 |
+
input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
|
| 34 |
+
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
|
| 35 |
+
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
|
| 36 |
+
n_queries: int = 256 # n_queries for attentional pooler
|
| 37 |
+
attn_pooler_heads: int = 8 # n heads for attentional_pooling
|
| 38 |
+
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
|
| 39 |
+
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
|
| 40 |
+
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
| 41 |
+
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
|
| 42 |
+
timm_proj_bias: bool = False # enable bias final projection
|
| 43 |
+
timm_drop: float = 0. # head dropout
|
| 44 |
+
timm_drop_path: Optional[float] = None # backbone stochastic depth
|
| 45 |
+
output_tokens: bool = False
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class CLIPTextCfg:
|
| 50 |
+
context_length: int = 77
|
| 51 |
+
vocab_size: int = 49408
|
| 52 |
+
width: int = 512
|
| 53 |
+
heads: int = 8
|
| 54 |
+
layers: int = 12
|
| 55 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
| 56 |
+
hf_model_name: str = None
|
| 57 |
+
hf_tokenizer_name: str = None
|
| 58 |
+
hf_model_pretrained: bool = True
|
| 59 |
+
proj: str = 'mlp'
|
| 60 |
+
pooler_type: str = 'mean_pooler'
|
| 61 |
+
embed_cls: bool = False
|
| 62 |
+
pad_id: int = 0
|
| 63 |
+
output_tokens: bool = False
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_cast_dtype(precision: str):
|
| 67 |
+
cast_dtype = None
|
| 68 |
+
if precision == 'bf16':
|
| 69 |
+
cast_dtype = torch.bfloat16
|
| 70 |
+
elif precision == 'fp16':
|
| 71 |
+
cast_dtype = torch.float16
|
| 72 |
+
return cast_dtype
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _build_vision_tower(
|
| 76 |
+
embed_dim: int,
|
| 77 |
+
vision_cfg: CLIPVisionCfg,
|
| 78 |
+
quick_gelu: bool = False,
|
| 79 |
+
cast_dtype: Optional[torch.dtype] = None
|
| 80 |
+
):
|
| 81 |
+
if isinstance(vision_cfg, dict):
|
| 82 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
| 83 |
+
|
| 84 |
+
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
|
| 85 |
+
# memory efficient in recent PyTorch releases (>= 1.10).
|
| 86 |
+
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
|
| 87 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
| 88 |
+
|
| 89 |
+
if vision_cfg.timm_model_name:
|
| 90 |
+
visual = TimmModel(
|
| 91 |
+
vision_cfg.timm_model_name,
|
| 92 |
+
pretrained=vision_cfg.timm_model_pretrained,
|
| 93 |
+
pool=vision_cfg.timm_pool,
|
| 94 |
+
proj=vision_cfg.timm_proj,
|
| 95 |
+
proj_bias=vision_cfg.timm_proj_bias,
|
| 96 |
+
drop=vision_cfg.timm_drop,
|
| 97 |
+
drop_path=vision_cfg.timm_drop_path,
|
| 98 |
+
embed_dim=embed_dim,
|
| 99 |
+
image_size=vision_cfg.image_size,
|
| 100 |
+
)
|
| 101 |
+
act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
|
| 102 |
+
elif isinstance(vision_cfg.layers, (tuple, list)):
|
| 103 |
+
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
|
| 104 |
+
visual = ModifiedResNet(
|
| 105 |
+
layers=vision_cfg.layers,
|
| 106 |
+
output_dim=embed_dim,
|
| 107 |
+
heads=vision_heads,
|
| 108 |
+
image_size=vision_cfg.image_size,
|
| 109 |
+
width=vision_cfg.width,
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
vision_heads = vision_cfg.width // vision_cfg.head_width
|
| 113 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
| 114 |
+
visual = VisionTransformer(
|
| 115 |
+
image_size=vision_cfg.image_size,
|
| 116 |
+
patch_size=vision_cfg.patch_size,
|
| 117 |
+
width=vision_cfg.width,
|
| 118 |
+
layers=vision_cfg.layers,
|
| 119 |
+
heads=vision_heads,
|
| 120 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
| 121 |
+
ls_init_value=vision_cfg.ls_init_value,
|
| 122 |
+
patch_dropout=vision_cfg.patch_dropout,
|
| 123 |
+
input_patchnorm=vision_cfg.input_patchnorm,
|
| 124 |
+
global_average_pool=vision_cfg.global_average_pool,
|
| 125 |
+
attentional_pool=vision_cfg.attentional_pool,
|
| 126 |
+
n_queries=vision_cfg.n_queries,
|
| 127 |
+
attn_pooler_heads=vision_cfg.attn_pooler_heads,
|
| 128 |
+
output_tokens=vision_cfg.output_tokens,
|
| 129 |
+
output_dim=embed_dim,
|
| 130 |
+
act_layer=act_layer,
|
| 131 |
+
norm_layer=norm_layer,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
return visual
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _build_text_tower(
|
| 138 |
+
embed_dim: int,
|
| 139 |
+
text_cfg: CLIPTextCfg,
|
| 140 |
+
quick_gelu: bool = False,
|
| 141 |
+
cast_dtype: Optional[torch.dtype] = None,
|
| 142 |
+
):
|
| 143 |
+
if isinstance(text_cfg, dict):
|
| 144 |
+
text_cfg = CLIPTextCfg(**text_cfg)
|
| 145 |
+
|
| 146 |
+
if text_cfg.hf_model_name:
|
| 147 |
+
text = HFTextEncoder(
|
| 148 |
+
text_cfg.hf_model_name,
|
| 149 |
+
output_dim=embed_dim,
|
| 150 |
+
proj=text_cfg.proj,
|
| 151 |
+
pooler_type=text_cfg.pooler_type,
|
| 152 |
+
pretrained=text_cfg.hf_model_pretrained,
|
| 153 |
+
output_tokens=text_cfg.output_tokens,
|
| 154 |
+
)
|
| 155 |
+
else:
|
| 156 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
| 157 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
| 158 |
+
|
| 159 |
+
text = TextTransformer(
|
| 160 |
+
context_length=text_cfg.context_length,
|
| 161 |
+
vocab_size=text_cfg.vocab_size,
|
| 162 |
+
width=text_cfg.width,
|
| 163 |
+
heads=text_cfg.heads,
|
| 164 |
+
layers=text_cfg.layers,
|
| 165 |
+
ls_init_value=text_cfg.ls_init_value,
|
| 166 |
+
output_dim=embed_dim,
|
| 167 |
+
embed_cls=text_cfg.embed_cls,
|
| 168 |
+
output_tokens=text_cfg.output_tokens,
|
| 169 |
+
pad_id=text_cfg.pad_id,
|
| 170 |
+
act_layer=act_layer,
|
| 171 |
+
norm_layer=norm_layer,
|
| 172 |
+
)
|
| 173 |
+
return text
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class CLIP(nn.Module):
|
| 177 |
+
output_dict: torch.jit.Final[bool]
|
| 178 |
+
|
| 179 |
+
def __init__(
|
| 180 |
+
self,
|
| 181 |
+
embed_dim: int,
|
| 182 |
+
vision_cfg: CLIPVisionCfg,
|
| 183 |
+
text_cfg: CLIPTextCfg,
|
| 184 |
+
quick_gelu: bool = False,
|
| 185 |
+
cast_dtype: Optional[torch.dtype] = None,
|
| 186 |
+
output_dict: bool = False,
|
| 187 |
+
):
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.output_dict = output_dict
|
| 190 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
| 191 |
+
|
| 192 |
+
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
| 193 |
+
self.transformer = text.transformer
|
| 194 |
+
self.vocab_size = text.vocab_size
|
| 195 |
+
self.token_embedding = text.token_embedding
|
| 196 |
+
self.positional_embedding = text.positional_embedding
|
| 197 |
+
self.ln_final = text.ln_final
|
| 198 |
+
self.text_projection = text.text_projection
|
| 199 |
+
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
|
| 200 |
+
|
| 201 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 202 |
+
|
| 203 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
| 204 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
| 205 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
| 206 |
+
|
| 207 |
+
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
| 208 |
+
locked_layers = []
|
| 209 |
+
locked_layers.append(self.token_embedding)
|
| 210 |
+
self.positional_embedding.requires_grad = False
|
| 211 |
+
if unlocked_layers > 0:
|
| 212 |
+
locked_layers.append(self.transformer.resblocks[:-unlocked_layers])
|
| 213 |
+
else:
|
| 214 |
+
locked_layers.append(self.transformer)
|
| 215 |
+
locked_layers.append(self.ln_final)
|
| 216 |
+
self.text_projection.requires_grad = False
|
| 217 |
+
|
| 218 |
+
# freeze layers
|
| 219 |
+
for module in locked_layers:
|
| 220 |
+
for n, p in module.named_parameters():
|
| 221 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
| 222 |
+
|
| 223 |
+
@torch.jit.ignore
|
| 224 |
+
def set_grad_checkpointing(self, enable=True):
|
| 225 |
+
self.visual.set_grad_checkpointing(enable)
|
| 226 |
+
self.transformer.grad_checkpointing = enable
|
| 227 |
+
|
| 228 |
+
def encode_image(self, image, normalize: bool = False):
|
| 229 |
+
features = self.visual(image)
|
| 230 |
+
return F.normalize(features, dim=-1) if normalize else features
|
| 231 |
+
|
| 232 |
+
def encode_text(self, text, normalize: bool = False):
|
| 233 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
| 234 |
+
|
| 235 |
+
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
| 236 |
+
|
| 237 |
+
x = x + self.positional_embedding.to(cast_dtype)
|
| 238 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 239 |
+
x = self.transformer(x, attn_mask=self.attn_mask)
|
| 240 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 241 |
+
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
| 242 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 243 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
| 244 |
+
return F.normalize(x, dim=-1) if normalize else x
|
| 245 |
+
|
| 246 |
+
def forward(self, image, text):
|
| 247 |
+
image_features = self.encode_image(image, normalize=True)
|
| 248 |
+
text_features = self.encode_text(text, normalize=True)
|
| 249 |
+
if self.output_dict:
|
| 250 |
+
return {
|
| 251 |
+
"image_features": image_features,
|
| 252 |
+
"text_features": text_features,
|
| 253 |
+
"logit_scale": self.logit_scale.exp()
|
| 254 |
+
}
|
| 255 |
+
return image_features, text_features, self.logit_scale.exp()
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class CustomTextCLIP(nn.Module):
|
| 259 |
+
output_dict: torch.jit.Final[bool]
|
| 260 |
+
|
| 261 |
+
def __init__(
|
| 262 |
+
self,
|
| 263 |
+
embed_dim: int,
|
| 264 |
+
vision_cfg: CLIPVisionCfg,
|
| 265 |
+
text_cfg: CLIPTextCfg,
|
| 266 |
+
quick_gelu: bool = False,
|
| 267 |
+
cast_dtype: Optional[torch.dtype] = None,
|
| 268 |
+
output_dict: bool = False,
|
| 269 |
+
):
|
| 270 |
+
super().__init__()
|
| 271 |
+
self.output_dict = output_dict
|
| 272 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
| 273 |
+
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
| 274 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 275 |
+
|
| 276 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
| 277 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
| 278 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
| 279 |
+
|
| 280 |
+
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
|
| 281 |
+
self.text.lock(unlocked_layers, freeze_layer_norm)
|
| 282 |
+
|
| 283 |
+
@torch.jit.ignore
|
| 284 |
+
def set_grad_checkpointing(self, enable=True):
|
| 285 |
+
self.visual.set_grad_checkpointing(enable)
|
| 286 |
+
self.text.set_grad_checkpointing(enable)
|
| 287 |
+
|
| 288 |
+
def encode_image(self, image, normalize: bool = False):
|
| 289 |
+
features = self.visual(image)
|
| 290 |
+
return F.normalize(features, dim=-1) if normalize else features
|
| 291 |
+
|
| 292 |
+
def encode_text(self, text, normalize: bool = False):
|
| 293 |
+
features = self.text(text)
|
| 294 |
+
return F.normalize(features, dim=-1) if normalize else features
|
| 295 |
+
|
| 296 |
+
def forward(self, image, text):
|
| 297 |
+
image_features = self.encode_image(image, normalize=True)
|
| 298 |
+
text_features = self.encode_text(text, normalize=True)
|
| 299 |
+
if self.output_dict:
|
| 300 |
+
return {
|
| 301 |
+
"image_features": image_features,
|
| 302 |
+
"text_features": text_features,
|
| 303 |
+
"logit_scale": self.logit_scale.exp()
|
| 304 |
+
}
|
| 305 |
+
return image_features, text_features, self.logit_scale.exp()
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
|
| 309 |
+
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
|
| 310 |
+
|
| 311 |
+
def _convert_weights(l):
|
| 312 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
| 313 |
+
l.weight.data = l.weight.data.to(dtype)
|
| 314 |
+
if l.bias is not None:
|
| 315 |
+
l.bias.data = l.bias.data.to(dtype)
|
| 316 |
+
|
| 317 |
+
if isinstance(l, (nn.MultiheadAttention, Attention)):
|
| 318 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
| 319 |
+
tensor = getattr(l, attr)
|
| 320 |
+
if tensor is not None:
|
| 321 |
+
tensor.data = tensor.data.to(dtype)
|
| 322 |
+
|
| 323 |
+
for name in ["text_projection", "proj"]:
|
| 324 |
+
if hasattr(l, name):
|
| 325 |
+
attr = getattr(l, name)
|
| 326 |
+
if attr is not None:
|
| 327 |
+
attr.data = attr.data.to(dtype)
|
| 328 |
+
|
| 329 |
+
model.apply(_convert_weights)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
# used to maintain checkpoint compatibility
|
| 336 |
+
def convert_to_custom_text_state_dict(state_dict: dict):
|
| 337 |
+
if 'text_projection' in state_dict:
|
| 338 |
+
# old format state_dict, move text tower -> .text
|
| 339 |
+
new_state_dict = {}
|
| 340 |
+
for k, v in state_dict.items():
|
| 341 |
+
if any(k.startswith(p) for p in (
|
| 342 |
+
'text_projection',
|
| 343 |
+
'positional_embedding',
|
| 344 |
+
'token_embedding',
|
| 345 |
+
'transformer',
|
| 346 |
+
'ln_final',
|
| 347 |
+
)):
|
| 348 |
+
k = 'text.' + k
|
| 349 |
+
new_state_dict[k] = v
|
| 350 |
+
return new_state_dict
|
| 351 |
+
return state_dict
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def build_model_from_openai_state_dict(
|
| 355 |
+
state_dict: dict,
|
| 356 |
+
quick_gelu=True,
|
| 357 |
+
cast_dtype=torch.float16,
|
| 358 |
+
):
|
| 359 |
+
vit = "visual.proj" in state_dict
|
| 360 |
+
|
| 361 |
+
if vit:
|
| 362 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
| 363 |
+
vision_layers = len(
|
| 364 |
+
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
| 365 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
| 366 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 367 |
+
image_size = vision_patch_size * grid_size
|
| 368 |
+
else:
|
| 369 |
+
counts: list = [
|
| 370 |
+
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
| 371 |
+
vision_layers = tuple(counts)
|
| 372 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
| 373 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 374 |
+
vision_patch_size = None
|
| 375 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
| 376 |
+
image_size = output_width * 32
|
| 377 |
+
|
| 378 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
| 379 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
| 380 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
| 381 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
| 382 |
+
transformer_heads = transformer_width // 64
|
| 383 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
| 384 |
+
|
| 385 |
+
vision_cfg = CLIPVisionCfg(
|
| 386 |
+
layers=vision_layers,
|
| 387 |
+
width=vision_width,
|
| 388 |
+
patch_size=vision_patch_size,
|
| 389 |
+
image_size=image_size,
|
| 390 |
+
)
|
| 391 |
+
text_cfg = CLIPTextCfg(
|
| 392 |
+
context_length=context_length,
|
| 393 |
+
vocab_size=vocab_size,
|
| 394 |
+
width=transformer_width,
|
| 395 |
+
heads=transformer_heads,
|
| 396 |
+
layers=transformer_layers,
|
| 397 |
+
)
|
| 398 |
+
model = CLIP(
|
| 399 |
+
embed_dim,
|
| 400 |
+
vision_cfg=vision_cfg,
|
| 401 |
+
text_cfg=text_cfg,
|
| 402 |
+
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
|
| 403 |
+
cast_dtype=cast_dtype,
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
| 407 |
+
state_dict.pop(key, None)
|
| 408 |
+
|
| 409 |
+
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
|
| 410 |
+
model.load_state_dict(state_dict)
|
| 411 |
+
return model.eval()
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def trace_model(model, batch_size=256, device=torch.device('cpu')):
|
| 415 |
+
model.eval()
|
| 416 |
+
image_size = model.visual.image_size
|
| 417 |
+
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
|
| 418 |
+
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
|
| 419 |
+
model = torch.jit.trace_module(
|
| 420 |
+
model,
|
| 421 |
+
inputs=dict(
|
| 422 |
+
forward=(example_images, example_text),
|
| 423 |
+
encode_text=(example_text,),
|
| 424 |
+
encode_image=(example_images,)
|
| 425 |
+
))
|
| 426 |
+
model.visual.image_size = image_size
|
| 427 |
+
return model
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
|
| 431 |
+
# Rescale the grid of position embeddings when loading from state_dict
|
| 432 |
+
old_pos_embed = state_dict.get('visual.positional_embedding', None)
|
| 433 |
+
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
|
| 434 |
+
return
|
| 435 |
+
grid_size = to_2tuple(model.visual.grid_size)
|
| 436 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
| 437 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
| 438 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
| 439 |
+
return
|
| 440 |
+
|
| 441 |
+
if extra_tokens:
|
| 442 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
| 443 |
+
else:
|
| 444 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
| 445 |
+
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
| 446 |
+
|
| 447 |
+
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
| 448 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
| 449 |
+
pos_emb_img = F.interpolate(
|
| 450 |
+
pos_emb_img,
|
| 451 |
+
size=grid_size,
|
| 452 |
+
mode=interpolation,
|
| 453 |
+
antialias=antialias,
|
| 454 |
+
align_corners=False,
|
| 455 |
+
)
|
| 456 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
| 457 |
+
if pos_emb_tok is not None:
|
| 458 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
| 459 |
+
else:
|
| 460 |
+
new_pos_embed = pos_emb_img
|
| 461 |
+
state_dict['visual.positional_embedding'] = new_pos_embed
|