from typing import List

import torch.nn as nn


class MLP(nn.Module):

    def __init__(self, features: List[int], momentum: float = 0.1, negative_slope: float = 0.0, dropout: float = 0.0):
        super(MLP, self).__init__()

        self.layers = nn.Sequential()
        for i in range(0, len(features) - 1):
            self.layers.append(nn.Linear(features[i], features[i + 1]))
            self.layers.append(nn.BatchNorm1d(features[i + 1], momentum=momentum))
            self.layers.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=True))
            if dropout > 0.0:
                self.layers.append(nn.Dropout(p=dropout))

    def forward(self, input):
        return self.layers(input)
