import torch
from torch import nn

from .functional import init_truncnormal_zero_bias


class UpActDownMlp(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        init_weights: str = "truncnormal002",
    ):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.init_weights = init_weights
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, input_dim)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        if self.init_weights == "torch":
            pass
        elif self.init_weights in ["truncnormal", "truncnormal002"]:
            self.apply(init_truncnormal_zero_bias)
        else:
            raise NotImplementedError

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x
