import torch
import torch.nn as nn


class AugBUGDiagSLayer(nn.Module):
    def __init__(self, input_size, output_size, rank, tau, bias=False):
        """Constructs a dense layer of the form W*x + b, where W is the weigh matrix and b is the bias vector
        Args:
            input_size: input dimension of weight W
            output_size: output dimension of weight W, dimension of bias b
        """
        # construct parent class nn.Module
        super(AugBUGDiagSLayer, self).__init__()
        # define weights as trainable parameter
        self.r = rank
        self.tol = tau
        self.rmax = rank * 2
        self.rmin = 2
        self.U = nn.Parameter(torch.randn(input_size, self.rmax))
        self.V = nn.Parameter(torch.randn(output_size, self.rmax))
        self.U.data, _ = torch.linalg.qr(self.U, "reduced")
        self.V.data, _ = torch.linalg.qr(self.V, "reduced")

        # initilize coefficient matrix
        self.singular_values, _ = torch.sort(
            torch.randn(self.rmax) ** 2, descending=True
        )
        self.S = nn.Parameter(self.singular_values).reshape(-1, 1)

        # define bias as trainable parameter
        if bias:
            self.bias = nn.Parameter(torch.randn(output_size))
        else:
            self.bias = None

    def __init__(self, U, S, V, output_size, rank, tau, bias=False):
        """Constructs a dense layer of the form W*x + b, where W is the weigh matrix and b is the bias vector
        Args:
            input_size: input dimension of weight W
            output_size: output dimension of weight W, dimension of bias b
        """
        # construct parent class nn.Module
        super(AugBUGDiagSLayer, self).__init__()
        # define weights as trainable parameter
        self.r = rank
        self.tol = tau
        self.rmax = rank * 2
        self.rmin = 2
        self.U = nn.Parameter(U)
        self.V = nn.Parameter(V)
        self.U.data, _ = torch.linalg.qr(self.U, "reduced")
        self.V.data, _ = torch.linalg.qr(self.V, "reduced")

        # initilize coefficient matrix
        self.S = nn.Parameter(torch.diagonal(S))

        # define bias as trainable parameter
        if bias:
            self.bias = nn.Parameter(torch.randn(output_size))
        else:
            self.bias = None

    def forward(self):
        """Returns the output of the layer. The formula implemented is output = W*x + bias.
        Args:
            x: input to layer
        Returns:
            output of layer
        """
        out = (self.U[:, : self.r] @ torch.diag(self.S[: self.r])) @ self.V[
            :, : self.r
        ].T
        if self.bias is not None:
            out = out + self.bias
        return out

    @torch.no_grad()
    def step(
        self, learning_rate, dlrt_step="basis", momentum=0.0, weight_decay=0.0
    ) -> None:
        """Performs a steepest descend training update on specified low-rank factors
        Args:
            learning_rate: learning rate for training
            dlrt_step: sepcifies step that is taken. Can be 'K', 'L' or 'S'
            adaptive: specifies if fixed-rank or rank-adaptivity is used
        """

        if dlrt_step == "basis":
            r1 = min(self.rmax, 2 * self.r)

            U1, _ = torch.linalg.qr(
                torch.cat((self.U[:, : self.r], -self.U.grad[:, : self.r]), 1),
                "reduced",
            )

            V1, _ = torch.linalg.qr(
                torch.cat((self.V[:, : self.r], -self.V.grad[:, : self.r]), 1),
                "reduced",
            )
            # Basis projection

            # M = U1[:, :r1].T @ self.U[:, : self.r]
            # N = self.V[:, : self.r].T @ V1[:, :r1]
            # Project coefficients
            self.S.data[self.r : r1] = torch.zeros((r1 - self.r))

            # update basis
            self.U.data[:, :r1] = U1[:, :r1]
            self.V.data[:, :r1] = V1[:, :r1]

            self.r = r1

        elif dlrt_step == "coefficients":
            self.S.data[: self.r] = (
                self.S[: self.r] - learning_rate * self.S.grad[: self.r]
            )
            if self.bias:
                self.bias.data = self.bias - learning_rate * self.bias.grad
        elif dlrt_step == "truncate":
            # truncate to new rank
            self.truncate(learning_rate)
        else:
            print("Wrong step defined: ", dlrt_step)

    @torch.no_grad()
    def truncate(self, learning_rate) -> None:
        """Truncates the weight matrix to a new rank"""

        tol = self.tol * torch.linalg.norm(self.S[: self.r])
        r1 = self.r
        for j in range(0, self.r):
            tmp = torch.linalg.norm(self.S[j : self.r])
            if tmp < tol:
                r1 = j
                break

        # Check if new ranks is withing legal bounds
        r1 = min(r1, self.rmax)
        r1 = max(r1, self.rmin)

        self.r = int(r1)
