"""Transformer modules used to build ACEv2 model."""

from typing import Any, Callable, List, Optional, Tuple

import torch
from torch.nn.attention.flex_attention import BlockMask, flex_attention

from src.models.utils import (
    expand_kv_heads,
)


class MultiheadAttention(torch.nn.Module):
    """Multi-head attention with flexible key-value head configuration."""

    def __init__(
        self,
        in_features: int,
        num_heads: int,
        head_dim: Optional[int] = None,
        out_features: Optional[int] = None,
        key_features: Optional[int] = None,
        value_features: Optional[int] = None,
        num_kv_heads: Optional[int] = None,
    ) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features or in_features
        self.num_heads = num_heads
        self.head_dim = head_dim or in_features // num_heads
        self.num_kv_heads = num_kv_heads or num_heads

        self.q_proj = torch.nn.Linear(
            in_features, num_heads * self.head_dim, bias=False
        )
        self.k_proj = torch.nn.Linear(
            key_features or in_features, self.num_kv_heads * self.head_dim, bias=False
        )
        self.v_proj = torch.nn.Linear(
            value_features or in_features, self.num_kv_heads * self.head_dim, bias=False
        )
        self.o_proj = torch.nn.Linear(
            num_heads * self.head_dim, self.out_features, bias=False
        )

    def forward(
        self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, block_mask: BlockMask
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """Forward pass with flex attention."""
        B = q.size(0)
        qh = (
            self.q_proj(q)
            .view(B, q.size(1), self.num_heads, self.head_dim)
            .transpose(1, 2)
        )
        kh = (
            self.k_proj(k)
            .view(B, k.size(1), self.num_kv_heads, self.head_dim)
            .transpose(1, 2)
        )
        vh = (
            self.v_proj(v)
            .view(B, v.size(1), self.num_kv_heads, self.head_dim)
            .transpose(1, 2)
        )
        kh = expand_kv_heads(kh, self.num_heads // self.num_kv_heads)
        vh = expand_kv_heads(vh, self.num_heads // self.num_kv_heads)

        out = flex_attention(qh, kh, vh, block_mask=block_mask)
        out = out.transpose(1, 2).reshape(B, q.size(1), self.num_heads * self.head_dim)
        return self.o_proj(out), (kh, vh)


class TransformerLayer(torch.nn.Module):
    """Single transformer layer with attention and feed-forward."""

    def __init__(
        self,
        dim_model: int,
        num_head: int,
        *,
        dim_feedforward: int = 512,
        dropout: float = 0.0,
        layer_norm_eps: float = 1e-5,
        **mha_kw: Any,
    ) -> None:
        super().__init__()
        self.attn = MultiheadAttention(dim_model, num_head, **mha_kw)
        self.drop_attn = torch.nn.Dropout(dropout)
        self.ff1, self.ff2 = (
            torch.nn.Linear(dim_model, dim_feedforward),
            torch.nn.Linear(dim_feedforward, dim_model),
        )
        self.drop_ff = torch.nn.Dropout(dropout)
        self.norm1, self.norm2 = (
            torch.nn.LayerNorm(dim_model, eps=layer_norm_eps),
            torch.nn.LayerNorm(dim_model, eps=layer_norm_eps),
        )

    def forward(
        self, x: torch.Tensor, block_mask: BlockMask
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """Forward pass through attention and feed-forward blocks."""
        y = self.norm1(x)
        attn_out, (k, v) = self.attn(y, y, y, block_mask)
        x = x + self.drop_attn(attn_out)
        y = self.norm2(x)
        return x + self.ff2(self.drop_ff(torch.nn.functional.gelu(self.ff1(y)))), (k, v)


class Transformer(torch.nn.Module):
    """Multi-layer transformer encoder with optional gradient checkpointing."""

    def __init__(
        self,
        num_layers: int,
        dim_model: int,
        num_head: int,
        *,
        dim_feedforward: int = 512,
        dropout: float = 0.0,
        gradient_checkpointing: bool = False,
        **mha_kw: Any,
    ) -> None:
        super().__init__()
        self.dim_model = dim_model
        self.layers = torch.nn.ModuleList(
            TransformerLayer(
                dim_model,
                num_head,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
                **mha_kw,
            )
            for _ in range(num_layers)
        )
        self.norm = torch.nn.LayerNorm(dim_model)
        self.grad_ckpt = gradient_checkpointing

    def forward(
        self, x: torch.Tensor, block_mask: BlockMask
    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        """Forward pass through all layers with optional gradient checkpointing."""
        kv_cache = []
        for lyr in self.layers:
            if self.grad_ckpt and self.training:
                x = torch.utils.checkpoint.checkpoint(lyr, x, block_mask)
            else:
                x, kv = lyr(x, block_mask)
                kv_cache.append(kv)
        return self.norm(x), kv_cache
