import torch
from torch import nn

from .perceiver_attention import PerceiverAttention
from .upactdown_mlp import UpActDownMlp


class PerceiverBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        kv_dim: int | None = None,
        mlp_hidden_dim: int | None = None,
        norm_ctor: type = nn.LayerNorm,
        eps: float = 1e-6,
        init_weights: str = "truncnormal002",
    ):
        super().__init__()
        self.norm1q = norm_ctor(dim, eps=eps)
        self.norm1kv = norm_ctor(kv_dim or dim, eps=eps)
        self.attn = PerceiverAttention(
            dim=dim,
            num_heads=num_heads,
            kv_dim=kv_dim,
            init_weights=init_weights,
        )
        self.norm2 = norm_ctor(dim, eps=eps)
        self.mlp = UpActDownMlp(
            input_dim=dim,
            hidden_dim=mlp_hidden_dim or dim * 4,
            init_weights=init_weights,
        )

    def forward(self, q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
        q = q + self.attn(q=self.norm1q(q), kv=self.norm1kv(kv))
        q = q + self.mlp(self.norm2(q))
        return q
