Spaces:
Configuration error
Configuration error
| """ | |
| This file contains the MANO defination and mesh sampling operations for MANO mesh | |
| Adapted from opensource projects | |
| MANOPTH (https://github.com/hassony2/manopth) | |
| Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE) | |
| GraphCMR (https://github.com/nkolot/GraphCMR/) | |
| """ | |
| from __future__ import division | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import os.path as osp | |
| import json | |
| import code | |
| from custom_manopth.manolayer import ManoLayer | |
| import scipy.sparse | |
| import custom_mesh_graphormer.modeling.data.config as cfg | |
| from pathlib import Path | |
| from comfy.model_management import get_torch_device | |
| from wrapper_for_mps import sparse_to_dense | |
| device = get_torch_device() | |
| class MANO(nn.Module): | |
| def __init__(self): | |
| super(MANO, self).__init__() | |
| self.mano_dir = str(Path(__file__).parent / "data") | |
| self.layer = self.get_layer() | |
| self.vertex_num = 778 | |
| self.face = self.layer.th_faces.numpy() | |
| self.joint_regressor = self.layer.th_J_regressor.numpy() | |
| self.joint_num = 21 | |
| self.joints_name = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4') | |
| self.skeleton = ( (0,1), (0,5), (0,9), (0,13), (0,17), (1,2), (2,3), (3,4), (5,6), (6,7), (7,8), (9,10), (10,11), (11,12), (13,14), (14,15), (15,16), (17,18), (18,19), (19,20) ) | |
| self.root_joint_idx = self.joints_name.index('Wrist') | |
| # add fingertips to joint_regressor | |
| self.fingertip_vertex_idx = [745, 317, 444, 556, 673] # mesh vertex idx (right hand) | |
| thumbtip_onehot = np.array([1 if i == 745 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) | |
| indextip_onehot = np.array([1 if i == 317 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) | |
| middletip_onehot = np.array([1 if i == 445 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) | |
| ringtip_onehot = np.array([1 if i == 556 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) | |
| pinkytip_onehot = np.array([1 if i == 673 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) | |
| self.joint_regressor = np.concatenate((self.joint_regressor, thumbtip_onehot, indextip_onehot, middletip_onehot, ringtip_onehot, pinkytip_onehot)) | |
| self.joint_regressor = self.joint_regressor[[0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20],:] | |
| joint_regressor_torch = torch.from_numpy(self.joint_regressor).float() | |
| self.register_buffer('joint_regressor_torch', joint_regressor_torch) | |
| def get_layer(self): | |
| return ManoLayer(mano_root=osp.join(self.mano_dir), flat_hand_mean=False, use_pca=False) # load right hand MANO model | |
| def get_3d_joints(self, vertices): | |
| """ | |
| This method is used to get the joint locations from the SMPL mesh | |
| Input: | |
| vertices: size = (B, 778, 3) | |
| Output: | |
| 3D joints: size = (B, 21, 3) | |
| """ | |
| joints = torch.einsum('bik,ji->bjk', [vertices, self.joint_regressor_torch]) | |
| return joints | |
| class SparseMM(torch.autograd.Function): | |
| """Redefine sparse @ dense matrix multiplication to enable backpropagation. | |
| The builtin matrix multiplication operation does not support backpropagation in some cases. | |
| """ | |
| def forward(ctx, sparse, dense): | |
| ctx.req_grad = dense.requires_grad | |
| ctx.save_for_backward(sparse) | |
| return torch.matmul(sparse, dense) | |
| def backward(ctx, grad_output): | |
| grad_input = None | |
| sparse, = ctx.saved_tensors | |
| if ctx.req_grad: | |
| grad_input = torch.matmul(sparse.t(), grad_output) | |
| return None, grad_input | |
| def spmm(sparse, dense): | |
| sparse = sparse.to(device) | |
| dense = dense.to(device) | |
| return SparseMM.apply(sparse, dense) | |
| def scipy_to_pytorch(A, U, D): | |
| """Convert scipy sparse matrices to pytorch sparse matrix.""" | |
| ptU = [] | |
| ptD = [] | |
| for i in range(len(U)): | |
| u = scipy.sparse.coo_matrix(U[i]) | |
| i = torch.LongTensor(np.array([u.row, u.col])) | |
| v = torch.FloatTensor(u.data) | |
| ptU.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, u.shape))) | |
| for i in range(len(D)): | |
| d = scipy.sparse.coo_matrix(D[i]) | |
| i = torch.LongTensor(np.array([d.row, d.col])) | |
| v = torch.FloatTensor(d.data) | |
| ptD.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, d.shape))) | |
| return ptU, ptD | |
| def adjmat_sparse(adjmat, nsize=1): | |
| """Create row-normalized sparse graph adjacency matrix.""" | |
| adjmat = scipy.sparse.csr_matrix(adjmat) | |
| if nsize > 1: | |
| orig_adjmat = adjmat.copy() | |
| for _ in range(1, nsize): | |
| adjmat = adjmat * orig_adjmat | |
| adjmat.data = np.ones_like(adjmat.data) | |
| for i in range(adjmat.shape[0]): | |
| adjmat[i,i] = 1 | |
| num_neighbors = np.array(1 / adjmat.sum(axis=-1)) | |
| adjmat = adjmat.multiply(num_neighbors) | |
| adjmat = scipy.sparse.coo_matrix(adjmat) | |
| row = adjmat.row | |
| col = adjmat.col | |
| data = adjmat.data | |
| i = torch.LongTensor(np.array([row, col])) | |
| v = torch.from_numpy(data).float() | |
| adjmat = sparse_to_dense(torch.sparse_coo_tensor(i, v, adjmat.shape)) | |
| return adjmat | |
| def get_graph_params(filename, nsize=1): | |
| """Load and process graph adjacency matrix and upsampling/downsampling matrices.""" | |
| data = np.load(filename, encoding='latin1', allow_pickle=True) | |
| A = data['A'] | |
| U = data['U'] | |
| D = data['D'] | |
| U, D = scipy_to_pytorch(A, U, D) | |
| A = [adjmat_sparse(a, nsize=nsize) for a in A] | |
| return A, U, D | |
| class Mesh(object): | |
| """Mesh object that is used for handling certain graph operations.""" | |
| def __init__(self, filename=cfg.MANO_sampling_matrix, | |
| num_downsampling=1, nsize=1, device=torch.device('cuda')): | |
| self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize) | |
| # self._A = [a.to(device) for a in self._A] | |
| self._U = [u.to(device) for u in self._U] | |
| self._D = [d.to(device) for d in self._D] | |
| self.num_downsampling = num_downsampling | |
| def downsample(self, x, n1=0, n2=None): | |
| """Downsample mesh.""" | |
| if n2 is None: | |
| n2 = self.num_downsampling | |
| if x.ndimension() < 3: | |
| for i in range(n1, n2): | |
| x = spmm(self._D[i], x) | |
| elif x.ndimension() == 3: | |
| out = [] | |
| for i in range(x.shape[0]): | |
| y = x[i] | |
| for j in range(n1, n2): | |
| y = spmm(self._D[j], y) | |
| out.append(y) | |
| x = torch.stack(out, dim=0) | |
| return x | |
| def upsample(self, x, n1=1, n2=0): | |
| """Upsample mesh.""" | |
| if x.ndimension() < 3: | |
| for i in reversed(range(n2, n1)): | |
| x = spmm(self._U[i], x) | |
| elif x.ndimension() == 3: | |
| out = [] | |
| for i in range(x.shape[0]): | |
| y = x[i] | |
| for j in reversed(range(n2, n1)): | |
| y = spmm(self._U[j], y) | |
| out.append(y) | |
| x = torch.stack(out, dim=0) | |
| return x | |