import torch.nn as nn


class MLP(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size, bias=True):
        super(MLP, self).__init__()
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.output_size = output_size
        self.activations = []
        self.activations_from_abs_input = None
        self.layers = nn.ModuleList()
        self.non_linearity = nn.ReLU()
        self.uses_bias = bias
        self.alpha = 1
        layer_sizes = [input_size] + hidden_sizes + [output_size]
        for i in range(len(hidden_sizes) + 1):
            self.layers.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1], bias=bias))

    def forward(self, x, keep_activations=False):
        x = x.flatten(start_dim=1)
        if keep_activations:
            self.activations = []
        for i, layer in enumerate(self.layers):
            x = (
                self.non_linearity(layer(x))
                if i < len(self.layers) - 1 and (i > 0)
                else layer(x)
            )
            if keep_activations and i < len(self.layers):
                self.activations.append(x)
        return x * self.alpha
