Spaces:
Running
on
Zero
Running
on
Zero
| # Dataloader for training GenHead, modified from EG3D: https://github.com/NVlabs/eg3d | |
| # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: LicenseRef-NvidiaProprietary | |
| # | |
| # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual | |
| # property and proprietary rights in and to this material, related | |
| # documentation and any modifications thereto. Any use, reproduction, | |
| # disclosure or distribution of this material and related documentation | |
| # without an express license agreement from NVIDIA CORPORATION or | |
| # its affiliates is strictly prohibited. | |
| """Streaming images and labels from datasets created with dataset_tool.py.""" | |
| import os | |
| import sys | |
| import numpy as np | |
| import zipfile | |
| import PIL.Image | |
| import json | |
| import torch | |
| import dnnlib | |
| from training.dataloader.protocols import datum_genhead_pb2 as datum_pb2 | |
| import lmdb | |
| import cv2 | |
| # try: | |
| # import pyspng | |
| # except ImportError: | |
| pyspng = None | |
| #---------------------------------------------------------------------------- | |
| class Dataset(torch.utils.data.Dataset): | |
| def __init__(self, | |
| name, # Name of the dataset. | |
| raw_shape, # Shape of the raw image data (NCHW). | |
| max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. | |
| use_labels = False, # Enable conditioning labels? False = label dimension is zero. | |
| xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. | |
| random_seed = 0, # Random seed to use when applying max_size. | |
| ): | |
| self._name = name | |
| self._raw_shape = list(raw_shape) | |
| self._use_labels = use_labels | |
| self._raw_labels = None | |
| self._label_shape = None | |
| # Apply max_size. | |
| self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) | |
| if (max_size is not None) and (self._raw_idx.size > max_size): | |
| np.random.RandomState(random_seed).shuffle(self._raw_idx) | |
| self._raw_idx = np.sort(self._raw_idx[:max_size]) | |
| # Apply xflip. | |
| self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) | |
| if xflip: | |
| self._raw_idx = np.tile(self._raw_idx, 2) | |
| self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) | |
| def _get_raw_labels(self): | |
| if self._raw_labels is None: | |
| self._raw_labels = self._load_raw_labels() if self._use_labels else None | |
| if self._raw_labels is None: | |
| self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) | |
| assert isinstance(self._raw_labels, np.ndarray) | |
| assert self._raw_labels.shape[0] == self._raw_shape[0] | |
| assert self._raw_labels.dtype in [np.float32, np.int64] | |
| if self._raw_labels.dtype == np.int64: | |
| assert self._raw_labels.ndim == 1 | |
| assert np.all(self._raw_labels >= 0) | |
| self._raw_labels_std = self._raw_labels.std(0) | |
| return self._raw_labels | |
| def close(self): # to be overridden by subclass | |
| pass | |
| def _load_raw_image(self, raw_idx): # to be overridden by subclass | |
| raise NotImplementedError | |
| def _load_raw_labels(self): # to be overridden by subclass | |
| raise NotImplementedError | |
| def __getstate__(self): | |
| return dict(self.__dict__, _raw_labels=None) | |
| def __del__(self): | |
| try: | |
| self.close() | |
| except: | |
| pass | |
| def __len__(self): | |
| return self._raw_idx.size | |
| def __getitem__(self, idx): | |
| image = self._load_raw_image(self._raw_idx[idx]) | |
| assert isinstance(image, np.ndarray) | |
| assert list(image.shape) == self.image_shape | |
| assert image.dtype == np.uint8 | |
| if self._xflip[idx]: | |
| assert image.ndim == 3 # CHW | |
| image = image[:, :, ::-1] | |
| return image.copy(), self.get_label(idx) | |
| def get_label(self, idx): | |
| label = self._get_raw_labels()[self._raw_idx[idx]] | |
| if label.dtype == np.int64: | |
| onehot = np.zeros(self.label_shape, dtype=np.float32) | |
| onehot[label] = 1 | |
| label = onehot | |
| return label.copy() | |
| def get_details(self, idx): | |
| d = dnnlib.EasyDict() | |
| d.raw_idx = int(self._raw_idx[idx]) | |
| d.xflip = (int(self._xflip[idx]) != 0) | |
| d.raw_label = self._get_raw_labels()[d.raw_idx].copy() | |
| return d | |
| def get_label_std(self): | |
| return self._raw_labels_std | |
| def name(self): | |
| return self._name | |
| def image_shape(self): | |
| return list(self._raw_shape[1:]) | |
| def num_channels(self): | |
| assert len(self.image_shape) == 3 # CHW | |
| return self.image_shape[0] | |
| def resolution(self): | |
| assert len(self.image_shape) == 3 # CHW | |
| assert self.image_shape[1] == self.image_shape[2] | |
| return self.image_shape[1] | |
| def label_shape(self): | |
| if self._label_shape is None: | |
| raw_labels = self._get_raw_labels() | |
| if raw_labels.dtype == np.int64: | |
| self._label_shape = [int(np.max(raw_labels)) + 1] | |
| else: | |
| self._label_shape = raw_labels.shape[1:] | |
| return list(self._label_shape) | |
| def label_dim(self): | |
| assert len(self.label_shape) == 1 | |
| return self.label_shape[0] | |
| def has_labels(self): | |
| return any(x != 0 for x in self.label_shape) | |
| def has_onehot_labels(self): | |
| return self._get_raw_labels().dtype == np.int64 | |
| class PortraitSynthesisSegLmdbFolderDatasetV2(Dataset): | |
| def __init__(self, | |
| path, # Path to datalist. | |
| resolution = None, # Ensure specific resolution, None = highest available. | |
| data_type = "vox2",# Set dataset type, deprecated | |
| rescale_camera = False, # Rescale camera extrinsics and intrinscs to align with an older version of camera labels | |
| **super_kwargs, # Additional arguments for the Dataset base class. | |
| ): | |
| self._path = path | |
| print(self._path) | |
| self._resolution = resolution | |
| self._zipfile = None | |
| self._data_type = data_type | |
| self.rescale_camera = rescale_camera | |
| # initialize lmdb | |
| if os.path.isdir(self._path): | |
| self.db = None | |
| self.txn = None | |
| self.num = None | |
| self.datum = None | |
| else: | |
| raise IOError('Path must point to a directory or zip') | |
| # print('./data/FFHQ_512_50/'.split("/")[-2]) | |
| # print(self._path) | |
| # print((self._path.split("/")[-2].split("_")[-2])) | |
| img_size = int(self._path.split("/")[-2].split("_")[-2]) | |
| num = int(self._path.split("/")[-2].split("_")[-1]) | |
| img_shape = [3, img_size, img_size] | |
| raw_shape = [num] + img_shape | |
| if resolution is None: | |
| self._resolution = raw_shape[2] | |
| name = os.path.splitext(os.path.basename(self._path))[0] | |
| super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) | |
| def open_lmdb(self): | |
| self.db = lmdb.open(self._path, map_size=1024 ** 4, readonly=True, lock=False) | |
| self.txn = self.db.begin() | |
| self.num = int(self.txn.get('num_samples'.encode())) | |
| self.datum = datum_pb2.Datum_genhead() | |
| def get_details(self, idx): | |
| d = dnnlib.EasyDict() | |
| d.raw_idx = int(self._raw_idx[idx]) | |
| d.xflip = (int(self._xflip[idx]) != 0) | |
| d.raw_label = self._load_raw_labels(d.raw_idx).copy() | |
| return d | |
| def get_label_std(self): | |
| return 0 | |
| def resolution(self): | |
| return self._resolution | |
| def label_shape(self): | |
| if self._label_shape is None: | |
| raw_labels = self._load_raw_labels(0) | |
| self._label_shape = raw_labels.shape | |
| return list(self._label_shape) | |
| def label_dim(self): | |
| assert len(self.label_shape) == 1 | |
| return self.label_shape[0] | |
| def has_labels(self): | |
| return any(x != 0 for x in self.label_shape) | |
| def has_onehot_labels(self): | |
| return self._load_raw_labels(0).dtype == np.int64 | |
| def __getstate__(self): | |
| return dict(super().__getstate__(), _zipfile=None) | |
| def _load_raw_image(self, raw_idx): | |
| if self.txn is None: | |
| self.open_lmdb() | |
| value = self.txn.get('{:0>8d}'.format(raw_idx).encode()) | |
| self.datum.ParseFromString(value) | |
| image = np.fromstring(self.datum.image, dtype=np.uint8) | |
| image = cv2.imdecode(image, cv2.IMREAD_COLOR) | |
| image = image[:, :, [2, 1, 0]] # bgr -> rgb | |
| if image.ndim == 2: | |
| image = image[:, :, np.newaxis] # HW => HWC | |
| image = image.transpose(2, 0, 1) # HWC => CHW | |
| return image | |
| def _load_raw_seg(self, raw_idx): | |
| if self.txn is None: | |
| self.open_lmdb() | |
| value = self.txn.get('{:0>8d}'.format(raw_idx).encode()) | |
| self.datum.ParseFromString(value) | |
| seg = np.fromstring(self.datum.seg, dtype=np.uint8) | |
| seg = cv2.imdecode(seg, cv2.IMREAD_COLOR) | |
| if seg.ndim == 2: | |
| seg = seg[:, :, np.newaxis] # HW => HWC | |
| seg = seg.transpose(2, 0, 1) # HWC => CHW | |
| if seg.shape[0] == 1: | |
| seg = np.tile(seg, (3, 1, 1)) | |
| return seg | |
| def _load_raw_labels(self, raw_idx): | |
| if self.txn is None: | |
| self.open_lmdb() | |
| value = self.txn.get('{:0>8d}'.format(raw_idx).encode()) | |
| self.datum.ParseFromString(value) | |
| labels = np.fromstring(self.datum.labels, dtype=np.float32) | |
| intrinsics = labels[16:25].reshape(3,3) | |
| if self.rescale_camera: | |
| # normalize intrinsics | |
| if self._resolution != intrinsics[0,2]*2: | |
| intrinsics[:2,:] *= (0.5*self._resolution/intrinsics[0,2]) | |
| intrinsics[0, 0] /= self._resolution | |
| intrinsics[1, 1] /= self._resolution | |
| intrinsics[0, 2] /= self._resolution | |
| intrinsics[1, 2] /= self._resolution | |
| # rescale extrinsics | |
| extrinsics = labels[:16].reshape(4,4) # Our face scale is around 0.1~0.2. Multiply by 3 to match the scale of EG3D | |
| extrinsics[:3,3] *= 3 | |
| return labels | |
| def get_label(self, idx): | |
| label = self._load_raw_labels(self._raw_idx[idx])[:25] | |
| return label.copy() | |
| def get_shape_param(self, idx): | |
| label = self._load_raw_labels(self._raw_idx[idx])[25:325] | |
| return label.copy() | |
| def get_exp_param(self, idx): | |
| label = self._load_raw_labels(self._raw_idx[idx])[325:425] | |
| return label.copy() | |
| def get_exp_param_w_jaw_pose(self, idx): | |
| label = self._load_raw_labels(self._raw_idx[idx]) | |
| label = np.concatenate([label[325:425],label[428:431]],axis=0) | |
| return label.copy() | |
| def get_pose_param(self, idx): | |
| label = self._load_raw_labels(self._raw_idx[idx])[425:431] | |
| return label.copy() | |
| def get_eye_pose_param(self, idx): | |
| label = self._load_raw_labels(self._raw_idx[idx])[431:437] | |
| return label.copy() | |
| def get_label_all(self,idx): | |
| c = self.get_label(idx) | |
| shape_param = self.get_shape_param(idx) | |
| exp_param = self.get_exp_param(idx) | |
| pose_param = self.get_pose_param(idx) | |
| eye_pose_param = self.get_eye_pose_param(idx) | |
| return c, shape_param, exp_param, pose_param, eye_pose_param | |
| def __getitem__(self, idx): | |
| image = self._load_raw_image(self._raw_idx[idx]) | |
| seg = self._load_raw_seg(self._raw_idx[idx]) | |
| assert isinstance(image, np.ndarray) | |
| assert isinstance(seg, np.ndarray) | |
| assert list(image.shape) == self.image_shape | |
| assert seg.shape[1] == self.image_shape[1] and seg.shape[2] == self.image_shape[2] | |
| assert image.dtype == np.uint8 | |
| return image.copy(), seg.copy(), self.get_label(idx), self.get_shape_param(idx), self.get_exp_param(idx), self.get_pose_param(idx), self.get_eye_pose_param(idx) |