import einops
import torch
import torch.nn.functional as F
from torch import nn

from .functional import init_truncnormal_zero_bias


class DotProductAttention(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        init_weights: str = "truncnormal002",
    ):
        super().__init__()
        assert hasattr(F, "scaled_dot_product_attention")
        assert dim % num_heads == 0, "dim should be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.init_weights = init_weights

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        if self.init_weights == "torch":
            pass
        elif self.init_weights in ["truncnormal", "truncnormal002"]:
            self.apply(init_truncnormal_zero_bias)
        else:
            raise NotImplementedError

    def forward(self, x: torch.Tensor, attn_mask: None | torch.Tensor = None) -> torch.Tensor:
        q, k, v = einops.rearrange(
            self.qkv(x),
            "bs seqlen (three num_heads head_dim) -> three bs num_heads seqlen head_dim",
            three=3,
            num_heads=self.num_heads,
            head_dim=self.head_dim,
        ).unbind(0)
        x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
        x = einops.rearrange(x, "bs num_heads seqlen head_dim -> bs seqlen (num_heads head_dim)")
        x = self.proj(x)

        return x
