Spaces:
Running
Running
| import os | |
| import imageio | |
| import numpy as np | |
| import torch | |
| import random | |
| import spaces | |
| import gradio as gr | |
| import torchvision | |
| import torchvision.transforms as T | |
| from einops import rearrange | |
| from huggingface_hub import hf_hub_download | |
| from torchvision.models.optical_flow import raft_large, Raft_Large_Weights | |
| from torchvision.utils import flow_to_image | |
| from diffusers import AutoencoderKL, MotionAdapter, UNet2DConditionModel | |
| from diffusers import DDIMScheduler | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from onlyflow.models.flow_adaptor import FlowEncoder, FlowAdaptor | |
| from onlyflow.models.unet import UNetMotionModel | |
| from onlyflow.pipelines.pipeline_animation_long import FlowCtrlPipeline | |
| from tools.optical_flow import get_optical_flow | |
| def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): | |
| videos = rearrange(videos, "b c t h w -> t b c h w") | |
| outputs = [] | |
| for x in videos: | |
| x = torchvision.utils.make_grid(x, nrow=n_rows) | |
| x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
| if rescale: | |
| x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
| x = (x * 255).numpy().astype(np.uint8) | |
| outputs.append(x) | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| imageio.mimsave(path, outputs, fps=fps) | |
| css = """ | |
| .toolbutton { | |
| margin-buttom: 0em 0em 0em 0em; | |
| max-width: 2.5em; | |
| min-width: 2.5em !important; | |
| height: 2.5em; | |
| } | |
| """ | |
| class AnimateController: | |
| def __init__(self): | |
| # config dirs | |
| self.basedir = os.getcwd() | |
| self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion") | |
| self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module") | |
| self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA") | |
| self.savedir = os.path.join(self.basedir, "samples") | |
| os.makedirs(self.savedir, exist_ok=True) | |
| ckpt_path = hf_hub_download('obvious-research/onlyflow', 'weights_fp16.ckpt') | |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) | |
| self.flow_encoder_state_dict = ckpt['flow_encoder_state_dict'] | |
| self.attention_processor_state_dict = ckpt['attention_processor_state_dict'] | |
| self.tokenizer = None | |
| self.text_encoder = None | |
| self.vae = None | |
| self.unet = None | |
| self.motion_adapter = None | |
| def update_base_model(self, base_model_id, progress=gr.Progress()): | |
| progress(0, desc="Starting...") | |
| self.tokenizer = CLIPTokenizer.from_pretrained(base_model_id, subfolder="tokenizer") | |
| self.text_encoder = CLIPTextModel.from_pretrained(base_model_id, subfolder="text_encoder") | |
| self.vae = AutoencoderKL.from_pretrained(base_model_id, subfolder="vae") | |
| self.unet = UNet2DConditionModel.from_pretrained(base_model_id, subfolder="unet") | |
| return base_model_id | |
| def update_motion_module(self, motion_module_id, progress=gr.Progress()): | |
| self.motion_adapter = MotionAdapter.from_pretrained(motion_module_id) | |
| def animate( | |
| self, | |
| id_base_model, | |
| id_motion_module, | |
| prompt_textbox_positive, | |
| prompt_textbox_negative, | |
| seed_textbox, | |
| input_video, | |
| height, | |
| width, | |
| flow_scale, | |
| cfg, | |
| diffusion_steps, | |
| temporal_ds, | |
| ctx_stride | |
| ): | |
| #if any([x is None for x in [self.tokenizer, self.text_encoder, self.vae, self.unet, self.motion_adapter]]) or isinstance(self.unet, str): | |
| self.update_base_model(id_base_model) | |
| self.update_motion_module(id_motion_module) | |
| self.unet = UNetMotionModel.from_unet2d( | |
| self.unet, | |
| motion_adapter=self.motion_adapter | |
| ) | |
| self.raft = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).eval() | |
| self.flow_encoder = FlowEncoder( | |
| downscale_factor=8, | |
| channels=[320, 640, 1280, 1280], | |
| nums_rb=2, | |
| ksize=1, | |
| sk=True, | |
| use_conv=False, | |
| compression_factor=1, | |
| temporal_attention_nhead=8, | |
| positional_embeddings="sinusoidal", | |
| num_positional_embeddings=16, | |
| checkpointing=False | |
| ).eval() | |
| self.vae.requires_grad_(False) | |
| self.text_encoder.requires_grad_(False) | |
| self.unet.requires_grad_(False) | |
| self.raft.requires_grad_(False) | |
| self.flow_encoder.requires_grad_(False) | |
| self.unet.set_all_attn( | |
| flow_channels=[320, 640, 1280, 1280], | |
| add_spatial=False, | |
| add_temporal=True, | |
| encoder_only=False, | |
| query_condition=True, | |
| key_value_condition=True, | |
| flow_scale=1.0, | |
| ) | |
| self.flow_adaptor = FlowAdaptor(self.unet, self.flow_encoder).eval() | |
| # load the flow encoder weights | |
| pose_enc_m, pose_enc_u = self.flow_adaptor.flow_encoder.load_state_dict( | |
| self.flow_encoder_state_dict, | |
| strict=False | |
| ) | |
| assert len(pose_enc_m) == 0 and len(pose_enc_u) == 0 | |
| # load the attention processor weights | |
| _, attention_processor_u = self.flow_adaptor.unet.load_state_dict( | |
| self.attention_processor_state_dict, | |
| strict=False | |
| ) | |
| assert len(attention_processor_u) == 0 | |
| pipeline = FlowCtrlPipeline( | |
| vae=self.vae, | |
| text_encoder=self.text_encoder, | |
| tokenizer=self.tokenizer, | |
| unet=self.unet, | |
| motion_adapter=self.motion_adapter, | |
| flow_encoder=self.flow_encoder, | |
| scheduler=DDIMScheduler.from_pretrained(id_base_model, subfolder="scheduler"), | |
| ) | |
| if int(seed_textbox) > 0: | |
| seed = int(seed_textbox) | |
| else: | |
| seed = random.randint(1, int(1e16)) | |
| return animate_diffusion(seed, pipeline, self.raft, input_video, prompt_textbox_positive, prompt_textbox_negative, width, height, flow_scale, cfg, diffusion_steps, temporal_ds, ctx_stride) | |
| def animate_diffusion(seed, pipeline, raft_model, base_video, prompt_textbox, negative_prompt_textbox, width_slider, height_slider, flow_scale, cfg, diffusion_steps, temporal_ds, context_stride): | |
| savedir = './samples' | |
| device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | |
| generator = torch.Generator(device="cpu") | |
| generator.manual_seed(seed) | |
| raft_model = raft_model.to(device) | |
| pipeline = pipeline.to(device) | |
| pixel_values = torchvision.io.read_video(base_video, output_format="TCHW", pts_unit='sec')[0][::temporal_ds] | |
| print("Video loaded, shape:", pixel_values.shape) | |
| if width_slider/height_slider > pixel_values.shape[3]/pixel_values.shape[2]: | |
| print("Resizing video to fit width cause input video is not wide enough") | |
| temp_height = int(width_slider * pixel_values.shape[2]/pixel_values.shape[3]) | |
| temp_width = width_slider | |
| else: | |
| print("Resizing video to fit height cause input video is not tall enough") | |
| temp_height = height_slider | |
| temp_width = int(height_slider * pixel_values.shape[3]/pixel_values.shape[2]) | |
| print("Resizing video to:", temp_height, temp_width) | |
| pixel_values = T.Resize((temp_height, temp_width))(pixel_values) | |
| pixel_values = T.CenterCrop((height_slider, width_slider))(pixel_values) | |
| pixel_values = T.ConvertImageDtype(torch.float32)(pixel_values)[None, ...].contiguous().to(device) | |
| save_sample_path_input = os.path.join(savedir, f"input.mp4") | |
| pixel_values_save = pixel_values[0] * 255 | |
| pixel_values_save = pixel_values_save.cpu() | |
| pixel_values_save = torch.permute(pixel_values_save, (0, 2, 3, 1)) | |
| torchvision.io.write_video(save_sample_path_input, pixel_values_save, fps=8) | |
| del pixel_values_save | |
| print("Video loaded, shape:", pixel_values.shape) | |
| flow = get_optical_flow( | |
| raft_model, | |
| (pixel_values * 2) - 1, | |
| pixel_values.shape[1] - 1, | |
| encode_chunk_size=16, | |
| ).to('cpu') | |
| sample_flow = (flow_to_image(rearrange(flow[0], "c f h w -> f c h w"))) # N, 3, H, W | |
| save_sample_path_flow = os.path.join(savedir, f"flow.mp4") | |
| sample_flow = (sample_flow).cpu().to(torch.uint8).permute(0, 2, 3, 1) | |
| torchvision.io.write_video(save_sample_path_flow, sample_flow, fps=8) | |
| del sample_flow | |
| original_flow_shape = flow.shape | |
| print("Optical flow computed, shape:", flow.shape) | |
| if flow.shape[2] < 16: | |
| print("Video is too short, padding to 16 frames") | |
| video_length = 16 | |
| n = 16 - flow.shape[2] | |
| # create a tensor containing the last frame optical flow repeated n times | |
| to_add = flow[:, :, -1].unsqueeze(2).expand(-1, -1, n, -1, -1) | |
| flow = torch.cat([flow, to_add], dim=2).to(device) | |
| elif flow.shape[2] > 16: | |
| print("Video is too long, enabling windowing") | |
| print("Enabling model CPU offload") | |
| pipeline.enable_model_cpu_offload() | |
| print("Enabling VAE slicing") | |
| pipeline.enable_vae_slicing() | |
| print("Enabling VAE tiling") | |
| pipeline.enable_vae_tiling() | |
| print("Enabling free noise") | |
| pipeline.enable_free_noise( | |
| context_length=16, | |
| context_stride=context_stride, | |
| ) | |
| import math | |
| def find_divisors(n: int): | |
| """ | |
| Return sorted list of all positive divisors of n. | |
| Uses a sqrt(n) approach for efficiency. | |
| """ | |
| divs = set() | |
| limit = int(math.isqrt(n)) | |
| for i in range(1, limit + 1): | |
| if n % i == 0: | |
| divs.add(i) | |
| divs.add(n // i) | |
| return sorted(divs) | |
| def multiples_in_range(k: int, min_val: int, max_val: int): | |
| """ | |
| Return all multiples of k within [min_val, max_val]. | |
| """ | |
| if k == 0: | |
| return [] | |
| # First multiple of k >= min_val | |
| start = ((min_val + k - 1) // k) * k | |
| # Last multiple of k <= max_val | |
| end = (max_val // k) * k | |
| return list(range(start, end + 1, k)) if start <= end else [] | |
| def adjust_video_length(original_length: int, | |
| context_stride: int, | |
| chunk_size: int, | |
| temporal_split_size: int) -> int: | |
| """ | |
| Find the minimal video_length >= original_length satisfying: | |
| 1) (video_length - 16) is divisible by context_stride. | |
| 2) EITHER (2*video_length) is divisible by temporal_split_size | |
| OR (2*video_length) is divisible by chunk_size | |
| (when 2*video_length is not multiple of temporal_split_size). | |
| """ | |
| # We start at least at 16 (though in practice original_length likely > 16) | |
| candidate = max(original_length, 16) | |
| # We want (candidate - 16) % context_stride == 0 | |
| # so let n be the multiple to step. | |
| # n is how many times we add `context_stride` beyond 16. | |
| # This ensures (candidate - 16) is a multiple of context_stride. | |
| # Then we check the second condition, else keep stepping. | |
| # If candidate < 16, bump it to 16 | |
| if candidate < 16: | |
| candidate = 16 | |
| # Make sure we jump to the correct "starting multiple" of context_stride | |
| offset = (candidate - 16) % context_stride | |
| if offset != 0: | |
| candidate += (context_stride - offset) # jump to the next multiple | |
| while True: | |
| # Condition: (candidate - 16) is multiple of context_stride (already enforced by stepping) | |
| # Check second part: | |
| # - if (2*candidate) % temporal_split_size == 0, we are good | |
| # - else we require (2*candidate) % chunk_size == 0 | |
| twoL = 2 * candidate | |
| if (twoL % temporal_split_size == 0) or (twoL % chunk_size == 0): | |
| return candidate | |
| # Go to next valid candidate | |
| candidate += context_stride | |
| def find_valid_configs(original_video_length: int, | |
| width: int, | |
| height: int, | |
| context_stride: int): | |
| """ | |
| Generate all valid tuples (chunk_size, spatial_split_size, temporal_split_size, video_length) | |
| subject to the constraints: | |
| 1) chunk_size divides temporal_split_size | |
| 2) chunk_size divides spatial_split_size | |
| 3) chunk_size divides (2 * (width//64) * (height//64)) | |
| 4) if (2*video_length) % temporal_split_size != 0, then chunk_size divides (2*video_length) | |
| 5) context_stride divides (video_length - 16) | |
| 6) 128 <= spatial_split_size <= 512 | |
| 7) 1 <= temporal_split_size <= 32 | |
| 8) 1 <= chunk_size <= 16 | |
| We allow increasing original_video_length minimally if needed to satisfy constraints #4 and #5. | |
| """ | |
| factor = 2 * (width // 64) * (height // 64) | |
| # 1) find all possible chunk_size as divisors of factor, in [1..16] | |
| possible_chunks = [d for d in find_divisors(factor) if 1 <= d <= 32] | |
| # For storing results | |
| valid_tuples = [] | |
| for chunk_size in possible_chunks: | |
| # 2) generate all spatial_split_size in [128..512] that are multiples of chunk_size | |
| spatial_splits = multiples_in_range(chunk_size, 480, 512) | |
| # 3) generate all temporal_split_size in [1..32] that are multiples of chunk_size | |
| temporal_splits = multiples_in_range(chunk_size, 1, 32) | |
| for ssp in spatial_splits: | |
| for tsp in temporal_splits: | |
| # 4) & 5) Adjust video_length minimally to satisfy constraints | |
| final_length = adjust_video_length(original_video_length, | |
| context_stride, | |
| chunk_size, | |
| tsp) | |
| # Now we have a valid (chunk_size, ssp, tsp, final_length) | |
| valid_tuples.append((chunk_size, ssp, tsp, final_length)) | |
| return valid_tuples | |
| def find_pareto_optimal(configs): | |
| """ | |
| Given a list of tuples (chunk_size, spatial_split_size, temporal_split_size, video_length), | |
| return the Pareto-optimal subset under the criteria: | |
| - chunk_size: larger is better | |
| - spatial_split_size: larger is better | |
| - temporal_split_size: larger is better | |
| - video_length: smaller is better | |
| """ | |
| def dominates(A, B): | |
| cA, sA, tA, lA = A | |
| cB, sB, tB, lB = B | |
| # A dominates B if: | |
| # cA >= cB, sA >= sB, tA >= tB, and lA <= lB | |
| # AND at least one of these is a strict inequality. | |
| better_or_equal = (cA >= cB) and (tA >= tB) and (lA <= lB) | |
| strictly_better = (cA > cB) or (tA > tB) or (lA < lB) | |
| return better_or_equal and strictly_better | |
| pareto = [] | |
| for i, cfg_i in enumerate(configs): | |
| # Check if cfg_i is dominated by any cfg_j | |
| is_dominated = False | |
| for j, cfg_j in enumerate(configs): | |
| if i == j: | |
| continue | |
| if dominates(cfg_j, cfg_i): | |
| is_dominated = True | |
| break | |
| if not is_dominated: | |
| pareto.append(cfg_i) | |
| return pareto | |
| print("Finding valid configurations...") | |
| valid_configs = find_valid_configs( | |
| original_video_length=flow.shape[2], | |
| width=width_slider, | |
| height=height_slider, | |
| context_stride=context_stride | |
| ) | |
| print("Found", len(valid_configs), "valid configurations") | |
| print("Finding Pareto-optimal configurations...") | |
| pareto_optimal = find_pareto_optimal(valid_configs) | |
| print("Found", pareto_optimal) | |
| criteria = lambda cs, sss, tss, vl: cs + tss - 3 * int(abs(flow.shape[2] - vl) / 10) | |
| pareto_optimal.sort(key=lambda x: criteria(*x), reverse=True) | |
| print("Found sorted", pareto_optimal) | |
| solution = pareto_optimal[0] | |
| chunk_size, spatial_split_size, temporal_split_size, video_length = solution | |
| n = video_length - original_flow_shape[2] | |
| to_add = flow[:, :, -1].unsqueeze(2).expand(-1, -1, n, -1, -1) | |
| flow = torch.cat([flow, to_add], dim=2) | |
| pipeline.enable_free_noise_split_inference( | |
| temporal_split_size=temporal_split_size, | |
| spatial_split_size=spatial_split_size | |
| ) | |
| pipeline.unet.enable_forward_chunking(chunk_size) | |
| print("Chunking enabled with chunk size:", chunk_size) | |
| print("Temporal split size:", temporal_split_size) | |
| print("Spatial split size:", spatial_split_size) | |
| print("Context stride:", context_stride) | |
| print("Temporal downscale:", temporal_ds) | |
| print("Video length:", video_length) | |
| print("Flow shape:", flow.shape) | |
| else: | |
| print("Video is just right, no padding or windowing needed") | |
| flow = flow.to(device) | |
| video_length = flow.shape[2] | |
| sample_vid = pipeline( | |
| prompt_textbox, | |
| negative_prompt=negative_prompt_textbox, | |
| optical_flow=flow, | |
| num_inference_steps=diffusion_steps, | |
| guidance_scale=cfg, | |
| width=width_slider, | |
| height=height_slider, | |
| num_frames=video_length, | |
| val_scale_factor_temporal=flow_scale, | |
| generator=generator, | |
| ).frames[0] | |
| del flow | |
| if device == "cuda": | |
| torch.cuda.synchronize() | |
| torch.cuda.empty_cache() | |
| save_sample_path_video = os.path.join(savedir, f"sample.mp4") | |
| sample_vid = sample_vid[:original_flow_shape[2]] * 255. | |
| sample_vid = sample_vid.cpu().numpy() | |
| sample_vid = np.transpose(sample_vid, axes=(0, 2, 3, 1)) | |
| torchvision.io.write_video(save_sample_path_video, sample_vid, fps=8) | |
| return gr.Video(value=save_sample_path_flow), gr.Video(value=save_sample_path_video) | |
| controller = AnimateController() | |
| def find_closest_ratio(target_ratio): | |
| width_list = list(reversed(range(256, 1025, 64))) | |
| height_list = list(reversed(range(256, 1025, 64))) | |
| ratio_list = [(h, w, w/h) for h in height_list for w in width_list] | |
| ratio_list.sort(key=lambda x: abs(x[2] - target_ratio)) | |
| ratio_list = list(filter(lambda x: x[2] == ratio_list[0][2], ratio_list)) | |
| ratio_list.sort(key=lambda x: abs(x[0]*x[1] - 512*512)) | |
| return ratio_list[0][:2] | |
| def find_dimension(video): | |
| import av | |
| container = av.open(open(video, 'rb')) | |
| height, width = container.streams.video[0].height, container.streams.video[0].width | |
| target_ratio = width / height | |
| return find_closest_ratio(target_ratio) | |
| def ui(): | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown( | |
| """ | |
| # <p style="text-align:center;">OnlyFlow: Optical Flow based Motion Conditioning for Video Diffusion Models</p> | |
| Mathis Koroglu, Hugo Caselles-Dupré, Guillaume Jeanneret Sanmiguel, Matthieu Cord<br> | |
| [Arxiv Report](https://arxiv.org/abs/2411.10501) | [Project Page](https://obvious-research.github.io/onlyflow/) | [Github](https://github.com/obvious-research/onlyflow/) | |
| """ | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### Quick Start: | |
| 1. Select desired `Base Model`. | |
| 2. Select `Motion Module`. We recommend trying guoyww/animatediff-motion-adapter-v1-5-3 for the best results. | |
| 3. Provide `Positive Prompt` and `Negative Prompt`. You are encouraged to refer to each model's webpage on HuggingFace Hub or CivitAI to learn how to write prompts for them. | |
| 4. Upload a video to extract optical flow from. | |
| 5. Select a 'Flow Scale' to modulate the input video optical flow conditioning. | |
| 6. Select a 'CFG' and 'Diffusion Steps' to control the quality of the generated video and prompt adherence. | |
| 7. Select a 'Temporal Downsample' to reduce the number of frames in the input video. | |
| 8. If you want to use a custom dimension, check the `Custom Dimension` box and adjust the `Width` and `Height` sliders. | |
| 9. If the video is too long, you can adjust the generation window offset with the `Context Stride` slider. | |
| 10. Click `Generate`, wait for ~1/3 min, and enjoy the result! | |
| If you have any error concerning GPU limits, please try again later when your ZeroGPU quota is reset, or try with a shorter video. | |
| Otherwise, you can also duplicate this space and select a custom GPU plan. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("# INPUTS") | |
| with gr.Row(equal_height=True, show_progress=True): | |
| base_model = gr.Dropdown( | |
| label="Select or type a base model id", | |
| choices=[ | |
| "stable-diffusion-v1-5/stable-diffusion-v1-5", | |
| "digiplay/Photon_v1", | |
| ], | |
| interactive=True, | |
| scale=4, | |
| allow_custom_value=True, | |
| show_label=True | |
| ) | |
| base_model_btn = gr.Button(value="Update", scale=1, size='lg') | |
| with gr.Row(equal_height=True, show_progress=True): | |
| motion_module = gr.Dropdown( | |
| label="Select or type a motion module id", | |
| choices=[ | |
| "guoyww/animatediff-motion-adapter-v1-5-3", | |
| "guoyww/animatediff-motion-adapter-v1-5-2" | |
| ], | |
| interactive=True, | |
| scale=4 | |
| ) | |
| motion_module_btn = gr.Button(value="Update", scale=1, size='lg') | |
| base_model_btn.click(fn=controller.update_base_model, inputs=[base_model]) | |
| motion_module_btn.click(fn=controller.update_motion_module, inputs=[motion_module]) | |
| prompt_textbox_positive = gr.Textbox(label="Positive Prompt", lines=3) | |
| prompt_textbox_negative = gr.Textbox(label="Negative Prompt", lines=2, value="worst quality, low quality, nsfw, logo") | |
| flow_scale = gr.Slider(label="Flow Scale", value=1.0, minimum=0, maximum=2, step=0.025) | |
| diffusion_steps = gr.Slider(label="Diffusion Steps", value=25, minimum=0, maximum=100, step=1) | |
| cfg = gr.Slider(label="CFG", value=7.5, minimum=0, maximum=30, step=0.1) | |
| temporal_ds = gr.Slider(label="Temporal Downsample", value=1, minimum=1, maximum=30, step=1) | |
| input_video = gr.Video(label="Input Video", interactive=True) | |
| ctx_stride = gr.State(12) | |
| with gr.Accordion("Advanced", open=False): | |
| use_custom_dim = gr.Checkbox(label="Custom Dimension", value=False) | |
| with gr.Row(equal_height=True): | |
| height, width = gr.State(512), gr.State(512) | |
| def render_custom_dim(use_custom_dim, input_video): | |
| if input_video is not None: | |
| loc_height, loc_width = find_dimension(input_video) | |
| else: | |
| loc_height, loc_width = 512, 512 | |
| slider_width = gr.Slider(label="Width", value=loc_width, minimum=256, maximum=1024, | |
| step=64, visible=use_custom_dim) | |
| slider_height = gr.Slider(label="Height", value=loc_height, minimum=256, maximum=1024, | |
| step=64, visible=use_custom_dim) | |
| slider_width.change(lambda x: x, inputs=[slider_width], outputs=[width]) | |
| slider_height.change(lambda x: x, inputs=[slider_height], outputs=[height]) | |
| with gr.Row(): | |
| def render_ctx_stride(input_video): | |
| if input_video is not None: | |
| video = open(input_video, 'rb') | |
| import av | |
| container = av.open(video) | |
| num_frames = container.streams.video[0].frames | |
| if num_frames > 17: | |
| stride_slider = gr.Slider(label="Context Stride", value=12, minimum=1, maximum=16, step=1) | |
| stride_slider.input(lambda x: x, inputs=[stride_slider], outputs=[ctx_stride]) | |
| if num_frames > 32: | |
| gr.Warning(f"Video is long ({num_frames} frames), consider using a shorter video, increasing the context stride, or selecting a custom GPU plan.") | |
| elif num_frames > 64: | |
| raise gr.Error(f"Video is too long ({num_frames} frames), please use a shorter video, increase the context stride, or select a custom GPU plan. The current parameters won't allow generation on ZeroGPU.") | |
| with gr.Row(equal_height=True): | |
| seed_textbox = gr.Textbox(label="Seed", value='-1') | |
| seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") | |
| seed_button.click( | |
| fn=lambda: random.randint(1, int(1e16)), | |
| inputs=[], | |
| outputs=[seed_textbox] | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.ClearButton(value="Clear & Reset", size='lg', variant='secondary', scale=1) | |
| generate_button = gr.Button(value="Generate", variant='primary', scale=2, size='lg') | |
| clear_btn.add([base_model, motion_module, input_video, prompt_textbox_positive, prompt_textbox_negative, seed_textbox, use_custom_dim, ctx_stride]) | |
| with gr.Column(): | |
| gr.Markdown("# OUTPUTS") | |
| result_optical_flow = gr.Video(label="Optical Flow", interactive=False) | |
| result_video = gr.Video(label="Generated Animation", interactive=False) | |
| inputs = [base_model, motion_module, prompt_textbox_positive, prompt_textbox_negative, seed_textbox, input_video, height, width, flow_scale, cfg, diffusion_steps, temporal_ds, ctx_stride] | |
| outputs = [result_optical_flow, result_video] | |
| generate_button.click(fn=controller.animate, inputs=inputs, outputs=outputs) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = ui() | |
| demo.queue(max_size=20) | |
| demo.launch() |