import torch
import torch.nn as nn


def create_new_forward(module):
    def lora_forward(input):
        lora_res = nn.functional.linear(input, module.a_lora.t())
        lora_res = nn.functional.linear(lora_res, module.b_lora.t())
        return nn.functional.linear(input, module.weight, module.bias) + lora_res

    return lora_forward


def lora(model, prompt_rank, model_name):
    model_modules = list(model.named_modules())

    with torch.no_grad():
        for name, layer in model_modules:
            if isinstance(layer, nn.Linear) and (
                ("roberta" in model_name and name.find("attention") != -1)
                or (
                    "deberta" in model_name
                    and name.find("attention") != -1
                    and (name.find("in_proj") != -1 or name.find("dense") != -1)
                )
            ):
                weight = getattr(layer, "weight")

                a = torch.randn(weight.size(1), prompt_rank, device=weight.device)
                b = torch.zeros(prompt_rank, weight.size(0), device=weight.device)
                layer.register_parameter("a_lora", nn.Parameter(a))
                layer.register_parameter("b_lora", nn.Parameter(b))

                layer.forward = create_new_forward(layer)
