Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import torch | |
| import tempfile | |
| import random | |
| import string | |
| import json | |
| from omegaconf import OmegaConf,ListConfig | |
| from train import main as train_main | |
| from inference import inference as inference_main | |
| # 模拟训练函数 | |
| def train_model(video, config): | |
| output_dir = 'results' | |
| os.makedirs(output_dir, exist_ok=True) | |
| cur_save_dir = os.path.join(output_dir, 'custom') | |
| config.dataset.single_video_path = video | |
| config.train.output_dir = cur_save_dir | |
| # copy video to cur_save_dir | |
| video_name = 'source.mp4' | |
| video_path = os.path.join(cur_save_dir, video_name) | |
| os.system(f"cp {video} {video_path}") | |
| train_main(config) | |
| # cur_save_dir = 'results/06' | |
| return cur_save_dir | |
| # 模拟推理函数 | |
| def inference_model(text, checkpoint, inference_steps, video_type,seed): | |
| checkpoint = os.path.join('results',checkpoint) | |
| embedding_dir = '/'.join(checkpoint.split('/')[:-1]) | |
| video_round = checkpoint.split('/')[-1] | |
| video_path = inference_main( | |
| embedding_dir=embedding_dir, | |
| prompt=text, | |
| video_round=video_round, | |
| save_dir=os.path.join('outputs',embedding_dir.split('/')[-1]), | |
| motion_type=video_type, | |
| seed=seed, | |
| inference_steps=inference_steps | |
| ) | |
| return video_path | |
| # 获取checkpoint文件列表 | |
| def get_checkpoints(checkpoint_dir): | |
| checkpoints = [] | |
| for root, dirs, files in os.walk(checkpoint_dir): | |
| for file in files: | |
| if file == 'motion_embed.pt': | |
| checkpoints.append('/'.join(root.split('/')[-2:])) | |
| return checkpoints | |
| def extract_combinations(motion_embeddings_combinations): | |
| assert len(motion_embeddings_combinations) > 0, "At least one motion embedding combination is required" | |
| combinations = [] | |
| for combination in motion_embeddings_combinations: | |
| name, resolution = combination.split(" ") | |
| combinations.append([name, int(resolution)]) | |
| return combinations | |
| def generate_config_train(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps): | |
| default_config = OmegaConf.load('configs/config.yaml') | |
| default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations)) | |
| default_config.model.unet = unet | |
| default_config.train.checkpointing_steps = checkpointing_steps | |
| default_config.train.max_train_steps = max_train_steps | |
| return default_config | |
| def generate_config_inference(motion_embeddings_combinations, unet, checkpointing_steps, max_train_steps): | |
| default_config = OmegaConf.load('configs/config.yaml') | |
| default_config.model.motion_embeddings.combinations = ListConfig(extract_combinations(motion_embeddings_combinations)) | |
| default_config.model.unet = unet | |
| default_config.train.checkpointing_steps = checkpointing_steps | |
| default_config.train.max_train_steps = max_train_steps | |
| return default_config | |
| def update_preview_video(checkpoint_dir): | |
| # get the parent dir of the checkpoint | |
| parent_dir = '/'.join(checkpoint_dir.split('/')[:-1]) | |
| return gr.update(value=f'results/{parent_dir}/source.mp4') | |
| if __name__ == "__main__": | |
| if os.path.exists('results/custom'): | |
| os.system('rm -rf results/custom') | |
| if os.path.exists('outputs'): | |
| os.system('rm -rf outputs') | |
| inject_motion_embeddings_combinations = ['down 1280','up 1280','down 640','up 640'] | |
| default_motion_embeddings_combinations = ['down 1280','up 1280'] | |
| examples_train = [ | |
| 'assets/train/car_turn.mp4', | |
| 'assets/train/pan_up.mp4', | |
| 'assets/train/run_up.mp4', | |
| 'assets/train/train_ride.mp4', | |
| 'assets/train/orbit_shot.mp4', | |
| 'assets/train/dolly_zoom_out.mp4', | |
| 'assets/train/santa_dance.mp4', | |
| ] | |
| examples_inference = [ | |
| ['results/pan_up/source.mp4', 'A flora garden.', 'camera', 'pan_up/checkpoint'], | |
| ['results/dolly_zoom/source.mp4','A firefighter standing in front of a burning forest captured with a dolly zoom.','camera','dolly_zoom/checkpoint'], | |
| ['results/orbit_shot/source.mp4','A micro graden with orbit shot','camera','orbit_shot/checkpoint'], | |
| ['results/walk/source.mp4', 'A elephant walking in desert', 'object', 'walk/checkpoint'], | |
| ['results/santa_dance/source.mp4','A skeleton in suit is dancing with his hands','object','santa_dance/checkpoint'], | |
| ['results/car_turn/source.mp4','A toy train chugs around a roundabout tree','object','car_turn/checkpoint'], | |
| ['results/train_ride/source.mp4','A motorbike driving in a forest','object','train_ride/checkpoint'], | |
| ] | |
| # 创建Gradio界面 | |
| with gr.Blocks() as demo: | |
| with gr.Tab("Train"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input = gr.Video(label="Upload Video") | |
| train_button = gr.Button("Train") | |
| with gr.Column(): | |
| checkpoint_output = gr.Textbox(label="Checkpoint Directory") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| motion_embeddings_combinations = gr.Dropdown(label="Motion Embeddings Combinations", choices=inject_motion_embeddings_combinations, multiselect=True,value=default_motion_embeddings_combinations) | |
| unet_dropdown = gr.Dropdown(label="Unet", choices=["videoCrafter2", "zeroscope_v2_576w"], value="videoCrafter2") | |
| checkpointing_steps = gr.Dropdown(label="Checkpointing Steps",choices=[100,50],value=100) | |
| max_train_steps = gr.Slider(label="Max Train Steps", minimum=200,maximum=500,value=200,step=50) | |
| # examples | |
| gr.Examples(examples=examples_train,inputs=[video_input]) | |
| train_button.click( | |
| lambda video, mec, u, cs, mts: train_model(video, generate_config_train(mec, u, cs, mts)), | |
| inputs=[video_input, motion_embeddings_combinations, unet_dropdown, checkpointing_steps, max_train_steps], | |
| outputs=checkpoint_output | |
| ) | |
| with gr.Tab("Inference"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| preview_video = gr.Video(label="Preview Video") | |
| text_input = gr.Textbox(label="Input Text") | |
| checkpoint_dropdown = gr.Dropdown(label="Select Checkpoint", choices=get_checkpoints('results')) | |
| seed = gr.Number(label="Seed", value=0) | |
| inference_button = gr.Button("Generate Video") | |
| with gr.Column(): | |
| output_video = gr.Video(label="Output Video") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| inference_steps = gr.Number(label="Inference Steps", value=30) | |
| motion_type = gr.Dropdown(label="Motion Type", choices=["camera", "object"], value="object") | |
| gr.Examples(examples=examples_inference,inputs=[preview_video,text_input,motion_type,checkpoint_dropdown]) | |
| def update_checkpoints(checkpoint_dir): | |
| return gr.update(choices=get_checkpoints('results')) | |
| checkpoint_dropdown.change(fn=update_preview_video, inputs=checkpoint_dropdown, outputs=preview_video) | |
| checkpoint_output.change(update_checkpoints, inputs=checkpoint_output, outputs=checkpoint_dropdown) | |
| inference_button.click(inference_model, inputs=[text_input, checkpoint_dropdown,inference_steps,motion_type, seed], outputs=output_video) | |
| # 启动Gradio界面 | |
| demo.launch() |