import torch
import torch.nn as nn


class ParallelLowRankLayer(nn.Module):
    def __init__(self, input_size, output_size, rank, tau, bias=False):
        """Constructs a low-rank layer of the form U*S*V'*x + b, where
           U, S, V represent the facorized weight W and b is the bias vector
        Args:
            input_size: input dimension of weight W
            output_size: output dimension of weight W, dimension of bias b
            rank: initial rank of factorized weight
            init_compression: initial compression of neural network
        """
        # construct parent class nn.Module
        super(ParallelLowRankLayer, self).__init__()
        # set rank and truncation tolerance

        self.r = rank
        self.tol = tau
        self.rmax = rank *2
        self.rmin = 2

        # initializes factorized weight
        self.U = nn.Parameter(torch.randn(input_size, self.rmax))
        self.VT = nn.Parameter(torch.randn(self.rmax, output_size))
        self.U.data, _ = torch.linalg.qr(self.U, "reduced")
        self.VT.data = torch.linalg.qr(self.VT.T, "reduced")[0].T

        self.singular_values, _ = torch.sort(
            torch.randn(self.rmax) ** 2, descending=True
        )
        self.S = nn.Parameter(torch.diag(self.singular_values))
        self.Sinv = nn.Parameter(torch.diag(1 / self.singular_values),requires_grad=False)

        # initialize bias
        if bias:
            self.bias = nn.Parameter(torch.randn(output_size))
        else:
            self.bias = None

    def forward(self):
        """Returns the output of the layer. The formula implemented is output = U*S*V'*x + bias.
        Args:
            x: input to layer
        Returns:
            output of layer
        """
        out = self.U[:, : self.r] @ self.S[: self.r, : self.r] @ self.VT[: self.r, :]
        if self.bias:
            out = out + self.bias
        return out

    @torch.no_grad()
    def step(self, learning_rate):
        """Performs a steepest descend training update on specified low-rank factors
        Args:
            learning_rate: learning rate for training
        """
        r = self.r
        r1 = min(self.rmax, 2 * self.r) 

        U_view = self.U  # just a view, no copy
        VT_view = self.VT
        S_view = self.S
        Sinv_view = self.Sinv

        # perform K-step
        # gradient modification
        U_view.grad[:, :r] = U_view.grad[:, :r] @ Sinv_view[:r, :r]
        U1, _ = torch.linalg.qr(
            torch.cat((U_view[:, :r], U_view.grad[:, :r]), 1), "reduced"
        )

        # perform L-step
        # gradient modification
        VT_view.grad[:r, :] = Sinv_view[:r, :r].T @ VT_view.grad[:r, :]
        V1, _ = torch.linalg.qr(
            torch.cat((VT_view[:r, :].T, VT_view.grad[:r, :].T), 1), "reduced"
        )

        # set up augmented S matrix
        # 1) SK (lower left block)
        S_view.data[r:r1, :r] = U1[:, r:r1].T @ (
            U_view[:, :r] @ S_view[:r, :r] - learning_rate * U_view.grad[:, :r]
        )
        # 2) SL (upper right block)
        S_view.data[:r, r:r1] = (
            S_view[:r, :r] @ VT_view[:r, :] - learning_rate * VT_view.grad[:r, :]
        ) @ V1[:, r:r1]

        # needs to go after SK and SL, since S gets updated here and SK and SL needs old S
        S_view.data[:r, :r] = S_view[:r, :r] - learning_rate * S_view.grad[:r, :r]
        S_view.data[r:r1, r:r1] *= 0  # = torch.zeros((r, r))

        U_view.data[:, r:r1] = U1[:, r:r1]  # torch.cat((U0, U1[:, r:r1]), 1)
        VT_view.data[r:r1, :] = V1[:, r:r1].T  # torch.cat((V0, V1[:, r:r1]), 1)
        self.r = r1

        # update bias
        if self.bias:
            self.bias.data = self.bias - learning_rate * self.bias.grad

        self.Truncate()

    @torch.no_grad()
    def BiasStep(self, learning_rate):
        """Performs a steepest descend training update on the bias
        Args:
            learning_rate: learning rate for training
        """
        self.bias.data = self.bias - learning_rate * self.bias.grad

    @torch.no_grad()
    def Truncate(self):
        """Truncates the weight matrix to a new rank"""
        P, d, Q = torch.linalg.svd(self.S[: self.r, : self.r])

        tol = self.tol * torch.linalg.norm(d)
        r1 = self.r
        for j in range(0, self.r):
            tmp = torch.linalg.norm(d[j : self.r])
            if tmp < tol:
                r1 = j
                break

        r1 = min(r1, self.rmax)
        r1 = max(r1, 2)

        # update s
        self.S.data[:r1, :r1] = torch.diag(d[:r1])
        self.Sinv[:r1, :r1] = torch.diag(1.0 / d[:r1])

        # update u and v
        self.U.data[:, :r1] = self.U[:, : self.r] @ P[:, :r1]
        self.VT.data[:r1, :] = Q[:r1, :] @ self.VT[: self.r, :]
        self.r = int(r1)


