from torch import nn


class MLP(nn.Module):
    def __init__(self, in_features, num_classes, drop_prob=0.15, units=128, num_layers=4):
        super().__init__()
        layers = []
        for n in range(num_layers):
            layers.extend([nn.Linear(in_features, units),
                           nn.ELU(),
                           nn.Dropout(drop_prob)])
            in_features = units
        layers.append(nn.Linear(in_features, num_classes))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)
