Spaces:
Sleeping
Sleeping
File size: 6,095 Bytes
dcacefd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
#!/usr/bin/python
# -*- coding:utf-8 -*-
import torch
import numpy as np
from scipy.spatial.transform import Rotation
# from https://github.com/charnley/rmsd/blob/master/rmsd/calculate_rmsd.py
def kabsch_rotation(P, Q):
"""
Using the Kabsch algorithm with two sets of paired point P and Q, centered
around the centroid. Each vector set is represented as an NxD
matrix, where D is the the dimension of the space.
The algorithm works in three steps:
- a centroid translation of P and Q (assumed done before this function
call)
- the computation of a covariance matrix C
- computation of the optimal rotation matrix U
For more info see http://en.wikipedia.org/wiki/Kabsch_algorithm
Parameters
----------
P : array
(N,D) matrix, where N is points and D is dimension.
Q : array
(N,D) matrix, where N is points and D is dimension.
Returns
-------
U : matrix
Rotation matrix (D,D)
"""
# Computation of the covariance matrix
C = np.dot(np.transpose(P), Q)
# Computation of the optimal rotation matrix
# This can be done using singular value decomposition (SVD)
# Getting the sign of the det(V)*(W) to decide
# whether we need to correct our rotation matrix to ensure a
# right-handed coordinate system.
# And finally calculating the optimal rotation matrix U
# see http://en.wikipedia.org/wiki/Kabsch_algorithm
V, S, W = np.linalg.svd(C)
d = (np.linalg.det(V) * np.linalg.det(W)) < 0.0
if d:
S[-1] = -S[-1]
V[:, -1] = -V[:, -1]
# Create Rotation matrix U
U = np.dot(V, W)
return U
# have been validated with kabsch from RefineGNN
def kabsch(a, b):
# find optimal rotation matrix to transform a into b
# a, b are both [N, 3]
# a_aligned = aR + t
a, b = np.array(a), np.array(b)
a_mean = np.mean(a, axis=0)
b_mean = np.mean(b, axis=0)
a_c = a - a_mean
b_c = b - b_mean
rotation = kabsch_rotation(a_c, b_c)
# a_aligned = np.dot(a_c, rotation)
# t = b_mean - np.mean(a_aligned, axis=0)
# a_aligned += t
t = b_mean - np.dot(a_mean, rotation)
a_aligned = np.dot(a, rotation) + t
return a_aligned, rotation, t
# a: [N, 3], b: [N, 3]
def compute_rmsd(a, b, aligned=False): # amino acids level rmsd
if aligned:
a_aligned = a
else:
a_aligned, _, _ = kabsch(a, b)
dist = np.sum((a_aligned - b) ** 2, axis=-1)
rmsd = np.sqrt(dist.sum() / a.shape[0])
return float(rmsd)
def kabsch_torch(A, B, requires_grad=False):
"""
See: https://en.wikipedia.org/wiki/Kabsch_algorithm
2-D or 3-D registration with known correspondences.
Registration occurs in the zero centered coordinate system, and then
must be transported back.
Args:
- A: Torch tensor of shape (N,D) -- Point Cloud to Align (source)
- B: Torch tensor of shape (N,D) -- Reference Point Cloud (target)
Returns:
- R: optimal rotation
- t: optimal translation
Test on rotation + translation and on rotation + translation + reflection
>>> A = torch.tensor([[1., 1.], [2., 2.], [1.5, 3.]], dtype=torch.float)
>>> R0 = torch.tensor([[np.cos(60), -np.sin(60)], [np.sin(60), np.cos(60)]], dtype=torch.float)
>>> B = (R0.mm(A.T)).T
>>> t0 = torch.tensor([3., 3.])
>>> B += t0
>>> R, t = find_rigid_alignment(A, B)
>>> A_aligned = (R.mm(A.T)).T + t
>>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
>>> rmsd
tensor(3.7064e-07)
>>> B *= torch.tensor([-1., 1.])
>>> R, t = find_rigid_alignment(A, B)
>>> A_aligned = (R.mm(A.T)).T + t
>>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
>>> rmsd
tensor(3.7064e-07)
"""
a_mean = A.mean(axis=0)
b_mean = B.mean(axis=0)
A_c = A - a_mean
B_c = B - b_mean
# Covariance matrix
H = A_c.T.mm(B_c)
# U, S, V = torch.svd(H)
if requires_grad: # try more times to find a stable solution
assert not torch.isnan(H).any()
U, S, Vt = torch.linalg.svd(H)
num_it = 0
while torch.min(S) < 1e-3 or torch.min(
torch.abs((S ** 2).view(1, 3) - (S ** 2).view(3, 1) + torch.eye(3).to(S.device))) < 1e-2:
H = H + torch.rand(3, 3).to(H.device) * torch.eye(3).to(H.device)
U, S, Vt = torch.linalg.svd(H)
num_it += 1
if num_it > 10:
raise RuntimeError('SVD consistently numerically unstable! Exitting ... ')
else:
U, S, Vt = torch.linalg.svd(H)
V = Vt.T
# rms
d = (torch.linalg.det(U) * torch.linalg.det(V)) < 0.0
if d:
SS = torch.diag(torch.tensor([1. for _ in range(len(U) - 1)] + [-1.], device=U.device, dtype=U.dtype))
U = U @ SS
# U[:, -1] = -U[:, -1]
# Rotation matrix
R = V.mm(U.T)
# Translation vector
t = b_mean[None, :] - R.mm(a_mean[None, :].T).T
t = (t.T).squeeze()
return R.mm(A.T).T + t, R, t
def batch_kabsch_torch(A, B):
'''
A: [B, N, 3]
B: [B, N, 3]
'''
a_mean = A.mean(dim=1, keepdims=True)
b_mean = B.mean(dim=1, keepdims=True)
A_c = A - a_mean
B_c = B - b_mean
# Covariance matrix
H = torch.bmm(A_c.transpose(1, 2), B_c) # [B, 3, 3]
U, S, Vt = torch.linalg.svd(H) # [B, 3, 3]
V = Vt.transpose(1, 2)
# rms
d = ((torch.linalg.det(U) * torch.linalg.det(V)) < 0.0).long() # [B]
nSS = torch.diag(torch.tensor([1. for _ in range(len(U))], device=U.device, dtype=U.dtype))
SS = torch.diag(torch.tensor([1. for _ in range(len(U) - 1)] + [-1.], device=U.device, dtype=U.dtype))
bSS = torch.stack([nSS, SS], dim=0)[d] # [B, 3, 3]
U = torch.bmm(U, bSS)
# Rotation matrix
R = torch.bmm(V, U.transpose(1, 2)) # [B, 3, 3]
# Translation vector
t = b_mean - torch.bmm(R, a_mean.transpose(1, 2)).transpose(1, 2)
A_aligned = torch.bmm(R, A.transpose(1, 2)).transpose(1, 2) + t
return A_aligned, R, t |