class ParallelLowRankLayerTranspose(nn.Module):
    def __init__(self, input_size, output_size, rank, tau, bias=False):
        """Constructs a low-rank layer of the form U*S*V'*x + b, where
           U, S, V represent the facorized weight W and b is the bias vector
        Args:
            input_size: input dimension of weight W
            output_size: output dimension of weight W, dimension of bias b
            rank: initial rank of factorized weight
            init_compression: initial compression of neural network
        """
        # construct parent class nn.Module
        super(ParallelLowRankLayerTranspose, self).__init__()
        # set rank and truncation tolerance

        self.r = rank
        self.tol = tau
        self.rmax = int(min(input_size, output_size) / 2)
        self.rmin = 2

        # initializes factorized weight
        self.U = nn.Parameter(torch.randn(input_size, self.rmax))
        self.V = nn.Parameter(torch.randn(output_size, self.rmax))
        self.U.data, _ = torch.linalg.qr(self.U, "reduced")
        self.V.data, _ = torch.linalg.qr(self.V, "reduced")

        self.singular_values, _ = torch.sort(
            torch.randn(self.rmax) ** 2, descending=True
        )
        self.S = nn.Parameter(torch.diag(self.singular_values))
        self.Sinv = torch.Tensor(torch.diag(1 / self.singular_values))

        # initialize bias
        if bias:
            self.bias = nn.Parameter(torch.randn(output_size))
        else:
            self.bias = None

    def forward(self):
        """Returns the output of the layer. The formula implemented is output = U*S*V'*x + bias.
        Args:
            x: input to layer
        Returns:
            output of layer
        """
        r = self.r

        out = self.U[:, :r] @ self.S[:r, :r] @ self.V[:, :r].T
        if self.bias:
            out = out + self.bias
        return out

    @torch.no_grad()
    def step(self, learning_rate):
        """Performs a steepest descend training update on specified low-rank factors
        Args:
            learning_rate: learning rate for training
        """
        r = self.r
        r1 = 2 * r

        U_view = self.U  # just a view, no copy
        V_view = self.V
        S_view = self.S
        Sinv_view = self.Sinv

        # perform K-step
        # gradient modification
        U_view.grad[:, :r] = U_view.grad[:, :r] @ Sinv_view[:r, :r]
        U1, _ = torch.linalg.qr(
            torch.cat((U_view[:, :r], U_view.grad[:, :r]), 1), "reduced"
        )

        # perform L-step
        # gradient modification
        V_view.grad[:, :r] = V_view.grad[:, :r] @ Sinv_view[:r, :r]
        V1, _ = torch.linalg.qr(
            torch.cat((V_view[:, :r], V_view.grad[:, :r]), 1), "reduced"
        )

        # set up augmented S matrix
        S_view.data[r:r1, :r] = U1[:, r:r1].T @ (
            U_view[:, :r] @ S_view[:r, :r] - learning_rate * U_view.grad[:, :r]
        )
        S_view.data[:r, r:r1] = (
            V_view[:, :r] @ S_view[:r, :r].T - learning_rate * V_view.grad[:, :r]
        ).T @ V1[:, r:r1]
        # needs to go after SK and SL, since S gets updated here and SK and SL needs old S
        S_view.data[:r, :r] = S_view[:r, :r] - learning_rate * S_view.grad[:r, :r]
        S_view.data[r:r1, r:r1] *= 0  # = torch.zeros((r, r))

        U_view.data[:, r:r1] = U1[:, r:r1]  # torch.cat((U0, U1[:, r:r1]), 1)
        V_view.data[:, r:r1] = V1[:, r:r1]  # torch.cat((V0, V1[:, r:r1]), 1)
        self.r = r1

        # update bias
        if self.bias:
            self.bias.data = self.bias - learning_rate * self.bias.grad

        self.Truncate()

    @torch.no_grad()
    def BiasStep(self, learning_rate):
        """Performs a steepest descend training update on the bias
        Args:
            learning_rate: learning rate for training
        """
        self.bias.data = self.bias - learning_rate * self.bias.grad

    @torch.no_grad()
    def Truncate(self):
        """Truncates the weight matrix to a new rank"""
        P, d, Q = torch.linalg.svd(self.S[: self.r, : self.r])

        tol = self.tol * torch.linalg.norm(d)
        r1 = self.r
        for j in range(0, self.r):
            tmp = torch.linalg.norm(d[j : self.r])
            if tmp < tol:
                r1 = j
                break

        r1 = min(r1, self.rmax)
        r1 = max(r1, 2)

        # update s
        self.S.data[:r1, :r1] = torch.diag(d[:r1])
        self.Sinv[:r1, :r1] = torch.diag(1.0 / d[:r1])

        # update u and v
        self.U.data[:, :r1] = torch.matmul(self.U[:, : self.r], P[:, :r1])
        self.V.data[:, :r1] = torch.matmul(
            self.V[:, : self.r], Q.T[:, :r1]
        )  # DOUBLE CECK Q.T here. Should it be Q?
        self.r = int(r1)
