# -*- coding: utf-8 -*-
"""
Fused cache update kernel for Priority 2 optimization.

This kernel fuses:
1. Coefficient computation (a = 1 - alpha * slot_probs)
2. Masking operations
3. b computation (b = alpha * slot_probs * k/v)
4. Linear recurrence (sequential scan for final state only)

Since we only need the final state of the recurrence, we can use a sequential
scan instead of parallel scan, which simplifies the implementation significantly.

Two kernel designs are provided:

1. **fused_cache_update_kernel_per_sample** (Original Design):
   - Grid: (B * H * num_bucket * N_sample_per_bucket,)
   - Each program handles one (B, H, num_bucket, N_sample_per_bucket) position
   - Processes D elements in parallel
   - Tokens loaded N_sample_per_bucket times (redundant)
   - Better GPU occupancy (more programs)
   - Lower register pressure

2. **fused_cache_update_kernel_per_bucket** (Alternative Design):
   - Grid: (B * H * num_bucket,)
   - Each program handles one (B, H, num_bucket) position
   - Processes N_sample_per_bucket samples in parallel, each with D elements
   - Tokens loaded ONCE and reused across all samples (32x reduction)
   - Better memory efficiency (recommended for memory-bound workloads)
   - Higher register pressure (N_sample_per_bucket * D elements per program)

Use `use_per_bucket_design=True` in `fused_cache_update()` to select the alternative design.

drawback: currently we first precompute cache for prefilling, and hence the states computations for sparse retrieval part may can not attend to the overwritten token in the kv cache from past
"""

import torch
import triton
import triton.language as tl
from typing import Tuple, Optional

# Try to import fla utilities for autograd support
try:
    from fla.utils import autocast_custom_fwd, autocast_custom_bwd, contiguous
except ImportError:
    # Fallback: simple decorators if fla is not available
    def autocast_custom_fwd(fn):
        return fn
    
    def autocast_custom_bwd(fn):
        return fn
    
    def contiguous(fn):
        return fn

@triton.autotune(
    configs=[
        triton.Config({}, num_warps=num_warps)
        for num_warps in [1]
    ],
    key=["D"],
)
@triton.jit
def fused_cache_update_kernel_per_sample(
    cache_k_out,      # [B*H*num_bucket*N_sample_per_bucket, D] - output final states
    cache_v_out,      # [B*H*num_bucket*N_sample_per_bucket, D] - output final states
    k_gen,            # [B, H, num_bucket, max_tokens_per_bucket, D] - input keys
    v_gen,            # [B, H, num_bucket, max_tokens_per_bucket, D] - input values
    alpha,            # [B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket] - mixing coefficients
    slot_probs,       # [B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket] - slot probabilities
    initial_k,        # [B, H, num_bucket, N_sample_per_bucket, D] - initial cache states for K
    initial_v,        # [B, H, num_bucket, N_sample_per_bucket, D] - initial cache states for V
    valid_mask,       # [B, H, num_bucket, max_tokens_per_bucket] - mask for valid tokens
    B: tl.constexpr,
    H: tl.constexpr,
    num_bucket: tl.constexpr,
    N_sample_per_bucket: tl.constexpr,
    max_tokens_per_bucket: tl.constexpr,
    D: tl.constexpr,
):
    """
    Fused kernel (ORIGINAL DESIGN: One Program Per Sample) that computes final cache states using linear recurrence:
    K_i = (1 - alpha_i * slot_probs_i) * K_{i-1} + (alpha_i * slot_probs_i) * k_i
    
    Uses two-level hierarchical parallel scan for long sequences.
    
    Grid: (B * H * num_bucket * N_sample_per_bucket,)
    - Each program handles one scan sequence (one (B, H, num_bucket, N_sample_per_bucket) position)
    - Each program processes all D elements in parallel (typical D is 64-128, small enough)
    - Memory: Tokens are loaded N_sample_per_bucket times (once per sample)
    - GPU Utilization: Better (more programs = better occupancy)
    """
    # Get program ID: which scan sequence we're processing
    pid = tl.program_id(0)  # [0, B*H*num_bucket*N_sample_per_bucket)
    
    # Decompose pid into (b, h, bucket, sample) indices
    total_per_bucket = num_bucket * N_sample_per_bucket
    total_per_head = total_per_bucket * H
    
    i_b = pid // total_per_head
    remainder = pid % total_per_head
    i_h = remainder // total_per_bucket
    remainder = remainder % total_per_bucket
    i_bucket = remainder // N_sample_per_bucket
    i_sample = remainder % N_sample_per_bucket
    
    # Compute base offsets for input arrays
    # k_gen: [B, H, num_bucket, max_tokens_per_bucket, D]
    k_gen_base = (i_b * H * num_bucket * max_tokens_per_bucket * D +
                  i_h * num_bucket * max_tokens_per_bucket * D +
                  i_bucket * max_tokens_per_bucket * D)
    
    # v_gen: same as k_gen
    v_gen_base = k_gen_base
    
    # alpha: [B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket]
    alpha_base = (i_b * H * num_bucket * max_tokens_per_bucket * N_sample_per_bucket +
                  i_h * num_bucket * max_tokens_per_bucket * N_sample_per_bucket +
                  i_bucket * max_tokens_per_bucket * N_sample_per_bucket)
    
    # slot_probs: same as alpha
    slot_probs_base = alpha_base
    
    # valid_mask: [B, H, num_bucket, max_tokens_per_bucket]
    valid_mask_base = (i_b * H * num_bucket * max_tokens_per_bucket +
                       i_h * num_bucket * max_tokens_per_bucket +
                       i_bucket * max_tokens_per_bucket)
    
    # initial_k/v: [B, H, num_bucket, N_sample_per_bucket, D]
    initial_k_base = (i_b * H * num_bucket * N_sample_per_bucket * D +
                      i_h * num_bucket * N_sample_per_bucket * D +
                      i_bucket * N_sample_per_bucket * D +
                      i_sample * D)
    
    initial_v_base = initial_k_base
    
    # output: [B*H*num_bucket*N_sample_per_bucket, D]
    output_base = pid * D
    
    # Process all D elements in parallel
    d_idx = tl.arange(0, D)
    
    # Load initial K and V states
    state_k = tl.load(initial_k + initial_k_base + d_idx)
    state_v = tl.load(initial_v + initial_v_base + d_idx)
    
    # Sequential scan: iterate through tokens
    for t in range(max_tokens_per_bucket):
        # Load valid mask for this token
        is_valid = tl.load(valid_mask + valid_mask_base + t) > 0
        
        # Load alpha and slot_probs for this token and sample
        # Original implementation transposes alpha to [B, H, num_bucket, N_sample_per_bucket, max_tokens_per_bucket]
        # To access alpha_for_scan[b, h, bucket, sample, t] (transposed), we need original alpha[b, h, bucket, t, sample]
        # alpha_base points to [b, h, bucket, 0, 0], so offset is t * N_sample_per_bucket + i_sample
        alpha_val = tl.load(alpha + alpha_base + t * N_sample_per_bucket + i_sample)
        slot_prob_val = tl.load(slot_probs + slot_probs_base + t * N_sample_per_bucket + i_sample)
        
        # Compute alpha_slot_prod
        alpha_slot_prod = alpha_val * slot_prob_val
        
        # Compute coefficient a = 1 - alpha_slot_prod
        # When invalid, we set a_coeff = 1.0 and b = 0.0 (same as original implementation)
        # Note: Original implementation does NOT clamp a_k in the main recurrence path
        # (clamping only happens in initial state correction computation)
        a_coeff = tl.where(is_valid, 1.0 - alpha_slot_prod, 1.0)
        
        # Load k_gen and v_gen for this token (all D elements)
        k_gen_t = tl.load(k_gen + k_gen_base + t * D + d_idx)
        v_gen_t = tl.load(v_gen + v_gen_base + t * D + d_idx)
        
        # Compute b = alpha_slot_prod * k/v
        # When invalid, b = 0.0 (same as original implementation)
        b_k = tl.where(is_valid, alpha_slot_prod * k_gen_t, 0.0)
        b_v = tl.where(is_valid, alpha_slot_prod * v_gen_t, 0.0)
        
        # Update state: state = a * state + b
        # When invalid: state = 1.0 * state + 0.0 = state (no change)
        # Use explicit two-step computation to match PyTorch's order of operations
        # This helps reduce numerical differences between sequential and parallel scan
        state_k = (a_coeff * state_k) + b_k
        state_v = (a_coeff * state_v) + b_v
    
    # Store final states
    tl.store(cache_k_out + output_base + d_idx, state_k)
    tl.store(cache_v_out + output_base + d_idx, state_v)


