File size: 9,389 Bytes
7968cb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
#import dependencies
import os
import yaml
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import copy
import random
import warnings
import json
import tempfile
import matplotlib.pyplot as plt
from io import StringIO
from collections.abc import Mapping
from dataclasses import dataclass
from random import randint
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from scipy import stats
from Bio import SeqIO
from tqdm import tqdm
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.utils.data import DataLoader
import transformers
from transformers import T5EncoderModel, T5Tokenizer, TrainingArguments, Trainer, set_seed
from transformers.modeling_outputs import TokenClassifierOutput
from transformers import T5Config, T5PreTrainedModel
from transformers.models.t5.modeling_t5 import T5Stack
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from datasets import Dataset
import wandb
import argparse
from datetime import datetime
# from utils.lora_utils import LoRAConfig, modify_with_lora
from utils.utils import (
ClassConfig, ENMAdaptedTrainer, set_seeds, create_dataset, save_finetuned_model,
DataCollatorForTokenRegression, do_topology_split, update_config, compute_metrics
)
from models.T5_encoder_per_token import PT5_classification_model, T5EncoderForTokenClassification
from models.enm_adaptor_heads import (
ENMAdaptedAttentionClassifier, ENMAdaptedDirectClassifier,
ENMAdaptedConvClassifier, ENMNoAdaptorClassifier
)
def parse_args():
parser = argparse.ArgumentParser(description='Train a model on the CATH dataset')
# Required arguments
parser.add_argument('--run_name', type=str, required=True, help='Name of the run.')
parser.add_argument('--adaptor_architecture', type=str, required=True, help='What model to use to adapt the ENM values into the sequence latent space.')
# Optional arguments
parser.add_argument('--data_path', type=str, help='Path to the data file')
parser.add_argument('--batch_size', type=int, help='Size of the batch for training.')
parser.add_argument('--epochs', type=int, help='Number of epochs for training.')
parser.add_argument('--save_steps', type=int, help='After how many training steps to save the checkpoint.')
parser.add_argument('--add_pearson_loss', action='store_true', help='If provided, Pearson correlation term will be added to the loss function.')
parser.add_argument('--add_sse_loss', action='store_true', help='If provided, term forcing the model to predict same values along sse blocks will be added to the loss function.')
parser.add_argument('--fasta_path', type=str, help='Path to the FASTA file with the AA sequences for the dataset.')
parser.add_argument('--enm_path', type=str, help='Path to the enm file with precomputed flexibilities (ENM).')
parser.add_argument('--splits_path', type=str, help='Path to the file with the data splits.')
#Optional ENM adaptor arguments
parser.add_argument('--enm_embed_dim', type=int, help='Dimension of the ENM embedding / number of conv filters.')
parser.add_argument('--enm_att_heads', type=int, help='Number of attention heads for the ENM embedding.')
parser.add_argument('--num_layers', type=int, help='Number of conv layers in the ENM adaptor.')
parser.add_argument('--kernel_size', type=int, help='Size of the convolutional kernels in the ENM adaptor.')
parser.add_argument('--mixed_precision', action='store_true', help='Enable mixed precision training.')
parser.add_argument('--gradient_accumulation_steps', type=int, help='Number of steps to accumulate gradients before performing a backward/update pass.')
return parser.parse_args()
def preprocess_data(tokenizer, train, valid, test):
train = train[["sequence", "label", "enm_vals"]]
valid = valid[["sequence", "label", "enm_vals"]]
test = test[["sequence", "label", "enm_vals"]]
train.reset_index(drop=True,inplace=True)
valid.reset_index(drop=True,inplace=True)
test.reset_index(drop=True,inplace=True)
# Replace invalid labels (>900) with -100 (will be ignored by pytorch loss)
train['label'] = train.apply(lambda row: [-100 if x > 900 else x for x in row['label']], axis=1)
valid['label'] = valid.apply(lambda row: [-100 if x > 900 else x for x in row['label']], axis=1)
test['label'] = test.apply(lambda row: [-100 if x > 900 else x for x in row['label']], axis=1)
# Preprocess inputs for the model
# Replace uncommon AAs with "X"
train["sequence"]=train["sequence"].str.replace('|'.join(["O","B","U","Z","-"]),"X",regex=True)
valid["sequence"]=valid["sequence"].str.replace('|'.join(["O","B","U","Z","-"]),"X",regex=True)
# Add spaces between each amino acid for PT5 to correctly use them
train['sequence']=train.apply(lambda row : " ".join(row["sequence"]), axis = 1)
valid['sequence']=valid.apply(lambda row : " ".join(row["sequence"]), axis = 1)
# Create Datasets
train_set=create_dataset(tokenizer,list(train['sequence']),list(train['label']),list(train['enm_vals']))
valid_set=create_dataset(tokenizer,list(valid['sequence']),list(valid['label']),list(valid['enm_vals']))
return train_set, valid_set, test
if __name__=='__main__':
### Read and update config
args = parse_args()
config = yaml.load(open('configs/train_config.yaml', 'r'), Loader=yaml.FullLoader)
config = update_config(config, args)
# Update training arguments
config['training_args']['run_name'] = config['run_name']
config['training_args']['output_dir'] = config['training_args']['output_dir'].format(
run_name=config['run_name'],
timestamp=datetime.now().strftime("%Y%m%d_%H%M%S")
)
config['training_args']['fp16'] = config['mixed_precision']
config['training_args']['gradient_accumulation_steps'] = config['gradient_accumulation_steps']
config['training_args']['num_train_epochs'] = config['epochs']
config['training_args']['per_device_train_batch_size'] = config['batch_size']
config['training_args']['per_device_eval_batch_size'] = config['batch_size']
config['training_args']['eval_steps'] = config['training_args']['save_steps']
print("Training with the following config: \n", config)
env_config = yaml.load(open('configs/env_config.yaml', 'r'), Loader=yaml.FullLoader)
### Set environment variables
# Set folder for huggingface cache
os.environ['HF_HOME'] = env_config['huggingface']['HF_HOME']
# Set gpu device
os.environ["CUDA_VISIBLE_DEVICES"]= env_config['gpus']['cuda_visible_device']
### Initialize wandb
wandb.init(project=env_config['wandb']['project'], name=config['run_name'], config = config)
### Load data - into dataframe
DATA_PATH = config['data_path']
FASTA_PATH = config['fasta_path']
ENM_PATH = config['enm_path']
SPLITS_PATH = config['splits_path']
sequences, names, labels, enm_vals = [], [], [], []
with open(FASTA_PATH, "r") as fasta_file:
# Load FASTA file using Biopython
for record in SeqIO.parse(fasta_file, "fasta"):
sequences.append([record.name, str(record.seq)])
# Create dataframe
df = pd.DataFrame(sequences, columns=["name", "sequence"])
with open(ENM_PATH, "r") as f:
enm_lines = f.readlines()
enm_vals_dict={}
for l in enm_lines:
_d = json.loads(l)
_key = ".".join(_d['pdb_name'].split("_"))
enm_vals_dict[_key] = _d['fluctuations']
enm_vals_dict[_key].append(0.0)
with open(DATA_PATH, "r") as f:
lines = f.readlines()
# Split each line into name and label
for l in lines:
_split_line = l.split(":\t")
names.append(_split_line[0])
labels.append([float(label) for label in _split_line[1].split(", ")])
enm_vals.append(enm_vals_dict[_split_line[0]])
# Add label and enm_vals columns
df["label"] = labels
df["enm_vals"] = enm_vals
### Set all random seeds
set_seeds(config['seed'])
### Load model
class_config=ClassConfig(config)
model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config)
### Split data into train, valid, test and preprocess
train,valid,test = do_topology_split(df, SPLITS_PATH)
train_set, valid_set, test = preprocess_data(tokenizer, train, valid, test)
### Set training arguments
training_args = TrainingArguments(**config['training_args'])
### For token classification (regression) we need a data collator here to pad correctly
data_collator = DataCollatorForTokenRegression(tokenizer)
### Trainer
trainer = ENMAdaptedTrainer(
model,
training_args,
train_dataset=train_set,
eval_dataset=valid_set,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics
)
### Train model and save
trainer.train()
save_finetuned_model(trainer.model,config['training_args']['output_dir']) |