import einops
import torch
import torch.nn.functional as F
from torch import nn

from kappamodules.functional.pos_embed import relative_position_indices
from kappamodules.init import (
    init_xavier_uniform_zero_bias,
    init_xavier_uniform_merged_linear,
    init_truncnormal_zero_bias,
)

from src.modules.rope import rope_rotation
from src.utils.attn_sinks import AttentionSinks


class AnchoredDotProductAttention(nn.Module):
    def __init__(
            self,
            dim,
            n_anchors:int,  # negative values disable the anchors
            num_heads:int,
            qkv_bias=True,
            proj_bias=True,
            rel_pos_bias="rope",
            seqlens=None,
            channel_first=False,
            init_weights="truncnormal002",
            init_last_proj_zero=False,
            do_attn_gating=False,
            n_attn_sinks = 0,
    ):
        super().__init__()
        assert hasattr(F, "scaled_dot_product_attention")
        assert dim % num_heads == 0, "dim should be divisible by num_heads"
        self.n_anchors = n_anchors
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.rel_pos_bias = rel_pos_bias
        self.seqlens = seqlens
        self.channel_first = channel_first
        self.init_weights = init_weights
        self.init_last_proj_zero = init_last_proj_zero
        self.do_attn_gating = do_attn_gating
        self.n_attn_sinks = n_attn_sinks

        if do_attn_gating:
            self.gate_proj = nn.Linear(self.head_dim, self.head_dim, bias=qkv_bias)

        self.kv_proj = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.q_proj = nn.Linear(dim, dim * 1, bias=qkv_bias)
        
        if rel_pos_bias == "none":
            self.rel_pos_bias_table = None
            self.rel_pos_idx = None
        elif rel_pos_bias == "learnable":
            assert seqlens is not None
            rel_pos_idx, num_distinct_distances = relative_position_indices(seqlens=seqlens, num_aux_tokens=1)
            self.register_buffer("rel_pos_idx", rel_pos_idx)
            self.rel_pos_bias_table = nn.Parameter(torch.empty(num_distinct_distances, num_heads))
        elif rel_pos_bias == 'rope':
            pass
            
        else:
            raise NotImplementedError
        self.proj = nn.Linear(dim, dim, bias=proj_bias)

        if self.n_attn_sinks > 0:
            self.attn_sinks: AttentionSinks = AttentionSinks(dim, n_attn_sinks)

        self.reset_parameters()

    def reset_parameters(self):
        if self.init_weights == "torch":
            pass
        elif self.init_weights == "xavier_uniform":
            self.apply(init_xavier_uniform_zero_bias)
            init_xavier_uniform_merged_linear(self.kv_proj, num_layers=3)
        elif self.init_weights in ["truncnormal", "truncnormal002"]:
            self.apply(init_truncnormal_zero_bias)
        else:
            raise NotImplementedError
        if self.init_last_proj_zero:
            nn.init.zeros_(self.proj.weight)
            # init_weights == "torch" has no zero bias init
            if self.proj.bias is not None:
                nn.init.zeros_(self.proj.bias)

    def forward(self, 
                x,
                rope_freqs=None,
                attn_mask=None,
                **kwargs
                ):
        if self.rel_pos_bias == 'rope':
            assert rope_freqs is not None

        if self.do_attn_gating:
            # save for attn gating later
            x_for_gate = einops.rearrange(x, 'bs seqlen (num_heads head_dim) -> bs num_heads seqlen head_dim', head_dim=self.head_dim)

        # disable anchoring if self.n_anchors < 0
        n_anchors = x.shape[1] if self.n_anchors < 0 else self.n_anchors
        anchors = x[:, :n_anchors]

        k, v = einops.rearrange(self.kv_proj(anchors), "bs seqlen (two num_heads head_dim) -> two bs num_heads seqlen head_dim", two=2, num_heads=self.num_heads, head_dim=self.head_dim).unbind(0)
        q = einops.rearrange(self.q_proj(x), "bs seqlen (num_heads head_dim) -> bs num_heads seqlen head_dim", num_heads=self.num_heads, head_dim=self.head_dim)
        
        # add positional embedding
        if rope_freqs is not None:
            q = rope_rotation(q, rope_freqs)
            k = rope_rotation(k, rope_freqs[:, :n_anchors])
        
        if self.n_attn_sinks > 0:
            k, v = self.attn_sinks(k, v)
            if attn_mask is not None:
                assert isinstance(attn_mask, torch.bool), "otherwise not implemented yet"
                b, n_heads, d_queries, d_keys = attn_mask.shape
                attn_mask_sinks = torch.ones((b, n_heads, d_queries, self.n_attn_sinks), dtype=torch.bool, device=attn_mask.device)
                attn_mask = torch.cat([attn_mask_sinks, attn_mask], dim=-1)
        
        x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
        
        if self.do_attn_gating:
            x = self.attn_gate(x, x_for_gate)
        
        x = einops.rearrange(x, "bs num_heads seqlen head_dim -> bs seqlen (num_heads head_dim)")
        
        x = self.proj(x)

        return x

    def attn_gate(self, x, x_for_gate):
        # gated attention: https://arxiv.org/abs/2505.06708
        gate = self.gate_proj(x_for_gate)
        gate = torch.sigmoid(gate)
        x = x * gate
        return x