vivienfanghua commited on
Commit
5e60fc9
·
verified ·
1 Parent(s): df1d959

Delete wan

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. wan/__init__.py +0 -5
  2. wan/__pycache__/__init__.cpython-310.pyc +0 -0
  3. wan/__pycache__/image2video.cpython-310.pyc +0 -0
  4. wan/__pycache__/text2video.cpython-310.pyc +0 -0
  5. wan/__pycache__/textimage2video.cpython-310.pyc +0 -0
  6. wan/configs/__init__.py +0 -39
  7. wan/configs/__pycache__/__init__.cpython-310.pyc +0 -0
  8. wan/configs/__pycache__/shared_config.cpython-310.pyc +0 -0
  9. wan/configs/__pycache__/wan_i2v_A14B.cpython-310.pyc +0 -0
  10. wan/configs/__pycache__/wan_t2v_A14B.cpython-310.pyc +0 -0
  11. wan/configs/__pycache__/wan_ti2v_5B.cpython-310.pyc +0 -0
  12. wan/configs/shared_config.py +0 -20
  13. wan/configs/wan_i2v_A14B.py +0 -37
  14. wan/configs/wan_t2v_A14B.py +0 -37
  15. wan/configs/wan_ti2v_5B.py +0 -36
  16. wan/distributed/__init__.py +0 -1
  17. wan/distributed/__pycache__/__init__.cpython-310.pyc +0 -0
  18. wan/distributed/__pycache__/fsdp.cpython-310.pyc +0 -0
  19. wan/distributed/__pycache__/sequence_parallel.cpython-310.pyc +0 -0
  20. wan/distributed/__pycache__/ulysses.cpython-310.pyc +0 -0
  21. wan/distributed/__pycache__/util.cpython-310.pyc +0 -0
  22. wan/distributed/fsdp.py +0 -43
  23. wan/distributed/sequence_parallel.py +0 -176
  24. wan/distributed/ulysses.py +0 -47
  25. wan/distributed/util.py +0 -51
  26. wan/image2video.py +0 -431
  27. wan/modules/__init__.py +0 -19
  28. wan/modules/__pycache__/__init__.cpython-310.pyc +0 -0
  29. wan/modules/__pycache__/attention.cpython-310.pyc +0 -0
  30. wan/modules/__pycache__/model.cpython-310.pyc +0 -0
  31. wan/modules/__pycache__/t5.cpython-310.pyc +0 -0
  32. wan/modules/__pycache__/tokenizers.cpython-310.pyc +0 -0
  33. wan/modules/__pycache__/vae2_1.cpython-310.pyc +0 -0
  34. wan/modules/__pycache__/vae2_2.cpython-310.pyc +0 -0
  35. wan/modules/attention.py +0 -179
  36. wan/modules/model.py +0 -546
  37. wan/modules/t5.py +0 -513
  38. wan/modules/tokenizers.py +0 -82
  39. wan/modules/vae2_1.py +0 -663
  40. wan/modules/vae2_2.py +0 -1051
  41. wan/text2video.py +0 -378
  42. wan/textimage2video.py +0 -619
  43. wan/utils/__init__.py +0 -12
  44. wan/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  45. wan/utils/__pycache__/fm_solvers.cpython-310.pyc +0 -0
  46. wan/utils/__pycache__/fm_solvers_unipc.cpython-310.pyc +0 -0
  47. wan/utils/__pycache__/utils.cpython-310.pyc +0 -0
  48. wan/utils/fm_solvers.py +0 -859
  49. wan/utils/fm_solvers_unipc.py +0 -802
  50. wan/utils/prompt_extend.py +0 -542
wan/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- from . import configs, distributed, modules
3
- from .image2video import WanI2V
4
- from .text2video import WanT2V
5
- from .textimage2video import WanTI2V
 
 
 
 
 
 
wan/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (333 Bytes)
 
wan/__pycache__/image2video.cpython-310.pyc DELETED
Binary file (12.3 kB)
 
wan/__pycache__/text2video.cpython-310.pyc DELETED
Binary file (11.1 kB)
 
wan/__pycache__/textimage2video.cpython-310.pyc DELETED
Binary file (17.5 kB)
 
