import torch
import torch.nn as nn
import csv


def create_layer(layer):
    if layer["type"] == "dense":
        return DenseLayer(layer["dims"][0], layer["dims"][1])
    if layer["type"] == "vanilla_low_rank":
        return VanillaLowRankLayer(layer["dims"][0], layer["dims"][1], layer["rank"])
    if layer["type"] == "dynamical_low_rank":
        return LowRankLayer(layer["dims"][0], layer["dims"][1], layer["rank"])
    if layer["type"] == "parallel_low_rank":
        if "init_compression" in layer:
            return ParallelLowRankLayer(
                layer["dims"][0],
                layer["dims"][1],
                layer["rank"],
                init_compression=layer["init_compression"],
            )
        else:
            return ParallelLowRankLayer(
                layer["dims"][0], layer["dims"][1], layer["rank"]
            )


# Define standard layer
class DenseLayer(nn.Module):
    def __init__(self, input_size, output_size):
        """Constructs a dense layer of the form W*x + b, where W is the weigh matrix and b is the bias vector
        Args:
            input_size: input dimension of weight W
            output_size: output dimension of weight W, dimension of bias b
        """
        # construct parent class nn.Module
        super(DenseLayer, self).__init__()
        # define weights as trainable parameter
        self.W = nn.Parameter(torch.randn(input_size, output_size))
        # define bias as trainable parameter
        self.bias = nn.Parameter(torch.randn(output_size))

    def forward(self, x):
        """Returns the output of the layer. The formula implemented is output = W*x + bias.
        Args:
            x: input to layer
        Returns:
            output of layer
        """
        out = torch.matmul(x, self.W)
        return out + self.bias

    def step(self, learning_rate):
        """Performs a steepest descend training update on weights and biases
        Args:
            learning_rate: learning rate for training
        """
        self.W.data = self.W - learning_rate * self.W.grad
        self.bias.data = self.bias - learning_rate * self.bias.grad

    def write(self, file_name, use_txt=True):
        """Writes all weight matrices
        Args:
            file_name: name of the file format in which weights are stored
        """
        # save as pth
        torch.save(self.W, file_name + "_W.pth")
        torch.save(self.bias, file_name + "_b.pth")

        if use_txt:
            with open(file_name + "_W.txt", "w") as file:
                for row in self.W.data.T:
                    row_str = "\t".join(map(str, row.tolist()))
                    file.write(row_str + "\n")

            with open(file_name + "_b.txt", "w") as file:
                bias_str = "\t".join(map(str, self.bias.data.tolist()))
                file.write(bias_str)


