"""CPU-pinned KV cache store for sparse attention retrieval.

This module provides a lightweight `CpuKVStore` that holds past keys/values
on CPU pinned memory to support efficient H2D asynchronous transfers for
selected token indices during decoding.
"""

from __future__ import annotations

import torch
from torch.cuda import Stream


class CpuKVStore:
    """Pinned-memory KV store.

    Shapes follow the convention: K/V -> [1, Hk, S, D], where S is the time axis.

    - append_from_gpu: Append new tokens from GPU tensors to CPU pinned buffers
    - gather_cpu: Gather selected indices into contiguous pinned buffers
    - to_gpu_async: Transfer gathered slices to GPU via a separate CUDA stream
    """

    def __init__(self, n_kv_head: int, head_dim: int, *, dtype=torch.bfloat16, reserve_tokens: int = 8192):
        self.dtype = dtype
        self.n_kv_head = n_kv_head
        self.head_dim = head_dim
        self.cap = reserve_tokens
        self.len = 0
        self.k = torch.empty((1, n_kv_head, reserve_tokens, head_dim), dtype=dtype, pin_memory=True)
        self.v = torch.empty_like(self.k)
        self.h2d_stream: Stream = Stream()

    def _grow(self, need: int) -> None:
        new_cap = max(need, self.cap * 2)

        def grow(buf: torch.Tensor) -> torch.Tensor:
            nb = torch.empty((1, self.n_kv_head, new_cap, self.head_dim), dtype=self.dtype, pin_memory=True)
            if self.len > 0:
                nb[..., : self.len, :].copy_(buf[..., : self.len, :])
            return nb

        self.k = grow(self.k)
        self.v = grow(self.v)
        self.cap = new_cap

    @torch.no_grad()
    def append_from_gpu(self, k_gpu: torch.Tensor, v_gpu: torch.Tensor) -> None:
        """Append new K/V slices from GPU to CPU pinned buffers.

        Args:
            k_gpu, v_gpu: [1, Hk, t, D]

        Big-O: O(Hk * t * D) memory copies on CPU side.
        """
        t = k_gpu.shape[2]
        need = self.len + t
        if need > self.cap:
            self._grow(need)
        self.k[..., self.len : self.len + t, :].copy_(k_gpu.detach().to("cpu"))
        self.v[..., self.len : self.len + t, :].copy_(v_gpu.detach().to("cpu"))
        self.len += t

    @torch.no_grad()
    def gather_cpu(self, indices_cpu: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Gather selected time indices for each KV head on CPU.

        Args:
            indices_cpu: [Hk, Tk] (CPU long)

        Returns:
            (fi_k, fi_v): [1, Hk, Tk, D] pinned tensors
        """
        Hk, Tk = indices_cpu.shape
        D = self.head_dim
        fi_k = torch.empty((1, Hk, Tk, D), dtype=self.dtype, pin_memory=True)
        fi_v = torch.empty_like(fi_k)
        for h in range(Hk):
            idx = indices_cpu[h]
            fi_k[0, h] = self.k[0, h].index_select(0, idx)
            fi_v[0, h] = self.v[0, h].index_select(0, idx)
        return fi_k, fi_v

    @torch.no_grad()
    def to_gpu_async(self, cpu_tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
        """Asynchronously copy a pinned CPU tensor to GPU using a dedicated stream."""
        with torch.cuda.stream(self.h2d_stream):
            t = cpu_tensor.to(device, non_blocking=True)
        torch.cuda.current_stream().wait_stream(self.h2d_stream)
        return t


