|
|
from data.scripts.data_utils import parse_PDB |
|
|
from utils.utils import ClassConfig, DataCollatorForTokenRegression, process_in_batches_and_combine, get_dot_separated_name |
|
|
from models.T5_encoder_per_token import PT5_classification_model |
|
|
from data.scripts.get_enm_fluctuations_for_dataset import get_fluctuation_for_json_dict |
|
|
import argparse |
|
|
import os |
|
|
import yaml |
|
|
import torch |
|
|
from pathlib import Path |
|
|
from Bio import SeqIO |
|
|
import json |
|
|
import warnings |
|
|
from datetime import datetime |
|
|
|
|
|
from data.scripts.data_utils import modify_bfactor_biotite |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--input_file", type=str, required=True, help="Input file") |
|
|
parser.add_argument("--modality", type=str, required=True, help="Indicate 'Seq' or '3D' to use Flexpert-Seq or Flexpert-3D?") |
|
|
parser.add_argument("--splits_file", type=str, required=False, help="Path to the file defining the splits, in case that input_file is a dataset which should be subsampled.") |
|
|
parser.add_argument("--split", type=str, required=False, help="Specify test/train/val to subselect the respective split. If specified, the splits file needs to be provided as well.") |
|
|
parser.add_argument("--output_enm", action='store_true', help="If true, the ENM values will be outputted in separate file(s).") |
|
|
parser.add_argument("--output_fasta", action='store_true', help="If true, the sequences used for the prediction will be outputted in a fasta file (can be relevant when working with input list of PDB files).") |
|
|
parser.add_argument("--output_name", type=str, required=False, help="Name of the output file.") |
|
|
args = parser.parse_args() |
|
|
|
|
|
args.modality = args.modality.upper() |
|
|
filename, suffix = os.path.splitext(args.input_file) |
|
|
|
|
|
if args.modality not in ["SEQ", "3D"]: |
|
|
raise ValueError("Modality must be either Seq or 3D") |
|
|
if args.splits_file is not None and args.split is None: |
|
|
raise ValueError("If splits_file is provided, split must be specified.") |
|
|
if args.split is not None and args.splits_file is None: |
|
|
raise ValueError("If split is specified, splits_file must be provided.") |
|
|
if args.split is not None and args.split not in ["test", "train", "val", "validation"]: |
|
|
raise ValueError("Split must be either 'test', 'train', 'val' or 'validation'") |
|
|
if args.output_enm and (args.modality not in ["3D"]): |
|
|
raise ValueError("Output ENM is only supported for 3D modality") |
|
|
if not args.output_name: |
|
|
default_name = 'untitled_{}'.format(datetime.now().strftime('%Y%m%d_%H%M%S')) |
|
|
args.output_name = default_name |
|
|
warnings.warn("Output name is not provided, using default name: {}".format(default_name)) |
|
|
|
|
|
|
|
|
if args.splits_file is not None: |
|
|
with open(args.splits_file, 'r') as f: |
|
|
splits = json.load(f) |
|
|
if 'val' in splits.keys() and args.split == 'validation': |
|
|
args.split = 'val' |
|
|
elif 'validation' in splits.keys() and args.split == 'val': |
|
|
args.split = 'validation' |
|
|
|
|
|
datapoint_for_eval = splits[args.split] |
|
|
else: |
|
|
datapoint_for_eval = 'all' |
|
|
|
|
|
sequences = [] |
|
|
names = [] |
|
|
backbones = [] |
|
|
pdb_files = [] |
|
|
flucts_list = [] |
|
|
|
|
|
def process_pdb_file(pdb_file, backbones, sequences, names): |
|
|
parsed_name = os.path.splitext(os.path.basename(pdb_file))[0].split('_') |
|
|
if len(parsed_name[0]) != 4 or len(parsed_name[1]) != 1 or not parsed_name[1].isalpha(): |
|
|
raise ValueError("PDB file name is expected to be in the format of 'name_chain.pdb', e.g.: 1BUI_C.pdb") |
|
|
_name = parsed_name[0] |
|
|
_chain = parsed_name[1] |
|
|
parsed_pdb = parse_PDB(pdb_file, name=_name, input_chain_list=[_chain])[0] |
|
|
backbone, sequence = parsed_pdb['coords_chain_{}'.format(_chain)], parsed_pdb['seq_chain_{}'.format(_chain)] |
|
|
if len(sequence) > 1023: |
|
|
print("Sequence length is greater than 1023, skipping {}".format(_name + "." + _chain)) |
|
|
else: |
|
|
backbones.append(backbone) |
|
|
sequences.append(sequence) |
|
|
names.append(_name + "." + _chain) |
|
|
return backbones, sequences, names |
|
|
|
|
|
if suffix == ".fasta": |
|
|
if args.modality == "3D": |
|
|
raise ValueError("Flexpert-3D needs the structure, fasta is not enough") |
|
|
|
|
|
|
|
|
for record in SeqIO.parse(args.input_file, "fasta"): |
|
|
if '_' in record.name: |
|
|
dot_separated_name = '.'.join(record.name.split('_')) |
|
|
elif '.' in record.name: |
|
|
dot_separated_name = record.name |
|
|
else: |
|
|
raise ValueError("Sequence name must contain either an underscore or a dot to separate the PDB code and the chain code.") |
|
|
if datapoint_for_eval == 'all' or dot_separated_name in datapoint_for_eval: |
|
|
names.append(dot_separated_name) |
|
|
sequences.append(str(record.seq)) |
|
|
backbones.append(None) |
|
|
|
|
|
elif suffix == ".pdb": |
|
|
backbones, sequences, names = process_pdb_file(args.input_file, backbones, sequences, names) |
|
|
pdb_files.append(args.input_file) |
|
|
|
|
|
elif suffix == ".jsonl": |
|
|
for line in open(args.input_file, 'r'): |
|
|
_dict = json.loads(line) |
|
|
|
|
|
if 'fluctuations' in _dict.keys(): |
|
|
print("fluctuations are precomputed, using them") |
|
|
dot_separated_name = get_dot_separated_name(key='pdb_name', _dict=_dict) |
|
|
if datapoint_for_eval == 'all' or dot_separated_name in datapoint_for_eval: |
|
|
|
|
|
names.append(_dict['pdb_name']) |
|
|
backbones.append(None) |
|
|
sequences.append(_dict['sequence']) |
|
|
|
|
|
flucts_list.append(_dict['fluctuations']+[0.0]) |
|
|
continue |
|
|
|
|
|
dot_separated_name = get_dot_separated_name(key='name', _dict=_dict) |
|
|
|
|
|
if datapoint_for_eval == 'all' or dot_separated_name in datapoint_for_eval: |
|
|
backbones.append(_dict['coords']) |
|
|
sequences.append(_dict['seq']) |
|
|
names.append(dot_separated_name) |
|
|
|
|
|
elif suffix == ".pdb_list": |
|
|
with open(args.input_file, 'r') as f: |
|
|
pdb_files = f.read().splitlines() |
|
|
for pdb_file in pdb_files: |
|
|
backbones, sequences, names = process_pdb_file(pdb_file, backbones, sequences, names) |
|
|
|
|
|
else: |
|
|
raise ValueError("Input file must be a fasta, pdb, jsonl file or a pdb list file") |
|
|
|
|
|
|
|
|
env_config = yaml.load(open('configs/env_config.yaml', 'r'), Loader=yaml.FullLoader) |
|
|
|
|
|
os.environ['HF_HOME'] = env_config['huggingface']['HF_HOME'] |
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"]= env_config['gpus']['cuda_visible_device'] |
|
|
|
|
|
config = yaml.load(open('configs/train_config.yaml', 'r'), Loader=yaml.FullLoader) |
|
|
class_config=ClassConfig(config) |
|
|
class_config.adaptor_architecture = 'no-adaptor' if args.modality == 'SEQ' else 'conv' |
|
|
model, tokenizer = PT5_classification_model(half_precision=config['mixed_precision'], class_config=class_config) |
|
|
|
|
|
model.to(config['inference_args']['device']) |
|
|
if args.modality == 'SEQ': |
|
|
state_dict = torch.load(config['inference_args']['seq_model_path'], map_location=config['inference_args']['device']) |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
elif args.modality == '3D': |
|
|
print("Loading 3D model from {}".format(config['inference_args']['3d_model_path'])) |
|
|
state_dict = torch.load(config['inference_args']['3d_model_path'], map_location=config['inference_args']['device']) |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
model.eval() |
|
|
|
|
|
data_to_collate = [] |
|
|
for idx, (backbone, sequence) in enumerate(zip(backbones, sequences)): |
|
|
|
|
|
if args.modality == '3D': |
|
|
if backbone is not None: |
|
|
_dict = {'coords': backbone, 'seq': sequence} |
|
|
flucts, _ = get_fluctuation_for_json_dict(_dict, enm_type = config['inference_args']['enm_type']) |
|
|
flucts = flucts.tolist() |
|
|
flucts.append(0.0) |
|
|
flucts = torch.tensor(flucts).to(config['inference_args']['device']) |
|
|
else: |
|
|
flucts = flucts_list[idx] |
|
|
|
|
|
|
|
|
sequence = sequence.replace('-', 'X') |
|
|
|
|
|
tokenizer_out = tokenizer(' '.join(sequence), add_special_tokens=True, return_tensors='pt') |
|
|
tokenized_seq, attention_mask = tokenizer_out['input_ids'].to(config['inference_args']['device']), tokenizer_out['attention_mask'].to(config['inference_args']['device']) |
|
|
|
|
|
if args.modality == '3D': |
|
|
data_to_collate.append({'input_ids': tokenized_seq[0,:], 'attention_mask': attention_mask[0,:], 'enm_vals': flucts}) |
|
|
elif args.modality == 'SEQ': |
|
|
data_to_collate.append({'input_ids': tokenized_seq[0,:], 'attention_mask': attention_mask[0,:]}) |
|
|
|
|
|
|
|
|
data_collator = DataCollatorForTokenRegression(tokenizer) |
|
|
|
|
|
batch = data_collator(data_to_collate) |
|
|
batch.to(model.device) |
|
|
for key in batch.keys(): |
|
|
print("___________-", key, "-___________") |
|
|
for b in batch[key]: |
|
|
if key == 'attention_mask': |
|
|
print(b.sum()) |
|
|
else: |
|
|
print(b.shape) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output_logits = process_in_batches_and_combine(model, batch, config['inference_args']['batch_size']) |
|
|
predictions = output_logits[:,:,0] |
|
|
|
|
|
|
|
|
output_filename = Path(config['inference_args']['prediction_output_dir'].format(args.output_name, args.modality, 'all' if not args.split else args.split)) |
|
|
output_filename.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
with open(output_filename.with_suffix('.txt'), 'w') as f: |
|
|
print("Saving predictions to {}.".format(output_filename)) |
|
|
for prediction, mask, name, sequence in zip(predictions, batch['attention_mask'], names, sequences): |
|
|
prediction = prediction[mask.bool()] |
|
|
if len(prediction) != len(sequence)+1: |
|
|
print("Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1)) |
|
|
|
|
|
assert len(prediction) == len(sequence)+1, "Prediction length {} is not equal to sequence length + 1 {}".format(len(prediction), len(sequence)+1) |
|
|
if '.' in name: |
|
|
name = name.replace('.', '_') |
|
|
f.write('>' + name + '\n') |
|
|
f.write(', '.join([str(p) for p in prediction.tolist()[:-1]]) + '\n') |
|
|
|
|
|
if suffix == ".pdb" or suffix == ".pdb_list": |
|
|
for name, pdb_file, prediction in zip(names, pdb_files, predictions): |
|
|
chain_id = name.split('.')[1] |
|
|
_prediction = prediction[:-1].reshape(1,-1) |
|
|
_outname = output_filename.with_name(output_filename.stem + '_{}.pdb'.format(name.replace('.', '_'))) |
|
|
print("Saving prediction to {}.".format(_outname)) |
|
|
modify_bfactor_biotite(pdb_file, chain_id, _outname, _prediction) |
|
|
|
|
|
if args.output_enm: |
|
|
_outname = output_filename.with_name(output_filename.stem + '_enm.txt') |
|
|
with open(_outname, 'w') as f: |
|
|
print("Saving ENM predictions to {}.".format(_outname)) |
|
|
for enm_prediction, name in zip(batch['enm_vals'], names): |
|
|
f.write('>' + name + '\n') |
|
|
f.write(', '.join([str(p) for p in enm_prediction.tolist()[:-1]]) + '\n') |
|
|
|
|
|
if suffix == ".pdb" or suffix == ".pdb_list": |
|
|
for name, pdb_file, enm_vals_single in zip(names, pdb_files, batch['enm_vals']): |
|
|
_outname = output_filename.with_name(output_filename.stem + '_{}.pdb'.format(name.replace('.', '_'))) |
|
|
print("Saving ENM prediction to {}.".format(_outname)) |
|
|
chain_id = name.split('.')[1] |
|
|
_enm_vals = enm_vals_single[:-1].reshape(1,-1) |
|
|
modify_bfactor_biotite(pdb_file, chain_id, _outname, _enm_vals) |
|
|
|
|
|
if args.output_fasta: |
|
|
_outname = output_filename.with_name(output_filename.stem + '_fasta.fasta') |
|
|
with open(_outname, 'w') as f: |
|
|
print("Saving fasta to {}.".format(_outname)) |
|
|
for name, sequence in zip(names, sequences): |
|
|
f.write('>' + name + '\n') |
|
|
f.write(sequence + '\n') |