import torch
import torch.nn as nn


class ABLayer(nn.Module):
    def __init__(self, input_size, output_size, rank, bias=False):
        """Constructs a dense layer of the form W*x + b, where W is the weigh matrix and b is the bias vector
        Args:
            input_size: input dimension of weight W
            output_size: output dimension of weight W, dimension of bias b
        """
        # construct parent class nn.Module
        super(ABLayer, self).__init__()
        # define weights as trainable parameter
        self.r = rank
        self.A = nn.Parameter(torch.randn(input_size, self.r) / input_size)
        self.BT = nn.Parameter(torch.randn(self.r, output_size) / output_size)
        # define bias as trainable parameter
        if bias:
            self.bias = nn.Parameter(torch.randn(output_size))
        else:
            self.bias = None

    def __init__(self, A, BT, input_size, output_size, rank, bias=False):
        """Constructs a dense layer of the form W*x + b, where W is the weigh matrix and b is the bias vector
        Args:
            input_size: input dimension of weight W
            output_size: output dimension of weight W, dimension of bias b
        """
        # construct parent class nn.Module
        super(ABLayer, self).__init__()
        # define weights as trainable parameter
        self.r = rank
        self.A = nn.Parameter(A.detach().clone())
        self.BT = nn.Parameter(BT.detach().clone())

        # define bias as trainable parameter
        if bias:
            self.bias = nn.Parameter(torch.randn(output_size))
        else:
            self.bias = None

    def forward(self):
        """Returns the output of the layer. The formula implemented is output = W*x + bias.
        Args:
            x: input to layer
        Returns:
            output of layer
        """
        out = self.A[:, : self.r] @ self.BT[: self.r, :]
        if self.bias is not None:
            out = out + self.bias
        return out

    @torch.no_grad()
    def step(self, learning_rate):
        """Performs a steepest descend training update on weights and biases
        Args:
            learning_rate: learning rate for training
        """
        self.A.data[:, : self.r] = (
            self.A[:, : self.r] - learning_rate * self.A.grad[:, : self.r]
        )
        self.BT.data[: self.r, :] = (
            self.BT[: self.r, :] - learning_rate * self.BT.grad[: self.r, :]
        )

        if self.bias is not None:
            self.bias.data = self.bias - learning_rate * self.bias.grad

        # self.A.grad.data.zero_()
        # self.BT.grad.data.zero_()
        # if self.bias is not None:
        #    self.bias.grad.data.zero_()
