flexpert / get_correlation_analysis.py
Honzus24's picture
initial commit
7968cb0
import pickle
import numpy as np
import json
import pandas as pd
from data.scripts.extract_rmsf_labels import extract_rmsf_labels, extract_bfactor_labels, extract_plddt_labels
import yaml
from tqdm import tqdm
import os
def get_flucts_from_pickle(f):
return pickle.load(f)
def get_flucts_from_jsonl(f):
_flucts = f.readlines()
pdb_code_to_fluct_dict = {}
for line in _flucts:
json_obj = json.loads(line.strip())
pdb_code_to_fluct_dict[json_obj['pdb_name']] = np.array(json_obj['fluctuations'])
return pdb_code_to_fluct_dict
def read_flexpert_predictions(path):
with open(path, 'r') as f:
lines = f.readlines()
pdb_code_to_fluct_dict = {}
for name_line, fluct_line in zip(lines[::2], lines[1::2]):
_name = name_line.strip().strip('>')
if '.' in _name:
_name = _name.replace('.', '_')
pdb_code_to_fluct_dict[_name] = np.array(fluct_line.strip().split(','), dtype=np.float32)
return pdb_code_to_fluct_dict
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--evaluate_flexpert', action='store_true', default=False)
args = parser.parse_args()
config = yaml.load(open('configs/data_config.yaml', 'r'), Loader=yaml.FullLoader)
DATA_DIR = config['precomputed_flexibility_profiles_dir']
if args.evaluate_flexpert:
flexpert_3d_predictions_path = config['flexpert_3d_predictions_path']
flexpert_seq_predictions_path = config['flexpert_seq_predictions_path']
assert os.path.exists(flexpert_3d_predictions_path), f"Flexpert-3D predictions file does not exist: {flexpert_3d_predictions_path}"
assert os.path.exists(flexpert_seq_predictions_path), f"Flexpert-Seq predictions file does not exist: {flexpert_seq_predictions_path}"
flexpert_3d_predictions = read_flexpert_predictions(flexpert_3d_predictions_path)
flexpert_seq_predictions = read_flexpert_predictions(flexpert_seq_predictions_path)
with open(f'{DATA_DIR}/anm_square_fluctuations.pickle','rb') as f:
anm_sqFlucts = get_flucts_from_pickle(f)
with open(f'{DATA_DIR}/gnm_square_fluctuations.pickle','rb') as f:
gnm_sqFlucts = get_flucts_from_pickle(f)
with open(f'{DATA_DIR}/atlas_esm_plddt.jsonl','rb') as f:
esm_plddt = get_flucts_from_jsonl(f)
atlas_list_path = config['pdb_codes_path']
atlas_analyses_dir = config['atlas_out_dir']
atlas_bfactor_path = atlas_analyses_dir + "/{}_analysis/{}_Bfactor.tsv"
atlas_plddt_path = atlas_analyses_dir + "/{}_analysis/{}_pLDDT.tsv"
atlas_rmsf_path = atlas_analyses_dir + "/{}_analysis/{}_RMSF.tsv"
with open(atlas_list_path,'r') as f:
atlas_list = f.readlines()
atlas_list = [a.strip() for a in atlas_list]
fluctuations = {}
if args.evaluate_flexpert:
print("Filtering to testset only, to evaluate Flexpert-3D and Flexpert-Seq predictions")
atlas_list = [a for a in atlas_list if a in flexpert_seq_predictions.keys()]
for key in tqdm(atlas_list):
fluctuations[key] = pd.DataFrame()
fluctuations[key]['prody_ANM'] = np.sqrt(anm_sqFlucts.get(key, np.nan))
fluctuations[key]['prody_GNM'] = np.sqrt(gnm_sqFlucts.get(key, np.nan))
fluctuations[key]['esm_plddt'] = 1 - esm_plddt.get(key, np.nan)
fluctuations[key]['rmsf'] = extract_rmsf_labels(atlas_rmsf_path.format(key, key))[1]
fluctuations[key]['bfactor'] = extract_bfactor_labels(atlas_bfactor_path.format(key, key))[1]
fluctuations[key]['af2_plddt'] = 1 - extract_plddt_labels(atlas_plddt_path.format(key, key))[1]
if args.evaluate_flexpert and key in flexpert_seq_predictions.keys():
fluctuations[key]['flexpert_3d'] = flexpert_3d_predictions.get(key)
fluctuations[key]['flexpert_seq'] = flexpert_seq_predictions.get(key)
pearson_correlations = []
for pdb_code,df in fluctuations.items():
cols = ['rmsf', 'bfactor', 'af2_plddt', 'esm_plddt', 'prody_GNM', 'prody_ANM']
if args.evaluate_flexpert:
cols.append('flexpert_3d')
cols.append('flexpert_seq')
pc = df[cols].corr(method='pearson')
if np.any(np.isnan(pc)):
print(f'{pdb_code} has NaN values in Pearson correlation')
continue
pearson_correlations.append(pc)
#compute average across all pdb codes
columns = ['rmsf', 'bfactor', 'af2_plddt', 'esm_plddt', 'prody_GNM', 'prody_ANM']
if args.evaluate_flexpert:
columns.append('flexpert_3d')
columns.append('flexpert_seq')
print("Pearson correlations:")
pearson_mean = np.mean(pearson_correlations, axis=0)
pearson_mean_rounded = np.round(pearson_mean, 2)
print(pd.DataFrame(pearson_mean_rounded, index=columns, columns=columns))
print("\n")