"""Index building utilities for sparse attention.

Includes k-means clustering and product quantization block construction for
per-head KV token grouping. Optimized GPU and CPU variants are provided.
"""

from __future__ import annotations

import os
import torch


def batched_kmeans_cpu(
    q_blocks: torch.Tensor,
    num_centroids: int,
    num_iters: int = 10,
    seed: int = 42,
    use_cdist: bool = True,
) -> torch.Tensor:
    """CPU batched k-means on [Hk, S, m_blocks, sub_dim].

    Returns centroids of shape [Hk, m_blocks, num_centroids, sub_dim].

    Time: O(num_iters * Hk * S * m_blocks * num_centroids * sub_dim)
    Space: O(Hk * (S + num_centroids) * m_blocks * sub_dim)
    """
    if q_blocks.is_cuda:
        q_blocks = q_blocks.to("cpu")
    q_blocks = q_blocks.contiguous().to(torch.float32)

    n_kv_head, S, m_blocks, sub_dim = q_blocks.shape
    q_blocks_flat = q_blocks.permute(0, 2, 1, 3).reshape(-1, S, sub_dim)  # [B,S,D]
    B = q_blocks_flat.shape[0]

    g = torch.Generator(device="cpu")
    g.manual_seed(seed)
    rand_idx = torch.randint(0, S, (B, num_centroids), generator=g, device="cpu")
    batch_indices = torch.arange(B, device="cpu").unsqueeze(1)
    centroids = q_blocks_flat[batch_indices, rand_idx].clone()  # [B,C,D]

    for _ in range(num_iters):
        if use_cdist:
            distances = torch.cdist(q_blocks_flat, centroids, p=2)  # [B,S,C]
        else:
            x2 = (q_blocks_flat ** 2).sum(-1, keepdim=True)
            y2 = (centroids ** 2).sum(-1).unsqueeze(1)
            xy = q_blocks_flat @ centroids.transpose(1, 2)
            distances = (x2 + y2 - 2 * xy).clamp_min_(0.0).sqrt_()

        assignments = distances.argmin(dim=-1)  # [B,S]

        new_centroids_sum = torch.zeros_like(centroids)
        counts = torch.zeros(B, num_centroids, dtype=torch.float32)

        flat_assign = assignments.reshape(-1)
        flat_q = q_blocks_flat.reshape(-1, sub_dim)
        batch_offset = (torch.arange(B) * num_centroids).unsqueeze(1)
        flat_indices = (assignments + batch_offset).long().reshape(-1)

        new_centroids_sum.view(-1, sub_dim).scatter_add_(
            0, flat_indices.unsqueeze(1).expand(-1, sub_dim), flat_q
        )
        counts.view(-1).scatter_add_(0, flat_indices, torch.ones_like(flat_indices, dtype=torch.float32))

        counts = counts.clamp_min_(1.0)
        centroids = new_centroids_sum / counts.unsqueeze(-1)

    return centroids.view(n_kv_head, m_blocks, num_centroids, sub_dim).contiguous()


def batched_kmeans_torch(q_blocks: torch.Tensor, num_centroids: int, num_iters: int = 7) -> torch.Tensor:
    """GPU-friendly k-means on [Hk, S, m_blocks, sub_dim].

    Time: O(num_iters * Hk * S * m_blocks * num_centroids * sub_dim)
    Space: O(Hk * (S + num_centroids) * m_blocks * sub_dim)
    """
    n_kv_head, S, m_blocks, sub_dim = q_blocks.shape
    device = q_blocks.device
    q_blocks_flat = q_blocks.permute(0, 2, 1, 3).reshape(-1, S, sub_dim)
    B = q_blocks_flat.shape[0]
    rand_idx = torch.randint(0, S, (B, num_centroids), device=device)
    batch_indices = torch.arange(B, device=device).unsqueeze(1)
    centroids = q_blocks_flat[batch_indices, rand_idx]
    for _ in range(num_iters):
        distances = torch.cdist(q_blocks_flat, centroids, p=2)
        assignments = distances.argmin(dim=-1)
        new_centroids_sum = torch.zeros_like(centroids)
        counts = torch.zeros(B, num_centroids, dtype=q_blocks_flat.dtype, device=device)

        batch_offset = torch.arange(B, device=device).unsqueeze(1) * num_centroids
        flat_indices = (assignments + batch_offset).long().reshape(-1)
        flat_q_blocks = q_blocks_flat.reshape(-1, sub_dim)

        new_centroids_sum.view(-1, sub_dim).scatter_add_(0, flat_indices.unsqueeze(1).expand(-1, sub_dim), flat_q_blocks)
        counts.view(-1).scatter_add_(0, flat_indices, torch.ones_like(flat_indices, dtype=counts.dtype))
        counts = counts.clamp(min=1e-6)
        centroids = new_centroids_sum / counts.unsqueeze(-1)
    return centroids.view(n_kv_head, m_blocks, num_centroids, sub_dim)


