import torch
import torch.nn as nn


class AdaloraLayer(nn.Module):
    def __init__(self, input_size, output_size, rank, tau, bias=False):
        """Constructs a dense layer of the form W*x + b, where W is the weigh matrix and b is the bias vector
        Args:
            input_size: input dimension of weight W
            output_size: output dimension of weight W, dimension of bias b
        """
        # construct parent class nn.Module
        super(AdaloraLayer, self).__init__()
        # define weights as trainable parameter
        self.r = rank
        self.tol = tau
        self.rmax = rank * 2

        self.U = nn.Parameter(torch.randn(input_size, self.r) / input_size)
        self.S = nn.Parameter(torch.randn(self.r, self.r) / input_size)
        self.VT = nn.Parameter(torch.randn(self.r, output_size) / output_size)

        # define bias as trainable parameter
        if bias:
            self.bias = nn.Parameter(torch.randn(output_size))
        else:
            self.bias = None

    def __init__(self, U, S, VT, output_size, rank, tau, bias=False):
        """Constructs a dense layer of the form W*x + b, where W is the weigh matrix and b is the bias vector
        Args:
            input_size: input dimension of weight W
            output_size: output dimension of weight W, dimension of bias b
        """
        # construct parent class nn.Module
        super(AdaloraLayer, self).__init__()
        # define weights as trainable parameter
        self.r = rank
        self.tol = tau
        self.rmax = rank * 2
        self.U = nn.Parameter(U.detach().clone()[:, : self.r])
        self.S = nn.Parameter(S.detach().clone()[: self.r, : self.r])
        self.VT = nn.Parameter(VT.detach().clone()[: self.r, :])

        # define bias as trainable parameter
        if bias:
            self.bias = nn.Parameter(torch.randn(output_size))
        else:
            self.bias = None

    def forward(self):
        """Returns the output of the layer. The formula implemented is output = W*x + bias.
        Args:
            x: input to layer
        Returns:
            output of layer
        """
        out = self.U[:, : self.r] @ self.S[: self.r, : self.r] @ self.VT[: self.r, :]
        if self.bias:
            out = out + self.bias
        return out

    def ortho_regularization(self):

        return 0.5 * (
            torch.linalg.norm(
                self.U[:, : self.r].T @ self.U[:, : self.r] - torch.eye(self.r)
            )
            ** 2
            + torch.linalg.norm(
                self.VT[: self.r, :] @ self.VT[: self.r, :].T - torch.eye(self.r)
            )
            ** 2
        )

    @torch.no_grad()
    def step(self, learning_rate):
        """Performs a steepest descend training update on weights and biases
        Args:
            learning_rate: learning rate for training
        """
        self.U.data[:, : self.r] = (
            self.U[:, : self.r] - learning_rate * self.U.grad[:, : self.r]
        )
        self.S.data[: self.r, : self.r] = (
            self.S[: self.r, : self.r] - learning_rate * self.S.grad[: self.r, : self.r]
        )
        self.VT.data[: self.r, :] = (
            self.VT[: self.r, :] - learning_rate * self.VT.grad[: self.r, :]
        )

        if self.bias is not None:
            self.bias.data = self.bias - learning_rate * self.bias.grad

        self.U.grad.zero_()
        self.S.grad.zero_()
        self.VT.grad.zero_()
        # if self.bias is not None:
        #    self.bias.grad.data.zero_()
        self.Truncate()

    @torch.no_grad()
    def Truncate(self):
        """Truncates the weight matrix to a new rank"""
        P, d, Q = torch.linalg.svd(self.S[: self.r, : self.r])

        tol = self.tol * torch.linalg.norm(d)
        r1 = self.r
        for j in range(0, self.r):
            tmp = torch.linalg.norm(d[j : self.r])
            if tmp < tol:
                r1 = j
                break

        r1 = min(r1, self.rmax)
        r1 = max(r1, 5)

        # update s
        self.S.data[:r1, :r1] = torch.diag(d[:r1])
        # self.Sinv[:r1, :r1] = torch.diag(1.0 / d[:r1])

        # update u and v
        self.U.data[:, :r1] = self.U[:, : self.r] @ P[:, :r1]
        self.VT.data[:r1, :] = Q[:r1, :] @ self.VT[: self.r, :]
        self.r = int(r1)
