import torch
import torch.nn.functional as F


class MLP(torch.nn.Module):
    """
    Simple 2 layer MLP
    """

    def __init__(
        self, input_dim, hidden_dim, output_dim, activation=torch.nn.ReLU(), bias=True
    ):

        super(MLP, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim, bias=bias)
        # self.relu = torch.nn.ReLU()
        self.activation = activation
        self.fc2 = torch.nn.Linear(hidden_dim, output_dim, bias=bias)

    def forward(self, x):
        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)
        return x


class GatedMLP(torch.nn.Module):
    """
    A Simple 2-layer Gated MLP model.
    """

    def __init__(self, input_dim, hidden_dim, output_dim, activation=F.silu, bias=True):

        super(GatedMLP, self).__init__()
        self.input_size = input_dim
        self.hidden_size = hidden_dim
        self.output_size = output_dim
        self.activation = activation

        self.fc1 = torch.nn.Linear(input_dim, 2 * hidden_dim, bias=bias)
        self.fc2 = torch.nn.Linear(hidden_dim, output_dim, bias=bias)

    def forward(self, x):
        """
        x: (batch_size, input_dim) Input tensor
        """
        x = self.fc1(x)
        x, gate = x.chunk(2, dim=-1)
        x = x * self.activation(gate)
        x = self.fc2(x)
        return x
