#!/usr/bin/env python3 """ Script to convert Wan2.1 model from ComfyUI format to diffusers format Usage: python convert_comfyui_to_diffusers.py \ --input /path/to/aniWan2114BFp8E4m3fn_i2v480pNew.safetensors \ --output ./weights/aniWan2114B_diffusers/ \ --dtype bfloat16 \ --split-size 4.0 """ import argparse import json import os import re from collections import OrderedDict from pathlib import Path import torch from safetensors import safe_open from safetensors.torch import save_file from tqdm import tqdm def build_key_mapping(): """ComfyUI -> diffusers のキーマッピングルールを定義""" def convert_key(comfy_key: str) -> str: """単一のキーを変換""" key = comfy_key # プレフィックス除去: model.diffusion_model. -> (なし) key = key.replace("model.diffusion_model.", "") # Attention type変換 # cross_attn -> attn2 (cross attention with text/image) # self_attn -> attn1 (self attention) key = key.replace(".cross_attn.", ".attn2.") key = key.replace(".self_attn.", ".attn1.") # Attention projection変換 # ComfyUI: .q.weight, .k.weight, .v.weight, .o.weight # diffusers: .to_q.weight, .to_k.weight, .to_v.weight, .to_out.0.weight key = re.sub(r'\.q\.weight', '.to_q.weight', key) key = re.sub(r'\.q\.bias', '.to_q.bias', key) key = re.sub(r'\.k\.weight', '.to_k.weight', key) key = re.sub(r'\.k\.bias', '.to_k.bias', key) key = re.sub(r'\.v\.weight', '.to_v.weight', key) key = re.sub(r'\.v\.bias', '.to_v.bias', key) key = re.sub(r'\.o\.weight', '.to_out.0.weight', key) key = re.sub(r'\.o\.bias', '.to_out.0.bias', key) # Image cross attention (add_k_proj, add_v_proj) # ComfyUI: k_img, v_img # diffusers: add_k_proj, add_v_proj key = re.sub(r'\.k_img\.weight', '.add_k_proj.weight', key) key = re.sub(r'\.k_img\.bias', '.add_k_proj.bias', key) key = re.sub(r'\.v_img\.weight', '.add_v_proj.weight', key) key = re.sub(r'\.v_img\.bias', '.add_v_proj.bias', key) # Norm keys for attention # ComfyUI: norm_k, norm_q, norm_k_img # diffusers: norm_k, norm_q, norm_added_k key = key.replace('.norm_k_img.', '.norm_added_k.') # norm_k, norm_q はそのまま # FFN/MLP変換 # ComfyUI: ffn.0.weight, ffn.2.weight # diffusers: ffn.net.0.proj.weight, ffn.net.2.weight key = re.sub(r'\.ffn\.0\.', '.ffn.net.0.proj.', key) key = re.sub(r'\.ffn\.2\.', '.ffn.net.2.', key) # Block norm変換 # ComfyUI: norm3 -> diffusers: norm2 # ComfyUI: modulation -> diffusers: scale_shift_table key = key.replace('.norm3.', '.norm2.') # ブロック内のmodulationのみ変換 key = re.sub(r'blocks\.(\d+)\.modulation', r'blocks.\1.scale_shift_table', key) # Head (output projection) # ComfyUI: head.head.weight -> diffusers: proj_out.weight # ComfyUI: head.modulation -> diffusers: scale_shift_table (ルートレベル) key = key.replace('head.head.weight', 'proj_out.weight') key = key.replace('head.head.bias', 'proj_out.bias') key = key.replace('head.modulation', 'scale_shift_table') # ルートレベルに # Time/Text embedding変換 # ComfyUI: time_embedding.0 -> diffusers: condition_embedder.time_embedder.linear_1 # ComfyUI: time_embedding.2 -> diffusers: condition_embedder.time_embedder.linear_2 # ComfyUI: text_embedding.0 -> diffusers: condition_embedder.text_embedder.linear_1 # ComfyUI: text_embedding.2 -> diffusers: condition_embedder.text_embedder.linear_2 key = key.replace('time_embedding.0.', 'condition_embedder.time_embedder.linear_1.') key = key.replace('time_embedding.2.', 'condition_embedder.time_embedder.linear_2.') key = key.replace('text_embedding.0.', 'condition_embedder.text_embedder.linear_1.') key = key.replace('text_embedding.2.', 'condition_embedder.text_embedder.linear_2.') # Time projection # ComfyUI: time_projection.1 -> diffusers: condition_embedder.time_proj key = key.replace('time_projection.1.', 'condition_embedder.time_proj.') # Image embedding (I2V専用) # 正しいマッピング(shapeで確認済み): # ComfyUI: img_emb.proj.0 (1280,) -> diffusers: image_embedder.norm1 (1280,) # ComfyUI: img_emb.proj.1 (1280, 1280) -> diffusers: image_embedder.ff.net.0.proj (1280, 1280) # ComfyUI: img_emb.proj.3 (5120, 1280) -> diffusers: image_embedder.ff.net.2 (5120, 1280) # ComfyUI: img_emb.proj.4 (5120,) -> diffusers: image_embedder.norm2 (5120,) key = key.replace('img_emb.proj.0.', 'condition_embedder.image_embedder.norm1.') key = key.replace('img_emb.proj.1.', 'condition_embedder.image_embedder.ff.net.0.proj.') key = key.replace('img_emb.proj.3.', 'condition_embedder.image_embedder.ff.net.2.') key = key.replace('img_emb.proj.4.', 'condition_embedder.image_embedder.norm2.') return key return convert_key def analyze_model_structure(input_path: str) -> dict: """モデルの構造を分析""" print(f"Analyzing model structure: {input_path}") with safe_open(input_path, framework="pt") as f: keys = list(f.keys()) # サンプルテンソルの情報を取得 sample_tensor = f.get_tensor(keys[0]) dtype = sample_tensor.dtype # キーのパターンを分析 prefixes = set() for key in keys: parts = key.split('.') if len(parts) >= 3: prefixes.add('.'.join(parts[:3])) return { 'total_keys': len(keys), 'dtype': dtype, 'sample_keys': keys[:10], 'prefixes': sorted(prefixes)[:20] } def convert_dtype(tensor: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor: """テンソルのdtypeを変換""" if tensor.dtype == target_dtype: return tensor # FP8からの変換 if tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: # FP8 -> FP32 -> target return tensor.to(torch.float32).to(target_dtype) return tensor.to(target_dtype) def convert_model( input_path: str, output_dir: str, target_dtype: torch.dtype = torch.bfloat16, split_size_gb: float = 2.0, dry_run: bool = False ): """ ComfyUI形式のモデルをdiffusers形式に変換 Args: input_path: 入力safetensorsファイルのパス output_dir: 出力ディレクトリ target_dtype: 出力のdtype (default: bfloat16) split_size_gb: 分割サイズ (GB) dry_run: True の場合、実際の変換は行わずマッピングのみ表示 """ convert_key = build_key_mapping() print(f"Input: {input_path}") print(f"Output: {output_dir}") print(f"Target dtype: {target_dtype}") # モデル構造を分析 info = analyze_model_structure(input_path) print(f"Total keys: {info['total_keys']}") print(f"Source dtype: {info['dtype']}") # キーマッピングを作成 print("\nBuilding key mapping...") key_mapping = {} unmapped_keys = [] with safe_open(input_path, framework="pt") as f: all_keys = list(f.keys()) for old_key in all_keys: new_key = convert_key(old_key) key_mapping[old_key] = new_key # マッピングが変わらなかったキーを記録 if old_key == new_key and old_key.startswith("model.diffusion_model."): unmapped_keys.append(old_key) # マッピング結果を表示 print("\nSample key mappings:") for i, (old, new) in enumerate(list(key_mapping.items())[:15]): print(f" {old}") print(f" -> {new}") if unmapped_keys: print(f"\nWarning: {len(unmapped_keys)} keys were not mapped:") for key in unmapped_keys[:5]: print(f" {key}") if dry_run: print("\n[DRY RUN] Skipping actual conversion") return key_mapping # 出力ディレクトリ作成 os.makedirs(output_dir, exist_ok=True) # 変換実行 print("\nConverting tensors...") converted_state_dict = OrderedDict() with safe_open(input_path, framework="pt") as f: for old_key in tqdm(all_keys, desc="Converting"): tensor = f.get_tensor(old_key) new_key = key_mapping[old_key] # dtype変換 tensor = convert_dtype(tensor, target_dtype) converted_state_dict[new_key] = tensor # ファイルサイズを計算して分割するか決定 total_size = sum(t.numel() * t.element_size() for t in converted_state_dict.values()) total_size_gb = total_size / (1024**3) print(f"\nTotal model size: {total_size_gb:.2f} GB") if total_size_gb > split_size_gb: # 複数ファイルに分割 num_shards = int(total_size_gb / split_size_gb) + 1 print(f"Splitting into {num_shards} shards...") keys_per_shard = len(converted_state_dict) // num_shards all_keys_list = list(converted_state_dict.keys()) weight_map = {} for shard_idx in range(num_shards): start_idx = shard_idx * keys_per_shard end_idx = start_idx + keys_per_shard if shard_idx < num_shards - 1 else len(all_keys_list) shard_keys = all_keys_list[start_idx:end_idx] shard_dict = OrderedDict((k, converted_state_dict[k]) for k in shard_keys) shard_filename = f"diffusion_pytorch_model-{shard_idx+1:05d}-of-{num_shards:05d}.safetensors" shard_path = os.path.join(output_dir, shard_filename) print(f"Saving {shard_filename}...") save_file(shard_dict, shard_path) for k in shard_keys: weight_map[k] = shard_filename # index.json作成 index = { "metadata": {"total_size": total_size}, "weight_map": weight_map } index_path = os.path.join(output_dir, "diffusion_pytorch_model.safetensors.index.json") with open(index_path, 'w') as f: json.dump(index, f, indent=2) print(f"Saved index: {index_path}") else: # 単一ファイルとして保存 output_path = os.path.join(output_dir, "diffusion_pytorch_model.safetensors") print(f"Saving to {output_path}...") save_file(converted_state_dict, output_path) # config.jsonをコピー(存在する場合) input_dir = os.path.dirname(input_path) config_src = os.path.join(input_dir, "config.json") config_dst = os.path.join(output_dir, "config.json") if os.path.exists(config_src): import shutil shutil.copy(config_src, config_dst) print(f"Copied config.json") else: # デフォルトのWan2.1 I2V config default_config = { "_class_name": "WanTransformer3DModel", "_diffusers_version": "0.33.0.dev0", "added_kv_proj_dim": 5120, "attention_head_dim": 128, "cross_attn_norm": True, "eps": 1e-06, "ffn_dim": 13824, "freq_dim": 256, "image_dim": 1280, "in_channels": 36, "num_attention_heads": 40, "num_layers": 40, "out_channels": 16, "patch_size": [1, 2, 2], "qk_norm": "rms_norm_across_heads", "rope_max_seq_len": 1024, "text_dim": 4096 } with open(config_dst, 'w') as f: json.dump(default_config, f, indent=2) print(f"Created default config.json") print("\nConversion complete!") print(f"Output directory: {output_dir}") return key_mapping def verify_conversion(converted_dir: str, reference_model: str = None): """変換結果を検証""" print(f"\nVerifying conversion: {converted_dir}") # 変換後のキーを取得 converted_keys = set() for f in os.listdir(converted_dir): if f.endswith('.safetensors'): path = os.path.join(converted_dir, f) with safe_open(path, framework="pt") as sf: converted_keys.update(sf.keys()) print(f"Converted model has {len(converted_keys)} keys") # リファレンスモデルと比較(オプション) if reference_model: try: from diffusers import WanTransformer3DModel print(f"Loading reference model structure from {reference_model}...") # configだけ読み込み config_path = os.path.join(converted_dir, "config.json") with open(config_path) as f: config = json.load(f) # 期待されるキーパターンを確認 print("Config loaded successfully") print(f" num_layers: {config.get('num_layers')}") print(f" num_attention_heads: {config.get('num_attention_heads')}") except Exception as e: print(f"Could not load reference model: {e}") # サンプルキーを表示 print("\nSample converted keys:") for key in sorted(converted_keys)[:10]: print(f" {key}") def main(): parser = argparse.ArgumentParser( description="Convert ComfyUI Wan2.1 model to diffusers format" ) parser.add_argument( "--input", "-i", required=True, help="Input safetensors file (ComfyUI format)" ) parser.add_argument( "--output", "-o", required=True, help="Output directory for diffusers format" ) parser.add_argument( "--dtype", choices=["bfloat16", "float16", "float32"], default="bfloat16", help="Target dtype (default: bfloat16)" ) parser.add_argument( "--split-size", type=float, default=2.0, help="Split size in GB (default: 2.0)" ) parser.add_argument( "--dry-run", action="store_true", help="Show key mapping without converting" ) parser.add_argument( "--verify", action="store_true", help="Verify conversion after completion" ) args = parser.parse_args() dtype_map = { "bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32, } convert_model( input_path=args.input, output_dir=args.output, target_dtype=dtype_map[args.dtype], split_size_gb=args.split_size, dry_run=args.dry_run ) if args.verify and not args.dry_run: verify_conversion(args.output) if __name__ == "__main__": main()