import torch
import torch.nn as nn


class model(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim, depth, relu, bias, gamma):
        super(model, self).__init__()
        self.depth = depth
        self.activation = nn.LeakyReLU(negative_slope=relu)
        self.layers = nn.ModuleDict()

        for i in range(1, depth+1):
            if i == 1:  # first layer
                self.layers['layer_'+str(i)] = torch.nn.Linear(in_dim, hid_dim, bias=bias)
            elif i == depth:  # last layer
                self.layers['layer_'+str(i)] = torch.nn.Linear(hid_dim, out_dim, bias=bias)
            else:
                self.layers['layer_'+str(i)] = torch.nn.Linear(hid_dim, hid_dim, bias=bias)            
        if gamma != 0:
            self._init_weights(gamma)

    def forward(self, x):
        for i in range(1, self.depth+1):
            x = self.layers['layer_'+str(i)](x)
            if i != self.depth:  # last layer doesn't need activation
                x = self.activation(x)
        return x

    def _init_weights(self, gamma):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.manual_seed(1)
                nn.init.normal_(m.weight, mean=0, std=gamma/(torch.numel(m.weight)**0.5))
                if m.bias is not None:
                    nn.init.normal_(m.bias, mean=0, std=gamma/(torch.numel(m.weight)**0.5))
                    