import torch
import torch.nn as nn
import utils


class MLP(nn.Module):
    def __init__(self, dims_all, activation, dropout=-1.0, batchnorm=False, actfun_output=False, binary_output=False):
        super(MLP, self).__init__()
        modules = []
        for i in range(len(dims_all)-2):
            modules.append(nn.Linear(dims_all[i], dims_all[i+1]))
            if batchnorm:
                modules.append(nn.BatchNorm1d(dims_all[i+1]))
            modules.append(utils.actmodule(activation))
            if dropout>0.0:
                modules.append(nn.Dropout(p=dropout))
        modules.append(nn.Linear(dims_all[-2], dims_all[-1]))
        if actfun_output:
            modules.append(utils.actmodule(activation))
        if binary_output:
            modules.append(nn.Sigmoid())
        self.net = nn.Sequential(*modules)
        self.dim_in = dims_all[0]

    def forward(self, x):
        out = self.net(x.view(-1, self.dim_in))
        return out