"""
    This script implements Low-Rank Adaptation (LoRA) as designed by the original
    proposed work.
    We added a function:
        - _init_lora_A: to provide different initialization, in line with our
        theoretical results provided in the paper.
"""
import math

import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from safetensors import safe_open
from safetensors.torch import save_file
from timm.models.vision_transformer import VisionTransformer as timm_ViT
from torch import Tensor
from torch.nn.parameter import Parameter

from base_vit import ViT


def _init_lora_A(w: nn.Linear, init_type: str, init_scale: float):
    """
    This function allows to initialize the low matrix A using different
    distributions and scales.
    ---
        - init_type (str): is the distribution to be used
        - init_scale (float): is the parameter to the distribution. Note that
              by default the parameter will use the default PyTorch value.
    """
    it = init_type.lower()
    fan_in = w.weight.size(1)

    if it == "uniform":
        stdv = init_scale / math.sqrt(fan_in)
        nn.init.uniform_(w.weight, -stdv, +stdv)

    elif it == "gaussian":
        std = init_scale / math.sqrt(fan_in)
        nn.init.normal_(w.weight, mean=0.0, std=std)

    elif it == "orthogonal":
        nn.init.orthogonal_(w.weight, gain=init_scale)

    elif it == "xavier_uniform":
        nn.init.xavier_uniform_(w.weight, gain=init_scale)

    elif it == "xavier_normal":
        nn.init.xavier_normal_(w.weight, gain=init_scale)

    elif it == "kaiming_uniform":
        nn.init.kaiming_uniform_(w.weight, a=init_scale, mode="fan_in",
                                                    nonlinearity="leaky_relu")

    elif it == "kaiming_normal":
        nn.init.kaiming_normal_(w.weight, a=init_scale, mode="fan_in",
                                                    nonlinearity="leaky_relu")

    elif it == "zeros":
        nn.init.zeros_(w.weight)

    else:
        raise ValueError(f"Unknown init_type: {init_type}")



