vivienfanghua commited on
Commit
a9cc2b4
·
verified ·
1 Parent(s): 5608ba7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -66
app.py CHANGED
@@ -77,22 +77,39 @@ def _extract_video_from_history(history: Dict[str, Any]) -> Dict[str, str]:
77
  return {"filename": it["filename"], "subfolder": it["subfolder"], "type": it["type"]}
78
  raise RuntimeError("No video file found in history outputs")
79
 
80
- with gr.Blocks(title="Wan 2.2 T2V UI running on AMD MI300x") as demo:
 
 
 
 
 
 
 
 
 
 
81
  st_token = gr.State()
82
- with gr.Row():
83
- text = gr.Textbox(label="Prompt", placeholder="Text to generate", lines=3)
84
- with gr.Row():
85
- width = gr.Number(label="Width", value=1280, precision=0)
86
- height = gr.Number(label="Height", value=704, precision=0)
87
- length = gr.Number(label="FPS", value=121, precision=0)
88
- fps = gr.Number(label="FPS", value=24, precision=0)
89
- with gr.Row():
90
- steps = gr.Number(label="Steps", value=20, precision=0)
91
- cfg = gr.Number(label="CFG", value=5.0)
92
- seed = gr.Number(label="Seed", value=None)
93
- filename_prefix = gr.Textbox(label="Prefix of video", value="video/ComfyUI")
94
- run_btn = gr.Button("Generate")
95
- prog_bar = gr.Slider(label="Step", minimum=0, maximum=100, value=0, step=1, interactive=False)
 
 
 
 
 
 
 
96
  out_video = gr.Video(label="Result")
97
 
98
  def _init_token():
@@ -101,62 +118,63 @@ with gr.Blocks(title="Wan 2.2 T2V UI running on AMD MI300x") as demo:
101
  demo.load(_init_token, outputs=st_token)
102
 
103
  def generate_fn(text, width, height, length, fps, steps, cfg, seed, filename_prefix, token):
