import torch
from torch import nn

from .dot_product_attention import DotProductAttention
from .upactdown_mlp import UpActDownMlp


class TransformerBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_hidden_dim: int | None = None,
        mlp_ctor: type = UpActDownMlp,
        norm_ctor: type = nn.LayerNorm,
        attn_ctor: type = DotProductAttention,
        eps: float = 1e-6,
        init_weights: str = "truncnormal002",
    ):
        super().__init__()
        self.norm1 = norm_ctor(dim, eps=eps)
        self.attn = attn_ctor(
            dim=dim,
            num_heads=num_heads,
            init_weights=init_weights,
        )
        self.norm2 = norm_ctor(dim, eps=eps)
        self.mlp = mlp_ctor(
            input_dim=dim,
            hidden_dim=mlp_hidden_dim or dim * 4,
            init_weights=init_weights,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x