File size: 9,389 Bytes
7968cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
#import dependencies
import os
import yaml
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import copy
import random
import warnings
import json
import tempfile
import matplotlib.pyplot as plt
from io import StringIO
from collections.abc import Mapping
from dataclasses import dataclass
from random import randint
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from scipy import stats
from Bio import SeqIO
from tqdm import tqdm

from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.utils.data import DataLoader

import transformers
from transformers import T5EncoderModel, T5Tokenizer, TrainingArguments, Trainer, set_seed
from transformers.modeling_outputs import TokenClassifierOutput
from transformers import T5Config, T5PreTrainedModel
from transformers.models.t5.modeling_t5 import T5Stack
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map

from datasets import Dataset

import wandb
import argparse
from datetime import datetime

# from utils.lora_utils import LoRAConfig, modify_with_lora
from utils.utils import (
    ClassConfig, ENMAdaptedTrainer, set_seeds, create_dataset, save_finetuned_model, 
    DataCollatorForTokenRegression, do_topology_split, update_config, compute_metrics
)
from models.T5_encoder_per_token import PT5_classification_model, T5EncoderForTokenClassification
from models.enm_adaptor_heads import (
    ENMAdaptedAttentionClassifier, ENMAdaptedDirectClassifier, 
    ENMAdaptedConvClassifier, ENMNoAdaptorClassifier
)

def parse_args():
    parser = argparse.ArgumentParser(description='Train a model on the CATH dataset')
    # Required arguments
    parser.add_argument('--run_name', type=str, required=True, help='Name of the run.')
    parser.add_argument('--adaptor_architecture', type=str, required=True, help='What model to use to adapt the ENM values into the sequence latent space.')

    # Optional arguments
    parser.add_argument('--data_path', type=str, help='Path to the data file')
    parser.add_argument('--batch_size', type=int, help='Size of the batch for training.')
    parser.add_argument('--epochs', type=int, help='Number of epochs for training.')
    parser.add_argument('--save_steps', type=int, help='After how many training steps to save the checkpoint.')
    parser.add_argument('--add_pearson_loss', action='store_true', help='If provided, Pearson correlation term will be added to the loss function.')
    parser.add_argument('--add_sse_loss', action='store_true', help='If provided, term forcing the model to predict same values along sse blocks will be added to the loss function.')
    parser.add_argument('--fasta_path', type=str, help='Path to the FASTA file with the AA sequences for the dataset.')
    parser.add_argument('--enm_path', type=str, help='Path to the enm file with precomputed flexibilities (ENM).')
    parser.add_argument('--splits_path', type=str, help='Path to the file with the data splits.')
    
    #Optional ENM adaptor arguments
    parser.add_argument('--enm_embed_dim', type=int, help='Dimension of the ENM embedding / number of conv filters.')
    parser.add_argument('--enm_att_heads', type=int, help='Number of attention heads for the ENM embedding.')
    parser.add_argument('--num_layers', type=int, help='Number of conv layers in the ENM adaptor.')
    parser.add_argument('--kernel_size', type=int, help='Size of the convolutional kernels in the ENM adaptor.')
    parser.add_argument('--mixed_precision', action='store_true', help='Enable mixed precision training.')
    parser.add_argument('--gradient_accumulation_steps', type=int, help='Number of steps to accumulate gradients before performing a backward/update pass.')
    return parser.parse_args()

def preprocess_data(tokenizer, train, valid, test):

    train = train[["sequence", "label", "enm_vals"]]
    valid = valid[["sequence", "label", "enm_vals"]]
    test = test[["sequence", "label", "enm_vals"]]
    
    train.reset_index(drop=True,inplace=True)
    valid.reset_index(drop=True,inplace=True)
    test.reset_index(drop=True,inplace=True)

    # Replace invalid labels (>900) with -100 (will be ignored by pytorch loss)
    train['label'] = train.apply(lambda row:  [-100 if x > 900 else x for x in row['label']], axis=1)
    valid['label'] = valid.apply(lambda row:  [-100 if x > 900 else x for x in row['label']], axis=1)
    test['label'] = test.apply(lambda row:  [-100 if x > 900 else x for x in row['label']], axis=1)

    # Preprocess inputs for the model
    # Replace uncommon AAs with "X"
    train["sequence"]=train["sequence"].str.replace('|'.join(["O","B","U","Z","-"]),"X",regex=True)
    valid["sequence"]=valid["sequence"].str.replace('|'.join(["O","B","U","Z","-"]),"X",regex=True)
    # Add spaces between each amino acid for PT5 to correctly use them
    train['sequence']=train.apply(lambda row : " ".join(row["sequence"]), axis = 1)
    valid['sequence']=valid.apply(lambda row : " ".join(row["sequence"]), axis = 1)


    # Create Datasets
    train_set=create_dataset(tokenizer,list(train['sequence']),list(train['label']),list(train['enm_vals']))
    valid_set=create_dataset(tokenizer,list(valid['sequence']),list(valid['label']),list(valid['enm_vals']))

    return train_set, valid_set, test

