Spaces:
Build error
Build error
resolve conflict
Browse files
app.py
CHANGED
|
@@ -28,7 +28,7 @@ def handler(signum, frame):
|
|
| 28 |
if res == 'y':
|
| 29 |
gr.close_all()
|
| 30 |
exit(1)
|
| 31 |
-
|
| 32 |
signal.signal(signal.SIGINT, handler)
|
| 33 |
|
| 34 |
|
|
@@ -56,7 +56,7 @@ def check_name(model_name='FFHQ512'):
|
|
| 56 |
"""Gets model by name."""
|
| 57 |
if model_name == 'FFHQ512':
|
| 58 |
network_pkl = hf_hub_download(repo_id='thomagram/stylenerf-ffhq-config-basic', filename='ffhq_512.pkl')
|
| 59 |
-
|
| 60 |
# TODO: checkpoint to be updated!
|
| 61 |
# elif model_name == 'FFHQ512v2':
|
| 62 |
# network_pkl = "./pretrained/ffhq_512_eg3d.pkl"
|
|
@@ -109,10 +109,10 @@ def proc_seed(history, seed):
|
|
| 109 |
def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, history):
|
| 110 |
history = history or {}
|
| 111 |
seeds = []
|
| 112 |
-
|
| 113 |
if model_find != "":
|
| 114 |
model_name = model_find
|
| 115 |
-
|
| 116 |
model_name = check_name(model_name)
|
| 117 |
if model_name != history.get("model_name", None):
|
| 118 |
model, res, imgs = get_model(model_name, render_option)
|
|
@@ -139,7 +139,7 @@ def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed
|
|
| 139 |
ws = ws.detach().cpu().numpy()
|
| 140 |
img = img[0].permute(1,2,0).detach().cpu().numpy()
|
| 141 |
|
| 142 |
-
|
| 143 |
imgs[idx * res // 2: (1 + idx) * res // 2] = cv2.resize(
|
| 144 |
np.asarray(img).clip(-1, 1) * 0.5 + 0.5,
|
| 145 |
(res//2, res//2), cv2.INTER_AREA)
|
|
@@ -151,7 +151,7 @@ def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed
|
|
| 151 |
history[f'seed{idx}'] = seed
|
| 152 |
history['trunc'] = trunc
|
| 153 |
history['model_name'] = model_name
|
| 154 |
-
|
| 155 |
set_random_seed(sum(seeds))
|
| 156 |
|
| 157 |
# style mixing (?)
|
|
@@ -159,18 +159,18 @@ def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed
|
|
| 159 |
ws = ws1.clone()
|
| 160 |
ws[:, :8] = ws1[:, :8] * mix1 + ws2[:, :8] * (1 - mix1)
|
| 161 |
ws[:, 8:] = ws1[:, 8:] * mix2 + ws2[:, 8:] * (1 - mix2)
|
| 162 |
-
|
| 163 |
# set visualization for other types of inputs.
|
| 164 |
if early == 'Normal Map':
|
| 165 |
render_option += ',normal,early'
|
| 166 |
elif early == 'Gradient Map':
|
| 167 |
render_option += ',gradient,early'
|
| 168 |
-
|
| 169 |
start_t = time.time()
|
| 170 |
with torch.no_grad():
|
| 171 |
cam = get_camera_traj(model, pitch, yaw, fov, model_name=model_name)
|
| 172 |
image = model.get_final_output(
|
| 173 |
-
styles=ws, camera_matrices=cam,
|
| 174 |
theta=roll * np.pi,
|
| 175 |
render_option=render_option)
|
| 176 |
end_t = time.time()
|
|
@@ -184,7 +184,7 @@ def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed
|
|
| 184 |
b = int(imgs.shape[1] / imgs.shape[0] * a)
|
| 185 |
print(f'resize {a} {b} {image.shape} {imgs.shape}')
|
| 186 |
image = np.concatenate([cv2.resize(imgs, (b, a), cv2.INTER_AREA), image], 1)
|
| 187 |
-
|
| 188 |
print(f'rendering time = {end_t-start_t:.4f}s')
|
| 189 |
image = (image * 255).astype('uint8')
|
| 190 |
return image, history
|
|
@@ -210,4 +210,4 @@ gr.Interface(fn=f_synthesis,
|
|
| 210 |
outputs=["image", "state"],
|
| 211 |
layout='unaligned',
|
| 212 |
css=css, theme='dark-huggingface',
|
| 213 |
-
live=True).launch(
|
|
|
|
| 28 |
if res == 'y':
|
| 29 |
gr.close_all()
|
| 30 |
exit(1)
|
| 31 |
+
|
| 32 |
signal.signal(signal.SIGINT, handler)
|
| 33 |
|
| 34 |
|
|
|
|
| 56 |
"""Gets model by name."""
|
| 57 |
if model_name == 'FFHQ512':
|
| 58 |
network_pkl = hf_hub_download(repo_id='thomagram/stylenerf-ffhq-config-basic', filename='ffhq_512.pkl')
|
| 59 |
+
|
| 60 |
# TODO: checkpoint to be updated!
|
| 61 |
# elif model_name == 'FFHQ512v2':
|
| 62 |
# network_pkl = "./pretrained/ffhq_512_eg3d.pkl"
|
|
|
|
| 109 |
def f_synthesis(model_name, model_find, render_option, early, trunc, seed1, seed2, mix1, mix2, yaw, pitch, roll, fov, history):
|
| 110 |
history = history or {}
|
| 111 |
seeds = []
|
| 112 |
+
|
| 113 |
if model_find != "":
|
| 114 |
model_name = model_find
|
| 115 |
+
|
| 116 |
model_name = check_name(model_name)
|
| 117 |
if model_name != history.get("model_name", None):
|
| 118 |
model, res, imgs = get_model(model_name, render_option)
|
|
|
|
| 139 |
ws = ws.detach().cpu().numpy()
|
| 140 |
img = img[0].permute(1,2,0).detach().cpu().numpy()
|
| 141 |
|
| 142 |
+
|
| 143 |
imgs[idx * res // 2: (1 + idx) * res // 2] = cv2.resize(
|
| 144 |
np.asarray(img).clip(-1, 1) * 0.5 + 0.5,
|
| 145 |
(res//2, res//2), cv2.INTER_AREA)
|
|
|
|
| 151 |
history[f'seed{idx}'] = seed
|
| 152 |
history['trunc'] = trunc
|
| 153 |
history['model_name'] = model_name
|
| 154 |
+
|
| 155 |
set_random_seed(sum(seeds))
|
| 156 |
|
| 157 |
# style mixing (?)
|
|
|
|
| 159 |
ws = ws1.clone()
|
| 160 |
ws[:, :8] = ws1[:, :8] * mix1 + ws2[:, :8] * (1 - mix1)
|
| 161 |
ws[:, 8:] = ws1[:, 8:] * mix2 + ws2[:, 8:] * (1 - mix2)
|
| 162 |
+
|
| 163 |
# set visualization for other types of inputs.
|
| 164 |
if early == 'Normal Map':
|
| 165 |
render_option += ',normal,early'
|
| 166 |
elif early == 'Gradient Map':
|
| 167 |
render_option += ',gradient,early'
|
| 168 |
+
|
| 169 |
start_t = time.time()
|
| 170 |
with torch.no_grad():
|
| 171 |
cam = get_camera_traj(model, pitch, yaw, fov, model_name=model_name)
|
| 172 |
image = model.get_final_output(
|
| 173 |
+
styles=ws, camera_matrices=cam,
|
| 174 |
theta=roll * np.pi,
|
| 175 |
render_option=render_option)
|
| 176 |
end_t = time.time()
|
|
|
|
| 184 |
b = int(imgs.shape[1] / imgs.shape[0] * a)
|
| 185 |
print(f'resize {a} {b} {image.shape} {imgs.shape}')
|
| 186 |
image = np.concatenate([cv2.resize(imgs, (b, a), cv2.INTER_AREA), image], 1)
|
| 187 |
+
|
| 188 |
print(f'rendering time = {end_t-start_t:.4f}s')
|
| 189 |
image = (image * 255).astype('uint8')
|
| 190 |
return image, history
|
|
|
|
| 210 |
outputs=["image", "state"],
|
| 211 |
layout='unaligned',
|
| 212 |
css=css, theme='dark-huggingface',
|
| 213 |
+
live=True).launch(enable_queue=True)
|