"""Qwen attention forward override integrated with SparseAttentionManager."""

from __future__ import annotations

from typing import Callable, Tuple

import torch

from ..sparse_attention.search import (
    gqa_prefill_on_full_kvcache,
    gqa_decode_on_selected_kvcache,
)


def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def _apply_rotary_pos_emb(
    q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, *, unsqueeze_dim: int = 1
) -> tuple[torch.Tensor, torch.Tensor]:
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (_rotate_half(q) * sin)
    k_embed = (k * cos) + (_rotate_half(k) * sin)
    return q_embed, k_embed


def get_custom_qwen_forward(attention_module, sparse_manager) -> Callable:
    def custom_forward(
        hidden_states,
        attention_mask=None,
        position_ids=None,
        past_key_value=None,
        output_attentions=False,
        use_cache=False,
        cache_position=None,
        position_embeddings=None,
        **kwargs,
    ):
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, attention_module.head_dim)
        bsz, q_len, _ = hidden_states.size()

        query_states = attention_module.q_norm(attention_module.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
        key_states = attention_module.k_norm(attention_module.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
        value_states = attention_module.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        if position_embeddings is not None:
            cos, sin = position_embeddings
            if q_len == 1:
                cos = cos[:, -1:, :]
                sin = sin[:, -1:, :]
        elif hasattr(attention_module, "rotary_emb"):
            cos, sin = attention_module.rotary_emb(value_states, position_ids)
            if q_len == 1:
                cos = cos[:, -1:, :]
                sin = sin[:, -1:, :]
        else:
            cos = sin = None

        if cos is not None and sin is not None:
            query_states, key_states = _apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if q_len > 1:
            sparse_manager.prefill_step(query_states, key_states, value_states)
            attn_output = gqa_prefill_on_full_kvcache(query_states, key_states, value_states, attention_mask)
        else:
            filtered_k, filtered_v = sparse_manager.decode_step(query_states, key_states, value_states)
            attn_output = gqa_decode_on_selected_kvcache(query_states, filtered_k, filtered_v)

        attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous()
        attn_output = attention_module.o_proj(attn_output)
        return attn_output, None

    return custom_forward