@triton.autotune(
    configs=[
        triton.Config({}, num_warps=num_warps)
        for num_warps in [4]
    ],
    key=["D"],
)
@triton.jit
def fused_cache_update_kernel_per_bucket(
    cache_k_out,      # [B*H*num_bucket*N_sample_per_bucket, D] - output final states
    cache_v_out,      # [B*H*num_bucket*N_sample_per_bucket, D] - output final states
    k_gen,            # [B, H, num_bucket, max_tokens_per_bucket, D] - input keys
    v_gen,            # [B, H, num_bucket, max_tokens_per_bucket, D] - input values
    alpha,            # [B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket] - mixing coefficients
    slot_probs,       # [B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket] - slot probabilities
    initial_k,        # [B, H, num_bucket, N_sample_per_bucket, D] - initial cache states for K
    initial_v,        # [B, H, num_bucket, N_sample_per_bucket, D] - initial cache states for V
    valid_mask,       # [B, H, num_bucket, max_tokens_per_bucket] - mask for valid tokens
    B: tl.constexpr,
    H: tl.constexpr,
    num_bucket: tl.constexpr,
    N_sample_per_bucket: tl.constexpr,
    max_tokens_per_bucket: tl.constexpr,
    D: tl.constexpr,
):
    """
    Fused kernel (ALTERNATIVE DESIGN: One Program Per Bucket) that computes final cache states using linear recurrence:
    K_i = (1 - alpha_i * slot_probs_i) * K_{i-1} + (alpha_i * slot_probs_i) * k_i
    
    Only computes the final state (not all intermediate states), allowing us to
    use a simple sequential scan instead of parallel scan.
    
    Grid: (B * H * num_bucket,)
    - Each program handles one bucket (one (B, H, num_bucket) position)
    - Each program processes N_sample_per_bucket samples in parallel
    - Each sample processes all D elements in parallel
    - Memory: Tokens are loaded ONCE and reused across all samples (32x reduction in token loads)
    - GPU Utilization: Potentially lower (fewer programs, higher register pressure)
    
    This design is better for memory-bound workloads because it reduces redundant memory loads.
    """
    # Get program ID: which bucket we're processing
    pid = tl.program_id(0)  # [0, B*H*num_bucket)
    
    # Decompose pid into (b, h, bucket) indices
    total_per_head = num_bucket * H
    i_b = pid // total_per_head
    remainder = pid % total_per_head
    i_h = remainder // num_bucket
    i_bucket = remainder % num_bucket
    
    # Compute base offsets for input arrays
    # k_gen: [B, H, num_bucket, max_tokens_per_bucket, D]
    k_gen_base = (i_b * H * num_bucket * max_tokens_per_bucket * D +
                  i_h * num_bucket * max_tokens_per_bucket * D +
                  i_bucket * max_tokens_per_bucket * D)
    
    # v_gen: same as k_gen
    v_gen_base = k_gen_base
    
    # alpha: [B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket]
    alpha_base = (i_b * H * num_bucket * max_tokens_per_bucket * N_sample_per_bucket +
                  i_h * num_bucket * max_tokens_per_bucket * N_sample_per_bucket +
                  i_bucket * max_tokens_per_bucket * N_sample_per_bucket)
    
    # slot_probs: same as alpha
    slot_probs_base = alpha_base
    
    # valid_mask: [B, H, num_bucket, max_tokens_per_bucket]
    valid_mask_base = (i_b * H * num_bucket * max_tokens_per_bucket +
                       i_h * num_bucket * max_tokens_per_bucket +
                       i_bucket * max_tokens_per_bucket)
    
    # initial_k/v: [B, H, num_bucket, N_sample_per_bucket, D]
    initial_k_base = (i_b * H * num_bucket * N_sample_per_bucket * D +
                      i_h * num_bucket * N_sample_per_bucket * D +
                      i_bucket * N_sample_per_bucket * D)
    
    initial_v_base = initial_k_base
    
    # Process all N_sample_per_bucket samples in parallel
    # Each sample processes all D elements in parallel
    # We use a flattened approach: process (sample_idx * D + d_idx) combinations
    # Since N_sample_per_bucket and D are both tl.constexpr, their product is also constexpr
    element_idx = tl.arange(0, N_sample_per_bucket * D)  # [0, N_sample_per_bucket * D)
    
    # Decompose element_idx into (sample_idx, d_idx)
    sample_idx = element_idx // D  # [0, N_sample_per_bucket)
    d_idx = element_idx % D  # [0, D)
    
    # Load initial K and V states for all (sample, d) combinations
    # initial_k: [B, H, num_bucket, N_sample_per_bucket, D]
    # Offset for (sample, d): sample * D + d
    initial_k_offsets = initial_k_base + element_idx
    state_k = tl.load(initial_k + initial_k_offsets)  # [N_sample_per_bucket * D]
    state_v = tl.load(initial_v + initial_k_offsets)  # [N_sample_per_bucket * D]
    
    # Sequential scan: iterate through tokens
    for t in range(max_tokens_per_bucket):
        # Load valid mask for this token (same for all samples)
        is_valid = tl.load(valid_mask + valid_mask_base + t) > 0
        
        # Load k_gen and v_gen for this token (all D elements, shared across all samples)
        # We load D elements and broadcast them across all samples
        k_gen_t_base = k_gen_base + t * D
        v_gen_t_base = v_gen_base + t * D
        # d_idx is [0, D) repeated N_sample_per_bucket times, so we can use modulo
        k_gen_t = tl.load(k_gen + k_gen_t_base + (element_idx % D))  # [N_sample_per_bucket * D]
        v_gen_t = tl.load(v_gen + v_gen_t_base + (element_idx % D))  # [N_sample_per_bucket * D]
        
        # Load alpha and slot_probs for this token and ALL samples
        # alpha: [B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket]
        # We need: alpha[b, h, bucket, t, sample] for each sample
        # sample_idx is [0, N_sample_per_bucket) repeated D times
        alpha_offsets = alpha_base + t * N_sample_per_bucket + sample_idx
        alpha_vals = tl.load(alpha + alpha_offsets)  # [N_sample_per_bucket * D]
        slot_prob_vals = tl.load(slot_probs + alpha_offsets)  # [N_sample_per_bucket * D]
        
        # Compute alpha_slot_prod for all samples
        alpha_slot_prod = alpha_vals * slot_prob_vals  # [N_sample_per_bucket * D]
        
        # Compute coefficient a = 1 - alpha_slot_prod for all samples
        # When invalid, we set a_coeff = 1.0 (same for all samples)
        a_coeff = tl.where(is_valid, 1.0 - alpha_slot_prod, 1.0)  # [N_sample_per_bucket * D]
        
        # Compute b = alpha_slot_prod * k/v for all samples
        # When invalid, b = 0.0 (same as original implementation)
        b_k = tl.where(is_valid, alpha_slot_prod * k_gen_t, 0.0)  # [N_sample_per_bucket * D]
        b_v = tl.where(is_valid, alpha_slot_prod * v_gen_t, 0.0)  # [N_sample_per_bucket * D]
        
        # Update state for all samples: state = a * state + b
        # When invalid: state = 1.0 * state + 0.0 = state (no change)
        state_k = a_coeff * state_k + b_k  # [N_sample_per_bucket * D]
        state_v = a_coeff * state_v + b_v  # [N_sample_per_bucket * D]
    
    # Store final states for all samples
    # output: [B*H*num_bucket*N_sample_per_bucket, D]
    # Output base for this bucket: (i_b * H * num_bucket * N_sample_per_bucket + i_h * num_bucket * N_sample_per_bucket + i_bucket * N_sample_per_bucket) * D
    output_base = ((i_b * H * num_bucket * N_sample_per_bucket +
                    i_h * num_bucket * N_sample_per_bucket +
                    i_bucket * N_sample_per_bucket) * D)
    
    output_offsets = output_base + element_idx
    
    tl.store(cache_k_out + output_offsets, state_k)
    tl.store(cache_v_out + output_offsets, state_v)


