import torch import matplotlib import numpy as np import scipy.stats as sp_stats atom_encoder = {'H': 1, 'C': 6, 'N': 7, 'O': 8, 'F': 9, 'P': 15, 'S': 16, 'Cl': 17} atom_decoder = {v: k for k, v in atom_encoder.items()} # Bond lengths from http://www.wiredchemist.com/chemistry/data/bond_energies_lengths.html bonds1 = {'H': {'H': 74, 'C': 109, 'N': 101, 'O': 96, 'F': 92, 'P': 144, 'S': 134, 'Cl': 127}, 'C': {'H': 109, 'C': 154, 'N': 147, 'O': 143, 'F': 135, 'P': 184, 'S': 182, 'Cl': 177}, 'N': {'H': 101, 'C': 147, 'N': 145, 'O': 140, 'F': 136, 'P': 177, 'S': 168, 'Cl': 175}, 'O': {'H': 96, 'C': 143, 'N': 140, 'O': 148, 'F': 142, 'P': 163, 'S': 151, 'Cl': 164}, 'F': {'H': 92, 'C': 135, 'N': 136, 'O': 142, 'F': 142, 'P': 156, 'S': 158, 'Cl': 166}, 'P': {'H': 144, 'C': 184, 'N': 177, 'O': 163, 'F': 156, 'P': 221, 'S': 210, 'Cl': 203}, 'S': {'H': 134, 'C': 182, 'N': 168, 'O': 151, 'F': 158, 'P': 210, 'S': 204, 'Cl': 207}, 'Cl': {'H': 127, 'C': 177, 'N': 175, 'O': 164, 'F': 166, 'P': 203, 'S': 207, 'Cl': 199} } bonds2 = {'H': {'H': -1, 'C': -1, 'N': -1, 'O': -1, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1}, 'C': {'H': -1, 'C': 134, 'N': 129, 'O': 120, 'F': -1, 'P': -1, 'S': 160, 'Cl': -1}, 'N': {'H': -1, 'C': 129, 'N': 125, 'O': 121, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1}, 'O': {'H': -1, 'C': 120, 'N': 121, 'O': 121, 'F': -1, 'P': 150, 'S': -1, 'Cl': -1}, 'F': {'H': -1, 'C': -1, 'N': -1, 'O': -1, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1}, 'P': {'H': -1, 'C': -1, 'N': -1, 'O': 150, 'F': -1, 'P': -1, 'S': 186, 'Cl': -1}, 'S': {'H': -1, 'C': 160, 'N': -1, 'O': -1, 'F': -1, 'P': 186, 'S': -1, 'Cl': -1}, 'Cl': {'H': -1, 'C': -1, 'N': -1, 'O': -1, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1}, } bonds3 = {'H': {'H': -1, 'C': -1, 'N': -1, 'O': -1, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1}, 'C': {'H': -1, 'C': 120, 'N': 116, 'O': 113, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1}, 'N': {'H': -1, 'C': 116, 'N': 110, 'O': -1, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1}, 'O': {'H': -1, 'C': 113, 'N': -1, 'O': -1, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1}, 'F': {'H': -1, 'C': -1, 'N': -1, 'O': -1, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1}, 'P': {'H': -1, 'C': -1, 'N': -1, 'O': -1, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1}, 'S': {'H': -1, 'C': -1, 'N': -1, 'O': -1, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1}, 'Cl': {'H': -1, 'C': -1, 'N': -1, 'O': -1, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1}, } stdv = {'H': 5, 'C': 1, 'N': 1, 'O': 2, 'F': 3} margin1, margin2, margin3 = 10, 5, 3 allowed_bonds = {'H': 1, 'C': 4, 'N': 3, 'O': 2, 'F': 1, 'P': 5, 'S': 4, 'Cl': 1} def normalize_histogram(hist): hist = np.array(hist) prob = hist / np.sum(hist) return prob def coord2distances(x): x = x.unsqueeze(2) x_t = x.transpose(1, 2) dist = (x - x_t) ** 2 dist = torch.sqrt(torch.sum(dist, 3)) dist = dist.flatten() return dist def earth_mover_distance(h1, h2): p1 = normalize_histogram(h1) p2 = normalize_histogram(h2) distance = sp_stats.wasserstein_distance(p1, p2) return distance def kl_divergence(p1, p2): return np.sum(p1 * np.log(p1 / p2)) def kl_divergence_sym(h1, h2): p1 = normalize_histogram(h1) + 1e-10 p2 = normalize_histogram(h2) + 1e-10 kl = kl_divergence(p1, p2) kl_flipped = kl_divergence(p2, p1) return (kl + kl_flipped) / 2. def js_divergence(h1, h2): p1 = normalize_histogram(h1) + 1e-10 p2 = normalize_histogram(h2) + 1e-10 M = (p1 + p2) / 2 js = (kl_divergence(p1, M) + kl_divergence(p2, M)) / 2 return js def get_bond_order(atom1, atom2, distance): distance = 100 * distance # We change the metric # margin1, margin2 and margin3 have been tuned to maximize the stability of the QM9 true samples if distance < bonds1[atom1][atom2] + margin1: thr_bond2 = bonds2[atom1][atom2] + margin2 if distance < thr_bond2: thr_bond3 = bonds3[atom1][atom2] + margin3 if distance < thr_bond3: return 3 return 2 return 1 return 0 def check_stability(positions, atom_type, debug=False, hs=False, return_nr_bonds=False): assert len(positions.shape) == 2 assert positions.shape[1] == 3 x = positions[:, 0] y = positions[:, 1] z = positions[:, 2] nr_bonds = np.zeros(len(x), dtype='int') for i in range(len(x)): for j in range(i + 1, len(x)): p1 = np.array([x[i], y[i], z[i]]) p2 = np.array([x[j], y[j], z[j]]) dist = np.sqrt(np.sum((p1 - p2) ** 2)) atom1, atom2 = atom_decoder[atom_type[i]], atom_decoder[ atom_type[j]] order = get_bond_order(atom1, atom2, dist) # if i == 0: # print(j, order) nr_bonds[i] += order nr_bonds[j] += order nr_stable_bonds = 0 for atom_type_i, nr_bonds_i in zip(atom_type, nr_bonds): if hs: is_stable = allowed_bonds[atom_decoder[atom_type_i]] == nr_bonds_i else: is_stable = (allowed_bonds[atom_decoder[atom_type_i]] >= nr_bonds_i > 0) if is_stable == False and debug: print("Invalid bonds for molecule %s with %d bonds" % (atom_decoder[atom_type_i], nr_bonds_i)) nr_stable_bonds += int(is_stable) molecule_stable = nr_stable_bonds == len(x) if return_nr_bonds: return molecule_stable, nr_stable_bonds, len(x), nr_bonds else: return molecule_stable, nr_stable_bonds, len(x) def analyze_stability_for_molecules(molecule_list): n_samples = len(molecule_list) molecule_stable_list = [] molecule_stable = 0 nr_stable_bonds = 0 n_atoms = 0 for one_hot, x in molecule_list: atom_type = one_hot.argmax(2).squeeze(0).cpu().detach().numpy() x = x.squeeze(0).cpu().detach().numpy() validity_results = check_stability(x, atom_type) molecule_stable += int(validity_results[0]) nr_stable_bonds += int(validity_results[1]) n_atoms += int(validity_results[2]) if validity_results[0]: molecule_stable_list.append((x, atom_type)) # Validity fraction_mol_stable = molecule_stable / float(n_samples) fraction_atm_stable = nr_stable_bonds / float(n_atoms) validity_dict = { 'mol_stable': fraction_mol_stable, 'atm_stable': fraction_atm_stable, } # print('Validity:', validity_dict) return validity_dict, molecule_stable_list