flexpert / utils /lora_utils.py
Honzus24's picture
initial commit
7968cb0
raw
history blame
3.73 kB
import yaml
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
class LoRAConfig:
def __init__(self, config_file):
# Load the YAML configuration file
with open(config_file, 'r') as file:
config = yaml.safe_load(file)
# self.config = config
# Set class attributes based on the loaded YAML config
for key, value in config.items():
setattr(self, key, value)
class LoRALinear(nn.Module):
def __init__(self, linear_layer, rank, scaling_rank, init_scale):
super().__init__()
self.in_features = linear_layer.in_features
self.out_features = linear_layer.out_features
self.rank = rank
self.scaling_rank = scaling_rank
self.weight = linear_layer.weight
self.bias = linear_layer.bias
if self.rank > 0:
self.lora_a = nn.Parameter(torch.randn(rank, linear_layer.in_features) * init_scale)
if init_scale < 0:
self.lora_b = nn.Parameter(torch.randn(linear_layer.out_features, rank) * init_scale)
else:
self.lora_b = nn.Parameter(torch.zeros(linear_layer.out_features, rank))
if self.scaling_rank:
self.multi_lora_a = nn.Parameter(
torch.ones(self.scaling_rank, linear_layer.in_features)
+ torch.randn(self.scaling_rank, linear_layer.in_features) * init_scale
)
if init_scale < 0:
self.multi_lora_b = nn.Parameter(
torch.ones(linear_layer.out_features, self.scaling_rank)
+ torch.randn(linear_layer.out_features, self.scaling_rank) * init_scale
)
else:
self.multi_lora_b = nn.Parameter(torch.ones(linear_layer.out_features, self.scaling_rank))
def forward(self, input):
if self.scaling_rank == 1 and self.rank == 0:
# parsimonious implementation for ia3 and lora scaling
if self.multi_lora_a.requires_grad:
hidden = F.linear((input * self.multi_lora_a.flatten()), self.weight, self.bias)
else:
hidden = F.linear(input, self.weight, self.bias)
if self.multi_lora_b.requires_grad:
hidden = hidden * self.multi_lora_b.flatten()
return hidden
else:
# general implementation for lora (adding and scaling)
weight = self.weight
if self.scaling_rank:
weight = weight * torch.matmul(self.multi_lora_b, self.multi_lora_a) / self.scaling_rank
if self.rank:
weight = weight + torch.matmul(self.lora_b, self.lora_a) / self.rank
return F.linear(input, weight, self.bias)
def extra_repr(self):
return "in_features={}, out_features={}, bias={}, rank={}, scaling_rank={}".format(
self.in_features, self.out_features, self.bias is not None, self.rank, self.scaling_rank
)
def modify_with_lora(transformer, config):
for m_name, module in dict(transformer.named_modules()).items():
if re.fullmatch(config.lora_modules, m_name):
for c_name, layer in dict(module.named_children()).items():
if re.fullmatch(config.lora_layers, c_name):
assert isinstance(
layer, nn.Linear
), f"LoRA can only be applied to torch.nn.Linear, but {layer} is {type(layer)}."
setattr(
module,
c_name,
LoRALinear(layer, config.lora_rank, config.lora_scaling_rank, config.lora_init_scale),
)
return transformer