import torch
import torch.nn as nn


class SVDLayer(nn.Module):
    def __init__(self, input_size, output_size, rank, bias=False):
        """Constructs a dense layer of the form W*x + b, where W is the weigh matrix and b is the bias vector
        Args:
            input_size: input dimension of weight W
            output_size: output dimension of weight W, dimension of bias b
        """
        # construct parent class nn.Module
        super(SVDLayer, self).__init__()
        # define weights as trainable parameter
        self.r = rank
        self.A = nn.Parameter(torch.randn(input_size, self.r) / input_size)
        self.S = nn.Parameter(torch.randn(self.r, self.r) / input_size)
        self.BT = nn.Parameter(torch.randn(self.r, output_size) / output_size)

        # define bias as trainable parameter
        if bias:
            self.bias = nn.Parameter(torch.randn(output_size))
        else:
            self.bias = None

    def __init__(self, A, S, BT, output_size, rank, bias=False):
        """Constructs a dense layer of the form W*x + b, where W is the weigh matrix and b is the bias vector
        Args:
            input_size: input dimension of weight W
            output_size: output dimension of weight W, dimension of bias b
        """
        # construct parent class nn.Module
        super(SVDLayer, self).__init__()
        # define weights as trainable parameter
        self.r = rank
        self.A = nn.Parameter(A.detach().clone()[:, : self.r])
        self.S = nn.Parameter(S.detach().clone()[: self.r, : self.r])
        self.BT = nn.Parameter(BT.detach().clone()[: self.r, :])

        # define bias as trainable parameter
        if bias:
            self.bias = nn.Parameter(torch.randn(output_size))
        else:
            self.bias = None

    def forward(self):
        """Returns the output of the layer. The formula implemented is output = W*x + bias.
        Args:
            x: input to layer
        Returns:
            output of layer
        """
        out = self.A @ self.S @ self.BT
        if self.bias is not None:
            out = out + self.bias
        return out

    def ortho_regularization(self):

        return 0.5 * (
            torch.linalg.norm(self.A.T @ self.A - torch.eye(self.r)) ** 2
            + torch.linalg.norm(self.BT @ self.BT.T - torch.eye(self.r)) ** 2
        )

    @torch.no_grad()
    def step(self, learning_rate):
        """Performs a steepest descend training update on weights and biases
        Args:
            learning_rate: learning rate for training
        """
        self.A.data[:, : self.r] = (
            self.A[:, : self.r] - learning_rate * self.A.grad[:, : self.r]
        )
        self.S.data[: self.r, : self.r] = (
            self.S[: self.r, : self.r] - learning_rate * self.S.grad[: self.r, : self.r]
        )
        self.BT.data[: self.r, :] = (
            self.BT[: self.r, :] - learning_rate * self.BT.grad[: self.r, :]
        )

        if self.bias is not None:
            self.bias.data = self.bias - learning_rate * self.bias.grad

        self.A.grad.zero_()
        self.S.grad.zero_()
        self.BT.grad.zero_()
        # if self.bias is not None:
        #    self.bias.grad.data.zero_()