104
- def _runner():
105
- req = T2VReq(
106
- token=token,
107
- text=text,
108
- seed=int(seed) if seed is not None else None,
109
- steps=int(steps) if steps is not None else None,
110
- cfg=float(cfg) if cfg is not None else None,
111
- width=int(width) if width is not None else None,
112
- height=int(height) if height is not None else None,
113
- length=int(length) if length is not None else None,
114
- fps=int(fps) if fps is not None else None,
115
- filename_prefix=filename_prefix if filename_prefix else None,
116
- )
117
- prompt = _inject_params(WORKFLOW_TEMPLATE, req)
118
- client_id = str(uuid.uuid4())
119
- ws = _open_ws(client_id, req.token)
120
- prompt_id = _queue_prompt(prompt, client_id, req.token)
121
- total_nodes = max(1, len(prompt))
122
- seen = set()
123
- p = 0
124
- last_emit = -1
125
- start = time.time()
126
- ws.settimeout(60)
127
- while True:
128
- out = ws.recv()
129
- if isinstance(out, (bytes, bytearray)):
130
- if p < 95 and time.time() - start > 2:
131
- p = min(95, p + 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  if p != last_emit:
133
  last_emit = p
134
  yield p, None
135
- continue
136
- msg = json.loads(out)
137
- if msg.get("type") == "executing":
138
- data = msg.get("data", {})
139
- if data.get("prompt_id") != prompt_id:
140
- continue
141
- node = data.get("node")
142
- if node is None:
143
- break
144
- if node not in seen:
145
- seen.add(node)
146
- p = min(99, int(len(seen) / total_nodes * 100))
147
- if p != last_emit:
148
- last_emit = p
149
- yield p, None
150
- ws.close()
151
- hist = _get_history(prompt_id, req.token)
152
- info = _extract_video_from_history(hist)
153
- q = urlencode(info)
154
- video_url = f"http://{COMFY_HOST}/view?{q}"
155
- yield 100, video_url
156
- return _runner()
157
 
158
  run_btn.click(
159
  generate_fn,
160
  inputs=[text, width, height, length, fps, steps, cfg, seed, filename_prefix, st_token],
161
  outputs=[prog_bar, out_video]
162
- )
 
 
 
 
77
  return {"filename": it["filename"], "subfolder": it["subfolder"], "type": it["type"]}
78
  raise RuntimeError("No video file found in history outputs")
79
 
80
+ sample_prompts = [
81
+ "A golden retriever running across a beach at sunset, cinematic, 24fps",
82
+ "A cyberpunk city street at night with neon lights, light rain, slow pan",
83
+ "Aerial drone shot over snowy mountains, 5 seconds, dramatic lighting",
84
+ "Cartoon-style cat riding a skateboard in a park, vibrant colors"
85
+ ]
86
+
87
+ with gr.Blocks(
88
+ title="T2V UI",
89
+ theme=gr.themes.Soft(primary_hue="blue", secondary_hue="blue", neutral_hue="slate"),
90
+ ) as demo:
91
  st_token = gr.State()
92
+ gr.Markdown("# Wan2.2 T2V running on AMD MI300x")
93
+ gr.Markdown("### Prompt")
94
+ text = gr.Textbox(label="Prompt", placeholder="Describe the video you want", lines=3)
95
+
96
+ gr.Examples(examples=sample_prompts, inputs=text)
97
+
98
+ with gr.Accordion("Advanced Settings", open=False):
99
+ with gr.Row():
100
+ width = gr.Number(label="Width", value=1280, precision=0)
101
+ height = gr.Number(label="Height", value=704, precision=0)
102
+ with gr.Row():
103
+ length = gr.Number(label="Frames", value=121, precision=0)
104
+ fps = gr.Number(label="FPS", value=24, precision=0)
105
+ with gr.Row():
106
+ steps = gr.Number(label="Steps", value=20, precision=0)
107
+ cfg = gr.Number(label="CFG", value=5.0)
108
+ seed = gr.Number(label="Seed (optional)", value=None)
109
+ filename_prefix = gr.Textbox(label="Filename prefix", value="video/ComfyUI")
110
+
111
+ run_btn = gr.Button("Generate", variant="primary")
112
+ prog_bar = gr.Slider(label="Progress", minimum=0, maximum=100, value=0, step=1, interactive=False)
113
  out_video = gr.Video(label="Result")
114
 
115
  def _init_token():
 
118
  demo.load(_init_token, outputs=st_token)
119
 
120
  def generate_fn(text, width, height, length, fps, steps, cfg, seed, filename_prefix, token):
121
+ req = T2VReq(
122
+ token=token,
123
+ text=text,
124
+ seed=int(seed) if seed is not None else None,
125
+ steps=int(steps) if steps is not None else None,
126
+ cfg=float(cfg) if cfg is not None else None,
127
+ width=int(width) if width is not None else None,
128
+ height=int(height) if height is not None else None,
129
+ length=int(length) if length is not None else None,
130
+ fps=int(fps) if fps is not None else None,
131
+ filename_prefix=filename_prefix if filename_prefix else None,
132
+ )
133
+ prompt = _inject_params(WORKFLOW_TEMPLATE, req)
134
+ client_id = str(uuid.uuid4())
135
+ ws = _open_ws(client_id, req.token)
136
+ prompt_id = _queue_prompt(prompt, client_id, req.token)
137
+ total_nodes = max(1, len(prompt))
138
+ seen = set()
139
+ p = 0
140
+ last_emit = -1
141
+ start = time.time()
142
+ ws.settimeout(60)
143
+ while True:
144
+ out = ws.recv()
145
+ if isinstance(out, (bytes, bytearray)):
146
+ if p < 95 and time.time() - start > 2:
147
+ p = min(95, p + 1)
148
+ if p != last_emit:
149
+ last_emit = p
150
+ yield p, None
151
+ continue
152
+ msg = json.loads(out)
153
+ if msg.get("type") == "executing":
154
+ data = msg.get("data", {})
155
+ if data.get("prompt_id") != prompt_id:
156
+ continue
157
+ node = data.get("node")
158
+ if node is None:
159
+ break
160
+ if node not in seen:
161
+ seen.add(node)
162
+ p = min(99, int(len(seen) / total_nodes * 100))
163
  if p != last_emit:
164
  last_emit = p
165
  yield p, None
166
+ ws.close()
167
+ hist = _get_history(prompt_id, req.token)
168
+ info = _extract_video_from_history(hist)
169
+ q = urlencode(info)
170
+ video_url = f"http://{COMFY_HOST}/view?{q}&token={req.token}"
171
+ yield 100, video_url
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  run_btn.click(
174
  generate_fn,
175
  inputs=[text, width, height, length, fps, steps, cfg, seed, filename_prefix, st_token],
176
  outputs=[prog_bar, out_video]
177
+ )
178
+
179
+ if __name__ == "__main__":
180
+ demo.queue().launch(server_name="0.0.0.0", server_port=9001)