flexpert / models /T5_encoder_per_token.py
Honzus24's picture
Update models/T5_encoder_per_token.py
5f38146 verified
import numpy as np
import torch
import torch.nn as nn
import copy
import re
from transformers import T5Config, T5PreTrainedModel, T5EncoderModel, T5Tokenizer
from transformers.models.t5.modeling_t5 import T5Stack
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from models.enm_adaptor_heads import ENMAdaptedAttentionClassifier, ENMAdaptedDirectClassifier, ENMAdaptedConvClassifier, ENMNoAdaptorClassifier
from utils.lora_utils import LoRAConfig, modify_with_lora
class T5EncoderForTokenClassification(T5PreTrainedModel):
def __init__(self, config: T5Config, class_config):
super().__init__(config)
self.num_labels = class_config.num_labels
self.config = config
self.add_pearson_loss = class_config.add_pearson_loss
self.add_sse_loss = class_config.add_sse_loss
self.shared = nn.Embedding(config.vocab_size, config.d_model)
encoder_config = copy.deepcopy(config)
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = T5Stack(encoder_config, self.shared)
self.dropout = nn.Dropout(class_config.dropout_rate)
if class_config.adaptor_architecture == 'attention':
self.classifier = ENMAdaptedAttentionClassifier(config.hidden_size, class_config.num_labels, class_config.enm_embed_dim, class_config.enm_att_heads) #nn.Linear(config.hidden_size, class_config.num_labels)
elif class_config.adaptor_architecture == 'direct':
self.classifier = ENMAdaptedDirectClassifier(config.hidden_size, class_config.num_labels)
elif class_config.adaptor_architecture == 'conv':
self.classifier = ENMAdaptedConvClassifier(config.hidden_size, class_config.num_labels, class_config.kernel_size, class_config.enm_embed_dim, class_config.num_layers)
elif class_config.adaptor_architecture == 'no-adaptor':
self.classifier = ENMNoAdaptorClassifier(config.hidden_size, class_config.num_labels)
else:
raise ValueError('Only attention, direct, conv and no-adaptor architectures are supported for the adaptor.')
# Initialize weights and apply final processing
self.post_init()
# Model parallel
self.model_parallel = False
self.device_map = None
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.encoder.block))
self.encoder.parallelize(self.device_map)
self.classifier = self.classifier.to(self.encoder.first_device)
self.model_parallel = True
def deparallelize(self):
self.encoder.deparallelize()
self.encoder = self.encoder.to("cpu")
self.model_parallel = False
self.device_map = None
torch.cuda.empty_cache()
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
def get_encoder(self):
return self.encoder
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(
self,
enm_vals = None,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# import pdb; pdb.set_trace()
outputs = self.encoder(input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
#TODO: check the enm_vals are padded properly and check that the sequence limit (in the transformer) is indeed 512
logits = self.classifier(sequence_output, enm_vals, attention_mask)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
#loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def PT5_classification_model(half_precision, class_config):
# Load PT5 and tokenizer
# possible to load the half preciion model (thanks to @pawel-rezo for pointing that out)
if not half_precision:
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50", local_files_only=False)
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", local_files_only=False)
elif half_precision and torch.cuda.is_available():
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False, local_files_only=False)
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", torch_dtype=torch.float16, local_files_only=False).to(torch.device('cuda'))
else:
raise ValueError('Half precision can be run on GPU only.')
# Create new Classifier model with PT5 dimensions
class_model=T5EncoderForTokenClassification(model.config,class_config)
# Set encoder and embedding weights to checkpoint weights
class_model.shared=model.shared
class_model.encoder=model.encoder
# Delete the checkpoint model
model=class_model
del class_model
# Print number of trainable parameters
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("ProtT5_Classfier\nTrainable Parameter: "+ str(params))
# Add model modification lora
config = LoRAConfig('configs/lora_config.yaml')
# Add LoRA layers
model = modify_with_lora(model, config)
# Freeze Embeddings and Encoder (except LoRA)
for (param_name, param) in model.shared.named_parameters():
param.requires_grad = False
for (param_name, param) in model.encoder.named_parameters():
param.requires_grad = False
for (param_name, param) in model.named_parameters():
if re.fullmatch(config.trainable_param_names, param_name):
param.requires_grad = True
# Print trainable Parameter
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("ProtT5_LoRA_Classfier\nTrainable Parameter: "+ str(params) + "\n")
return model, tokenizer