import torch
import torch.nn as nn


class AugBUGLayer(nn.Module):
    def __init__(self, input_size, output_size, rank, tau, bias=False):
        """Constructs a dense layer of the form W*x + b, where W is the weigh matrix and b is the bias vector
        Args:
            input_size: input dimension of weight W
            output_size: output dimension of weight W, dimension of bias b
        """
        # construct parent class nn.Module
        super(AugBUGLayer, self).__init__()
        # define weights as trainable parameter
        self.r = rank
        self.tol = tau
        self.rmax = rank * 2
        self.rmin = 2
        self.U = nn.Parameter(torch.randn(input_size, self.rmax))
        self.V = nn.Parameter(torch.randn(output_size, self.rmax))
        self.U.data, _ = torch.linalg.qr(self.U, "reduced")
        self.V.data, _ = torch.linalg.qr(self.V, "reduced")

        # initilize coefficient matrix
        self.singular_values, _ = torch.sort(
            torch.randn(self.rmax) ** 2, descending=True
        )
        self.S = nn.Parameter(torch.diag(self.singular_values))

        # define bias as trainable parameter
        if bias:
            self.bias = nn.Parameter(torch.randn(output_size))
        else:
            self.bias = None

    def __init__(self, U, S, V, output_size, rank, tau, bias=False):
        """Constructs a dense layer of the form W*x + b, where W is the weigh matrix and b is the bias vector
        Args:
            input_size: input dimension of weight W
            output_size: output dimension of weight W, dimension of bias b
        """
        # construct parent class nn.Module
        super(AugBUGLayer, self).__init__()
        # define weights as trainable parameter
        self.r = rank
        self.tol = tau
        self.rmax = rank * 2
        self.rmin = 2
        self.U = nn.Parameter(U.detach().clone())
        self.V = nn.Parameter(V.detach().clone())
        self.U.data, _ = torch.linalg.qr(self.U, "reduced")
        self.V.data, _ = torch.linalg.qr(self.V, "reduced")

        # initilize coefficient matrix
        self.S = nn.Parameter(S)

        # define bias as trainable parameter
        if bias:
            self.bias = nn.Parameter(torch.randn(output_size))
        else:
            self.bias = None

    def forward(self):
        """Returns the output of the layer. The formula implemented is output = W*x + bias.
        Args:
            x: input to layer
        Returns:
            output of layer
        """
        out = self.U[:, : self.r] @ self.S[: self.r, : self.r] @ self.V[:, : self.r].T
        if self.bias is not None:
            out = out + self.bias
        return out

    @torch.no_grad()
    def step(
        self, learning_rate, dlrt_step="basis", momentum=0.0, weight_decay=0.0
    ) -> None:
        """Performs a steepest descend training update on specified low-rank factors
        Args:
            learning_rate: learning rate for training
            dlrt_step: sepcifies step that is taken. Can be 'K', 'L' or 'S'
            adaptive: specifies if fixed-rank or rank-adaptivity is used
        """

        if dlrt_step == "basis":
            r1 = min(self.rmax, 2 * self.r)

            U1, _ = torch.linalg.qr(
                torch.cat((self.U[:, : self.r], -self.U.grad[:, : self.r]), 1),
                "reduced",
            )

            V1, _ = torch.linalg.qr(
                torch.cat((self.V[:, : self.r], -self.V.grad[:, : self.r]), 1),
                "reduced",
            )
            # Basis projection

            M = U1[:, :r1].T @ self.U[:, : self.r]
            N = self.V[:, : self.r].T @ V1[:, :r1]
            # Project coefficients
            self.S.data[:r1, :r1] = M @ self.S[: self.r, : self.r] @ N

            # update basis
            self.U.data[:, :r1] = U1[:, :r1]
            self.V.data[:, :r1] = V1[:, :r1]

            self.r = r1

        elif dlrt_step == "coefficients":
            self.S.data[: self.r, : self.r] = (
                self.S[: self.r, : self.r]
                - learning_rate * self.S.grad[: self.r, : self.r]
            )
            if self.bias:
                self.bias.data = self.bias - learning_rate * self.bias.grad
        elif dlrt_step == "truncate":
            # truncate to new rank
            self.truncate(learning_rate)
        else:
            print("Wrong step defined: ", dlrt_step)

    @torch.no_grad()
    def truncate(self, learning_rate) -> None:
        """Truncates the weight matrix to a new rank"""
        try:
            P, d, Q = torch.linalg.svd(self.S[: self.r, : self.r])
        except:
            P, d, Q = torch.linalg.svd(self.S[: self.r, : self.r] + 1e-8)
        tol = self.tol * torch.linalg.norm(d)
        r1 = self.r
        for j in range(0, self.r):
            tmp = torch.linalg.norm(d[j : self.r])
            if tmp < tol:
                r1 = j
                break

        # Check if new ranks is withing legal bounds
        r1 = min(r1, self.rmax)
        r1 = max(r1, self.rmin)

        # update s
        self.S.data[:r1, :r1] = torch.diag(d[:r1])

        # update u and v
        self.U.data[:, :r1] = torch.matmul(self.U[:, : self.r], P[:, :r1])
        self.V.data[:, :r1] = torch.matmul(
            self.V[:, : self.r], Q.T[:, :r1]
        )  # DOUBLE CECK Q.T here. Should it be Q?
        self.r = int(r1)
