import torch.nn as nn
import torch.nn.utils.parametrizations as param


class SimpleMLP(nn.Module):
    def __init__(self, input_dim, output_dim, dropout_prob=0.2, use_layernorm=True):
        super().__init__()

        h1 = max(1, int((2 / 3) * input_dim))
        h2 = max(1, int((1 / 3) * input_dim))

        fc1 = param.weight_norm(nn.Linear(input_dim, h1))
        fc2 = param.weight_norm(nn.Linear(h1, h2))
        fc3 = nn.Linear(h2, output_dim)

        layers = [fc1]
        if use_layernorm:
            layers.append(nn.LayerNorm(h1))
        layers += [nn.ReLU(), nn.Dropout(dropout_prob)]

        layers.append(fc2)
        if use_layernorm:
            layers.append(nn.LayerNorm(h2))
        layers += [nn.ReLU(), nn.Dropout(dropout_prob)]

        layers.append(fc3)

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)