File size: 4,886 Bytes
7968cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import pickle
import numpy as np
import json
import pandas as pd
from data.scripts.extract_rmsf_labels import extract_rmsf_labels, extract_bfactor_labels, extract_plddt_labels
import yaml
from tqdm import tqdm
import os
def get_flucts_from_pickle(f):
    return pickle.load(f)

def get_flucts_from_jsonl(f):
    _flucts = f.readlines()
    pdb_code_to_fluct_dict = {}
    for line in _flucts:
        json_obj = json.loads(line.strip())
        pdb_code_to_fluct_dict[json_obj['pdb_name']] = np.array(json_obj['fluctuations'])
    return pdb_code_to_fluct_dict

def read_flexpert_predictions(path):
    with open(path, 'r') as f:
        lines = f.readlines()
        pdb_code_to_fluct_dict = {}

        for name_line, fluct_line in zip(lines[::2], lines[1::2]):
            _name = name_line.strip().strip('>')
            if '.' in _name:
                _name = _name.replace('.', '_')
            pdb_code_to_fluct_dict[_name] = np.array(fluct_line.strip().split(','), dtype=np.float32)
    return pdb_code_to_fluct_dict

if __name__ == "__main__":

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--evaluate_flexpert', action='store_true', default=False)
    args = parser.parse_args()



    config = yaml.load(open('configs/data_config.yaml', 'r'), Loader=yaml.FullLoader)
    DATA_DIR = config['precomputed_flexibility_profiles_dir']


    if args.evaluate_flexpert:
        flexpert_3d_predictions_path = config['flexpert_3d_predictions_path']
        flexpert_seq_predictions_path = config['flexpert_seq_predictions_path']
        assert os.path.exists(flexpert_3d_predictions_path), f"Flexpert-3D predictions file does not exist: {flexpert_3d_predictions_path}"
        assert os.path.exists(flexpert_seq_predictions_path), f"Flexpert-Seq predictions file does not exist: {flexpert_seq_predictions_path}"
        flexpert_3d_predictions = read_flexpert_predictions(flexpert_3d_predictions_path)
        flexpert_seq_predictions = read_flexpert_predictions(flexpert_seq_predictions_path)

    with open(f'{DATA_DIR}/anm_square_fluctuations.pickle','rb') as f:
        anm_sqFlucts = get_flucts_from_pickle(f)

    with open(f'{DATA_DIR}/gnm_square_fluctuations.pickle','rb') as f:
        gnm_sqFlucts = get_flucts_from_pickle(f)

    with open(f'{DATA_DIR}/atlas_esm_plddt.jsonl','rb') as f:
        esm_plddt = get_flucts_from_jsonl(f)

    atlas_list_path = config['pdb_codes_path']
    atlas_analyses_dir = config['atlas_out_dir']

    atlas_bfactor_path = atlas_analyses_dir + "/{}_analysis/{}_Bfactor.tsv"
    atlas_plddt_path = atlas_analyses_dir + "/{}_analysis/{}_pLDDT.tsv"
    atlas_rmsf_path = atlas_analyses_dir + "/{}_analysis/{}_RMSF.tsv"

    with open(atlas_list_path,'r') as f:
        atlas_list = f.readlines()
        atlas_list = [a.strip() for a in atlas_list]

    fluctuations = {}

    if args.evaluate_flexpert:
        print("Filtering to testset only, to evaluate Flexpert-3D and Flexpert-Seq predictions")
        atlas_list = [a for a in atlas_list if a in flexpert_seq_predictions.keys()]

    for key in tqdm(atlas_list):
        fluctuations[key] = pd.DataFrame()
        fluctuations[key]['prody_ANM'] = np.sqrt(anm_sqFlucts.get(key, np.nan))
        fluctuations[key]['prody_GNM'] = np.sqrt(gnm_sqFlucts.get(key, np.nan))
        fluctuations[key]['esm_plddt'] = 1 - esm_plddt.get(key, np.nan)
        fluctuations[key]['rmsf'] = extract_rmsf_labels(atlas_rmsf_path.format(key, key))[1]
        fluctuations[key]['bfactor'] = extract_bfactor_labels(atlas_bfactor_path.format(key, key))[1]
        fluctuations[key]['af2_plddt'] = 1 - extract_plddt_labels(atlas_plddt_path.format(key, key))[1]
        if args.evaluate_flexpert and key in flexpert_seq_predictions.keys():
            fluctuations[key]['flexpert_3d'] = flexpert_3d_predictions.get(key)
            fluctuations[key]['flexpert_seq'] = flexpert_seq_predictions.get(key)

    pearson_correlations = []

    for pdb_code,df in fluctuations.items():
        cols = ['rmsf', 'bfactor', 'af2_plddt', 'esm_plddt', 'prody_GNM', 'prody_ANM']
        if args.evaluate_flexpert:
            cols.append('flexpert_3d')
            cols.append('flexpert_seq')

        pc = df[cols].corr(method='pearson')
        if  np.any(np.isnan(pc)):
            print(f'{pdb_code} has NaN values in Pearson correlation')
            continue
        pearson_correlations.append(pc)

    #compute average across all pdb codes
    columns = ['rmsf', 'bfactor', 'af2_plddt', 'esm_plddt', 'prody_GNM', 'prody_ANM']
    if args.evaluate_flexpert:
        columns.append('flexpert_3d')
        columns.append('flexpert_seq')
    print("Pearson correlations:")
    pearson_mean = np.mean(pearson_correlations, axis=0)
    pearson_mean_rounded = np.round(pearson_mean, 2)
    print(pd.DataFrame(pearson_mean_rounded, index=columns, columns=columns))
    print("\n")