|
|
import pickle |
|
|
import numpy as np |
|
|
import json |
|
|
import pandas as pd |
|
|
from data.scripts.extract_rmsf_labels import extract_rmsf_labels, extract_bfactor_labels, extract_plddt_labels |
|
|
import yaml |
|
|
from tqdm import tqdm |
|
|
import os |
|
|
def get_flucts_from_pickle(f): |
|
|
return pickle.load(f) |
|
|
|
|
|
def get_flucts_from_jsonl(f): |
|
|
_flucts = f.readlines() |
|
|
pdb_code_to_fluct_dict = {} |
|
|
for line in _flucts: |
|
|
json_obj = json.loads(line.strip()) |
|
|
pdb_code_to_fluct_dict[json_obj['pdb_name']] = np.array(json_obj['fluctuations']) |
|
|
return pdb_code_to_fluct_dict |
|
|
|
|
|
def read_flexpert_predictions(path): |
|
|
with open(path, 'r') as f: |
|
|
lines = f.readlines() |
|
|
pdb_code_to_fluct_dict = {} |
|
|
|
|
|
for name_line, fluct_line in zip(lines[::2], lines[1::2]): |
|
|
_name = name_line.strip().strip('>') |
|
|
if '.' in _name: |
|
|
_name = _name.replace('.', '_') |
|
|
pdb_code_to_fluct_dict[_name] = np.array(fluct_line.strip().split(','), dtype=np.float32) |
|
|
return pdb_code_to_fluct_dict |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
import argparse |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--evaluate_flexpert', action='store_true', default=False) |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
config = yaml.load(open('configs/data_config.yaml', 'r'), Loader=yaml.FullLoader) |
|
|
DATA_DIR = config['precomputed_flexibility_profiles_dir'] |
|
|
|
|
|
|
|
|
if args.evaluate_flexpert: |
|
|
flexpert_3d_predictions_path = config['flexpert_3d_predictions_path'] |
|
|
flexpert_seq_predictions_path = config['flexpert_seq_predictions_path'] |
|
|
assert os.path.exists(flexpert_3d_predictions_path), f"Flexpert-3D predictions file does not exist: {flexpert_3d_predictions_path}" |
|
|
assert os.path.exists(flexpert_seq_predictions_path), f"Flexpert-Seq predictions file does not exist: {flexpert_seq_predictions_path}" |
|
|
flexpert_3d_predictions = read_flexpert_predictions(flexpert_3d_predictions_path) |
|
|
flexpert_seq_predictions = read_flexpert_predictions(flexpert_seq_predictions_path) |
|
|
|
|
|
with open(f'{DATA_DIR}/anm_square_fluctuations.pickle','rb') as f: |
|
|
anm_sqFlucts = get_flucts_from_pickle(f) |
|
|
|
|
|
with open(f'{DATA_DIR}/gnm_square_fluctuations.pickle','rb') as f: |
|
|
gnm_sqFlucts = get_flucts_from_pickle(f) |
|
|
|
|
|
with open(f'{DATA_DIR}/atlas_esm_plddt.jsonl','rb') as f: |
|
|
esm_plddt = get_flucts_from_jsonl(f) |
|
|
|
|
|
atlas_list_path = config['pdb_codes_path'] |
|
|
atlas_analyses_dir = config['atlas_out_dir'] |
|
|
|
|
|
atlas_bfactor_path = atlas_analyses_dir + "/{}_analysis/{}_Bfactor.tsv" |
|
|
atlas_plddt_path = atlas_analyses_dir + "/{}_analysis/{}_pLDDT.tsv" |
|
|
atlas_rmsf_path = atlas_analyses_dir + "/{}_analysis/{}_RMSF.tsv" |
|
|
|
|
|
with open(atlas_list_path,'r') as f: |
|
|
atlas_list = f.readlines() |
|
|
atlas_list = [a.strip() for a in atlas_list] |
|
|
|
|
|
fluctuations = {} |
|
|
|
|
|
if args.evaluate_flexpert: |
|
|
print("Filtering to testset only, to evaluate Flexpert-3D and Flexpert-Seq predictions") |
|
|
atlas_list = [a for a in atlas_list if a in flexpert_seq_predictions.keys()] |
|
|
|
|
|
for key in tqdm(atlas_list): |
|
|
fluctuations[key] = pd.DataFrame() |
|
|
fluctuations[key]['prody_ANM'] = np.sqrt(anm_sqFlucts.get(key, np.nan)) |
|
|
fluctuations[key]['prody_GNM'] = np.sqrt(gnm_sqFlucts.get(key, np.nan)) |
|
|
fluctuations[key]['esm_plddt'] = 1 - esm_plddt.get(key, np.nan) |
|
|
fluctuations[key]['rmsf'] = extract_rmsf_labels(atlas_rmsf_path.format(key, key))[1] |
|
|
fluctuations[key]['bfactor'] = extract_bfactor_labels(atlas_bfactor_path.format(key, key))[1] |
|
|
fluctuations[key]['af2_plddt'] = 1 - extract_plddt_labels(atlas_plddt_path.format(key, key))[1] |
|
|
if args.evaluate_flexpert and key in flexpert_seq_predictions.keys(): |
|
|
fluctuations[key]['flexpert_3d'] = flexpert_3d_predictions.get(key) |
|
|
fluctuations[key]['flexpert_seq'] = flexpert_seq_predictions.get(key) |
|
|
|
|
|
pearson_correlations = [] |
|
|
|
|
|
for pdb_code,df in fluctuations.items(): |
|
|
cols = ['rmsf', 'bfactor', 'af2_plddt', 'esm_plddt', 'prody_GNM', 'prody_ANM'] |
|
|
if args.evaluate_flexpert: |
|
|
cols.append('flexpert_3d') |
|
|
cols.append('flexpert_seq') |
|
|
|
|
|
pc = df[cols].corr(method='pearson') |
|
|
if np.any(np.isnan(pc)): |
|
|
print(f'{pdb_code} has NaN values in Pearson correlation') |
|
|
continue |
|
|
pearson_correlations.append(pc) |
|
|
|
|
|
|
|
|
columns = ['rmsf', 'bfactor', 'af2_plddt', 'esm_plddt', 'prody_GNM', 'prody_ANM'] |
|
|
if args.evaluate_flexpert: |
|
|
columns.append('flexpert_3d') |
|
|
columns.append('flexpert_seq') |
|
|
print("Pearson correlations:") |
|
|
pearson_mean = np.mean(pearson_correlations, axis=0) |
|
|
pearson_mean_rounded = np.round(pearson_mean, 2) |
|
|
print(pd.DataFrame(pearson_mean_rounded, index=columns, columns=columns)) |
|
|
print("\n") |
|
|
|