# Define low-rank layer
class LowRankLayer(nn.Module):
    def __init__(
        self, input_size, output_size, rank, adaptive=True, init_compression=0.5
    ):
        """Constructs a low-rank layer of the form U*S*V'*x + b, where
           U, S, V represent the facorized weight W and b is the bias vector
        Args:
            input_size: input dimension of weight W
            output_size: output dimension of weight W, dimension of bias b
            rank: initial rank of factorized weight
            adaptive: set True if layer isrank-adaptive
            init_compression: initial compression of neural network
        """
        # construct parent class nn.Module
        super(LowRankLayer, self).__init__()

        self.rmax = rank
        rmax = self.rmax
        if adaptive:
            r1 = 2 * rmax
        else:
            r1 = rank

        # initializes factorized weight
        self.U = nn.Parameter(torch.randn(input_size, r1))
        self.S = nn.Parameter(torch.randn(r1, r1))
        self.V = nn.Parameter(torch.randn(output_size, r1))

        # ensure that U and V are orthonormal
        self.U.data, _ = torch.linalg.qr(self.U, "reduced")
        self.V.data, _ = torch.linalg.qr(self.V, "reduced")

        # initialize non-trainable Parameter fields for S-step
        self.U1 = nn.Parameter(torch.randn(input_size, 2 * rank))
        self.V1 = nn.Parameter(torch.randn(output_size, 2 * rank))
        self.U1.requires_grad = False
        self.V1.requires_grad = False

        # initialize bias
        self.bias = nn.Parameter(torch.randn(output_size))

        # set rank and truncation tolerance
        self.r = int(init_compression * rank)
        self.tol = 1e-2
        self.adaptive = adaptive

    def forward(self, x):
        """Returns the output of the layer. The formula implemented is output = U*S*V'*x + bias.
        Args:
            x: input to layer
        Returns:
            output of layer
        """
        r = self.r
        xU = torch.matmul(x, self.U[:, :r])
        xUS = torch.matmul(xU, self.S[:r, :r])
        out = torch.matmul(xUS, self.V[:, :r].T)
        return out + self.bias

    @torch.no_grad()
    def step(self, learning_rate, dlrt_step="basis"):
        """Performs a steepest descend training update on specified low-rank factors
        Args:
            learning_rate: learning rate for training
            dlrt_step: sepcifies step that is taken. Can be 'K', 'L' or 'S'
            adaptive: specifies if fixed-rank or rank-adaptivity is used
        """
        r = self.r
        if self.adaptive:
            r1 = 2 * r
        else:
            r1 = r

        if dlrt_step == "basis":
            # perform K-step
            U0 = self.U[:, :r]
            V0 = self.V[:, :r]
            S0 = self.S[:r, :r]
            K = torch.matmul(U0, S0)
            dK = torch.matmul(
                self.U.grad[:, :r], S0
            )  # + torch.matmul(U0, self.S.grad[:r,:r])
            if self.adaptive:
                self.U1[:, :r1], _ = torch.linalg.qr(torch.cat((U0, dK), 1), "reduced")
            else:
                K = K - learning_rate * dK
                self.U1[:, :r1], _ = torch.linalg.qr(K, "reduced")

            # perform L-step
            L = torch.matmul(V0, S0.T)
            dL = torch.matmul(
                self.V.grad[:, :r], S0.T
            )  # + torch.matmul(V0, self.S.grad[:r,:r].T)
            if self.adaptive:
                self.V1[:, :r1], _ = torch.linalg.qr(torch.cat((V0, dL), 1), "reduced")
            else:
                L = L - learning_rate * dL
                self.V1[:, :r1], _ = torch.linalg.qr(L, "reduced")

            M = self.U1[:, :r1].T @ U0[:, :r]
            N = V0[:, :r].T @ self.V1[:, :r1]

            # update basis
            self.U.data[:, :r1] = self.U1[:, :r1]
            self.V.data[:, :r1] = self.V1[:, :r1]

            self.S.data[:r1, :r1] = M @ self.S[:r, :r] @ N

            self.r = r1

            # update bias
            self.bias.data = self.bias - learning_rate * self.bias.grad

        elif dlrt_step == "coefficients":
            self.S.data[: self.r, : self.r] = (
                self.S[: self.r, : self.r]
                - learning_rate * self.S.grad[: self.r, : self.r]
            )

            # truncate to new rank
            if self.adaptive:
                self.Truncate()
        else:
            print("Wrong step defined: ", dlrt_step)

    @torch.no_grad()
    def BiasStep(self, learning_rate):
        """Performs a steepest descend training update on the bias
        Args:
            learning_rate: learning rate for training
        """
        self.bias.data = self.bias - learning_rate * self.bias.grad

    @torch.no_grad()
    def Truncate(self):
        """Truncates the weight matrix to a new rank"""
        r0 = int(0.5 * self.r)
        P, d, Q = torch.linalg.svd(self.S[: self.r, : self.r])

        # print(torch.linalg.matrix_norm(P @ torch.diag(d) @ Q.t() - self.S[:self.r, :self.r], 'fro'))

        tol = self.tol  # * torch.linalg.norm(d)
        r1 = self.r
        for j in range(0, self.r):
            tmp = torch.linalg.norm(d[j : self.r])
            if tmp < tol:
                r1 = j
                break

        r1 = min(r1, self.rmax)
        r1 = max(r1, 2)
        r1 = r0

        # update s
        self.S.data[:r1, :r1] = torch.diag(d[:r1])

        # update u and v
        self.U.data[:, :r1] = torch.matmul(self.U[:, : self.r], P[:, :r1])
        self.V.data[:, :r1] = torch.matmul(
            self.V[:, : self.r], Q.T[:, :r1]
        )  # DOUBLE CECK Q.T here. Should it be Q?
        self.r = int(r1)