# Note: For backward pass, we use PyTorch's autograd system.
# Since the forward kernel only computes the final state (not intermediate states),
# implementing a custom backward kernel would require recomputing intermediate states,
# which defeats the purpose of the optimization.
# 
# The Function wrapper allows PyTorch's autograd to handle gradients automatically.
# When inputs have requires_grad=True, autograd will trace through the forward computation.


class FusedCacheUpdateFunction(torch.autograd.Function):
    """
    PyTorch autograd Function for fused cache update with backward support.
    """
    
    @staticmethod
    @contiguous
    @autocast_custom_fwd
    def forward(
        ctx,
        cache_k: torch.Tensor,
        cache_v: torch.Tensor,
        k_gen_padded: torch.Tensor,
        v_gen_padded: torch.Tensor,
        slot_probs_padded: torch.Tensor,
        alpha_padded: torch.Tensor,
        bucket_id_padded: torch.Tensor,
        padded_gen_len: int,
        B: int,
        H: int,
        N_sample_per_bucket: int,
        D: int,
        num_leaf_buckets: int,
        use_per_bucket_design: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass of fused cache update.
        
        Returns:
            cache_k: Updated cache keys [B, H, cache_len, D]
            cache_v: Updated cache values [B, H, cache_len, D]
        """
        device = cache_k.device
        dtype = cache_k.dtype
        cache_len_val = cache_k.size(2)
        num_bucket = num_leaf_buckets
        max_tokens_per_bucket = padded_gen_len // num_bucket if num_bucket > 0 else padded_gen_len
        
        # Create valid mask
        valid_mask = (bucket_id_padded >= 0).view(B, H, num_bucket, max_tokens_per_bucket)
        
        # Reshape inputs
        k_gen = k_gen_padded.view(B, H, num_bucket, max_tokens_per_bucket, D).contiguous()
        v_gen = v_gen_padded.view(B, H, num_bucket, max_tokens_per_bucket, D).contiguous()
        slot_probs = slot_probs_padded.view(B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket).contiguous()
        alpha = alpha_padded.view(B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket).contiguous()
        initial_k = cache_k.view(B, H, num_bucket, N_sample_per_bucket, D).contiguous()
        initial_v = cache_v.view(B, H, num_bucket, N_sample_per_bucket, D).contiguous()
        valid_mask = valid_mask.contiguous()
        
        # Prepare output tensors
        batch_size_scan = B * H * num_bucket * N_sample_per_bucket
        cache_k_out = torch.empty(batch_size_scan, D, device=device, dtype=dtype)
        cache_v_out = torch.empty(batch_size_scan, D, device=device, dtype=dtype)

        # Launch kernel based on design choice
        if use_per_bucket_design:
            # Alternative design: One program per bucket
            # Grid: (B * H * num_bucket,)
            grid = (B * H * num_bucket,)
            fused_cache_update_kernel_per_bucket[grid](
                cache_k_out, cache_v_out,
                k_gen, v_gen, alpha, slot_probs,
                initial_k, initial_v, valid_mask,
                B=B, H=H, num_bucket=num_bucket,
                N_sample_per_bucket=N_sample_per_bucket,
                max_tokens_per_bucket=max_tokens_per_bucket,
                D=D,
            )
        else:
            # Original design: One program per sample with two-level hierarchical scan
            # Use two-level scan for long sequences (similar to flash-linear-attention)
            BT = 512  # chunk size
            BD = min(128, triton.next_power_of_2(D))
            num_chunks = (max_tokens_per_bucket + BT - 1) // BT
            
            if False: # max_tokens_per_bucket > BT and num_chunks > 1:
                # Use final-state-only two-level scan (v2)
                A_scalar = torch.empty(batch_size_scan, num_chunks, device=device, dtype=torch.bfloat16)
                B_k = torch.empty(batch_size_scan, num_chunks, D, device=device, dtype=torch.bfloat16)
                B_v = torch.empty(batch_size_scan, num_chunks, D, device=device, dtype=torch.bfloat16)
                out_final_k = torch.empty(batch_size_scan, D, device=device, dtype=torch.bfloat16)
                out_final_v = torch.empty(batch_size_scan, D, device=device, dtype=torch.bfloat16)

                def grid_h(meta):
                    return (triton.cdiv(D, BD), num_chunks, batch_size_scan)

                chunk_scan_chunk_params_kernel_v2[grid_h](
                    k_gen, v_gen, alpha, slot_probs, valid_mask,
                    A_scalar, B_k, B_v,
                    T=max_tokens_per_bucket, D=D, BT=BT, BD=BD,
                    B=B, H=H, num_bucket=num_bucket,
                    N_sample_per_bucket=N_sample_per_bucket,
                )

                def grid_o(meta):
                    return (triton.cdiv(D, BD), batch_size_scan)

                chunk_scan_combine_chunks_kernel_v2[grid_o](
                    initial_k, initial_v,
                    A_scalar, B_k, B_v,
                    out_final_k, out_final_v,
                    T=max_tokens_per_bucket, D=D, BT=BT, BD=BD,
                    B=B, H=H, num_bucket=num_bucket,
                    N_sample_per_bucket=N_sample_per_bucket,
                    USE_INITIAL_STATE=True,
                )

                # hier_k = out_final_k.view(B, H, num_bucket, N_sample_per_bucket, D)
                # hier_v = out_final_v.view(B, H, num_bucket, N_sample_per_bucket, D)
            else:
                # Use simple sequential scan
                # print(f"  Note: Sequence length {max_tokens_per_bucket} <= BT={BT}, using simple sequential scan")
                cache_k_out = torch.empty(batch_size_scan, D, device=device, dtype=dtype)
                cache_v_out = torch.empty(batch_size_scan, D, device=device, dtype=dtype)
            
                grid = (batch_size_scan,)
                fused_cache_update_kernel_per_sample[grid](
                cache_k_out, cache_v_out,
                k_gen, v_gen, alpha, slot_probs,
                initial_k, initial_v, valid_mask,
                B=B, H=H, num_bucket=num_bucket,
                N_sample_per_bucket=N_sample_per_bucket,
                max_tokens_per_bucket=max_tokens_per_bucket,
                D=D,
                )
            
                # hier_k = cache_k_out.view(B, H, num_bucket, N_sample_per_bucket, D)
                # hier_v = cache_v_out.view(B, H, num_bucket, N_sample_per_bucket, D)
        
        # Compare results
    #     max_diff_k = (hier_k - ref_k).abs().max().item()
    #     max_diff_v = (hier_v - ref_v).abs().max().item()
    #     mean_diff_k = (hier_k - ref_k).abs().mean().item()
    #     mean_diff_v = (hier_v - ref_v).abs().mean().item()
        
    #     print(f"  Max difference (K): {max_diff_k:.6e}")
    #     print(f"  Max difference (V): {max_diff_v:.6e}")
    #     print(f"  Mean difference (K): {mean_diff_k:.6e}")
    #     print(f"  Mean difference (V): {mean_diff_v:.6e}")
        
    #     # Tolerance check (allow for floating point errors)
    #     rtol = 1e-4
    #     atol = 1e-5
    #     if torch.allclose(hier_k, ref_k, rtol=rtol, atol=atol) and torch.allclose(hier_v, ref_v, rtol=rtol, atol=atol):
    #         print(f"  ✓ PASS (within tolerance rtol={rtol}, atol={atol})")
    #     else:
    #         print(f"  ✗ FAIL (exceeds tolerance)")
    #         # Show some sample differences
    #         print(f"  Sample differences:")
    #         diff_k_sample = (hier_k - ref_k).abs()
    #         diff_v_sample = (hier_v - ref_v).abs()
    #         print(f"    K: max={diff_k_sample.max().item():.6e}, mean={diff_k_sample.mean().item():.6e}")
    #         print(f"    V: max={diff_v_sample.max().item():.6e}, mean={diff_v_sample.mean().item():.6e}")
    
    # print("\n" + "=" * 80)

        # Reshape output
        cache_k_final = cache_k_out.view(B, H, num_bucket, N_sample_per_bucket, D)
        cache_v_final = cache_v_out.view(B, H, num_bucket, N_sample_per_bucket, D)
        cache_k_result = cache_k_final.view(B, H, cache_len_val, D)
        cache_v_result = cache_v_final.view(B, H, cache_len_val, D)
        
        # Save for backward
        # ctx.save_for_backward(
        #     k_gen, v_gen, alpha, slot_probs,
        #     initial_k, initial_v, valid_mask
        # )
        # ctx.B = B
        # ctx.H = H
        # ctx.num_bucket = num_bucket
        # ctx.N_sample_per_bucket = N_sample_per_bucket
        # ctx.max_tokens_per_bucket = max_tokens_per_bucket
        # ctx.D = D
        # ctx.cache_len_val = cache_len_val
        # ctx.padded_gen_len = padded_gen_len
        # ctx.use_per_bucket_design = use_per_bucket_design
        
        return cache_k_result, cache_v_result


def fused_cache_update(
    cache_k: torch.Tensor,
    cache_v: torch.Tensor,
    k_gen_padded: torch.Tensor,
    v_gen_padded: torch.Tensor,
    slot_probs_padded: torch.Tensor,
    alpha_padded: torch.Tensor,
    bucket_id_padded: torch.Tensor,
    padded_gen_len: int,
    B: int,
    H: int,
    N_sample_per_bucket: int,
    D: int,
    num_leaf_buckets: int,
    use_per_bucket_design: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Fused cache update using Triton kernel (with autograd support).
    
    Args:
        cache_k: [B, H, cache_len, D] - cache keys
        cache_v: [B, H, cache_len, D] - cache values
        k_gen_padded: [B, H, padded_gen_len, D] - generated keys
        v_gen_padded: [B, H, padded_gen_len, D] - generated values
        slot_probs_padded: [B, H, padded_gen_len, N_sample_per_bucket] - slot probabilities
        alpha_padded: [B, H, padded_gen_len, N_sample_per_bucket] - mixing coefficients
        bucket_id_padded: [B, H, padded_gen_len] - bucket IDs for each token
        padded_gen_len: Total padded generation length
        B: Batch size
        H: Number of heads
        N_sample_per_bucket: Number of samples per bucket
        D: Feature dimension
        num_leaf_buckets: Number of leaf buckets (num_bucket)
        use_per_bucket_design: If True, use alternative design (one program per bucket).
                              If False, use original design (one program per sample).
                              Default: False (original design).
    
    Returns:
        cache_k: [B, H, cache_len, D] - updated cache keys
        cache_v: [B, H, cache_len, D] - updated cache values
    
    Design Comparison:
        - Original (use_per_bucket_design=False):
          * Grid: (B * H * num_bucket * N_sample_per_bucket,)
          * Better GPU occupancy (more programs)
          * Tokens loaded N_sample_per_bucket times (redundant)
          * Lower register pressure
        
        - Alternative (use_per_bucket_design=True):
          * Grid: (B * H * num_bucket,)
          * Better memory efficiency (tokens loaded once, reused)
          * 32x reduction in redundant token loads
          * Higher register pressure (N_sample_per_bucket * D elements per program)
          * Recommended for memory-bound workloads
    """
    return FusedCacheUpdateFunction.apply(
        cache_k, cache_v,
        k_gen_padded, v_gen_padded,
        slot_probs_padded, alpha_padded,
        bucket_id_padded,
        padded_gen_len, B, H, N_sample_per_bucket, D,
        num_leaf_buckets, use_per_bucket_design
    )
# ============================================================
# Final-state-only two-level scan (v2): affine params per chunk
#   Level 1: for each chunk compute (A_scalar, B_k, B_v) such that
#       h_end = A_scalar * h_start + B
#   Level 2: apply chunk affines sequentially over chunks to get FINAL state.
#
# Notes:
# - No per-token outputs (avoids allocating o_k/o_v).
# - A is stored as scalar per (pid, chunk) (not replicated across D).
# - Scalar loads for alpha/slot_probs/valid_mask: do NOT pass block masks.
# ============================================================

@triton.autotune(
    configs=[
        # triton.Config({}, num_warps=1, ),
        triton.Config({}, num_warps=2, ),
        # triton.Config({}, num_warps=4, ),
        # triton.Config({}, num_warps=8, ),
    ],
    key=['D'],
)
@triton.jit
def chunk_scan_chunk_params_kernel_v2(
    k_gen, v_gen,
    alpha, slot_probs,
    valid_mask,
    A_scalar,     # [batch_size_scan, num_chunks] fp32
    B_k, B_v,     # [batch_size_scan, num_chunks, D] fp32
    T: tl.constexpr, D: tl.constexpr,
    BT: tl.constexpr, BD: tl.constexpr,
    B: tl.constexpr, H: tl.constexpr, num_bucket: tl.constexpr,
    N_sample_per_bucket: tl.constexpr,
):
    # Grid: (ceil(D/BD), num_chunks, batch_size_scan)
    pid_d = tl.program_id(0)
    pid_chunk = tl.program_id(1)
    pid = tl.program_id(2)   # flattened (b,h,bucket,sample)

    # Decode flattened id
    i_sample = pid % N_sample_per_bucket
    tmp = pid // N_sample_per_bucket
    i_bucket = tmp % num_bucket
    tmp //= num_bucket
    i_h = tmp % H
    i_b = tmp // H

    o_d = pid_d * BD + tl.arange(0, BD)
    mask_d = o_d < D

    # Base offsets
    k_base = (i_b * H * num_bucket * T * D +
              i_h * num_bucket * T * D +
              i_bucket * T * D +
              pid_chunk * BT * D)
    a_base = (i_b * H * num_bucket * T * N_sample_per_bucket +
              i_h * num_bucket * T * N_sample_per_bucket +
              i_bucket * T * N_sample_per_bucket +
              pid_chunk * BT * N_sample_per_bucket)
    vld_base = (i_b * H * num_bucket * T +
                i_h * num_bucket * T +
                i_bucket * T +
                pid_chunk * BT)

    num_chunks = (T + BT - 1) // BT

    # Output base for B vectors
    out_vec_base = (pid * num_chunks + pid_chunk) * D + o_d

    # Local recurrence within the chunk, starting from zero state
    h_k = tl.zeros([BD], dtype=tl.float32)
    h_v = tl.zeros([BD], dtype=tl.float32)
    A = tl.full([], 1.0, tl.float32)  # scalar

    p_k = k_gen + k_base + o_d
    p_v = v_gen + k_base + o_d
    p_alpha = alpha + a_base + i_sample
    p_slot  = slot_probs + a_base + i_sample
    p_valid = valid_mask + vld_base

    BC = tl.minimum(BT, T - pid_chunk * BT)
    for _ in range(BC):
        is_valid = tl.load(p_valid) > 0

        k_vec = tl.load(p_k, mask=mask_d, other=0).to(tl.float32)
        v_vec = tl.load(p_v, mask=mask_d, other=0).to(tl.float32)

        # scalar loads (no block masks!)
        alpha_val = tl.load(p_alpha).to(tl.float32)
        slot_val  = tl.load(p_slot).to(tl.float32)

        alpha_slot = alpha_val * slot_val
        a = tl.where(is_valid, 1.0 - alpha_slot, 1.0)
        b_scale = tl.where(is_valid, alpha_slot, 0.0)

        h_k = a * h_k + b_scale * k_vec
        h_v = a * h_v + b_scale * v_vec
        A = A * a

        p_alpha += N_sample_per_bucket
        p_slot  += N_sample_per_bucket
        p_k += D
        p_v += D
        p_valid += 1

    # Store B vectors
    tl.store(B_k + out_vec_base, h_k.to(B_k.dtype.element_ty), mask=mask_d)
    tl.store(B_v + out_vec_base, h_v.to(B_v.dtype.element_ty), mask=mask_d)

    # Store scalar A once per (pid, chunk) from a single lane (pid_d==0)
    if pid_d == 0:
        tl.store(A_scalar + pid * num_chunks + pid_chunk, A)


@triton.autotune(
    configs=[
        # triton.Config({}, num_warps=1, ),
        triton.Config({}, num_warps=2, ),
        # triton.Config({}, num_warps=4, ),
        # triton.Config({}, num_warps=8, ),
    ],
    key=['D'], #, 'T'],
)
@triton.jit
def chunk_scan_combine_chunks_kernel_v2(
    initial_k, initial_v,     # [B,H,num_bucket,N_sample,D]
    A_scalar,                 # [batch_size_scan, num_chunks]
    B_k, B_v,                 # [batch_size_scan, num_chunks, D]
    out_final_k, out_final_v, # [batch_size_scan, D]
    T: tl.constexpr, D: tl.constexpr,
    BT: tl.constexpr, BD: tl.constexpr,
    B: tl.constexpr, H: tl.constexpr, num_bucket: tl.constexpr,
    N_sample_per_bucket: tl.constexpr,
    USE_INITIAL_STATE: tl.constexpr,
):
    # Grid: (ceil(D/BD), batch_size_scan)
    pid_d = tl.program_id(0)
    pid = tl.program_id(1)

    # Decode flattened id
    i_sample = pid % N_sample_per_bucket
    tmp = pid // N_sample_per_bucket
    i_bucket = tmp % num_bucket
    tmp //= num_bucket
    i_h = tmp % H
    i_b = tmp // H

    o_d = pid_d * BD + tl.arange(0, BD)
    mask_d = o_d < D

    init_base = (i_b * H * num_bucket * N_sample_per_bucket * D +
                 i_h * num_bucket * N_sample_per_bucket * D +
                 i_bucket * N_sample_per_bucket * D +
                 i_sample * D)

    h_k = tl.zeros([BD], dtype=tl.float32)
    h_v = tl.zeros([BD], dtype=tl.float32)
    if USE_INITIAL_STATE:
        h_k = tl.load(initial_k + init_base + o_d, mask=mask_d, other=0).to(tl.float32)
        h_v = tl.load(initial_v + init_base + o_d, mask=mask_d, other=0).to(tl.float32)

    num_chunks = (T + BT - 1) // BT
    base_vec = (pid * num_chunks) * D + o_d
    base_a = pid * num_chunks

    # Combine chunks: h <- A_c * h + B_c
    for c in range(num_chunks):
        a = tl.load(A_scalar + base_a + c).to(tl.float32)  # scalar
        b_k = tl.load(B_k + base_vec + c * D, mask=mask_d, other=0).to(tl.float32)
        b_v = tl.load(B_v + base_vec + c * D, mask=mask_d, other=0).to(tl.float32)
        h_k = a * h_k + b_k
        h_v = a * h_v + b_v

    out_base = pid * D
    tl.store(out_final_k + out_base + o_d, h_k.to(out_final_k.dtype.element_ty), mask=mask_d)
    tl.store(out_final_v + out_base + o_d, h_v.to(out_final_v.dtype.element_ty), mask=mask_d)




# ============================================================================
# Reference + tests (uses v2 FINAL-STATE chunk scan kernels)
# ============================================================================

# import time


# def reference_sequential_scan(
#     k_gen: torch.Tensor,
#     v_gen: torch.Tensor,
#     alpha: torch.Tensor,
#     slot_probs: torch.Tensor,
#     valid_mask: torch.Tensor,
#     initial_k: torch.Tensor,
#     initial_v: torch.Tensor,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
#     """Reference implementation using sequential scan in PyTorch."""
#     B, H, num_bucket, max_tokens_per_bucket, D = k_gen.shape
#     N_sample_per_bucket = alpha.shape[-1]

#     state_k = initial_k.clone()
#     state_v = initial_v.clone()

#     for t in range(max_tokens_per_bucket):
#         is_valid = valid_mask[:, :, :, t].unsqueeze(-1).unsqueeze(-1)  # [B,H,bucket,1,1]
#         alpha_t = alpha[:, :, :, t, :].unsqueeze(-1)                   # [B,H,bucket,S,1]
#         probs_t = slot_probs[:, :, :, t, :].unsqueeze(-1)              # [B,H,bucket,S,1]

#         alpha_slot = alpha_t * probs_t
#         k_t = k_gen[:, :, :, t, :].unsqueeze(-2)  # [B,H,bucket,1,D]
#         v_t = v_gen[:, :, :, t, :].unsqueeze(-2)

#         a = torch.where(is_valid, 1.0 - alpha_slot, 1.0)
#         b_k = torch.where(is_valid, alpha_slot * k_t, 0.0)
#         b_v = torch.where(is_valid, alpha_slot * v_t, 0.0)

#         state_k = a * state_k + b_k
#         state_v = a * state_v + b_v

#     return state_k, state_v


# def test_hierarchical_scan_correctness():
#     """Test that v2 hierarchical scan produces the same final state as sequential scan."""
#     print("=" * 80)
#     print("Testing Two-Level Hierarchical Scan Correctness (v2 final-state kernels)")
#     print("=" * 80)

#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     dtype = torch.bfloat16

#     test_cases = [
#         {"B": 1, "H": 2, "num_bucket": 32, "N_sample_per_bucket": 32, "D": 64, "max_tokens": 128},
#     ]

#     for case_idx, case in enumerate(test_cases):
#         print(f"\nTest Case {case_idx + 1}: {case}")

#         B = case["B"]
#         H = case["H"]
#         num_bucket = case["num_bucket"]
#         N_sample_per_bucket = case["N_sample_per_bucket"]
#         D = case["D"]
#         max_tokens_per_bucket = case["max_tokens"]

#         torch.manual_seed(42 + case_idx)
#         k_gen = torch.randn(B, H, num_bucket, max_tokens_per_bucket, D, device=device, dtype=dtype)
#         v_gen = torch.randn(B, H, num_bucket, max_tokens_per_bucket, D, device=device, dtype=dtype)
#         alpha = torch.sigmoid(torch.randn(B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket, device=device, dtype=dtype))
#         slot_probs = torch.softmax(torch.randn(B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket, device=device, dtype=dtype), dim=-1)
#         valid_mask = torch.ones(B, H, num_bucket, max_tokens_per_bucket, device=device, dtype=torch.bool)
#         initial_k = torch.randn(B, H, num_bucket, N_sample_per_bucket, D, device=device, dtype=dtype)
#         initial_v = torch.randn(B, H, num_bucket, N_sample_per_bucket, D, device=device, dtype=dtype)

#         if max_tokens_per_bucket > 10:
#             valid_mask[:, :, :, -5:] = False

#         ref_k, ref_v = reference_sequential_scan(k_gen, v_gen, alpha, slot_probs, valid_mask, initial_k, initial_v)

#         # v2 hierarchical (final-state only)
#         BT = 512
#         BD = min(128, triton.next_power_of_2(D))
#         num_chunks = (max_tokens_per_bucket + BT - 1) // BT
#         batch_size_scan = B * H * num_bucket * N_sample_per_bucket

#         if max_tokens_per_bucket > BT and num_chunks > 1:
#             A_scalar = torch.empty(batch_size_scan, num_chunks, device=device, dtype=torch.bfloat16)
#             B_k = torch.empty(batch_size_scan, num_chunks, D, device=device, dtype=torch.bfloat16)
#             B_v = torch.empty(batch_size_scan, num_chunks, D, device=device, dtype=torch.bfloat16)
#             out_final_k = torch.empty(batch_size_scan, D, device=device, dtype=torch.bfloat16)
#             out_final_v = torch.empty(batch_size_scan, D, device=device, dtype=torch.bfloat16)

#             def grid_h(meta):
#                 return (triton.cdiv(D, BD), num_chunks, batch_size_scan)

#             chunk_scan_chunk_params_kernel_v2[grid_h](
#                 k_gen, v_gen, alpha, slot_probs, valid_mask,
#                 A_scalar, B_k, B_v,
#                 T=max_tokens_per_bucket, D=D, BT=BT, BD=BD,
#                 B=B, H=H, num_bucket=num_bucket,
#                 N_sample_per_bucket=N_sample_per_bucket,
#             )

#             def grid_o(meta):
#                 return (triton.cdiv(D, BD), batch_size_scan)

#             chunk_scan_combine_chunks_kernel_v2[grid_o](
#                 initial_k, initial_v,
#                 A_scalar, B_k, B_v,
#                 out_final_k, out_final_v,
#                 T=max_tokens_per_bucket, D=D, BT=BT, BD=BD,
#                 B=B, H=H, num_bucket=num_bucket,
#                 N_sample_per_bucket=N_sample_per_bucket,
#                 USE_INITIAL_STATE=True,
#             )

#             hier_k = out_final_k.view(B, H, num_bucket, N_sample_per_bucket, D)
#             hier_v = out_final_v.view(B, H, num_bucket, N_sample_per_bucket, D)
#         else:
#             print(f"  Note: Sequence length {max_tokens_per_bucket} <= BT={BT}, using per-sample sequential kernel")
#             cache_k_out = torch.empty(batch_size_scan, D, device=device, dtype=dtype)
#             cache_v_out = torch.empty(batch_size_scan, D, device=device, dtype=dtype)
#             grid = (batch_size_scan,)
#             fused_cache_update_kernel_per_sample[grid](
#                 cache_k_out, cache_v_out,
#                 k_gen, v_gen, alpha, slot_probs,
#                 initial_k, initial_v, valid_mask,
#                 B=B, H=H, num_bucket=num_bucket,
#                 N_sample_per_bucket=N_sample_per_bucket,
#                 max_tokens_per_bucket=max_tokens_per_bucket,
#                 D=D,
#             )
#             hier_k = cache_k_out.view(B, H, num_bucket, N_sample_per_bucket, D)
#             hier_v = cache_v_out.view(B, H, num_bucket, N_sample_per_bucket, D)

#         # Per-bucket kernel (optional comparison)
#         cache_k_bucket = torch.empty(batch_size_scan, D, device=device, dtype=dtype)
#         cache_v_bucket = torch.empty(batch_size_scan, D, device=device, dtype=dtype)

#         def grid_bucket(meta):
#             # Handle both autotune config (has BD/BS) and bound_args (might not have them)
#             BD = meta.get("BD", 64)  # Default to 64 if not in meta
#             BS = meta.get("BS", 16)  # Default to 16 if not in meta
#             return (triton.cdiv(D, BD),
#                     triton.cdiv(N_sample_per_bucket, BS),
#                     B * H * num_bucket)

#         fused_cache_update_kernel_per_bucket[grid_bucket](
#             cache_k_bucket, cache_v_bucket,
#             k_gen, v_gen, alpha, slot_probs,
#             initial_k, initial_v, valid_mask,
#             B=B, H=H, num_bucket=num_bucket,
#             N_sample_per_bucket=N_sample_per_bucket,
#             max_tokens_per_bucket=max_tokens_per_bucket,
#             D=D,
#         )

#         bucket_k = cache_k_bucket.view(B, H, num_bucket, N_sample_per_bucket, D)
#         bucket_v = cache_v_bucket.view(B, H, num_bucket, N_sample_per_bucket, D)

#         print("\n  === v2 Hierarchical vs Reference ===")
#         print(f"  Max diff (K): {(hier_k - ref_k).abs().max().item():.6e}")
#         print(f"  Max diff (V): {(hier_v - ref_v).abs().max().item():.6e}")
#         print(f"  Mean diff (K): {(hier_k - ref_k).abs().mean().item():.6e}")
#         print(f"  Mean diff (V): {(hier_v - ref_v).abs().mean().item():.6e}")

#         print("\n  === Per-Bucket vs Reference ===")
#         print(f"  Max diff (K): {(bucket_k - ref_k).abs().max().item():.6e}")
#         print(f"  Max diff (V): {(bucket_v - ref_v).abs().max().item():.6e}")
#         print(f"  Mean diff (K): {(bucket_k - ref_k).abs().mean().item():.6e}")
#         print(f"  Mean diff (V): {(bucket_v - ref_v).abs().mean().item():.6e}")

#         rtol, atol = 1e-4, 1e-5
#         hier_pass = torch.allclose(hier_k, ref_k, rtol=rtol, atol=atol) and torch.allclose(hier_v, ref_v, rtol=rtol, atol=atol)
#         bucket_pass = torch.allclose(bucket_k, ref_k, rtol=rtol, atol=atol) and torch.allclose(bucket_v, ref_v, rtol=rtol, atol=atol)
#         hier_bucket_match = torch.allclose(hier_k, bucket_k, rtol=rtol, atol=atol) and torch.allclose(hier_v, bucket_v, rtol=rtol, atol=atol)

#         print("\n  === Overall Results ===")
#         print(f"  {'✓' if hier_pass else '✗'} v2 Hierarchical {'PASS' if hier_pass else 'FAIL'} (rtol={rtol}, atol={atol})")
#         print(f"  {'✓' if bucket_pass else '✗'} Per-Bucket {'PASS' if bucket_pass else 'FAIL'} (rtol={rtol}, atol={atol})")
#         print(f"  {'✓' if hier_bucket_match else '✗'} Hierarchical vs Per-Bucket {'MATCH' if hier_bucket_match else 'MISMATCH'} (rtol={rtol}, atol={atol})")

#     print("\n" + "=" * 80)


# def compare_sequential_vs_hierarchical_speed():
#     """Compare speed: per-sample sequential vs per-bucket vs v2 hierarchical final-state scan."""
#     print("=" * 80)
#     print("Comparing Speed: Sequential (per-sample) vs Per-bucket vs Hierarchical (v2 final-state)")
#     print("=" * 80)

#     if not torch.cuda.is_available():
#         print("CUDA not available, skipping speed comparison")
#         return

#     device = torch.device("cuda")
#     dtype = torch.bfloat16

#     test_cases = [
#         {"B": 1, "H": 8, "num_bucket": 64, "N_sample_per_bucket": 32, "D": 128, "max_tokens": 7417},
#     ]

#     num_warmup = 10
#     num_iterations = 100

#     for case_idx, case in enumerate(test_cases):
#         print(f"\nTest Case {case_idx + 1}: {case}")

#         B = case["B"]
#         H = case["H"]
#         num_bucket = case["num_bucket"]
#         N_sample_per_bucket = case["N_sample_per_bucket"]
#         D = case["D"]
#         max_tokens_per_bucket = case["max_tokens"]

#         torch.manual_seed(42 + case_idx)
#         k_gen = torch.randn(B, H, num_bucket, max_tokens_per_bucket, D, device=device, dtype=dtype)
#         v_gen = torch.randn(B, H, num_bucket, max_tokens_per_bucket, D, device=device, dtype=dtype)
#         alpha = torch.sigmoid(torch.randn(B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket, device=device, dtype=dtype))
#         slot_probs = torch.softmax(torch.randn(B, H, num_bucket, max_tokens_per_bucket, N_sample_per_bucket, device=device, dtype=dtype), dim=-1)
#         valid_mask = torch.ones(B, H, num_bucket, max_tokens_per_bucket, device=device, dtype=torch.bool)
#         initial_k = torch.randn(B, H, num_bucket, N_sample_per_bucket, D, device=device, dtype=dtype)
#         initial_v = torch.randn(B, H, num_bucket, N_sample_per_bucket, D, device=device, dtype=dtype)

#         batch_size_scan = B * H * num_bucket * N_sample_per_bucket

#         # ---- Sequential per-sample ----
#         cache_k_seq = torch.empty(batch_size_scan, D, device=device, dtype=dtype)
#         cache_v_seq = torch.empty(batch_size_scan, D, device=device, dtype=dtype)
#         grid_seq = (batch_size_scan,)

#         for _ in range(num_warmup):
#             fused_cache_update_kernel_per_sample[grid_seq](
#                 cache_k_seq, cache_v_seq,
#                 k_gen, v_gen, alpha, slot_probs,
#                 initial_k, initial_v, valid_mask,
#                 B=B, H=H, num_bucket=num_bucket,
#                 N_sample_per_bucket=N_sample_per_bucket,
#                 max_tokens_per_bucket=max_tokens_per_bucket,
#                 D=D,
#             )
#         torch.cuda.synchronize()
#         t0 = time.perf_counter()
#         for _ in range(num_iterations):
#             fused_cache_update_kernel_per_sample[grid_seq](
#                 cache_k_seq, cache_v_seq,
#                 k_gen, v_gen, alpha, slot_probs,
#                 initial_k, initial_v, valid_mask,
#                 B=B, H=H, num_bucket=num_bucket,
#                 N_sample_per_bucket=N_sample_per_bucket,
#                 max_tokens_per_bucket=max_tokens_per_bucket,
#                 D=D,
#             )
#         torch.cuda.synchronize()
#         seq_time = (time.perf_counter() - t0) / num_iterations * 1000

#         # ---- Per-bucket design ----
#         cache_k_bucket = torch.empty(batch_size_scan, D, device=device, dtype=dtype)
#         cache_v_bucket = torch.empty(batch_size_scan, D, device=device, dtype=dtype)

#         def grid_bucket(meta):
#             # Handle both autotune config (has BD/BS) and bound_args (might not have them)
#             BD = meta.get("BD", 128)  # Default to 64 if not in meta
#             BS = meta.get("BS", 16)  # Default to 16 if not in meta
#             return (triton.cdiv(D, BD),
#                     triton.cdiv(N_sample_per_bucket, BS),
#                     B * H * num_bucket)

#         for _ in range(num_warmup):
#             fused_cache_update_kernel_per_bucket[grid_bucket](
#                 cache_k_bucket, cache_v_bucket,
#                 k_gen, v_gen, alpha, slot_probs,
#                 initial_k, initial_v, valid_mask,
#                 B=B, H=H, num_bucket=num_bucket,
#                 N_sample_per_bucket=N_sample_per_bucket,
#                 max_tokens_per_bucket=max_tokens_per_bucket,
#                 D=D,
#             )
#         torch.cuda.synchronize()
#         t0 = time.perf_counter()
#         for _ in range(num_iterations):
#             fused_cache_update_kernel_per_bucket[grid_bucket](
#                 cache_k_bucket, cache_v_bucket,
#                 k_gen, v_gen, alpha, slot_probs,
#                 initial_k, initial_v, valid_mask,
#                 B=B, H=H, num_bucket=num_bucket,
#                 N_sample_per_bucket=N_sample_per_bucket,
#                 max_tokens_per_bucket=max_tokens_per_bucket,
#                 D=D,
#             )
#         torch.cuda.synchronize()
#         bucket_time = (time.perf_counter() - t0) / num_iterations * 1000

#         # ---- v2 hierarchical final-state scan ----
#         BT = 512
#         BD = min(128, triton.next_power_of_2(D))
#         num_chunks = (max_tokens_per_bucket + BT - 1) // BT

#         if max_tokens_per_bucket > BT and num_chunks > 1:
#             A_scalar = torch.empty(batch_size_scan, num_chunks, device=device, dtype=torch.bfloat16)
#             Bk = torch.empty(batch_size_scan, num_chunks, D, device=device, dtype=torch.bfloat16)
#             Bv = torch.empty(batch_size_scan, num_chunks, D, device=device, dtype=torch.bfloat16)
#             out_final_k = torch.empty(batch_size_scan, D, device=device, dtype=torch.bfloat16)
#             out_final_v = torch.empty(batch_size_scan, D, device=device, dtype=torch.bfloat16)

#             def grid_h(meta):
#                 return (triton.cdiv(D, BD), num_chunks, batch_size_scan)

#             def grid_o(meta):
#                 return (triton.cdiv(D, BD), batch_size_scan)

#             for _ in range(num_warmup):
#                 chunk_scan_chunk_params_kernel_v2[grid_h](
#                     k_gen, v_gen, alpha, slot_probs, valid_mask,
#                     A_scalar, Bk, Bv,
#                     T=max_tokens_per_bucket, D=D, BT=BT, BD=BD,
#                     B=B, H=H, num_bucket=num_bucket,
#                     N_sample_per_bucket=N_sample_per_bucket,
#                 )
#                 chunk_scan_combine_chunks_kernel_v2[grid_o](
#                     initial_k, initial_v,
#                     A_scalar, Bk, Bv,
#                     out_final_k, out_final_v,
#                     T=max_tokens_per_bucket, D=D, BT=BT, BD=BD,
#                     B=B, H=H, num_bucket=num_bucket,
#                     N_sample_per_bucket=N_sample_per_bucket,
#                     USE_INITIAL_STATE=True,
#                 )
#             torch.cuda.synchronize()
#             t0 = time.perf_counter()
#             for _ in range(num_iterations):
#                 chunk_scan_chunk_params_kernel_v2[grid_h](
#                     k_gen, v_gen, alpha, slot_probs, valid_mask,
#                     A_scalar, Bk, Bv,
#                     T=max_tokens_per_bucket, D=D, BT=BT, BD=BD,
#                     B=B, H=H, num_bucket=num_bucket,
#                     N_sample_per_bucket=N_sample_per_bucket,
#                 )
#                 chunk_scan_combine_chunks_kernel_v2[grid_o](
#                     initial_k, initial_v,
#                     A_scalar, Bk, Bv,
#                     out_final_k, out_final_v,
#                     T=max_tokens_per_bucket, D=D, BT=BT, BD=BD,
#                     B=B, H=H, num_bucket=num_bucket,
#                     N_sample_per_bucket=N_sample_per_bucket,
#                     USE_INITIAL_STATE=True,
#                 )
#             torch.cuda.synchronize()
#             hier_time = (time.perf_counter() - t0) / num_iterations * 1000

#             print(f"  Sequential (per-sample): {seq_time:.4f} ms")
#             print(f"  Per-bucket:              {bucket_time:.4f} ms")
#             print(f"  Hierarchical v2:         {hier_time:.4f} ms")
#             print(f"  Speedup vs sequential:   Per-bucket={seq_time / bucket_time:.2f}x, Hier-v2={seq_time / hier_time:.2f}x")
#             print(f"  Speedup vs per-bucket:   Hier-v2={bucket_time / hier_time:.2f}x")
#             print(f"  Num chunks: {num_chunks}")
#         else:
#             print(f"  Note: max_tokens_per_bucket={max_tokens_per_bucket} <= BT={BT}; hierarchical not applicable")
#             print(f"  Sequential (per-sample): {seq_time:.4f} ms")
#             print(f"  Per-bucket:              {bucket_time:.4f} ms")
#             print(f"  Speedup (per-bucket vs sequential): {seq_time / bucket_time:.2f}x")

#     print("\n" + "=" * 80)


# if __name__ == "__main__":
#     test_hierarchical_scan_correctness()
#     print("\n")
#     compare_sequential_vs_hierarchical_speed()