def build_pq_clusters_for_kv_heads_cpu(
    q_vectors: torch.Tensor,
    k_vectors: torch.Tensor,
    m_blocks: int = 8,
    num_centroids_per_block: int = 512,
    topk_tokens_per_centroid=None,
    max_decode_len: int = 512,
):
    """CPU path to build PQ centroids and initialize score tables.

    Returns (centroids_cpu, scores_cpu, S) where centroids are fp32.
    """
    if q_vectors.is_cuda:
        q_vectors = q_vectors.to("cpu")
    if k_vectors.is_cuda:
        k_vectors = k_vectors.to("cpu")
    q_vectors = q_vectors.contiguous()
    k_vectors = k_vectors.contiguous()

    bsz, n_q_head, S, d = q_vectors.shape
    _, n_kv_head, S2, d2 = k_vectors.shape
    assert bsz == 1 and S == S2 and d == d2
    assert d % m_blocks == 0
    assert n_q_head % n_kv_head == 0

    sub_dim = d // m_blocks
    n_nv_group = n_q_head // n_kv_head

    q_s = q_vectors.squeeze(0).view(n_q_head, S, d)
    q_grouped = q_s.view(n_kv_head, n_nv_group, S, d)
    q_merged_for_all_kv = q_grouped.permute(0, 2, 1, 3).reshape(n_kv_head, S, n_nv_group * d)
    q_relevant = q_merged_for_all_kv[:, :, :d].contiguous()

    q_blocks = q_relevant.unfold(2, sub_dim, sub_dim)
    k_s = k_vectors.squeeze(0).contiguous()
    k_blocks = k_s.unfold(2, sub_dim, sub_dim)

    kv_block_centroids_tensor = batched_kmeans_cpu(q_blocks, num_centroids=num_centroids_per_block, num_iters=10)
    k_blocks_T = k_blocks.permute(0, 2, 3, 1).contiguous()
    initial_scores = torch.matmul(kv_block_centroids_tensor.to(torch.float32), k_blocks_T.to(torch.float32))

    if topk_tokens_per_centroid is not None:
        initial_scores, _ = torch.topk(initial_scores, k=topk_tokens_per_centroid, dim=-1, largest=True, sorted=True)
        S = topk_tokens_per_centroid

    total_capacity = int(S + max_decode_len)
    kv_block_key_scores_buffer = torch.empty(
        n_kv_head, m_blocks, num_centroids_per_block, total_capacity, dtype=initial_scores.dtype, device="cpu"
    )
    kv_block_key_scores_buffer[:, :, :, :S] = initial_scores
    return kv_block_centroids_tensor.contiguous(), kv_block_key_scores_buffer.contiguous(), S


def build_pq_clusters_for_kv_heads_optimized_1346(
    q_vectors: torch.Tensor,
    k_vectors: torch.Tensor,
    m_blocks: int = 8,
    num_centroids_per_block: int = 300,
    topk_tokens_per_centroid=None,
    max_decode_len: int = 512,
):
    """GPU-optimized path to build PQ centroids and initialize score tables.

    Returns (centroids, scores_buffer, S). All tensors remain on the input device.
    """
    if q_vectors.is_cuda:
        q_vectors = q_vectors
    if k_vectors.is_cuda:
        k_vectors = k_vectors
    bsz, n_q_head, S, d = q_vectors.shape
    _, n_kv_head, _, _ = k_vectors.shape
    device = q_vectors.device

    sub_dim = d // m_blocks
    assert bsz == 1
    assert S == k_vectors.shape[2]
    assert d % m_blocks == 0
    assert n_q_head % n_kv_head == 0

    n_nv_group = n_q_head // n_kv_head
    q_vectors_squeezed = q_vectors.squeeze(0)
    q_grouped = q_vectors_squeezed.view(n_kv_head, n_nv_group, S, d)
    q_merged_for_all_kv = q_grouped.permute(0, 2, 1, 3).reshape(n_kv_head, S, n_nv_group * d)
    k_vectors_squeezed = k_vectors.squeeze(0)
    q_relevant_d_dim = q_merged_for_all_kv[:, :, :d]
    q_blocks = q_relevant_d_dim.unfold(2, sub_dim, sub_dim)
    k_blocks = k_vectors_squeezed.unfold(2, sub_dim, sub_dim)

    kv_block_centroids_tensor = batched_kmeans_torch(q_blocks, num_centroids=num_centroids_per_block, num_iters=7)
    k_blocks_transposed = k_blocks.permute(0, 2, 3, 1)
    initial_kv_block_key_scores_content = torch.matmul(kv_block_centroids_tensor, k_blocks_transposed)

    if topk_tokens_per_centroid is not None:
        initial_kv_block_key_scores_content, _ = torch.topk(
            initial_kv_block_key_scores_content, k=topk_tokens_per_centroid, dim=-1, largest=True, sorted=True
        )
        S = topk_tokens_per_centroid

    total_capacity = S + max_decode_len
    kv_block_key_scores_buffer = torch.empty(
        n_kv_head, m_blocks, num_centroids_per_block, total_capacity, device=device, dtype=initial_kv_block_key_scores_content.dtype
    )
    kv_block_key_scores_buffer[:, :, :, :S] = initial_kv_block_key_scores_content
    return kv_block_centroids_tensor, kv_block_key_scores_buffer, S


def append_new_k_token_to_score_tables_optimized_1346(
    kv_block_centroids: torch.Tensor,
    kv_block_key_scores_buffer: torch.Tensor,
    new_k_token: torch.Tensor,
    current_token_idx: int,
) -> None:
    """Update PQ score buffer with a new K token at a specific index.

    Time: O(Hk * m_blocks * num_centroids * sub_dim)
    """
    n_kv_head, m_blocks, num_centroids, sub_dim = kv_block_centroids.shape
    new_k_token_blocks = new_k_token.view(n_kv_head, m_blocks, sub_dim)
    score_new_all_heads_all_blocks = torch.einsum("hmcs,hms->hmc", kv_block_centroids, new_k_token_blocks)
    kv_block_key_scores_buffer[:, :, :, current_token_idx] = score_new_all_heads_all_blocks


