"""MaskTopKAttentionManager for eager top-k masking without PQ indexing.

This manager stores per-layer K/V on GPU, supports a prefill stage that
captures the full prompt, and a decode stage that performs exact top-k
selection per KV head by computing q @ K^T and masking others.

Big-O:
- Prefill storage: O(Hk * S * D) memory
- Decode per step: O(Hk * S * D) for logits + O(Hk * S) for top-k +
  O(Hq * Tk * D) for attention, where Tk ≈ ratio * S
"""

from __future__ import annotations

from typing import Optional, Tuple

import torch


class MaskTopKAttentionManager:
    """Per-layer manager for mask-topk decoding without PQ/indexing.

    Shapes follow K/V -> [1, Hk, S, D], Q -> [1, Hq, Tq, D].
    """

    def __init__(self, model_config, dtype, device, sparsity_ratio: float = 0.05):
        self.num_kv_heads = int(getattr(model_config, "num_key_value_heads"))
        self.num_attn_heads = int(getattr(model_config, "num_attention_heads"))
        self.head_dim = int(getattr(model_config, "hidden_size")) // self.num_attn_heads
        self.dtype = dtype
        self.device = device
        self._ratio = float(sparsity_ratio)

        self.k_gpu: Optional[torch.Tensor] = None  # [1, Hk, cap, D]
        self.v_gpu: Optional[torch.Tensor] = None
        self.cap: int = 0
        self.seq_len: int = 0
        self.prefill_len: int = 0

    def set_sparsity_ratio(self, r: float) -> None:
        self._ratio = float(r)

    def _ensure_capacity(self, need: int) -> None:
        if self.k_gpu is None:
            cap = max(need, 1024)
            self.k_gpu = torch.empty((1, self.num_kv_heads, cap, self.head_dim), dtype=self.dtype, device=self.device)
            self.v_gpu = torch.empty_like(self.k_gpu)
            self.cap = cap
            self.seq_len = 0
            return
        if need <= self.cap:
            return
        new_cap = max(need, self.cap * 2)
        k_new = torch.empty((1, self.num_kv_heads, new_cap, self.head_dim), dtype=self.dtype, device=self.device)
        v_new = torch.empty_like(k_new)
        if self.seq_len > 0:
            k_new[..., : self.seq_len, :].copy_(self.k_gpu[..., : self.seq_len, :])
            v_new[..., : self.seq_len, :].copy_(self.v_gpu[..., : self.seq_len, :])
        self.k_gpu = k_new
        self.v_gpu = v_new
        self.cap = new_cap

    @torch.no_grad()
    def prefill_step(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
        """Store full prompt K/V on GPU; to be called once during prefill.

        Args:
            key_states, value_states: [1, Hk, S, D] on the correct device/dtype
        """
        S = key_states.shape[2]
        self._ensure_capacity(S)
        self.k_gpu[..., :S, :].copy_(key_states)
        self.v_gpu[..., :S, :].copy_(value_states)
        self.seq_len = S
        self.prefill_len = S

    @torch.no_grad()
    def append_decode_token(self, k_new: torch.Tensor, v_new: torch.Tensor) -> None:
        """Append one decode token's K/V to GPU buffers.

        Args:
            k_new, v_new: [1, Hk, 1, D]
        """
        need = self.seq_len + 1
        self._ensure_capacity(need)
        self.k_gpu[..., self.seq_len : self.seq_len + 1, :].copy_(k_new)
        self.v_gpu[..., self.seq_len : self.seq_len + 1, :].copy_(v_new)
        self.seq_len += 1

    @torch.no_grad()
    def reset_to_prefill(self) -> None:
        """Truncate sequence back to prefill length, keeping buffers allocated."""
        self.seq_len = int(self.prefill_len)

    @torch.no_grad()
    def select_topk_indices(self, q_state: torch.Tensor) -> torch.Tensor:
        """Compute exact logits and pick top-k per KV head.

        Args:
            q_state: [1, Hq, 1, D]
        Returns:
            indices: [Hk, Tk] long on device
        """
        assert self.k_gpu is not None and self.seq_len > 0
        _, Hq, _, D = q_state.shape
        Hk = self.num_kv_heads
        assert Hq % Hk == 0
        n_q_per_kv = Hq // Hk

        # Reduce Q heads within each KV group to a single vector per KV head
        q_flat = q_state.squeeze(0).squeeze(1)  # [Hq, D]
        q_grouped = q_flat.view(Hk, n_q_per_kv, D)
        q_kv = q_grouped.mean(dim=1)  # [Hk, D]

        k_all = self.k_gpu.squeeze(0)[:, : self.seq_len, :]  # [Hk, S, D]
        q_scaled = q_kv / (D ** 0.5)
        # logits: [Hk, S]
        logits = torch.einsum("hd,hsd->hs", q_scaled, k_all)

        S = logits.shape[1]
        topk = int(S * self._ratio)
        if topk == 0 and S > 0:
            topk = 1
        if topk == 0:
            return torch.empty(Hk, 0, dtype=torch.long, device=logits.device)
        _, idx = torch.topk(logits, k=topk, dim=1, largest=True)
        idx, _ = torch.sort(idx, dim=1)
        return idx

    @torch.no_grad()
    def gather_selected_kv(self, indices: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Gather selected tokens into contiguous [1, Hk, Tk, D] tensors on GPU."""
        Hk, Tk = indices.shape
        D = self.head_dim
        k = self.k_gpu.squeeze(0)[:, : self.seq_len, :]
        v = self.v_gpu.squeeze(0)[:, : self.seq_len, :]
        idx_exp = indices.unsqueeze(-1).expand(-1, -1, D)
        k_sel = k.gather(dim=1, index=idx_exp)
        v_sel = v.gather(dim=1, index=idx_exp)
        return k_sel.unsqueeze(0), v_sel.unsqueeze(0)



class MaskTopKRecallAttentionManager(MaskTopKAttentionManager):
    """Variant of MaskTopK that models imperfect recall of the top set.

    For each decode step and KV head, we:
      1) Identify the base top-K tokens (K = base_ratio * S)
      2) Randomly keep `recall_ratio` fraction of these K tokens
      3) Replace the missing (1 - recall_ratio) * K tokens with the next-best
         tokens immediately after the top-K boundary

    The final number of selected tokens remains K per KV head.

    Big-O per step: Same logits computation as base class
      - O(Hk * S * D) for logits + O(Hk * S) for top-k, then small bookkeeping.
    """

    def __init__(self, model_config, dtype, device, *, base_sparsity_ratio: float = 0.05, recall_ratio: float = 1.0):
        super().__init__(model_config, dtype, device, sparsity_ratio=base_sparsity_ratio)
        self._recall = float(recall_ratio)

    def set_recall_ratio(self, r: float) -> None:
        self._recall = float(r)

    @torch.no_grad()
    def select_topk_indices_with_recall(self, q_state: torch.Tensor) -> torch.Tensor:
        """Select indices using recall-aware substitution strategy.

        Args:
            q_state: [1, Hq, 1, D]
        Returns:
            indices: [Hk, K] long on device, where K = base_ratio * S (rounded)
        """
        assert self.k_gpu is not None and self.seq_len > 0
        _, Hq, _, D = q_state.shape
        Hk = self.num_kv_heads
        assert Hq % Hk == 0
        n_q_per_kv = Hq // Hk

        # Reduce Q heads within each KV group to a single vector per KV head
        q_flat = q_state.squeeze(0).squeeze(1)  # [Hq, D]
        q_grouped = q_flat.view(Hk, n_q_per_kv, D)
        q_kv = q_grouped.mean(dim=1)  # [Hk, D]

        k_all = self.k_gpu.squeeze(0)[:, : self.seq_len, :]  # [Hk, S, D]
        q_scaled = q_kv / (D ** 0.5)
        logits = torch.einsum("hd,hsd->hs", q_scaled, k_all)  # [Hk, S]

        S = logits.shape[1]
        base_k = int(S * self._ratio)
        if base_k == 0 and S > 0:
            base_k = 1
        if base_k == 0:
            return torch.empty(Hk, 0, dtype=torch.long, device=logits.device)

        keep_k = int(round(base_k * self._recall))
        keep_k = max(0, min(base_k, keep_k))
        replace_k = base_k - keep_k

        # Retrieve the top-(base_k + replace_k) indices; the extra slice will
        # provide the immediate next-best tokens to substitute missing ones.
        total_need = min(base_k + replace_k, S)
        _, idx_desc = torch.topk(logits, k=total_need, dim=1, largest=True)
        # idx_desc: descending by score per head

        # Prepare output tensor [Hk, base_k]
        out = torch.empty((Hk, base_k), dtype=torch.long, device=logits.device)

        for h in range(Hk):
            top_h = idx_desc[h]  # [total_need]
            # Randomly choose keep_k indices from the first base_k positions
            if keep_k > 0:
                perm = torch.randperm(min(base_k, top_h.shape[0]), device=logits.device)
                chosen = perm[:keep_k]
                chosen_in_top = top_h[:base_k][chosen]
            else:
                chosen_in_top = top_h.new_empty((0,), dtype=torch.long)

            if replace_k > 0:
                # Next-best segment immediately following the top-K boundary
                tail_slice = top_h[base_k : base_k + replace_k]
            else:
                tail_slice = top_h.new_empty((0,), dtype=torch.long)

            merged = torch.cat([chosen_in_top, tail_slice], dim=0)
            # Sort by time index to match gather expectations
            merged, _ = torch.sort(merged)
            # If due to S shortage we have < base_k, pad by truncating keep_k
            if merged.shape[0] >= base_k:
                out[h] = merged[:base_k]
            else:
                # rare edge: not enough tokens beyond boundary; fallback to top base_k
                base_only = torch.topk(logits[h], k=base_k, largest=True).indices
                out[h] = torch.sort(base_only)[0]

        return out

    # Expose the method name expected by llama_patch
    @torch.no_grad()
    def select_topk_indices(self, q_state: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        return self.select_topk_indices_with_recall(q_state)

