import torch.nn as nn
import torch.nn.functional as F


class SimpleNet(nn.Module):
    def __init__(self, dims, mnist_like=False):
        super(SimpleNet, self).__init__()
        dims_tup = list(zip(dims, dims[1:]))
        self.mnist_like = mnist_like
        self.layers = nn.ModuleList()
        for i, o in dims_tup:
            self.layers.append(nn.Linear(i, o))

    def forward(self, x):
        if self.mnist_like:
            x = x.view(-1, 784)
        activations = []
        for i, layer in enumerate(self.layers[:-1]):
            if i == 0:
                activations.append(F.relu(layer(x)))
            else:
                activations.append(F.relu(layer(activations[i - 1])))
        if len(self.layers[:-1]) > 0:
            activations.append(F.log_softmax(self.layers[-1](activations[-1]), -1))
        else:
            activations.append(F.log_softmax(self.layers[-1](x), -1))
        return activations[-1]


# class SimpleNetGauss(torch.nn.Module):

#     def __init__(self, dims, mnist_like=False):
#         super(SimpleNet, self).__init__()
#         dims_tup = list(zip(dims, dims[1:]))
#         self.mnist_like = mnist_like
#         self.layers = torch.nn.ModuleList()
#         for i, o in dims_tup:
#             self.layers.append(torch.nn.Linear(i, o))

#     def forward(self, x):
#         for i, layer in enumerate(self.layers[:-1]):
#             x = F.relu(layer(x))
#         return torch.sigmoid(self.layers[-1](x))
