import torch.nn as nn
import torch.nn.functional as F


class MLP(nn.Module):
    def __init__(
        self,
        *,
        input_dim: int,
        hidden_dim: int | None = None,
        output_dim: int,
        use_bias: bool = False,
    ):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = input_dim

        self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=use_bias)
        self.up_proj = nn.Linear(input_dim, hidden_dim, bias=use_bias)
        self.down_proj = nn.Linear(hidden_dim, output_dim, bias=use_bias)
        self.act_fn = F.gelu

    def forward(self, x):
        down_proj = self.down_proj(
            self.act_fn(self.gate_proj(x)) * self.up_proj(x)
        )
        return down_proj
