Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -116,13 +116,15 @@ os.makedirs("./gradio_tmp", exist_ok=True)
|
|
| 116 |
upscale_model = load_sd_upscale("model_real_esran/RealESRGAN_x4.pth", device)
|
| 117 |
frame_interpolation_model = load_rife_model("model_rife")
|
| 118 |
|
| 119 |
-
|
|
|
|
| 120 |
prompt: str,
|
| 121 |
image_input: str,
|
| 122 |
num_inference_steps: int,
|
| 123 |
guidance_scale: float,
|
| 124 |
seed: int = 42,
|
| 125 |
-
|
|
|
|
| 126 |
):
|
| 127 |
if seed == -1:
|
| 128 |
seed = random.randint(0, 2**8 - 1)
|
|
@@ -167,6 +169,12 @@ def infer(
|
|
| 167 |
).frames
|
| 168 |
|
| 169 |
free_memory()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
return (video_pt, seed)
|
| 171 |
|
| 172 |
|
|
@@ -320,8 +328,8 @@ with gr.Blocks() as demo:
|
|
| 320 |
</table>
|
| 321 |
""")
|
| 322 |
|
| 323 |
-
|
| 324 |
-
def
|
| 325 |
prompt,
|
| 326 |
image_input,
|
| 327 |
seed_value,
|
|
@@ -329,18 +337,15 @@ with gr.Blocks() as demo:
|
|
| 329 |
rife_status,
|
| 330 |
progress=gr.Progress(track_tqdm=True)
|
| 331 |
):
|
| 332 |
-
latents, seed =
|
| 333 |
prompt,
|
| 334 |
image_input,
|
| 335 |
num_inference_steps=50,
|
| 336 |
guidance_scale=7.0,
|
| 337 |
seed=seed_value,
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
latents = upscale_batch_and_concatenate(upscale_model, latents, device)
|
| 342 |
-
if rife_status:
|
| 343 |
-
latents = rife_inference_with_latents(frame_interpolation_model, latents)
|
| 344 |
|
| 345 |
batch_size = latents.shape[0]
|
| 346 |
batch_video_frames = []
|
|
@@ -361,11 +366,11 @@ with gr.Blocks() as demo:
|
|
| 361 |
return video_path, video_update, gif_update, seed_update
|
| 362 |
|
| 363 |
generate_button.click(
|
| 364 |
-
|
| 365 |
inputs=[prompt, image_input, seed_param, enable_scale, enable_rife],
|
| 366 |
outputs=[video_output, download_video_button, download_gif_button, seed_text],
|
| 367 |
)
|
| 368 |
|
| 369 |
if __name__ == "__main__":
|
| 370 |
demo.queue(max_size=15)
|
| 371 |
-
demo.launch()
|
|
|
|
| 116 |
upscale_model = load_sd_upscale("model_real_esran/RealESRGAN_x4.pth", device)
|
| 117 |
frame_interpolation_model = load_rife_model("model_rife")
|
| 118 |
|
| 119 |
+
@spaces.GPU(duration=300)
|
| 120 |
+
def generate(
|
| 121 |
prompt: str,
|
| 122 |
image_input: str,
|
| 123 |
num_inference_steps: int,
|
| 124 |
guidance_scale: float,
|
| 125 |
seed: int = 42,
|
| 126 |
+
scale_status: bool = False,
|
| 127 |
+
rife_status: bool = False,
|
| 128 |
):
|
| 129 |
if seed == -1:
|
| 130 |
seed = random.randint(0, 2**8 - 1)
|
|
|
|
| 169 |
).frames
|
| 170 |
|
| 171 |
free_memory()
|
| 172 |
+
|
| 173 |
+
if scale_status:
|
| 174 |
+
video_pt = upscale_batch_and_concatenate(upscale_model, video_pt, device)
|
| 175 |
+
if rife_status:
|
| 176 |
+
video_pt = rife_inference_with_latents(frame_interpolation_model, video_pt)
|
| 177 |
+
|
| 178 |
return (video_pt, seed)
|
| 179 |
|
| 180 |
|
|
|
|
| 328 |
</table>
|
| 329 |
""")
|
| 330 |
|
| 331 |
+
|
| 332 |
+
def run(
|
| 333 |
prompt,
|
| 334 |
image_input,
|
| 335 |
seed_value,
|
|
|
|
| 337 |
rife_status,
|
| 338 |
progress=gr.Progress(track_tqdm=True)
|
| 339 |
):
|
| 340 |
+
latents, seed = generate(
|
| 341 |
prompt,
|
| 342 |
image_input,
|
| 343 |
num_inference_steps=50,
|
| 344 |
guidance_scale=7.0,
|
| 345 |
seed=seed_value,
|
| 346 |
+
scale_status=scale_status,
|
| 347 |
+
rife_status=rife_status,
|
| 348 |
+
)
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
batch_size = latents.shape[0]
|
| 351 |
batch_video_frames = []
|
|
|
|
| 366 |
return video_path, video_update, gif_update, seed_update
|
| 367 |
|
| 368 |
generate_button.click(
|
| 369 |
+
fn=run,
|
| 370 |
inputs=[prompt, image_input, seed_param, enable_scale, enable_rife],
|
| 371 |
outputs=[video_output, download_video_button, download_gif_button, seed_text],
|
| 372 |
)
|
| 373 |
|
| 374 |
if __name__ == "__main__":
|
| 375 |
demo.queue(max_size=15)
|
| 376 |
+
demo.launch()
|