import yaml import torch import torch.nn as nn import torch.nn.functional as F import re class LoRAConfig: def __init__(self, config_file): # Load the YAML configuration file with open(config_file, 'r') as file: config = yaml.safe_load(file) # self.config = config # Set class attributes based on the loaded YAML config for key, value in config.items(): setattr(self, key, value) class LoRALinear(nn.Module): def __init__(self, linear_layer, rank, scaling_rank, init_scale): super().__init__() self.in_features = linear_layer.in_features self.out_features = linear_layer.out_features self.rank = rank self.scaling_rank = scaling_rank self.weight = linear_layer.weight self.bias = linear_layer.bias if self.rank > 0: self.lora_a = nn.Parameter(torch.randn(rank, linear_layer.in_features) * init_scale) if init_scale < 0: self.lora_b = nn.Parameter(torch.randn(linear_layer.out_features, rank) * init_scale) else: self.lora_b = nn.Parameter(torch.zeros(linear_layer.out_features, rank)) if self.scaling_rank: self.multi_lora_a = nn.Parameter( torch.ones(self.scaling_rank, linear_layer.in_features) + torch.randn(self.scaling_rank, linear_layer.in_features) * init_scale ) if init_scale < 0: self.multi_lora_b = nn.Parameter( torch.ones(linear_layer.out_features, self.scaling_rank) + torch.randn(linear_layer.out_features, self.scaling_rank) * init_scale ) else: self.multi_lora_b = nn.Parameter(torch.ones(linear_layer.out_features, self.scaling_rank)) def forward(self, input): if self.scaling_rank == 1 and self.rank == 0: # parsimonious implementation for ia3 and lora scaling if self.multi_lora_a.requires_grad: hidden = F.linear((input * self.multi_lora_a.flatten()), self.weight, self.bias) else: hidden = F.linear(input, self.weight, self.bias) if self.multi_lora_b.requires_grad: hidden = hidden * self.multi_lora_b.flatten() return hidden else: # general implementation for lora (adding and scaling) weight = self.weight if self.scaling_rank: weight = weight * torch.matmul(self.multi_lora_b, self.multi_lora_a) / self.scaling_rank if self.rank: weight = weight + torch.matmul(self.lora_b, self.lora_a) / self.rank return F.linear(input, weight, self.bias) def extra_repr(self): return "in_features={}, out_features={}, bias={}, rank={}, scaling_rank={}".format( self.in_features, self.out_features, self.bias is not None, self.rank, self.scaling_rank ) def modify_with_lora(transformer, config): for m_name, module in dict(transformer.named_modules()).items(): if re.fullmatch(config.lora_modules, m_name): for c_name, layer in dict(module.named_children()).items(): if re.fullmatch(config.lora_layers, c_name): assert isinstance( layer, nn.Linear ), f"LoRA can only be applied to torch.nn.Linear, but {layer} is {type(layer)}." setattr( module, c_name, LoRALinear(layer, config.lora_rank, config.lora_scaling_rank, config.lora_init_scale), ) return transformer