import torch.nn as nn


class MLP(nn.Module):
    """
    A simple Multi-layer Perceptron model.
    """

    def __init__(
        self,
        input_dim,
        output_dim,
        num_layers,
        hidden_dim=1024,
        dropout=0.1,
        activation=nn.ReLU,
        layer_norm=True,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.dropout = dropout
        self.layer_norm = layer_norm
        # Instantiate the activation function
        self.activation = activation() if isinstance(activation, type) else activation
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_dim, hidden_dim))
        # add layer normalization, putting all values between 0 and 1
        self.layers.append(nn.LayerNorm(hidden_dim))

        for _ in range(num_layers - 1):
            self.layers.append(nn.Linear(hidden_dim, hidden_dim))
            self.layers.append(nn.Dropout(dropout))
            self.layers.append(self.activation)

        self.layers.append(nn.Linear(hidden_dim, output_dim))

    def normalize(self, x):
        # normalize the input horizontally
        # (all values in the same row go to 0 mean and 1 std)
        return (x - x.mean(dim=1, keepdim=True)) / x.std(dim=1, keepdim=True)

    def forward(self, x):
        if self.layer_norm:
            # normalize the input horizontally
            # (all values in the same row go to 0 mean and 1 std)
            x = self.normalize(x)

        for layer in self.layers:
            x = layer(x)
        return x
