import torch
import torch.nn as nn
from torch import Tensor


class WLayer(nn.Module):
    def __init__(self, input_size, output_size, bias=False):
        """Constructs a dense layer of the form W*x + b, where W is the weigh matrix and b is the bias vector
        Args:
            input_size: input dimension of weight W
            output_size: output dimension of weight W, dimension of bias b
        """
        # construct parent class nn.Module
        super(WLayer, self).__init__()
        # define weights as trainable parameter
        self.W = nn.Parameter(torch.randn(input_size, output_size))
        # define bias as trainable parameter
        if bias:
            self.bias = nn.Parameter(torch.randn(output_size))
        else:
            self.bias = None

    def __init__(self, W_init: Tensor, bias=False):
        """Constructs a dense layer of the form W*x + b, where W is the weigh matrix and b is the bias vector
        Args:
            input_size: input dimension of weight W
            output_size: output dimension of weight W, dimension of bias b
        """
        # construct parent class nn.Module
        super(WLayer, self).__init__()
        # define weights as trainable parameter
        self.W = nn.Parameter(W_init.detach().clone())
        # define bias as trainable parameter
        if bias:
            self.bias = nn.Parameter(torch.randn(W_init.shape[1]))
        else:
            self.bias = None

    def forward(self):
        """Returns the output of the layer. The formula implemented is output = W*x + bias.
        Args:
            x: input to layer
        Returns:
            output of layer
        """
        out = self.W
        if self.bias is not None:
            out = out + self.bias
        return out

    @torch.no_grad()
    def step(self, learning_rate):
        """Performs a steepest descend training update on weights and biases
        Args:
            learning_rate: learning rate for training
        """
        self.W.data = self.W - learning_rate * self.W.grad

        if self.bias is not None:
            self.bias.data = self.bias - learning_rate * self.bias.grad

        self.W.grad.data.zero_()
        if self.bias is not None:
            self.bias.grad.data.zero_()
