Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| sys.path.insert(1, os.path.join(sys.path[0], '..')) | |
| import cv2 | |
| import os | |
| import time | |
| import imageio | |
| import numpy as np | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from PIL import Image, ImageDraw, ImageFont | |
| import torch | |
| import torchvision | |
| from torch import Tensor | |
| from torchvision.utils import make_grid | |
| from torchvision.transforms.functional import to_tensor | |
| def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None): | |
| """ | |
| video: torch.Tensor, b,c,t,h,w, 0-1 | |
| if -1~1, enable rescale=True | |
| """ | |
| n = video.shape[0] | |
| video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w | |
| nrow = int(np.sqrt(n)) if nrow is None else nrow | |
| frame_grids = [torchvision.utils.make_grid(framesheet, nrow=nrow) for framesheet in video] # [3, grid_h, grid_w] | |
| grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [T, 3, grid_h, grid_w] | |
| grid = torch.clamp(grid.float(), -1., 1.) | |
| if rescale: | |
| grid = (grid + 1.0) / 2.0 | |
| grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3] | |
| #print(f'Save video to {savepath}') | |
| torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'}) | |
| # ---------------------------------------------------------------------------------------------- | |
| def savenp2sheet(imgs, savepath, nrow=None): | |
| """ save multiple imgs (in numpy array type) to a img sheet. | |
| img sheet is one row. | |
| imgs: | |
| np array of size [N, H, W, 3] or List[array] with array size = [H,W,3] | |
| """ | |
| if imgs.ndim == 4: | |
| img_list = [imgs[i] for i in range(imgs.shape[0])] | |
| imgs = img_list | |
| imgs_new = [] | |
| for i, img in enumerate(imgs): | |
| if img.ndim == 3 and img.shape[0] == 3: | |
| img = np.transpose(img,(1,2,0)) | |
| assert(img.ndim == 3 and img.shape[-1] == 3), img.shape # h,w,3 | |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
| imgs_new.append(img) | |
| n = len(imgs) | |
| if nrow is not None: | |
| n_cols = nrow | |
| else: | |
| n_cols=int(n**0.5) | |
| n_rows=int(np.ceil(n/n_cols)) | |
| print(n_cols) | |
| print(n_rows) | |
| imgsheet = cv2.vconcat([cv2.hconcat(imgs_new[i*n_cols:(i+1)*n_cols]) for i in range(n_rows)]) | |
| cv2.imwrite(savepath, imgsheet) | |
| print(f'saved in {savepath}') | |
| # ---------------------------------------------------------------------------------------------- | |
| def save_np_to_img(img, path, norm=True): | |
| if norm: | |
| img = (img + 1) / 2 * 255 | |
| img = img.astype(np.uint8) | |
| image = Image.fromarray(img) | |
| image.save(path, q=95) | |
| # ---------------------------------------------------------------------------------------------- | |
| def npz_to_imgsheet_5d(data_path, res_dir, nrow=None,): | |
| if isinstance(data_path, str): | |
| imgs = np.load(data_path)['arr_0'] # NTHWC | |
| elif isinstance(data_path, np.ndarray): | |
| imgs = data_path | |
| else: | |
| raise Exception | |
| if os.path.isdir(res_dir): | |
| res_path = os.path.join(res_dir, f'samples.jpg') | |
| else: | |
| assert(res_dir.endswith('.jpg')) | |
| res_path = res_dir | |
| imgs = np.concatenate([imgs[i] for i in range(imgs.shape[0])], axis=0) | |
| savenp2sheet(imgs, res_path, nrow=nrow) | |
| # ---------------------------------------------------------------------------------------------- | |
| def npz_to_imgsheet_4d(data_path, res_path, nrow=None,): | |
| if isinstance(data_path, str): | |
| imgs = np.load(data_path)['arr_0'] # NHWC | |
| elif isinstance(data_path, np.ndarray): | |
| imgs = data_path | |
| else: | |
| raise Exception | |
| print(imgs.shape) | |
| savenp2sheet(imgs, res_path, nrow=nrow) | |
| # ---------------------------------------------------------------------------------------------- | |
| def tensor_to_imgsheet(tensor, save_path): | |
| """ | |
| save a batch of videos in one image sheet with shape of [batch_size * num_frames]. | |
| data: [b,c,t,h,w] | |
| """ | |
| assert(tensor.dim() == 5) | |
| b,c,t,h,w = tensor.shape | |
| imgs = [tensor[bi,:,ti, :, :] for bi in range(b) for ti in range(t)] | |
| torchvision.utils.save_image(imgs, save_path, normalize=True, nrow=t) | |
| # ---------------------------------------------------------------------------------------------- | |
| def npz_to_frames(data_path, res_dir, norm, num_frames=None, num_samples=None): | |
| start = time.time() | |
| arr = np.load(data_path) | |
| imgs = arr['arr_0'] # [N, T, H, W, 3] | |
| print('original data shape: ', imgs.shape) | |
| if num_samples is not None: | |
| imgs = imgs[:num_samples, :, :, :, :] | |
| print('after sample selection: ', imgs.shape) | |
| if num_frames is not None: | |
| imgs = imgs[:, :num_frames, :, :, :] | |
| print('after frame selection: ', imgs.shape) | |
| for vid in tqdm(range(imgs.shape[0]), desc='Video'): | |
| video_dir = os.path.join(res_dir, f'video{vid:04d}') | |
| os.makedirs(video_dir, exist_ok=True) | |
| for fid in range(imgs.shape[1]): | |
| frame = imgs[vid, fid, :, :, :] #HW3 | |
| save_np_to_img(frame, os.path.join(video_dir, f'frame{fid:04d}.jpg'), norm=norm) | |
| print('Finish') | |
| print(f'Total time = {time.time()- start}') | |
| # ---------------------------------------------------------------------------------------------- | |
| def npz_to_gifs(data_path, res_dir, duration=0.2, start_idx=0, num_videos=None, mode='gif'): | |
| os.makedirs(res_dir, exist_ok=True) | |
| if isinstance(data_path, str): | |
| imgs = np.load(data_path)['arr_0'] # NTHWC | |
| elif isinstance(data_path, np.ndarray): | |
| imgs = data_path | |
| else: | |
| raise Exception | |
| for i in range(imgs.shape[0]): | |
| frames = [imgs[i,j,:,:,:] for j in range(imgs[i].shape[0])] # [(h,w,3)] | |
| if mode == 'gif': | |
| imageio.mimwrite(os.path.join(res_dir, f'samples_{start_idx+i}.gif'), frames, format='GIF', duration=duration) | |
| elif mode == 'mp4': | |
| frames = [torch.from_numpy(frame) for frame in frames] | |
| frames = torch.stack(frames, dim=0).to(torch.uint8) # [T, H, W, C] | |
| torchvision.io.write_video(os.path.join(res_dir, f'samples_{start_idx+i}.mp4'), | |
| frames, fps=0.5, video_codec='h264', options={'crf': '10'}) | |
| if i+ 1 == num_videos: | |
| break | |
| # ---------------------------------------------------------------------------------------------- | |
| def fill_with_black_squares(video, desired_len: int) -> Tensor: | |
| if len(video) >= desired_len: | |
| return video | |
| return torch.cat([ | |
| video, | |
| torch.zeros_like(video[0]).unsqueeze(0).repeat(desired_len - len(video), 1, 1, 1), | |
| ], dim=0) | |
| # ---------------------------------------------------------------------------------------------- | |
| def load_num_videos(data_path, num_videos): | |
| # data_path can be either data_path of np array | |
| if isinstance(data_path, str): | |
| videos = np.load(data_path)['arr_0'] # NTHWC | |
| elif isinstance(data_path, np.ndarray): | |
| videos = data_path | |
| else: | |
| raise Exception | |
| if num_videos is not None: | |
| videos = videos[:num_videos, :, :, :, :] | |
| return videos | |
| # ---------------------------------------------------------------------------------------------- | |
| def npz_to_video_grid(data_path, out_path, num_frames=None, fps=8, num_videos=None, nrow=None, verbose=True): | |
| if isinstance(data_path, str): | |
| videos = load_num_videos(data_path, num_videos) | |
| elif isinstance(data_path, np.ndarray): | |
| videos = data_path | |
| else: | |
| raise Exception | |
| n,t,h,w,c = videos.shape | |
| videos_th = [] | |
| for i in range(n): | |
| video = videos[i, :,:,:,:] | |
| images = [video[j, :,:,:] for j in range(t)] | |
| images = [to_tensor(img) for img in images] | |
| video = torch.stack(images) | |
| videos_th.append(video) | |
| if num_frames is None: | |
| num_frames = videos.shape[1] | |
| if verbose: | |
| videos = [fill_with_black_squares(v, num_frames) for v in tqdm(videos_th, desc='Adding empty frames')] # NTCHW | |
| else: | |
| videos = [fill_with_black_squares(v, num_frames) for v in videos_th] # NTCHW | |
| frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4) # [T, N, C, H, W] | |
| if nrow is None: | |
| nrow = int(np.ceil(np.sqrt(n))) | |
| if verbose: | |
| frame_grids = [make_grid(fs, nrow=nrow) for fs in tqdm(frame_grids, desc='Making grids')] | |
| else: | |
| frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids] | |
| if os.path.dirname(out_path) != "": | |
| os.makedirs(os.path.dirname(out_path), exist_ok=True) | |
| frame_grids = (torch.stack(frame_grids) * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, H, W, C] | |
| torchvision.io.write_video(out_path, frame_grids, fps=fps, video_codec='h264', options={'crf': '10'}) | |
| # ---------------------------------------------------------------------------------------------- | |
| def npz_to_gif_grid(data_path, out_path, n_cols=None, num_videos=20): | |
| arr = np.load(data_path) | |
| imgs = arr['arr_0'] # [N, T, H, W, 3] | |
| imgs = imgs[:num_videos] | |
| n, t, h, w, c = imgs.shape | |
| assert(n == num_videos) | |
| n_cols = n_cols if n_cols else imgs.shape[0] | |
| n_rows = np.ceil(imgs.shape[0] / n_cols).astype(np.int8) | |
| H, W = h * n_rows, w * n_cols | |
| grid = np.zeros((t, H, W, c), dtype=np.uint8) | |
| for i in range(n_rows): | |
| for j in range(n_cols): | |
| if i*n_cols+j < imgs.shape[0]: | |
| grid[:, i*h:(i+1)*h, j*w:(j+1)*w, :] = imgs[i*n_cols+j, :, :, :, :] | |
| videos = [grid[i] for i in range(grid.shape[0])] # grid: TH'W'C | |
| imageio.mimwrite(out_path, videos, format='GIF', duration=0.5,palettesize=256) | |
| # ---------------------------------------------------------------------------------------------- | |
| def torch_to_video_grid(videos, out_path, num_frames, fps, num_videos=None, nrow=None, verbose=True): | |
| """ | |
| videos: -1 ~ 1, torch.Tensor, BCTHW | |
| """ | |
| n,t,h,w,c = videos.shape | |
| videos_th = [videos[i, ...] for i in range(n)] | |
| if verbose: | |
| videos = [fill_with_black_squares(v, num_frames) for v in tqdm(videos_th, desc='Adding empty frames')] # NTCHW | |
| else: | |
| videos = [fill_with_black_squares(v, num_frames) for v in videos_th] # NTCHW | |
| frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4) # [T, N, C, H, W] | |
| if nrow is None: | |
| nrow = int(np.ceil(np.sqrt(n))) | |
| if verbose: | |
| frame_grids = [make_grid(fs, nrow=nrow) for fs in tqdm(frame_grids, desc='Making grids')] | |
| else: | |
| frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids] | |
| if os.path.dirname(out_path) != "": | |
| os.makedirs(os.path.dirname(out_path), exist_ok=True) | |
| frame_grids = ((torch.stack(frame_grids) + 1) / 2 * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, H, W, C] | |
| torchvision.io.write_video(out_path, frame_grids, fps=fps, video_codec='h264', options={'crf': '10'}) | |
| def log_txt_as_img(wh, xc, size=10): | |
| # wh a tuple of (width, height) | |
| # xc a list of captions to plot | |
| b = len(xc) | |
| txts = list() | |
| for bi in range(b): | |
| txt = Image.new("RGB", wh, color="white") | |
| draw = ImageDraw.Draw(txt) | |
| font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) | |
| nc = int(40 * (wh[0] / 256)) | |
| lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) | |
| try: | |
| draw.text((0, 0), lines, fill="black", font=font) | |
| except UnicodeEncodeError: | |
| print("Cant encode string for logging. Skipping.") | |
| txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 | |
| txts.append(txt) | |
| txts = np.stack(txts) | |
| txts = torch.tensor(txts) | |
| return txts | |