"""Search and retrieval utilities for sparse attention.

Provides functions to select top token indices from PQ scores and to form
GQA-aligned K/V slices for attention computation.
"""

from __future__ import annotations

import torch


def search_index_test(
    q_tokens: torch.Tensor,  # [n_q_head, d]
    kv_block_centroids_tensor: torch.Tensor,  # [n_kv_head, m_blocks, num_c, sub_dim]
    kv_block_key_scores_tensor: torch.Tensor,  # [n_kv_head, m_blocks, num_c, S_total]
    n_kv_head: int,
    current_valid_S: int,
    *,
    top_centroids_per_block: int = 1,
    top_token_ratio: float = 0.05,
) -> torch.Tensor:
    n_q_head, d = q_tokens.shape
    device = q_tokens.device
    _, m_blocks, num_c, sub_dim = kv_block_centroids_tensor.shape
    n_nv_group = n_q_head // n_kv_head

    q_tokens_grouped = q_tokens.view(n_kv_head, n_nv_group, d)
    q_tokens_block_reshaped = q_tokens_grouped.view(n_kv_head, n_nv_group, m_blocks, sub_dim)

    sim_all_heads = torch.einsum("kgbs,kbcs->kbgc", q_tokens_block_reshaped, kv_block_centroids_tensor)
    sim_mean_all_heads = sim_all_heads.mean(dim=2)
    _, top_c_idx = torch.topk(sim_mean_all_heads, k=top_centroids_per_block, dim=2, largest=True)

    kv_head_idx = torch.arange(n_kv_head, device=device).view(n_kv_head, 1, 1).expand(-1, m_blocks, top_centroids_per_block)
    block_idx = torch.arange(m_blocks, device=device).view(1, m_blocks, 1).expand(n_kv_head, -1, top_centroids_per_block)

    gathered_scores = kv_block_key_scores_tensor[:, :, :, :current_valid_S][kv_head_idx, block_idx, top_c_idx]
    final_head_scores = gathered_scores.sum(dim=(1, 2))  # [n_kv_head, S]

    topk_num = int(current_valid_S * top_token_ratio)
    if topk_num == 0 and current_valid_S > 0:
        topk_num = 1
    if topk_num == 0:
        return torch.empty(n_kv_head, 0, dtype=torch.long, device=device)

    _, top_token_indices = torch.topk(final_head_scores, k=topk_num, dim=1, largest=True)
    top_token_indices, _ = torch.sort(top_token_indices, dim=1)
    return top_token_indices


def extract_index_test(k_cache: torch.Tensor, v_cache: torch.Tensor, indices: torch.Tensor):
    """Gather K/V by head/time indices using advanced indexing.

    Returns: k_sel, v_sel with shape [1, Hk, Tk, D]
    """
    device = k_cache.device
    k = k_cache.squeeze(0)
    v = v_cache.squeeze(0)
    n_kv_head, _, dim = k.shape
    topk_num = indices.shape[1]
    head_idx = torch.arange(n_kv_head, device=device).unsqueeze(1).expand(-1, topk_num)
    k_sel = k[head_idx, indices]
    v_sel = v[head_idx, indices]
    return k_sel.unsqueeze(0), v_sel.unsqueeze(0)


def extract_index_test2_fast(k_cache: torch.Tensor, v_cache: torch.Tensor, indices: torch.Tensor, workspace=None):
    device = k_cache.device
    k = k_cache.squeeze(0).contiguous()
    v = v_cache.squeeze(0).contiguous()
    Hk, S, D = k.shape
    Hq, Tk = indices.shape
    assert Hq % Hk == 0
    if indices.dtype != torch.long:
        indices = indices.long()
    if indices.device != device:
        indices = indices.to(device)
    n_nv_group = Hq // Hk
    kv_for_q = (torch.arange(Hq, device=device) // n_nv_group).long()
    k_q = torch.index_select(k, dim=0, index=kv_for_q)
    v_q = torch.index_select(v, dim=0, index=kv_for_q)
    if Tk > 0:
        idx_exp = indices.unsqueeze(-1).expand(-1, -1, D)
        k_sel = k_q.gather(dim=1, index=idx_exp)
        v_sel = v_q.gather(dim=1, index=idx_exp)
    else:
        k_sel = k_q.new_empty((Hq, 0, D))
        v_sel = v_q.new_empty((Hq, 0, D))
    total_T = Tk + 1
    if workspace is not None:
        k_out, v_out = workspace
        assert k_out.shape == (1, Hq, total_T, D) and v_out.shape == (1, Hq, total_T, D)
    else:
        k_out = k.new_empty((1, Hq, total_T, D))
        v_out = v.new_empty((1, Hq, total_T, D))
    if Tk > 0:
        k_out[0, :, :Tk, :] = k_sel
        v_out[0, :, :Tk, :] = v_sel
    k_out[0, :, Tk : Tk + 1, :] = k_q[:, S - 1 : S, :]
    v_out[0, :, Tk : Tk + 1, :] = v_q[:, S - 1 : S, :]
    return k_out, v_out


def gqa_decode_on_selected_kvcache(q_state: torch.Tensor, filtered_k: torch.Tensor, filtered_v: torch.Tensor):
    """Scaled dot-product attention on selected K/V with GQA head replication.

    Args:
        q_state: [B, Hq, Tq, D]
        filtered_k/v: [B, Hk, Tk, D]
    Returns:
        [B, Hq, Tq, D]
    """
    batch_size, n_q_head, query_seq_len, head_dim = q_state.shape
    _, n_kv_head, _, _ = filtered_k.shape
    assert n_q_head % n_kv_head == 0
    num_q_per_kv = n_q_head // n_kv_head
    k_rep = filtered_k.permute(0, 2, 1, 3).repeat_interleave(num_q_per_kv, dim=2).permute(0, 2, 1, 3)
    v_rep = filtered_v.permute(0, 2, 1, 3).repeat_interleave(num_q_per_kv, dim=2).permute(0, 2, 1, 3)
    q_scaled = q_state / (head_dim ** 0.5)
    attn_scores = torch.matmul(q_scaled, k_rep.transpose(-2, -1))
    attn_probs = torch.softmax(attn_scores, dim=-1)
    return torch.matmul(attn_probs, v_rep)


def gqa_prefill_on_full_kvcache(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: torch.Tensor | None):
    n_kv_group = q.shape[1] // k.shape[1]
    k_cache = k.repeat_interleave(n_kv_group, dim=1)
    v_cache = v.repeat_interleave(n_kv_group, dim=1)
    return torch.nn.functional.scaled_dot_product_attention(
        q, k_cache, v_cache, attn_mask=attention_mask, dropout_p=0.0, is_causal=True
    )


