|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import copy |
|
|
import re |
|
|
from transformers import T5Config, T5PreTrainedModel, T5EncoderModel, T5Tokenizer |
|
|
from transformers.models.t5.modeling_t5 import T5Stack |
|
|
from transformers.modeling_outputs import TokenClassifierOutput |
|
|
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map |
|
|
from models.enm_adaptor_heads import ENMAdaptedAttentionClassifier, ENMAdaptedDirectClassifier, ENMAdaptedConvClassifier, ENMNoAdaptorClassifier |
|
|
from utils.lora_utils import LoRAConfig, modify_with_lora |
|
|
|
|
|
class T5EncoderForTokenClassification(T5PreTrainedModel): |
|
|
|
|
|
def __init__(self, config: T5Config, class_config): |
|
|
super().__init__(config) |
|
|
self.num_labels = class_config.num_labels |
|
|
self.config = config |
|
|
self.add_pearson_loss = class_config.add_pearson_loss |
|
|
self.add_sse_loss = class_config.add_sse_loss |
|
|
self.shared = nn.Embedding(config.vocab_size, config.d_model) |
|
|
|
|
|
encoder_config = copy.deepcopy(config) |
|
|
encoder_config.use_cache = False |
|
|
encoder_config.is_encoder_decoder = False |
|
|
self.encoder = T5Stack(encoder_config, self.shared) |
|
|
|
|
|
self.dropout = nn.Dropout(class_config.dropout_rate) |
|
|
if class_config.adaptor_architecture == 'attention': |
|
|
self.classifier = ENMAdaptedAttentionClassifier(config.hidden_size, class_config.num_labels, class_config.enm_embed_dim, class_config.enm_att_heads) |
|
|
elif class_config.adaptor_architecture == 'direct': |
|
|
self.classifier = ENMAdaptedDirectClassifier(config.hidden_size, class_config.num_labels) |
|
|
elif class_config.adaptor_architecture == 'conv': |
|
|
self.classifier = ENMAdaptedConvClassifier(config.hidden_size, class_config.num_labels, class_config.kernel_size, class_config.enm_embed_dim, class_config.num_layers) |
|
|
elif class_config.adaptor_architecture == 'no-adaptor': |
|
|
self.classifier = ENMNoAdaptorClassifier(config.hidden_size, class_config.num_labels) |
|
|
else: |
|
|
raise ValueError('Only attention, direct, conv and no-adaptor architectures are supported for the adaptor.') |
|
|
|
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
|
|
|
self.model_parallel = False |
|
|
self.device_map = None |
|
|
|
|
|
def parallelize(self, device_map=None): |
|
|
self.device_map = ( |
|
|
get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) |
|
|
if device_map is None |
|
|
else device_map |
|
|
) |
|
|
assert_device_map(self.device_map, len(self.encoder.block)) |
|
|
self.encoder.parallelize(self.device_map) |
|
|
self.classifier = self.classifier.to(self.encoder.first_device) |
|
|
self.model_parallel = True |
|
|
|
|
|
def deparallelize(self): |
|
|
self.encoder.deparallelize() |
|
|
self.encoder = self.encoder.to("cpu") |
|
|
self.model_parallel = False |
|
|
self.device_map = None |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.shared |
|
|
|
|
|
def set_input_embeddings(self, new_embeddings): |
|
|
self.shared = new_embeddings |
|
|
self.encoder.set_input_embeddings(new_embeddings) |
|
|
|
|
|
def get_encoder(self): |
|
|
return self.encoder |
|
|
|
|
|
def _prune_heads(self, heads_to_prune): |
|
|
""" |
|
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base |
|
|
class PreTrainedModel |
|
|
""" |
|
|
for layer, heads in heads_to_prune.items(): |
|
|
self.encoder.layer[layer].attention.prune_heads(heads) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
enm_vals = None, |
|
|
input_ids=None, |
|
|
attention_mask=None, |
|
|
head_mask=None, |
|
|
inputs_embeds=None, |
|
|
labels=None, |
|
|
output_attentions=None, |
|
|
output_hidden_states=None, |
|
|
return_dict=None, |
|
|
): |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.encoder(input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
head_mask=head_mask, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
sequence_output = outputs[0] |
|
|
sequence_output = self.dropout(sequence_output) |
|
|
|
|
|
logits = self.classifier(sequence_output, enm_vals, attention_mask) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[2:] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return TokenClassifierOutput( |
|
|
|
|
|
logits=logits, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
def PT5_classification_model(half_precision, class_config): |
|
|
|
|
|
|
|
|
if not half_precision: |
|
|
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50", local_files_only=False) |
|
|
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", local_files_only=False) |
|
|
elif half_precision and torch.cuda.is_available(): |
|
|
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False, local_files_only=False) |
|
|
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", torch_dtype=torch.float16, local_files_only=False).to(torch.device('cuda')) |
|
|
else: |
|
|
raise ValueError('Half precision can be run on GPU only.') |
|
|
|
|
|
|
|
|
class_model=T5EncoderForTokenClassification(model.config,class_config) |
|
|
|
|
|
|
|
|
class_model.shared=model.shared |
|
|
class_model.encoder=model.encoder |
|
|
|
|
|
|
|
|
model=class_model |
|
|
del class_model |
|
|
|
|
|
|
|
|
model_parameters = filter(lambda p: p.requires_grad, model.parameters()) |
|
|
params = sum([np.prod(p.size()) for p in model_parameters]) |
|
|
print("ProtT5_Classfier\nTrainable Parameter: "+ str(params)) |
|
|
|
|
|
|
|
|
config = LoRAConfig('configs/lora_config.yaml') |
|
|
|
|
|
|
|
|
model = modify_with_lora(model, config) |
|
|
|
|
|
|
|
|
for (param_name, param) in model.shared.named_parameters(): |
|
|
param.requires_grad = False |
|
|
for (param_name, param) in model.encoder.named_parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
for (param_name, param) in model.named_parameters(): |
|
|
if re.fullmatch(config.trainable_param_names, param_name): |
|
|
param.requires_grad = True |
|
|
|
|
|
|
|
|
model_parameters = filter(lambda p: p.requires_grad, model.parameters()) |
|
|
params = sum([np.prod(p.size()) for p in model_parameters]) |
|
|
print("ProtT5_LoRA_Classfier\nTrainable Parameter: "+ str(params) + "\n") |
|
|
|
|
|
return model, tokenizer |
|
|
|