Spaces:
Configuration error
Configuration error
| """ | |
| Copyright (c) Microsoft Corporation. | |
| Licensed under the MIT license. | |
| End-to-end inference codes for | |
| 3D hand mesh reconstruction from an image | |
| """ | |
| from __future__ import absolute_import, division, print_function | |
| import argparse | |
| import os | |
| import os.path as op | |
| import code | |
| import json | |
| import time | |
| import datetime | |
| import torch | |
| import torchvision.models as models | |
| from torchvision.utils import make_grid | |
| import gc | |
| import numpy as np | |
| import cv2 | |
| from custom_mesh_graphormer.modeling.bert import BertConfig, Graphormer | |
| from custom_mesh_graphormer.modeling.bert import Graphormer_Hand_Network as Graphormer_Network | |
| from custom_mesh_graphormer.modeling._mano import MANO, Mesh | |
| from custom_mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat | |
| from custom_mesh_graphormer.modeling.hrnet.config import config as hrnet_config | |
| from custom_mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config | |
| import custom_mesh_graphormer.modeling.data.config as cfg | |
| from custom_mesh_graphormer.datasets.build import make_hand_data_loader | |
| from custom_mesh_graphormer.utils.logger import setup_logger | |
| from custom_mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather | |
| from custom_mesh_graphormer.utils.miscellaneous import mkdir, set_seed | |
| from custom_mesh_graphormer.utils.metric_logger import AverageMeter | |
| from custom_mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction_and_att_local, visualize_reconstruction_no_text | |
| from custom_mesh_graphormer.utils.metric_pampjpe import reconstruction_error | |
| from custom_mesh_graphormer.utils.geometric_layers import orthographic_projection | |
| from PIL import Image | |
| from torchvision import transforms | |
| from comfy.model_management import get_torch_device | |
| device = get_torch_device() | |
| transform = transforms.Compose([ | |
| transforms.Resize(224), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225])]) | |
| transform_visualize = transforms.Compose([ | |
| transforms.Resize(224), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor()]) | |
| def run_inference(args, image_list, Graphormer_model, mano, renderer, mesh_sampler): | |
| # switch to evaluate mode | |
| Graphormer_model.eval() | |
| mano.eval() | |
| with torch.no_grad(): | |
| for image_file in image_list: | |
| if 'pred' not in image_file: | |
| att_all = [] | |
| print(image_file) | |
| img = Image.open(image_file) | |
| img_tensor = transform(img) | |
| img_visual = transform_visualize(img) | |
| batch_imgs = torch.unsqueeze(img_tensor, 0).to(device) | |
| batch_visual_imgs = torch.unsqueeze(img_visual, 0).to(device) | |
| # forward-pass | |
| pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, mano, mesh_sampler) | |
| # obtain 3d joints from full mesh | |
| pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices) | |
| pred_3d_pelvis = pred_3d_joints_from_mesh[:,cfg.J_NAME.index('Wrist'),:] | |
| pred_3d_joints_from_mesh = pred_3d_joints_from_mesh - pred_3d_pelvis[:, None, :] | |
| pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :] | |
| # save attantion | |
| att_max_value = att[-1] | |
| att_cpu = np.asarray(att_max_value.cpu().detach()) | |
| att_all.append(att_cpu) | |
| # obtain 3d joints, which are regressed from the full mesh | |
| pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices) | |
| # obtain 2d joints, which are projected from 3d joints of mesh | |
| pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous()) | |
| pred_2d_coarse_vertices_from_mesh = orthographic_projection(pred_vertices_sub.contiguous(), pred_camera.contiguous()) | |
| visual_imgs_output = visualize_mesh( renderer, batch_visual_imgs[0], | |
| pred_vertices[0].detach(), | |
| pred_camera.detach()) | |
| # visual_imgs_output = visualize_mesh_and_attention( renderer, batch_visual_imgs[0], | |
| # pred_vertices[0].detach(), | |
| # pred_vertices_sub[0].detach(), | |
| # pred_2d_coarse_vertices_from_mesh[0].detach(), | |
| # pred_2d_joints_from_mesh[0].detach(), | |
| # pred_camera.detach(), | |
| # att[-1][0].detach()) | |
| visual_imgs = visual_imgs_output.transpose(1,2,0) | |
| visual_imgs = np.asarray(visual_imgs) | |
| temp_fname = image_file[:-4] + '_graphormer_pred.jpg' | |
| print('save to ', temp_fname) | |
| cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255)) | |
| return | |
| def visualize_mesh( renderer, images, | |
| pred_vertices_full, | |
| pred_camera): | |
| img = images.cpu().numpy().transpose(1,2,0) | |
| # Get predict vertices for the particular example | |
| vertices_full = pred_vertices_full.cpu().numpy() | |
| cam = pred_camera.cpu().numpy() | |
| # Visualize only mesh reconstruction | |
| rend_img = visualize_reconstruction_no_text(img, 224, vertices_full, cam, renderer, color='light_blue') | |
| rend_img = rend_img.transpose(2,0,1) | |
| return rend_img | |
| def visualize_mesh_and_attention( renderer, images, | |
| pred_vertices_full, | |
| pred_vertices, | |
| pred_2d_vertices, | |
| pred_2d_joints, | |
| pred_camera, | |
| attention): | |
| img = images.cpu().numpy().transpose(1,2,0) | |
| # Get predict vertices for the particular example | |
| vertices_full = pred_vertices_full.cpu().numpy() | |
| vertices = pred_vertices.cpu().numpy() | |
| vertices_2d = pred_2d_vertices.cpu().numpy() | |
| joints_2d = pred_2d_joints.cpu().numpy() | |
| cam = pred_camera.cpu().numpy() | |
| att = attention.cpu().numpy() | |
| # Visualize reconstruction and attention | |
| rend_img = visualize_reconstruction_and_att_local(img, 224, vertices_full, vertices, vertices_2d, cam, renderer, joints_2d, att, color='light_blue') | |
| rend_img = rend_img.transpose(2,0,1) | |
| return rend_img | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| ######################################################### | |
| # Data related arguments | |
| ######################################################### | |
| parser.add_argument("--num_workers", default=4, type=int, | |
| help="Workers in dataloader.") | |
| parser.add_argument("--img_scale_factor", default=1, type=int, | |
| help="adjust image resolution.") | |
| parser.add_argument("--image_file_or_path", default='./samples/hand', type=str, | |
| help="test data") | |
| ######################################################### | |
| # Loading/saving checkpoints | |
| ######################################################### | |
| parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False, | |
| help="Path to pre-trained transformer model or model type.") | |
| parser.add_argument("--resume_checkpoint", default=None, type=str, required=False, | |
| help="Path to specific checkpoint for resume training.") | |
| parser.add_argument("--output_dir", default='output/', type=str, required=False, | |
| help="The output directory to save checkpoint and test results.") | |
| parser.add_argument("--config_name", default="", type=str, | |
| help="Pretrained config name or path if not the same as model_name.") | |
| parser.add_argument('-a', '--arch', default='hrnet-w64', | |
| help='CNN backbone architecture: hrnet-w64, hrnet, resnet50') | |
| ######################################################### | |
| # Model architectures | |
| ######################################################### | |
| parser.add_argument("--num_hidden_layers", default=4, type=int, required=False, | |
| help="Update model config if given") | |
| parser.add_argument("--hidden_size", default=-1, type=int, required=False, | |
| help="Update model config if given") | |
| parser.add_argument("--num_attention_heads", default=4, type=int, required=False, | |
| help="Update model config if given. Note that the division of " | |
| "hidden_size / num_attention_heads should be in integer.") | |
| parser.add_argument("--intermediate_size", default=-1, type=int, required=False, | |
| help="Update model config if given.") | |
| parser.add_argument("--input_feat_dim", default='2051,512,128', type=str, | |
| help="The Image Feature Dimension.") | |
| parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str, | |
| help="The Image Feature Dimension.") | |
| parser.add_argument("--which_gcn", default='0,0,1', type=str, | |
| help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv") | |
| parser.add_argument("--mesh_type", default='hand', type=str, help="body or hand") | |
| ######################################################### | |
| # Others | |
| ######################################################### | |
| parser.add_argument("--run_eval_only", default=True, action='store_true',) | |
| parser.add_argument("--device", type=str, default='cuda', | |
| help="cuda or cpu") | |
| parser.add_argument('--seed', type=int, default=88, | |
| help="random seed for initialization.") | |
| args = parser.parse_args() | |
| return args | |
| def main(args): | |
| global logger | |
| # Setup CUDA, GPU & distributed training | |
| args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 | |
| os.environ['OMP_NUM_THREADS'] = str(args.num_workers) | |
| print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS'])) | |
| mkdir(args.output_dir) | |
| logger = setup_logger("Graphormer", args.output_dir, get_rank()) | |
| set_seed(args.seed, args.num_gpus) | |
| logger.info("Using {} GPUs".format(args.num_gpus)) | |
| # Mesh and MANO utils | |
| mano_model = MANO().to(args.device) | |
| mano_model.layer = mano_model.layer.to(device) | |
| mesh_sampler = Mesh() | |
| # Renderer for visualization | |
| renderer = Renderer(faces=mano_model.face) | |
| # Load pretrained model | |
| trans_encoder = [] | |
| input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')] | |
| hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')] | |
| output_feat_dim = input_feat_dim[1:] + [3] | |
| # which encoder block to have graph convs | |
| which_blk_graph = [int(item) for item in args.which_gcn.split(',')] | |
| if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint: | |
| # if only run eval, load checkpoint | |
| logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint)) | |
| _model = torch.load(args.resume_checkpoint) | |
| else: | |
| # init three transformer-encoder blocks in a loop | |
| for i in range(len(output_feat_dim)): | |
| config_class, model_class = BertConfig, Graphormer | |
| config = config_class.from_pretrained(args.config_name if args.config_name \ | |
| else args.model_name_or_path) | |
| config.output_attentions = False | |
| config.img_feature_dim = input_feat_dim[i] | |
| config.output_feature_dim = output_feat_dim[i] | |
| args.hidden_size = hidden_feat_dim[i] | |
| args.intermediate_size = int(args.hidden_size*2) | |
| if which_blk_graph[i]==1: | |
| config.graph_conv = True | |
| logger.info("Add Graph Conv") | |
| else: | |
| config.graph_conv = False | |
| config.mesh_type = args.mesh_type | |
| # update model structure if specified in arguments | |
| update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size'] | |
| for idx, param in enumerate(update_params): | |
| arg_param = getattr(args, param) | |
| config_param = getattr(config, param) | |
| if arg_param > 0 and arg_param != config_param: | |
| logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param)) | |
| setattr(config, param, arg_param) | |
| # init a transformer encoder and append it to a list | |
| assert config.hidden_size % config.num_attention_heads == 0 | |
| model = model_class(config=config) | |
| logger.info("Init model from scratch.") | |
| trans_encoder.append(model) | |
| # create backbone model | |
| if args.arch=='hrnet': | |
| hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' | |
| hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth' | |
| hrnet_update_config(hrnet_config, hrnet_yaml) | |
| backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) | |
| logger.info('=> loading hrnet-v2-w40 model') | |
| elif args.arch=='hrnet-w64': | |
| hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' | |
| hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth' | |
| hrnet_update_config(hrnet_config, hrnet_yaml) | |
| backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) | |
| logger.info('=> loading hrnet-v2-w64 model') | |
| else: | |
| print("=> using pre-trained model '{}'".format(args.arch)) | |
| backbone = models.__dict__[args.arch](pretrained=True) | |
| # remove the last fc layer | |
| backbone = torch.nn.Sequential(*list(backbone.children())[:-1]) | |
| trans_encoder = torch.nn.Sequential(*trans_encoder) | |
| total_params = sum(p.numel() for p in trans_encoder.parameters()) | |
| logger.info('Graphormer encoders total parameters: {}'.format(total_params)) | |
| backbone_total_params = sum(p.numel() for p in backbone.parameters()) | |
| logger.info('Backbone total parameters: {}'.format(backbone_total_params)) | |
| # build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder) | |
| _model = Graphormer_Network(args, config, backbone, trans_encoder) | |
| if args.resume_checkpoint!=None and args.resume_checkpoint!='None': | |
| # for fine-tuning or resume training or inference, load weights from checkpoint | |
| logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint)) | |
| # workaround approach to load sparse tensor in graph conv. | |
| state_dict = torch.load(args.resume_checkpoint) | |
| _model.load_state_dict(state_dict, strict=False) | |
| del state_dict | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # update configs to enable attention outputs | |
| setattr(_model.trans_encoder[-1].config,'output_attentions', True) | |
| setattr(_model.trans_encoder[-1].config,'output_hidden_states', True) | |
| _model.trans_encoder[-1].bert.encoder.output_attentions = True | |
| _model.trans_encoder[-1].bert.encoder.output_hidden_states = True | |
| for iter_layer in range(4): | |
| _model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True | |
| for inter_block in range(3): | |
| setattr(_model.trans_encoder[-1].config,'device', args.device) | |
| _model.to(args.device) | |
| logger.info("Run inference") | |
| image_list = [] | |
| if not args.image_file_or_path: | |
| raise ValueError("image_file_or_path not specified") | |
| if op.isfile(args.image_file_or_path): | |
| image_list = [args.image_file_or_path] | |
| elif op.isdir(args.image_file_or_path): | |
| # should be a path with images only | |
| for filename in os.listdir(args.image_file_or_path): | |
| if filename.endswith(".png") or filename.endswith(".jpg") and 'pred' not in filename: | |
| image_list.append(args.image_file_or_path+'/'+filename) | |
| else: | |
| raise ValueError("Cannot find images at {}".format(args.image_file_or_path)) | |
| run_inference(args, image_list, _model, mano_model, renderer, mesh_sampler) | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| main(args) | |