File size: 7,406 Bytes
7968cb0 5f38146 7968cb0 5f38146 7968cb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import numpy as np
import torch
import torch.nn as nn
import copy
import re
from transformers import T5Config, T5PreTrainedModel, T5EncoderModel, T5Tokenizer
from transformers.models.t5.modeling_t5 import T5Stack
from transformers.modeling_outputs import TokenClassifierOutput
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from models.enm_adaptor_heads import ENMAdaptedAttentionClassifier, ENMAdaptedDirectClassifier, ENMAdaptedConvClassifier, ENMNoAdaptorClassifier
from utils.lora_utils import LoRAConfig, modify_with_lora
class T5EncoderForTokenClassification(T5PreTrainedModel):
def __init__(self, config: T5Config, class_config):
super().__init__(config)
self.num_labels = class_config.num_labels
self.config = config
self.add_pearson_loss = class_config.add_pearson_loss
self.add_sse_loss = class_config.add_sse_loss
self.shared = nn.Embedding(config.vocab_size, config.d_model)
encoder_config = copy.deepcopy(config)
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = T5Stack(encoder_config, self.shared)
self.dropout = nn.Dropout(class_config.dropout_rate)
if class_config.adaptor_architecture == 'attention':
self.classifier = ENMAdaptedAttentionClassifier(config.hidden_size, class_config.num_labels, class_config.enm_embed_dim, class_config.enm_att_heads) #nn.Linear(config.hidden_size, class_config.num_labels)
elif class_config.adaptor_architecture == 'direct':
self.classifier = ENMAdaptedDirectClassifier(config.hidden_size, class_config.num_labels)
elif class_config.adaptor_architecture == 'conv':
self.classifier = ENMAdaptedConvClassifier(config.hidden_size, class_config.num_labels, class_config.kernel_size, class_config.enm_embed_dim, class_config.num_layers)
elif class_config.adaptor_architecture == 'no-adaptor':
self.classifier = ENMNoAdaptorClassifier(config.hidden_size, class_config.num_labels)
else:
raise ValueError('Only attention, direct, conv and no-adaptor architectures are supported for the adaptor.')
# Initialize weights and apply final processing
self.post_init()
# Model parallel
self.model_parallel = False
self.device_map = None
def parallelize(self, device_map=None):
self.device_map = (
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
assert_device_map(self.device_map, len(self.encoder.block))
self.encoder.parallelize(self.device_map)
self.classifier = self.classifier.to(self.encoder.first_device)
self.model_parallel = True
def deparallelize(self):
self.encoder.deparallelize()
self.encoder = self.encoder.to("cpu")
self.model_parallel = False
self.device_map = None
torch.cuda.empty_cache()
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
def get_encoder(self):
return self.encoder
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(
self,
enm_vals = None,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# import pdb; pdb.set_trace()
outputs = self.encoder(input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
#TODO: check the enm_vals are padded properly and check that the sequence limit (in the transformer) is indeed 512
logits = self.classifier(sequence_output, enm_vals, attention_mask)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
#loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def PT5_classification_model(half_precision, class_config):
# Load PT5 and tokenizer
# possible to load the half preciion model (thanks to @pawel-rezo for pointing that out)
if not half_precision:
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50", local_files_only=False)
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", local_files_only=False)
elif half_precision and torch.cuda.is_available():
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False, local_files_only=False)
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", torch_dtype=torch.float16, local_files_only=False).to(torch.device('cuda'))
else:
raise ValueError('Half precision can be run on GPU only.')
# Create new Classifier model with PT5 dimensions
class_model=T5EncoderForTokenClassification(model.config,class_config)
# Set encoder and embedding weights to checkpoint weights
class_model.shared=model.shared
class_model.encoder=model.encoder
# Delete the checkpoint model
model=class_model
del class_model
# Print number of trainable parameters
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("ProtT5_Classfier\nTrainable Parameter: "+ str(params))
# Add model modification lora
config = LoRAConfig('configs/lora_config.yaml')
# Add LoRA layers
model = modify_with_lora(model, config)
# Freeze Embeddings and Encoder (except LoRA)
for (param_name, param) in model.shared.named_parameters():
param.requires_grad = False
for (param_name, param) in model.encoder.named_parameters():
param.requires_grad = False
for (param_name, param) in model.named_parameters():
if re.fullmatch(config.trainable_param_names, param_name):
param.requires_grad = True
# Print trainable Parameter
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("ProtT5_LoRA_Classfier\nTrainable Parameter: "+ str(params) + "\n")
return model, tokenizer
|