if __name__=='__main__':
    ### Read and update config
    args = parse_args()
    config = yaml.load(open('configs/train_config.yaml', 'r'), Loader=yaml.FullLoader)
    config = update_config(config, args)
    # Update training arguments
    config['training_args']['run_name'] = config['run_name']
    config['training_args']['output_dir'] = config['training_args']['output_dir'].format(
        run_name=config['run_name'],
        timestamp=datetime.now().strftime("%Y%m%d_%H%M%S")
    )
    config['training_args']['fp16'] = config['mixed_precision']
    config['training_args']['gradient_accumulation_steps'] = config['gradient_accumulation_steps']
    config['training_args']['num_train_epochs'] = config['epochs']
    config['training_args']['per_device_train_batch_size'] = config['batch_size']
    config['training_args']['per_device_eval_batch_size'] = config['batch_size']
    config['training_args']['eval_steps'] = config['training_args']['save_steps']

    print("Training with the following config: \n", config)

    env_config = yaml.load(open('configs/env_config.yaml', 'r'), Loader=yaml.FullLoader)

    ### Set environment variables
    # Set folder for huggingface cache
    os.environ['HF_HOME'] = env_config['huggingface']['HF_HOME']
    # Set gpu device
    os.environ["CUDA_VISIBLE_DEVICES"]= env_config['gpus']['cuda_visible_device']
    
    ### Initialize wandb
    wandb.init(project=env_config['wandb']['project'], name=config['run_name'], config = config)

    ### Load data - into dataframe
    DATA_PATH = config['data_path']
    FASTA_PATH = config['fasta_path']
    ENM_PATH = config['enm_path']
    SPLITS_PATH = config['splits_path']

    sequences, names, labels, enm_vals = [], [], [], []

    with open(FASTA_PATH, "r") as fasta_file:
        # Load FASTA file using Biopython
        for record in SeqIO.parse(fasta_file, "fasta"):
            sequences.append([record.name, str(record.seq)])
        # Create dataframe
        df = pd.DataFrame(sequences, columns=["name", "sequence"])

    with open(ENM_PATH, "r") as f:
        enm_lines = f.readlines()
        enm_vals_dict={}
        for l in enm_lines:
            _d = json.loads(l)
            _key = ".".join(_d['pdb_name'].split("_"))
            enm_vals_dict[_key] = _d['fluctuations']
            enm_vals_dict[_key].append(0.0)

    with open(DATA_PATH, "r") as f:
        lines = f.readlines()
        # Split each line into name and label
        for l in lines:
            _split_line = l.split(":\t")
            names.append(_split_line[0])
            labels.append([float(label) for label in _split_line[1].split(", ")])
            enm_vals.append(enm_vals_dict[_split_line[0]])

    # Add label and enm_vals columns
    df["label"] = labels
    df["enm_vals"] = enm_vals
    
    ### Set all random seeds
    set_seeds(config['seed'])
        
    ### Load model
    class_config=ClassConfig(config)
    model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config)
    
    ### Split data into train, valid, test and preprocess
    train,valid,test = do_topology_split(df, SPLITS_PATH)
    train_set, valid_set, test = preprocess_data(tokenizer, train, valid, test)
    
    ### Set training arguments
    training_args = TrainingArguments(**config['training_args'])
    
    ### For token classification (regression) we need a data collator here to pad correctly
    data_collator = DataCollatorForTokenRegression(tokenizer)

    ### Trainer          
    trainer = ENMAdaptedTrainer(
            model,
            training_args,
            train_dataset=train_set,
            eval_dataset=valid_set,
            tokenizer=tokenizer,
            data_collator=data_collator,
            compute_metrics=compute_metrics
        )

    ### Train model and save
    trainer.train()
    save_finetuned_model(trainer.model,config['training_args']['output_dir'])