import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self, 
            input_size  : int, 
            hidden_size : int, 
            layers      : int, 
            out_size    : int,
            act=nn.ReLU(), bn=True, bias=False, use_pe=0):

        super(MLP, self).__init__()

        if use_pe:
            self.pe = PositionalEncodingLayer(L=use_pe)
            self.fc1 = nn.Linear(2 * use_pe, hidden_size, bias=False)
        else:
            self.pe = None
            self.fc1 = nn.Linear(input_size, hidden_size, bias=False)

        if bn:
            self.bn = nn.BatchNorm1d(hidden_size)
        else:
            self.bn = None
        mid_list = []
        for i in range(layers):
            if bn:
                mid_list += [nn.Linear(hidden_size,hidden_size), nn.BatchNorm1d(hidden_size), act]
            else:
                mid_list += [nn.Linear(hidden_size,hidden_size, bias=False), act]
        self.mid = nn.Sequential(*mid_list)
        self.out = nn.Linear(hidden_size, out_size, bias=bias)
        self.act = act
        init_weights(self, {'weights':'xavier', 'bias':'ones'}, gain=1)

    def forward(self,x):
        if self.pe is not None:
            out = self.pe(x)
        else:
            out = x
        out = self.fc1(out)
        if self.bn:
            out = self.bn(out)
        out = self.act(out)
        out = self.mid(out)
        out = self.out(out)
        return torch.softmax(out, -1)

def init_weights(net, init_dict, gain=1, input_class=None):
    def init_func(m):
        if input_class is None or type(m) == input_class:
            for key, value in init_dict.items():
                param = getattr(m, key, None)
                if param is not None:
                    if value == 'normal':
                        nn.init.normal_(param.data, 0.0, gain)
                    elif value == 'xavier':
                        nn.init.xavier_normal_(param.data, gain=gain)
                    elif value == 'kaiming':
                        nn.init.kaiming_normal_(param.data, a=0, mode='fan_in')
                    elif value == 'orthogonal':
                        nn.init.orthogonal_(param.data, gain=gain)
                    elif value == 'uniform':
                        nn.init.uniform_(param.data, 0, 2)
                    elif value == 'zeros':
                        nn.init.zeros_(param.data)
                    elif value == 'very_small':
                        nn.init.constant_(param.data, 1e-3*gain)
                    elif value == 'ones':
                        nn.init.constant_(param.data, 1)
                    elif value == 'xavier1D':
                        nn.init.normal_(param.data, 0.0, gain/param.numel().sqrt())
                    elif value == 'identity':
                        nn.init.eye_(param.data)
                    else:
                        raise NotImplementedError('initialization method [%s] is not implemented' % value)
    net.apply(init_func)
