import torch
import torch.nn as nn
import math


class LoRALinear(nn.Module):
    def __init__(self, original_linear, rank=8, alpha=16):
        super(LoRALinear, self).__init__()
        self.original_linear = original_linear
        self.rank = rank
        self.alpha = alpha

        # Creating low-rank adaptation matrices
        self.lora_A = nn.Parameter(torch.zeros((original_linear.out_features, rank)))
        self.lora_B = nn.Parameter(torch.zeros((rank, original_linear.in_features)))

        # Scaling factor
        self.scaling = alpha / rank

        # Initialize parameters
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)

    def forward(self, x):
        return self.original_linear(x) + self.scaling * (x @ self.lora_A @ self.lora_B)