wan/configs/__init__.py DELETED
@@ -1,39 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import copy
3
- import os
4
-
5
- os.environ['TOKENIZERS_PARALLELISM'] = 'false'
6
-
7
- from .wan_i2v_A14B import i2v_A14B
8
- from .wan_t2v_A14B import t2v_A14B
9
- from .wan_ti2v_5B import ti2v_5B
10
-
11
- WAN_CONFIGS = {
12
- 't2v-A14B': t2v_A14B,
13
- 'i2v-A14B': i2v_A14B,
14
- 'ti2v-5B': ti2v_5B,
15
- }
16
-
17
- SIZE_CONFIGS = {
18
- '720*1280': (720, 1280),
19
- '1280*720': (1280, 720),
20
- '480*832': (480, 832),
21
- '832*480': (832, 480),
22
- '704*1280': (704, 1280),
23
- '1280*704': (1280, 704)
24
- }
25
-
26
- MAX_AREA_CONFIGS = {
27
- '720*1280': 720 * 1280,
28
- '1280*720': 1280 * 720,
29
- '480*832': 480 * 832,
30
- '832*480': 832 * 480,
31
- '704*1280': 704 * 1280,
32
- '1280*704': 1280 * 704,
33
- }
34
-
35
- SUPPORTED_SIZES = {
36
- 't2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
37
- 'i2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
38
- 'ti2v-5B': ('704*1280', '1280*704'),
39
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/configs/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (737 Bytes)
 
wan/configs/__pycache__/shared_config.cpython-310.pyc DELETED
Binary file (848 Bytes)
 
wan/configs/__pycache__/wan_i2v_A14B.cpython-310.pyc DELETED
Binary file (968 Bytes)
 
wan/configs/__pycache__/wan_t2v_A14B.cpython-310.pyc DELETED
Binary file (955 Bytes)
 
wan/configs/__pycache__/wan_ti2v_5B.cpython-310.pyc DELETED
Binary file (868 Bytes)
 
wan/configs/shared_config.py DELETED
@@ -1,20 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import torch
3
- from easydict import EasyDict
4
-
5
- #------------------------ Wan shared config ------------------------#
6
- wan_shared_cfg = EasyDict()
7
-
8
- # t5
9
- wan_shared_cfg.t5_model = 'umt5_xxl'
10
- wan_shared_cfg.t5_dtype = torch.bfloat16
11
- wan_shared_cfg.text_len = 512
12
-
13
- # transformer
14
- wan_shared_cfg.param_dtype = torch.bfloat16
15
-
16
- # inference
17
- wan_shared_cfg.num_train_timesteps = 1000
18
- wan_shared_cfg.sample_fps = 16
19
- wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
20
- wan_shared_cfg.frame_num = 81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/configs/wan_i2v_A14B.py DELETED
@@ -1,37 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import torch
3
- from easydict import EasyDict
4
-
5
- from .shared_config import wan_shared_cfg
6
-
7
- #------------------------ Wan I2V A14B ------------------------#
8
-
9
- i2v_A14B = EasyDict(__name__='Config: Wan I2V A14B')
10
- i2v_A14B.update(wan_shared_cfg)
11
-
12
- i2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
- i2v_A14B.t5_tokenizer = 'google/umt5-xxl'
14
-
15
- # vae
16
- i2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'
17
- i2v_A14B.vae_stride = (4, 8, 8)
18
-
19
- # transformer
20
- i2v_A14B.patch_size = (1, 2, 2)
21
- i2v_A14B.dim = 5120
22
- i2v_A14B.ffn_dim = 13824
23
- i2v_A14B.freq_dim = 256
24
- i2v_A14B.num_heads = 40
25
- i2v_A14B.num_layers = 40
26
- i2v_A14B.window_size = (-1, -1)
27
- i2v_A14B.qk_norm = True
28
- i2v_A14B.cross_attn_norm = True
29
- i2v_A14B.eps = 1e-6
30
- i2v_A14B.low_noise_checkpoint = 'low_noise_model'
31
- i2v_A14B.high_noise_checkpoint = 'high_noise_model'
32
-
33
- # inference
34
- i2v_A14B.sample_shift = 5.0
35
- i2v_A14B.sample_steps = 40
36
- i2v_A14B.boundary = 0.900
37
- i2v_A14B.sample_guide_scale = (3.5, 3.5) # low noise, high noise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/configs/wan_t2v_A14B.py DELETED
@@ -1,37 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- from easydict import EasyDict
3
-
4
- from .shared_config import wan_shared_cfg
5
-
6
- #------------------------ Wan T2V A14B ------------------------#
7
-
8
- t2v_A14B = EasyDict(__name__='Config: Wan T2V A14B')
9
- t2v_A14B.update(wan_shared_cfg)
10
-
11
- # t5
12
- t2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
- t2v_A14B.t5_tokenizer = 'google/umt5-xxl'
14
-
15
- # vae
16
- t2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'
17
- t2v_A14B.vae_stride = (4, 8, 8)
18
-
19
- # transformer
20
- t2v_A14B.patch_size = (1, 2, 2)
21
- t2v_A14B.dim = 5120
22
- t2v_A14B.ffn_dim = 13824
23
- t2v_A14B.freq_dim = 256
24
- t2v_A14B.num_heads = 40
25
- t2v_A14B.num_layers = 40
26
- t2v_A14B.window_size = (-1, -1)
27
- t2v_A14B.qk_norm = True
28
- t2v_A14B.cross_attn_norm = True
29
- t2v_A14B.eps = 1e-6
30
- t2v_A14B.low_noise_checkpoint = 'low_noise_model'
31
- t2v_A14B.high_noise_checkpoint = 'high_noise_model'
32
-
33
- # inference
34
- t2v_A14B.sample_shift = 12.0
35
- t2v_A14B.sample_steps = 40
36
- t2v_A14B.boundary = 0.875
37
- t2v_A14B.sample_guide_scale = (3.0, 4.0) # low noise, high noise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/configs/wan_ti2v_5B.py DELETED
@@ -1,36 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- from easydict import EasyDict
3
-
4
- from .shared_config import wan_shared_cfg
5
-
6
- #------------------------ Wan TI2V 5B ------------------------#
7
-
8
- ti2v_5B = EasyDict(__name__='Config: Wan TI2V 5B')
9
- ti2v_5B.update(wan_shared_cfg)
10
-
11
- # t5
12
- ti2v_5B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
- ti2v_5B.t5_tokenizer = 'google/umt5-xxl'
14
-
15
- # vae
16
- ti2v_5B.vae_checkpoint = 'Wan2.2_VAE.pth'
17
- ti2v_5B.vae_stride = (4, 16, 16)
18
-
19
- # transformer
20
- ti2v_5B.patch_size = (1, 2, 2)
21
- ti2v_5B.dim = 3072
22
- ti2v_5B.ffn_dim = 14336
23
- ti2v_5B.freq_dim = 256
24
- ti2v_5B.num_heads = 24
25
- ti2v_5B.num_layers = 30
26
- ti2v_5B.window_size = (-1, -1)
27
- ti2v_5B.qk_norm = True
28
- ti2v_5B.cross_attn_norm = True
29
- ti2v_5B.eps = 1e-6
30
-
31
- # inference
32
- ti2v_5B.sample_fps = 12
33
- ti2v_5B.sample_shift = 5.0
34
- ti2v_5B.sample_steps = 50
35
- ti2v_5B.sample_guide_scale = 5.0
36
- ti2v_5B.frame_num = 121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/distributed/__init__.py DELETED
@@ -1 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 
 
wan/distributed/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (143 Bytes)
 
wan/distributed/__pycache__/fsdp.cpython-310.pyc DELETED
Binary file (1.36 kB)
 
wan/distributed/__pycache__/sequence_parallel.cpython-310.pyc DELETED
Binary file (5.24 kB)
 
wan/distributed/__pycache__/ulysses.cpython-310.pyc DELETED
Binary file (1.23 kB)
 
wan/distributed/__pycache__/util.cpython-310.pyc DELETED
Binary file (1.93 kB)
 
wan/distributed/fsdp.py DELETED
@@ -1,43 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import gc
3
- from functools import partial
4
-
5
- import torch
6
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
7
- from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
8
- from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
9
- from torch.distributed.utils import _free_storage
10
-
11
-
12
- def shard_model(
13
- model,
14
- device_id,
15
- param_dtype=torch.bfloat16,
16
- reduce_dtype=torch.float32,
17
- buffer_dtype=torch.float32,
18
- process_group=None,
19
- sharding_strategy=ShardingStrategy.FULL_SHARD,
20
- sync_module_states=True,
21
- ):
22
- model = FSDP(
23
- module=model,
24
- process_group=process_group,
25
- sharding_strategy=sharding_strategy,
26
- auto_wrap_policy=partial(
27
- lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
28
- mixed_precision=MixedPrecision(
29
- param_dtype=param_dtype,
30
- reduce_dtype=reduce_dtype,
31
- buffer_dtype=buffer_dtype),
32
- device_id=device_id,
33
- sync_module_states=sync_module_states)
34
- return model
35
-
36
-
37
- def free_model(model):
38
- for m in model.modules():
39
- if isinstance(m, FSDP):
40
- _free_storage(m._handle.flat_param.data)
41
- del model
42
- gc.collect()
43
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/distributed/sequence_parallel.py DELETED
@@ -1,176 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import torch
3
- import torch.cuda.amp as amp
4
-
5
- from ..modules.model import sinusoidal_embedding_1d
6
- from .ulysses import distributed_attention
7
- from .util import gather_forward, get_rank, get_world_size
8
-
9
-
10
- def pad_freqs(original_tensor, target_len):
11
- seq_len, s1, s2 = original_tensor.shape
12
- pad_size = target_len - seq_len
13
- padding_tensor = torch.ones(
14
- pad_size,
15
- s1,
16
- s2,
17
- dtype=original_tensor.dtype,
18
- device=original_tensor.device)
19
- padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
20
- return padded_tensor
21
-
22
-
23
- @torch.amp.autocast('cuda', enabled=False)
24
- def rope_apply(x, grid_sizes, freqs):
25
- """
26
- x: [B, L, N, C].
27
- grid_sizes: [B, 3].
28
- freqs: [M, C // 2].
29
- """
30
- s, n, c = x.size(1), x.size(2), x.size(3) // 2
31
- # split freqs
32
- freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
33
-
34
- # loop over samples
35
- output = []
36
- for i, (f, h, w) in enumerate(grid_sizes.tolist()):
37
- seq_len = f * h * w
38
-
39
- # precompute multipliers
40
- x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
41
- s, n, -1, 2))
42
- freqs_i = torch.cat([
43
- freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
44
- freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
45
- freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
46
- ],
47
- dim=-1).reshape(seq_len, 1, -1)
48
-
49
- # apply rotary embedding
50
- sp_size = get_world_size()
51
- sp_rank = get_rank()
52
- freqs_i = pad_freqs(freqs_i, s * sp_size)
53
- s_per_rank = s
54
- freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
55
- s_per_rank), :, :]
56
- x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
57
- x_i = torch.cat([x_i, x[i, s:]])
58
-
59
- # append to collection
60
- output.append(x_i)
61
- return torch.stack(output).float()
62
-
63
-
64
- def sp_dit_forward(
65
- self,
66
- x,
67
- t,
68
- context,
69
- seq_len,
70
- y=None,
71
- ):
72
- """
73
- x: A list of videos each with shape [C, T, H, W].
74
- t: [B].
75
- context: A list of text embeddings each with shape [L, C].
76
- """
77
- if self.model_type == 'i2v':
78
- assert y is not None
79
- # params
80
- device = self.patch_embedding.weight.device
81
- if self.freqs.device != device:
82
- self.freqs = self.freqs.to(device)
83
-
84
- if y is not None:
85
- x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
86
-
87
- # embeddings
88
- x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
89
- grid_sizes = torch.stack(
90
- [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
91
- x = [u.flatten(2).transpose(1, 2) for u in x]
92
- seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
93
- assert seq_lens.max() <= seq_len
94
- x = torch.cat([
95
- torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
96
- for u in x
97
- ])
98
-
99
- # time embeddings
100
- if t.dim() == 1:
101
- t = t.expand(t.size(0), seq_len)
102
- with torch.amp.autocast('cuda', dtype=torch.float32):
103
- bt = t.size(0)
104
- t = t.flatten()
105
- e = self.time_embedding(
106
- sinusoidal_embedding_1d(self.freq_dim,
107
- t).unflatten(0, (bt, seq_len)).float())
108
- e0 = self.time_projection(e).unflatten(2, (6, self.dim))
109
- assert e.dtype == torch.float32 and e0.dtype == torch.float32
110
-
111
- # context
112
- context_lens = None
113
- context = self.text_embedding(
114
- torch.stack([
115
- torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
116
- for u in context
117
- ]))
118
-
119
- # Context Parallel
120
- x = torch.chunk(x, get_world_size(), dim=1)[get_rank()]
121
- e = torch.chunk(e, get_world_size(), dim=1)[get_rank()]
122
- e0 = torch.chunk(e0, get_world_size(), dim=1)[get_rank()]
123
-
124
- # arguments
125
- kwargs = dict(
126
- e=e0,
127
- seq_lens=seq_lens,
128
- grid_sizes=grid_sizes,
129
- freqs=self.freqs,
130
- context=context,
131
- context_lens=context_lens)
132
-
133
- for block in self.blocks:
134
- x = block(x, **kwargs)
135
-
136
- # head
137
- x = self.head(x, e)
138
-
139
- # Context Parallel
140
- x = gather_forward(x, dim=1)
141
-
142
- # unpatchify
143
- x = self.unpatchify(x, grid_sizes)
144
- return [u.float() for u in x]
145
-
146
-
147
- def sp_attn_forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16):
148
- b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
149
- half_dtypes = (torch.float16, torch.bfloat16)
150
-
151
- def half(x):
152
- return x if x.dtype in half_dtypes else x.to(dtype)
153
-
154
- # query, key, value function
155
- def qkv_fn(x):
156
- q = self.norm_q(self.q(x)).view(b, s, n, d)
157
- k = self.norm_k(self.k(x)).view(b, s, n, d)
158
- v = self.v(x).view(b, s, n, d)
159
- return q, k, v
160
-
161
- q, k, v = qkv_fn(x)
162
- q = rope_apply(q, grid_sizes, freqs)
163
- k = rope_apply(k, grid_sizes, freqs)
164
-
165
- x = distributed_attention(
166
- half(q),
167
- half(k),
168
- half(v),
169
- seq_lens,
170
- window_size=self.window_size,
171
- )
172
-
173
- # output
174
- x = x.flatten(2)
175
- x = self.o(x)
176
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/distributed/ulysses.py DELETED
@@ -1,47 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import torch
3
- import torch.distributed as dist
4
-
5
- from ..modules.attention import flash_attention
6
- from .util import all_to_all
7
-
8
-
9
- def distributed_attention(
10
- q,
11
- k,
12
- v,
13
- seq_lens,
14
- window_size=(-1, -1),
15
- ):
16
- """
17
- Performs distributed attention based on DeepSpeed Ulysses attention mechanism.
18
- please refer to https://arxiv.org/pdf/2309.14509
19
-
20
- Args:
21
- q: [B, Lq // p, Nq, C1].
22
- k: [B, Lk // p, Nk, C1].
23
- v: [B, Lk // p, Nk, C2]. Nq must be divisible by Nk.
24
- seq_lens: [B], length of each sequence in batch
25
- window_size: (left right). If not (-1, -1), apply sliding window local attention.
26
- """
27
- if not dist.is_initialized():
28
- raise ValueError("distributed group should be initialized.")
29
- b = q.shape[0]
30
-
31
- # gather q/k/v sequence
32
- q = all_to_all(q, scatter_dim=2, gather_dim=1)
33
- k = all_to_all(k, scatter_dim=2, gather_dim=1)
34
- v = all_to_all(v, scatter_dim=2, gather_dim=1)
35
-
36
- # apply attention
37
- x = flash_attention(
38
- q,
39
- k,
40
- v,
41
- k_lens=seq_lens,
42
- window_size=window_size,
43
- )
44
-
45
- # scatter q/k/v sequence
46
- x = all_to_all(x, scatter_dim=1, gather_dim=2)
47
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/distributed/util.py DELETED
@@ -1,51 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import torch
3
- import torch.distributed as dist
4
-
5
-
6
- def init_distributed_group():
7
- """r initialize sequence parallel group.
8
- """
9
- if not dist.is_initialized():
10
- dist.init_process_group(backend='nccl')
11
-
12
-
13
- def get_rank():
14
- return dist.get_rank()
15
-
16
-
17
- def get_world_size():
18
- return dist.get_world_size()
19
-
20
-
21
- def all_to_all(x, scatter_dim, gather_dim, group=None, **kwargs):
22
- """
23
- `scatter` along one dimension and `gather` along another.
24
- """
25
- world_size = get_world_size()
26
- if world_size > 1:
27
- inputs = [u.contiguous() for u in x.chunk(world_size, dim=scatter_dim)]
28
- outputs = [torch.empty_like(u) for u in inputs]
29
- dist.all_to_all(outputs, inputs, group=group, **kwargs)
30
- x = torch.cat(outputs, dim=gather_dim).contiguous()
31
- return x
32
-
33
-
34
- def all_gather(tensor):
35
- world_size = dist.get_world_size()
36
- if world_size == 1:
37
- return [tensor]
38
- tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
39
- torch.distributed.all_gather(tensor_list, tensor)
40
- return tensor_list
41
-
42
-
43
- def gather_forward(input, dim):
44
- # skip if world_size == 1
45
- world_size = dist.get_world_size()
46
- if world_size == 1:
47
- return input
48
-
49
- # gather sequence
50
- output = all_gather(input)
51
- return torch.cat(output, dim=dim).contiguous()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/image2video.py DELETED
@@ -1,431 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import gc
3
- import logging
4
- import math
5
- import os
6
- import random
7
- import sys
8
- import types
9
- from contextlib import contextmanager
10
- from functools import partial
11
-
12
- import numpy as np
13
- import torch
14
- import torch.cuda.amp as amp
15
- import torch.distributed as dist
16
- import torchvision.transforms.functional as TF
17
- from tqdm import tqdm
18
-
19
- from .distributed.fsdp import shard_model
20
- from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
21
- from .distributed.util import get_world_size
22
- from .modules.model import WanModel
23
- from .modules.t5 import T5EncoderModel
24
- from .modules.vae2_1 import Wan2_1_VAE
25
- from .utils.fm_solvers import (
26
- FlowDPMSolverMultistepScheduler,
27
- get_sampling_sigmas,
28
- retrieve_timesteps,
29
- )
30
- from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
31
-
32
-
33
- class WanI2V:
34
-
35
- def __init__(
36
- self,
37
- config,
38
- checkpoint_dir,
39
- device_id=0,
40
- rank=0,
41
- t5_fsdp=False,
42
- dit_fsdp=False,
43
- use_sp=False,
44
- t5_cpu=False,
45
- init_on_cpu=True,
46
- convert_model_dtype=False,
47
- ):
48
- r"""
49
- Initializes the image-to-video generation model components.
50
-
51
- Args:
52
- config (EasyDict):
53
- Object containing model parameters initialized from config.py
54
- checkpoint_dir (`str`):
55
- Path to directory containing model checkpoints
56
- device_id (`int`, *optional*, defaults to 0):
57
- Id of target GPU device
58
- rank (`int`, *optional*, defaults to 0):
59
- Process rank for distributed training
60
- t5_fsdp (`bool`, *optional*, defaults to False):
61
- Enable FSDP sharding for T5 model
62
- dit_fsdp (`bool`, *optional*, defaults to False):
63
- Enable FSDP sharding for DiT model
64
- use_sp (`bool`, *optional*, defaults to False):
65
- Enable distribution strategy of sequence parallel.
66
- t5_cpu (`bool`, *optional*, defaults to False):
67
- Whether to place T5 model on CPU. Only works without t5_fsdp.
68
- init_on_cpu (`bool`, *optional*, defaults to True):
69
- Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
70
- convert_model_dtype (`bool`, *optional*, defaults to False):
71
- Convert DiT model parameters dtype to 'config.param_dtype'.
72
- Only works without FSDP.
73
- """
74
- self.device = torch.device(f"cuda:{device_id}")
75
- self.config = config
76
- self.rank = rank
77
- self.t5_cpu = t5_cpu
78
- self.init_on_cpu = init_on_cpu
79
-
80
- self.num_train_timesteps = config.num_train_timesteps
81
- self.boundary = config.boundary
82
- self.param_dtype = config.param_dtype
83
-
84
- if t5_fsdp or dit_fsdp or use_sp:
85
- self.init_on_cpu = False
86
-
87
- shard_fn = partial(shard_model, device_id=device_id)
88
- self.text_encoder = T5EncoderModel(
89
- text_len=config.text_len,
90
- dtype=config.t5_dtype,
91
- device=torch.device('cpu'),
92
- checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
93
- tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
94
- shard_fn=shard_fn if t5_fsdp else None,
95
- )
96
-
97
- self.vae_stride = config.vae_stride
98
- self.patch_size = config.patch_size
99
- self.vae = Wan2_1_VAE(
100
- vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
101
- device=self.device)
102
-
103
- logging.info(f"Creating WanModel from {checkpoint_dir}")
104
- self.low_noise_model = WanModel.from_pretrained(
105
- checkpoint_dir, subfolder=config.low_noise_checkpoint)
106
- self.low_noise_model = self._configure_model(
107
- model=self.low_noise_model,
108
- use_sp=use_sp,
109
- dit_fsdp=dit_fsdp,
110
- shard_fn=shard_fn,
111
- convert_model_dtype=convert_model_dtype)
112
-
113
- self.high_noise_model = WanModel.from_pretrained(
114
- checkpoint_dir, subfolder=config.high_noise_checkpoint)
115
- self.high_noise_model = self._configure_model(
116
- model=self.high_noise_model,
117
- use_sp=use_sp,
118
- dit_fsdp=dit_fsdp,
119
- shard_fn=shard_fn,
120
- convert_model_dtype=convert_model_dtype)
121
- if use_sp:
122
- self.sp_size = get_world_size()
123
- else:
124
- self.sp_size = 1
125
-
126
- self.sample_neg_prompt = config.sample_neg_prompt
127
-
128
- def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
129
- convert_model_dtype):
130
- """
131
- Configures a model object. This includes setting evaluation modes,
132
- applying distributed parallel strategy, and handling device placement.
133
-
134
- Args:
135
- model (torch.nn.Module):
136
- The model instance to configure.
137
- use_sp (`bool`):
138
- Enable distribution strategy of sequence parallel.
139
- dit_fsdp (`bool`):
140
- Enable FSDP sharding for DiT model.
141
- shard_fn (callable):
142
- The function to apply FSDP sharding.
143
- convert_model_dtype (`bool`):
144
- Convert DiT model parameters dtype to 'config.param_dtype'.
145
- Only works without FSDP.
146
-
147
- Returns:
148
- torch.nn.Module:
149
- The configured model.
150
- """
151
- model.eval().requires_grad_(False)
152
-
153
- if use_sp:
154
- for block in model.blocks:
155
- block.self_attn.forward = types.MethodType(
156
- sp_attn_forward, block.self_attn)
157
- model.forward = types.MethodType(sp_dit_forward, model)
158
-
159
- if dist.is_initialized():
160
- dist.barrier()
161
-
162
- if dit_fsdp:
163
- model = shard_fn(model)
164
- else:
165
- if convert_model_dtype:
166
- model.to(self.param_dtype)
167
- if not self.init_on_cpu:
168
- model.to(self.device)
169
-
170
- return model
171
-
172
- def _prepare_model_for_timestep(self, t, boundary, offload_model):
173
- r"""
174
- Prepares and returns the required model for the current timestep.
175
-
176
- Args:
177
- t (torch.Tensor):
178
- current timestep.
179
- boundary (`int`):
180
- The timestep threshold. If `t` is at or above this value,
181
- the `high_noise_model` is considered as the required model.
182
- offload_model (`bool`):
183
- A flag intended to control the offloading behavior.
184
-
185
- Returns:
186
- torch.nn.Module:
187
- The active model on the target device for the current timestep.
188
- """
189
- if t.item() >= boundary:
190
- required_model_name = 'high_noise_model'
191
- offload_model_name = 'low_noise_model'
192
- else:
193
- required_model_name = 'low_noise_model'
194
- offload_model_name = 'high_noise_model'
195
- if offload_model or self.init_on_cpu:
196
- if next(getattr(
197
- self,
198
- offload_model_name).parameters()).device.type == 'cuda':
199
- getattr(self, offload_model_name).to('cpu')
200
- if next(getattr(
201
- self,
202
- required_model_name).parameters()).device.type == 'cpu':
203
- getattr(self, required_model_name).to(self.device)
204
- return getattr(self, required_model_name)
205
-
206
- def generate(self,
207
- input_prompt,
208
- img,
209
- max_area=720 * 1280,
210
- frame_num=81,
211
- shift=5.0,
212
- sample_solver='unipc',
213
- sampling_steps=40,
214
- guide_scale=5.0,
215
- n_prompt="",
216
- seed=-1,
217
- offload_model=True):
218
- r"""
219
- Generates video frames from input image and text prompt using diffusion process.
220
-
221
- Args:
222
- input_prompt (`str`):
223
- Text prompt for content generation.
224
- img (PIL.Image.Image):
225
- Input image tensor. Shape: [3, H, W]
226
- max_area (`int`, *optional*, defaults to 720*1280):
227
- Maximum pixel area for latent space calculation. Controls video resolution scaling
228
- frame_num (`int`, *optional*, defaults to 81):
229
- How many frames to sample from a video. The number should be 4n+1
230
- shift (`float`, *optional*, defaults to 5.0):
231
- Noise schedule shift parameter. Affects temporal dynamics
232
- [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
233
- sample_solver (`str`, *optional*, defaults to 'unipc'):
234
- Solver used to sample the video.
235
- sampling_steps (`int`, *optional*, defaults to 40):
236
- Number of diffusion sampling steps. Higher values improve quality but slow generation
237
- guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0):
238
- Classifier-free guidance scale. Controls prompt adherence vs. creativity.
239
- If tuple, the first guide_scale will be used for low noise model and
240
- the second guide_scale will be used for high noise model.
241
- n_prompt (`str`, *optional*, defaults to ""):
242
- Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
243
- seed (`int`, *optional*, defaults to -1):
244
- Random seed for noise generation. If -1, use random seed
245
- offload_model (`bool`, *optional*, defaults to True):
246
- If True, offloads models to CPU during generation to save VRAM
247
-
248
- Returns:
249
- torch.Tensor:
250
- Generated video frames tensor. Dimensions: (C, N H, W) where:
251
- - C: Color channels (3 for RGB)
252
- - N: Number of frames (81)
253
- - H: Frame height (from max_area)
254
- - W: Frame width from max_area)
255
- """
256
- # preprocess
257
- guide_scale = (guide_scale, guide_scale) if isinstance(
258
- guide_scale, float) else guide_scale
259
- img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
260
-
261
- F = frame_num
262
- h, w = img.shape[1:]
263
- aspect_ratio = h / w
264
- lat_h = round(
265
- np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
266
- self.patch_size[1] * self.patch_size[1])
267
- lat_w = round(
268
- np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
269
- self.patch_size[2] * self.patch_size[2])
270
- h = lat_h * self.vae_stride[1]
271
- w = lat_w * self.vae_stride[2]
272
-
273
- max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
274
- self.patch_size[1] * self.patch_size[2])
275
- max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
276
-
277
- seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
278
- seed_g = torch.Generator(device=self.device)
279
- seed_g.manual_seed(seed)
280
- noise = torch.randn(
281
- 16,
282
- 21,
283
- lat_h,
284
- lat_w,
285
- dtype=torch.float32,
286
- generator=seed_g,
287
- device=self.device)
288
-
289
- msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
290
- msk[:, 1:] = 0
291
- msk = torch.concat([
292
- torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
293
- ],
294
- dim=1)
295
- msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
296
- msk = msk.transpose(1, 2)[0]
297
-
298
- if n_prompt == "":
299
- n_prompt = self.sample_neg_prompt
300
-
301
- # preprocess
302
- if not self.t5_cpu:
303
- self.text_encoder.model.to(self.device)
304
- context = self.text_encoder([input_prompt], self.device)
305
- context_null = self.text_encoder([n_prompt], self.device)
306
- if offload_model:
307
- self.text_encoder.model.cpu()
308
- else:
309
- context = self.text_encoder([input_prompt], torch.device('cpu'))
310
- context_null = self.text_encoder([n_prompt], torch.device('cpu'))
311
- context = [t.to(self.device) for t in context]
312
- context_null = [t.to(self.device) for t in context_null]
313
-
314
- y = self.vae.encode([
315
- torch.concat([
316
- torch.nn.functional.interpolate(
317
- img[None].cpu(), size=(h, w), mode='bicubic').transpose(
318
- 0, 1),
319
- torch.zeros(3, 80, h, w)
320
- ],
321
- dim=1).to(self.device)
322
- ])[0]
323
- y = torch.concat([msk, y])
324
-
325
- @contextmanager
326
- def noop_no_sync():
327
- yield
328
-
329
- no_sync_low_noise = getattr(self.low_noise_model, 'no_sync',
330
- noop_no_sync)
331
- no_sync_high_noise = getattr(self.high_noise_model, 'no_sync',
332
- noop_no_sync)
333
-
334
- # evaluation mode
335
- with (
336
- torch.amp.autocast('cuda', dtype=self.param_dtype),
337
- torch.no_grad(),
338
- no_sync_low_noise(),
339
- no_sync_high_noise(),
340
- ):
341
- boundary = self.boundary * self.num_train_timesteps
342
-
343
- if sample_solver == 'unipc':
344
- sample_scheduler = FlowUniPCMultistepScheduler(
345
- num_train_timesteps=self.num_train_timesteps,
346
- shift=1,
347
- use_dynamic_shifting=False)
348
- sample_scheduler.set_timesteps(
349
- sampling_steps, device=self.device, shift=shift)
350
- timesteps = sample_scheduler.timesteps
351
- elif sample_solver == 'dpm++':
352
- sample_scheduler = FlowDPMSolverMultistepScheduler(
353
- num_train_timesteps=self.num_train_timesteps,
354
- shift=1,
355
- use_dynamic_shifting=False)
356
- sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
357
- timesteps, _ = retrieve_timesteps(
358
- sample_scheduler,
359
- device=self.device,
360
- sigmas=sampling_sigmas)
361
- else:
362
- raise NotImplementedError("Unsupported solver.")
363
-
364
- # sample videos
365
- latent = noise
366
-
367
- arg_c = {
368
- 'context': [context[0]],
369
- 'seq_len': max_seq_len,
370
- 'y': [y],
371
- }
372
-
373
- arg_null = {
374
- 'context': context_null,
375
- 'seq_len': max_seq_len,
376
- 'y': [y],
377
- }
378
-
379
- if offload_model:
380
- torch.cuda.empty_cache()
381
-
382
- for _, t in enumerate(tqdm(timesteps)):
383
- latent_model_input = [latent.to(self.device)]
384
- timestep = [t]
385
-
386
- timestep = torch.stack(timestep).to(self.device)
387
-
388
- model = self._prepare_model_for_timestep(
389
- t, boundary, offload_model)
390
- sample_guide_scale = guide_scale[1] if t.item(
391
- ) >= boundary else guide_scale[0]
392
-
393
- noise_pred_cond = model(
394
- latent_model_input, t=timestep, **arg_c)[0]
395
- if offload_model:
396
- torch.cuda.empty_cache()
397
- noise_pred_uncond = model(
398
- latent_model_input, t=timestep, **arg_null)[0]
399
- if offload_model:
400
- torch.cuda.empty_cache()
401
- noise_pred = noise_pred_uncond + sample_guide_scale * (
402
- noise_pred_cond - noise_pred_uncond)
403
-
404
- temp_x0 = sample_scheduler.step(
405
- noise_pred.unsqueeze(0),
406
- t,
407
- latent.unsqueeze(0),
408
- return_dict=False,
409
- generator=seed_g)[0]
410
- latent = temp_x0.squeeze(0)
411
-
412
- x0 = [latent]
413
- del latent_model_input, timestep
414
-
415
- if offload_model:
416
- self.low_noise_model.cpu()
417
- self.high_noise_model.cpu()
418
- torch.cuda.empty_cache()
419
-
420
- if self.rank == 0:
421
- videos = self.vae.decode(x0)
422
-
423
- del noise, latent, x0
424
- del sample_scheduler
425
- if offload_model:
426
- gc.collect()
427
- torch.cuda.synchronize()
428
- if dist.is_initialized():
429
- dist.barrier()
430
-
431
- return videos[0] if self.rank == 0 else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/modules/__init__.py DELETED
@@ -1,19 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- from .attention import flash_attention
3
- from .model import WanModel
4
- from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
5
- from .tokenizers import HuggingfaceTokenizer
6
- from .vae2_1 import Wan2_1_VAE
7
- from .vae2_2 import Wan2_2_VAE
8
-
9
- __all__ = [
10
- 'Wan2_1_VAE',
11
- 'Wan2_2_VAE',
12
- 'WanModel',
13
- 'T5Model',
14
- 'T5Encoder',
15
- 'T5Decoder',
16
- 'T5EncoderModel',
17
- 'HuggingfaceTokenizer',
18
- 'flash_attention',
19
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/modules/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (528 Bytes)
 
wan/modules/__pycache__/attention.cpython-310.pyc DELETED
Binary file (3.95 kB)
 
wan/modules/__pycache__/model.cpython-310.pyc DELETED
Binary file (16.9 kB)
 
wan/modules/__pycache__/t5.cpython-310.pyc DELETED
Binary file (12.9 kB)
 
wan/modules/__pycache__/tokenizers.cpython-310.pyc DELETED
Binary file (2.55 kB)
 
wan/modules/__pycache__/vae2_1.cpython-310.pyc DELETED
Binary file (16.9 kB)
 
wan/modules/__pycache__/vae2_2.cpython-310.pyc DELETED
Binary file (22.1 kB)
 
wan/modules/attention.py DELETED
@@ -1,179 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import torch
3
-
4
- try:
5
- import flash_attn_interface
6
- FLASH_ATTN_3_AVAILABLE = True
7
- except ModuleNotFoundError:
8
- FLASH_ATTN_3_AVAILABLE = False
9
-
10
- try:
11
- import flash_attn
12
- FLASH_ATTN_2_AVAILABLE = True
13
- except ModuleNotFoundError:
14
- FLASH_ATTN_2_AVAILABLE = False
15
-
16
- import warnings
17
-
18
- __all__ = [
19
- 'flash_attention',
20
- 'attention',
21
- ]
22
-
23
-
24
- def flash_attention(
25
- q,
26
- k,
27
- v,
28
- q_lens=None,
29
- k_lens=None,
30
- dropout_p=0.,
31
- softmax_scale=None,
32
- q_scale=None,
33
- causal=False,
34
- window_size=(-1, -1),
35
- deterministic=False,
36
- dtype=torch.bfloat16,
37
- version=None,
38
- ):
39
- """
40
- q: [B, Lq, Nq, C1].
41
- k: [B, Lk, Nk, C1].
42
- v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
43
- q_lens: [B].
44
- k_lens: [B].
45
- dropout_p: float. Dropout probability.
46
- softmax_scale: float. The scaling of QK^T before applying softmax.
47
- causal: bool. Whether to apply causal attention mask.
48
- window_size: (left right). If not (-1, -1), apply sliding window local attention.
49
- deterministic: bool. If True, slightly slower and uses more memory.
50
- dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
51
- """
52
- half_dtypes = (torch.float16, torch.bfloat16)
53
- assert dtype in half_dtypes
54
- assert q.device.type == 'cuda' and q.size(-1) <= 256
55
-
56
- # params
57
- b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
58
-
59
- def half(x):
60
- return x if x.dtype in half_dtypes else x.to(dtype)
61
-
62
- # preprocess query
63
- if q_lens is None:
64
- q = half(q.flatten(0, 1))
65
- q_lens = torch.tensor(
66
- [lq] * b, dtype=torch.int32).to(
67
- device=q.device, non_blocking=True)
68
- else:
69
- q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
70
-
71
- # preprocess key, value
72
- if k_lens is None:
73
- k = half(k.flatten(0, 1))
74
- v = half(v.flatten(0, 1))
75
- k_lens = torch.tensor(
76
- [lk] * b, dtype=torch.int32).to(
77
- device=k.device, non_blocking=True)
78
- else:
79
- k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
80
- v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
81
-
82
- q = q.to(v.dtype)
83
- k = k.to(v.dtype)
84
-
85
- if q_scale is not None:
86
- q = q * q_scale
87
-
88
- if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
89
- warnings.warn(
90
- 'Flash attention 3 is not available, use flash attention 2 instead.'
91
- )
92
-
93
- # apply attention
94
- if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
95
- # Note: dropout_p, window_size are not supported in FA3 now.
96
- x = flash_attn_interface.flash_attn_varlen_func(
97
- q=q,
98
- k=k,
99
- v=v,
100
- cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
101
- 0, dtype=torch.int32).to(q.device, non_blocking=True),
102
- cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
103
- 0, dtype=torch.int32).to(q.device, non_blocking=True),
104
- seqused_q=None,
105
- seqused_k=None,
106
- max_seqlen_q=lq,
107
- max_seqlen_k=lk,
108
- softmax_scale=softmax_scale,
109
- causal=causal,
110
- deterministic=deterministic)[0].unflatten(0, (b, lq))
111
- else:
112
- assert FLASH_ATTN_2_AVAILABLE
113
- x = flash_attn.flash_attn_varlen_func(
114
- q=q,
115
- k=k,
116
- v=v,
117
- cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
118
- 0, dtype=torch.int32).to(q.device, non_blocking=True),
119
- cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
120
- 0, dtype=torch.int32).to(q.device, non_blocking=True),
121
- max_seqlen_q=lq,
122
- max_seqlen_k=lk,
123
- dropout_p=dropout_p,
124
- softmax_scale=softmax_scale,
125
- causal=causal,
126
- window_size=window_size,
127
- deterministic=deterministic).unflatten(0, (b, lq))
128
-
129
- # output
130
- return x.type(out_dtype)
131
-
132
-
133
- def attention(
134
- q,
135
- k,
136
- v,
137
- q_lens=None,
138
- k_lens=None,
139
- dropout_p=0.,
140
- softmax_scale=None,
141
- q_scale=None,
142
- causal=False,
143
- window_size=(-1, -1),
144
- deterministic=False,
145
- dtype=torch.bfloat16,
146
- fa_version=None,
147
- ):
148
- if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
149
- return flash_attention(
150
- q=q,
151
- k=k,
152
- v=v,
153
- q_lens=q_lens,
154
- k_lens=k_lens,
155
- dropout_p=dropout_p,
156
- softmax_scale=softmax_scale,
157
- q_scale=q_scale,
158
- causal=causal,
159
- window_size=window_size,
160
- deterministic=deterministic,
161
- dtype=dtype,
162
- version=fa_version,
163
- )
164
- else:
165
- if q_lens is not None or k_lens is not None:
166
- warnings.warn(
167
- 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
168
- )
169
- attn_mask = None
170
-
171
- q = q.transpose(1, 2).to(dtype)
172
- k = k.transpose(1, 2).to(dtype)
173
- v = v.transpose(1, 2).to(dtype)
174
-
175
- out = torch.nn.functional.scaled_dot_product_attention(
176
- q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
177
-
178
- out = out.transpose(1, 2).contiguous()
179
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/modules/model.py DELETED
@@ -1,546 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import math
3
-
4
- import torch
5
- import torch.nn as nn
6
- from diffusers.configuration_utils import ConfigMixin, register_to_config
7
- from diffusers.models.modeling_utils import ModelMixin
8
-
9
- from .attention import flash_attention
10
-
11
- __all__ = ['WanModel']
12
-
13
-
14
- def sinusoidal_embedding_1d(dim, position):
15
- # preprocess
16
- assert dim % 2 == 0
17
- half = dim // 2
18
- position = position.type(torch.float64)
19
-
20
- # calculation
21
- sinusoid = torch.outer(
22
- position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
23
- x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
24
- return x
25
-
26
-
27
- @torch.amp.autocast('cuda', enabled=False)
28
- def rope_params(max_seq_len, dim, theta=10000):
29
- assert dim % 2 == 0
30
- freqs = torch.outer(
31
- torch.arange(max_seq_len),
32
- 1.0 / torch.pow(theta,
33
- torch.arange(0, dim, 2).to(torch.float64).div(dim)))
34
- freqs = torch.polar(torch.ones_like(freqs), freqs)
35
- return freqs
36
-
37
-
38
- @torch.amp.autocast('cuda', enabled=False)
39
- def rope_apply(x, grid_sizes, freqs):
40
- n, c = x.size(2), x.size(3) // 2
41
-
42
- # split freqs
43
- freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
44
-
45
- # loop over samples
46
- output = []
47
- for i, (f, h, w) in enumerate(grid_sizes.tolist()):
48
- seq_len = f * h * w
49
-
50
- # precompute multipliers
51
- x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
52
- seq_len, n, -1, 2))
53
- freqs_i = torch.cat([
54
- freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
55
- freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
56
- freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
57
- ],
58
- dim=-1).reshape(seq_len, 1, -1)
59
-
60
- # apply rotary embedding
61
- x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
62
- x_i = torch.cat([x_i, x[i, seq_len:]])
63
-
64
- # append to collection
65
- output.append(x_i)
66
- return torch.stack(output).float()
67
-
68
-
69
- class WanRMSNorm(nn.Module):
70
-
71
- def __init__(self, dim, eps=1e-5):
72
- super().__init__()
73
- self.dim = dim
74
- self.eps = eps
75
- self.weight = nn.Parameter(torch.ones(dim))
76
-
77
- def forward(self, x):
78
- r"""
79
- Args:
80
- x(Tensor): Shape [B, L, C]
81
- """
82
- return self._norm(x.float()).type_as(x) * self.weight
83
-
84
- def _norm(self, x):
85
- return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
86
-
87
-
88
- class WanLayerNorm(nn.LayerNorm):
89
-
90
- def __init__(self, dim, eps=1e-6, elementwise_affine=False):
91
- super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
92
-
93
- def forward(self, x):
94
- r"""
95
- Args:
96
- x(Tensor): Shape [B, L, C]
97
- """
98
- return super().forward(x.float()).type_as(x)
99
-
100
-
101
- class WanSelfAttention(nn.Module):
102
-
103
- def __init__(self,
104
- dim,
105
- num_heads,
106
- window_size=(-1, -1),
107
- qk_norm=True,
108
- eps=1e-6):
109
- assert dim % num_heads == 0
110
- super().__init__()
111
- self.dim = dim
112
- self.num_heads = num_heads
113
- self.head_dim = dim // num_heads
114
- self.window_size = window_size
115
- self.qk_norm = qk_norm
116
- self.eps = eps
117
-
118
- # layers
119
- self.q = nn.Linear(dim, dim)
120
- self.k = nn.Linear(dim, dim)
121
- self.v = nn.Linear(dim, dim)
122
- self.o = nn.Linear(dim, dim)
123
- self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
124
- self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
125
-
126
- def forward(self, x, seq_lens, grid_sizes, freqs):
127
- r"""
128
- Args:
129
- x(Tensor): Shape [B, L, num_heads, C / num_heads]
130
- seq_lens(Tensor): Shape [B]
131
- grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
132
- freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
133
- """
134
- b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
135
-
136
- # query, key, value function
137
- def qkv_fn(x):
138
- q = self.norm_q(self.q(x)).view(b, s, n, d)
139
- k = self.norm_k(self.k(x)).view(b, s, n, d)
140
- v = self.v(x).view(b, s, n, d)
141
- return q, k, v
142
-
143
- q, k, v = qkv_fn(x)
144
-
145
- x = flash_attention(
146
- q=rope_apply(q, grid_sizes, freqs),
147
- k=rope_apply(k, grid_sizes, freqs),
148
- v=v,
149
- k_lens=seq_lens,
150
- window_size=self.window_size)
151
-
152
- # output
153
- x = x.flatten(2)
154
- x = self.o(x)
155
- return x
156
-
157
-
158
- class WanCrossAttention(WanSelfAttention):
159
-
160
- def forward(self, x, context, context_lens):
161
- r"""
162
- Args:
163
- x(Tensor): Shape [B, L1, C]
164
- context(Tensor): Shape [B, L2, C]
165
- context_lens(Tensor): Shape [B]
166
- """
167
- b, n, d = x.size(0), self.num_heads, self.head_dim
168
-
169
- # compute query, key, value
170
- q = self.norm_q(self.q(x)).view(b, -1, n, d)
171
- k = self.norm_k(self.k(context)).view(b, -1, n, d)
172
- v = self.v(context).view(b, -1, n, d)
173
-
174
- # compute attention
175
- x = flash_attention(q, k, v, k_lens=context_lens)
176
-
177
- # output
178
- x = x.flatten(2)
179
- x = self.o(x)
180
- return x
181
-
182
-
183
- class WanAttentionBlock(nn.Module):
184
-
185
- def __init__(self,
186
- dim,
187
- ffn_dim,
188
- num_heads,
189
- window_size=(-1, -1),
190
- qk_norm=True,
191
- cross_attn_norm=False,
192
- eps=1e-6):
193
- super().__init__()
194
- self.dim = dim
195
- self.ffn_dim = ffn_dim
196
- self.num_heads = num_heads
197
- self.window_size = window_size
198
- self.qk_norm = qk_norm
199
- self.cross_attn_norm = cross_attn_norm
200
- self.eps = eps
201
-
202
- # layers
203
- self.norm1 = WanLayerNorm(dim, eps)
204
- self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
205
- eps)
206
- self.norm3 = WanLayerNorm(
207
- dim, eps,
208
- elementwise_affine=True) if cross_attn_norm else nn.Identity()
209
- self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm,
210
- eps)
211
- self.norm2 = WanLayerNorm(dim, eps)
212
- self.ffn = nn.Sequential(
213
- nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
214
- nn.Linear(ffn_dim, dim))
215
-
216
- # modulation
217
- self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
218
-
219
- def forward(
220
- self,
221
- x,
222
- e,
223
- seq_lens,
224
- grid_sizes,
225
- freqs,
226
- context,
227
- context_lens,
228
- ):
229
- r"""
230
- Args:
231
- x(Tensor): Shape [B, L, C]
232
- e(Tensor): Shape [B, L1, 6, C]
233
- seq_lens(Tensor): Shape [B], length of each sequence in batch
234
- grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
235
- freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
236
- """
237
- assert e.dtype == torch.float32
238
- with torch.amp.autocast('cuda', dtype=torch.float32):
239
- e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
240
- assert e[0].dtype == torch.float32
241
-
242
- # self-attention
243
- y = self.self_attn(
244
- self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2),
245
- seq_lens, grid_sizes, freqs)
246
- with torch.amp.autocast('cuda', dtype=torch.float32):
247
- x = x + y * e[2].squeeze(2)
248
-
249
- # cross-attention & ffn function
250
- def cross_attn_ffn(x, context, context_lens, e):
251
- x = x + self.cross_attn(self.norm3(x), context, context_lens)
252
- y = self.ffn(
253
- self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2))
254
- with torch.amp.autocast('cuda', dtype=torch.float32):
255
- x = x + y * e[5].squeeze(2)
256
- return x
257
-
258
- x = cross_attn_ffn(x, context, context_lens, e)
259
- return x
260
-
261
-
262
- class Head(nn.Module):
263
-
264
- def __init__(self, dim, out_dim, patch_size, eps=1e-6):
265
- super().__init__()
266
- self.dim = dim
267
- self.out_dim = out_dim
268
- self.patch_size = patch_size
269
- self.eps = eps
270
-
271
- # layers
272
- out_dim = math.prod(patch_size) * out_dim
273
- self.norm = WanLayerNorm(dim, eps)
274
- self.head = nn.Linear(dim, out_dim)
275
-
276
- # modulation
277
- self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
278
-
279
- def forward(self, x, e):
280
- r"""
281
- Args:
282
- x(Tensor): Shape [B, L1, C]
283
- e(Tensor): Shape [B, L1, C]
284
- """
285
- assert e.dtype == torch.float32
286
- with torch.amp.autocast('cuda', dtype=torch.float32):
287
- e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2)
288
- x = (
289
- self.head(
290
- self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2)))
291
- return x
292
-
293
-
294
- class WanModel(ModelMixin, ConfigMixin):
295
- r"""
296
- Wan diffusion backbone supporting both text-to-video and image-to-video.
297
- """
298
-
299
- ignore_for_config = [
300
- 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
301
- ]
302
- _no_split_modules = ['WanAttentionBlock']
303
-
304
- @register_to_config
305
- def __init__(self,
306
- model_type='t2v',
307
- patch_size=(1, 2, 2),
308
- text_len=512,
309
- in_dim=16,
310
- dim=2048,
311
- ffn_dim=8192,
312
- freq_dim=256,
313
- text_dim=4096,
314
- out_dim=16,
315
- num_heads=16,
316
- num_layers=32,
317
- window_size=(-1, -1),
318
- qk_norm=True,
319
- cross_attn_norm=True,
320
- eps=1e-6):
321
- r"""
322
- Initialize the diffusion model backbone.
323
-
324
- Args:
325
- model_type (`str`, *optional*, defaults to 't2v'):
326
- Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
327
- patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
328
- 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
329
- text_len (`int`, *optional*, defaults to 512):
330
- Fixed length for text embeddings
331
- in_dim (`int`, *optional*, defaults to 16):
332
- Input video channels (C_in)
333
- dim (`int`, *optional*, defaults to 2048):
334
- Hidden dimension of the transformer
335
- ffn_dim (`int`, *optional*, defaults to 8192):
336
- Intermediate dimension in feed-forward network
337
- freq_dim (`int`, *optional*, defaults to 256):
338
- Dimension for sinusoidal time embeddings
339
- text_dim (`int`, *optional*, defaults to 4096):
340
- Input dimension for text embeddings
341
- out_dim (`int`, *optional*, defaults to 16):
342
- Output video channels (C_out)
343
- num_heads (`int`, *optional*, defaults to 16):
344
- Number of attention heads
345
- num_layers (`int`, *optional*, defaults to 32):
346
- Number of transformer blocks
347
- window_size (`tuple`, *optional*, defaults to (-1, -1)):
348
- Window size for local attention (-1 indicates global attention)
349
- qk_norm (`bool`, *optional*, defaults to True):
350
- Enable query/key normalization
351
- cross_attn_norm (`bool`, *optional*, defaults to False):
352
- Enable cross-attention normalization
353
- eps (`float`, *optional*, defaults to 1e-6):
354
- Epsilon value for normalization layers
355
- """
356
-
357
- super().__init__()
358
-
359
- assert model_type in ['t2v', 'i2v', 'ti2v']
360
- self.model_type = model_type
361
-
362
- self.patch_size = patch_size
363
- self.text_len = text_len
364
- self.in_dim = in_dim
365
- self.dim = dim
366
- self.ffn_dim = ffn_dim
367
- self.freq_dim = freq_dim
368
- self.text_dim = text_dim
369
- self.out_dim = out_dim
370
- self.num_heads = num_heads
371
- self.num_layers = num_layers
372
- self.window_size = window_size
373
- self.qk_norm = qk_norm
374
- self.cross_attn_norm = cross_attn_norm
375
- self.eps = eps
376
-
377
- # embeddings
378
- self.patch_embedding = nn.Conv3d(
379
- in_dim, dim, kernel_size=patch_size, stride=patch_size)
380
- self.text_embedding = nn.Sequential(
381
- nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
382
- nn.Linear(dim, dim))
383
-
384
- self.time_embedding = nn.Sequential(
385
- nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
386
- self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
387
-
388
- # blocks
389
- self.blocks = nn.ModuleList([
390
- WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,
391
- cross_attn_norm, eps) for _ in range(num_layers)
392
- ])
393
-
394
- # head
395
- self.head = Head(dim, out_dim, patch_size, eps)
396
-
397
- # buffers (don't use register_buffer otherwise dtype will be changed in to())
398
- assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
399
- d = dim // num_heads
400
- self.freqs = torch.cat([
401
- rope_params(1024, d - 4 * (d // 6)),
402
- rope_params(1024, 2 * (d // 6)),
403
- rope_params(1024, 2 * (d // 6))
404
- ],
405
- dim=1)
406
-
407
- # initialize weights
408
- self.init_weights()
409
-
410
- def forward(
411
- self,
412
- x,
413
- t,
414
- context,
415
- seq_len,
416
- y=None,
417
- ):
418
- r"""
419
- Forward pass through the diffusion model
420
-
421
- Args:
422
- x (List[Tensor]):
423
- List of input video tensors, each with shape [C_in, F, H, W]
424
- t (Tensor):
425
- Diffusion timesteps tensor of shape [B]
426
- context (List[Tensor]):
427
- List of text embeddings each with shape [L, C]
428
- seq_len (`int`):
429
- Maximum sequence length for positional encoding
430
- y (List[Tensor], *optional*):
431
- Conditional video inputs for image-to-video mode, same shape as x
432
-
433
- Returns:
434
- List[Tensor]:
435
- List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
436
- """
437
- if self.model_type == 'i2v':
438
- assert y is not None
439
- # params
440
- device = self.patch_embedding.weight.device
441
- if self.freqs.device != device:
442
- self.freqs = self.freqs.to(device)
443
-
444
- if y is not None:
445
- x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
446
-
447
- # embeddings
448
- x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
449
- grid_sizes = torch.stack(
450
- [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
451
- x = [u.flatten(2).transpose(1, 2) for u in x]
452
- seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
453
- assert seq_lens.max() <= seq_len
454
- x = torch.cat([
455
- torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
456
- dim=1) for u in x
457
- ])
458
-
459
- # time embeddings
460
- if t.dim() == 1:
461
- t = t.expand(t.size(0), seq_len)
462
- with torch.amp.autocast('cuda', dtype=torch.float32):
463
- bt = t.size(0)
464
- t = t.flatten()
465
- e = self.time_embedding(
466
- sinusoidal_embedding_1d(self.freq_dim,
467
- t).unflatten(0, (bt, seq_len)).float())
468
- e0 = self.time_projection(e).unflatten(2, (6, self.dim))
469
- assert e.dtype == torch.float32 and e0.dtype == torch.float32
470
-
471
- # context
472
- context_lens = None
473
- context = self.text_embedding(
474
- torch.stack([
475
- torch.cat(
476
- [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
477
- for u in context
478
- ]))
479
-
480
- # arguments
481
- kwargs = dict(
482
- e=e0,
483
- seq_lens=seq_lens,
484
- grid_sizes=grid_sizes,
485
- freqs=self.freqs,
486
- context=context,
487
- context_lens=context_lens)
488
-
489
- for block in self.blocks:
490
- x = block(x, **kwargs)
491
-
492
- # head
493
- x = self.head(x, e)
494
-
495
- # unpatchify
496
- x = self.unpatchify(x, grid_sizes)
497
- return [u.float() for u in x]
498
-
499
- def unpatchify(self, x, grid_sizes):
500
- r"""
501
- Reconstruct video tensors from patch embeddings.
502
-
503
- Args:
504
- x (List[Tensor]):
505
- List of patchified features, each with shape [L, C_out * prod(patch_size)]
506
- grid_sizes (Tensor):
507
- Original spatial-temporal grid dimensions before patching,
508
- shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
509
-
510
- Returns:
511
- List[Tensor]:
512
- Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
513
- """
514
-
515
- c = self.out_dim
516
- out = []
517
- for u, v in zip(x, grid_sizes.tolist()):
518
- u = u[:math.prod(v)].view(*v, *self.patch_size, c)
519
- u = torch.einsum('fhwpqrc->cfphqwr', u)
520
- u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
521
- out.append(u)
522
- return out
523
-
524
- def init_weights(self):
525
- r"""
526
- Initialize model parameters using Xavier initialization.
527
- """
528
-
529
- # basic init
530
- for m in self.modules():
531
- if isinstance(m, nn.Linear):
532
- nn.init.xavier_uniform_(m.weight)
533
- if m.bias is not None:
534
- nn.init.zeros_(m.bias)
535
-
536
- # init embeddings
537
- nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
538
- for m in self.text_embedding.modules():
539
- if isinstance(m, nn.Linear):
540
- nn.init.normal_(m.weight, std=.02)
541
- for m in self.time_embedding.modules():
542
- if isinstance(m, nn.Linear):
543
- nn.init.normal_(m.weight, std=.02)
544
-
545
- # init output layer
546
- nn.init.zeros_(self.head.head.weight)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/modules/t5.py DELETED
@@ -1,513 +0,0 @@
1
- # Modified from transformers.models.t5.modeling_t5
2
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
- import logging
4
- import math
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
-
10
- from .tokenizers import HuggingfaceTokenizer
11
-
12
- __all__ = [
13
- 'T5Model',
14
- 'T5Encoder',
15
- 'T5Decoder',
16
- 'T5EncoderModel',
17
- ]
18
-
19
-
20
- def fp16_clamp(x):
21
- if x.dtype == torch.float16 and torch.isinf(x).any():
22
- clamp = torch.finfo(x.dtype).max - 1000
23
- x = torch.clamp(x, min=-clamp, max=clamp)
24
- return x
25
-
26
-
27
- def init_weights(m):
28
- if isinstance(m, T5LayerNorm):
29
- nn.init.ones_(m.weight)
30
- elif isinstance(m, T5Model):
31
- nn.init.normal_(m.token_embedding.weight, std=1.0)
32
- elif isinstance(m, T5FeedForward):
33
- nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
34
- nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
35
- nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
36
- elif isinstance(m, T5Attention):
37
- nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
38
- nn.init.normal_(m.k.weight, std=m.dim**-0.5)
39
- nn.init.normal_(m.v.weight, std=m.dim**-0.5)
40
- nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
41
- elif isinstance(m, T5RelativeEmbedding):
42
- nn.init.normal_(
43
- m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
44
-
45
-
46
- class GELU(nn.Module):
47
-
48
- def forward(self, x):
49
- return 0.5 * x * (1.0 + torch.tanh(
50
- math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
51
-
52
-
53
- class T5LayerNorm(nn.Module):
54
-
55
- def __init__(self, dim, eps=1e-6):
56
- super(T5LayerNorm, self).__init__()
57
- self.dim = dim
58
- self.eps = eps
59
- self.weight = nn.Parameter(torch.ones(dim))
60
-
61
- def forward(self, x):
62
- x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
63
- self.eps)
64
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
65
- x = x.type_as(self.weight)
66
- return self.weight * x
67
-
68
-
69
- class T5Attention(nn.Module):
70
-
71
- def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
72
- assert dim_attn % num_heads == 0
73
- super(T5Attention, self).__init__()
74
- self.dim = dim
75
- self.dim_attn = dim_attn
76
- self.num_heads = num_heads
77
- self.head_dim = dim_attn // num_heads
78
-
79
- # layers
80
- self.q = nn.Linear(dim, dim_attn, bias=False)
81
- self.k = nn.Linear(dim, dim_attn, bias=False)
82
- self.v = nn.Linear(dim, dim_attn, bias=False)
83
- self.o = nn.Linear(dim_attn, dim, bias=False)
84
- self.dropout = nn.Dropout(dropout)
85
-
86
- def forward(self, x, context=None, mask=None, pos_bias=None):
87
- """
88
- x: [B, L1, C].
89
- context: [B, L2, C] or None.
90
- mask: [B, L2] or [B, L1, L2] or None.
91
- """
92
- # check inputs
93
- context = x if context is None else context
94
- b, n, c = x.size(0), self.num_heads, self.head_dim
95
-
96
- # compute query, key, value
97
- q = self.q(x).view(b, -1, n, c)
98
- k = self.k(context).view(b, -1, n, c)
99
- v = self.v(context).view(b, -1, n, c)
100
-
101
- # attention bias
102
- attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
103
- if pos_bias is not None:
104
- attn_bias += pos_bias
105
- if mask is not None:
106
- assert mask.ndim in [2, 3]
107
- mask = mask.view(b, 1, 1,
108
- -1) if mask.ndim == 2 else mask.unsqueeze(1)
109
- attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
110
-
111
- # compute attention (T5 does not use scaling)
112
- attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
113
- attn = F.softmax(attn.float(), dim=-1).type_as(attn)
114
- x = torch.einsum('bnij,bjnc->binc', attn, v)
115
-
116
- # output
117
- x = x.reshape(b, -1, n * c)
118
- x = self.o(x)
119
- x = self.dropout(x)
120
- return x
121
-
122
-
123
- class T5FeedForward(nn.Module):
124
-
125
- def __init__(self, dim, dim_ffn, dropout=0.1):
126
- super(T5FeedForward, self).__init__()
127
- self.dim = dim
128
- self.dim_ffn = dim_ffn
129
-
130
- # layers
131
- self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
132
- self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
133
- self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
134
- self.dropout = nn.Dropout(dropout)
135
-
136
- def forward(self, x):
137
- x = self.fc1(x) * self.gate(x)
138
- x = self.dropout(x)
139
- x = self.fc2(x)
140
- x = self.dropout(x)
141
- return x
142
-
143
-
144
- class T5SelfAttention(nn.Module):
145
-
146
- def __init__(self,
147
- dim,
148
- dim_attn,
149
- dim_ffn,
150
- num_heads,
151
- num_buckets,
152
- shared_pos=True,
153
- dropout=0.1):
154
- super(T5SelfAttention, self).__init__()
155
- self.dim = dim
156
- self.dim_attn = dim_attn
157
- self.dim_ffn = dim_ffn
158
- self.num_heads = num_heads
159
- self.num_buckets = num_buckets
160
- self.shared_pos = shared_pos
161
-
162
- # layers
163
- self.norm1 = T5LayerNorm(dim)
164
- self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
165
- self.norm2 = T5LayerNorm(dim)
166
- self.ffn = T5FeedForward(dim, dim_ffn, dropout)
167
- self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
168
- num_buckets, num_heads, bidirectional=True)
169
-
170
- def forward(self, x, mask=None, pos_bias=None):
171
- e = pos_bias if self.shared_pos else self.pos_embedding(
172
- x.size(1), x.size(1))
173
- x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
174
- x = fp16_clamp(x + self.ffn(self.norm2(x)))
175
- return x
176
-
177
-
178
- class T5CrossAttention(nn.Module):
179
-
180
- def __init__(self,
181
- dim,
182
- dim_attn,
183
- dim_ffn,
184
- num_heads,
185
- num_buckets,
186
- shared_pos=True,
187
- dropout=0.1):
188
- super(T5CrossAttention, self).__init__()
189
- self.dim = dim
190
- self.dim_attn = dim_attn
191
- self.dim_ffn = dim_ffn
192
- self.num_heads = num_heads
193
- self.num_buckets = num_buckets
194
- self.shared_pos = shared_pos
195
-
196
- # layers
197
- self.norm1 = T5LayerNorm(dim)
198
- self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
199
- self.norm2 = T5LayerNorm(dim)
200
- self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
201
- self.norm3 = T5LayerNorm(dim)
202
- self.ffn = T5FeedForward(dim, dim_ffn, dropout)
203
- self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
204
- num_buckets, num_heads, bidirectional=False)
205
-
206
- def forward(self,
207
- x,
208
- mask=None,
209
- encoder_states=None,
210
- encoder_mask=None,
211
- pos_bias=None):
212
- e = pos_bias if self.shared_pos else self.pos_embedding(
213
- x.size(1), x.size(1))
214
- x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
215
- x = fp16_clamp(x + self.cross_attn(
216
- self.norm2(x), context=encoder_states, mask=encoder_mask))
217
- x = fp16_clamp(x + self.ffn(self.norm3(x)))
218
- return x
219
-
220
-
221
- class T5RelativeEmbedding(nn.Module):
222
-
223
- def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
224
- super(T5RelativeEmbedding, self).__init__()
225
- self.num_buckets = num_buckets
226
- self.num_heads = num_heads
227
- self.bidirectional = bidirectional
228
- self.max_dist = max_dist
229
-
230
- # layers
231
- self.embedding = nn.Embedding(num_buckets, num_heads)
232
-
233
- def forward(self, lq, lk):
234
- device = self.embedding.weight.device
235
- # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
236
- # torch.arange(lq).unsqueeze(1).to(device)
237
- rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
238
- torch.arange(lq, device=device).unsqueeze(1)
239
- rel_pos = self._relative_position_bucket(rel_pos)
240
- rel_pos_embeds = self.embedding(rel_pos)
241
- rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
242
- 0) # [1, N, Lq, Lk]
243
- return rel_pos_embeds.contiguous()
244
-
245
- def _relative_position_bucket(self, rel_pos):
246
- # preprocess
247
- if self.bidirectional:
248
- num_buckets = self.num_buckets // 2
249
- rel_buckets = (rel_pos > 0).long() * num_buckets
250
- rel_pos = torch.abs(rel_pos)
251
- else:
252
- num_buckets = self.num_buckets
253
- rel_buckets = 0
254
- rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
255
-
256
- # embeddings for small and large positions
257
- max_exact = num_buckets // 2
258
- rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
259
- math.log(self.max_dist / max_exact) *
260
- (num_buckets - max_exact)).long()
261
- rel_pos_large = torch.min(
262
- rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
263
- rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
264
- return rel_buckets
265
-
266
-
267
- class T5Encoder(nn.Module):
268
-
269
- def __init__(self,
270
- vocab,
271
- dim,
272
- dim_attn,
273
- dim_ffn,
274
- num_heads,
275
- num_layers,
276
- num_buckets,
277
- shared_pos=True,
278
- dropout=0.1):
279
- super(T5Encoder, self).__init__()
280
- self.dim = dim
281
- self.dim_attn = dim_attn
282
- self.dim_ffn = dim_ffn
283
- self.num_heads = num_heads
284
- self.num_layers = num_layers
285
- self.num_buckets = num_buckets
286
- self.shared_pos = shared_pos
287
-
288
- # layers
289
- self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
290
- else nn.Embedding(vocab, dim)
291
- self.pos_embedding = T5RelativeEmbedding(
292
- num_buckets, num_heads, bidirectional=True) if shared_pos else None
293
- self.dropout = nn.Dropout(dropout)
294
- self.blocks = nn.ModuleList([
295
- T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
296
- shared_pos, dropout) for _ in range(num_layers)
297
- ])
298
- self.norm = T5LayerNorm(dim)
299
-
300
- # initialize weights
301
- self.apply(init_weights)
302
-
303
- def forward(self, ids, mask=None):
304
- x = self.token_embedding(ids)
305
- x = self.dropout(x)
306
- e = self.pos_embedding(x.size(1),
307
- x.size(1)) if self.shared_pos else None
308
- for block in self.blocks:
309
- x = block(x, mask, pos_bias=e)
310
- x = self.norm(x)
311
- x = self.dropout(x)
312
- return x
313
-
314
-
315
- class T5Decoder(nn.Module):
316
-
317
- def __init__(self,
318
- vocab,
319
- dim,
320
- dim_attn,
321
- dim_ffn,
322
- num_heads,
323
- num_layers,
324
- num_buckets,
325
- shared_pos=True,
326
- dropout=0.1):
327
- super(T5Decoder, self).__init__()
328
- self.dim = dim
329
- self.dim_attn = dim_attn
330
- self.dim_ffn = dim_ffn
331
- self.num_heads = num_heads
332
- self.num_layers = num_layers
333
- self.num_buckets = num_buckets
334
- self.shared_pos = shared_pos
335
-
336
- # layers
337
- self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
338
- else nn.Embedding(vocab, dim)
339
- self.pos_embedding = T5RelativeEmbedding(
340
- num_buckets, num_heads, bidirectional=False) if shared_pos else None
341
- self.dropout = nn.Dropout(dropout)
342
- self.blocks = nn.ModuleList([
343
- T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
344
- shared_pos, dropout) for _ in range(num_layers)
345
- ])
346
- self.norm = T5LayerNorm(dim)
347
-
348
- # initialize weights
349
- self.apply(init_weights)
350
-
351
- def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
352
- b, s = ids.size()
353
-
354
- # causal mask
355
- if mask is None:
356
- mask = torch.tril(torch.ones(1, s, s).to(ids.device))
357
- elif mask.ndim == 2:
358
- mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
359
-
360
- # layers
361
- x = self.token_embedding(ids)
362
- x = self.dropout(x)
363
- e = self.pos_embedding(x.size(1),
364
- x.size(1)) if self.shared_pos else None
365
- for block in self.blocks:
366
- x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
367
- x = self.norm(x)
368
- x = self.dropout(x)
369
- return x
370
-
371
-
372
- class T5Model(nn.Module):
373
-
374
- def __init__(self,
375
- vocab_size,
376
- dim,
377
- dim_attn,
378
- dim_ffn,
379
- num_heads,
380
- encoder_layers,
381
- decoder_layers,
382
- num_buckets,
383
- shared_pos=True,
384
- dropout=0.1):
385
- super(T5Model, self).__init__()
386
- self.vocab_size = vocab_size
387
- self.dim = dim
388
- self.dim_attn = dim_attn
389
- self.dim_ffn = dim_ffn
390
- self.num_heads = num_heads
391
- self.encoder_layers = encoder_layers
392
- self.decoder_layers = decoder_layers
393
- self.num_buckets = num_buckets
394
-
395
- # layers
396
- self.token_embedding = nn.Embedding(vocab_size, dim)
397
- self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
398
- num_heads, encoder_layers, num_buckets,
399
- shared_pos, dropout)
400
- self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
401
- num_heads, decoder_layers, num_buckets,
402
- shared_pos, dropout)
403
- self.head = nn.Linear(dim, vocab_size, bias=False)
404
-
405
- # initialize weights
406
- self.apply(init_weights)
407
-
408
- def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
409
- x = self.encoder(encoder_ids, encoder_mask)
410
- x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
411
- x = self.head(x)
412
- return x
413
-
414
-
415
- def _t5(name,
416
- encoder_only=False,
417
- decoder_only=False,
418
- return_tokenizer=False,
419
- tokenizer_kwargs={},
420
- dtype=torch.float32,
421
- device='cpu',
422
- **kwargs):
423
- # sanity check
424
- assert not (encoder_only and decoder_only)
425
-
426
- # params
427
- if encoder_only:
428
- model_cls = T5Encoder
429
- kwargs['vocab'] = kwargs.pop('vocab_size')
430
- kwargs['num_layers'] = kwargs.pop('encoder_layers')
431
- _ = kwargs.pop('decoder_layers')
432
- elif decoder_only:
433
- model_cls = T5Decoder
434
- kwargs['vocab'] = kwargs.pop('vocab_size')
435
- kwargs['num_layers'] = kwargs.pop('decoder_layers')
436
- _ = kwargs.pop('encoder_layers')
437
- else:
438
- model_cls = T5Model
439
-
440
- # init model
441
- with torch.device(device):
442
- model = model_cls(**kwargs)
443
-
444
- # set device
445
- model = model.to(dtype=dtype, device=device)
446
-
447
- # init tokenizer
448
- if return_tokenizer:
449
- from .tokenizers import HuggingfaceTokenizer
450
- tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
451
- return model, tokenizer
452
- else:
453
- return model
454
-
455
-
456
- def umt5_xxl(**kwargs):
457
- cfg = dict(
458
- vocab_size=256384,
459
- dim=4096,
460
- dim_attn=4096,
461
- dim_ffn=10240,
462
- num_heads=64,
463
- encoder_layers=24,
464
- decoder_layers=24,
465
- num_buckets=32,
466
- shared_pos=False,
467
- dropout=0.1)
468
- cfg.update(**kwargs)
469
- return _t5('umt5-xxl', **cfg)
470
-
471
-
472
- class T5EncoderModel:
473
-
474
- def __init__(
475
- self,
476
- text_len,
477
- dtype=torch.bfloat16,
478
- device=torch.cuda.current_device(),
479
- checkpoint_path=None,
480
- tokenizer_path=None,
481
- shard_fn=None,
482
- ):
483
- self.text_len = text_len
484
- self.dtype = dtype
485
- self.device = device
486
- self.checkpoint_path = checkpoint_path
487
- self.tokenizer_path = tokenizer_path
488
-
489
- # init model
490
- model = umt5_xxl(
491
- encoder_only=True,
492
- return_tokenizer=False,
493
- dtype=dtype,
494
- device=device).eval().requires_grad_(False)
495
- logging.info(f'loading {checkpoint_path}')
496
- model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
497
- self.model = model
498
- if shard_fn is not None:
499
- self.model = shard_fn(self.model, sync_module_states=False)
500
- else:
501
- self.model.to(self.device)
502
- # init tokenizer
503
- self.tokenizer = HuggingfaceTokenizer(
504
- name=tokenizer_path, seq_len=text_len, clean='whitespace')
505
-
506
- def __call__(self, texts, device):
507
- ids, mask = self.tokenizer(
508
- texts, return_mask=True, add_special_tokens=True)
509
- ids = ids.to(device)
510
- mask = mask.to(device)
511
- seq_lens = mask.gt(0).sum(dim=1).long()
512
- context = self.model(ids, mask)
513
- return [u[:v] for u, v in zip(context, seq_lens)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/modules/tokenizers.py DELETED
@@ -1,82 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import html
3
- import string
4
-
5
- import ftfy
6
- import regex as re
7
- from transformers import AutoTokenizer
8
-
9
- __all__ = ['HuggingfaceTokenizer']
10
-
11
-
12
- def basic_clean(text):
13
- text = ftfy.fix_text(text)
14
- text = html.unescape(html.unescape(text))
15
- return text.strip()
16
-
17
-
18
- def whitespace_clean(text):
19
- text = re.sub(r'\s+', ' ', text)
20
- text = text.strip()
21
- return text
22
-
23
-
24
- def canonicalize(text, keep_punctuation_exact_string=None):
25
- text = text.replace('_', ' ')
26
- if keep_punctuation_exact_string:
27
- text = keep_punctuation_exact_string.join(
28
- part.translate(str.maketrans('', '', string.punctuation))
29
- for part in text.split(keep_punctuation_exact_string))
30
- else:
31
- text = text.translate(str.maketrans('', '', string.punctuation))
32
- text = text.lower()
33
- text = re.sub(r'\s+', ' ', text)
34
- return text.strip()
35
-
36
-
37
- class HuggingfaceTokenizer:
38
-
39
- def __init__(self, name, seq_len=None, clean=None, **kwargs):
40
- assert clean in (None, 'whitespace', 'lower', 'canonicalize')
41
- self.name = name
42
- self.seq_len = seq_len
43
- self.clean = clean
44
-
45
- # init tokenizer
46
- self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
47
- self.vocab_size = self.tokenizer.vocab_size
48
-
49
- def __call__(self, sequence, **kwargs):
50
- return_mask = kwargs.pop('return_mask', False)
51
-
52
- # arguments
53
- _kwargs = {'return_tensors': 'pt'}
54
- if self.seq_len is not None:
55
- _kwargs.update({
56
- 'padding': 'max_length',
57
- 'truncation': True,
58
- 'max_length': self.seq_len
59
- })
60
- _kwargs.update(**kwargs)
61
-
62
- # tokenization
63
- if isinstance(sequence, str):
64
- sequence = [sequence]
65
- if self.clean:
66
- sequence = [self._clean(u) for u in sequence]
67
- ids = self.tokenizer(sequence, **_kwargs)
68
-
69
- # output
70
- if return_mask:
71
- return ids.input_ids, ids.attention_mask
72
- else:
73
- return ids.input_ids
74
-
75
- def _clean(self, text):
76
- if self.clean == 'whitespace':
77
- text = whitespace_clean(basic_clean(text))
78
- elif self.clean == 'lower':
79
- text = whitespace_clean(basic_clean(text)).lower()
80
- elif self.clean == 'canonicalize':
81
- text = canonicalize(basic_clean(text))
82
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/modules/vae2_1.py DELETED
@@ -1,663 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import logging
3
-
4
- import torch
5
- import torch.cuda.amp as amp
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from einops import rearrange
9
-
10
- __all__ = [
11
- 'Wan2_1_VAE',
12
- ]
13
-
14
- CACHE_T = 2
15
-
16
-
17
- class CausalConv3d(nn.Conv3d):
18
- """
19
- Causal 3d convolusion.
20
- """
21
-
22
- def __init__(self, *args, **kwargs):
23
- super().__init__(*args, **kwargs)
24
- self._padding = (self.padding[2], self.padding[2], self.padding[1],
25
- self.padding[1], 2 * self.padding[0], 0)
26
- self.padding = (0, 0, 0)
27
-
28
- def forward(self, x, cache_x=None):
29
- padding = list(self._padding)
30
- if cache_x is not None and self._padding[4] > 0:
31
- cache_x = cache_x.to(x.device)
32
- x = torch.cat([cache_x, x], dim=2)
33
- padding[4] -= cache_x.shape[2]
34
- x = F.pad(x, padding)
35
-
36
- return super().forward(x)
37
-
38
-
39
- class RMS_norm(nn.Module):
40
-
41
- def __init__(self, dim, channel_first=True, images=True, bias=False):
42
- super().__init__()
43
- broadcastable_dims = (1, 1, 1) if not images else (1, 1)
44
- shape = (dim, *broadcastable_dims) if channel_first else (dim,)
45
-
46
- self.channel_first = channel_first
47
- self.scale = dim**0.5
48
- self.gamma = nn.Parameter(torch.ones(shape))
49
- self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
50
-
51
- def forward(self, x):
52
- return F.normalize(
53
- x, dim=(1 if self.channel_first else
54
- -1)) * self.scale * self.gamma + self.bias
55
-
56
-
57
- class Upsample(nn.Upsample):
58
-
59
- def forward(self, x):
60
- """
61
- Fix bfloat16 support for nearest neighbor interpolation.
62
- """
63
- return super().forward(x.float()).type_as(x)
64
-
65
-
66
- class Resample(nn.Module):
67
-
68
- def __init__(self, dim, mode):
69
- assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
70
- 'downsample3d')
71
- super().__init__()
72
- self.dim = dim
73
- self.mode = mode
74
-
75
- # layers
76
- if mode == 'upsample2d':
77
- self.resample = nn.Sequential(
78
- Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
79
- nn.Conv2d(dim, dim // 2, 3, padding=1))
80
- elif mode == 'upsample3d':
81
- self.resample = nn.Sequential(
82
- Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
83
- nn.Conv2d(dim, dim // 2, 3, padding=1))
84
- self.time_conv = CausalConv3d(
85
- dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
86
-
87
- elif mode == 'downsample2d':
88
- self.resample = nn.Sequential(
89
- nn.ZeroPad2d((0, 1, 0, 1)),
90
- nn.Conv2d(dim, dim, 3, stride=(2, 2)))
91
- elif mode == 'downsample3d':
92
- self.resample = nn.Sequential(
93
- nn.ZeroPad2d((0, 1, 0, 1)),
94
- nn.Conv2d(dim, dim, 3, stride=(2, 2)))
95
- self.time_conv = CausalConv3d(
96
- dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
97
-
98
- else:
99
- self.resample = nn.Identity()
100
-
101
- def forward(self, x, feat_cache=None, feat_idx=[0]):
102
- b, c, t, h, w = x.size()
103
- if self.mode == 'upsample3d':
104
- if feat_cache is not None:
105
- idx = feat_idx[0]
106
- if feat_cache[idx] is None:
107
- feat_cache[idx] = 'Rep'
108
- feat_idx[0] += 1
109
- else:
110
-
111
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
112
- if cache_x.shape[2] < 2 and feat_cache[
113
- idx] is not None and feat_cache[idx] != 'Rep':
114
- # cache last frame of last two chunk
115
- cache_x = torch.cat([
116
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
117
- cache_x.device), cache_x
118
- ],
119
- dim=2)
120
- if cache_x.shape[2] < 2 and feat_cache[
121
- idx] is not None and feat_cache[idx] == 'Rep':
122
- cache_x = torch.cat([
123
- torch.zeros_like(cache_x).to(cache_x.device),
124
- cache_x
125
- ],
126
- dim=2)
127
- if feat_cache[idx] == 'Rep':
128
- x = self.time_conv(x)
129
- else:
130
- x = self.time_conv(x, feat_cache[idx])
131
- feat_cache[idx] = cache_x
132
- feat_idx[0] += 1
133
-
134
- x = x.reshape(b, 2, c, t, h, w)
135
- x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
136
- 3)
137
- x = x.reshape(b, c, t * 2, h, w)
138
- t = x.shape[2]
139
- x = rearrange(x, 'b c t h w -> (b t) c h w')
140
- x = self.resample(x)
141
- x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
142
-
143
- if self.mode == 'downsample3d':
144
- if feat_cache is not None:
145
- idx = feat_idx[0]
146
- if feat_cache[idx] is None:
147
- feat_cache[idx] = x.clone()
148
- feat_idx[0] += 1
149
- else:
150
-
151
- cache_x = x[:, :, -1:, :, :].clone()
152
- # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
153
- # # cache last frame of last two chunk
154
- # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
155
-
156
- x = self.time_conv(
157
- torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
158
- feat_cache[idx] = cache_x
159
- feat_idx[0] += 1
160
- return x
161
-
162
- def init_weight(self, conv):
163
- conv_weight = conv.weight
164
- nn.init.zeros_(conv_weight)
165
- c1, c2, t, h, w = conv_weight.size()
166
- one_matrix = torch.eye(c1, c2)
167
- init_matrix = one_matrix
168
- nn.init.zeros_(conv_weight)
169
- #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
170
- conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
171
- conv.weight.data.copy_(conv_weight)
172
- nn.init.zeros_(conv.bias.data)
173
-
174
- def init_weight2(self, conv):
175
- conv_weight = conv.weight.data
176
- nn.init.zeros_(conv_weight)
177
- c1, c2, t, h, w = conv_weight.size()
178
- init_matrix = torch.eye(c1 // 2, c2)
179
- #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
180
- conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
181
- conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
182
- conv.weight.data.copy_(conv_weight)
183
- nn.init.zeros_(conv.bias.data)
184
-
185
-
186
- class ResidualBlock(nn.Module):
187
-
188
- def __init__(self, in_dim, out_dim, dropout=0.0):
189
- super().__init__()
190
- self.in_dim = in_dim
191
- self.out_dim = out_dim
192
-
193
- # layers
194
- self.residual = nn.Sequential(
195
- RMS_norm(in_dim, images=False), nn.SiLU(),
196
- CausalConv3d(in_dim, out_dim, 3, padding=1),
197
- RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
198
- CausalConv3d(out_dim, out_dim, 3, padding=1))
199
- self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
200
- if in_dim != out_dim else nn.Identity()
201
-
202
- def forward(self, x, feat_cache=None, feat_idx=[0]):
203
- h = self.shortcut(x)
204
- for layer in self.residual:
205
- if isinstance(layer, CausalConv3d) and feat_cache is not None:
206
- idx = feat_idx[0]
207
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
208
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
209
- # cache last frame of last two chunk
210
- cache_x = torch.cat([
211
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
212
- cache_x.device), cache_x
213
- ],
214
- dim=2)
215
- x = layer(x, feat_cache[idx])
216
- feat_cache[idx] = cache_x
217
- feat_idx[0] += 1
218
- else:
219
- x = layer(x)
220
- return x + h
221
-
222
-
223
- class AttentionBlock(nn.Module):
224
- """
225
- Causal self-attention with a single head.
226
- """
227
-
228
- def __init__(self, dim):
229
- super().__init__()
230
- self.dim = dim
231
-
232
- # layers
233
- self.norm = RMS_norm(dim)
234
- self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
235
- self.proj = nn.Conv2d(dim, dim, 1)
236
-
237
- # zero out the last layer params
238
- nn.init.zeros_(self.proj.weight)
239
-
240
- def forward(self, x):
241
- identity = x
242
- b, c, t, h, w = x.size()
243
- x = rearrange(x, 'b c t h w -> (b t) c h w')
244
- x = self.norm(x)
245
- # compute query, key, value
246
- q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
247
- -1).permute(0, 1, 3,
248
- 2).contiguous().chunk(
249
- 3, dim=-1)
250
-
251
- # apply attention
252
- x = F.scaled_dot_product_attention(
253
- q,
254
- k,
255
- v,
256
- )
257
- x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
258
-
259
- # output
260
- x = self.proj(x)
261
- x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
262
- return x + identity
263
-
264
-
265
- class Encoder3d(nn.Module):
266
-
267
- def __init__(self,
268
- dim=128,
269
- z_dim=4,
270
- dim_mult=[1, 2, 4, 4],
271
- num_res_blocks=2,
272
- attn_scales=[],
273
- temperal_downsample=[True, True, False],
274
- dropout=0.0):
275
- super().__init__()
276
- self.dim = dim
277
- self.z_dim = z_dim
278
- self.dim_mult = dim_mult
279
- self.num_res_blocks = num_res_blocks
280
- self.attn_scales = attn_scales
281
- self.temperal_downsample = temperal_downsample
282
-
283
- # dimensions
284
- dims = [dim * u for u in [1] + dim_mult]
285
- scale = 1.0
286
-
287
- # init block
288
- self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
289
-
290
- # downsample blocks
291
- downsamples = []
292
- for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
293
- # residual (+attention) blocks
294
- for _ in range(num_res_blocks):
295
- downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
296
- if scale in attn_scales:
297
- downsamples.append(AttentionBlock(out_dim))
298
- in_dim = out_dim
299
-
300
- # downsample block
301
- if i != len(dim_mult) - 1:
302
- mode = 'downsample3d' if temperal_downsample[
303
- i] else 'downsample2d'
304
- downsamples.append(Resample(out_dim, mode=mode))
305
- scale /= 2.0
306
- self.downsamples = nn.Sequential(*downsamples)
307
-
308
- # middle blocks
309
- self.middle = nn.Sequential(
310
- ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
311
- ResidualBlock(out_dim, out_dim, dropout))
312
-
313
- # output blocks
314
- self.head = nn.Sequential(
315
- RMS_norm(out_dim, images=False), nn.SiLU(),
316
- CausalConv3d(out_dim, z_dim, 3, padding=1))
317
-
318
- def forward(self, x, feat_cache=None, feat_idx=[0]):
319
- if feat_cache is not None:
320
- idx = feat_idx[0]
321
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
322
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
323
- # cache last frame of last two chunk
324
- cache_x = torch.cat([
325
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
326
- cache_x.device), cache_x
327
- ],
328
- dim=2)
329
- x = self.conv1(x, feat_cache[idx])
330
- feat_cache[idx] = cache_x
331
- feat_idx[0] += 1
332
- else:
333
- x = self.conv1(x)
334
-
335
- ## downsamples
336
- for layer in self.downsamples:
337
- if feat_cache is not None:
338
- x = layer(x, feat_cache, feat_idx)
339
- else:
340
- x = layer(x)
341
-
342
- ## middle
343
- for layer in self.middle:
344
- if isinstance(layer, ResidualBlock) and feat_cache is not None:
345
- x = layer(x, feat_cache, feat_idx)
346
- else:
347
- x = layer(x)
348
-
349
- ## head
350
- for layer in self.head:
351
- if isinstance(layer, CausalConv3d) and feat_cache is not None:
352
- idx = feat_idx[0]
353
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
354
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
355
- # cache last frame of last two chunk
356
- cache_x = torch.cat([
357
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
358
- cache_x.device), cache_x
359
- ],
360
- dim=2)
361
- x = layer(x, feat_cache[idx])
362
- feat_cache[idx] = cache_x
363
- feat_idx[0] += 1
364
- else:
365
- x = layer(x)
366
- return x
367
-
368
-
369
- class Decoder3d(nn.Module):
370
-
371
- def __init__(self,
372
- dim=128,
373
- z_dim=4,
374
- dim_mult=[1, 2, 4, 4],
375
- num_res_blocks=2,
376
- attn_scales=[],
377
- temperal_upsample=[False, True, True],
378
- dropout=0.0):
379
- super().__init__()
380
- self.dim = dim
381
- self.z_dim = z_dim
382
- self.dim_mult = dim_mult
383
- self.num_res_blocks = num_res_blocks
384
- self.attn_scales = attn_scales
385
- self.temperal_upsample = temperal_upsample
386
-
387
- # dimensions
388
- dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
389
- scale = 1.0 / 2**(len(dim_mult) - 2)
390
-
391
- # init block
392
- self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
393
-
394
- # middle blocks
395
- self.middle = nn.Sequential(
396
- ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
397
- ResidualBlock(dims[0], dims[0], dropout))
398
-
399
- # upsample blocks
400
- upsamples = []
401
- for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
402
- # residual (+attention) blocks
403
- if i == 1 or i == 2 or i == 3:
404
- in_dim = in_dim // 2
405
- for _ in range(num_res_blocks + 1):
406
- upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
407
- if scale in attn_scales:
408
- upsamples.append(AttentionBlock(out_dim))
409
- in_dim = out_dim
410
-
411
- # upsample block
412
- if i != len(dim_mult) - 1:
413
- mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
414
- upsamples.append(Resample(out_dim, mode=mode))
415
- scale *= 2.0
416
- self.upsamples = nn.Sequential(*upsamples)
417
-
418
- # output blocks
419
- self.head = nn.Sequential(
420
- RMS_norm(out_dim, images=False), nn.SiLU(),
421
- CausalConv3d(out_dim, 3, 3, padding=1))
422
-
423
- def forward(self, x, feat_cache=None, feat_idx=[0]):
424
- ## conv1
425
- if feat_cache is not None:
426
- idx = feat_idx[0]
427
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
428
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
429
- # cache last frame of last two chunk
430
- cache_x = torch.cat([
431
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
432
- cache_x.device), cache_x
433
- ],
434
- dim=2)
435
- x = self.conv1(x, feat_cache[idx])
436
- feat_cache[idx] = cache_x
437
- feat_idx[0] += 1
438
- else:
439
- x = self.conv1(x)
440
-
441
- ## middle
442
- for layer in self.middle:
443
- if isinstance(layer, ResidualBlock) and feat_cache is not None:
444
- x = layer(x, feat_cache, feat_idx)
445
- else:
446
- x = layer(x)
447
-
448
- ## upsamples
449
- for layer in self.upsamples:
450
- if feat_cache is not None:
451
- x = layer(x, feat_cache, feat_idx)
452
- else:
453
- x = layer(x)
454
-
455
- ## head
456
- for layer in self.head:
457
- if isinstance(layer, CausalConv3d) and feat_cache is not None:
458
- idx = feat_idx[0]
459
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
460
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
461
- # cache last frame of last two chunk
462
- cache_x = torch.cat([
463
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
464
- cache_x.device), cache_x
465
- ],
466
- dim=2)
467
- x = layer(x, feat_cache[idx])
468
- feat_cache[idx] = cache_x
469
- feat_idx[0] += 1
470
- else:
471
- x = layer(x)
472
- return x
473
-
474
-
475
- def count_conv3d(model):
476
- count = 0
477
- for m in model.modules():
478
- if isinstance(m, CausalConv3d):
479
- count += 1
480
- return count
481
-
482
-
483
- class WanVAE_(nn.Module):
484
-
485
- def __init__(self,
486
- dim=128,
487
- z_dim=4,
488
- dim_mult=[1, 2, 4, 4],
489
- num_res_blocks=2,
490
- attn_scales=[],
491
- temperal_downsample=[True, True, False],
492
- dropout=0.0):
493
- super().__init__()
494
- self.dim = dim
495
- self.z_dim = z_dim
496
- self.dim_mult = dim_mult
497
- self.num_res_blocks = num_res_blocks
498
- self.attn_scales = attn_scales
499
- self.temperal_downsample = temperal_downsample
500
- self.temperal_upsample = temperal_downsample[::-1]
501
-
502
- # modules
503
- self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
504
- attn_scales, self.temperal_downsample, dropout)
505
- self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
506
- self.conv2 = CausalConv3d(z_dim, z_dim, 1)
507
- self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
508
- attn_scales, self.temperal_upsample, dropout)
509
-
510
- def forward(self, x):
511
- mu, log_var = self.encode(x)
512
- z = self.reparameterize(mu, log_var)
513
- x_recon = self.decode(z)
514
- return x_recon, mu, log_var
515
-
516
- def encode(self, x, scale):
517
- self.clear_cache()
518
- ## cache
519
- t = x.shape[2]
520
- iter_ = 1 + (t - 1) // 4
521
- ## 对encode输入的x,按时间拆分为1、4、4、4....
522
- for i in range(iter_):
523
- self._enc_conv_idx = [0]
524
- if i == 0:
525
- out = self.encoder(
526
- x[:, :, :1, :, :],
527
- feat_cache=self._enc_feat_map,
528
- feat_idx=self._enc_conv_idx)
529
- else:
530
- out_ = self.encoder(
531
- x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
532
- feat_cache=self._enc_feat_map,
533
- feat_idx=self._enc_conv_idx)
534
- out = torch.cat([out, out_], 2)
535
- mu, log_var = self.conv1(out).chunk(2, dim=1)
536
- if isinstance(scale[0], torch.Tensor):
537
- mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
538
- 1, self.z_dim, 1, 1, 1)
539
- else:
540
- mu = (mu - scale[0]) * scale[1]
541
- self.clear_cache()
542
- return mu
543
-
544
- def decode(self, z, scale):
545
- self.clear_cache()
546
- # z: [b,c,t,h,w]
547
- if isinstance(scale[0], torch.Tensor):
548
- z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
549
- 1, self.z_dim, 1, 1, 1)
550
- else:
551
- z = z / scale[1] + scale[0]
552
- iter_ = z.shape[2]
553
- x = self.conv2(z)
554
- for i in range(iter_):
555
- self._conv_idx = [0]
556
- if i == 0:
557
- out = self.decoder(
558
- x[:, :, i:i + 1, :, :],
559
- feat_cache=self._feat_map,
560
- feat_idx=self._conv_idx)
561
- else:
562
- out_ = self.decoder(
563
- x[:, :, i:i + 1, :, :],
564
- feat_cache=self._feat_map,
565
- feat_idx=self._conv_idx)
566
- out = torch.cat([out, out_], 2)
567
- self.clear_cache()
568
- return out
569
-
570
- def reparameterize(self, mu, log_var):
571
- std = torch.exp(0.5 * log_var)
572
- eps = torch.randn_like(std)
573
- return eps * std + mu
574
-
575
- def sample(self, imgs, deterministic=False):
576
- mu, log_var = self.encode(imgs)
577
- if deterministic:
578
- return mu
579
- std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
580
- return mu + std * torch.randn_like(std)
581
-
582
- def clear_cache(self):
583
- self._conv_num = count_conv3d(self.decoder)
584
- self._conv_idx = [0]
585
- self._feat_map = [None] * self._conv_num
586
- #cache encode
587
- self._enc_conv_num = count_conv3d(self.encoder)
588
- self._enc_conv_idx = [0]
589
- self._enc_feat_map = [None] * self._enc_conv_num
590
-
591
-
592
- def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
593
- """
594
- Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
595
- """
596
- # params
597
- cfg = dict(
598
- dim=96,
599
- z_dim=z_dim,
600
- dim_mult=[1, 2, 4, 4],
601
- num_res_blocks=2,
602
- attn_scales=[],
603
- temperal_downsample=[False, True, True],
604
- dropout=0.0)
605
- cfg.update(**kwargs)
606
-
607
- # init model
608
- with torch.device('meta'):
609
- model = WanVAE_(**cfg)
610
-
611
- # load checkpoint
612
- logging.info(f'loading {pretrained_path}')
613
- model.load_state_dict(
614
- torch.load(pretrained_path, map_location=device), assign=True)
615
-
616
- return model
617
-
618
-
619
- class Wan2_1_VAE:
620
-
621
- def __init__(self,
622
- z_dim=16,
623
- vae_pth='cache/vae_step_411000.pth',
624
- dtype=torch.float,
625
- device="cuda"):
626
- self.dtype = dtype
627
- self.device = device
628
-
629
- mean = [
630
- -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
631
- 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
632
- ]
633
- std = [
634
- 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
635
- 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
636
- ]
637
- self.mean = torch.tensor(mean, dtype=dtype, device=device)
638
- self.std = torch.tensor(std, dtype=dtype, device=device)
639
- self.scale = [self.mean, 1.0 / self.std]
640
-
641
- # init model
642
- self.model = _video_vae(
643
- pretrained_path=vae_pth,
644
- z_dim=z_dim,
645
- ).eval().requires_grad_(False).to(device)
646
-
647
- def encode(self, videos):
648
- """
649
- videos: A list of videos each with shape [C, T, H, W].
650
- """
651
- with amp.autocast(dtype=self.dtype):
652
- return [
653
- self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
654
- for u in videos
655
- ]
656
-
657
- def decode(self, zs):
658
- with amp.autocast(dtype=self.dtype):
659
- return [
660
- self.model.decode(u.unsqueeze(0),
661
- self.scale).float().clamp_(-1, 1).squeeze(0)
662
- for u in zs
663
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/modules/vae2_2.py DELETED
@@ -1,1051 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import logging
3
-
4
- import torch
5
- import torch.cuda.amp as amp
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from einops import rearrange
9
-
10
- __all__ = [
11
- "Wan2_2_VAE",
12
- ]
13
-
14
- CACHE_T = 2
15
-
16
-
17
- class CausalConv3d(nn.Conv3d):
18
- """
19
- Causal 3d convolusion.
20
- """
21
-
22
- def __init__(self, *args, **kwargs):
23
- super().__init__(*args, **kwargs)
24
- self._padding = (
25
- self.padding[2],
26
- self.padding[2],
27
- self.padding[1],
28
- self.padding[1],
29
- 2 * self.padding[0],
30
- 0,
31
- )
32
- self.padding = (0, 0, 0)
33
-
34
- def forward(self, x, cache_x=None):
35
- padding = list(self._padding)
36
- if cache_x is not None and self._padding[4] > 0:
37
- cache_x = cache_x.to(x.device)
38
- x = torch.cat([cache_x, x], dim=2)
39
- padding[4] -= cache_x.shape[2]
40
- x = F.pad(x, padding)
41
-
42
- return super().forward(x)
43
-
44
-
45
- class RMS_norm(nn.Module):
46
-
47
- def __init__(self, dim, channel_first=True, images=True, bias=False):
48
- super().__init__()
49
- broadcastable_dims = (1, 1, 1) if not images else (1, 1)
50
- shape = (dim, *broadcastable_dims) if channel_first else (dim,)
51
-
52
- self.channel_first = channel_first
53
- self.scale = dim**0.5
54
- self.gamma = nn.Parameter(torch.ones(shape))
55
- self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
56
-
57
- def forward(self, x):
58
- return (F.normalize(x, dim=(1 if self.channel_first else -1)) *
59
- self.scale * self.gamma + self.bias)
60
-
61
-
62
- class Upsample(nn.Upsample):
63
-
64
- def forward(self, x):
65
- """
66
- Fix bfloat16 support for nearest neighbor interpolation.
67
- """
68
- return super().forward(x.float()).type_as(x)
69
-
70
-
71
- class Resample(nn.Module):
72
-
73
- def __init__(self, dim, mode):
74
- assert mode in (
75
- "none",
76
- "upsample2d",
77
- "upsample3d",
78
- "downsample2d",
79
- "downsample3d",
80
- )
81
- super().__init__()
82
- self.dim = dim
83
- self.mode = mode
84
-
85
- # layers
86
- if mode == "upsample2d":
87
- self.resample = nn.Sequential(
88
- Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
89
- nn.Conv2d(dim, dim, 3, padding=1),
90
- )
91
- elif mode == "upsample3d":
92
- self.resample = nn.Sequential(
93
- Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
94
- nn.Conv2d(dim, dim, 3, padding=1),
95
- # nn.Conv2d(dim, dim//2, 3, padding=1)
96
- )
97
- self.time_conv = CausalConv3d(
98
- dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
99
- elif mode == "downsample2d":
100
- self.resample = nn.Sequential(
101
- nn.ZeroPad2d((0, 1, 0, 1)),
102
- nn.Conv2d(dim, dim, 3, stride=(2, 2)))
103
- elif mode == "downsample3d":
104
- self.resample = nn.Sequential(
105
- nn.ZeroPad2d((0, 1, 0, 1)),
106
- nn.Conv2d(dim, dim, 3, stride=(2, 2)))
107
- self.time_conv = CausalConv3d(
108
- dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
109
- else:
110
- self.resample = nn.Identity()
111
-
112
- def forward(self, x, feat_cache=None, feat_idx=[0]):
113
- b, c, t, h, w = x.size()
114
- if self.mode == "upsample3d":
115
- if feat_cache is not None:
116
- idx = feat_idx[0]
117
- if feat_cache[idx] is None:
118
- feat_cache[idx] = "Rep"
119
- feat_idx[0] += 1
120
- else:
121
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
122
- if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
123
- feat_cache[idx] != "Rep"):
124
- # cache last frame of last two chunk
125
- cache_x = torch.cat(
126
- [
127
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
128
- cache_x.device),
129
- cache_x,
130
- ],
131
- dim=2,
132
- )
133
- if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
134
- feat_cache[idx] == "Rep"):
135
- cache_x = torch.cat(
136
- [
137
- torch.zeros_like(cache_x).to(cache_x.device),
138
- cache_x
139
- ],
140
- dim=2,
141
- )
142
- if feat_cache[idx] == "Rep":
143
- x = self.time_conv(x)
144
- else:
145
- x = self.time_conv(x, feat_cache[idx])
146
- feat_cache[idx] = cache_x
147
- feat_idx[0] += 1
148
- x = x.reshape(b, 2, c, t, h, w)
149
- x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
150
- 3)
151
- x = x.reshape(b, c, t * 2, h, w)
152
- t = x.shape[2]
153
- x = rearrange(x, "b c t h w -> (b t) c h w")
154
- x = self.resample(x)
155
- x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
156
-
157
- if self.mode == "downsample3d":
158
- if feat_cache is not None:
159
- idx = feat_idx[0]
160
- if feat_cache[idx] is None:
161
- feat_cache[idx] = x.clone()
162
- feat_idx[0] += 1
163
- else:
164
- cache_x = x[:, :, -1:, :, :].clone()
165
- x = self.time_conv(
166
- torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
167
- feat_cache[idx] = cache_x
168
- feat_idx[0] += 1
169
- return x
170
-
171
- def init_weight(self, conv):
172
- conv_weight = conv.weight.detach().clone()
173
- nn.init.zeros_(conv_weight)
174
- c1, c2, t, h, w = conv_weight.size()
175
- one_matrix = torch.eye(c1, c2)
176
- init_matrix = one_matrix
177
- nn.init.zeros_(conv_weight)
178
- conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
179
- conv.weight = nn.Parameter(conv_weight)
180
- nn.init.zeros_(conv.bias.data)
181
-
182
- def init_weight2(self, conv):
183
- conv_weight = conv.weight.data.detach().clone()
184
- nn.init.zeros_(conv_weight)
185
- c1, c2, t, h, w = conv_weight.size()
186
- init_matrix = torch.eye(c1 // 2, c2)
187
- conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
188
- conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
189
- conv.weight = nn.Parameter(conv_weight)
190
- nn.init.zeros_(conv.bias.data)
191
-
192
-
193
- class ResidualBlock(nn.Module):
194
-
195
- def __init__(self, in_dim, out_dim, dropout=0.0):
196
- super().__init__()
197
- self.in_dim = in_dim
198
- self.out_dim = out_dim
199
-
200
- # layers
201
- self.residual = nn.Sequential(
202
- RMS_norm(in_dim, images=False),
203
- nn.SiLU(),
204
- CausalConv3d(in_dim, out_dim, 3, padding=1),
205
- RMS_norm(out_dim, images=False),
206
- nn.SiLU(),
207
- nn.Dropout(dropout),
208
- CausalConv3d(out_dim, out_dim, 3, padding=1),
209
- )
210
- self.shortcut = (
211
- CausalConv3d(in_dim, out_dim, 1)
212
- if in_dim != out_dim else nn.Identity())
213
-
214
- def forward(self, x, feat_cache=None, feat_idx=[0]):
215
- h = self.shortcut(x)
216
- for layer in self.residual:
217
- if isinstance(layer, CausalConv3d) and feat_cache is not None:
218
- idx = feat_idx[0]
219
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
220
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
221
- # cache last frame of last two chunk
222
- cache_x = torch.cat(
223
- [
224
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
225
- cache_x.device),
226
- cache_x,
227
- ],
228
- dim=2,
229
- )
230
- x = layer(x, feat_cache[idx])
231
- feat_cache[idx] = cache_x
232
- feat_idx[0] += 1
233
- else:
234
- x = layer(x)
235
- return x + h
236
-
237
-
238
- class AttentionBlock(nn.Module):
239
- """
240
- Causal self-attention with a single head.
241
- """
242
-
243
- def __init__(self, dim):
244
- super().__init__()
245
- self.dim = dim
246
-
247
- # layers
248
- self.norm = RMS_norm(dim)
249
- self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
250
- self.proj = nn.Conv2d(dim, dim, 1)
251
-
252
- # zero out the last layer params
253
- nn.init.zeros_(self.proj.weight)
254
-
255
- def forward(self, x):
256
- identity = x
257
- b, c, t, h, w = x.size()
258
- x = rearrange(x, "b c t h w -> (b t) c h w")
259
- x = self.norm(x)
260
- # compute query, key, value
261
- q, k, v = (
262
- self.to_qkv(x).reshape(b * t, 1, c * 3,
263
- -1).permute(0, 1, 3,
264
- 2).contiguous().chunk(3, dim=-1))
265
-
266
- # apply attention
267
- x = F.scaled_dot_product_attention(
268
- q,
269
- k,
270
- v,
271
- )
272
- x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
273
-
274
- # output
275
- x = self.proj(x)
276
- x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
277
- return x + identity
278
-
279
-
280
- def patchify(x, patch_size):
281
- if patch_size == 1:
282
- return x
283
- if x.dim() == 4:
284
- x = rearrange(
285
- x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
286
- elif x.dim() == 5:
287
- x = rearrange(
288
- x,
289
- "b c f (h q) (w r) -> b (c r q) f h w",
290
- q=patch_size,
291
- r=patch_size,
292
- )
293
- else:
294
- raise ValueError(f"Invalid input shape: {x.shape}")
295
-
296
- return x
297
-
298
-
299
- def unpatchify(x, patch_size):
300
- if patch_size == 1:
301
- return x
302
-
303
- if x.dim() == 4:
304
- x = rearrange(
305
- x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
306
- elif x.dim() == 5:
307
- x = rearrange(
308
- x,
309
- "b (c r q) f h w -> b c f (h q) (w r)",
310
- q=patch_size,
311
- r=patch_size,
312
- )
313
- return x
314
-
315
-
316
- class AvgDown3D(nn.Module):
317
-
318
- def __init__(
319
- self,
320
- in_channels,
321
- out_channels,
322
- factor_t,
323
- factor_s=1,
324
- ):
325
- super().__init__()
326
- self.in_channels = in_channels
327
- self.out_channels = out_channels
328
- self.factor_t = factor_t
329
- self.factor_s = factor_s
330
- self.factor = self.factor_t * self.factor_s * self.factor_s
331
-
332
- assert in_channels * self.factor % out_channels == 0
333
- self.group_size = in_channels * self.factor // out_channels
334
-
335
- def forward(self, x: torch.Tensor) -> torch.Tensor:
336
- pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
337
- pad = (0, 0, 0, 0, pad_t, 0)
338
- x = F.pad(x, pad)
339
- B, C, T, H, W = x.shape
340
- x = x.view(
341
- B,
342
- C,
343
- T // self.factor_t,
344
- self.factor_t,
345
- H // self.factor_s,
346
- self.factor_s,
347
- W // self.factor_s,
348
- self.factor_s,
349
- )
350
- x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
351
- x = x.view(
352
- B,
353
- C * self.factor,
354
- T // self.factor_t,
355
- H // self.factor_s,
356
- W // self.factor_s,
357
- )
358
- x = x.view(
359
- B,
360
- self.out_channels,
361
- self.group_size,
362
- T // self.factor_t,
363
- H // self.factor_s,
364
- W // self.factor_s,
365
- )
366
- x = x.mean(dim=2)
367
- return x
368
-
369
-
370
- class DupUp3D(nn.Module):
371
-
372
- def __init__(
373
- self,
374
- in_channels: int,
375
- out_channels: int,
376
- factor_t,
377
- factor_s=1,
378
- ):
379
- super().__init__()
380
- self.in_channels = in_channels
381
- self.out_channels = out_channels
382
-
383
- self.factor_t = factor_t
384
- self.factor_s = factor_s
385
- self.factor = self.factor_t * self.factor_s * self.factor_s
386
-
387
- assert out_channels * self.factor % in_channels == 0
388
- self.repeats = out_channels * self.factor // in_channels
389
-
390
- def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
391
- x = x.repeat_interleave(self.repeats, dim=1)
392
- x = x.view(
393
- x.size(0),
394
- self.out_channels,
395
- self.factor_t,
396
- self.factor_s,
397
- self.factor_s,
398
- x.size(2),
399
- x.size(3),
400
- x.size(4),
401
- )
402
- x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
403
- x = x.view(
404
- x.size(0),
405
- self.out_channels,
406
- x.size(2) * self.factor_t,
407
- x.size(4) * self.factor_s,
408
- x.size(6) * self.factor_s,
409
- )
410
- if first_chunk:
411
- x = x[:, :, self.factor_t - 1:, :, :]
412
- return x
413
-
414
-
415
- class Down_ResidualBlock(nn.Module):
416
-
417
- def __init__(self,
418
- in_dim,
419
- out_dim,
420
- dropout,
421
- mult,
422
- temperal_downsample=False,
423
- down_flag=False):
424
- super().__init__()
425
-
426
- # Shortcut path with downsample
427
- self.avg_shortcut = AvgDown3D(
428
- in_dim,
429
- out_dim,
430
- factor_t=2 if temperal_downsample else 1,
431
- factor_s=2 if down_flag else 1,
432
- )
433
-
434
- # Main path with residual blocks and downsample
435
- downsamples = []
436
- for _ in range(mult):
437
- downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
438
- in_dim = out_dim
439
-
440
- # Add the final downsample block
441
- if down_flag:
442
- mode = "downsample3d" if temperal_downsample else "downsample2d"
443
- downsamples.append(Resample(out_dim, mode=mode))
444
-
445
- self.downsamples = nn.Sequential(*downsamples)
446
-
447
- def forward(self, x, feat_cache=None, feat_idx=[0]):
448
- x_copy = x.clone()
449
- for module in self.downsamples:
450
- x = module(x, feat_cache, feat_idx)
451
-
452
- return x + self.avg_shortcut(x_copy)
453
-
454
-
455
- class Up_ResidualBlock(nn.Module):
456
-
457
- def __init__(self,
458
- in_dim,
459
- out_dim,
460
- dropout,
461
- mult,
462
- temperal_upsample=False,
463
- up_flag=False):
464
- super().__init__()
465
- # Shortcut path with upsample
466
- if up_flag:
467
- self.avg_shortcut = DupUp3D(
468
- in_dim,
469
- out_dim,
470
- factor_t=2 if temperal_upsample else 1,
471
- factor_s=2 if up_flag else 1,
472
- )
473
- else:
474
- self.avg_shortcut = None
475
-
476
- # Main path with residual blocks and upsample
477
- upsamples = []
478
- for _ in range(mult):
479
- upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
480
- in_dim = out_dim
481
-
482
- # Add the final upsample block
483
- if up_flag:
484
- mode = "upsample3d" if temperal_upsample else "upsample2d"
485
- upsamples.append(Resample(out_dim, mode=mode))
486
-
487
- self.upsamples = nn.Sequential(*upsamples)
488
-
489
- def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
490
- x_main = x.clone()
491
- for module in self.upsamples:
492
- x_main = module(x_main, feat_cache, feat_idx)
493
- if self.avg_shortcut is not None:
494
- x_shortcut = self.avg_shortcut(x, first_chunk)
495
- return x_main + x_shortcut
496
- else:
497
- return x_main
498
-
499
-
500
- class Encoder3d(nn.Module):
501
-
502
- def __init__(
503
- self,
504
- dim=128,
505
- z_dim=4,
506
- dim_mult=[1, 2, 4, 4],
507
- num_res_blocks=2,
508
- attn_scales=[],
509
- temperal_downsample=[True, True, False],
510
- dropout=0.0,
511
- ):
512
- super().__init__()
513
- self.dim = dim
514
- self.z_dim = z_dim
515
- self.dim_mult = dim_mult
516
- self.num_res_blocks = num_res_blocks
517
- self.attn_scales = attn_scales
518
- self.temperal_downsample = temperal_downsample
519
-
520
- # dimensions
521
- dims = [dim * u for u in [1] + dim_mult]
522
- scale = 1.0
523
-
524
- # init block
525
- self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
526
-
527
- # downsample blocks
528
- downsamples = []
529
- for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
530
- t_down_flag = (
531
- temperal_downsample[i]
532
- if i < len(temperal_downsample) else False)
533
- downsamples.append(
534
- Down_ResidualBlock(
535
- in_dim=in_dim,
536
- out_dim=out_dim,
537
- dropout=dropout,
538
- mult=num_res_blocks,
539
- temperal_downsample=t_down_flag,
540
- down_flag=i != len(dim_mult) - 1,
541
- ))
542
- scale /= 2.0
543
- self.downsamples = nn.Sequential(*downsamples)
544
-
545
- # middle blocks
546
- self.middle = nn.Sequential(
547
- ResidualBlock(out_dim, out_dim, dropout),
548
- AttentionBlock(out_dim),
549
- ResidualBlock(out_dim, out_dim, dropout),
550
- )
551
-
552
- # # output blocks
553
- self.head = nn.Sequential(
554
- RMS_norm(out_dim, images=False),
555
- nn.SiLU(),
556
- CausalConv3d(out_dim, z_dim, 3, padding=1),
557
- )
558
-
559
- def forward(self, x, feat_cache=None, feat_idx=[0]):
560
-
561
- if feat_cache is not None:
562
- idx = feat_idx[0]
563
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
564
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
565
- cache_x = torch.cat(
566
- [
567
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
568
- cache_x.device),
569
- cache_x,
570
- ],
571
- dim=2,
572
- )
573
- x = self.conv1(x, feat_cache[idx])
574
- feat_cache[idx] = cache_x
575
- feat_idx[0] += 1
576
- else:
577
- x = self.conv1(x)
578
-
579
- ## downsamples
580
- for layer in self.downsamples:
581
- if feat_cache is not None:
582
- x = layer(x, feat_cache, feat_idx)
583
- else:
584
- x = layer(x)
585
-
586
- ## middle
587
- for layer in self.middle:
588
- if isinstance(layer, ResidualBlock) and feat_cache is not None:
589
- x = layer(x, feat_cache, feat_idx)
590
- else:
591
- x = layer(x)
592
-
593
- ## head
594
- for layer in self.head:
595
- if isinstance(layer, CausalConv3d) and feat_cache is not None:
596
- idx = feat_idx[0]
597
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
598
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
599
- cache_x = torch.cat(
600
- [
601
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
602
- cache_x.device),
603
- cache_x,
604
- ],
605
- dim=2,
606
- )
607
- x = layer(x, feat_cache[idx])
608
- feat_cache[idx] = cache_x
609
- feat_idx[0] += 1
610
- else:
611
- x = layer(x)
612
-
613
- return x
614
-
615
-
616
- class Decoder3d(nn.Module):
617
-
618
- def __init__(
619
- self,
620
- dim=128,
621
- z_dim=4,
622
- dim_mult=[1, 2, 4, 4],
623
- num_res_blocks=2,
624
- attn_scales=[],
625
- temperal_upsample=[False, True, True],
626
- dropout=0.0,
627
- ):
628
- super().__init__()
629
- self.dim = dim
630
- self.z_dim = z_dim
631
- self.dim_mult = dim_mult
632
- self.num_res_blocks = num_res_blocks
633
- self.attn_scales = attn_scales
634
- self.temperal_upsample = temperal_upsample
635
-
636
- # dimensions
637
- dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
638
- scale = 1.0 / 2**(len(dim_mult) - 2)
639
- # init block
640
- self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
641
-
642
- # middle blocks
643
- self.middle = nn.Sequential(
644
- ResidualBlock(dims[0], dims[0], dropout),
645
- AttentionBlock(dims[0]),
646
- ResidualBlock(dims[0], dims[0], dropout),
647
- )
648
-
649
- # upsample blocks
650
- upsamples = []
651
- for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
652
- t_up_flag = temperal_upsample[i] if i < len(
653
- temperal_upsample) else False
654
- upsamples.append(
655
- Up_ResidualBlock(
656
- in_dim=in_dim,
657
- out_dim=out_dim,
658
- dropout=dropout,
659
- mult=num_res_blocks + 1,
660
- temperal_upsample=t_up_flag,
661
- up_flag=i != len(dim_mult) - 1,
662
- ))
663
- self.upsamples = nn.Sequential(*upsamples)
664
-
665
- # output blocks
666
- self.head = nn.Sequential(
667
- RMS_norm(out_dim, images=False),
668
- nn.SiLU(),
669
- CausalConv3d(out_dim, 12, 3, padding=1),
670
- )
671
-
672
- def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
673
- if feat_cache is not None:
674
- idx = feat_idx[0]
675
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
676
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
677
- cache_x = torch.cat(
678
- [
679
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
680
- cache_x.device),
681
- cache_x,
682
- ],
683
- dim=2,
684
- )
685
- x = self.conv1(x, feat_cache[idx])
686
- feat_cache[idx] = cache_x
687
- feat_idx[0] += 1
688
- else:
689
- x = self.conv1(x)
690
-
691
- for layer in self.middle:
692
- if isinstance(layer, ResidualBlock) and feat_cache is not None:
693
- x = layer(x, feat_cache, feat_idx)
694
- else:
695
- x = layer(x)
696
-
697
- ## upsamples
698
- for layer in self.upsamples:
699
- if feat_cache is not None:
700
- x = layer(x, feat_cache, feat_idx, first_chunk)
701
- else:
702
- x = layer(x)
703
-
704
- ## head
705
- for layer in self.head:
706
- if isinstance(layer, CausalConv3d) and feat_cache is not None:
707
- idx = feat_idx[0]
708
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
709
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
710
- cache_x = torch.cat(
711
- [
712
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
713
- cache_x.device),
714
- cache_x,
715
- ],
716
- dim=2,
717
- )
718
- x = layer(x, feat_cache[idx])
719
- feat_cache[idx] = cache_x
720
- feat_idx[0] += 1
721
- else:
722
- x = layer(x)
723
- return x
724
-
725
-
726
- def count_conv3d(model):
727
- count = 0
728
- for m in model.modules():
729
- if isinstance(m, CausalConv3d):
730
- count += 1
731
- return count
732
-
733
-
734
- class WanVAE_(nn.Module):
735
-
736
- def __init__(
737
- self,
738
- dim=160,
739
- dec_dim=256,
740
- z_dim=16,
741
- dim_mult=[1, 2, 4, 4],
742
- num_res_blocks=2,
743
- attn_scales=[],
744
- temperal_downsample=[True, True, False],
745
- dropout=0.0,
746
- ):
747
- super().__init__()
748
- self.dim = dim
749
- self.z_dim = z_dim
750
- self.dim_mult = dim_mult
751
- self.num_res_blocks = num_res_blocks
752
- self.attn_scales = attn_scales
753
- self.temperal_downsample = temperal_downsample
754
- self.temperal_upsample = temperal_downsample[::-1]
755
-
756
- # modules
757
- self.encoder = Encoder3d(
758
- dim,
759
- z_dim * 2,
760
- dim_mult,
761
- num_res_blocks,
762
- attn_scales,
763
- self.temperal_downsample,
764
- dropout,
765
- )
766
- self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
767
- self.conv2 = CausalConv3d(z_dim, z_dim, 1)
768
- self.decoder = Decoder3d(
769
- dec_dim,
770
- z_dim,
771
- dim_mult,
772
- num_res_blocks,
773
- attn_scales,
774
- self.temperal_upsample,
775
- dropout,
776
- )
777
-
778
- def forward(self, x, scale=[0, 1]):
779
- mu = self.encode(x, scale)
780
- x_recon = self.decode(mu, scale)
781
- return x_recon, mu
782
-
783
- def encode(self, x, scale):
784
- self.clear_cache()
785
- x = patchify(x, patch_size=2)
786
- t = x.shape[2]
787
- iter_ = 1 + (t - 1) // 4
788
- for i in range(iter_):
789
- self._enc_conv_idx = [0]
790
- if i == 0:
791
- out = self.encoder(
792
- x[:, :, :1, :, :],
793
- feat_cache=self._enc_feat_map,
794
- feat_idx=self._enc_conv_idx,
795
- )
796
- else:
797
- out_ = self.encoder(
798
- x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
799
- feat_cache=self._enc_feat_map,
800
- feat_idx=self._enc_conv_idx,
801
- )
802
- out = torch.cat([out, out_], 2)
803
- mu, log_var = self.conv1(out).chunk(2, dim=1)
804
- if isinstance(scale[0], torch.Tensor):
805
- mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
806
- 1, self.z_dim, 1, 1, 1)
807
- else:
808
- mu = (mu - scale[0]) * scale[1]
809
- self.clear_cache()
810
- return mu
811
-
812
- def decode(self, z, scale):
813
- self.clear_cache()
814
- if isinstance(scale[0], torch.Tensor):
815
- z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
816
- 1, self.z_dim, 1, 1, 1)
817
- else:
818
- z = z / scale[1] + scale[0]
819
- iter_ = z.shape[2]
820
- x = self.conv2(z)
821
- for i in range(iter_):
822
- self._conv_idx = [0]
823
- if i == 0:
824
- out = self.decoder(
825
- x[:, :, i:i + 1, :, :],
826
- feat_cache=self._feat_map,
827
- feat_idx=self._conv_idx,
828
- first_chunk=True,
829
- )
830
- else:
831
- out_ = self.decoder(
832
- x[:, :, i:i + 1, :, :],
833
- feat_cache=self._feat_map,
834
- feat_idx=self._conv_idx,
835
- )
836
- out = torch.cat([out, out_], 2)
837
- out = unpatchify(out, patch_size=2)
838
- self.clear_cache()
839
- return out
840
-
841
- def reparameterize(self, mu, log_var):
842
- std = torch.exp(0.5 * log_var)
843
- eps = torch.randn_like(std)
844
- return eps * std + mu
845
-
846
- def sample(self, imgs, deterministic=False):
847
- mu, log_var = self.encode(imgs)
848
- if deterministic:
849
- return mu
850
- std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
851
- return mu + std * torch.randn_like(std)
852
-
853
- def clear_cache(self):
854
- self._conv_num = count_conv3d(self.decoder)
855
- self._conv_idx = [0]
856
- self._feat_map = [None] * self._conv_num
857
- # cache encode
858
- self._enc_conv_num = count_conv3d(self.encoder)
859
- self._enc_conv_idx = [0]
860
- self._enc_feat_map = [None] * self._enc_conv_num
861
-
862
-
863
- def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs):
864
- # params
865
- cfg = dict(
866
- dim=dim,
867
- z_dim=z_dim,
868
- dim_mult=[1, 2, 4, 4],
869
- num_res_blocks=2,
870
- attn_scales=[],
871
- temperal_downsample=[True, True, True],
872
- dropout=0.0,
873
- )
874
- cfg.update(**kwargs)
875
-
876
- # init model
877
- with torch.device("meta"):
878
- model = WanVAE_(**cfg)
879
-
880
- # load checkpoint
881
- logging.info(f"loading {pretrained_path}")
882
- model.load_state_dict(
883
- torch.load(pretrained_path, map_location=device), assign=True)
884
-
885
- return model
886
-
887
-
888
- class Wan2_2_VAE:
889
-
890
- def __init__(
891
- self,
892
- z_dim=48,
893
- c_dim=160,
894
- vae_pth=None,
895
- dim_mult=[1, 2, 4, 4],
896
- temperal_downsample=[False, True, True],
897
- dtype=torch.float,
898
- device="cuda",
899
- ):
900
-
901
- self.dtype = dtype
902
- self.device = device
903
-
904
- mean = torch.tensor(
905
- [
906
- -0.2289,
907
- -0.0052,
908
- -0.1323,
909
- -0.2339,
910
- -0.2799,
911
- 0.0174,
912
- 0.1838,
913
- 0.1557,
914
- -0.1382,
915
- 0.0542,
916
- 0.2813,
917
- 0.0891,
918
- 0.1570,
919
- -0.0098,
920
- 0.0375,
921
- -0.1825,
922
- -0.2246,
923
- -0.1207,
924
- -0.0698,
925
- 0.5109,
926
- 0.2665,
927
- -0.2108,
928
- -0.2158,
929
- 0.2502,
930
- -0.2055,
931
- -0.0322,
932
- 0.1109,
933
- 0.1567,
934
- -0.0729,
935
- 0.0899,
936
- -0.2799,
937
- -0.1230,
938
- -0.0313,
939
- -0.1649,
940
- 0.0117,
941
- 0.0723,
942
- -0.2839,
943
- -0.2083,
944
- -0.0520,
945
- 0.3748,
946
- 0.0152,
947
- 0.1957,
948
- 0.1433,
949
- -0.2944,
950
- 0.3573,
951
- -0.0548,
952
- -0.1681,
953
- -0.0667,
954
- ],
955
- dtype=dtype,
956
- device=device,
957
- )
958
- std = torch.tensor(
959
- [
960
- 0.4765,
961
- 1.0364,
962
- 0.4514,
963
- 1.1677,
964
- 0.5313,
965
- 0.4990,
966
- 0.4818,
967
- 0.5013,
968
- 0.8158,
969
- 1.0344,
970
- 0.5894,
971
- 1.0901,
972
- 0.6885,
973
- 0.6165,
974
- 0.8454,
975
- 0.4978,
976
- 0.5759,
977
- 0.3523,
978
- 0.7135,
979
- 0.6804,
980
- 0.5833,
981
- 1.4146,
982
- 0.8986,
983
- 0.5659,
984
- 0.7069,
985
- 0.5338,
986
- 0.4889,
987
- 0.4917,
988
- 0.4069,
989
- 0.4999,
990
- 0.6866,
991
- 0.4093,
992
- 0.5709,
993
- 0.6065,
994
- 0.6415,
995
- 0.4944,
996
- 0.5726,
997
- 1.2042,
998
- 0.5458,
999
- 1.6887,
1000
- 0.3971,
1001
- 1.0600,
1002
- 0.3943,
1003
- 0.5537,
1004
- 0.5444,
1005
- 0.4089,
1006
- 0.7468,
1007
- 0.7744,
1008
- ],
1009
- dtype=dtype,
1010
- device=device,
1011
- )
1012
- self.scale = [mean, 1.0 / std]
1013
-
1014
- # init model
1015
- self.model = (
1016
- _video_vae(
1017
- pretrained_path=vae_pth,
1018
- z_dim=z_dim,
1019
- dim=c_dim,
1020
- dim_mult=dim_mult,
1021
- temperal_downsample=temperal_downsample,
1022
- ).eval().requires_grad_(False).to(device))
1023
-
1024
- def encode(self, videos):
1025
- try:
1026
- if not isinstance(videos, list):
1027
- raise TypeError("videos should be a list")
1028
- with amp.autocast(dtype=self.dtype):
1029
- return [
1030
- self.model.encode(u.unsqueeze(0),
1031
- self.scale).float().squeeze(0)
1032
- for u in videos
1033
- ]
1034
- except TypeError as e:
1035
- logging.info(e)
1036
- return None
1037
-
1038
- def decode(self, zs):
1039
- try:
1040
- if not isinstance(zs, list):
1041
- raise TypeError("zs should be a list")
1042
- with amp.autocast(dtype=self.dtype):
1043
- return [
1044
- self.model.decode(u.unsqueeze(0),
1045
- self.scale).float().clamp_(-1,
1046
- 1).squeeze(0)
1047
- for u in zs
1048
- ]
1049
- except TypeError as e:
1050
- logging.info(e)
1051
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/text2video.py DELETED
@@ -1,378 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import gc
3
- import logging
4
- import math
5
- import os
6
- import random
7
- import sys
8
- import types
9
- from contextlib import contextmanager
10
- from functools import partial
11
-
12
- import torch
13
- import torch.cuda.amp as amp
14
- import torch.distributed as dist
15
- from tqdm import tqdm
16
-
17
- from .distributed.fsdp import shard_model
18
- from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
19
- from .distributed.util import get_world_size
20
- from .modules.model import WanModel
21
- from .modules.t5 import T5EncoderModel
22
- from .modules.vae2_1 import Wan2_1_VAE
23
- from .utils.fm_solvers import (
24
- FlowDPMSolverMultistepScheduler,
25
- get_sampling_sigmas,
26
- retrieve_timesteps,
27
- )
28
- from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
29
-
30
-
31
- class WanT2V:
32
-
33
- def __init__(
34
- self,
35
- config,
36
- checkpoint_dir,
37
- device_id=0,
38
- rank=0,
39
- t5_fsdp=False,
40
- dit_fsdp=False,
41
- use_sp=False,
42
- t5_cpu=False,
43
- init_on_cpu=True,
44
- convert_model_dtype=False,
45
- ):
46
- r"""
47
- Initializes the Wan text-to-video generation model components.
48
-
49
- Args:
50
- config (EasyDict):
51
- Object containing model parameters initialized from config.py
52
- checkpoint_dir (`str`):
53
- Path to directory containing model checkpoints
54
- device_id (`int`, *optional*, defaults to 0):
55
- Id of target GPU device
56
- rank (`int`, *optional*, defaults to 0):
57
- Process rank for distributed training
58
- t5_fsdp (`bool`, *optional*, defaults to False):
59
- Enable FSDP sharding for T5 model
60
- dit_fsdp (`bool`, *optional*, defaults to False):
61
- Enable FSDP sharding for DiT model
62
- use_sp (`bool`, *optional*, defaults to False):
63
- Enable distribution strategy of sequence parallel.
64
- t5_cpu (`bool`, *optional*, defaults to False):
65
- Whether to place T5 model on CPU. Only works without t5_fsdp.
66
- init_on_cpu (`bool`, *optional*, defaults to True):
67
- Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
68
- convert_model_dtype (`bool`, *optional*, defaults to False):
69
- Convert DiT model parameters dtype to 'config.param_dtype'.
70
- Only works without FSDP.
71
- """
72
- self.device = torch.device(f"cuda:{device_id}")
73
- self.config = config
74
- self.rank = rank
75
- self.t5_cpu = t5_cpu
76
- self.init_on_cpu = init_on_cpu
77
-
78
- self.num_train_timesteps = config.num_train_timesteps
79
- self.boundary = config.boundary
80
- self.param_dtype = config.param_dtype
81
-
82
- if t5_fsdp or dit_fsdp or use_sp:
83
- self.init_on_cpu = False
84
-
85
- shard_fn = partial(shard_model, device_id=device_id)
86
- self.text_encoder = T5EncoderModel(
87
- text_len=config.text_len,
88
- dtype=config.t5_dtype,
89
- device=torch.device('cpu'),
90
- checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
91
- tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
92
- shard_fn=shard_fn if t5_fsdp else None)
93
-
94
- self.vae_stride = config.vae_stride
95
- self.patch_size = config.patch_size
96
- self.vae = Wan2_1_VAE(
97
- vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
98
- device=self.device)
99
-
100
- logging.info(f"Creating WanModel from {checkpoint_dir}")
101
- self.low_noise_model = WanModel.from_pretrained(
102
- checkpoint_dir, subfolder=config.low_noise_checkpoint)
103
- self.low_noise_model = self._configure_model(
104
- model=self.low_noise_model,
105
- use_sp=use_sp,
106
- dit_fsdp=dit_fsdp,
107
- shard_fn=shard_fn,
108
- convert_model_dtype=convert_model_dtype)
109
-
110
- self.high_noise_model = WanModel.from_pretrained(
111
- checkpoint_dir, subfolder=config.high_noise_checkpoint)
112
- self.high_noise_model = self._configure_model(
113
- model=self.high_noise_model,
114
- use_sp=use_sp,
115
- dit_fsdp=dit_fsdp,
116
- shard_fn=shard_fn,
117
- convert_model_dtype=convert_model_dtype)
118
- if use_sp:
119
- self.sp_size = get_world_size()
120
- else:
121
- self.sp_size = 1
122
-
123
- self.sample_neg_prompt = config.sample_neg_prompt
124
-
125
- def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
126
- convert_model_dtype):
127
- """
128
- Configures a model object. This includes setting evaluation modes,
129
- applying distributed parallel strategy, and handling device placement.
130
-
131
- Args:
132
- model (torch.nn.Module):
133
- The model instance to configure.
134
- use_sp (`bool`):
135
- Enable distribution strategy of sequence parallel.
136
- dit_fsdp (`bool`):
137
- Enable FSDP sharding for DiT model.
138
- shard_fn (callable):
139
- The function to apply FSDP sharding.
140
- convert_model_dtype (`bool`):
141
- Convert DiT model parameters dtype to 'config.param_dtype'.
142
- Only works without FSDP.
143
-
144
- Returns:
145
- torch.nn.Module:
146
- The configured model.
147
- """
148
- model.eval().requires_grad_(False)
149
-
150
- if use_sp:
151
- for block in model.blocks:
152
- block.self_attn.forward = types.MethodType(
153
- sp_attn_forward, block.self_attn)
154
- model.forward = types.MethodType(sp_dit_forward, model)
155
-
156
- if dist.is_initialized():
157
- dist.barrier()
158
-
159
- if dit_fsdp:
160
- model = shard_fn(model)
161
- else:
162
- if convert_model_dtype:
163
- model.to(self.param_dtype)
164
- if not self.init_on_cpu:
165
- model.to(self.device)
166
-
167
- return model
168
-
169
- def _prepare_model_for_timestep(self, t, boundary, offload_model):
170
- r"""
171
- Prepares and returns the required model for the current timestep.
172
-
173
- Args:
174
- t (torch.Tensor):
175
- current timestep.
176
- boundary (`int`):
177
- The timestep threshold. If `t` is at or above this value,
178
- the `high_noise_model` is considered as the required model.
179
- offload_model (`bool`):
180
- A flag intended to control the offloading behavior.
181
-
182
- Returns:
183
- torch.nn.Module:
184
- The active model on the target device for the current timestep.
185
- """
186
- if t.item() >= boundary:
187
- required_model_name = 'high_noise_model'
188
- offload_model_name = 'low_noise_model'
189
- else:
190
- required_model_name = 'low_noise_model'
191
- offload_model_name = 'high_noise_model'
192
- if offload_model or self.init_on_cpu:
193
- if next(getattr(
194
- self,
195
- offload_model_name).parameters()).device.type == 'cuda':
196
- getattr(self, offload_model_name).to('cpu')
197
- if next(getattr(
198
- self,
199
- required_model_name).parameters()).device.type == 'cpu':
200
- getattr(self, required_model_name).to(self.device)
201
- return getattr(self, required_model_name)
202
-
203
- def generate(self,
204
- input_prompt,
205
- size=(1280, 720),
206
- frame_num=81,
207
- shift=5.0,
208
- sample_solver='unipc',
209
- sampling_steps=50,
210
- guide_scale=5.0,
211
- n_prompt="",
212
- seed=-1,
213
- offload_model=True):
214
- r"""
215
- Generates video frames from text prompt using diffusion process.
216
-
217
- Args:
218
- input_prompt (`str`):
219
- Text prompt for content generation
220
- size (`tuple[int]`, *optional*, defaults to (1280,720)):
221
- Controls video resolution, (width,height).
222
- frame_num (`int`, *optional*, defaults to 81):
223
- How many frames to sample from a video. The number should be 4n+1
224
- shift (`float`, *optional*, defaults to 5.0):
225
- Noise schedule shift parameter. Affects temporal dynamics
226
- sample_solver (`str`, *optional*, defaults to 'unipc'):
227
- Solver used to sample the video.
228
- sampling_steps (`int`, *optional*, defaults to 50):
229
- Number of diffusion sampling steps. Higher values improve quality but slow generation
230
- guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0):
231
- Classifier-free guidance scale. Controls prompt adherence vs. creativity.
232
- If tuple, the first guide_scale will be used for low noise model and
233
- the second guide_scale will be used for high noise model.
234
- n_prompt (`str`, *optional*, defaults to ""):
235
- Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
236
- seed (`int`, *optional*, defaults to -1):
237
- Random seed for noise generation. If -1, use random seed.
238
- offload_model (`bool`, *optional*, defaults to True):
239
- If True, offloads models to CPU during generation to save VRAM
240
-
241
- Returns:
242
- torch.Tensor:
243
- Generated video frames tensor. Dimensions: (C, N H, W) where:
244
- - C: Color channels (3 for RGB)
245
- - N: Number of frames (81)
246
- - H: Frame height (from size)
247
- - W: Frame width from size)
248
- """
249
- # preprocess
250
- guide_scale = (guide_scale, guide_scale) if isinstance(
251
- guide_scale, float) else guide_scale
252
- F = frame_num
253
- target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
254
- size[1] // self.vae_stride[1],
255
- size[0] // self.vae_stride[2])
256
-
257
- seq_len = math.ceil((target_shape[2] * target_shape[3]) /
258
- (self.patch_size[1] * self.patch_size[2]) *
259
- target_shape[1] / self.sp_size) * self.sp_size
260
-
261
- if n_prompt == "":
262
- n_prompt = self.sample_neg_prompt
263
- seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
264
- seed_g = torch.Generator(device=self.device)
265
- seed_g.manual_seed(seed)
266
-
267
- if not self.t5_cpu:
268
- self.text_encoder.model.to(self.device)
269
- context = self.text_encoder([input_prompt], self.device)
270
- context_null = self.text_encoder([n_prompt], self.device)
271
- if offload_model:
272
- self.text_encoder.model.cpu()
273
- else:
274
- context = self.text_encoder([input_prompt], torch.device('cpu'))
275
- context_null = self.text_encoder([n_prompt], torch.device('cpu'))
276
- context = [t.to(self.device) for t in context]
277
- context_null = [t.to(self.device) for t in context_null]
278
-
279
- noise = [
280
- torch.randn(
281
- target_shape[0],
282
- target_shape[1],
283
- target_shape[2],
284
- target_shape[3],
285
- dtype=torch.float32,
286
- device=self.device,
287
- generator=seed_g)
288
- ]
289
-
290
- @contextmanager
291
- def noop_no_sync():
292
- yield
293
-
294
- no_sync_low_noise = getattr(self.low_noise_model, 'no_sync',
295
- noop_no_sync)
296
- no_sync_high_noise = getattr(self.high_noise_model, 'no_sync',
297
- noop_no_sync)
298
-
299
- # evaluation mode
300
- with (
301
- torch.amp.autocast('cuda', dtype=self.param_dtype),
302
- torch.no_grad(),
303
- no_sync_low_noise(),
304
- no_sync_high_noise(),
305
- ):
306
- boundary = self.boundary * self.num_train_timesteps
307
-
308
- if sample_solver == 'unipc':
309
- sample_scheduler = FlowUniPCMultistepScheduler(
310
- num_train_timesteps=self.num_train_timesteps,
311
- shift=1,
312
- use_dynamic_shifting=False)
313
- sample_scheduler.set_timesteps(
314
- sampling_steps, device=self.device, shift=shift)
315
- timesteps = sample_scheduler.timesteps
316
- elif sample_solver == 'dpm++':
317
- sample_scheduler = FlowDPMSolverMultistepScheduler(
318
- num_train_timesteps=self.num_train_timesteps,
319
- shift=1,
320
- use_dynamic_shifting=False)
321
- sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
322
- timesteps, _ = retrieve_timesteps(
323
- sample_scheduler,
324
- device=self.device,
325
- sigmas=sampling_sigmas)
326
- else:
327
- raise NotImplementedError("Unsupported solver.")
328
-
329
- # sample videos
330
- latents = noise
331
-
332
- arg_c = {'context': context, 'seq_len': seq_len}
333
- arg_null = {'context': context_null, 'seq_len': seq_len}
334
-
335
- for _, t in enumerate(tqdm(timesteps)):
336
- latent_model_input = latents
337
- timestep = [t]
338
-
339
- timestep = torch.stack(timestep)
340
-
341
- model = self._prepare_model_for_timestep(
342
- t, boundary, offload_model)
343
- sample_guide_scale = guide_scale[1] if t.item(
344
- ) >= boundary else guide_scale[0]
345
-
346
- noise_pred_cond = model(
347
- latent_model_input, t=timestep, **arg_c)[0]
348
- noise_pred_uncond = model(
349
- latent_model_input, t=timestep, **arg_null)[0]
350
-
351
- noise_pred = noise_pred_uncond + sample_guide_scale * (
352
- noise_pred_cond - noise_pred_uncond)
353
-
354
- temp_x0 = sample_scheduler.step(
355
- noise_pred.unsqueeze(0),
356
- t,
357
- latents[0].unsqueeze(0),
358
- return_dict=False,
359
- generator=seed_g)[0]
360
- latents = [temp_x0.squeeze(0)]
361
-
362
- x0 = latents
363
- if offload_model:
364
- self.low_noise_model.cpu()
365
- self.high_noise_model.cpu()
366
- torch.cuda.empty_cache()
367
- if self.rank == 0:
368
- videos = self.vae.decode(x0)
369
-
370
- del noise, latents
371
- del sample_scheduler
372
- if offload_model:
373
- gc.collect()
374
- torch.cuda.synchronize()
375
- if dist.is_initialized():
376
- dist.barrier()
377
-
378
- return videos[0] if self.rank == 0 else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/textimage2video.py DELETED
@@ -1,619 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import gc
3
- import logging
4
- import math
5
- import os
6
- import random
7
- import sys
8
- import types
9
- from contextlib import contextmanager
10
- from functools import partial
11
-
12
- import torch
13
- import torch.cuda.amp as amp
14
- import torch.distributed as dist
15
- import torchvision.transforms.functional as TF
16
- from PIL import Image
17
- from tqdm import tqdm
18
-
19
- from .distributed.fsdp import shard_model
20
- from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
21
- from .distributed.util import get_world_size
22
- from .modules.model import WanModel
23
- from .modules.t5 import T5EncoderModel
24
- from .modules.vae2_2 import Wan2_2_VAE
25
- from .utils.fm_solvers import (
26
- FlowDPMSolverMultistepScheduler,
27
- get_sampling_sigmas,
28
- retrieve_timesteps,
29
- )
30
- from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
31
- from .utils.utils import best_output_size, masks_like
32
-
33
-
34
- class WanTI2V:
35
-
36
- def __init__(
37
- self,
38
- config,
39
- checkpoint_dir,
40
- device_id=0,
41
- rank=0,
42
- t5_fsdp=False,
43
- dit_fsdp=False,
44
- use_sp=False,
45
- t5_cpu=False,
46
- init_on_cpu=True,
47
- convert_model_dtype=False,
48
- ):
49
- r"""
50
- Initializes the Wan text-to-video generation model components.
51
-
52
- Args:
53
- config (EasyDict):
54
- Object containing model parameters initialized from config.py
55
- checkpoint_dir (`str`):
56
- Path to directory containing model checkpoints
57
- device_id (`int`, *optional*, defaults to 0):
58
- Id of target GPU device
59
- rank (`int`, *optional*, defaults to 0):
60
- Process rank for distributed training
61
- t5_fsdp (`bool`, *optional*, defaults to False):
62
- Enable FSDP sharding for T5 model
63
- dit_fsdp (`bool`, *optional*, defaults to False):
64
- Enable FSDP sharding for DiT model
65
- use_sp (`bool`, *optional*, defaults to False):
66
- Enable distribution strategy of sequence parallel.
67
- t5_cpu (`bool`, *optional*, defaults to False):
68
- Whether to place T5 model on CPU. Only works without t5_fsdp.
69
- init_on_cpu (`bool`, *optional*, defaults to True):
70
- Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
71
- convert_model_dtype (`bool`, *optional*, defaults to False):
72
- Convert DiT model parameters dtype to 'config.param_dtype'.
73
- Only works without FSDP.
74
- """
75
- self.device = torch.device(f"cuda:{device_id}")
76
- self.config = config
77
- self.rank = rank
78
- self.t5_cpu = t5_cpu
79
- self.init_on_cpu = init_on_cpu
80
-
81
- self.num_train_timesteps = config.num_train_timesteps
82
- self.param_dtype = config.param_dtype
83
-
84
- if t5_fsdp or dit_fsdp or use_sp:
85
- self.init_on_cpu = False
86
-
87
- shard_fn = partial(shard_model, device_id=device_id)
88
- self.text_encoder = T5EncoderModel(
89
- text_len=config.text_len,
90
- dtype=config.t5_dtype,
91
- device=torch.device('cpu'),
92
- checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
93
- tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
94
- shard_fn=shard_fn if t5_fsdp else None)
95
-
96
- self.vae_stride = config.vae_stride
97
- self.patch_size = config.patch_size
98
- self.vae = Wan2_2_VAE(
99
- vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
100
- device=self.device)
101
-
102
- logging.info(f"Creating WanModel from {checkpoint_dir}")
103
- self.model = WanModel.from_pretrained(checkpoint_dir)
104
- self.model = self._configure_model(
105
- model=self.model,
106
- use_sp=use_sp,
107
- dit_fsdp=dit_fsdp,
108
- shard_fn=shard_fn,
109
- convert_model_dtype=convert_model_dtype)
110
-
111
- if use_sp:
112
- self.sp_size = get_world_size()
113
- else:
114
- self.sp_size = 1
115
-
116
- self.sample_neg_prompt = config.sample_neg_prompt
117
-
118
- def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
119
- convert_model_dtype):
120
- """
121
- Configures a model object. This includes setting evaluation modes,
122
- applying distributed parallel strategy, and handling device placement.
123
-
124
- Args:
125
- model (torch.nn.Module):
126
- The model instance to configure.
127
- use_sp (`bool`):
128
- Enable distribution strategy of sequence parallel.
129
- dit_fsdp (`bool`):
130
- Enable FSDP sharding for DiT model.
131
- shard_fn (callable):
132
- The function to apply FSDP sharding.
133
- convert_model_dtype (`bool`):
134
- Convert DiT model parameters dtype to 'config.param_dtype'.
135
- Only works without FSDP.
136
-
137
- Returns:
138
- torch.nn.Module:
139
- The configured model.
140
- """
141
- model.eval().requires_grad_(False)
142
-
143
- if use_sp:
144
- for block in model.blocks:
145
- block.self_attn.forward = types.MethodType(
146
- sp_attn_forward, block.self_attn)
147
- model.forward = types.MethodType(sp_dit_forward, model)
148
-
149
- if dist.is_initialized():
150
- dist.barrier()
151
-
152
- if dit_fsdp:
153
- model = shard_fn(model)
154
- else:
155
- if convert_model_dtype:
156
- model.to(self.param_dtype)
157
- if not self.init_on_cpu:
158
- model.to(self.device)
159
-
160
- return model
161
-
162
- def generate(self,
163
- input_prompt,
164
- img=None,
165
- size=(1280, 704),
166
- max_area=704 * 1280,
167
- frame_num=81,
168
- shift=5.0,
169
- sample_solver='unipc',
170
- sampling_steps=50,
171
- guide_scale=5.0,
172
- n_prompt="",
173
- seed=-1,
174
- offload_model=True):
175
- r"""
176
- Generates video frames from text prompt using diffusion process.
177
-
178
- Args:
179
- input_prompt (`str`):
180
- Text prompt for content generation
181
- img (PIL.Image.Image):
182
- Input image tensor. Shape: [3, H, W]
183
- size (`tuple[int]`, *optional*, defaults to (1280,704)):
184
- Controls video resolution, (width,height).
185
- max_area (`int`, *optional*, defaults to 704*1280):
186
- Maximum pixel area for latent space calculation. Controls video resolution scaling
187
- frame_num (`int`, *optional*, defaults to 81):
188
- How many frames to sample from a video. The number should be 4n+1
189
- shift (`float`, *optional*, defaults to 5.0):
190
- Noise schedule shift parameter. Affects temporal dynamics
191
- sample_solver (`str`, *optional*, defaults to 'unipc'):
192
- Solver used to sample the video.
193
- sampling_steps (`int`, *optional*, defaults to 50):
194
- Number of diffusion sampling steps. Higher values improve quality but slow generation
195
- guide_scale (`float`, *optional*, defaults 5.0):
196
- Classifier-free guidance scale. Controls prompt adherence vs. creativity.
197
- n_prompt (`str`, *optional*, defaults to ""):
198
- Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
199
- seed (`int`, *optional*, defaults to -1):
200
- Random seed for noise generation. If -1, use random seed.
201
- offload_model (`bool`, *optional*, defaults to True):
202
- If True, offloads models to CPU during generation to save VRAM
203
-
204
- Returns:
205
- torch.Tensor:
206
- Generated video frames tensor. Dimensions: (C, N H, W) where:
207
- - C: Color channels (3 for RGB)
208
- - N: Number of frames (81)
209
- - H: Frame height (from size)
210
- - W: Frame width from size)
211
- """
212
- # i2v
213
- if img is not None:
214
- return self.i2v(
215
- input_prompt=input_prompt,
216
- img=img,
217
- max_area=max_area,
218
- frame_num=frame_num,
219
- shift=shift,
220
- sample_solver=sample_solver,
221
- sampling_steps=sampling_steps,
222
- guide_scale=guide_scale,
223
- n_prompt=n_prompt,
224
- seed=seed,
225
- offload_model=offload_model)
226
- # t2v
227
- return self.t2v(
228
- input_prompt=input_prompt,
229
- size=size,
230
- frame_num=frame_num,
231
- shift=shift,
232
- sample_solver=sample_solver,
233
- sampling_steps=sampling_steps,
234
- guide_scale=guide_scale,
235
- n_prompt=n_prompt,
236
- seed=seed,
237
- offload_model=offload_model)
238
-
239
- def t2v(self,
240
- input_prompt,
241
- size=(1280, 704),
242
- frame_num=121,
243
- shift=5.0,
244
- sample_solver='unipc',
245
- sampling_steps=50,
246
- guide_scale=5.0,
247
- n_prompt="",
248
- seed=-1,
249
- offload_model=True):
250
- r"""
251
- Generates video frames from text prompt using diffusion process.
252
-
253
- Args:
254
- input_prompt (`str`):
255
- Text prompt for content generation
256
- size (`tuple[int]`, *optional*, defaults to (1280,704)):
257
- Controls video resolution, (width,height).
258
- frame_num (`int`, *optional*, defaults to 121):
259
- How many frames to sample from a video. The number should be 4n+1
260
- shift (`float`, *optional*, defaults to 5.0):
261
- Noise schedule shift parameter. Affects temporal dynamics
262
- sample_solver (`str`, *optional*, defaults to 'unipc'):
263
- Solver used to sample the video.
264
- sampling_steps (`int`, *optional*, defaults to 50):
265
- Number of diffusion sampling steps. Higher values improve quality but slow generation
266
- guide_scale (`float`, *optional*, defaults 5.0):
267
- Classifier-free guidance scale. Controls prompt adherence vs. creativity.
268
- n_prompt (`str`, *optional*, defaults to ""):
269
- Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
270
- seed (`int`, *optional*, defaults to -1):
271
- Random seed for noise generation. If -1, use random seed.
272
- offload_model (`bool`, *optional*, defaults to True):
273
- If True, offloads models to CPU during generation to save VRAM
274
-
275
- Returns:
276
- torch.Tensor:
277
- Generated video frames tensor. Dimensions: (C, N H, W) where:
278
- - C: Color channels (3 for RGB)
279
- - N: Number of frames (81)
280
- - H: Frame height (from size)
281
- - W: Frame width from size)
282
- """
283
- # preprocess
284
- F = frame_num
285
- target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
286
- size[1] // self.vae_stride[1],
287
- size[0] // self.vae_stride[2])
288
-
289
- seq_len = math.ceil((target_shape[2] * target_shape[3]) /
290
- (self.patch_size[1] * self.patch_size[2]) *
291
- target_shape[1] / self.sp_size) * self.sp_size
292
-
293
- if n_prompt == "":
294
- n_prompt = self.sample_neg_prompt
295
- seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
296
- seed_g = torch.Generator(device=self.device)
297
- seed_g.manual_seed(seed)
298
-
299
- if not self.t5_cpu:
300
- self.text_encoder.model.to(self.device)
301
- context = self.text_encoder([input_prompt], self.device)
302
- context_null = self.text_encoder([n_prompt], self.device)
303
- if offload_model:
304
- self.text_encoder.model.cpu()
305
- else:
306
- context = self.text_encoder([input_prompt], torch.device('cpu'))
307
- context_null = self.text_encoder([n_prompt], torch.device('cpu'))
308
- context = [t.to(self.device) for t in context]
309
- context_null = [t.to(self.device) for t in context_null]
310
-
311
- noise = [
312
- torch.randn(
313
- target_shape[0],
314
- target_shape[1],
315
- target_shape[2],
316
- target_shape[3],
317
- dtype=torch.float32,
318
- device=self.device,
319
- generator=seed_g)
320
- ]
321
-
322
- @contextmanager
323
- def noop_no_sync():
324
- yield
325
-
326
- no_sync = getattr(self.model, 'no_sync', noop_no_sync)
327
-
328
- # evaluation mode
329
- with (
330
- torch.amp.autocast('cuda', dtype=self.param_dtype),
331
- torch.no_grad(),
332
- no_sync(),
333
- ):
334
-
335
- if sample_solver == 'unipc':
336
- sample_scheduler = FlowUniPCMultistepScheduler(
337
- num_train_timesteps=self.num_train_timesteps,
338
- shift=1,
339
- use_dynamic_shifting=False)
340
- sample_scheduler.set_timesteps(
341
- sampling_steps, device=self.device, shift=shift)
342
- timesteps = sample_scheduler.timesteps
343
- elif sample_solver == 'dpm++':
344
- sample_scheduler = FlowDPMSolverMultistepScheduler(
345
- num_train_timesteps=self.num_train_timesteps,
346
- shift=1,
347
- use_dynamic_shifting=False)
348
- sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
349
- timesteps, _ = retrieve_timesteps(
350
- sample_scheduler,
351
- device=self.device,
352
- sigmas=sampling_sigmas)
353
- else:
354
- raise NotImplementedError("Unsupported solver.")
355
-
356
- # sample videos
357
- latents = noise
358
- mask1, mask2 = masks_like(noise, zero=False)
359
-
360
- arg_c = {'context': context, 'seq_len': seq_len}
361
- arg_null = {'context': context_null, 'seq_len': seq_len}
362
-
363
- if offload_model or self.init_on_cpu:
364
- self.model.to(self.device)
365
- torch.cuda.empty_cache()
366
-
367
- for _, t in enumerate(tqdm(timesteps)):
368
- latent_model_input = latents
369
- timestep = [t]
370
-
371
- timestep = torch.stack(timestep)
372
-
373
- temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
374
- temp_ts = torch.cat([
375
- temp_ts,
376
- temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep
377
- ])
378
- timestep = temp_ts.unsqueeze(0)
379
-
380
- noise_pred_cond = self.model(
381
- latent_model_input, t=timestep, **arg_c)[0]
382
- noise_pred_uncond = self.model(
383
- latent_model_input, t=timestep, **arg_null)[0]
384
-
385
- noise_pred = noise_pred_uncond + guide_scale * (
386
- noise_pred_cond - noise_pred_uncond)
387
-
388
- temp_x0 = sample_scheduler.step(
389
- noise_pred.unsqueeze(0),
390
- t,
391
- latents[0].unsqueeze(0),
392
- return_dict=False,
393
- generator=seed_g)[0]
394
- latents = [temp_x0.squeeze(0)]
395
- x0 = latents
396
- if offload_model:
397
- self.model.cpu()
398
- torch.cuda.synchronize()
399
- torch.cuda.empty_cache()
400
- if self.rank == 0:
401
- videos = self.vae.decode(x0)
402
-
403
- del noise, latents
404
- del sample_scheduler
405
- if offload_model:
406
- gc.collect()
407
- torch.cuda.synchronize()
408
- if dist.is_initialized():
409
- dist.barrier()
410
-
411
- return videos[0] if self.rank == 0 else None
412
-
413
- def i2v(self,
414
- input_prompt,
415
- img,
416
- max_area=704 * 1280,
417
- frame_num=121,
418
- shift=5.0,
419
- sample_solver='unipc',
420
- sampling_steps=40,
421
- guide_scale=5.0,
422
- n_prompt="",
423
- seed=-1,
424
- offload_model=True):
425
- r"""
426
- Generates video frames from input image and text prompt using diffusion process.
427
-
428
- Args:
429
- input_prompt (`str`):
430
- Text prompt for content generation.
431
- img (PIL.Image.Image):
432
- Input image tensor. Shape: [3, H, W]
433
- max_area (`int`, *optional*, defaults to 704*1280):
434
- Maximum pixel area for latent space calculation. Controls video resolution scaling
435
- frame_num (`int`, *optional*, defaults to 121):
436
- How many frames to sample from a video. The number should be 4n+1
437
- shift (`float`, *optional*, defaults to 5.0):
438
- Noise schedule shift parameter. Affects temporal dynamics
439
- [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
440
- sample_solver (`str`, *optional*, defaults to 'unipc'):
441
- Solver used to sample the video.
442
- sampling_steps (`int`, *optional*, defaults to 40):
443
- Number of diffusion sampling steps. Higher values improve quality but slow generation
444
- guide_scale (`float`, *optional*, defaults 5.0):
445
- Classifier-free guidance scale. Controls prompt adherence vs. creativity.
446
- n_prompt (`str`, *optional*, defaults to ""):
447
- Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
448
- seed (`int`, *optional*, defaults to -1):
449
- Random seed for noise generation. If -1, use random seed
450
- offload_model (`bool`, *optional*, defaults to True):
451
- If True, offloads models to CPU during generation to save VRAM
452
-
453
- Returns:
454
- torch.Tensor:
455
- Generated video frames tensor. Dimensions: (C, N H, W) where:
456
- - C: Color channels (3 for RGB)
457
- - N: Number of frames (121)
458
- - H: Frame height (from max_area)
459
- - W: Frame width (from max_area)
460
- """
461
- # preprocess
462
- ih, iw = img.height, img.width
463
- dh, dw = self.patch_size[1] * self.vae_stride[1], self.patch_size[
464
- 2] * self.vae_stride[2]
465
- ow, oh = best_output_size(iw, ih, dw, dh, max_area)
466
-
467
- scale = max(ow / iw, oh / ih)
468
- img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS)
469
-
470
- # center-crop
471
- x1 = (img.width - ow) // 2
472
- y1 = (img.height - oh) // 2
473
- img = img.crop((x1, y1, x1 + ow, y1 + oh))
474
- assert img.width == ow and img.height == oh
475
-
476
- # to tensor
477
- img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device).unsqueeze(1)
478
-
479
- F = frame_num
480
- seq_len = ((F - 1) // self.vae_stride[0] + 1) * (
481
- oh // self.vae_stride[1]) * (ow // self.vae_stride[2]) // (
482
- self.patch_size[1] * self.patch_size[2])
483
- seq_len = int(math.ceil(seq_len / self.sp_size)) * self.sp_size
484
-
485
- seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
486
- seed_g = torch.Generator(device=self.device)
487
- seed_g.manual_seed(seed)
488
- noise = torch.randn(
489
- self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
490
- oh // self.vae_stride[1],
491
- ow // self.vae_stride[2],
492
- dtype=torch.float32,
493
- generator=seed_g,
494
- device=self.device)
495
-
496
- if n_prompt == "":
497
- n_prompt = self.sample_neg_prompt
498
-
499
- # preprocess
500
- if not self.t5_cpu:
501
- self.text_encoder.model.to(self.device)
502
- context = self.text_encoder([input_prompt], self.device)
503
- context_null = self.text_encoder([n_prompt], self.device)
504
- if offload_model:
505
- self.text_encoder.model.cpu()
506
- else:
507
- context = self.text_encoder([input_prompt], torch.device('cpu'))
508
- context_null = self.text_encoder([n_prompt], torch.device('cpu'))
509
- context = [t.to(self.device) for t in context]
510
- context_null = [t.to(self.device) for t in context_null]
511
-
512
- z = self.vae.encode([img])
513
-
514
- @contextmanager
515
- def noop_no_sync():
516
- yield
517
-
518
- no_sync = getattr(self.model, 'no_sync', noop_no_sync)
519
-
520
- # evaluation mode
521
- with (
522
- torch.amp.autocast('cuda', dtype=self.param_dtype),
523
- torch.no_grad(),
524
- no_sync(),
525
- ):
526
-
527
- if sample_solver == 'unipc':
528
- sample_scheduler = FlowUniPCMultistepScheduler(
529
- num_train_timesteps=self.num_train_timesteps,
530
- shift=1,
531
- use_dynamic_shifting=False)
532
- sample_scheduler.set_timesteps(
533
- sampling_steps, device=self.device, shift=shift)
534
- timesteps = sample_scheduler.timesteps
535
- elif sample_solver == 'dpm++':
536
- sample_scheduler = FlowDPMSolverMultistepScheduler(
537
- num_train_timesteps=self.num_train_timesteps,
538
- shift=1,
539
- use_dynamic_shifting=False)
540
- sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
541
- timesteps, _ = retrieve_timesteps(
542
- sample_scheduler,
543
- device=self.device,
544
- sigmas=sampling_sigmas)
545
- else:
546
- raise NotImplementedError("Unsupported solver.")
547
-
548
- # sample videos
549
- latent = noise
550
- mask1, mask2 = masks_like([noise], zero=True)
551
- latent = (1. - mask2[0]) * z[0] + mask2[0] * latent
552
-
553
- arg_c = {
554
- 'context': [context[0]],
555
- 'seq_len': seq_len,
556
- }
557
-
558
- arg_null = {
559
- 'context': context_null,
560
- 'seq_len': seq_len,
561
- }
562
-
563
- if offload_model or self.init_on_cpu:
564
- self.model.to(self.device)
565
- torch.cuda.empty_cache()
566
-
567
- for _, t in enumerate(tqdm(timesteps)):
568
- latent_model_input = [latent.to(self.device)]
569
- timestep = [t]
570
-
571
- timestep = torch.stack(timestep).to(self.device)
572
-
573
- temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
574
- temp_ts = torch.cat([
575
- temp_ts,
576
- temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep
577
- ])
578
- timestep = temp_ts.unsqueeze(0)
579
-
580
- noise_pred_cond = self.model(
581
- latent_model_input, t=timestep, **arg_c)[0]
582
- if offload_model:
583
- torch.cuda.empty_cache()
584
- noise_pred_uncond = self.model(
585
- latent_model_input, t=timestep, **arg_null)[0]
586
- if offload_model:
587
- torch.cuda.empty_cache()
588
- noise_pred = noise_pred_uncond + guide_scale * (
589
- noise_pred_cond - noise_pred_uncond)
590
-
591
- temp_x0 = sample_scheduler.step(
592
- noise_pred.unsqueeze(0),
593
- t,
594
- latent.unsqueeze(0),
595
- return_dict=False,
596
- generator=seed_g)[0]
597
- latent = temp_x0.squeeze(0)
598
- latent = (1. - mask2[0]) * z[0] + mask2[0] * latent
599
-
600
- x0 = [latent]
601
- del latent_model_input, timestep
602
-
603
- if offload_model:
604
- self.model.cpu()
605
- torch.cuda.synchronize()
606
- torch.cuda.empty_cache()
607
-
608
- if self.rank == 0:
609
- videos = self.vae.decode(x0)
610
-
611
- del noise, latent, x0
612
- del sample_scheduler
613
- if offload_model:
614
- gc.collect()
615
- torch.cuda.synchronize()
616
- if dist.is_initialized():
617
- dist.barrier()
618
-
619
- return videos[0] if self.rank == 0 else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/utils/__init__.py DELETED
@@ -1,12 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- from .fm_solvers import (
3
- FlowDPMSolverMultistepScheduler,
4
- get_sampling_sigmas,
5
- retrieve_timesteps,
6
- )
7
- from .fm_solvers_unipc import FlowUniPCMultistepScheduler
8
-
9
- __all__ = [
10
- 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
11
- 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler'
12
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/utils/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (393 Bytes)
 
wan/utils/__pycache__/fm_solvers.cpython-310.pyc DELETED
Binary file (26.1 kB)
 
wan/utils/__pycache__/fm_solvers_unipc.cpython-310.pyc DELETED
Binary file (22.2 kB)
 
wan/utils/__pycache__/utils.cpython-310.pyc DELETED
Binary file (4.31 kB)
 
wan/utils/fm_solvers.py DELETED
@@ -1,859 +0,0 @@
1
- # Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
2
- # Convert dpm solver for flow matching
3
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
4
-
5
- import inspect
6
- import math
7
- from typing import List, Optional, Tuple, Union
8
-
9
- import numpy as np
10
- import torch
11
- from diffusers.configuration_utils import ConfigMixin, register_to_config
12
- from diffusers.schedulers.scheduling_utils import (
13
- KarrasDiffusionSchedulers,
14
- SchedulerMixin,
15
- SchedulerOutput,
16
- )
17
- from diffusers.utils import deprecate, is_scipy_available
18
- from diffusers.utils.torch_utils import randn_tensor
19
-
20
- if is_scipy_available():
21
- pass
22
-
23
-
24
- def get_sampling_sigmas(sampling_steps, shift):
25
- sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
26
- sigma = (shift * sigma / (1 + (shift - 1) * sigma))
27
-
28
- return sigma
29
-
30
-
31
- def retrieve_timesteps(
32
- scheduler,
33
- num_inference_steps=None,
34
- device=None,
35
- timesteps=None,
36
- sigmas=None,
37
- **kwargs,
38
- ):
39
- if timesteps is not None and sigmas is not None:
40
- raise ValueError(
41
- "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
42
- )
43
- if timesteps is not None:
44
- accepts_timesteps = "timesteps" in set(
45
- inspect.signature(scheduler.set_timesteps).parameters.keys())
46
- if not accepts_timesteps:
47
- raise ValueError(
48
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
49
- f" timestep schedules. Please check whether you are using the correct scheduler."
50
- )
51
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
52
- timesteps = scheduler.timesteps
53
- num_inference_steps = len(timesteps)
54
- elif sigmas is not None:
55
- accept_sigmas = "sigmas" in set(
56
- inspect.signature(scheduler.set_timesteps).parameters.keys())
57
- if not accept_sigmas:
58
- raise ValueError(
59
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
60
- f" sigmas schedules. Please check whether you are using the correct scheduler."
61
- )
62
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
63
- timesteps = scheduler.timesteps
64
- num_inference_steps = len(timesteps)
65
- else:
66
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
67
- timesteps = scheduler.timesteps
68
- return timesteps, num_inference_steps
69
-
70
-
71
- class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
72
- """
73
- `FlowDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
74
- This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
75
- methods the library implements for all schedulers such as loading and saving.
76
- Args:
77
- num_train_timesteps (`int`, defaults to 1000):
78
- The number of diffusion steps to train the model. This determines the resolution of the diffusion process.
79
- solver_order (`int`, defaults to 2):
80
- The DPMSolver order which can be `1`, `2`, or `3`. It is recommended to use `solver_order=2` for guided
81
- sampling, and `solver_order=3` for unconditional sampling. This affects the number of model outputs stored
82
- and used in multistep updates.
83
- prediction_type (`str`, defaults to "flow_prediction"):
84
- Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
85
- the flow of the diffusion process.
86
- shift (`float`, *optional*, defaults to 1.0):
87
- A factor used to adjust the sigmas in the noise schedule. It modifies the step sizes during the sampling
88
- process.
89
- use_dynamic_shifting (`bool`, defaults to `False`):
90
- Whether to apply dynamic shifting to the timesteps based on image resolution. If `True`, the shifting is
91
- applied on the fly.
92
- thresholding (`bool`, defaults to `False`):
93
- Whether to use the "dynamic thresholding" method. This method adjusts the predicted sample to prevent
94
- saturation and improve photorealism.
95
- dynamic_thresholding_ratio (`float`, defaults to 0.995):
96
- The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
97
- sample_max_value (`float`, defaults to 1.0):
98
- The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
99
- `algorithm_type="dpmsolver++"`.
100
- algorithm_type (`str`, defaults to `dpmsolver++`):
101
- Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
102
- `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
103
- paper, and the `dpmsolver++` type implements the algorithms in the
104
- [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
105
- `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
106
- solver_type (`str`, defaults to `midpoint`):
107
- Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
108
- sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
109
- lower_order_final (`bool`, defaults to `True`):
110
- Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
111
- stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
112
- euler_at_final (`bool`, defaults to `False`):
113
- Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
114
- richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
115
- steps, but sometimes may result in blurring.
116
- final_sigmas_type (`str`, *optional*, defaults to "zero"):
117
- The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
118
- sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
119
- lambda_min_clipped (`float`, defaults to `-inf`):
120
- Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
121
- cosine (`squaredcos_cap_v2`) noise schedule.
122
- variance_type (`str`, *optional*):
123
- Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
124
- contains the predicted Gaussian variance.
125
- """
126
-
127
- _compatibles = [e.name for e in KarrasDiffusionSchedulers]
128
- order = 1
129
-
130
- @register_to_config
131
- def __init__(
132
- self,
133
- num_train_timesteps: int = 1000,
134
- solver_order: int = 2,
135
- prediction_type: str = "flow_prediction",
136
- shift: Optional[float] = 1.0,
137
- use_dynamic_shifting=False,
138
- thresholding: bool = False,
139
- dynamic_thresholding_ratio: float = 0.995,
140
- sample_max_value: float = 1.0,
141
- algorithm_type: str = "dpmsolver++",
142
- solver_type: str = "midpoint",
143
- lower_order_final: bool = True,
144
- euler_at_final: bool = False,
145
- final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
146
- lambda_min_clipped: float = -float("inf"),
147
- variance_type: Optional[str] = None,
148
- invert_sigmas: bool = False,
149
- ):
150
- if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
151
- deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
152
- deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0",
153
- deprecation_message)
154
-
155
- # settings for DPM-Solver
156
- if algorithm_type not in [
157
- "dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"
158
- ]:
159
- if algorithm_type == "deis":
160
- self.register_to_config(algorithm_type="dpmsolver++")
161
- else:
162
- raise NotImplementedError(
163
- f"{algorithm_type} is not implemented for {self.__class__}")
164
-
165
- if solver_type not in ["midpoint", "heun"]:
166
- if solver_type in ["logrho", "bh1", "bh2"]:
167
- self.register_to_config(solver_type="midpoint")
168
- else:
169
- raise NotImplementedError(
170
- f"{solver_type} is not implemented for {self.__class__}")
171
-
172
- if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"
173
- ] and final_sigmas_type == "zero":
174
- raise ValueError(
175
- f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
176
- )
177
-
178
- # setable values
179
- self.num_inference_steps = None
180
- alphas = np.linspace(1, 1 / num_train_timesteps,
181
- num_train_timesteps)[::-1].copy()
182
- sigmas = 1.0 - alphas
183
- sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
184
-
185
- if not use_dynamic_shifting:
186
- # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
187
- sigmas = shift * sigmas / (1 +
188
- (shift - 1) * sigmas) # pyright: ignore
189
-
190
- self.sigmas = sigmas
191
- self.timesteps = sigmas * num_train_timesteps
192
-
193
- self.model_outputs = [None] * solver_order
194
- self.lower_order_nums = 0
195
- self._step_index = None
196
- self._begin_index = None
197
-
198
- # self.sigmas = self.sigmas.to(
199
- # "cpu") # to avoid too much CPU/GPU communication
200
- self.sigma_min = self.sigmas[-1].item()
201
- self.sigma_max = self.sigmas[0].item()
202
-
203
- @property
204
- def step_index(self):
205
- """
206
- The index counter for current timestep. It will increase 1 after each scheduler step.
207
- """
208
- return self._step_index
209
-
210
- @property
211
- def begin_index(self):
212
- """
213
- The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
214
- """
215
- return self._begin_index
216
-
217
- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
218
- def set_begin_index(self, begin_index: int = 0):
219
- """
220
- Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
221
- Args:
222
- begin_index (`int`):
223
- The begin index for the scheduler.
224
- """
225
- self._begin_index = begin_index
226
-
227
- # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
228
- def set_timesteps(
229
- self,
230
- num_inference_steps: Union[int, None] = None,
231
- device: Union[str, torch.device] = None,
232
- sigmas: Optional[List[float]] = None,
233
- mu: Optional[Union[float, None]] = None,
234
- shift: Optional[Union[float, None]] = None,
235
- ):
236
- """
237
- Sets the discrete timesteps used for the diffusion chain (to be run before inference).
238
- Args:
239
- num_inference_steps (`int`):
240
- Total number of the spacing of the time steps.
241
- device (`str` or `torch.device`, *optional*):
242
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
243
- """
244
-
245
- if self.config.use_dynamic_shifting and mu is None:
246
- raise ValueError(
247
- " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
248
- )
249
-
250
- if sigmas is None:
251
- sigmas = np.linspace(self.sigma_max, self.sigma_min,
252
- num_inference_steps +
253
- 1).copy()[:-1] # pyright: ignore
254
-
255
- if self.config.use_dynamic_shifting:
256
- sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
257
- else:
258
- if shift is None:
259
- shift = self.config.shift
260
- sigmas = shift * sigmas / (1 +
261
- (shift - 1) * sigmas) # pyright: ignore
262
-
263
- if self.config.final_sigmas_type == "sigma_min":
264
- sigma_last = ((1 - self.alphas_cumprod[0]) /
265
- self.alphas_cumprod[0])**0.5
266
- elif self.config.final_sigmas_type == "zero":
267
- sigma_last = 0
268
- else:
269
- raise ValueError(
270
- f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
271
- )
272
-
273
- timesteps = sigmas * self.config.num_train_timesteps
274
- sigmas = np.concatenate([sigmas, [sigma_last]
275
- ]).astype(np.float32) # pyright: ignore
276
-
277
- self.sigmas = torch.from_numpy(sigmas)
278
- self.timesteps = torch.from_numpy(timesteps).to(
279
- device=device, dtype=torch.int64)
280
-
281
- self.num_inference_steps = len(timesteps)
282
-
283
- self.model_outputs = [
284
- None,
285
- ] * self.config.solver_order
286
- self.lower_order_nums = 0
287
-
288
- self._step_index = None
289
- self._begin_index = None
290
- # self.sigmas = self.sigmas.to(
291
- # "cpu") # to avoid too much CPU/GPU communication
292
-
293
- # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
294
- def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
295
- """
296
- "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
297
- prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
298
- s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
299
- pixels from saturation at each step. We find that dynamic thresholding results in significantly better
300
- photorealism as well as better image-text alignment, especially when using very large guidance weights."
301
- https://arxiv.org/abs/2205.11487
302
- """
303
- dtype = sample.dtype
304
- batch_size, channels, *remaining_dims = sample.shape
305
-
306
- if dtype not in (torch.float32, torch.float64):
307
- sample = sample.float(
308
- ) # upcast for quantile calculation, and clamp not implemented for cpu half
309
-
310
- # Flatten sample for doing quantile calculation along each image
311
- sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
312
-
313
- abs_sample = sample.abs() # "a certain percentile absolute pixel value"
314
-
315
- s = torch.quantile(
316
- abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
317
- s = torch.clamp(
318
- s, min=1, max=self.config.sample_max_value
319
- ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
320
- s = s.unsqueeze(
321
- 1) # (batch_size, 1) because clamp will broadcast along dim=0
322
- sample = torch.clamp(
323
- sample, -s, s
324
- ) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
325
-
326
- sample = sample.reshape(batch_size, channels, *remaining_dims)
327
- sample = sample.to(dtype)
328
-
329
- return sample
330
-
331
- # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
332
- def _sigma_to_t(self, sigma):
333
- return sigma * self.config.num_train_timesteps
334
-
335
- def _sigma_to_alpha_sigma_t(self, sigma):
336
- return 1 - sigma, sigma
337
-
338
- # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
339
- def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
340
- return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
341
-
342
- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
343
- def convert_model_output(
344
- self,
345
- model_output: torch.Tensor,
346
- *args,
347
- sample: torch.Tensor = None,
348
- **kwargs,
349
- ) -> torch.Tensor:
350
- """
351
- Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
352
- designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
353
- integral of the data prediction model.
354
- <Tip>
355
- The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
356
- prediction and data prediction models.
357
- </Tip>
358
- Args:
359
- model_output (`torch.Tensor`):
360
- The direct output from the learned diffusion model.
361
- sample (`torch.Tensor`):
362
- A current instance of a sample created by the diffusion process.
363
- Returns:
364
- `torch.Tensor`:
365
- The converted model output.
366
- """
367
- timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
368
- if sample is None:
369
- if len(args) > 1:
370
- sample = args[1]
371
- else:
372
- raise ValueError(
373
- "missing `sample` as a required keyward argument")
374
- if timestep is not None:
375
- deprecate(
376
- "timesteps",
377
- "1.0.0",
378
- "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
379
- )
380
-
381
- # DPM-Solver++ needs to solve an integral of the data prediction model.
382
- if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
383
- if self.config.prediction_type == "flow_prediction":
384
- sigma_t = self.sigmas[self.step_index]
385
- x0_pred = sample - sigma_t * model_output
386
- else:
387
- raise ValueError(
388
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
389
- " `v_prediction`, or `flow_prediction` for the FlowDPMSolverMultistepScheduler."
390
- )
391
-
392
- if self.config.thresholding:
393
- x0_pred = self._threshold_sample(x0_pred)
394
-
395
- return x0_pred
396
-
397
- # DPM-Solver needs to solve an integral of the noise prediction model.
398
- elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
399
- if self.config.prediction_type == "flow_prediction":
400
- sigma_t = self.sigmas[self.step_index]
401
- epsilon = sample - (1 - sigma_t) * model_output
402
- else:
403
- raise ValueError(
404
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
405
- " `v_prediction` or `flow_prediction` for the FlowDPMSolverMultistepScheduler."
406
- )
407
-
408
- if self.config.thresholding:
409
- sigma_t = self.sigmas[self.step_index]
410
- x0_pred = sample - sigma_t * model_output
411
- x0_pred = self._threshold_sample(x0_pred)
412
- epsilon = model_output + x0_pred
413
-
414
- return epsilon
415
-
416
- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update
417
- def dpm_solver_first_order_update(
418
- self,
419
- model_output: torch.Tensor,
420
- *args,
421
- sample: torch.Tensor = None,
422
- noise: Optional[torch.Tensor] = None,
423
- **kwargs,
424
- ) -> torch.Tensor:
425
- """
426
- One step for the first-order DPMSolver (equivalent to DDIM).
427
- Args:
428
- model_output (`torch.Tensor`):
429
- The direct output from the learned diffusion model.
430
- sample (`torch.Tensor`):
431
- A current instance of a sample created by the diffusion process.
432
- Returns:
433
- `torch.Tensor`:
434
- The sample tensor at the previous timestep.
435
- """
436
- timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
437
- prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
438
- "prev_timestep", None)
439
- if sample is None:
440
- if len(args) > 2:
441
- sample = args[2]
442
- else:
443
- raise ValueError(
444
- " missing `sample` as a required keyward argument")
445
- if timestep is not None:
446
- deprecate(
447
- "timesteps",
448
- "1.0.0",
449
- "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
450
- )
451
-
452
- if prev_timestep is not None:
453
- deprecate(
454
- "prev_timestep",
455
- "1.0.0",
456
- "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
457
- )
458
-
459
- sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[
460
- self.step_index] # pyright: ignore
461
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
462
- alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
463
- lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
464
- lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
465
-
466
- h = lambda_t - lambda_s
467
- if self.config.algorithm_type == "dpmsolver++":
468
- x_t = (sigma_t /
469
- sigma_s) * sample - (alpha_t *
470
- (torch.exp(-h) - 1.0)) * model_output
471
- elif self.config.algorithm_type == "dpmsolver":
472
- x_t = (alpha_t /
473
- alpha_s) * sample - (sigma_t *
474
- (torch.exp(h) - 1.0)) * model_output
475
- elif self.config.algorithm_type == "sde-dpmsolver++":
476
- assert noise is not None
477
- x_t = ((sigma_t / sigma_s * torch.exp(-h)) * sample +
478
- (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output +
479
- sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
480
- elif self.config.algorithm_type == "sde-dpmsolver":
481
- assert noise is not None
482
- x_t = ((alpha_t / alpha_s) * sample - 2.0 *
483
- (sigma_t * (torch.exp(h) - 1.0)) * model_output +
484
- sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
485
- return x_t # pyright: ignore
486
-
487
- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update
488
- def multistep_dpm_solver_second_order_update(
489
- self,
490
- model_output_list: List[torch.Tensor],
491
- *args,
492
- sample: torch.Tensor = None,
493
- noise: Optional[torch.Tensor] = None,
494
- **kwargs,
495
- ) -> torch.Tensor:
496
- """
497
- One step for the second-order multistep DPMSolver.
498
- Args:
499
- model_output_list (`List[torch.Tensor]`):
500
- The direct outputs from learned diffusion model at current and latter timesteps.
501
- sample (`torch.Tensor`):
502
- A current instance of a sample created by the diffusion process.
503
- Returns:
504
- `torch.Tensor`:
505
- The sample tensor at the previous timestep.
506
- """
507
- timestep_list = args[0] if len(args) > 0 else kwargs.pop(
508
- "timestep_list", None)
509
- prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
510
- "prev_timestep", None)
511
- if sample is None:
512
- if len(args) > 2:
513
- sample = args[2]
514
- else:
515
- raise ValueError(
516
- " missing `sample` as a required keyward argument")
517
- if timestep_list is not None:
518
- deprecate(
519
- "timestep_list",
520
- "1.0.0",
521
- "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
522
- )
523
-
524
- if prev_timestep is not None:
525
- deprecate(
526
- "prev_timestep",
527
- "1.0.0",
528
- "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
529
- )
530
-
531
- sigma_t, sigma_s0, sigma_s1 = (
532
- self.sigmas[self.step_index + 1], # pyright: ignore
533
- self.sigmas[self.step_index],
534
- self.sigmas[self.step_index - 1], # pyright: ignore
535
- )
536
-
537
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
538
- alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
539
- alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
540
-
541
- lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
542
- lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
543
- lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
544
-
545
- m0, m1 = model_output_list[-1], model_output_list[-2]
546
-
547
- h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
548
- r0 = h_0 / h
549
- D0, D1 = m0, (1.0 / r0) * (m0 - m1)
550
- if self.config.algorithm_type == "dpmsolver++":
551
- # See https://arxiv.org/abs/2211.01095 for detailed derivations
552
- if self.config.solver_type == "midpoint":
553
- x_t = ((sigma_t / sigma_s0) * sample -
554
- (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 *
555
- (alpha_t * (torch.exp(-h) - 1.0)) * D1)
556
- elif self.config.solver_type == "heun":
557
- x_t = ((sigma_t / sigma_s0) * sample -
558
- (alpha_t * (torch.exp(-h) - 1.0)) * D0 +
559
- (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1)
560
- elif self.config.algorithm_type == "dpmsolver":
561
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
562
- if self.config.solver_type == "midpoint":
563
- x_t = ((alpha_t / alpha_s0) * sample -
564
- (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 *
565
- (sigma_t * (torch.exp(h) - 1.0)) * D1)
566
- elif self.config.solver_type == "heun":
567
- x_t = ((alpha_t / alpha_s0) * sample -
568
- (sigma_t * (torch.exp(h) - 1.0)) * D0 -
569
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1)
570
- elif self.config.algorithm_type == "sde-dpmsolver++":
571
- assert noise is not None
572
- if self.config.solver_type == "midpoint":
573
- x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample +
574
- (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 *
575
- (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 +
576
- sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
577
- elif self.config.solver_type == "heun":
578
- x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample +
579
- (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 +
580
- (alpha_t * ((1.0 - torch.exp(-2.0 * h)) /
581
- (-2.0 * h) + 1.0)) * D1 +
582
- sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
583
- elif self.config.algorithm_type == "sde-dpmsolver":
584
- assert noise is not None
585
- if self.config.solver_type == "midpoint":
586
- x_t = ((alpha_t / alpha_s0) * sample - 2.0 *
587
- (sigma_t * (torch.exp(h) - 1.0)) * D0 -
588
- (sigma_t * (torch.exp(h) - 1.0)) * D1 +
589
- sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
590
- elif self.config.solver_type == "heun":
591
- x_t = ((alpha_t / alpha_s0) * sample - 2.0 *
592
- (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 *
593
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 +
594
- sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
595
- return x_t # pyright: ignore
596
-
597
- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update
598
- def multistep_dpm_solver_third_order_update(
599
- self,
600
- model_output_list: List[torch.Tensor],
601
- *args,
602
- sample: torch.Tensor = None,
603
- **kwargs,
604
- ) -> torch.Tensor:
605
- """
606
- One step for the third-order multistep DPMSolver.
607
- Args:
608
- model_output_list (`List[torch.Tensor]`):
609
- The direct outputs from learned diffusion model at current and latter timesteps.
610
- sample (`torch.Tensor`):
611
- A current instance of a sample created by diffusion process.
612
- Returns:
613
- `torch.Tensor`:
614
- The sample tensor at the previous timestep.
615
- """
616
-
617
- timestep_list = args[0] if len(args) > 0 else kwargs.pop(
618
- "timestep_list", None)
619
- prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
620
- "prev_timestep", None)
621
- if sample is None:
622
- if len(args) > 2:
623
- sample = args[2]
624
- else:
625
- raise ValueError(
626
- " missing`sample` as a required keyward argument")
627
- if timestep_list is not None:
628
- deprecate(
629
- "timestep_list",
630
- "1.0.0",
631
- "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
632
- )
633
-
634
- if prev_timestep is not None:
635
- deprecate(
636
- "prev_timestep",
637
- "1.0.0",
638
- "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
639
- )
640
-
641
- sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
642
- self.sigmas[self.step_index + 1], # pyright: ignore
643
- self.sigmas[self.step_index],
644
- self.sigmas[self.step_index - 1], # pyright: ignore
645
- self.sigmas[self.step_index - 2], # pyright: ignore
646
- )
647
-
648
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
649
- alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
650
- alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
651
- alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
652
-
653
- lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
654
- lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
655
- lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
656
- lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
657
-
658
- m0, m1, m2 = model_output_list[-1], model_output_list[
659
- -2], model_output_list[-3]
660
-
661
- h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
662
- r0, r1 = h_0 / h, h_1 / h
663
- D0 = m0
664
- D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
665
- D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
666
- D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
667
- if self.config.algorithm_type == "dpmsolver++":
668
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
669
- x_t = ((sigma_t / sigma_s0) * sample -
670
- (alpha_t * (torch.exp(-h) - 1.0)) * D0 +
671
- (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 -
672
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2)
673
- elif self.config.algorithm_type == "dpmsolver":
674
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
675
- x_t = ((alpha_t / alpha_s0) * sample - (sigma_t *
676
- (torch.exp(h) - 1.0)) * D0 -
677
- (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 -
678
- (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2)
679
- return x_t # pyright: ignore
680
-
681
- def index_for_timestep(self, timestep, schedule_timesteps=None):
682
- if schedule_timesteps is None:
683
- schedule_timesteps = self.timesteps
684
-
685
- indices = (schedule_timesteps == timestep).nonzero()
686
-
687
- # The sigma index that is taken for the **very** first `step`
688
- # is always the second index (or the last index if there is only 1)
689
- # This way we can ensure we don't accidentally skip a sigma in
690
- # case we start in the middle of the denoising schedule (e.g. for image-to-image)
691
- pos = 1 if len(indices) > 1 else 0
692
-
693
- return indices[pos].item()
694
-
695
- def _init_step_index(self, timestep):
696
- """
697
- Initialize the step_index counter for the scheduler.
698
- """
699
-
700
- if self.begin_index is None:
701
- if isinstance(timestep, torch.Tensor):
702
- timestep = timestep.to(self.timesteps.device)
703
- self._step_index = self.index_for_timestep(timestep)
704
- else:
705
- self._step_index = self._begin_index
706
-
707
- # Modified from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step
708
- def step(
709
- self,
710
- model_output: torch.Tensor,
711
- timestep: Union[int, torch.Tensor],
712
- sample: torch.Tensor,
713
- generator=None,
714
- variance_noise: Optional[torch.Tensor] = None,
715
- return_dict: bool = True,
716
- ) -> Union[SchedulerOutput, Tuple]:
717
- """
718
- Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
719
- the multistep DPMSolver.
720
- Args:
721
- model_output (`torch.Tensor`):
722
- The direct output from learned diffusion model.
723
- timestep (`int`):
724
- The current discrete timestep in the diffusion chain.
725
- sample (`torch.Tensor`):
726
- A current instance of a sample created by the diffusion process.
727
- generator (`torch.Generator`, *optional*):
728
- A random number generator.
729
- variance_noise (`torch.Tensor`):
730
- Alternative to generating noise with `generator` by directly providing the noise for the variance
731
- itself. Useful for methods such as [`LEdits++`].
732
- return_dict (`bool`):
733
- Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
734
- Returns:
735
- [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
736
- If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
737
- tuple is returned where the first element is the sample tensor.
738
- """
739
- if self.num_inference_steps is None:
740
- raise ValueError(
741
- "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
742
- )
743
-
744
- if self.step_index is None:
745
- self._init_step_index(timestep)
746
-
747
- # Improve numerical stability for small number of steps
748
- lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
749
- self.config.euler_at_final or
750
- (self.config.lower_order_final and len(self.timesteps) < 15) or
751
- self.config.final_sigmas_type == "zero")
752
- lower_order_second = ((self.step_index == len(self.timesteps) - 2) and
753
- self.config.lower_order_final and
754
- len(self.timesteps) < 15)
755
-
756
- model_output = self.convert_model_output(model_output, sample=sample)
757
- for i in range(self.config.solver_order - 1):
758
- self.model_outputs[i] = self.model_outputs[i + 1]
759
- self.model_outputs[-1] = model_output
760
-
761
- # Upcast to avoid precision issues when computing prev_sample
762
- sample = sample.to(torch.float32)
763
- if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"
764
- ] and variance_noise is None:
765
- noise = randn_tensor(
766
- model_output.shape,
767
- generator=generator,
768
- device=model_output.device,
769
- dtype=torch.float32)
770
- elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
771
- noise = variance_noise.to(
772
- device=model_output.device,
773
- dtype=torch.float32) # pyright: ignore
774
- else:
775
- noise = None
776
-
777
- if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
778
- prev_sample = self.dpm_solver_first_order_update(
779
- model_output, sample=sample, noise=noise)
780
- elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
781
- prev_sample = self.multistep_dpm_solver_second_order_update(
782
- self.model_outputs, sample=sample, noise=noise)
783
- else:
784
- prev_sample = self.multistep_dpm_solver_third_order_update(
785
- self.model_outputs, sample=sample)
786
-
787
- if self.lower_order_nums < self.config.solver_order:
788
- self.lower_order_nums += 1
789
-
790
- # Cast sample back to expected dtype
791
- prev_sample = prev_sample.to(model_output.dtype)
792
-
793
- # upon completion increase step index by one
794
- self._step_index += 1 # pyright: ignore
795
-
796
- if not return_dict:
797
- return (prev_sample,)
798
-
799
- return SchedulerOutput(prev_sample=prev_sample)
800
-
801
- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input
802
- def scale_model_input(self, sample: torch.Tensor, *args,
803
- **kwargs) -> torch.Tensor:
804
- """
805
- Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
806
- current timestep.
807
- Args:
808
- sample (`torch.Tensor`):
809
- The input sample.
810
- Returns:
811
- `torch.Tensor`:
812
- A scaled input sample.
813
- """
814
- return sample
815
-
816
- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input
817
- def add_noise(
818
- self,
819
- original_samples: torch.Tensor,
820
- noise: torch.Tensor,
821
- timesteps: torch.IntTensor,
822
- ) -> torch.Tensor:
823
- # Make sure sigmas and timesteps have the same device and dtype as original_samples
824
- sigmas = self.sigmas.to(
825
- device=original_samples.device, dtype=original_samples.dtype)
826
- if original_samples.device.type == "mps" and torch.is_floating_point(
827
- timesteps):
828
- # mps does not support float64
829
- schedule_timesteps = self.timesteps.to(
830
- original_samples.device, dtype=torch.float32)
831
- timesteps = timesteps.to(
832
- original_samples.device, dtype=torch.float32)
833
- else:
834
- schedule_timesteps = self.timesteps.to(original_samples.device)
835
- timesteps = timesteps.to(original_samples.device)
836
-
837
- # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
838
- if self.begin_index is None:
839
- step_indices = [
840
- self.index_for_timestep(t, schedule_timesteps)
841
- for t in timesteps
842
- ]
843
- elif self.step_index is not None:
844
- # add_noise is called after first denoising step (for inpainting)
845
- step_indices = [self.step_index] * timesteps.shape[0]
846
- else:
847
- # add noise is called before first denoising step to create initial latent(img2img)
848
- step_indices = [self.begin_index] * timesteps.shape[0]
849
-
850
- sigma = sigmas[step_indices].flatten()
851
- while len(sigma.shape) < len(original_samples.shape):
852
- sigma = sigma.unsqueeze(-1)
853
-
854
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
855
- noisy_samples = alpha_t * original_samples + sigma_t * noise
856
- return noisy_samples
857
-
858
- def __len__(self):
859
- return self.config.num_train_timesteps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/utils/fm_solvers_unipc.py DELETED
@@ -1,802 +0,0 @@
1
- # Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
2
- # Convert unipc for flow matching
3
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
4
-
5
- import math
6
- from typing import List, Optional, Tuple, Union
7
-
8
- import numpy as np
9
- import torch
10
- from diffusers.configuration_utils import ConfigMixin, register_to_config
11
- from diffusers.schedulers.scheduling_utils import (
12
- KarrasDiffusionSchedulers,
13
- SchedulerMixin,
14
- SchedulerOutput,
15
- )
16
- from diffusers.utils import deprecate, is_scipy_available
17
-
18
- if is_scipy_available():
19
- import scipy.stats
20
-
21
-
22
- class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
23
- """
24
- `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
25
-
26
- This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
27
- methods the library implements for all schedulers such as loading and saving.
28
-
29
- Args:
30
- num_train_timesteps (`int`, defaults to 1000):
31
- The number of diffusion steps to train the model.
32
- solver_order (`int`, default `2`):
33
- The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
34
- due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
35
- unconditional sampling.
36
- prediction_type (`str`, defaults to "flow_prediction"):
37
- Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
38
- the flow of the diffusion process.
39
- thresholding (`bool`, defaults to `False`):
40
- Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
41
- as Stable Diffusion.
42
- dynamic_thresholding_ratio (`float`, defaults to 0.995):
43
- The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
44
- sample_max_value (`float`, defaults to 1.0):
45
- The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
46
- predict_x0 (`bool`, defaults to `True`):
47
- Whether to use the updating algorithm on the predicted x0.
48
- solver_type (`str`, default `bh2`):
49
- Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
50
- otherwise.
51
- lower_order_final (`bool`, default `True`):
52
- Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
53
- stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
54
- disable_corrector (`list`, default `[]`):
55
- Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
56
- and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
57
- usually disabled during the first few steps.
58
- solver_p (`SchedulerMixin`, default `None`):
59
- Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
60
- use_karras_sigmas (`bool`, *optional*, defaults to `False`):
61
- Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
62
- the sigmas are determined according to a sequence of noise levels {σi}.
63
- use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
64
- Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
65
- timestep_spacing (`str`, defaults to `"linspace"`):
66
- The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
67
- Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
68
- steps_offset (`int`, defaults to 0):
69
- An offset added to the inference steps, as required by some model families.
70
- final_sigmas_type (`str`, defaults to `"zero"`):
71
- The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
72
- sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
73
- """
74
-
75
- _compatibles = [e.name for e in KarrasDiffusionSchedulers]
76
- order = 1
77
-
78
- @register_to_config
79
- def __init__(
80
- self,
81
- num_train_timesteps: int = 1000,
82
- solver_order: int = 2,
83
- prediction_type: str = "flow_prediction",
84
- shift: Optional[float] = 1.0,
85
- use_dynamic_shifting=False,
86
- thresholding: bool = False,
87
- dynamic_thresholding_ratio: float = 0.995,
88
- sample_max_value: float = 1.0,
89
- predict_x0: bool = True,
90
- solver_type: str = "bh2",
91
- lower_order_final: bool = True,
92
- disable_corrector: List[int] = [],
93
- solver_p: SchedulerMixin = None,
94
- timestep_spacing: str = "linspace",
95
- steps_offset: int = 0,
96
- final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
97
- ):
98
-
99
- if solver_type not in ["bh1", "bh2"]:
100
- if solver_type in ["midpoint", "heun", "logrho"]:
101
- self.register_to_config(solver_type="bh2")
102
- else:
103
- raise NotImplementedError(
104
- f"{solver_type} is not implemented for {self.__class__}")
105
-
106
- self.predict_x0 = predict_x0
107
- # setable values
108
- self.num_inference_steps = None
109
- alphas = np.linspace(1, 1 / num_train_timesteps,
110
- num_train_timesteps)[::-1].copy()
111
- sigmas = 1.0 - alphas
112
- sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
113
-
114
- if not use_dynamic_shifting:
115
- # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
116
- sigmas = shift * sigmas / (1 +
117
- (shift - 1) * sigmas) # pyright: ignore
118
-
119
- self.sigmas = sigmas
120
- self.timesteps = sigmas * num_train_timesteps
121
-
122
- self.model_outputs = [None] * solver_order
123
- self.timestep_list = [None] * solver_order
124
- self.lower_order_nums = 0
125
- self.disable_corrector = disable_corrector
126
- self.solver_p = solver_p
127
- self.last_sample = None
128
- self._step_index = None
129
- self._begin_index = None
130
-
131
- self.sigmas = self.sigmas.to(
132
- "cpu") # to avoid too much CPU/GPU communication
133
- self.sigma_min = self.sigmas[-1].item()
134
- self.sigma_max = self.sigmas[0].item()
135
-
136
- @property
137
- def step_index(self):
138
- """
139
- The index counter for current timestep. It will increase 1 after each scheduler step.
140
- """
141
- return self._step_index
142
-
143
- @property
144
- def begin_index(self):
145
- """
146
- The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
147
- """
148
- return self._begin_index
149
-
150
- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
151
- def set_begin_index(self, begin_index: int = 0):
152
- """
153
- Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
154
-
155
- Args:
156
- begin_index (`int`):
157
- The begin index for the scheduler.
158
- """
159
- self._begin_index = begin_index
160
-
161
- # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
162
- def set_timesteps(
163
- self,
164
- num_inference_steps: Union[int, None] = None,
165
- device: Union[str, torch.device] = None,
166
- sigmas: Optional[List[float]] = None,
167
- mu: Optional[Union[float, None]] = None,
168
- shift: Optional[Union[float, None]] = None,
169
- ):
170
- """
171
- Sets the discrete timesteps used for the diffusion chain (to be run before inference).
172
- Args:
173
- num_inference_steps (`int`):
174
- Total number of the spacing of the time steps.
175
- device (`str` or `torch.device`, *optional*):
176
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
177
- """
178
-
179
- if self.config.use_dynamic_shifting and mu is None:
180
- raise ValueError(
181
- " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
182
- )
183
-
184
- if sigmas is None:
185
- sigmas = np.linspace(self.sigma_max, self.sigma_min,
186
- num_inference_steps +
187
- 1).copy()[:-1] # pyright: ignore
188
-
189
- if self.config.use_dynamic_shifting:
190
- sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
191
- else:
192
- if shift is None:
193
- shift = self.config.shift
194
- sigmas = shift * sigmas / (1 +
195
- (shift - 1) * sigmas) # pyright: ignore
196
-
197
- if self.config.final_sigmas_type == "sigma_min":
198
- sigma_last = ((1 - self.alphas_cumprod[0]) /
199
- self.alphas_cumprod[0])**0.5
200
- elif self.config.final_sigmas_type == "zero":
201
- sigma_last = 0
202
- else:
203
- raise ValueError(
204
- f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
205
- )
206
-
207
- timesteps = sigmas * self.config.num_train_timesteps
208
- sigmas = np.concatenate([sigmas, [sigma_last]
209
- ]).astype(np.float32) # pyright: ignore
210
-
211
- self.sigmas = torch.from_numpy(sigmas)
212
- self.timesteps = torch.from_numpy(timesteps).to(
213
- device=device, dtype=torch.int64)
214
-
215
- self.num_inference_steps = len(timesteps)
216
-
217
- self.model_outputs = [
218
- None,
219
- ] * self.config.solver_order
220
- self.lower_order_nums = 0
221
- self.last_sample = None
222
- if self.solver_p:
223
- self.solver_p.set_timesteps(self.num_inference_steps, device=device)
224
-
225
- # add an index counter for schedulers that allow duplicated timesteps
226
- self._step_index = None
227
- self._begin_index = None
228
- self.sigmas = self.sigmas.to(
229
- "cpu") # to avoid too much CPU/GPU communication
230
-
231
- # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
232
- def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
233
- """
234
- "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
235
- prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
236
- s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
237
- pixels from saturation at each step. We find that dynamic thresholding results in significantly better
238
- photorealism as well as better image-text alignment, especially when using very large guidance weights."
239
-
240
- https://arxiv.org/abs/2205.11487
241
- """
242
- dtype = sample.dtype
243
- batch_size, channels, *remaining_dims = sample.shape
244
-
245
- if dtype not in (torch.float32, torch.float64):
246
- sample = sample.float(
247
- ) # upcast for quantile calculation, and clamp not implemented for cpu half
248
-
249
- # Flatten sample for doing quantile calculation along each image
250
- sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
251
-
252
- abs_sample = sample.abs() # "a certain percentile absolute pixel value"
253
-
254
- s = torch.quantile(
255
- abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
256
- s = torch.clamp(
257
- s, min=1, max=self.config.sample_max_value
258
- ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
259
- s = s.unsqueeze(
260
- 1) # (batch_size, 1) because clamp will broadcast along dim=0
261
- sample = torch.clamp(
262
- sample, -s, s
263
- ) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
264
-
265
- sample = sample.reshape(batch_size, channels, *remaining_dims)
266
- sample = sample.to(dtype)
267
-
268
- return sample
269
-
270
- # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
271
- def _sigma_to_t(self, sigma):
272
- return sigma * self.config.num_train_timesteps
273
-
274
- def _sigma_to_alpha_sigma_t(self, sigma):
275
- return 1 - sigma, sigma
276
-
277
- # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
278
- def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
279
- return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
280
-
281
- def convert_model_output(
282
- self,
283
- model_output: torch.Tensor,
284
- *args,
285
- sample: torch.Tensor = None,
286
- **kwargs,
287
- ) -> torch.Tensor:
288
- r"""
289
- Convert the model output to the corresponding type the UniPC algorithm needs.
290
-
291
- Args:
292
- model_output (`torch.Tensor`):
293
- The direct output from the learned diffusion model.
294
- timestep (`int`):
295
- The current discrete timestep in the diffusion chain.
296
- sample (`torch.Tensor`):
297
- A current instance of a sample created by the diffusion process.
298
-
299
- Returns:
300
- `torch.Tensor`:
301
- The converted model output.
302
- """
303
- timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
304
- if sample is None:
305
- if len(args) > 1:
306
- sample = args[1]
307
- else:
308
- raise ValueError(
309
- "missing `sample` as a required keyward argument")
310
- if timestep is not None:
311
- deprecate(
312
- "timesteps",
313
- "1.0.0",
314
- "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
315
- )
316
-
317
- sigma = self.sigmas[self.step_index]
318
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
319
-
320
- if self.predict_x0:
321
- if self.config.prediction_type == "flow_prediction":
322
- sigma_t = self.sigmas[self.step_index]
323
- x0_pred = sample - sigma_t * model_output
324
- else:
325
- raise ValueError(
326
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
327
- " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
328
- )
329
-
330
- if self.config.thresholding:
331
- x0_pred = self._threshold_sample(x0_pred)
332
-
333
- return x0_pred
334
- else:
335
- if self.config.prediction_type == "flow_prediction":
336
- sigma_t = self.sigmas[self.step_index]
337
- epsilon = sample - (1 - sigma_t) * model_output
338
- else:
339
- raise ValueError(
340
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
341
- " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
342
- )
343
-
344
- if self.config.thresholding:
345
- sigma_t = self.sigmas[self.step_index]
346
- x0_pred = sample - sigma_t * model_output
347
- x0_pred = self._threshold_sample(x0_pred)
348
- epsilon = model_output + x0_pred
349
-
350
- return epsilon
351
-
352
- def multistep_uni_p_bh_update(
353
- self,
354
- model_output: torch.Tensor,
355
- *args,
356
- sample: torch.Tensor = None,
357
- order: int = None, # pyright: ignore
358
- **kwargs,
359
- ) -> torch.Tensor:
360
- """
361
- One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
362
-
363
- Args:
364
- model_output (`torch.Tensor`):
365
- The direct output from the learned diffusion model at the current timestep.
366
- prev_timestep (`int`):
367
- The previous discrete timestep in the diffusion chain.
368
- sample (`torch.Tensor`):
369
- A current instance of a sample created by the diffusion process.
370
- order (`int`):
371
- The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
372
-
373
- Returns:
374
- `torch.Tensor`:
375
- The sample tensor at the previous timestep.
376
- """
377
- prev_timestep = args[0] if len(args) > 0 else kwargs.pop(
378
- "prev_timestep", None)
379
- if sample is None:
380
- if len(args) > 1:
381
- sample = args[1]
382
- else:
383
- raise ValueError(
384
- " missing `sample` as a required keyward argument")
385
- if order is None:
386
- if len(args) > 2:
387
- order = args[2]
388
- else:
389
- raise ValueError(
390
- " missing `order` as a required keyward argument")
391
- if prev_timestep is not None:
392
- deprecate(
393
- "prev_timestep",
394
- "1.0.0",
395
- "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
396
- )
397
- model_output_list = self.model_outputs
398
-
399
- s0 = self.timestep_list[-1]
400
- m0 = model_output_list[-1]
401
- x = sample
402
-
403
- if self.solver_p:
404
- x_t = self.solver_p.step(model_output, s0, x).prev_sample
405
- return x_t
406
-
407
- sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[
408
- self.step_index] # pyright: ignore
409
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
410
- alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
411
-
412
- lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
413
- lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
414
-
415
- h = lambda_t - lambda_s0
416
- device = sample.device
417
-
418
- rks = []
419
- D1s = []
420
- for i in range(1, order):
421
- si = self.step_index - i # pyright: ignore
422
- mi = model_output_list[-(i + 1)]
423
- alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
424
- lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
425
- rk = (lambda_si - lambda_s0) / h
426
- rks.append(rk)
427
- D1s.append((mi - m0) / rk) # pyright: ignore
428
-
429
- rks.append(1.0)
430
- rks = torch.tensor(rks, device=device)
431
-
432
- R = []
433
- b = []
434
-
435
- hh = -h if self.predict_x0 else h
436
- h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
437
- h_phi_k = h_phi_1 / hh - 1
438
-
439
- factorial_i = 1
440
-
441
- if self.config.solver_type == "bh1":
442
- B_h = hh
443
- elif self.config.solver_type == "bh2":
444
- B_h = torch.expm1(hh)
445
- else:
446
- raise NotImplementedError()
447
-
448
- for i in range(1, order + 1):
449
- R.append(torch.pow(rks, i - 1))
450
- b.append(h_phi_k * factorial_i / B_h)
451
- factorial_i *= i + 1
452
- h_phi_k = h_phi_k / hh - 1 / factorial_i
453
-
454
- R = torch.stack(R)
455
- b = torch.tensor(b, device=device)
456
-
457
- if len(D1s) > 0:
458
- D1s = torch.stack(D1s, dim=1) # (B, K)
459
- # for order 2, we use a simplified version
460
- if order == 2:
461
- rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
462
- else:
463
- rhos_p = torch.linalg.solve(R[:-1, :-1],
464
- b[:-1]).to(device).to(x.dtype)
465
- else:
466
- D1s = None
467
-
468
- if self.predict_x0:
469
- x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
470
- if D1s is not None:
471
- pred_res = torch.einsum("k,bkc...->bc...", rhos_p,
472
- D1s) # pyright: ignore
473
- else:
474
- pred_res = 0
475
- x_t = x_t_ - alpha_t * B_h * pred_res
476
- else:
477
- x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
478
- if D1s is not None:
479
- pred_res = torch.einsum("k,bkc...->bc...", rhos_p,
480
- D1s) # pyright: ignore
481
- else:
482
- pred_res = 0
483
- x_t = x_t_ - sigma_t * B_h * pred_res
484
-
485
- x_t = x_t.to(x.dtype)
486
- return x_t
487
-
488
- def multistep_uni_c_bh_update(
489
- self,
490
- this_model_output: torch.Tensor,
491
- *args,
492
- last_sample: torch.Tensor = None,
493
- this_sample: torch.Tensor = None,
494
- order: int = None, # pyright: ignore
495
- **kwargs,
496
- ) -> torch.Tensor:
497
- """
498
- One step for the UniC (B(h) version).
499
-
500
- Args:
501
- this_model_output (`torch.Tensor`):
502
- The model outputs at `x_t`.
503
- this_timestep (`int`):
504
- The current timestep `t`.
505
- last_sample (`torch.Tensor`):
506
- The generated sample before the last predictor `x_{t-1}`.
507
- this_sample (`torch.Tensor`):
508
- The generated sample after the last predictor `x_{t}`.
509
- order (`int`):
510
- The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
511
-
512
- Returns:
513
- `torch.Tensor`:
514
- The corrected sample tensor at the current timestep.
515
- """
516
- this_timestep = args[0] if len(args) > 0 else kwargs.pop(
517
- "this_timestep", None)
518
- if last_sample is None:
519
- if len(args) > 1:
520
- last_sample = args[1]
521
- else:
522
- raise ValueError(
523
- " missing`last_sample` as a required keyward argument")
524
- if this_sample is None:
525
- if len(args) > 2:
526
- this_sample = args[2]
527
- else:
528
- raise ValueError(
529
- " missing`this_sample` as a required keyward argument")
530
- if order is None:
531
- if len(args) > 3:
532
- order = args[3]
533
- else:
534
- raise ValueError(
535
- " missing`order` as a required keyward argument")
536
- if this_timestep is not None:
537
- deprecate(
538
- "this_timestep",
539
- "1.0.0",
540
- "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
541
- )
542
-
543
- model_output_list = self.model_outputs
544
-
545
- m0 = model_output_list[-1]
546
- x = last_sample
547
- x_t = this_sample
548
- model_t = this_model_output
549
-
550
- sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[
551
- self.step_index - 1] # pyright: ignore
552
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
553
- alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
554
-
555
- lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
556
- lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
557
-
558
- h = lambda_t - lambda_s0
559
- device = this_sample.device
560
-
561
- rks = []
562
- D1s = []
563
- for i in range(1, order):
564
- si = self.step_index - (i + 1) # pyright: ignore
565
- mi = model_output_list[-(i + 1)]
566
- alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
567
- lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
568
- rk = (lambda_si - lambda_s0) / h
569
- rks.append(rk)
570
- D1s.append((mi - m0) / rk) # pyright: ignore
571
-
572
- rks.append(1.0)
573
- rks = torch.tensor(rks, device=device)
574
-
575
- R = []
576
- b = []
577
-
578
- hh = -h if self.predict_x0 else h
579
- h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
580
- h_phi_k = h_phi_1 / hh - 1
581
-
582
- factorial_i = 1
583
-
584
- if self.config.solver_type == "bh1":
585
- B_h = hh
586
- elif self.config.solver_type == "bh2":
587
- B_h = torch.expm1(hh)
588
- else:
589
- raise NotImplementedError()
590
-
591
- for i in range(1, order + 1):
592
- R.append(torch.pow(rks, i - 1))
593
- b.append(h_phi_k * factorial_i / B_h)
594
- factorial_i *= i + 1
595
- h_phi_k = h_phi_k / hh - 1 / factorial_i
596
-
597
- R = torch.stack(R)
598
- b = torch.tensor(b, device=device)
599
-
600
- if len(D1s) > 0:
601
- D1s = torch.stack(D1s, dim=1)
602
- else:
603
- D1s = None
604
-
605
- # for order 1, we use a simplified version
606
- if order == 1:
607
- rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
608
- else:
609
- rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
610
-
611
- if self.predict_x0:
612
- x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
613
- if D1s is not None:
614
- corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
615
- else:
616
- corr_res = 0
617
- D1_t = model_t - m0
618
- x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
619
- else:
620
- x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
621
- if D1s is not None:
622
- corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
623
- else:
624
- corr_res = 0
625
- D1_t = model_t - m0
626
- x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
627
- x_t = x_t.to(x.dtype)
628
- return x_t
629
-
630
- def index_for_timestep(self, timestep, schedule_timesteps=None):
631
- if schedule_timesteps is None:
632
- schedule_timesteps = self.timesteps
633
-
634
- indices = (schedule_timesteps == timestep).nonzero()
635
-
636
- # The sigma index that is taken for the **very** first `step`
637
- # is always the second index (or the last index if there is only 1)
638
- # This way we can ensure we don't accidentally skip a sigma in
639
- # case we start in the middle of the denoising schedule (e.g. for image-to-image)
640
- pos = 1 if len(indices) > 1 else 0
641
-
642
- return indices[pos].item()
643
-
644
- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
645
- def _init_step_index(self, timestep):
646
- """
647
- Initialize the step_index counter for the scheduler.
648
- """
649
-
650
- if self.begin_index is None:
651
- if isinstance(timestep, torch.Tensor):
652
- timestep = timestep.to(self.timesteps.device)
653
- self._step_index = self.index_for_timestep(timestep)
654
- else:
655
- self._step_index = self._begin_index
656
-
657
- def step(self,
658
- model_output: torch.Tensor,
659
- timestep: Union[int, torch.Tensor],
660
- sample: torch.Tensor,
661
- return_dict: bool = True,
662
- generator=None) -> Union[SchedulerOutput, Tuple]:
663
- """
664
- Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
665
- the multistep UniPC.
666
-
667
- Args:
668
- model_output (`torch.Tensor`):
669
- The direct output from learned diffusion model.
670
- timestep (`int`):
671
- The current discrete timestep in the diffusion chain.
672
- sample (`torch.Tensor`):
673
- A current instance of a sample created by the diffusion process.
674
- return_dict (`bool`):
675
- Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
676
-
677
- Returns:
678
- [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
679
- If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
680
- tuple is returned where the first element is the sample tensor.
681
-
682
- """
683
- if self.num_inference_steps is None:
684
- raise ValueError(
685
- "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
686
- )
687
-
688
- if self.step_index is None:
689
- self._init_step_index(timestep)
690
-
691
- use_corrector = (
692
- self.step_index > 0 and
693
- self.step_index - 1 not in self.disable_corrector and
694
- self.last_sample is not None # pyright: ignore
695
- )
696
-
697
- model_output_convert = self.convert_model_output(
698
- model_output, sample=sample)
699
- if use_corrector:
700
- sample = self.multistep_uni_c_bh_update(
701
- this_model_output=model_output_convert,
702
- last_sample=self.last_sample,
703
- this_sample=sample,
704
- order=self.this_order,
705
- )
706
-
707
- for i in range(self.config.solver_order - 1):
708
- self.model_outputs[i] = self.model_outputs[i + 1]
709
- self.timestep_list[i] = self.timestep_list[i + 1]
710
-
711
- self.model_outputs[-1] = model_output_convert
712
- self.timestep_list[-1] = timestep # pyright: ignore
713
-
714
- if self.config.lower_order_final:
715
- this_order = min(self.config.solver_order,
716
- len(self.timesteps) -
717
- self.step_index) # pyright: ignore
718
- else:
719
- this_order = self.config.solver_order
720
-
721
- self.this_order = min(this_order,
722
- self.lower_order_nums + 1) # warmup for multistep
723
- assert self.this_order > 0
724
-
725
- self.last_sample = sample
726
- prev_sample = self.multistep_uni_p_bh_update(
727
- model_output=model_output, # pass the original non-converted model output, in case solver-p is used
728
- sample=sample,
729
- order=self.this_order,
730
- )
731
-
732
- if self.lower_order_nums < self.config.solver_order:
733
- self.lower_order_nums += 1
734
-
735
- # upon completion increase step index by one
736
- self._step_index += 1 # pyright: ignore
737
-
738
- if not return_dict:
739
- return (prev_sample,)
740
-
741
- return SchedulerOutput(prev_sample=prev_sample)
742
-
743
- def scale_model_input(self, sample: torch.Tensor, *args,
744
- **kwargs) -> torch.Tensor:
745
- """
746
- Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
747
- current timestep.
748
-
749
- Args:
750
- sample (`torch.Tensor`):
751
- The input sample.
752
-
753
- Returns:
754
- `torch.Tensor`:
755
- A scaled input sample.
756
- """
757
- return sample
758
-
759
- # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
760
- def add_noise(
761
- self,
762
- original_samples: torch.Tensor,
763
- noise: torch.Tensor,
764
- timesteps: torch.IntTensor,
765
- ) -> torch.Tensor:
766
- # Make sure sigmas and timesteps have the same device and dtype as original_samples
767
- sigmas = self.sigmas.to(
768
- device=original_samples.device, dtype=original_samples.dtype)
769
- if original_samples.device.type == "mps" and torch.is_floating_point(
770
- timesteps):
771
- # mps does not support float64
772
- schedule_timesteps = self.timesteps.to(
773
- original_samples.device, dtype=torch.float32)
774
- timesteps = timesteps.to(
775
- original_samples.device, dtype=torch.float32)
776
- else:
777
- schedule_timesteps = self.timesteps.to(original_samples.device)
778
- timesteps = timesteps.to(original_samples.device)
779
-
780
- # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
781
- if self.begin_index is None:
782
- step_indices = [
783
- self.index_for_timestep(t, schedule_timesteps)
784
- for t in timesteps
785
- ]
786
- elif self.step_index is not None:
787
- # add_noise is called after first denoising step (for inpainting)
788
- step_indices = [self.step_index] * timesteps.shape[0]
789
- else:
790
- # add noise is called before first denoising step to create initial latent(img2img)
791
- step_indices = [self.begin_index] * timesteps.shape[0]
792
-
793
- sigma = sigmas[step_indices].flatten()
794
- while len(sigma.shape) < len(original_samples.shape):
795
- sigma = sigma.unsqueeze(-1)
796
-
797
- alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
798
- noisy_samples = alpha_t * original_samples + sigma_t * noise
799
- return noisy_samples
800
-
801
- def __len__(self):
802
- return self.config.num_train_timesteps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
wan/utils/prompt_extend.py DELETED
@@ -1,542 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import json
3
- import logging
4
- import math
5
- import os
6
- import random
7
- import sys
8
- import tempfile
9
- from dataclasses import dataclass
10
- from http import HTTPStatus
11
- from typing import Optional, Union
12
-
13
- import dashscope
14
- import torch
15
- from PIL import Image
16
-
17
- try:
18
- from flash_attn import flash_attn_varlen_func
19
- FLASH_VER = 2
20
- except ModuleNotFoundError:
21
- flash_attn_varlen_func = None # in compatible with CPU machines
22
- FLASH_VER = None
23
-
24
- from .system_prompt import *
25
-
26
- DEFAULT_SYS_PROMPTS = {
27
- "t2v-A14B": {
28
- "zh": T2V_A14B_ZH_SYS_PROMPT,
29
- "en": T2V_A14B_EN_SYS_PROMPT,
30
- },
31
- "i2v-A14B": {
32
- "zh": I2V_A14B_ZH_SYS_PROMPT,
33
- "en": I2V_A14B_EN_SYS_PROMPT,
34
- "empty": {
35
- "zh": I2V_A14B_EMPTY_ZH_SYS_PROMPT,
36
- "en": I2V_A14B_EMPTY_EN_SYS_PROMPT,
37
- }
38
- },
39
- "ti2v-5B": {
40
- "t2v": {
41
- "zh": T2V_A14B_ZH_SYS_PROMPT,
42
- "en": T2V_A14B_EN_SYS_PROMPT,
43
- },
44
- "i2v": {
45
- "zh": I2V_A14B_ZH_SYS_PROMPT,
46
- "en": I2V_A14B_EN_SYS_PROMPT,
47
- }
48
- },
49
- }
50
-
51
-
52
- @dataclass
53
- class PromptOutput(object):
54
- status: bool
55
- prompt: str
56
- seed: int
57
- system_prompt: str
58
- message: str
59
-
60
- def add_custom_field(self, key: str, value) -> None:
61
- self.__setattr__(key, value)
62
-
63
-
64
- class PromptExpander:
65
-
66
- def __init__(self, model_name, task, is_vl=False, device=0, **kwargs):
67
- self.model_name = model_name
68
- self.task = task
69
- self.is_vl = is_vl
70
- self.device = device
71
-
72
- def extend_with_img(self,
73
- prompt,
74
- system_prompt,
75
- image=None,
76
- seed=-1,
77
- *args,
78
- **kwargs):
79
- pass
80
-
81
- def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
82
- pass
83
-
84
- def decide_system_prompt(self, tar_lang="zh", prompt=None):
85
- assert self.task is not None
86
- if "ti2v" in self.task:
87
- if self.is_vl:
88
- return DEFAULT_SYS_PROMPTS[self.task]["i2v"][tar_lang]
89
- else:
90
- return DEFAULT_SYS_PROMPTS[self.task]["t2v"][tar_lang]
91
- if "i2v" in self.task and len(prompt) == 0:
92
- return DEFAULT_SYS_PROMPTS[self.task]["empty"][tar_lang]
93
- return DEFAULT_SYS_PROMPTS[self.task][tar_lang]
94
-
95
- def __call__(self,
96
- prompt,
97
- system_prompt=None,
98
- tar_lang="zh",
99
- image=None,
100
- seed=-1,
101
- *args,
102
- **kwargs):
103
- if system_prompt is None:
104
- system_prompt = self.decide_system_prompt(
105
- tar_lang=tar_lang, prompt=prompt)
106
- if seed < 0:
107
- seed = random.randint(0, sys.maxsize)
108
- if image is not None and self.is_vl:
109
- return self.extend_with_img(
110
- prompt, system_prompt, image=image, seed=seed, *args, **kwargs)
111
- elif not self.is_vl:
112
- return self.extend(prompt, system_prompt, seed, *args, **kwargs)
113
- else:
114
- raise NotImplementedError
115
-
116
-
117
- class DashScopePromptExpander(PromptExpander):
118
-
119
- def __init__(self,
120
- api_key=None,
121
- model_name=None,
122
- task=None,
123
- max_image_size=512 * 512,
124
- retry_times=4,
125
- is_vl=False,
126
- **kwargs):
127
- '''
128
- Args:
129
- api_key: The API key for Dash Scope authentication and access to related services.
130
- model_name: Model name, 'qwen-plus' for extending prompts, 'qwen-vl-max' for extending prompt-images.
131
- task: Task name. This is required to determine the default system prompt.
132
- max_image_size: The maximum size of the image; unit unspecified (e.g., pixels, KB). Please specify the unit based on actual usage.
133
- retry_times: Number of retry attempts in case of request failure.
134
- is_vl: A flag indicating whether the task involves visual-language processing.
135
- **kwargs: Additional keyword arguments that can be passed to the function or method.
136
- '''
137
- if model_name is None:
138
- model_name = 'qwen-plus' if not is_vl else 'qwen-vl-max'
139
- super().__init__(model_name, task, is_vl, **kwargs)
140
- if api_key is not None:
141
- dashscope.api_key = api_key
142
- elif 'DASH_API_KEY' in os.environ and os.environ[
143
- 'DASH_API_KEY'] is not None:
144
- dashscope.api_key = os.environ['DASH_API_KEY']
145
- else:
146
- raise ValueError("DASH_API_KEY is not set")
147
- if 'DASH_API_URL' in os.environ and os.environ[
148
- 'DASH_API_URL'] is not None:
149
- dashscope.base_http_api_url = os.environ['DASH_API_URL']
150
- else:
151
- dashscope.base_http_api_url = 'https://dashscope.aliyuncs.com/api/v1'
152
- self.api_key = api_key
153
-
154
- self.max_image_size = max_image_size
155
- self.model = model_name
156
- self.retry_times = retry_times
157
-
158
- def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
159
- messages = [{
160
- 'role': 'system',
161
- 'content': system_prompt
162
- }, {
163
- 'role': 'user',
164
- 'content': prompt
165
- }]
166
-
167
- exception = None
168
- for _ in range(self.retry_times):
169
- try:
170
- response = dashscope.Generation.call(
171
- self.model,
172
- messages=messages,
173
- seed=seed,
174
- result_format='message', # set the result to be "message" format.
175
- )
176
- assert response.status_code == HTTPStatus.OK, response
177
- expanded_prompt = response['output']['choices'][0]['message'][
178
- 'content']
179
- return PromptOutput(
180
- status=True,
181
- prompt=expanded_prompt,
182
- seed=seed,
183
- system_prompt=system_prompt,
184
- message=json.dumps(response, ensure_ascii=False))
185
- except Exception as e:
186
- exception = e
187
- return PromptOutput(
188
- status=False,
189
- prompt=prompt,
190
- seed=seed,
191
- system_prompt=system_prompt,
192
- message=str(exception))
193
-
194
- def extend_with_img(self,
195
- prompt,
196
- system_prompt,
197
- image: Union[Image.Image, str] = None,
198
- seed=-1,
199
- *args,
200
- **kwargs):
201
- if isinstance(image, str):
202
- image = Image.open(image).convert('RGB')
203
- w = image.width
204
- h = image.height
205
- area = min(w * h, self.max_image_size)
206
- aspect_ratio = h / w
207
- resized_h = round(math.sqrt(area * aspect_ratio))
208
- resized_w = round(math.sqrt(area / aspect_ratio))
209
- image = image.resize((resized_w, resized_h))
210
- with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:
211
- image.save(f.name)
212
- fname = f.name
213
- image_path = f"file://{f.name}"
214
- prompt = f"{prompt}"
215
- messages = [
216
- {
217
- 'role': 'system',
218
- 'content': [{
219
- "text": system_prompt
220
- }]
221
- },
222
- {
223
- 'role': 'user',
224
- 'content': [{
225
- "text": prompt
226
- }, {
227
- "image": image_path
228
- }]
229
- },
230
- ]
231
- response = None
232
- result_prompt = prompt
233
- exception = None
234
- status = False
235
- for _ in range(self.retry_times):
236
- try:
237
- response = dashscope.MultiModalConversation.call(
238
- self.model,
239
- messages=messages,
240
- seed=seed,
241
- result_format='message', # set the result to be "message" format.
242
- )
243
- assert response.status_code == HTTPStatus.OK, response
244
- result_prompt = response['output']['choices'][0]['message'][
245
- 'content'][0]['text'].replace('\n', '\\n')
246
- status = True
247
- break
248
- except Exception as e:
249
- exception = e
250
- result_prompt = result_prompt.replace('\n', '\\n')
251
- os.remove(fname)
252
-
253
- return PromptOutput(
254
- status=status,
255
- prompt=result_prompt,
256
- seed=seed,
257
- system_prompt=system_prompt,
258
- message=str(exception) if not status else json.dumps(
259
- response, ensure_ascii=False))
260
-
261
-
262
- class QwenPromptExpander(PromptExpander):
263
- model_dict = {
264
- "QwenVL2.5_3B": "Qwen/Qwen2.5-VL-3B-Instruct",
265
- "QwenVL2.5_7B": "Qwen/Qwen2.5-VL-7B-Instruct",
266
- "Qwen2.5_3B": "Qwen/Qwen2.5-3B-Instruct",
267
- "Qwen2.5_7B": "Qwen/Qwen2.5-7B-Instruct",
268
- "Qwen2.5_14B": "Qwen/Qwen2.5-14B-Instruct",
269
- }
270
-
271
- def __init__(self,
272
- model_name=None,
273
- task=None,
274
- device=0,
275
- is_vl=False,
276
- **kwargs):
277
- '''
278
- Args:
279
- model_name: Use predefined model names such as 'QwenVL2.5_7B' and 'Qwen2.5_14B',
280
- which are specific versions of the Qwen model. Alternatively, you can use the
281
- local path to a downloaded model or the model name from Hugging Face."
282
- Detailed Breakdown:
283
- Predefined Model Names:
284
- * 'QwenVL2.5_7B' and 'Qwen2.5_14B' are specific versions of the Qwen model.
285
- Local Path:
286
- * You can provide the path to a model that you have downloaded locally.
287
- Hugging Face Model Name:
288
- * You can also specify the model name from Hugging Face's model hub.
289
- task: Task name. This is required to determine the default system prompt.
290
- is_vl: A flag indicating whether the task involves visual-language processing.
291
- **kwargs: Additional keyword arguments that can be passed to the function or method.
292
- '''
293
- if model_name is None:
294
- model_name = 'Qwen2.5_14B' if not is_vl else 'QwenVL2.5_7B'
295
- super().__init__(model_name, task, is_vl, device, **kwargs)
296
- if (not os.path.exists(self.model_name)) and (self.model_name
297
- in self.model_dict):
298
- self.model_name = self.model_dict[self.model_name]
299
-
300
- if self.is_vl:
301
- # default: Load the model on the available device(s)
302
- from transformers import (
303
- AutoProcessor,
304
- AutoTokenizer,
305
- Qwen2_5_VLForConditionalGeneration,
306
- )
307
- try:
308
- from .qwen_vl_utils import process_vision_info
309
- except:
310
- from qwen_vl_utils import process_vision_info
311
- self.process_vision_info = process_vision_info
312
- min_pixels = 256 * 28 * 28
313
- max_pixels = 1280 * 28 * 28
314
- self.processor = AutoProcessor.from_pretrained(
315
- self.model_name,
316
- min_pixels=min_pixels,
317
- max_pixels=max_pixels,
318
- use_fast=True)
319
- self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
320
- self.model_name,
321
- torch_dtype=torch.bfloat16 if FLASH_VER == 2 else
322
- torch.float16 if "AWQ" in self.model_name else "auto",
323
- attn_implementation="flash_attention_2"
324
- if FLASH_VER == 2 else None,
325
- device_map="cpu")
326
- else:
327
- from transformers import AutoModelForCausalLM, AutoTokenizer
328
- self.model = AutoModelForCausalLM.from_pretrained(
329
- self.model_name,
330
- torch_dtype=torch.float16
331
- if "AWQ" in self.model_name else "auto",
332
- attn_implementation="flash_attention_2"
333
- if FLASH_VER == 2 else None,
334
- device_map="cpu")
335
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
336
-
337
- def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
338
- self.model = self.model.to(self.device)
339
- messages = [{
340
- "role": "system",
341
- "content": system_prompt
342
- }, {
343
- "role": "user",
344
- "content": prompt
345
- }]
346
- text = self.tokenizer.apply_chat_template(
347
- messages, tokenize=False, add_generation_prompt=True)
348
- model_inputs = self.tokenizer([text],
349
- return_tensors="pt").to(self.model.device)
350
-
351
- generated_ids = self.model.generate(**model_inputs, max_new_tokens=512)
352
- generated_ids = [
353
- output_ids[len(input_ids):] for input_ids, output_ids in zip(
354
- model_inputs.input_ids, generated_ids)
355
- ]
356
-
357
- expanded_prompt = self.tokenizer.batch_decode(
358
- generated_ids, skip_special_tokens=True)[0]
359
- self.model = self.model.to("cpu")
360
- return PromptOutput(
361
- status=True,
362
- prompt=expanded_prompt,
363
- seed=seed,
364
- system_prompt=system_prompt,
365
- message=json.dumps({"content": expanded_prompt},
366
- ensure_ascii=False))
367
-
368
- def extend_with_img(self,
369
- prompt,
370
- system_prompt,
371
- image: Union[Image.Image, str] = None,
372
- seed=-1,
373
- *args,
374
- **kwargs):
375
- self.model = self.model.to(self.device)
376
- messages = [{
377
- 'role': 'system',
378
- 'content': [{
379
- "type": "text",
380
- "text": system_prompt
381
- }]
382
- }, {
383
- "role":
384
- "user",
385
- "content": [
386
- {
387
- "type": "image",
388
- "image": image,
389
- },
390
- {
391
- "type": "text",
392
- "text": prompt
393
- },
394
- ],
395
- }]
396
-
397
- # Preparation for inference
398
- text = self.processor.apply_chat_template(
399
- messages, tokenize=False, add_generation_prompt=True)
400
- image_inputs, video_inputs = self.process_vision_info(messages)
401
- inputs = self.processor(
402
- text=[text],
403
- images=image_inputs,
404
- videos=video_inputs,
405
- padding=True,
406
- return_tensors="pt",
407
- )
408
- inputs = inputs.to(self.device)
409
-
410
- # Inference: Generation of the output
411
- generated_ids = self.model.generate(**inputs, max_new_tokens=512)
412
- generated_ids_trimmed = [
413
- out_ids[len(in_ids):]
414
- for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
415
- ]
416
- expanded_prompt = self.processor.batch_decode(
417
- generated_ids_trimmed,
418
- skip_special_tokens=True,
419
- clean_up_tokenization_spaces=False)[0]
420
- self.model = self.model.to("cpu")
421
- return PromptOutput(
422
- status=True,
423
- prompt=expanded_prompt,
424
- seed=seed,
425
- system_prompt=system_prompt,
426
- message=json.dumps({"content": expanded_prompt},
427
- ensure_ascii=False))
428
-
429
-
430
- if __name__ == "__main__":
431
- logging.basicConfig(
432
- level=logging.INFO,
433
- format="[%(asctime)s] %(levelname)s: %(message)s",
434
- handlers=[logging.StreamHandler(stream=sys.stdout)])
435
-
436
- seed = 100
437
- prompt = "夏日海滩度假风格,一只戴着墨镜的白色猫咪坐在冲浪板上。猫咪毛发蓬松,表情悠闲,直视镜头。背景是模糊的海滩景色,海水清澈,远处有绿色的山丘和蓝天白云。猫咪的姿态自然放松,仿佛在享受海风和阳光。近景特写,强调猫咪的细节和海滩的清新氛围。"
438
- en_prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
439
- image = "./examples/i2v_input.JPG"
440
-
441
- def test(method,
442
- prompt,
443
- model_name,
444
- task,
445
- image=None,
446
- en_prompt=None,
447
- seed=None):
448
- prompt_expander = method(
449
- model_name=model_name, task=task, is_vl=image is not None)
450
- result = prompt_expander(prompt, image=image, tar_lang="zh")
451
- logging.info(f"zh prompt -> zh: {result.prompt}")
452
- result = prompt_expander(prompt, image=image, tar_lang="en")
453
- logging.info(f"zh prompt -> en: {result.prompt}")
454
- if en_prompt is not None:
455
- result = prompt_expander(en_prompt, image=image, tar_lang="zh")
456
- logging.info(f"en prompt -> zh: {result.prompt}")
457
- result = prompt_expander(en_prompt, image=image, tar_lang="en")
458
- logging.info(f"en prompt -> en: {result.prompt}")
459
-
460
- ds_model_name = None
461
- ds_vl_model_name = None
462
- qwen_model_name = None
463
- qwen_vl_model_name = None
464
-
465
- for task in ["t2v-A14B", "i2v-A14B", "ti2v-5B"]:
466
- # test prompt extend
467
- if "t2v" in task or "ti2v" in task:
468
- # test dashscope api
469
- logging.info(f"-" * 40)
470
- logging.info(f"Testing {task} dashscope prompt extend")
471
- test(
472
- DashScopePromptExpander,
473
- prompt,
474
- ds_model_name,
475
- task,
476
- image=None,
477
- en_prompt=en_prompt,
478
- seed=seed)
479
-
480
- # test qwen api
481
- logging.info(f"-" * 40)
482
- logging.info(f"Testing {task} qwen prompt extend")
483
- test(
484
- QwenPromptExpander,
485
- prompt,
486
- qwen_model_name,
487
- task,
488
- image=None,
489
- en_prompt=en_prompt,
490
- seed=seed)
491
-
492
- # test prompt-image extend
493
- if "i2v" in task:
494
- # test dashscope api
495
- logging.info(f"-" * 40)
496
- logging.info(f"Testing {task} dashscope vl prompt extend")
497
- test(
498
- DashScopePromptExpander,
499
- prompt,
500
- ds_vl_model_name,
501
- task,
502
- image=image,
503
- en_prompt=en_prompt,
504
- seed=seed)
505
-
506
- # test qwen api
507
- logging.info(f"-" * 40)
508
- logging.info(f"Testing {task} qwen vl prompt extend")
509
- test(
510
- QwenPromptExpander,
511
- prompt,
512
- qwen_vl_model_name,
513
- task,
514
- image=image,
515
- en_prompt=en_prompt,
516
- seed=seed)
517
-
518
- # test empty prompt extend
519
- if "i2v-A14B" in task:
520
- # test dashscope api
521
- logging.info(f"-" * 40)
522
- logging.info(f"Testing {task} dashscope vl empty prompt extend")
523
- test(
524
- DashScopePromptExpander,
525
- "",
526
- ds_vl_model_name,
527
- task,
528
- image=image,
529
- en_prompt=None,
530
- seed=seed)
531
-
532
- # test qwen api
533
- logging.info(f"-" * 40)
534
- logging.info(f"Testing {task} qwen vl empty prompt extend")
535
- test(
536
- QwenPromptExpander,
537
- "",
538
- qwen_vl_model_name,
539
- task,
540
- image=image,
541
- en_prompt=None,
542
- seed=seed)