Spaces:
Running
on
Zero
Running
on
Zero
| # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: LicenseRef-NvidiaProprietary | |
| # | |
| # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual | |
| # property and proprietary rights in and to this material, related | |
| # documentation and any modifications thereto. Any use, reproduction, | |
| # disclosure or distribution of this material and related documentation | |
| # without an express license agreement from NVIDIA CORPORATION or | |
| # its affiliates is strictly prohibited. | |
| # import sys | |
| # sys.path.append('../../code_v1.0') | |
| import torch | |
| import numpy as np | |
| from torch_utils import persistence | |
| from recon.models.stylegan.networks_stylegan2 import MappingNetwork, SynthesisLayer | |
| from torch_utils import misc | |
| from training.volumetric_rendering.ray_sampler import RaySampler | |
| from training.volumetric_rendering.math_utils import normalize_vecs | |
| import dnnlib | |
| # def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: | |
| # """ | |
| # Normalize vector lengths. | |
| # """ | |
| # return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) | |
| # class RaySampler(torch.nn.Module): | |
| # def __init__(self): | |
| # super().__init__() | |
| # self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None | |
| # def forward(self, cam2world_matrix, intrinsics, resolution): | |
| # """ | |
| # Create batches of rays and return origins and directions. | |
| # cam2world_matrix: (N, 4, 4) | |
| # intrinsics: (N, 3, 3) | |
| # resolution: int | |
| # ray_origins: (N, M, 3) | |
| # ray_dirs: (N, M, 2) | |
| # """ | |
| # N, M = cam2world_matrix.shape[0], resolution**2 | |
| # cam_locs_world = cam2world_matrix[:, :3, 3] | |
| # fx = intrinsics[:, 0, 0] | |
| # fy = intrinsics[:, 1, 1] | |
| # cx = intrinsics[:, 0, 2] | |
| # cy = intrinsics[:, 1, 2] | |
| # sk = intrinsics[:, 0, 1] | |
| # # uv = torch.stack(torch.meshgrid(torch.arange(resolution, dtype=torch.float32, device=cam2world_matrix.device), torch.arange(resolution, dtype=torch.float32, device=cam2world_matrix.device), indexing='ij')) * (1./resolution) + (0.5/resolution) | |
| # uv = torch.stack(torch.meshgrid(torch.arange(resolution, dtype=torch.float32, device=cam2world_matrix.device), torch.arange(resolution, dtype=torch.float32, device=cam2world_matrix.device), indexing='ij')) + 0.5 | |
| # uv = uv.flip(0).reshape(2, -1).transpose(1, 0) | |
| # uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) | |
| # x_cam = uv[:, :, 0].view(N, -1) | |
| # y_cam = uv[:, :, 1].view(N, -1) | |
| # # z_cam = torch.ones((N, M), device=cam2world_matrix.device) # Original EG3D implementation, z points inward | |
| # z_cam = - torch.ones((N, M), device=cam2world_matrix.device) # Our camera space coordinate | |
| # x_lift = - (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam | |
| # y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam | |
| # cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1) | |
| # world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] | |
| # ray_dirs = world_rel_points - cam_locs_world[:, None, :] | |
| # ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2) | |
| # ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1) | |
| # return ray_origins, ray_dirs | |
| def get_dist_from_origin(ray_origins,ray_dirs): | |
| ray_dirs = normalize_vecs(ray_dirs) | |
| dist = torch.sqrt(torch.sum(ray_origins**2,dim=-1)-torch.sum(ray_origins*ray_dirs,dim=-1)**2 + 1e-10) | |
| return dist | |
| def get_intersections_with_sphere(ray_origins,ray_dirs,radius): | |
| dist = get_dist_from_origin(ray_origins,ray_dirs) | |
| valid_mask = (dist <= radius) | |
| intersections = torch.zeros_like(ray_dirs) | |
| intersections[valid_mask] = ray_origins[valid_mask] + (-torch.sum(ray_origins[valid_mask]*ray_dirs[valid_mask],dim=-1,keepdim=True) + \ | |
| torch.sqrt(radius**2 + torch.sum(ray_origins[valid_mask]*ray_dirs[valid_mask],dim=-1,keepdim=True)**2 - torch.sum(ray_origins[valid_mask]**2,dim=-1,keepdim=True)))*ray_dirs[valid_mask] | |
| return intersections, valid_mask | |
| def get_theta_phi_bg(ray_origins,ray_dirs,radius): | |
| intersections, valid_mask = get_intersections_with_sphere(ray_origins,ray_dirs,radius) | |
| phi = torch.zeros_like(intersections[...,0]) | |
| theta = torch.zeros_like(intersections[...,0]) | |
| phi[valid_mask] = torch.arcsin(torch.clamp(intersections[valid_mask][...,1]/radius,-1+1e-10,1-1e-10)) | |
| radius_xz = torch.sqrt(intersections[valid_mask][...,0]**2+intersections[valid_mask][...,2]**2) | |
| theta_no_sign = torch.arccos(torch.clamp(torch.div(intersections[valid_mask][...,2],radius_xz),-1+1e-10,1-1e-10)) | |
| theta[valid_mask] = torch.where(intersections[valid_mask][...,0]>torch.zeros_like(intersections[valid_mask][...,0]),theta_no_sign,2*np.pi-theta_no_sign) | |
| # normalizing to [-1,1] | |
| # theta = (theta-np.pi)/np.pi # times 2 because the theta can hardly exceed pi for frontal-facing scene | |
| theta = torch.sin(theta) | |
| phi = torch.sin(phi) | |
| return torch.stack([theta,phi],dim=-1), valid_mask | |
| class BGSynthesisNetwork(torch.nn.Module): | |
| def __init__(self, | |
| w_dim, # Intermediate latent (W) dimensionality. | |
| img_channels, # Number of color channels. | |
| hidden_channels = 64, | |
| L = 10, | |
| **block_kwargs, # Arguments for SynthesisBlock. | |
| ): | |
| super().__init__() | |
| self.w_dim = w_dim | |
| self.img_channels = img_channels | |
| self.hidden_channels = hidden_channels | |
| self.L = L | |
| self.num_ws = 0 | |
| for idx in range(5): | |
| in_channels = L*4 if idx == 0 else hidden_channels | |
| out_channels = hidden_channels if idx < 4 else img_channels | |
| activation = 'lrelu' if idx < 4 else 'sigmoid' | |
| layer = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=64, kernel_size=1, use_noise=False, activation=activation, **block_kwargs) | |
| self.num_ws += 1 | |
| setattr(self, f'b{idx}', layer) | |
| def positional_encoding(self, p, use_pos=False): | |
| p_transformed = torch.cat([torch.cat( | |
| [torch.sin((2 ** i) * np.pi * p), | |
| torch.cos((2 ** i) * np.pi * p)], | |
| dim=-1) for i in range(self.L)], dim=-1) | |
| if use_pos: | |
| p_transformed = torch.cat([p_transformed, p], -1) | |
| return p_transformed | |
| def forward(self, ws, x, update_emas, **block_kwargs): | |
| _ = update_emas | |
| layer_ws = [] | |
| with torch.autograd.profiler.record_function('split_ws'): | |
| misc.assert_shape(ws, [None, self.num_ws, self.w_dim]) | |
| ws = ws.to(torch.float32) | |
| w_idx = 0 | |
| for idx in range(5): | |
| layer = getattr(self, f'b{idx}') | |
| layer_ws.append(ws[:,w_idx]) | |
| w_idx += 1 | |
| x = self.positional_encoding(x) # (N,M,L*4) | |
| x = x.permute(0,2,1).unsqueeze(-1) # (N,L*4,M,1) | |
| for idx, cur_ws in enumerate(layer_ws): | |
| layer = getattr(self, f'b{idx}') | |
| x = layer(x, cur_ws, **block_kwargs) | |
| return x | |
| def extra_repr(self): | |
| return ' '.join([ | |
| f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},', | |
| f'hidden_channels={self.hidden_channels:d}, img_channels={self.img_channels:d},', | |
| f'L={self.L:d}']) | |
| class BGGenerator(torch.nn.Module): | |
| def __init__(self, | |
| z_dim, # Input latent (Z) dimensionality. | |
| c_dim, # Conditioning label (C) dimensionality. | |
| w_dim, # Intermediate latent (W) dimensionality. | |
| img_channels, # Number of output color channels. | |
| mapping_kwargs = {}, # Arguments for MappingNetwork. | |
| **synthesis_kwargs, # Arguments for SynthesisNetwork. | |
| ): | |
| super().__init__() | |
| self.z_dim = z_dim | |
| self.c_dim = c_dim | |
| self.w_dim = w_dim | |
| self.img_channels = img_channels | |
| self.synthesis = BGSynthesisNetwork(w_dim=w_dim, img_channels=img_channels, **synthesis_kwargs) | |
| self.num_ws = self.synthesis.num_ws | |
| self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) | |
| def forward(self, z, c, x, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): | |
| ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) | |
| img = self.synthesis(ws, x, update_emas=update_emas, **synthesis_kwargs) | |
| return img | |
| if __name__=='__main__': | |
| import os | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |
| from PIL import Image | |
| ray_sampler = RaySampler() | |
| camera_params = torch.eye(4).unsqueeze(0) | |
| camera_params[...,2,3] += 4.5 | |
| intrinsics = torch.eye(3).unsqueeze(0) | |
| intrinsics[:,0,0] = 300 | |
| intrinsics[:,1,1] = 300 | |
| intrinsics[:,0,2] = 32 | |
| intrinsics[:,1,2] = 32 | |
| neural_rendering_resolution = 64 | |
| ray_origins, ray_directions = ray_sampler(camera_params, intrinsics, neural_rendering_resolution) | |
| # print(ray_directions) | |
| angles, valid_mask = get_theta_phi_bg(ray_origins,ray_directions,radius=1.0) | |
| angles = angles.reshape(1,64,64,2) | |
| print(angles[0,31]) | |
| print(angles.shape) | |
| print(valid_mask.shape) | |
| valid_mask = valid_mask.reshape(1,64,64).squeeze(0).numpy().astype(np.uint8)*255 | |
| Image.fromarray(valid_mask, 'L').save('bg_mask.png') |