#! -*- coding: utf-8
# https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/layer.py#L81
import typing
import warnings

import torch
import numpy as np


class LoRALinear(torch.nn.Module):
    def __init__(self, base_layer: torch.nn.Linear, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, lora_bias: bool = True,
                 init_weight: str = "kaiming_uniform",
                 **kwargs):
        super().__init__()
        assert r > 0
        self.r = r
        self.lora_alpha = lora_alpha
        self.lora_dropout = lora_dropout
        self.lora_bias = lora_bias

        self.base_layer = base_layer
        for p in self.base_layer.parameters():  # frozen base layer
            p.requires_grad = False
        self.scaling = lora_alpha / r

        self.in_features = self.base_layer.in_features
        self.out_features = self.base_layer.out_features

        self.lora_A = torch.nn.Linear(self.in_features, r, bias=False)
        self.lora_B = torch.nn.Linear(r, self.out_features, bias=lora_bias)
        self.lora_dropout = torch.nn.Dropout(p=lora_dropout) if lora_dropout > 0.0 \
            else torch.nn.Identity()

        self._init_parameters(init_weight=init_weight)

    def _init_parameters(self, init_weight: str = "kaiming_uniform") -> None:
        if init_weight == "kaiming_uniform":
            torch.nn.init.kaiming_uniform_(self.lora_A.weight, a=np.sqrt(5))
        elif init_weight == "gaussian":
            torch.nn.init.normal_(self.lora_A.weight, std=1.0/self.r)
        else:
            raise ValueError(f"Unsupported init weight: {init_weight}")
        torch.nn.init.zeros_(self.lora_B.weight)
        if self.lora_bias:
            torch.nn.init.zeros_(self.lora_B.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        result = self.base_layer(x)
        dtype = result.dtype

        o = self.lora_B(self.lora_A(self.lora_dropout(x)))
        o = (o * self.scaling).to(dtype)

        return result + o
