import torch
import torch.nn as nn
from typing import Union
import tltorch,math
import torch.nn.functional as F


class low_rank_linear(nn.Module):
    def __init__(
        self, original_linear, rank, alpha=16, lora_dropout=0.0
    ):
        """Constructs a low-rank layer of the form U*S*V'*x + b, where
           U, S, V represent the facorized weight W
        Args:
            rank: initial rank of factorized weight
        """
        # construct parent class nn.Module
        super(low_rank_linear, self).__init__()

        self.original_linear = original_linear

        # set rank and truncation tolerance for parallel LoRA

        self.rank = rank
        self.out_features = original_linear.out_features
        self.in_features = original_linear.in_features
        self.dims = [self.out_features,self.in_features]
        # Scaling factor
        self.alpha = alpha
        self.scaling = self.alpha / self.rank  # probably not needed for dlrt

        self.us = nn.Parameter(
            torch.linalg.qr(
                torch.randn(original_linear.in_features, self.rank), "reduced"
            )[0],
            requires_grad=True,
        )
        self.vs = nn.Parameter(
            torch.linalg.qr(
                torch.randn(original_linear.out_features, self.rank), "reduced"
            )[0],
            requires_grad=True,
        )

        self.s = nn.Parameter(
            torch.zeros(self.rank),
            requires_grad=True,
        )


        if lora_dropout > 0.0:
            self.lora_dropout_layer = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout_layer = nn.Identity()
        self.bias  =original_linear.bias
        self.weight = original_linear.weight
        self.to(original_linear.weight.device)

    def forward(self, x):
        """Returns the output of the layer. The formula implemented is output =  xW + x*U*S*V' + bias.
        Args:
            x: input to layer
        Returns:
            output of layer
        """
        # out = self.original_linear(x) + self.scaling * (
        #   x @ self.lora_U[:, : self.r] @ self.lora_V[:, : self.r].T
        # )
        out = self.original_linear(x) + self.scaling * (
            (
                (self.lora_dropout_layer(x) @ self.us[:, : self.rank])
                @ torch.diag(self.s[: self.rank])
            )
            @ self.vs[:, : self.rank].T
        )
        return out
    
    def activate_upper_level(self):
        self.us.requires_grad = True
        self.vs.requires_grad = True
        self.s.requires_grad = True

    def activate_lower_level(self):
        self.us.requires_grad = True
        self.vs.requires_grad = True
        self.s.requires_grad = False
        self.s.grad = None
    
    @torch.no_grad()
    def get_hypergradient(self):
        eps = 1e-5
        masked_product = torch.tensor([1./s if s>=eps else 0.0 for s in self.s.squeeze() ],device = self.s.device)
        self.s.grad.add_(
            torch.diag(self.us.T @ self.us.grad + self.vs.grad.T @ self.vs)*masked_product
        )


