import torch
import torch.nn as nn


class LoRALinear(nn.Module):
    def __init__(self, in_features, out_features, rank=8, std=0.02, bias=True):
        super(LoRALinear, self).__init__()
        self.rank = rank
        self.in_features = in_features
        self.out_features = out_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        self.A = nn.Parameter(torch.randn(in_features, rank) * std)
        self.B = nn.Parameter(torch.randn(rank, out_features) * std)

    def forward(self, x):
        original_output = self.linear(x)

        lora_part = torch.matmul(x, self.A)
        lora_part = torch.matmul(lora_part, self.B)

        return original_output + lora_part
