|
|
import yaml |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import re |
|
|
|
|
|
class LoRAConfig: |
|
|
def __init__(self, config_file): |
|
|
|
|
|
with open(config_file, 'r') as file: |
|
|
config = yaml.safe_load(file) |
|
|
|
|
|
|
|
|
|
|
|
for key, value in config.items(): |
|
|
setattr(self, key, value) |
|
|
|
|
|
class LoRALinear(nn.Module): |
|
|
def __init__(self, linear_layer, rank, scaling_rank, init_scale): |
|
|
super().__init__() |
|
|
self.in_features = linear_layer.in_features |
|
|
self.out_features = linear_layer.out_features |
|
|
self.rank = rank |
|
|
self.scaling_rank = scaling_rank |
|
|
self.weight = linear_layer.weight |
|
|
self.bias = linear_layer.bias |
|
|
if self.rank > 0: |
|
|
self.lora_a = nn.Parameter(torch.randn(rank, linear_layer.in_features) * init_scale) |
|
|
if init_scale < 0: |
|
|
self.lora_b = nn.Parameter(torch.randn(linear_layer.out_features, rank) * init_scale) |
|
|
else: |
|
|
self.lora_b = nn.Parameter(torch.zeros(linear_layer.out_features, rank)) |
|
|
if self.scaling_rank: |
|
|
self.multi_lora_a = nn.Parameter( |
|
|
torch.ones(self.scaling_rank, linear_layer.in_features) |
|
|
+ torch.randn(self.scaling_rank, linear_layer.in_features) * init_scale |
|
|
) |
|
|
if init_scale < 0: |
|
|
self.multi_lora_b = nn.Parameter( |
|
|
torch.ones(linear_layer.out_features, self.scaling_rank) |
|
|
+ torch.randn(linear_layer.out_features, self.scaling_rank) * init_scale |
|
|
) |
|
|
else: |
|
|
self.multi_lora_b = nn.Parameter(torch.ones(linear_layer.out_features, self.scaling_rank)) |
|
|
|
|
|
def forward(self, input): |
|
|
if self.scaling_rank == 1 and self.rank == 0: |
|
|
|
|
|
if self.multi_lora_a.requires_grad: |
|
|
hidden = F.linear((input * self.multi_lora_a.flatten()), self.weight, self.bias) |
|
|
else: |
|
|
hidden = F.linear(input, self.weight, self.bias) |
|
|
if self.multi_lora_b.requires_grad: |
|
|
hidden = hidden * self.multi_lora_b.flatten() |
|
|
return hidden |
|
|
else: |
|
|
|
|
|
weight = self.weight |
|
|
if self.scaling_rank: |
|
|
weight = weight * torch.matmul(self.multi_lora_b, self.multi_lora_a) / self.scaling_rank |
|
|
if self.rank: |
|
|
weight = weight + torch.matmul(self.lora_b, self.lora_a) / self.rank |
|
|
return F.linear(input, weight, self.bias) |
|
|
|
|
|
def extra_repr(self): |
|
|
return "in_features={}, out_features={}, bias={}, rank={}, scaling_rank={}".format( |
|
|
self.in_features, self.out_features, self.bias is not None, self.rank, self.scaling_rank |
|
|
) |
|
|
|
|
|
|
|
|
def modify_with_lora(transformer, config): |
|
|
for m_name, module in dict(transformer.named_modules()).items(): |
|
|
if re.fullmatch(config.lora_modules, m_name): |
|
|
for c_name, layer in dict(module.named_children()).items(): |
|
|
if re.fullmatch(config.lora_layers, c_name): |
|
|
assert isinstance( |
|
|
layer, nn.Linear |
|
|
), f"LoRA can only be applied to torch.nn.Linear, but {layer} is {type(layer)}." |
|
|
setattr( |
|
|
module, |
|
|
c_name, |
|
|
LoRALinear(layer, config.lora_rank, config.lora_scaling_rank, config.lora_init_scale), |
|
|
) |
|
|
return transformer |