import einops
import torch
import torch.nn.functional as F
from torch import nn

from kappamodules.init import (
    init_xavier_uniform_zero_bias,
    init_xavier_uniform_merged_linear,
    init_truncnormal_zero_bias,
)

from src.modules.rope import rope_rotation
from src.utils.attn_sinks import AttentionSinks


class CrossAttention(nn.Module):
    def __init__(
            self,
            dim,
            kv_dim=None,
            num_heads=8,
            bias=True,
            n_attn_sinks=4,
            concat_query_to_kv=False,
            init_weights="truncnormal002",
            init_last_proj_zero=False,
    ):
        super().__init__()
        assert hasattr(F, "scaled_dot_product_attention")
        assert dim % num_heads == 0, "dim should be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.concat_query_to_kv = concat_query_to_kv
        self.init_weights = init_weights
        self.init_last_proj_zero = init_last_proj_zero
        self.n_attn_sinks = n_attn_sinks

        self.kv = nn.Linear(kv_dim or dim, dim * 2, bias=bias)
        self.q = nn.Linear(dim, dim, bias=bias)
        self.proj = nn.Linear(dim, dim, bias=bias)

        self.attn_sinks: AttentionSinks = AttentionSinks(dim, n_attn_sinks)

        self.reset_parameters()

    def reset_parameters(self):
        if self.init_weights == "torch":
            pass
        elif self.init_weights == "xavier_uniform":
            self.apply(init_xavier_uniform_zero_bias)
            init_xavier_uniform_merged_linear(self.kv, num_layers=2)
        elif self.init_weights in ["truncnormal", "truncnormal002"]:
            self.apply(init_truncnormal_zero_bias)
        else:
            raise NotImplementedError
        if self.init_last_proj_zero:
            nn.init.zeros_(self.proj.weight)
            # init_weights == "torch" has no zero bias init
            if self.proj.bias is not None:
                nn.init.zeros_(self.proj.bias)

    def forward(self, 
                q, 
                kv,
                rope_freqs, 
                kv_rope_freqs, 
                xa_attn_mask=None,  # name this differently to be able to pass attn_mask for main self-attn branch
                ):
        
        # project to attention space
        if self.concat_query_to_kv:
            kv = torch.concat([kv, q], dim=1)
        kv = self.kv(kv)
        q = self.q(q)

        # split per head
        q = einops.rearrange(q, "bs seqlen_q (num_heads head_dim) -> bs num_heads seqlen_q head_dim", num_heads=self.num_heads, head_dim=self.head_dim)
        k, v = einops.rearrange(kv, "bs seqlen_kv (two num_heads head_dim) -> two bs num_heads seqlen_kv head_dim", two=2, num_heads=self.num_heads, head_dim=self.head_dim).unbind(0)
        
        q = rope_rotation(q, rope_freqs)
        k = rope_rotation(k, kv_rope_freqs)

        if self.n_attn_sinks > 0:
            k, v = self.attn_sinks(k, v)

        x = F.scaled_dot_product_attention(q, k, v, attn_mask=xa_attn_mask)
        x = einops.rearrange(x, "bs num_heads seqlen head_dim -> bs seqlen (num_heads head_dim)")
        
        x = self.proj(x)
        return x