class low_rank_CP(nn.Module):
    def __init__(
        self, original_linear, rank, alpha=16, lora_dropout=0.0
    ):
        """Constructs a low-rank layer of the form U*S*V'*x + b, where
           U, S, V represent the facorized weight W
        Args:
            rank: initial rank of factorized weight
        """
        # construct parent class nn.Module
        super(low_rank_CP, self).__init__()

        self.original_linear = original_linear

        # set rank and truncation tolerance for parallel LoRA

        self.rank = rank
        self.out_channels = original_linear.out_channels
        self.in_channels = original_linear.in_channels
        self.kernel_size = original_linear.kernel_size if isinstance(original_linear.kernel_size,tuple) else (original_linear.kernel_size,original_linear.kernel_size)
        self.dims = [self.out_channels,self.in_channels,self.kernel_size[0],self.kernel_size[1]]
        # Scaling factor
        self.alpha = alpha
        self.scaling = self.alpha / self.rank 

        self.rank = rank #min(rank,min(self.dims))
        self.s = torch.nn.Parameter(torch.randn(self.rank), requires_grad=True)
        self.us = torch.nn.ParameterList(
            [
                torch.nn.Parameter(torch.randn(size=(d, r)), requires_grad=True)
                for d, r in zip(self.dims, [self.rank] * len(self.dims))
            ]
        )

        if lora_dropout > 0.0:
            self.lora_dropout_layer = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout_layer = nn.Identity()
        
        self.reset_parameters()
        self.original_linear = original_linear
        self.bias  =original_linear.bias
        self.weight = original_linear.weight
        self.stride=  original_linear.stride
        self.padding= original_linear.padding
        self.dilation=original_linear.dilation
        self.groups=original_linear.groups
        self.to(original_linear.weight.device)

    @torch.no_grad()
    def reset_parameters(self):
        # torch.nn.init.uniform_(self.s,a =0.01 , b=10)  ### blo
        torch.nn.init.uniform_(self.s,a =0.0 , b=0.0)  #### adalora
        for i, u in enumerate(self.us):
            torch.nn.init.kaiming_uniform_(u, a=math.sqrt(5))
            u[:,:min(self.rank,min(u.shape[0],u.shape[1]))].copy_(torch.linalg.qr(u,'reduced')[0])

    
    def forward(self, inputs):
        """
        forward phase for the convolutional layer. It has to contain the three different
        phases for the steps 'K','L' and 'S' in order to be optimizable using dlrt.
        Every step is rewritten in terms of the tucker decomposition of the kernel tensor
        """
        result = self.original_linear(inputs)
        # weight = torch.einsum('i,ai,bi,ci,di->abcd',self.s[:self.rank],*[u[:,:self.rank] for u in self.us])
        # result += torch.nn.functional.conv2d(input = inputs, weight = weight,bias = self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
        result = tltorch.functional.convolution.cp_conv(
            x=input,
            cp_tensor=tltorch.CPTensor(
                weights=self.s[: self.rank],
                factors=[p[:, : self.rank] for p in self.us],
                rank=self.rank,
            ),
            bias=self.bias,
            stride=self.stride,
            padding=self.padding,
            dilation=self.dilation
        )
        return result

    @torch.no_grad()
    def construct_weight_tensor(self):
        """
        just for debugging purposes, don't use it
        """
        return torch.einsum("i,ai,bi,ci,di->abcd", self.s, *self.us)

    def activate_upper_level(self):
        for p in self.us:
            p.requires_grad = True
        self.s.requires_grad = True

    def activate_lower_level(self):
        for p in self.us:
            p.requires_grad = True
        self.s.requires_grad = False  ### False
        self.s.grad = None

    @torch.no_grad()
    def get_hypergradient(self):
        eps = 1e-5
        masked_product = torch.tensor([1./s if s>=eps else 0.0 for s in self.s.squeeze() ],device = self.s.device)
        hypergradient_diagonal = torch.diag(
            self.us[0][:, : self.rank].T @ self.us[0].grad[:, : self.rank]
            + self.us[1][:, : self.rank].T @ self.us[1].grad[:, : self.rank]
            + self.us[3][:, : self.rank].T @ self.us[2].grad[:, : self.rank]
            + self.us[3][:, : self.rank].T @ self.us[3].grad[:, : self.rank]
        )
        self.s.grad.add_(hypergradient_diagonal*masked_product)


def blo_layer(
         original_linear, rank, alpha=16, lora_dropout=0.0
    ):
        """Constructs a low-rank layer of the form U*S*V'*x + b, where
           U, S, V represent the facorized weight W
        Args:
            rank: initial rank of factorized weight
        """
        # construct parent class nn.Module

        if isinstance(original_linear,torch.nn.Linear):
            return low_rank_linear(original_linear=original_linear,rank = rank,alpha=alpha,lora_dropout=lora_dropout)
        elif isinstance(original_linear,torch.nn.Conv2d):
            return low_rank_CP(original_linear=original_linear,rank = rank,alpha=alpha,lora_dropout=lora_dropout)
