"""LLaMA attention forward override integrated with SparseAttentionManager.

This module provides a factory `get_custom_llama_forward` that returns a
closure to replace the `forward` method of `transformers.models.llama.LlamaAttention`.

The custom forward integrates sparse attention by:
- During prefill (q_len > 1), building PQ indices and running full attention
- During decode (q_len == 1), searching top tokens and attending to filtered K/V
"""

from __future__ import annotations

from typing import Callable, Tuple

import torch
from transformers.models.llama.modeling_llama import eager_attention_forward

from ..sparse_attention.search import (
    gqa_prefill_on_full_kvcache,
    gqa_decode_on_selected_kvcache,
)
from ..sparse_attention.mask_topk import MaskTopKAttentionManager


def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def _apply_rotary_pos_emb(
    q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, *, unsqueeze_dim: int = 1
) -> Tuple[torch.Tensor, torch.Tensor]:
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (_rotate_half(q) * sin)
    k_embed = (k * cos) + (_rotate_half(k) * sin)
    return q_embed, k_embed


def get_custom_llama_forward(attention_module, sparse_manager) -> Callable:
    """Return a closure that overrides LLaMA attention forward.

    Args:
        attention_module: Instance of LlamaAttention for a single layer.
        sparse_manager: A per-layer SparseAttentionManager instance.
    """

    def custom_forward(
        hidden_states,
        attention_mask=None,
        position_ids=None,
        past_key_value=None,
        output_attentions=False,
        use_cache=False,
        cache_position=None,
        position_embeddings=None,
        **kwargs,
    ):
        bsz, q_len, _ = hidden_states.size()
        device = hidden_states.device
        dtype = hidden_states.dtype

        # Projections
        query_states = attention_module.q_proj(hidden_states)
        key_states = attention_module.k_proj(hidden_states)
        value_states = attention_module.v_proj(hidden_states)

        num_heads = getattr(attention_module, "num_heads", getattr(attention_module.config, "num_attention_heads"))
        num_kv_heads = getattr(
            attention_module, "num_key_value_heads", getattr(attention_module.config, "num_key_value_heads")
        )
        head_dim = attention_module.head_dim

        query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)

        # RoPE
        if position_embeddings is not None:
            cos, sin = position_embeddings
            if q_len == 1:
                cos = cos[:, -1:, :]
                sin = sin[:, -1:, :]
        elif hasattr(attention_module, "rotary_emb"):
            cos, sin = attention_module.rotary_emb(value_states, position_ids)
            if q_len == 1:
                cos = cos[:, -1:, :]
                sin = sin[:, -1:, :]
        else:
            cos = sin = None

        if cos is not None and sin is not None:
            query_states, key_states = _apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if q_len > 1:
            # Prefill: build index and run full attention
            sparse_manager.prefill_step(query_states, key_states, value_states)
            attn_output = gqa_prefill_on_full_kvcache(query_states, key_states, value_states, attention_mask)
        else:
            # Decode: retrieve sparse K/V and attend
            filtered_k, filtered_v = sparse_manager.decode_step(query_states, key_states, value_states)
            attn_output = gqa_decode_on_selected_kvcache(query_states, filtered_k, filtered_v)

        attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
        attn_output = attention_module.o_proj(attn_output)
        return attn_output, None

    return custom_forward



def get_custom_llama_forward_mask_topk(attention_module, manager: MaskTopKAttentionManager) -> Callable:
    """Return a closure overriding LLaMA attention for mask-topk decoding.

    Prefill (q_len > 1): store full K/V and run dense attention.
    Decode (q_len == 1): append new K/V, select exact top-k via logits, and
    attend only on selected tokens (equivalent to masking others to -inf).
    """

    def custom_forward(
        hidden_states,
        attention_mask=None,
        position_ids=None,
        past_key_value=None,
        output_attentions=False,
        use_cache=False,
        cache_position=None,
        position_embeddings=None,
        **kwargs,
    ):
        bsz, q_len, _ = hidden_states.size()

        # Projections
        query_states = attention_module.q_proj(hidden_states)
        key_states = attention_module.k_proj(hidden_states)
        value_states = attention_module.v_proj(hidden_states)

        num_heads = getattr(attention_module, "num_heads", getattr(attention_module.config, "num_attention_heads"))
        num_kv_heads = getattr(
            attention_module, "num_key_value_heads", getattr(attention_module.config, "num_key_value_heads")
        )
        head_dim = attention_module.head_dim

        query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)

        # RoPE
        if position_embeddings is not None:
            cos, sin = position_embeddings
            if q_len == 1:
                cos = cos[:, -1:, :]
                sin = sin[:, -1:, :]
        elif hasattr(attention_module, "rotary_emb"):
            cos, sin = attention_module.rotary_emb(value_states, position_ids)
            if q_len == 1:
                cos = cos[:, -1:, :]
                sin = sin[:, -1:, :]
        else:
            cos = sin = None

        if cos is not None and sin is not None:
            query_states, key_states = _apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if q_len > 1:
            # Prefill: store K/V and run dense attention
            manager.prefill_step(key_states, value_states)
            attn_output = gqa_prefill_on_full_kvcache(query_states, key_states, value_states, attention_mask)
        else:
            # Decode: append K/V (except when echoing last prompt token), select top-k, and attend
            cur_pos = None
            if 'position_ids' in kwargs and kwargs['position_ids'] is not None:
                pid = kwargs['position_ids']
                try:
                    cur_pos = int(pid.reshape(-1)[-1].item())
                except Exception:
                    cur_pos = None
            elif position_ids is not None:
                try:
                    cur_pos = int(position_ids.reshape(-1)[-1].item())
                except Exception:
                    cur_pos = None
            # Append only if we are beyond prefill boundary
            if cur_pos is None or cur_pos >= manager.seq_len:
                manager.append_decode_token(key_states, value_states)
            # Support recall-aware managers exposing the same method name
            indices = manager.select_topk_indices(query_states)
            filtered_k, filtered_v = manager.gather_selected_kv(indices)
            attn_output = gqa_decode_on_selected_kvcache(query_states, filtered_k, filtered_v)

        attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
        attn_output = attention_module.o_proj(attn_output)
        return attn_output, None

    return custom_forward