# Define vanilla low-rank layer
class VanillaLowRankLayer(nn.Module):
    def __init__(self, input_size, output_size, rank):
        """Constructs a vanilla low-rank layer of the form U*S*V'*x + b, where
           U, S, V represent the facorized weight W and b is the bias vector
        Args:
            input_size: input dimension of weight W
            output_size: output dimension of weight W, dimension of bias b
            rank: initial rank of factorized weight
        """
        # construct parent class nn.Module
        super(VanillaLowRankLayer, self).__init__()
        # initialize factorized weight and bias
        self.UT = nn.Parameter(torch.randn(rank, output_size))
        self.S = nn.Parameter(torch.randn(rank, rank))
        self.V = nn.Parameter(torch.randn(input_size, rank))
        self.bias = nn.Parameter(torch.randn(output_size))

        # ensure basis is orthonormal
        U1, _ = torch.linalg.qr(self.UT.T, "reduced")
        V1, _ = torch.linalg.qr(self.V, "reduced")
        self.UT.data = U1.T
        self.V.data = V1

    def forward(self, x):
        """Returns the output of the layer. The formula implemented is output = U*S*V'*x + bias.
        Args:
            x: input to layer
        Returns:
            output of layer
        """
        Vx = torch.matmul(x, self.V)
        SVx = torch.matmul(Vx, self.S)
        out = torch.matmul(SVx, self.UT) + self.bias
        return out

    @torch.no_grad()
    def step(self, learning_rate):
        """Performs a steepest descend training update on factorized weight and bias
        Args:
            learning_rate: learning rate for training
        """
        self.UT.data = self.UT - learning_rate * self.UT.grad
        self.V.data = self.V - learning_rate * self.V.grad
        self.S.data = self.S - learning_rate * self.S.grad
        self.bias.data = self.bias - learning_rate * self.bias.grad


