import torch
import torch.nn as nn
import math


class LoRAParallelDLRT(nn.Module):
    def __init__(self, original_linear, rank, tau, alpha=16, max_rank=32):
        """Constructs a low-rank layer of the form U*S*V'*x + b, where
           U, S, V represent the facorized weight W
        Args:
            rank: initial rank of factorized weight
        """
        # construct parent class nn.Module
        super(LoRAParallelDLRT, self).__init__()

        self.original_linear = original_linear

        # set rank and truncation tolerance for parallel LoRA

        self.r = rank
        self.tol = tau
        self.rmax = max_rank
        self.rmin = 2
        # Scaling factor
        self.alpha = alpha
        self.scaling = self.alpha / self.r  # probably not needed

        self.lora_U = nn.Parameter(
            torch.linalg.qr(
                torch.randn(original_linear.in_features, self.rmax), "reduced"
            )[0],
            requires_grad=True,
        )
        self.lora_V = nn.Parameter(
            torch.linalg.qr(
                torch.randn(original_linear.out_features, self.rmax), "reduced"
            )[0],
            requires_grad=True,
        )

        self.lora_S = nn.Parameter(
            torch.zeros(self.rmax, self.rmax),
            requires_grad=True,
        )

        self.Sinv = nn.Parameter(
            torch.eye(self.rmax), requires_grad=False
        )  # made parameter for multi gpu,  identity initialization for lora specifically

    def forward(self, x):
        """Returns the output of the layer. The formula implemented is output =  xW + x*U*S*V' + bias.
        Args:
            x: input to layer
        Returns:
            output of layer
        """
        # out = self.original_linear(x) + self.scaling * (
        #   x @ self.lora_U[:, : self.r] @ self.lora_V[:, : self.r].T
        # )
        out = self.original_linear(x) + self.scaling * (
            ((x @ self.lora_U[:, : self.r]) @ self.lora_S[: self.r, : self.r])
            @ self.lora_V[:, : self.r].T
        )
        return out

    @torch.no_grad()
    def step(self, learning_rate, weight_decay):
        """Performs a steepest descend training update on specified low-rank factors
        Args:
            learning_rate: learning rate for training
        """
        r = self.r

        r1 = min(2 * r, self.rmax)

        U_view = self.lora_U  # just a view, no copy
        V_view = self.lora_V  # just a view, no copy
        S_view = self.lora_S
        Sinv_view = self.Sinv

        # perform K-step
        # gradient modification
        U_view.grad[:, :r] = U_view.grad[:, :r] @ Sinv_view[:r, :r]
        U1, _ = torch.linalg.qr(
            torch.cat((U_view[:, :r], U_view.grad[:, :r]), 1), "reduced"
        )

        # perform L-step
        # gradient modification
        V_view.grad[:, :r] = V_view.grad[:, :r] @ Sinv_view[:r, :r]
        V1, _ = torch.linalg.qr(
            torch.cat((V_view[:, :r], V_view.grad[:, :r]), 1), "reduced"
        )

        # set up augmented S matrix
        try:
            S_view.data[r:r1, :r] = U1[:, r:r1].T @ (
                U_view[:, :r] @ S_view[:r, :r] - learning_rate * U_view.grad[:, :r]
            )
        except:
            print("error SK", r, r1, U_view.shape, U1.shape)
            print("end")
        try:
            S_view.data[:r, r:r1] = (
                V_view[:, :r] @ S_view[:r, :r].T - learning_rate * V_view.grad[:, :r]
            ).T @ V1[:, r:r1]
        except:
            print("error SL", r, r1, V_view.shape, V1.shape)
            print("end")

        # needs to go after SK and SL, since S gets updated here and SK and SL needs old S
        S_view.data[:r, :r] = S_view[:r, :r] - learning_rate * S_view.grad[:r, :r]
        S_view.data[r:r1, r:r1] *= 0  # = torch.zeros((r, r))

        U_view.data[:, r:r1] = U1[:, r:r1]  # torch.cat((U0, U1[:, r:r1]), 1)
        V_view.data[:, r:r1] = V1[:, r:r1]  # torch.cat((V0, V1[:, r:r1]), 1)
        self.r = r1

        S_view.grad.zero_()
        U_view.grad.zero_()
        V_view.grad.zero_()

        self.Truncate()

    @torch.no_grad()
    def Truncate(self):
        """Truncates the weight matrix to a new rank"""
        P, d, Q = torch.linalg.svd(self.lora_S[: self.r, : self.r])

        tol = self.tol * torch.linalg.norm(d)
        r1 = self.r
        for j in range(0, self.r):
            tmp = torch.linalg.norm(d[j : self.r])
            if tmp < tol:
                r1 = j
                break

        r1 = min(r1, self.rmax)
        r1 = max(r1, 2)

        # update s
        self.lora_S.data[:r1, :r1] = torch.diag(d[:r1])
        self.Sinv[:r1, :r1] = torch.diag(1.0 / d[:r1])

        # update u and v
        self.lora_U.data[:, :r1] = self.lora_U[:, : self.r] @ P[:, :r1]
        self.lora_V.data[:, :r1] = self.lora_V[:, : self.r] @ Q.T[:, :r1]
        self.r = int(r1)
        # self.scaling = self.alpha / self.r
