import torch.nn as nn


class MLPWithLayerNorm(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3):
        super(MLPWithLayerNorm, self).__init__()
        self.layers = nn.ModuleList()

        # input layers
        self.layers.append(nn.Linear(input_dim, hidden_dim))
        self.layers.append(nn.LayerNorm(hidden_dim))
        self.layers.append(nn.ReLU())

        # hidden layers
        for _ in range(num_layers - 2):
            self.layers.append(nn.Linear(hidden_dim, hidden_dim))
            self.layers.append(nn.LayerNorm(hidden_dim))
            self.layers.append(nn.ReLU())

        # output layers
        self.layers.append(nn.Linear(hidden_dim, output_dim))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class MLPWithGroupNorm(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, num_groups=32):
        super(MLPWithGroupNorm, self).__init__()
        self.layers = nn.ModuleList()

        # input layers
        self.layers.append(nn.Linear(input_dim, hidden_dim))
        self.layers.append(nn.GroupNorm(num_groups, hidden_dim))
        self.layers.append(nn.ReLU())

        # hidden layers
        for _ in range(num_layers - 2):
            self.layers.append(nn.Linear(hidden_dim, hidden_dim))
            self.layers.append(nn.GroupNorm(num_groups, hidden_dim))
            self.layers.append(nn.ReLU())

        # output layers
        self.layers.append(nn.Linear(hidden_dim, output_dim))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

