from __future__ import annotations

from typing import Optional

import torch
from torch import nn, Tensor
import torch.nn.functional as F

from .rope import RotaryEmbedding


def _sdpa_flat(q: Tensor, k: Tensor, v: Tensor, attn_mask: Optional[Tensor] = None, dropout_p: float = 0.0) -> Tensor:
    """Scaled dot-product attention with flattened batch dims to enable FlashAttention."""
    orig_shape = q.shape
    q = q.reshape(-1, *q.shape[-3:])
    k = k.reshape(-1, *k.shape[-3:])
    v = v.reshape(-1, *v.shape[-3:])
    if attn_mask is not None:
        attn_mask = attn_mask.reshape(-1, *attn_mask.shape[-3:])
    out = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p)
    return out.view(orig_shape)


def multi_head_attention_forward(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    num_heads: int,
    in_proj_weight: Tensor,
    in_proj_bias: Tensor,
    dropout_p: float,
    out_proj_weight: Tensor,
    out_proj_bias: Tensor,
    training: bool = True,
    key_padding_mask: Optional[Tensor] = None,
    attn_mask: Optional[Tensor | int] = None,
    rope: Optional[RotaryEmbedding] = None,
) -> Tensor:
    """Multi-head attention with optional RoPE and simple mask support.

    Accepts inputs of shape (..., tgt_len, embed_dim) and flattens batch dims.
    """

    if isinstance(attn_mask, int):
        assert key_padding_mask is None, "key_padding_mask not supported with integer attn_mask"
        assert rope is None, "RoPE not supported with integer attn_mask"

    *batch_shape, tgt_len, embed_dim = query.shape
    src_len = key.shape[-2]
    head_dim = embed_dim // num_heads
    assert head_dim * num_heads == embed_dim
    assert key.shape == value.shape

    q, k, v = F._in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)

    q = q.view(*batch_shape, tgt_len, num_heads, head_dim).transpose(-3, -2)  # (..., nh, tgt, hd)
    k = k.view(*batch_shape, src_len, num_heads, head_dim).transpose(-3, -2)  # (..., nh, src, hd)
    v = v.view(*batch_shape, src_len, num_heads, head_dim).transpose(-3, -2)  # (..., nh, src, hd)

    if rope is not None:
        q = rope.rotate_queries_or_keys(q)
        k = rope.rotate_queries_or_keys(k)

    if not training:
        dropout_p = 0.0

    # Only tensor masks supported here
    if isinstance(attn_mask, int):
        cut_pos = attn_mask
        out = torch.empty(*batch_shape, tgt_len, embed_dim, device=query.device, dtype=query.dtype)

        q_left = q[..., :cut_pos, :]
        k_left = k[..., :cut_pos, :]
        v_left = v[..., :cut_pos, :]
        attn_left = _sdpa_flat(q_left, k_left, v_left, dropout_p=dropout_p)
        attn_left = attn_left.transpose(-3, -2).contiguous().view(*batch_shape, cut_pos, embed_dim)
        out[..., :cut_pos, :] = F.linear(attn_left, out_proj_weight, out_proj_bias)

        if cut_pos < tgt_len:
            q_right = q[..., cut_pos:, :]
            attn_right = _sdpa_flat(q_right, k_left, v_left, dropout_p=dropout_p)
            attn_right = attn_right.transpose(-3, -2).contiguous().view(*batch_shape, tgt_len - cut_pos, embed_dim)
            out[..., cut_pos:, :] = F.linear(attn_right, out_proj_weight, out_proj_bias)
        return out

    # Tensor masks path
    if attn_mask is not None and attn_mask.dim() == 2:
        # Expand 2D to (..., nh, tgt, src)
        correct_2d = (tgt_len, src_len)
        if attn_mask.shape != correct_2d:
            raise ValueError(f"2D attn_mask should have shape {correct_2d}, got {attn_mask.shape}")
        attn_mask = attn_mask.expand(*batch_shape, num_heads, tgt_len, src_len)

    if key_padding_mask is not None:
        if key_padding_mask.shape != (*batch_shape, src_len):
            raise ValueError(
                f"key_padding_mask should have shape {(*batch_shape, src_len)}, got {key_padding_mask.shape}"
            )
        key_padding_mask = key_padding_mask.view(*batch_shape, 1, 1, src_len).expand(
            *batch_shape, num_heads, tgt_len, src_len
        )
        attn_mask = key_padding_mask if attn_mask is None else (attn_mask + key_padding_mask)

    attn_output = _sdpa_flat(q, k, v, attn_mask, dropout_p)
    attn_output = attn_output.transpose(-3, -2).contiguous().view(*batch_shape, tgt_len, embed_dim)
    return F.linear(attn_output, out_proj_weight, out_proj_bias)


