from torch import nn
import torch.nn.functional as F
import math


class Fusion(nn.Module):
    def __init__(self, input_dim=2816, out_dim=8, hidden=512):
        super(Fusion, self).__init__()
        self.enc_net = nn.Sequential(
          nn.Linear(input_dim, hidden),
          nn.ReLU(),
#           nn.Dropout(p=0.5),
          nn.Linear(hidden, out_dim)
        )
        
    def forward(self, feat):
        return self.enc_net(feat)


class MLP(nn.Module):

    def __init__(self, input_dim, hidden_dims, use_bn=True, input_dropout=0., hidden_dropout=0., normalize=False):
        super().__init__()

        layers = []
        for index, dim in enumerate(hidden_dims[:-1]):
            layers.append(nn.Linear(input_dim, dim, bias=True))
            nn.init.normal_(layers[-1].weight, std=0.01)
            nn.init.constant_(layers[-1].bias, 0.)

            if index < len(hidden_dims) - 1:
                if use_bn:
                    layers.append(nn.BatchNorm1d(dim))
#                 layers.append(nn.LeakyReLU(negative_slope=0.2))
                layers.append(nn.ReLU(inplace=True))
            if input_dropout and index == 0:
                layers.append(nn.Dropout(p=input_dropout))
            elif hidden_dropout and index < len(hidden_dims) - 1:
                layers.append(nn.Dropout(p=hidden_dropout))

            input_dim = dim

        layers.append(nn.Linear(input_dim, hidden_dims[-1]))
        nn.init.normal_(layers[-1].weight, std=0.01)
        nn.init.constant_(layers[-1].bias, 0.)

        self.mlp = nn.Sequential(*layers)
        
        self.normalize = normalize

    def forward(self, x):
        if self.normalize:
            return F.normalize(self.mlp(x), dim=1)
        return self.mlp(x)


# +
class MaxOut_MLP(nn.Module):
    
    def __init__(self, input_dim, hidden_dims, linear_layer_out_dim=None):
        
        super(MaxOut_MLP,self).__init__()
        
        layers = [nn.BatchNorm1d(input_dim, 1e-4)]
        for index, dim in enumerate(hidden_dims):
            layers.append(Maxout(input_dim, dim, 2))            
            layers.append(nn.BatchNorm1d(dim, 1e-4))
            layers.append(nn.Dropout(p=0.3))
            
            input_dim = dim
                
        if linear_layer_out_dim is not None:
            layers.append(nn.Linear(hidden_dims[-1], linear_layer_out_dim))
            
        self.mlp = nn.Sequential(*layers)
            
    def forward(self,x):
        return self.mlp(x)
    
class Maxout(nn.Module):
    def __init__(self,d,m,k):
        super(Maxout,self).__init__()
        self.d_in,self.d_out,self.pool_size = d,m,k
        self.fc = nn.Linear(d,m*k)
        
    def forward(self,inputs):
        shape = list(inputs.size())
        shape[-1] = self.d_out
        shape.append(self.pool_size)
        max_dim = len(shape) - 1
        out = self.fc(inputs)
        m,_ = out.view(*shape).max(dim=max_dim)
        return m


# -

class MNISTMLP(nn.Module):
    def __init__(self, indim, hiddim, outdim):
        super().__init__()
        self.indim = indim # indim not necessarily an integer
        self.hiddim = hiddim
        self.outdim = outdim
        
        # constructing the encoder and the predictor
        if isinstance(indim, tuple) or isinstance(indim, list):
            self.enc = nn.Sequential(
                nn.Flatten(),
                nn.Linear(np.prod(indim), hiddim),
                nn.ReLU(),
                nn.Linear(hiddim, hiddim),
                nn.ReLU()
            )
        else: 
            self.enc = nn.Sequential(
                nn.Linear(indim, hiddim),
                nn.ReLU(),
                nn.Linear(hiddim, hiddim),
                nn.ReLU()
            )

        self.pred = nn.Linear(hiddim, outdim)
        self.reset_parameters()

    def forward(self, x):
        feats = self.enc(x)
        logits = self.pred(feats)

        return logits


    def reset_parameters(self) -> None:
        """
        Calls the Xavier parameter initialization function.
        """
        for m in self.modules():
            if isinstance(m, nn.Linear):
                fan_in = m.weight.data.size(1)
                fan_out = m.weight.data.size(0)
                std = 1.0 * math.sqrt(2.0 / (fan_in + fan_out))
                a = math.sqrt(3.0) * std
                m.weight.data.uniform_(-a, a)
                if m.bias is not None:
                    m.bias.data.fill_(0.0)
