import copy
import math
import torch


class Module(torch.nn.Module):

    def __init__(self, device, kwargs):
        super().__init__()

        self.__device = device

        input_size = kwargs["input_size"]
        class_size = kwargs["class_size"]
        width = kwargs["width"]
        depth = kwargs["depth"]
        
        # Iniatializing the parameters
        self.__linear = []
        self.__bias = []
        
        self.__linear.append(torch.empty(input_size, width))
        self.__bias.append(torch.empty(input_size))
        for _ in range(depth-2):
            self.__linear.append(torch.empty(width, width))
            self.__bias.append(torch.empty(width))
        self.__linear.append(torch.empty(width, class_size))
        self.__bias.append(torch.empty(class_size))
        
        print(self.__linear)
        self.__linear = torch.nn.ParameterList(self.__linear)
        self.__bias = torch.nn.ParameterList(self.__bias)

        # Initializing weights with Xavier initializer
        for i in range(len(self.__linear)):
            torch.nn.init.xavier_normal_(self.__linear[i])
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(
                self.__linear[i])
            bound = 1 / math.sqrt(fan_in)
            torch.nn.init.uniform_(
                self.__bias[i], -bound, bound)

        # For the learner
        self.param_dict = {
            "linear": self.__linear,
            "bias": self.__bias,
        }

    def forward(self, batch):
        x = batch["x"]

        # Forwarding in the layers
        for i in range(len(self.__linear)-1):

            # Forwarding in convolution and activation
            x = torch.nn.functional.linear(
                x, self.__linear[i],
                bias=self.__bias[i])
            x = torch.nn.functional.relu(x)

        x = torch.nn.functional.linear(
            x, self.__linear[i],
            bias=self.__bias[i])
            
        # --------------------------------------------------------------- #

        self.out = x
        #self.pred = torch.max(x, dim=1)[1].unsqueeze(1)