from transformers.models.esm.modeling_esm import EsmPreTrainedModel, EsmModel import torch import torch.nn as nn from typing import Optional, Union, Tuple from transformers.models.auto.modeling_auto import AutoModel from transformers.models.auto.tokenization_auto import AutoTokenizer from torch.nn import MSELoss from transformers.modeling_outputs import TokenClassifierOutput import numpy as np import re from utils.lora_utils import LoRAConfig, modify_with_lora from models.enm_adaptor_heads import ( ENMAdaptedAttentionClassifier, ENMAdaptedDirectClassifier, ENMAdaptedConvClassifier, ENMNoAdaptorClassifier ) from peft import LoraConfig, inject_adapter_in_model class EsmForTokenRegression(EsmPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids"] def __init__(self, config, class_config): super().__init__(config) self.num_labels = config.num_labels self.add_pearson_loss = class_config.add_pearson_loss self.add_sse_loss = class_config.add_sse_loss self.esm = EsmModel(config, add_pooling_layer=False) self.dropout = nn.Dropout(config.hidden_dropout_prob) if class_config.adaptor_architecture == 'attention': self.classifier = ENMAdaptedAttentionClassifier( config.hidden_size, class_config.num_labels, class_config.enm_embed_dim, class_config.enm_att_heads ) elif class_config.adaptor_architecture == 'direct': self.classifier = ENMAdaptedDirectClassifier( config.hidden_size, class_config.num_labels ) elif class_config.adaptor_architecture == 'conv': self.classifier = ENMAdaptedConvClassifier( config.hidden_size, class_config.num_labels, class_config.kernel_size, class_config.enm_embed_dim, class_config.num_layers ) elif class_config.adaptor_architecture == 'no-adaptor': self.classifier = ENMNoAdaptorClassifier( config.hidden_size, class_config.num_labels ) else: raise ValueError('Only attention, direct, conv and no-adaptor architectures are supported.') self.init_weights() def forward( self, enm_vals=None, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, TokenClassifierOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.esm( input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output, enm_vals, attention_mask) if not return_dict: output = (logits,) + outputs[2:] return output return TokenClassifierOutput( logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def ESM_classification_model(half_precision, class_config, lora_config): # Load ESM and tokenizer if not half_precision: model = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D") tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D") elif half_precision and torch.cuda.is_available(): model = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D", torch_dtype=torch.float16).to(torch.device('cuda')) tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D") else: raise ValueError('Half precision can be run on GPU only.') # Create new Classifier model with ESM dimensions class_model = EsmForTokenRegression(model.config, class_config) # Set encoder weights to checkpoint weights class_model.esm = model # Delete the checkpoint model del model # Print number of trainable parameters model_parameters = filter(lambda p: p.requires_grad, class_model.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) print("ESM_Classifier\nTrainable Parameter: " + str(params)) # Add model modification lora esm_lora_peft_config = LoraConfig( r=4, lora_alpha=1, bias="all", target_modules=["query","key","value","dense"] ) # Add LoRA layers class_model.esm = inject_adapter_in_model(esm_lora_peft_config, class_model.esm) # Freeze Encoder (except LoRA) for (param_name, param) in class_model.esm.named_parameters(): param.requires_grad = False for (param_name, param) in class_model.esm.named_parameters(): if re.fullmatch(".*lora.*", param_name): #".*layer_norm.*|.*lora_[ab].*" param.requires_grad = True if re.fullmatch(".*layer_norm.*", param_name): #".*layer_norm.*|.*lora_[ab].*" param.requires_grad = True # Print trainable Parameter model_parameters = filter(lambda p: p.requires_grad, class_model.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) print("ESM_LoRA_Classifier\nTrainable Parameter: " + str(params) + "\n") return class_model, tokenizer