class MultiheadAttention(nn.MultiheadAttention):
    """MultiheadAttention that supports RoPE and integer split masks."""

    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0):
        super().__init__(embed_dim, num_heads, dropout, batch_first=True)

    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        key_padding_mask: Optional[Tensor] = None,
        attn_mask: Optional[Tensor | int] = None,
        rope: Optional[RotaryEmbedding] = None,
    ) -> Tensor:
        if isinstance(attn_mask, int):
            assert key_padding_mask is None, "key_padding_mask is not supported with attn_mask as int"
            assert rope is None, "Rotary position embedding is not supported with attn_mask as int"

        return multi_head_attention_forward(
            query,
            key,
            value,
            self.num_heads,
            self.in_proj_weight,
            self.in_proj_bias,
            self.dropout,
            self.out_proj.weight,
            self.out_proj.bias,
            training=self.training,
            key_padding_mask=key_padding_mask,
            attn_mask=attn_mask,
            rope=rope,
        )


class MultiheadAttentionBlock(nn.TransformerEncoderLayer):
    """TransformerEncoderLayer variant that routes RoPE into attention."""

    def __init__(
        self,
        d_model: int,
        nhead: int,
        dim_feedforward: int,
        dropout: float = 0.0,
        activation: str | callable = "gelu",
        norm_first: bool = True,
    ):
        super().__init__(d_model, nhead, dim_feedforward, dropout, activation, norm_first=norm_first, batch_first=True)
        del self.self_attn
        self.attn = MultiheadAttention(d_model, nhead, dropout)
        self.init_weights()

    def init_weights(self) -> None:
        nn.init.zeros_(self.attn.out_proj.weight)
        nn.init.zeros_(self.attn.out_proj.bias)
        nn.init.zeros_(self.linear2.weight)
        nn.init.zeros_(self.linear2.bias)

    def forward(
        self,
        q: Tensor,
        k: Optional[Tensor] = None,
        v: Optional[Tensor] = None,
        key_padding_mask: Optional[Tensor] = None,
        attn_mask: Optional[Tensor | int] = None,
        rope: Optional[RotaryEmbedding] = None,
    ) -> Tensor:
        if isinstance(attn_mask, int):
            assert key_padding_mask is None, "key_padding_mask is not supported with attn_mask as int"
            assert rope is None, "Rotary position embedding is not supported with attn_mask as int"

        k = q if k is None else k
        v = q if v is None else v

        x = q
        if self.norm_first:
            attn = self._attn_block(self.norm1(q), self.norm1(k), self.norm1(v), key_padding_mask, attn_mask, rope)
            x = x + attn
            x = x + self._ff_block(self.norm2(x))
        else:
            attn = self._attn_block(q, k, v, key_padding_mask, attn_mask, rope)
            x = self.norm1(x + attn)
            x = self.norm2(x + self._ff_block(x))
        return x

    def _attn_block(
        self,
        q: Tensor,
        k: Tensor,
        v: Tensor,
        key_padding_mask: Optional[Tensor],
        attn_mask: Optional[Tensor | int],
        rope: Optional[RotaryEmbedding],
    ) -> Tensor:
        attn = self.attn(q, k, v, key_padding_mask, attn_mask, rope)
        return self.dropout1(attn)

    def _ff_block(self, x: Tensor) -> Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout2(x)


class Encoder(nn.Module):
    """Stacked attention blocks with optional RoPE support."""

    def __init__(
        self,
        num_blocks: int,
        d_model: int,
        nhead: int,
        dim_feedforward: int,
        dropout: float = 0.0,
        activation: str = "gelu",
        norm_first: bool = True,
        use_rope: bool = True,
        rope_base: float = 30000.0,
    ) -> None:
        super().__init__()
        self.blocks = nn.ModuleList(
            [
                MultiheadAttentionBlock(
                    d_model=d_model,
                    nhead=nhead,
                    dim_feedforward=dim_feedforward,
                    dropout=dropout,
                    activation=activation,
                    norm_first=norm_first,
                )
                for _ in range(num_blocks)
            ]
        )
        self.rope = RotaryEmbedding(dim=d_model // nhead, theta=rope_base) if use_rope else None

    def forward(
        self,
        src: Tensor,
        key_padding_mask: Optional[Tensor] = None,
        attn_mask: Optional[Tensor | int] = None,
    ) -> Tensor:
        out = src
        for block in self.blocks:
            out = block(q=out, key_padding_mask=key_padding_mask, attn_mask=attn_mask, rope=self.rope)
        return out

