Update app.py
Browse files
app.py
CHANGED
|
@@ -277,7 +277,7 @@ print('Loading custom white-background unet ...')
|
|
| 277 |
if os.path.exists(infer_config.unet_path):
|
| 278 |
unet_ckpt_path = infer_config.unet_path
|
| 279 |
else:
|
| 280 |
-
unet_ckpt_path = hf_hub_download(repo_id="LTT/
|
| 281 |
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
|
| 282 |
pipeline.unet.load_state_dict(state_dict, strict=True)
|
| 283 |
|
|
@@ -289,7 +289,7 @@ model = instantiate_from_config(model_config)
|
|
| 289 |
if os.path.exists(infer_config.model_path):
|
| 290 |
model_ckpt_path = infer_config.model_path
|
| 291 |
else:
|
| 292 |
-
model_ckpt_path = hf_hub_download(repo_id="LTT/
|
| 293 |
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
| 294 |
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
| 295 |
model.load_state_dict(state_dict, strict=True)
|
|
|
|
| 277 |
if os.path.exists(infer_config.unet_path):
|
| 278 |
unet_ckpt_path = infer_config.unet_path
|
| 279 |
else:
|
| 280 |
+
unet_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="diffusion_pytorch_model.bin", repo_type="model")
|
| 281 |
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
|
| 282 |
pipeline.unet.load_state_dict(state_dict, strict=True)
|
| 283 |
|
|
|
|
| 289 |
if os.path.exists(infer_config.model_path):
|
| 290 |
model_ckpt_path = infer_config.model_path
|
| 291 |
else:
|
| 292 |
+
model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
|
| 293 |
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
| 294 |
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
| 295 |
model.load_state_dict(state_dict, strict=True)
|