"""SparseAttentionManager encapsulating prefill/decode logic.

This class manages CPU KV storage, PQ centroid/score tables, and per-step
search/retrieval to support sparse decoding attention.
"""

from __future__ import annotations

import torch

from .kv_cache import CpuKVStore
from .indexing import (
    build_pq_clusters_for_kv_heads_optimized_1346,
    append_new_k_token_to_score_tables_optimized_1346,
)
from .search import search_index_test


class SparseAttentionManager:
    def __init__(self, model_config, dtype, device, top_token_ratio: float = 0.05):
        """Initialize with model config and environment.

        Args:
            model_config: HF model config with num_key_value_heads, num_attention_heads, hidden_size.
            dtype: torch dtype for KV store.
            device: target device for GPU tensors.
            top_token_ratio: ratio for selecting top tokens per decode step.
        """
        self.num_kv_heads = int(getattr(model_config, "num_key_value_heads"))
        self.num_attn_heads = int(getattr(model_config, "num_attention_heads"))
        self.head_dim = int(getattr(model_config, "hidden_size")) // self.num_attn_heads
        self.dtype = dtype
        self.device = device
        self.top_token_ratio = float(top_token_ratio)

        self.cpu_kv_cache: CpuKVStore | None = None
        self.kv_block_centroids = None
        self.kv_block_scores = None
        self.current_seq_len = 0

    @torch.no_grad()
    def prefill_step(self, query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
        """Build initial PQ index and fill CPU KV cache from prefill states.

        Big-O: dominated by k-means and matmul: see indexing functions for details.
        """
        seq_len = key_states.shape[2]
        self.cpu_kv_cache = CpuKVStore(self.num_kv_heads, self.head_dim, dtype=self.dtype, reserve_tokens=seq_len + 1024)
        self.cpu_kv_cache.append_from_gpu(key_states, value_states)

        cents, scores, initial_len = build_pq_clusters_for_kv_heads_optimized_1346(query_states, key_states)
        self.kv_block_centroids = cents.to("cpu")
        self.kv_block_scores = scores.to("cpu")
        self.current_seq_len = int(initial_len)

    @torch.no_grad()
    def decode_step(
        self, query_states: torch.Tensor, new_key_state: torch.Tensor, new_value_state: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Append new K/V, update scores, search indices, and return filtered K/V on GPU."""
        assert self.cpu_kv_cache is not None
        self.cpu_kv_cache.append_from_gpu(new_key_state, new_value_state)

        append_new_k_token_to_score_tables_optimized_1346(
            self.kv_block_centroids,
            self.kv_block_scores,
            new_key_state.squeeze(2).detach().to("cpu", dtype=self.kv_block_centroids.dtype),
            self.current_seq_len,
        )
        self.current_seq_len += 1

        indices = search_index_test(
            query_states.squeeze(2).detach().to("cpu", dtype=self.kv_block_centroids.dtype),
            self.kv_block_centroids,
            self.kv_block_scores,
            self.num_kv_heads,
            self.current_seq_len,
            top_token_ratio=self.top_token_ratio,
        )

        filtered_k_cpu, filtered_v_cpu = self.cpu_kv_cache.gather_cpu(indices)
        filtered_k = self.cpu_kv_cache.to_gpu_async(filtered_k_cpu, self.device)
        filtered_v = self.cpu_kv_cache.to_gpu_async(filtered_v_cpu, self.device)
        return filtered_k, filtered_v