# Define low-rank layer
class ParallelLowRankLayer(nn.Module):
    def __init__(self, input_size, output_size, rank, init_compression=0.5):
        """Constructs a low-rank layer of the form U*S*V'*x + b, where
           U, S, V represent the facorized weight W and b is the bias vector
        Args:
            input_size: input dimension of weight W
            output_size: output dimension of weight W, dimension of bias b
            rank: initial rank of factorized weight
            init_compression: initial compression of neural network
        """
        # construct parent class nn.Module
        super(ParallelLowRankLayer, self).__init__()

        self.rmax = rank
        rmax = self.rmax
        r1 = 2 * self.rmax

        # initializes factorized weight
        self.U = nn.Parameter(torch.randn(input_size, r1))
        self.V = nn.Parameter(torch.randn(output_size, r1))
        # self.SK = nn.Parameter(torch.randn(rmax, rmax))
        # self.SL = nn.Parameter(torch.randn(rmax, rmax))
        # self.Sbar = nn.Parameter(torch.randn(rmax, rmax))
        # self.SK.requires_grad = False
        # self.SL.requires_grad = False
        # self.Sbar.requires_grad = False

        # ensure that U and V are orthonormal
        self.U.data, _ = torch.linalg.qr(self.U, "reduced")
        self.V.data, _ = torch.linalg.qr(self.V, "reduced")

        # initialize non-trainable Parameter fields for S-step
        # self.U1 = nn.Parameter(torch.randn(input_size, r1))
        # self.V1 = nn.Parameter(torch.randn(output_size, r1))
        # self.U1.requires_grad = False
        # self.V1.requires_grad = False

        self.singular_values, _ = torch.sort(torch.randn(r1) ** 2, descending=True)
        self.S = nn.Parameter(
            torch.diag(self.singular_values)
        )  # nn.Parameter(torch.eye(r1, r1))
        self.Sinv = torch.Tensor(
            torch.diag(1 / self.singular_values)
        )  # , requires_grad=False)

        # initialize bias
        self.bias = nn.Parameter(torch.randn(output_size))

        # set rank and truncation tolerance
        self.r = self.r = int(init_compression * rank)
        self.tol = 1e-2

    def forward(self, x):
        """Returns the output of the layer. The formula implemented is output = U*S*V'*x + bias.
        Args:
            x: input to layer
        Returns:
            output of layer
        """
        r = self.r
        xU = torch.matmul(x, self.U[:, :r])
        xUS = torch.matmul(xU, self.S[:r, :r])
        out = torch.matmul(xUS, self.V[:, :r].T)
        return out + self.bias

    @torch.no_grad()
    def step(self, learning_rate):
        """Performs a steepest descend training update on specified low-rank factors
        Args:
            learning_rate: learning rate for training
        """
        r = self.r
        r1 = 2 * r

        U_view = self.U  # just a view, no copy
        V_view = self.V
        S_view = self.S
        Sinv_view = self.Sinv

        # perform K-step
        # gradient modification
        U_view.grad[:, :r] = U_view.grad[:, :r] @ Sinv_view[:r, :r]
        U1, _ = torch.linalg.qr(
            torch.cat((U_view[:, :r], U_view.grad[:, :r]), 1), "reduced"
        )

        # perform L-step
        # gradient modification
        V_view.grad[:, :r] = V_view.grad[:, :r] @ Sinv_view[:r, :r]
        V1, _ = torch.linalg.qr(
            torch.cat((V_view[:, :r], V_view.grad[:, :r]), 1), "reduced"
        )

        # set up augmented S matrix
        S_view.data[r:r1, :r] = U1[:, r:r1].T @ (
            U_view[:, :r] @ S_view[:r, :r] - learning_rate * U_view.grad[:, :r]
        )
        S_view.data[:r, r:r1] = (
            V_view[:, :r] @ S_view[:r, :r].T - learning_rate * V_view.grad[:, :r]
        ).T @ V1[:, r:r1]
        # needs to go after SK and SL, since S gets updated here and SK and SL needs old S
        S_view.data[:r, :r] = S_view[:r, :r] - learning_rate * S_view.grad[:r, :r]
        S_view.data[r:r1, r:r1] *= 0  # = torch.zeros((r, r))

        U_view.data[:, r:r1] = U1[:, r:r1]  # torch.cat((U0, U1[:, r:r1]), 1)
        V_view.data[:, r:r1] = V1[:, r:r1]  # torch.cat((V0, V1[:, r:r1]), 1)
        self.r = r1

        # update bias
        self.bias.data = self.bias - learning_rate * self.bias.grad

        self.Truncate()

    @torch.no_grad()
    def BiasStep(self, learning_rate):
        """Performs a steepest descend training update on the bias
        Args:
            learning_rate: learning rate for training
        """
        self.bias.data = self.bias - learning_rate * self.bias.grad

    @torch.no_grad()
    def Truncate(self):
        """Truncates the weight matrix to a new rank"""
        r0 = int(0.5 * self.r)
        P, d, Q = torch.linalg.svd(self.S[: self.r, : self.r])

        # print(torch.linalg.matrix_norm(P @ torch.diag(d) @ Q.t() - self.S[:self.r, :self.r], 'fro'))

        tol = self.tol * torch.linalg.norm(d)
        r1 = self.r
        for j in range(0, self.r):
            tmp = torch.linalg.norm(d[j : self.r])
            if tmp < tol:
                r1 = j
                break

        r1 = min(r1, self.rmax)
        r1 = max(r1, 2)

        # update s
        self.S.data[:r1, :r1] = torch.diag(d[:r1])
        self.Sinv[:r1, :r1] = torch.diag(1.0 / d[:r1])

        # update u and v
        self.U.data[:, :r1] = torch.matmul(self.U[:, : self.r], P[:, :r1])
        self.V.data[:, :r1] = torch.matmul(
            self.V[:, : self.r], Q.T[:, :r1]
        )  # DOUBLE CECK Q.T here. Should it be Q?
        self.r = int(r1)
