| | import glob |
| | import os |
| | import os.path as osp |
| |
|
| | import fire |
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from PIL import Image |
| | from tqdm import tqdm |
| |
|
| | from seva.data_io import get_parser |
| | from seva.eval import ( |
| | IS_TORCH_NIGHTLY, |
| | compute_relative_inds, |
| | create_transforms_simple, |
| | infer_prior_inds, |
| | infer_prior_stats, |
| | run_one_scene, |
| | ) |
| | from seva.geometry import ( |
| | generate_interpolated_path, |
| | generate_spiral_path, |
| | get_arc_horizontal_w2cs, |
| | get_default_intrinsics, |
| | get_lookat, |
| | get_preset_pose_fov, |
| | ) |
| | from seva.model import SGMWrapper |
| | from seva.modules.autoencoder import AutoEncoder |
| | from seva.modules.conditioner import CLIPConditioner |
| | from seva.sampling import DDPMDiscretization, DiscreteDenoiser |
| | from seva.utils import load_model |
| |
|
| | device = "cuda:0" |
| |
|
| |
|
| | |
| | WORK_DIR = "work_dirs/demo" |
| |
|
| | if IS_TORCH_NIGHTLY: |
| | COMPILE = True |
| | os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1" |
| | os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1" |
| | else: |
| | COMPILE = False |
| |
|
| | MODEL = SGMWrapper(load_model(device="cpu", verbose=True).eval()).to(device) |
| | AE = AutoEncoder(chunk_size=1).to(device) |
| | CONDITIONER = CLIPConditioner().to(device) |
| | DISCRETIZATION = DDPMDiscretization() |
| | DENOISER = DiscreteDenoiser(discretization=DISCRETIZATION, num_idx=1000, device=device) |
| | VERSION_DICT = { |
| | "H": 576, |
| | "W": 576, |
| | "T": 21, |
| | "C": 4, |
| | "f": 8, |
| | "options": {}, |
| | } |
| |
|
| | if COMPILE: |
| | MODEL = torch.compile(MODEL, dynamic=False) |
| | CONDITIONER = torch.compile(CONDITIONER, dynamic=False) |
| | AE = torch.compile(AE, dynamic=False) |
| |
|
| |
|
| | def parse_task( |
| | task, |
| | scene, |
| | num_inputs, |
| | T, |
| | version_dict, |
| | ): |
| | options = version_dict["options"] |
| |
|
| | anchor_indices = None |
| | anchor_c2ws = None |
| | anchor_Ks = None |
| |
|
| | if task == "img2trajvid_s-prob": |
| | if num_inputs is not None: |
| | assert ( |
| | num_inputs == 1 |
| | ), "Task `img2trajvid_s-prob` only support 1-view conditioning..." |
| | else: |
| | num_inputs = 1 |
| | num_targets = options.get("num_targets", T - 1) |
| | num_anchors = infer_prior_stats( |
| | T, |
| | num_inputs, |
| | num_total_frames=num_targets, |
| | version_dict=version_dict, |
| | ) |
| |
|
| | input_indices = [0] |
| | anchor_indices = np.linspace(1, num_targets, num_anchors).tolist() |
| |
|
| | all_imgs_path = [scene] + [None] * num_targets |
| |
|
| | c2ws, fovs = get_preset_pose_fov( |
| | option=options.get("traj_prior", "orbit"), |
| | num_frames=num_targets + 1, |
| | start_w2c=torch.eye(4), |
| | look_at=torch.Tensor([0, 0, 10]), |
| | ) |
| |
|
| | with Image.open(scene) as img: |
| | W, H = img.size |
| | aspect_ratio = W / H |
| | Ks = get_default_intrinsics(fovs, aspect_ratio=aspect_ratio) |
| | Ks[:, :2] *= ( |
| | torch.tensor([W, H]).reshape(1, -1, 1).repeat(Ks.shape[0], 1, 1) |
| | ) |
| | Ks = Ks.numpy() |
| |
|
| | anchor_c2ws = c2ws[[round(ind) for ind in anchor_indices]] |
| | anchor_Ks = Ks[[round(ind) for ind in anchor_indices]] |
| |
|
| | else: |
| | parser = get_parser( |
| | parser_type="reconfusion", |
| | data_dir=scene, |
| | normalize=False, |
| | ) |
| | all_imgs_path = parser.image_paths |
| | c2ws = parser.camtoworlds |
| | camera_ids = parser.camera_ids |
| | Ks = np.concatenate([parser.Ks_dict[cam_id][None] for cam_id in camera_ids], 0) |
| |
|
| | if num_inputs is None: |
| | assert len(parser.splits_per_num_input_frames.keys()) == 1 |
| | num_inputs = list(parser.splits_per_num_input_frames.keys())[0] |
| | split_dict = parser.splits_per_num_input_frames[num_inputs] |
| | elif isinstance(num_inputs, str): |
| | split_dict = parser.splits_per_num_input_frames[num_inputs] |
| | num_inputs = int(num_inputs.split("-")[0]) |
| | else: |
| | split_dict = parser.splits_per_num_input_frames[num_inputs] |
| |
|
| | num_targets = len(split_dict["test_ids"]) |
| |
|
| | if task == "img2img": |
| | |
| | |
| | num_anchors = infer_prior_stats( |
| | T, |
| | num_inputs, |
| | num_total_frames=num_targets, |
| | version_dict=version_dict, |
| | ) |
| |
|
| | sampled_indices = np.sort( |
| | np.array(split_dict["train_ids"] + split_dict["test_ids"]) |
| | ) |
| |
|
| | traj_prior = options.get("traj_prior", None) |
| | if traj_prior == "spiral": |
| | assert parser.bounds is not None |
| | anchor_c2ws = generate_spiral_path( |
| | c2ws[sampled_indices] @ np.diagflat([1, -1, -1, 1]), |
| | parser.bounds[sampled_indices], |
| | n_frames=num_anchors + 1, |
| | n_rots=2, |
| | zrate=0.5, |
| | endpoint=False, |
| | )[1:] @ np.diagflat([1, -1, -1, 1]) |
| | elif traj_prior == "interpolated": |
| | assert num_inputs > 1 |
| | anchor_c2ws = generate_interpolated_path( |
| | c2ws[split_dict["train_ids"], :3], |
| | round((num_anchors + 1) / (num_inputs - 1)), |
| | endpoint=False, |
| | )[1 : num_anchors + 1] |
| | elif traj_prior == "orbit": |
| | c2ws_th = torch.as_tensor(c2ws) |
| | lookat = get_lookat( |
| | c2ws_th[sampled_indices, :3, 3], |
| | c2ws_th[sampled_indices, :3, 2], |
| | ) |
| | anchor_c2ws = torch.linalg.inv( |
| | get_arc_horizontal_w2cs( |
| | torch.linalg.inv(c2ws_th[split_dict["train_ids"][0]]), |
| | lookat, |
| | -F.normalize( |
| | c2ws_th[split_dict["train_ids"]][:, :3, 1].mean(0), |
| | dim=-1, |
| | ), |
| | num_frames=num_anchors + 1, |
| | endpoint=False, |
| | ) |
| | ).numpy()[1:, :3] |
| | else: |
| | anchor_c2ws = None |
| | |
| |
|
| | all_imgs_path = [all_imgs_path[i] for i in sampled_indices] |
| | c2ws = c2ws[sampled_indices] |
| | Ks = Ks[sampled_indices] |
| |
|
| | |
| | input_indices = compute_relative_inds( |
| | sampled_indices, |
| | np.array(split_dict["train_ids"]), |
| | ) |
| | anchor_indices = np.arange( |
| | sampled_indices.shape[0], |
| | sampled_indices.shape[0] + num_anchors, |
| | ).tolist() |
| |
|
| | elif task == "img2vid": |
| | num_targets = len(all_imgs_path) - num_inputs |
| | num_anchors = infer_prior_stats( |
| | T, |
| | num_inputs, |
| | num_total_frames=num_targets, |
| | version_dict=version_dict, |
| | ) |
| |
|
| | input_indices = split_dict["train_ids"] |
| | anchor_indices = infer_prior_inds( |
| | c2ws, |
| | num_prior_frames=num_anchors, |
| | input_frame_indices=input_indices, |
| | options=options, |
| | ).tolist() |
| | num_anchors = len(anchor_indices) |
| | anchor_c2ws = c2ws[anchor_indices, :3] |
| | anchor_Ks = Ks[anchor_indices] |
| |
|
| | elif task == "img2trajvid": |
| | num_anchors = infer_prior_stats( |
| | T, |
| | num_inputs, |
| | num_total_frames=num_targets, |
| | version_dict=version_dict, |
| | ) |
| |
|
| | target_c2ws = c2ws[split_dict["test_ids"], :3] |
| | target_Ks = Ks[split_dict["test_ids"]] |
| | anchor_c2ws = target_c2ws[ |
| | np.linspace(0, num_targets - 1, num_anchors).round().astype(np.int64) |
| | ] |
| | anchor_Ks = target_Ks[ |
| | np.linspace(0, num_targets - 1, num_anchors).round().astype(np.int64) |
| | ] |
| |
|
| | sampled_indices = split_dict["train_ids"] + split_dict["test_ids"] |
| | all_imgs_path = [all_imgs_path[i] for i in sampled_indices] |
| | c2ws = c2ws[sampled_indices] |
| | Ks = Ks[sampled_indices] |
| |
|
| | input_indices = np.arange(num_inputs).tolist() |
| | anchor_indices = np.linspace( |
| | num_inputs, num_inputs + num_targets - 1, num_anchors |
| | ).tolist() |
| |
|
| | else: |
| | raise ValueError(f"Unknown task: {task}") |
| |
|
| | return ( |
| | all_imgs_path, |
| | num_inputs, |
| | num_targets, |
| | input_indices, |
| | anchor_indices, |
| | torch.tensor(c2ws[:, :3]).float(), |
| | torch.tensor(Ks).float(), |
| | (torch.tensor(anchor_c2ws[:, :3]).float() if anchor_c2ws is not None else None), |
| | (torch.tensor(anchor_Ks).float() if anchor_Ks is not None else None), |
| | ) |
| |
|
| |
|
| | def main( |
| | data_path, |
| | data_items=None, |
| | task="img2img", |
| | save_subdir="", |
| | H=None, |
| | W=None, |
| | T=None, |
| | use_traj_prior=False, |
| | **overwrite_options, |
| | ): |
| | if H is not None: |
| | VERSION_DICT["H"] = H |
| | if W is not None: |
| | VERSION_DICT["W"] = W |
| | if T is not None: |
| | VERSION_DICT["T"] = [int(t) for t in T.split(",")] if isinstance(T, str) else T |
| |
|
| | options = VERSION_DICT["options"] |
| | options["chunk_strategy"] = "nearest-gt" |
| | options["video_save_fps"] = 30.0 |
| | options["beta_linear_start"] = 5e-6 |
| | options["log_snr_shift"] = 2.4 |
| | options["guider_types"] = 1 |
| | options["cfg"] = 2.0 |
| | options["camera_scale"] = 2.0 |
| | options["num_steps"] = 50 |
| | options["cfg_min"] = 1.2 |
| | options["encoding_t"] = 1 |
| | options["decoding_t"] = 1 |
| | options["num_inputs"] = None |
| | options["seed"] = 23 |
| | options.update(overwrite_options) |
| |
|
| | num_inputs = options["num_inputs"] |
| | seed = options["seed"] |
| |
|
| | if data_items is not None: |
| | if not isinstance(data_items, (list, tuple)): |
| | data_items = data_items.split(",") |
| | scenes = [os.path.join(data_path, item) for item in data_items] |
| | else: |
| | scenes = glob.glob(osp.join(data_path, "*")) |
| |
|
| | for scene in tqdm(scenes): |
| | save_path_scene = os.path.join( |
| | WORK_DIR, task, save_subdir, os.path.splitext(os.path.basename(scene))[0] |
| | ) |
| | if options.get("skip_saved", False) and os.path.exists( |
| | os.path.join(save_path_scene, "transforms.json") |
| | ): |
| | print(f"Skipping {scene} as it is already sampled.") |
| | continue |
| |
|
| | |
| | ( |
| | all_imgs_path, |
| | num_inputs, |
| | num_targets, |
| | input_indices, |
| | anchor_indices, |
| | c2ws, |
| | Ks, |
| | anchor_c2ws, |
| | anchor_Ks, |
| | ) = parse_task( |
| | task, |
| | scene, |
| | num_inputs, |
| | VERSION_DICT["T"], |
| | VERSION_DICT, |
| | ) |
| | assert num_inputs is not None |
| | |
| | image_cond = { |
| | "img": all_imgs_path, |
| | "input_indices": input_indices, |
| | "prior_indices": anchor_indices, |
| | } |
| | |
| | camera_cond = { |
| | "c2w": c2ws.clone(), |
| | "K": Ks.clone(), |
| | "input_indices": list(range(num_inputs + num_targets)), |
| | } |
| | |
| | video_path_generator = run_one_scene( |
| | task, |
| | VERSION_DICT, |
| | model=MODEL, |
| | ae=AE, |
| | conditioner=CONDITIONER, |
| | denoiser=DENOISER, |
| | image_cond=image_cond, |
| | camera_cond=camera_cond, |
| | save_path=save_path_scene, |
| | use_traj_prior=use_traj_prior, |
| | traj_prior_Ks=anchor_Ks, |
| | traj_prior_c2ws=anchor_c2ws, |
| | seed=seed, |
| | ) |
| | for _ in video_path_generator: |
| | pass |
| |
|
| | |
| | c2ws = c2ws @ torch.tensor(np.diag([1, -1, -1, 1])).float() |
| | img_paths = sorted(glob.glob(osp.join(save_path_scene, "samples-rgb", "*.png"))) |
| | if len(img_paths) != len(c2ws): |
| | input_img_paths = sorted( |
| | glob.glob(osp.join(save_path_scene, "input", "*.png")) |
| | ) |
| | assert len(img_paths) == num_targets |
| | assert len(input_img_paths) == num_inputs |
| | assert c2ws.shape[0] == num_inputs + num_targets |
| | target_indices = [i for i in range(c2ws.shape[0]) if i not in input_indices] |
| | img_paths = [ |
| | input_img_paths[input_indices.index(i)] |
| | if i in input_indices |
| | else img_paths[target_indices.index(i)] |
| | for i in range(c2ws.shape[0]) |
| | ] |
| | create_transforms_simple( |
| | save_path=save_path_scene, |
| | img_paths=img_paths, |
| | img_whs=np.array([VERSION_DICT["W"], VERSION_DICT["H"]])[None].repeat( |
| | num_inputs + num_targets, 0 |
| | ), |
| | c2ws=c2ws, |
| | Ks=Ks, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | fire.Fire(main) |
| |
|