class LoRA_ViT_timm(nn.Module):
    def __init__(self, vit_model: timm_ViT,
                r: int, alpha: int, num_classes: int = 0,
                lora_layer=None,
                init_type="kaiming_uniform", #this is the default in PEFT Package
                init_scale=1.0):

        super(LoRA_ViT_timm, self).__init__()

        self.init_type = init_type
        self.init_scale = float(init_scale)

        assert r > 0
        assert alpha > 0
        if lora_layer:
            self.lora_layer = lora_layer
        else:
            self.lora_layer = list(range(len(vit_model.blocks)))

        self.w_As = []
        self.w_Bs = []

        # lets freeze first
        for param in vit_model.parameters():
            param.requires_grad = False

        # We start by getting the different blocks and get where the LoRA should
        # be added.
        for t_layer_i, blk in enumerate(vit_model.blocks):
            if t_layer_i not in self.lora_layer:
                continue
            w_qkv_linear = blk.attn.qkv
            self.dim = w_qkv_linear.in_features
            w_a_linear_q = nn.Linear(self.dim, r, bias=False)
            w_b_linear_q = nn.Linear(r, self.dim, bias=False)
            w_a_linear_v = nn.Linear(self.dim, r, bias=False)
            w_b_linear_v = nn.Linear(r, self.dim, bias=False)
            self.w_As.append(w_a_linear_q)
            self.w_Bs.append(w_b_linear_q)
            self.w_As.append(w_a_linear_v)
            self.w_Bs.append(w_b_linear_v)
            blk.attn.qkv = _LoRA_qkv_timm(
                w_qkv_linear,
                w_a_linear_q,
                w_b_linear_q,
                w_a_linear_v,
                w_b_linear_v,
                r,
                alpha
            )
        self.reset_parameters()
        self.lora_vit = vit_model
        self.proj_3d = nn.Linear(num_classes * 30, num_classes)
        if num_classes > 0:
            self.lora_vit.head = nn.Identity()
            self.head = nn.Linear(self.dim, num_classes)

    def save_fc_parameters(self, filename: str) -> None:
        r"""Only safetensors is supported now.

        pip install safetensor if you do not have one installed yet.
        """
        assert filename.endswith(".safetensors")
        _in = self.lora_vit.head.in_features
        _out = self.lora_vit.head.out_features
        fc_tensors = {f"fc_{_in}in_{_out}out": self.lora_vit.head.weight}
        save_file(fc_tensors, filename)

    def load_fc_parameters(self, filename: str) -> None:
        assert filename.endswith(".safetensors")
        _in = self.lora_vit.head.in_features
        _out = self.lora_vit.head.out_features
        with safe_open(filename, framework="pt") as f:
            saved_key = f"fc_{_in}in_{_out}out"
            try:
                saved_tensor = f.get_tensor(saved_key)
                self.lora_vit.head.weight = Parameter(saved_tensor)
            except ValueError:
                print("this fc weight is not for this model")

    def save_lora_parameters(self, filename: str) -> None:

        assert filename.endswith(".safetensors")

        num_layer = len(self.w_As)
        a_tensors = {f"w_a_{i:03d}": self.w_As[i].weight for i in range(num_layer)}
        b_tensors = {f"w_b_{i:03d}": self.w_Bs[i].weight for i in range(num_layer)}

        _in = self.lora_vit.head.in_features
        _out = self.lora_vit.head.out_features
        fc_tensors = {f"fc_{_in}in_{_out}out": self.lora_vit.head.weight}

        merged_dict = {**a_tensors, **b_tensors, **fc_tensors}
        save_file(merged_dict, filename)

    def load_lora_parameters(self, filename: str) -> None:
        assert filename.endswith(".safetensors")

        with safe_open(filename, framework="pt") as f:
            for i, w_A_linear in enumerate(self.w_As):
                saved_key = f"w_a_{i:03d}"
                saved_tensor = f.get_tensor(saved_key)
                w_A_linear.weight = Parameter(saved_tensor)

            for i, w_B_linear in enumerate(self.w_Bs):
                saved_key = f"w_b_{i:03d}"
                saved_tensor = f.get_tensor(saved_key)
                w_B_linear.weight = Parameter(saved_tensor)

            _in = self.lora_vit.head.in_features
            _out = self.lora_vit.head.out_features
            saved_key = f"fc_{_in}in_{_out}out"
            try:
                saved_tensor = f.get_tensor(saved_key)
                self.lora_vit.head.weight = Parameter(saved_tensor)
            except ValueError:
                print("this fc weight is not for this model")


    def reset_parameters(self) -> None:
        """
        here we add the part about intialization using the function.
            - Note that B matrix is initialized to 0 as in the original paper.
        """
        for w_A in self.w_As:
            _init_lora_A(w_A, self.init_type, self.init_scale)
        for w_B in self.w_Bs:
            nn.init.zeros_(w_B.weight)

    def forward(self, x: Tensor) -> Tensor:
        return self.lora_vit(x)


class _LoRA_qkv_timm(nn.Module):
    """
    This is an adaptation for timm (which we are using to get the pre-trained)
    """

    def __init__(
        self,
        qkv: nn.Module,
        linear_a_q: nn.Module,
        linear_b_q: nn.Module,
        linear_a_v: nn.Module,
        linear_b_v: nn.Module,
        r: int,
        alpha: int
    ):
        super().__init__()
        self.qkv = qkv
        self.linear_a_q = linear_a_q
        self.linear_b_q = linear_b_q
        self.linear_a_v = linear_a_v
        self.linear_b_v = linear_b_v
        self.dim = qkv.in_features
        self.w_identity = torch.eye(qkv.in_features)
        self.r = r
        self.alpha = alpha

    def forward(self, x):
        qkv = self.qkv(x)
        new_q = self.linear_b_q(self.linear_a_q(x))
        new_v = self.linear_b_v(self.linear_a_v(x))
        qkv[:, :, : self.dim] += (self.alpha / self.r) * new_q
        qkv[:, :, -self.dim :] += (self.alpha / self.r) * new_v

        return qkv

if __name__=="__main__":
    pass
