|
|
|
|
|
import os |
|
|
import yaml |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import copy |
|
|
import random |
|
|
import warnings |
|
|
import json |
|
|
import tempfile |
|
|
import matplotlib.pyplot as plt |
|
|
from io import StringIO |
|
|
from collections.abc import Mapping |
|
|
from dataclasses import dataclass |
|
|
from random import randint |
|
|
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union |
|
|
from sklearn.metrics import accuracy_score |
|
|
from sklearn.model_selection import train_test_split |
|
|
from scipy import stats |
|
|
from Bio import SeqIO |
|
|
from tqdm import tqdm |
|
|
|
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
import transformers |
|
|
from transformers import T5EncoderModel, T5Tokenizer, TrainingArguments, Trainer, set_seed |
|
|
from transformers.modeling_outputs import TokenClassifierOutput |
|
|
from transformers import T5Config, T5PreTrainedModel |
|
|
from transformers.models.t5.modeling_t5 import T5Stack |
|
|
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map |
|
|
|
|
|
from datasets import Dataset |
|
|
|
|
|
import wandb |
|
|
import argparse |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
from utils.utils import ( |
|
|
ClassConfig, ENMAdaptedTrainer, set_seeds, create_dataset, save_finetuned_model, |
|
|
DataCollatorForTokenRegression, do_topology_split, update_config, compute_metrics |
|
|
) |
|
|
from models.T5_encoder_per_token import PT5_classification_model, T5EncoderForTokenClassification |
|
|
from models.enm_adaptor_heads import ( |
|
|
ENMAdaptedAttentionClassifier, ENMAdaptedDirectClassifier, |
|
|
ENMAdaptedConvClassifier, ENMNoAdaptorClassifier |
|
|
) |
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description='Train a model on the CATH dataset') |
|
|
|
|
|
parser.add_argument('--run_name', type=str, required=True, help='Name of the run.') |
|
|
parser.add_argument('--adaptor_architecture', type=str, required=True, help='What model to use to adapt the ENM values into the sequence latent space.') |
|
|
|
|
|
|
|
|
parser.add_argument('--data_path', type=str, help='Path to the data file') |
|
|
parser.add_argument('--batch_size', type=int, help='Size of the batch for training.') |
|
|
parser.add_argument('--epochs', type=int, help='Number of epochs for training.') |
|
|
parser.add_argument('--save_steps', type=int, help='After how many training steps to save the checkpoint.') |
|
|
parser.add_argument('--add_pearson_loss', action='store_true', help='If provided, Pearson correlation term will be added to the loss function.') |
|
|
parser.add_argument('--add_sse_loss', action='store_true', help='If provided, term forcing the model to predict same values along sse blocks will be added to the loss function.') |
|
|
parser.add_argument('--fasta_path', type=str, help='Path to the FASTA file with the AA sequences for the dataset.') |
|
|
parser.add_argument('--enm_path', type=str, help='Path to the enm file with precomputed flexibilities (ENM).') |
|
|
parser.add_argument('--splits_path', type=str, help='Path to the file with the data splits.') |
|
|
|
|
|
|
|
|
parser.add_argument('--enm_embed_dim', type=int, help='Dimension of the ENM embedding / number of conv filters.') |
|
|
parser.add_argument('--enm_att_heads', type=int, help='Number of attention heads for the ENM embedding.') |
|
|
parser.add_argument('--num_layers', type=int, help='Number of conv layers in the ENM adaptor.') |
|
|
parser.add_argument('--kernel_size', type=int, help='Size of the convolutional kernels in the ENM adaptor.') |
|
|
parser.add_argument('--mixed_precision', action='store_true', help='Enable mixed precision training.') |
|
|
parser.add_argument('--gradient_accumulation_steps', type=int, help='Number of steps to accumulate gradients before performing a backward/update pass.') |
|
|
return parser.parse_args() |
|
|
|
|
|
def preprocess_data(tokenizer, train, valid, test): |
|
|
|
|
|
train = train[["sequence", "label", "enm_vals"]] |
|
|
valid = valid[["sequence", "label", "enm_vals"]] |
|
|
test = test[["sequence", "label", "enm_vals"]] |
|
|
|
|
|
train.reset_index(drop=True,inplace=True) |
|
|
valid.reset_index(drop=True,inplace=True) |
|
|
test.reset_index(drop=True,inplace=True) |
|
|
|
|
|
|
|
|
train['label'] = train.apply(lambda row: [-100 if x > 900 else x for x in row['label']], axis=1) |
|
|
valid['label'] = valid.apply(lambda row: [-100 if x > 900 else x for x in row['label']], axis=1) |
|
|
test['label'] = test.apply(lambda row: [-100 if x > 900 else x for x in row['label']], axis=1) |
|
|
|
|
|
|
|
|
|
|
|
train["sequence"]=train["sequence"].str.replace('|'.join(["O","B","U","Z","-"]),"X",regex=True) |
|
|
valid["sequence"]=valid["sequence"].str.replace('|'.join(["O","B","U","Z","-"]),"X",regex=True) |
|
|
|
|
|
train['sequence']=train.apply(lambda row : " ".join(row["sequence"]), axis = 1) |
|
|
valid['sequence']=valid.apply(lambda row : " ".join(row["sequence"]), axis = 1) |
|
|
|
|
|
|
|
|
|
|
|
train_set=create_dataset(tokenizer,list(train['sequence']),list(train['label']),list(train['enm_vals'])) |
|
|
valid_set=create_dataset(tokenizer,list(valid['sequence']),list(valid['label']),list(valid['enm_vals'])) |
|
|
|
|
|
return train_set, valid_set, test |
|
|
|
|
|
if __name__=='__main__': |
|
|
|
|
|
args = parse_args() |
|
|
config = yaml.load(open('configs/train_config.yaml', 'r'), Loader=yaml.FullLoader) |
|
|
config = update_config(config, args) |
|
|
|
|
|
config['training_args']['run_name'] = config['run_name'] |
|
|
config['training_args']['output_dir'] = config['training_args']['output_dir'].format( |
|
|
run_name=config['run_name'], |
|
|
timestamp=datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
) |
|
|
config['training_args']['fp16'] = config['mixed_precision'] |
|
|
config['training_args']['gradient_accumulation_steps'] = config['gradient_accumulation_steps'] |
|
|
config['training_args']['num_train_epochs'] = config['epochs'] |
|
|
config['training_args']['per_device_train_batch_size'] = config['batch_size'] |
|
|
config['training_args']['per_device_eval_batch_size'] = config['batch_size'] |
|
|
config['training_args']['eval_steps'] = config['training_args']['save_steps'] |
|
|
|
|
|
print("Training with the following config: \n", config) |
|
|
|
|
|
env_config = yaml.load(open('configs/env_config.yaml', 'r'), Loader=yaml.FullLoader) |
|
|
|
|
|
|
|
|
|
|
|
os.environ['HF_HOME'] = env_config['huggingface']['HF_HOME'] |
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"]= env_config['gpus']['cuda_visible_device'] |
|
|
|
|
|
|
|
|
wandb.init(project=env_config['wandb']['project'], name=config['run_name'], config = config) |
|
|
|
|
|
|
|
|
DATA_PATH = config['data_path'] |
|
|
FASTA_PATH = config['fasta_path'] |
|
|
ENM_PATH = config['enm_path'] |
|
|
SPLITS_PATH = config['splits_path'] |
|
|
|
|
|
sequences, names, labels, enm_vals = [], [], [], [] |
|
|
|
|
|
with open(FASTA_PATH, "r") as fasta_file: |
|
|
|
|
|
for record in SeqIO.parse(fasta_file, "fasta"): |
|
|
sequences.append([record.name, str(record.seq)]) |
|
|
|
|
|
df = pd.DataFrame(sequences, columns=["name", "sequence"]) |
|
|
|
|
|
with open(ENM_PATH, "r") as f: |
|
|
enm_lines = f.readlines() |
|
|
enm_vals_dict={} |
|
|
for l in enm_lines: |
|
|
_d = json.loads(l) |
|
|
_key = ".".join(_d['pdb_name'].split("_")) |
|
|
enm_vals_dict[_key] = _d['fluctuations'] |
|
|
enm_vals_dict[_key].append(0.0) |
|
|
|
|
|
with open(DATA_PATH, "r") as f: |
|
|
lines = f.readlines() |
|
|
|
|
|
for l in lines: |
|
|
_split_line = l.split(":\t") |
|
|
names.append(_split_line[0]) |
|
|
labels.append([float(label) for label in _split_line[1].split(", ")]) |
|
|
enm_vals.append(enm_vals_dict[_split_line[0]]) |
|
|
|
|
|
|
|
|
df["label"] = labels |
|
|
df["enm_vals"] = enm_vals |
|
|
|
|
|
|
|
|
set_seeds(config['seed']) |
|
|
|
|
|
|
|
|
class_config=ClassConfig(config) |
|
|
model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config) |
|
|
|
|
|
|
|
|
train,valid,test = do_topology_split(df, SPLITS_PATH) |
|
|
train_set, valid_set, test = preprocess_data(tokenizer, train, valid, test) |
|
|
|
|
|
|
|
|
training_args = TrainingArguments(**config['training_args']) |
|
|
|
|
|
|
|
|
data_collator = DataCollatorForTokenRegression(tokenizer) |
|
|
|
|
|
|
|
|
trainer = ENMAdaptedTrainer( |
|
|
model, |
|
|
training_args, |
|
|
train_dataset=train_set, |
|
|
eval_dataset=valid_set, |
|
|
tokenizer=tokenizer, |
|
|
data_collator=data_collator, |
|
|
compute_metrics=compute_metrics |
|
|
) |
|
|
|
|
|
|
|
|
trainer.train() |
|
|
save_finetuned_model(trainer.model,config['training_args']['output_dir']) |