# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

import warnings
from typing import Optional, Union

import torch
import triton
import triton.language as tl
from einops import rearrange

import sys
import os

# Import fla modules
# Note: If you get a ValueError about 'bitnet' being already used by Transformers,
# this is because fla.models conflicts with transformers. You may need to modify
# fla/models/bitnet/__init__.py to use exist_ok=True or rename the model type.
from fla.ops.utils import prepare_chunk_indices
from fla.ops.utils import mean_pooling
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous

from ops.utils import _bitonic_merge

try:
    from flash_attn import flash_attn_func, flash_attn_varlen_func
except ImportError:
    warnings.warn(
        "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
        category=ImportWarning,
    )
    flash_attn_func = None


@triton.jit
def _sum_combine_fp32(a, b):
    return a + b


@triton.jit
def max_combine(a, b):
    return tl.maximum(a, b)

@triton.autotune(
    configs=[
        triton.Config({"BG": bg}, num_warps=num_warps),
        for num_warps in [1, 2, 4, 8, 16]
        for bg in [1]
    ],
    key=[],
)
@triton.jit
def parallel_arm_fwd_kernel(
    q,
    k,
    v,
    o,
    lse,
    scale,
    q_indices,
    k_indices,
    offsets,
    # token_indices,
    T_q,
    T_kv,
    H: tl.constexpr,
    HQ: tl.constexpr,
    G: tl.constexpr,
    K: tl.constexpr,
    V: tl.constexpr,
    S: tl.constexpr,
    BS: tl.constexpr,
    BK: tl.constexpr,
    BV: tl.constexpr,
    BG: tl.constexpr,
    # USE_OFFSETS: tl.constexpr,
    # USE_BLOCK_COUNTS: tl.constexpr
    # num_buffers_warp_spec: tl.constexpr,  # Internal Triton autotuner parameter (ignored, required by Triton 3.2.0+)
    # num_consumer_groups: tl.constexpr,  # Internal Triton autotuner parameter (ignored, required by Triton 3.2.0+)
    # reg_dec_producer: tl.constexpr,  # Internal Triton autotuner parameter (ignored, required by Triton 3.2.0+)
    # reg_inc_consumer: tl.constexpr,  # Internal Triton autotuner parameter (ignored, required by Triton 3.2.0+)
):
    # Grid: (T_q, NV, B*H*G/BG) where i_bh encodes (batch, head, group)
    # i_bh can be decomposed as: i_bh = i_b * (H * NG) + i_h * NG + i_g
    # Where: i_b in [0, B-1], i_h in [0, H-1], i_g in [0, NG-1]
    # (T_q, NV, triton.cdiv(B * H * G, BG)) = tl.program_id(0)
    i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    NG = G // BG  # how many groups in the current head

    # Extract batch, head, and group indices from flattened i_bh
    i_b = i_bh // (
        H * NG
    )  # Batch index: divide by number of (head, group) combinations
    remainder = i_bh % (H * NG)  # Remainder after removing batch contribution
    i_h = remainder // NG  # Head index: divide by NG groups
    i_g = remainder % NG  # Group index: remainder after dividing by NG

    # if i_t == 2 and i_b==0:
    #     breakpoint()
    # if i_h >0:
    #     breakpoint()
    # for query, i_h is the same
    # for key value:
    bos_q = i_b * T_q  # begin idx for query
    bos_kv = i_b * T_kv  # begin idx for key/value
    # (T_q, NV, B * H)
    k += (bos_kv * H + i_h) * K  # [2, 4096, 4, 64]  # (i_h)-th head
    v += (bos_kv * H + i_h) * V  # [2, 4096, 4, 64]
    # query_idx: torch.Size([2, 4096, 8, 1]), key_idx: torch.Size([2, 256, 4, 16])
    #                       torch.Size([16, 512, 16, 4]) torch.Size([16, 512, 1, 64]
    q_indices += (bos_q + i_t) * H * S + (
        i_h * G + i_g * BG + tl.arange(0, BG)
    ) * S  # NG -> BG: NG cause bug and many nans!!!
    k_indices += bos_kv * H + i_h
    NS = S
    # G -> NG
    # Base pointer: q[b, t, :, :] where b=i_b, t=i_t
    # For [B, T, HQ, K] tensor: offset = i_b * T * HQ * K + i_t * HQ * K = (bos + i_t) * HQ * K
    # Strides: (K, 1) means row stride K (next head), col stride 1 (next element)
    # bos;i_t;i_h;i_g

    # q_ptr = q + (bos + i_t) * HQ*K
    # rows = (i_h * G + i_g * BG + tl.arange(0, BG))[:, None]; cols = tl.arange(0, BK)[None, :]
    # b_q = tl.load(q_ptr + rows * K + cols)

    p_q = tl.make_block_ptr(
        q + (bos_q + i_t) * HQ * K,
        (HQ, K),
        (K, 1),
        (i_h * G + i_g * BG, 0),
        (BG, BK),
        (1, 0),
    )
    # the Q block is kept in the shared memory throughout the whole kernel
    # [G, BK]
    b_q = tl.load(p_q, boundary_check=(0, 1))
    b_q = (b_q * scale).to(b_q.dtype)

    p_o = tl.make_block_ptr(
        o + (bos_q + i_t) * HQ * V,
        (HQ, V),
        (V, 1),
        (i_h * G + i_g * BG, i_v * BV),
        (BG, BV),
        (1, 0),
    )
    p_lse = lse + (bos_q + i_t) * HQ + i_h * G + i_g * BG + tl.arange(0, BG)
    # [G, BV]
    b_o = tl.zeros([BG, BV], dtype=tl.float32)

    b_m = tl.full([BG], float("-inf"), dtype=tl.float32)
    b_mp = tl.full([BG], float("-inf"), dtype=tl.float32)  # previous max value
    b_acc = tl.zeros([BG], dtype=tl.float32)
    for i in range(NS):
        # i_s = tl.load(block_indices + i).to(tl.int32) * BS
        i_s = (
            tl.load(q_indices + i).to(tl.int32) * BS
        )  # definitely should not multiple with BS at all!
        # p_k_idx = tl.make_block_ptr(k_indices, (1, T), (1, H*1), (0, i_s), (BK, BS), (0, 1))
        # p_k_idx = tl.make_block_ptr(k_indices, (T, 1), (H*1, 1), (i_s, 0), (BS, 1), (1, 0))
        # b_k_idx = tl.load(p_k_idx, boundary_check=(0, 1))
        b_k_idx = tl.load(
            k_indices + (i_s[:, None] + tl.arange(0, BS)[None, :]) * H
        )  # need to multiple with H, otherwise cause nan bug, stride issues
        mask_head = tl.sum(i_t >= b_k_idx, 1) > 0
        if tl.sum(mask_head) > 0:
            # if i_s >= 0 and tl.sum(i_t>=b_k_idx)>0:   # i_s <= i_t and (T, NV, B * H) -> i_t, i_v, i_bh
            # mask_head =  tl.sum(i_t>=b_k_idx) # i_t>=b_k_idx #tl.sum(i_t>=b_k_idx, 1) > 0 #
            # q: [2, 4096, 4, 64] k,v: [2, 4096, 8, 64]
            p_k = (
                k
                + (i_s[:, None] + tl.arange(0, BS)[None, :])[:, None, :] * (H * K)
                + tl.arange(0, BK)[None, :, None]
            )
            p_v = (
                v
                + (i_s[:, None] + tl.arange(0, BS)[None, :])[:, :, None] * (H * V)
                + (i_v * BV + tl.arange(0, BV))[None, None, :]
            )
            # p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
            # p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
            # [BG, BK, BS]
            b_k = tl.load(p_k)
            # [BG, BS, BV]
            b_v = tl.load(p_v)
            # # [G, BS]

            b_s = b_q[:, :, None] * b_k
            b_s = tl.sum(
                b_s, 1
            )  # tl.reduce(b_s, combine_fn=_sum_combine_fp32, axis=1) #tl.dot(b_q, b_k)
            b_s = tl.where(
                i_t >= b_k_idx, b_s, float("-inf")
            )  # (i_s + tl.arange(0, BS)))[None, :]
            # [G]
            # Only update b_m when mask_head is true, otherwise keep previous value
            # Save old b_m to b_mp before updating (only when mask_head is true)
            # b_mp = tl.where(mask_head, b_m, b_mp)
            # Update b_m when mask_head is true
            # new_b_m = tl.maximum(b_m, tl.max(b_s, 1))
            # b_m = tl.where(mask_head, new_b_m, b_m)
            b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
            b_r = tl.exp(b_mp - b_m)
            # [G, BS]
            b_p = tl.exp(b_s - b_m[:, None])
            # [G]
            # Only update b_acc when mask_head is true, otherwise keep previous value
            new_b_acc = b_acc * b_r + tl.sum(b_p, 1)
            b_acc = tl.where(mask_head, new_b_acc, b_acc)
            # [G, BV]
            b_o = tl.where(
                mask_head[:, None],
                b_o * b_r[:, None] + tl.sum(b_p.to(b_q.dtype)[:, :, None] * b_v, 1),
                b_o,
            )
    if tl.sum(b_acc) != 0:
        b_mask = b_acc > 0
        b_o = b_o / tl.where(b_mask, b_acc, 1)[:, None]
        b_m += tl.log(tl.where(b_mask, b_acc, 1))
    # Normalize output only if we have valid accumulations
    # CRITICAL FIX: When b_acc == 0 (no valid attention), b_o should be zero
    # But if b_acc > 0 for some groups, normalize only those groups

    # b_mask = b_acc > 1e-10  # [BG] - which groups have valid attention (use small epsilon to avoid numerical issues)
    # b_acc_safe = tl.maximum(b_acc, 1e-10)  # Avoid division by zero
    # b_o = tl.where(b_mask[:, None], b_o / b_acc_safe[:, None], 0.0)  # Normalize valid, zero out invalid
    # b_m = tl.where(b_mask, b_m + tl.log(b_acc), b_m)  # Update lse only for valid groups

    # Store output - boundary_check=(0, 1) handles out-of-bounds automatically
    # But ensure we always store (boundary_check will mask invalid elements)
    # CRITICAL: Always call store - boundary_check will handle masking
    tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
    tl.store(p_lse, b_m.to(p_lse.dtype.element_ty))


@triton.jit
def parallel_arm_kernel_mask(
    q_indices,
    block_mask,
    T: tl.constexpr,
    H: tl.constexpr,
    S: tl.constexpr,
    BS: tl.constexpr,
    NS: tl.constexpr,
):
    """
    Triton kernel that populates the block_mask tensor.

    For each (token, batch, head, slot) position:
    1. Load the block index from q_indices
    2. If valid (0 <= block_idx < NS), mark that block as active in the mask

    Grid dimensions: (T, B, H*S) - one thread per (token, batch, head, slot) combination
    """
    i_t, i_b, i_hs = tl.program_id(0), tl.program_id(1), tl.program_id(2)  # (T, B, H*S)

    i_h, i_s = (
        i_hs // S,
        i_hs % S,
    )  # Extract head and slot indices from flattened head*slot dimension

    # Load the block index that this (batch, token, head, slot) should attend to
    # q_indices shape: [B, T, H, S], indexing: q_indices[i_b, i_t, i_h, i_s]
    b_i = tl.load(
        q_indices + i_b * T * H * S + i_t * H * S + i_h * S + i_s
    )  # block_indices
    # if (i_t == 0 and i_b ==0) and i_h == 1:
    #     breakpoint()
    # If the block index is valid, mark it as active in the mask
    # block_mask shape: [B, T, H, NS], indexing: block_mask[i_b, i_t, i_h, b_i] = True
    if b_i < NS and b_i >= 0:
        tl.store(
            block_mask + i_b * T * H * NS + i_t * H * NS + i_h * NS + b_i, 1.0
        )  # b_m.to(block_mask.dtype.element_ty)


@triton.jit
def parallel_arm_bwd_kernel_preprocess(o, do, delta, B: tl.constexpr, V: tl.constexpr):
    i_n = tl.program_id(0)
    o_d = tl.arange(0, B)
    m_d = o_d < V

    b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0)
    b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32)
    b_delta = tl.sum(b_o * b_do)

    tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty))


# @triton.heuristics(
#     {
#         "USE_OFFSETS": lambda args: args["offsets"] is not None,
#     }
# )
@triton.autotune(
    configs=[
        triton.Config({"BG": bg}, num_warps=num_warps)  # {"BG": bg},
        for num_warps in [1, 2, 4, 8, 16]
        for bg in [1]#, 2, 4]  # ,8] #,16]
    ],
    key=[],  # "BS", "BK", "BV"
)
@triton.jit(do_not_specialize=["T_q", "T_kv"])
def parallel_arm_bwd_kernel_dq(
    q,
    k,
    v,
    lse,
    delta,
    do,
    dq,
    scale,
    q_indices,
    k_indices,
    # block_indices,
    # block_counts,
    offsets,
    # token_indices,
    T_q,
    T_kv,
    B: tl.constexpr,
    H: tl.constexpr,
    HQ: tl.constexpr,
    G: tl.constexpr,
    K: tl.constexpr,
    V: tl.constexpr,
    S: tl.constexpr,
    BS: tl.constexpr,
    BK: tl.constexpr,
    BV: tl.constexpr,
    BG: tl.constexpr,
    # USE_OFFSETS: tl.constexpr,
    # USE_BLOCK_COUNTS: tl.constexpr
    # num_buffers_warp_spec: tl.constexpr,  # Internal Triton autotuner parameter (ignored, required by Triton 3.2.0+)
    # num_consumer_groups: tl.constexpr,  # Internal Triton autotuner parameter (ignored, required by Triton 3.2.0+)
    # reg_dec_producer: tl.constexpr,  # Internal Triton autotuner parameter (ignored, required by Triton 3.2.0+)
    # reg_inc_consumer: tl.constexpr,  # Internal Triton autotuner parameter (ignored, required by Triton 3.2.0+)
):
    i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    NG = G // BG
    # Extract batch, head, and group indices from flattened i_bh
    i_b = i_bh // (
        H * NG
    )  # Batch index: divide by number of (head, group) combinations
    remainder = i_bh % (H * NG)  # Remainder after removing batch contribution
    i_h = remainder // NG  # Head index: divide by NG groups
    i_g = remainder % NG  # Group index: remainder after dividing by NG
    # for query, i_h is the same
    # for key value:
    bos_q = i_b * T_q  # begin idx for query
    bos_kv = i_b * T_kv  # begin idx for key/value

    # Pointer offsets:
    # - q, do, lse, delta: shape [B, T_q, ...] -> offset by (bos_q + i_t)
    q += (bos_q + i_t) * HQ * K
    do += (bos_q + i_t) * HQ * V
    lse += (bos_q + i_t) * HQ
    delta += (bos_q + i_t) * HQ
    # - dq: shape [NV, B, T_q, HQ, K] where NV = number of V-dimension blocks
    #   The extra i_v * B * T_q term skips to the correct V-block slice:
    #   Each i_v processes a different BV-sized chunk of the V dimension in parallel,
    #   and results are summed later (see parallel_arm_bwd line 690: dq.sum(0))
    dq += (i_v * B * T_q + bos_q + i_t) * HQ * K
    # block_indices += (bos_q + i_t) * H*S + i_h * S

    # query_idx: same pattern as forward kernel - each group can select different KV pairs
    q_indices += (bos_q + i_t) * H * S + (
        i_h * G + i_g * BG + tl.arange(0, BG)
    ) * S  # i_h + i_g * NG -> i_h * G + i_g * BG: -> cause all nan
    k_indices += bos_kv * H + i_h

    # if USE_BLOCK_COUNTS:
    #     NS = tl.load(block_counts + (bos + i_t) * H + i_h)
    # else:
    NS = S

    k += (bos_kv * H + i_h) * K
    v += (bos_kv * H + i_h) * V

    b_dq = tl.zeros([BG, BK], dtype=tl.float32)

    # G -> NG: process only BG queries at a time
    p_q = tl.make_block_ptr(
        q, (HQ, K), (K, 1), (i_h * G + i_g * BG, 0), (BG, BK), (1, 0)
    )  # + (bos + i_t) * HQ*K
    p_dq = tl.make_block_ptr(
        dq, (HQ, K), (K, 1), (i_h * G + i_g * BG, 0), (BG, BK), (1, 0)
    )  # + (bos + i_t) * HQ*K

    # [BG, BK]
    b_q = tl.load(p_q, boundary_check=(0, 1))
    b_q = (b_q * scale).to(b_q.dtype)

    p_do = tl.make_block_ptr(
        do, (HQ, V), (V, 1), (i_h * G + i_g * BG, i_v * BV), (BG, BV), (1, 0)
    )  # + (bos + i_t) * HQ*V
    p_lse = lse + i_h * G + i_g * BG + tl.arange(0, BG)  # + (bos + i_t) * HQ
    p_delta = delta + i_h * G + i_g * BG + tl.arange(0, BG)  #  + (bos + i_t) * HQ

    # [BG, BV]
    b_do = tl.load(p_do, boundary_check=(0, 1))

    # [BG]
    b_lse = tl.load(p_lse)

    b_delta = tl.load(p_delta)
    # [BG, BK]
    # if tl.sum(b_lse)==-float('inf'):
    for i in range(NS):
        # i_s = tl.load(block_indices + i).to(tl.int32) * BS
        i_s = (
            tl.load(q_indices + i).to(tl.int32) * BS
        )  # each query in the group can select different KV pairs
        # p_k_idx = tl.make_block_ptr(k_indices, (1, T), (1, H*1), (0, i_s), (BK, BS), (0, 1))
        # b_k_idx = tl.load(p_k_idx, boundary_check=(0, 1))
        b_k_idx = tl.load(k_indices + (i_s[:, None] + tl.arange(0, BS)[None, :]) * H)
        mask_head = tl.sum(i_t >= b_k_idx, 1) > 0
        mask_head_combined = i_t >= b_k_idx  # & mask_head[:, None]
        if tl.sum(mask_head_combined) > 0:

            # [BG, BK, BS]
            p_k = (
                k
                + (i_s[:, None] + tl.arange(0, BS)[None, :])[:, None, :] * (H * K)
                + tl.arange(0, BK)[None, :, None]
            )
            # [BG, BS, BV]
            p_v = (
                v
                + (i_s[:, None] + tl.arange(0, BS)[None, :])[:, :, None] * (H * V)
                + (i_v * BV + tl.arange(0, BV))[None, None, :]
            )
            # [BG, BK, BS]

            b_k = tl.load(p_k)
            # [BG, BS, BV]
            b_v = tl.load(p_v)

            # [BG, BS]
            b_s = b_q[:, :, None] * b_k
            b_s = tl.sum(b_s, 1)  # [BG, BS]
            # b_s = tl.where(i_t >= b_k_idx, b_s, float('-inf'))  # [BG, BS]
            # [BG, BS]
            b_p = tl.exp(b_s - b_lse[:, None])
            b_p = tl.where(mask_head_combined, b_p, 0)  # mask out invalid positions

            # [BG, BV] @ [BG, BS, BV] -> [BG, BS]
            # Compute b_dp[BG, BS] = sum over BV of (b_do[BG, BV] * b_v[BG, BS, BV])
            # Broadcast b_do from [BG, BV] to [BG, 1, BV] and multiply with b_v [BG, BS, BV]
            # [BG, 1, BV]
            b_dp = tl.sum(b_do[:, None, :] * b_v, axis=2)  # [BG, BS]
            b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
            # [BG, BS] @ [BG, BK, BS] -> [BG, BK]
            # We need to compute: for each BG, b_ds[BS] @ b_k[BK, BS]^T -> [BK]
            # b_k is [BG, BK, BS], we need to transpose to [BG, BS, BK] for the dot product
            # Manual computation: b_ds[BG, BS] * b_k[BG, BK, BS] summed over BS -> [BG, BK]
            b_ds_expanded = b_ds[:, None, :].to(b_k.dtype)  # [BG, 1, BS]
            # Element-wise multiply and sum over BS dimension
            b_dq += tl.sum(b_ds_expanded * b_k, axis=2)  # [BG, BK]

    b_dq *= scale
    tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))


# @triton.heuristics({"USE_OFFSETS": lambda args: args["offsets"] is not None})
@triton.autotune(
    configs=[
        triton.Config({"BG": bg}, num_warps=num_warps)  # {"BG": bg},
        for num_warps in [1, 2, 4, 8, 16]
        for bg in [1]#, 2, 4]  # ,8] #,16]
    ],
    key=[], #"BS", "BK", "BV"],
)
@triton.jit(do_not_specialize=["T_q", "T_kv"])
def parallel_arm_bwd_kernel_dkv(
    q,
    k,
    v,
    lse,
    delta,
    do,
    dk,
    dv,
    q_indices,
    k_indices,
    block_mask,
    offsets,
    chunk_indices,
    scale,
    T_q,
    T_kv,
    B: tl.constexpr,
    H: tl.constexpr,
    HQ: tl.constexpr,
    G: tl.constexpr,
    K: tl.constexpr,
    V: tl.constexpr,
    M: tl.constexpr,
    BS: tl.constexpr,
    BK: tl.constexpr,
    BV: tl.constexpr,
    BG: tl.constexpr,
    # USE_OFFSETS: tl.constexpr,
    # num_buffers_warp_spec: tl.constexpr,  # Internal Triton autotuner parameter (ignored, required by Triton 3.2.0+)
    # num_consumer_groups: tl.constexpr,  # Internal Triton autotuner parameter (ignored, required by Triton 3.2.0+)
    # reg_dec_producer: tl.constexpr,  # Internal Triton autotuner parameter (ignored, required by Triton 3.2.0+)
    # reg_inc_consumer: tl.constexpr,  # Internal Triton autotuner parameter (ignored, required by Triton 3.2.0+)
):
    # (NV, NS, B * H): NS is number of blocks, B * H is number of heads of KV pairs
    i_v, i_s, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    NG = G // BG
    i_b, i_h = i_bh // H, i_bh % H
    # i_b, i_h = i_bh // (H * NG), i_bh % (H * NG)
    bos_q = i_b * T_q  # begin idx for query
    bos_kv = i_b * T_kv  # begin idx for key/value

    # i_g = i_bh // H % NG
    p_k = tl.make_block_ptr(
        k + (bos_kv * H + i_h) * K,
        (T_kv, K),
        (H * K, 1),
        (i_s * BS, 0),
        (BS, BK),
        (1, 0),
    )
    p_v = tl.make_block_ptr(
        v + (bos_kv * H + i_h) * V,
        (T_kv, V),
        (H * V, 1),
        (i_s * BS, i_v * BV),
        (BS, BV),
        (1, 0),
    )
    p_dk = tl.make_block_ptr(
        dk + (bos_kv * H + i_h) * K,
        (T_kv, K),
        (H * K, 1),
        (i_s * BS, 0),
        (BS, BK),
        (1, 0),
    )  # (i_v * B*T_kv*H + bos_kv * H + i_h) * K
    p_dv = tl.make_block_ptr(
        dv + (bos_kv * H + i_h) * V,
        (T_kv, V),
        (H * V, 1),
        (i_s * BS, i_v * BV),
        (BS, BV),
        (1, 0),
    )

    p_k_idx = tl.make_block_ptr(
        k_indices + (bos_kv * H + i_h),
        (T_kv, 1),
        (H * 1, 1),
        (i_s * BS, 0),
        (BS, 1),
        (1, 0),
    )

    # [BS, BK]
    b_k = tl.load(p_k, boundary_check=(0, 1))
    b_dk = tl.zeros([BS, BK], dtype=tl.float32)
    # p_k_idx = tl.make_block_ptr(k_indices, (1, T), (1, H*1), (0, i_s), (BK, BS), (0, 1))
    b_k_idx = tl.load(p_k_idx, boundary_check=(0, 1))
    b_k_idx = tl.reshape(b_k_idx, (BS))
    # [BS, BV]
    b_v = tl.load(p_v, boundary_check=(0, 1))
    b_dv = tl.zeros([BS, BV], dtype=tl.float32)
    # breakpoint()
    # given b_q, load b_v and b_k
    for i in range(T_q):  # Iterate over query positions

        for i_g in range(NG):
            # block_mask: [B, T_q, H, NS]
            # b_m = tl.load(block_mask + (bos_q + i) * H*M + i_h * M + i_s)
            b_m = tl.load(
                block_mask
                + (bos_q + i) * H * M
                + (i_h * G + i_g * BG + tl.arange(0, BG)) * M
                + i_s
            )  # is: index of blocks
            mask = i >= b_k_idx  # [BS, 1] -> broadcast to [BS, BG]
            b_m_mask = b_m[:, None] & mask[None, :]
            if tl.sum(b_m_mask) > 0:
                # Process each group separately - each group can have different KV selections
                group_start = i_h * G + i_g * BG
                p_q = tl.make_block_ptr(
                    q + (bos_q + i) * HQ * K,
                    (HQ, K),
                    (K, 1),
                    (group_start, 0),
                    (BG, BK),
                    (1, 0),
                )
                # [BG, BK]
                b_q = tl.load(p_q, boundary_check=(0, 1))
                b_q = (b_q * scale).to(b_q.dtype)

                p_do = tl.make_block_ptr(
                    do + (bos_q + i) * HQ * V,
                    (HQ, V),
                    (V, 1),
                    (group_start, i_v * BV),
                    (BG, BV),
                    (1, 0),
                )
                p_lse = lse + (bos_q + i) * HQ + group_start + tl.arange(0, BG)
                p_delta = delta + (bos_q + i) * HQ + group_start + tl.arange(0, BG)
                # [BG, BV]
                b_do = tl.load(p_do, boundary_check=(0, 1))
                # [BG]
                b_lse = tl.load(p_lse)
                b_delta = tl.load(p_delta)
                # [BS, BG] = [BS, BK] @ [BK, BG]
                # b_k is [BS, BK], b_q is [BG, BK], need [BS, BG]
                # Manual computation to handle BG < 16: b_k[BS, BK] @ b_q[BG, BK]^T
                b_k_expanded = b_k[None, :, :]  # [BS, 1, BK]
                b_q_expanded = b_q[:, None, :]  # [1, BG, BK]
                b_s = tl.sum(b_k_expanded * b_q_expanded, axis=2)  # [BS, BG]
                b_p = tl.exp(b_s - b_lse[:, None])  # [BS, BG]

                # Apply mask: b_k_idx is [BS, 1], need to broadcast to [BS, BG]
                # i is scalar, b_k_idx is [BS, 1], so (i >= b_k_idx) is [BS, 1]
                # mask = (i >= b_k_idx)  # [BS, 1] -> broadcast to [BS, BG]
                b_p = tl.where(b_m_mask, b_p, 0.0)  # [BS, BG]

                # [BS, BG] @ [BG, BV] -> [BS, BV]
                # Manual computation when BG < 16
                # breakpoint()
                b_dv += tl.sum(
                    b_p[:, :, None].to(b_do.dtype) * b_do[:, None, :], axis=0
                )  # [BS, BV] tl.sum(b_p[:, :, None] * b_do[:, None, :], axis=0)

                # [BS, BV] @ [BV, BG] -> [BS, BG]
                # b_v is [BS, BV], b_do is [BG, BV], need [BS, BG]
                # Manual computation to handle BG < 16
                b_dp = tl.sum(
                    b_v[None, :, :] * b_do[:, None, :], axis=2
                )  # [BS, BG] but [BG, BS] here
                # [BS, BG]
                b_ds = b_p * (b_dp - b_delta[:, None])  # [BG, BS]
                # [BS, BG] @ [BG, BK] -> [BS, BK]
                # Manual computation when BG < 16
                b_dk += tl.sum(
                    b_ds[:, :, None].to(b_q.dtype) * b_q[:, None, :], axis=0
                )  # [BS, BK]

    tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
    tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))


# save_fwd_kernel_tensors("store_data.pt",q, k, v, q_indices, k_indices, offsets,o , lse, do,block_mask,scale,block_size)
def save_fwd_kernel_tensors(
    filepath: str,
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    q_indices: torch.Tensor,
    k_indices: torch.Tensor,
    offsets: Optional[torch.Tensor] = None,
    o: Optional[torch.Tensor] = None,
    lse: Optional[torch.Tensor] = None,
    do: Optional[torch.Tensor] = None,
    block_mask: Optional[torch.Tensor] = None,
    scale: float = None,
    block_size: int = None,
    **kwargs,
):
    """
    Save tensors passed for debugging/reproducibility.

    Args:
        filepath: Path to save the tensors (should end with .pt)
        q: Query tensor [B, T, HQ, K]
        k: Key tensor [B, T, H, K]
        v: Value tensor [B, T, H, V]
        q_indices: Query indices [B, T, H, S]
        k_indices: Key indices [B, T, H]
        offsets: Optional offsets tensor [B+1]
        o: Optional output tensor [B, T, HQ, V]
        lse: Optional log-sum-exp tensor [B, T, HQ]
        do: Optional output gradient tensor [B, T, HQ, V]
        block_mask: Optional block mask tensor [B, T, H, NS]
        scale: Optional scale parameter
        block_size: Optional block_size parameter
        **kwargs: Additional parameters to save in metadata
    """
    save_dict = {
        "q": q.cpu(),
        "k": k.cpu(),
        "v": v.cpu(),
        "q_indices": q_indices.cpu(),
        "k_indices": k_indices.cpu(),
    }

    if offsets is not None:
        save_dict["offsets"] = offsets.cpu()
    if o is not None:
        save_dict["o"] = o.cpu()
    if lse is not None:
        save_dict["lse"] = lse.cpu()
    if do is not None:
        save_dict["do"] = do.cpu()
    if block_mask is not None:
        save_dict["block_mask"] = block_mask.cpu()

    # Store metadata
    metadata = {
        "q_shape": list(q.shape),
        "q_dtype": str(q.dtype),
        "k_shape": list(k.shape),
        "k_dtype": str(k.dtype),
        "v_shape": list(v.shape),
        "v_dtype": str(v.dtype),
        "q_indices_shape": list(q_indices.shape),
        "q_indices_dtype": str(q_indices.dtype),
        "k_indices_shape": list(k_indices.shape),
        "k_indices_dtype": str(k_indices.dtype),
    }

    if offsets is not None:
        metadata["offsets_shape"] = list(offsets.shape)
        metadata["offsets_dtype"] = str(offsets.dtype)
    if o is not None:
        metadata["o_shape"] = list(o.shape)
        metadata["o_dtype"] = str(o.dtype)
    if lse is not None:
        metadata["lse_shape"] = list(lse.shape)
        metadata["lse_dtype"] = str(lse.dtype)
    if do is not None:
        metadata["do_shape"] = list(do.shape)
        metadata["do_dtype"] = str(do.dtype)
    if block_mask is not None:
        metadata["block_mask_shape"] = list(block_mask.shape)
        metadata["block_mask_dtype"] = str(block_mask.dtype)
    if scale is not None:
        metadata["scale"] = scale
    if block_size is not None:
        metadata["block_size"] = block_size

    # Add any additional kwargs to metadata
    for key, value in kwargs.items():
        if key not in metadata:
            metadata[key] = value

    # Build list of saved tensor names for metadata
    saved_tensors = ["q", "k", "v", "q_indices", "k_indices"]
    if offsets is not None:
        saved_tensors.append("offsets")
    if o is not None:
        saved_tensors.append("o")
    if lse is not None:
        saved_tensors.append("lse")
    if do is not None:
        saved_tensors.append("do")
    if block_mask is not None:
        saved_tensors.append("block_mask")
    metadata["saved_tensors"] = saved_tensors

    save_dict["metadata"] = metadata

    torch.save(save_dict, filepath)
    print(f"[save_fwd_kernel_tensors] Tensors saved to {filepath}")
    print(f"  Saved tensors: {', '.join(saved_tensors)}")


def parallel_arm_fwd(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    q_indices: torch.LongTensor,
    k_indices: torch.LongTensor,
    block_size: int,
    scale: float,
    offsets: Optional[torch.LongTensor] = None,
    token_indices: Optional[torch.LongTensor] = None,
):
    B, T_kv, H, K = k.shape
    T_q = q.shape[1]
    V = v.shape[-1]
    S = q_indices.shape[-1]
    HQ = q.shape[2]
    G = HQ // H
    BS = block_size
    if torch.cuda.get_device_capability()[0] >= 9:
        BK = min(256, triton.next_power_of_2(K))
        BV = min(256, triton.next_power_of_2(V))
    else:
        BK = min(128, triton.next_power_of_2(K))
        BV = min(128, triton.next_power_of_2(V))
    NK = triton.cdiv(K, BK)
    NV = triton.cdiv(V, BV)
    assert NK == 1, "The key dimension can not be larger than 256"

    grid = lambda META: (
        T_q,
        NV,
        triton.cdiv(B * H * G, META["BG"]),
    )  # BH: block_size head B * H * G
    # grid = (T_q, NV, B * H)
    o = torch.empty(B, T_q, HQ, V, dtype=v.dtype, device=q.device)
    lse = torch.empty(B, T_q, HQ, dtype=torch.float, device=q.device)
    parallel_arm_fwd_kernel[grid](
        q=q,
        k=k,
        v=v,
        o=o,
        lse=lse,
        scale=scale,
        q_indices=q_indices,
        k_indices=k_indices,
        offsets=offsets,
        T_q=T_q,
        T_kv=T_kv,
        H=H,
        HQ=HQ,
        G=G,
        K=K,
        V=V,
        S=S,
        BS=BS,
        BK=BK,
        BV=BV,
        # BG=G
    )
    # upgrade triton to 3.5.1 solve all zero output
    return o, lse


def parallel_arm_block_mask(
    q_indices: torch.LongTensor,
    offsets: torch.LongTensor,
    block_size: int,
    T_kv: Optional[int] = None,
):
    """
    Creates a block mask that indicates which key-value blocks should be processed
    for each query token position in the backward pass.

    Purpose:
    - This function creates a sparse attention mask for HiR (Hierarchical Arm) attention
    - It determines which blocks need to be processed based on the arm indices in q_indices
    - The mask is used in the backward pass to skip unnecessary computations

    Inputs:
    - q_indices: [B, T_q, H, S] tensor containing block indices that each query should attend to
                 Each position (b, t, h, s) stores which block index the query should use
    - offsets: [B+1] tensor with sequence boundaries (optional, for variable-length sequences)
    - block_size: Size of each block (BS), used to compute number of blocks (NS = ceil(T_kv/BS))
    - T_kv: Key/value sequence length. If None, defaults to T_q from q_indices.shape[1]

    Output:
    - block_mask: [B, T_q, H, NS] boolean tensor
                  block_mask[b, t, h, block_idx] = True means that query at position (b, t, h)
                  should process key-value block at index block_idx

    How it works:
    1. Initialize a zero mask of shape [B, T_q, H, NS] where NS = number of KV blocks
    2. For each (batch, token, head, slot) position:
       - Read the block index from q_indices[b, t, h, s]
       - Mark that block as active: block_mask[b, t, h, block_idx] = True
    3. The resulting mask tells the backward kernel which blocks to process
       (see usage at line 429: if b_m: process the block)

    Example:
    - If q_indices[0, 5, 2, 0] = 3, then block_mask[0, 5, 2, 3] = True
      This means: batch 0, token 5, head 2 should process block 3
    """
    B, T_q, H, S = q_indices.shape
    BS = block_size
    if T_kv is None:
        T_kv = T_q  # Backward compatibility: assume T_q == T_kv if not provided
    NS = triton.cdiv(
        T_kv, BS
    )  # Number of blocks = ceil(KV_sequence_length / block_size)
    # Initialize mask: [B, T_q, H, NS] - all False initially
    block_mask = torch.zeros(
        B, T_q, H, NS, device=q_indices.device
    )  # dtype=torch.bool,

    # block_mask = torch.zeros(B, T_q, H, NS, dtype=torch.bool, device=block_indices.device) -> this cause error! since torch.bool format cause some latent issue
    # Kernel iterates over all (token, batch, head*slot) combinations
    # Grid: (T_q, B, H*S) - one thread per (token, batch, head, slot)
    parallel_arm_kernel_mask[(T_q, B, H * S)](
        q_indices=q_indices, block_mask=block_mask, T=T_q, H=H, S=S, BS=BS, NS=NS
    )
    block_mask = block_mask.to(torch.bool)
    return block_mask


def parallel_arm_bwd_preprocess(o: torch.Tensor, do: torch.Tensor):
    V = o.shape[-1]
    delta = torch.empty_like(o[..., 0], dtype=torch.float32)
    parallel_arm_bwd_kernel_preprocess[(delta.numel(),)](
        o=o,
        do=do,
        delta=delta,
        B=triton.next_power_of_2(V),
        V=V,
    )
    return delta


def parallel_arm_bwd(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    o: torch.Tensor,
    lse: torch.Tensor,
    do: torch.Tensor,
    q_indices: torch.Tensor,
    k_indices: torch.Tensor,
    block_size: int = 64,
    scale: float = None,
    offsets: Optional[torch.LongTensor] = None,
    token_indices: Optional[torch.LongTensor] = None,
):
    B, T_kv, H, K = k.shape
    T_q = q.shape[1]
    V = v.shape[-1]
    S = q_indices.shape[-1]
    HQ = q.shape[2]
    G = HQ // H
    BS = block_size
    BK = triton.next_power_of_2(K)
    BV = min(128, triton.next_power_of_2(v.shape[-1]))
    NV = triton.cdiv(V, BV)

    delta = parallel_arm_bwd_preprocess(o, do)

    dq = torch.empty(
        NV, *q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device
    )
    # Grid accounts for BG grouping: (T_q, NV, B * H * NG) where NG = G // BG
    # Use lambda to get BG from autotune config
    grid = lambda META: (T_q, NV, triton.cdiv(B * H * G, META["BG"]))
    # grid = (T_q, NV, B * H)
    parallel_arm_bwd_kernel_dq[grid](
        q=q,  # torch.Size([16, 512, 16, 64])
        k=k,  # torch.Size([16, 512, 1, 64])
        v=v,  # torch.Size([16, 512, 1, 64])
        lse=lse,  # torch.Size([16, 512, 16])
        delta=delta,  # torch.Size([16, 512, 16])
        do=do,  # torch.Size([16, 512, 16, 64])
        dq=dq,  # torch.Size([1, 16, 512, 16, 64])
        q_indices=q_indices,  # torch.Size([16, 512, 16, 4])
        k_indices=k_indices,  # torch.Size([16, 512, 1])
        offsets=offsets,
        scale=scale,
        T_q=T_q,
        T_kv=T_kv,
        B=B,
        H=H,
        HQ=HQ,
        G=G,
        K=K,
        V=V,
        S=S,
        BS=BS,
        BK=BK,
        BV=BV,
        # BG=G
    )
    dq = dq.sum(0)

    if offsets is not None:
        chunk_indices = prepare_chunk_indices(offsets, BS)
        NS = len(chunk_indices)
    else:
        chunk_indices = None
        NS = triton.cdiv(T_kv, BS)  # Use T_kv for number of KV blocks

    # [B, T_q, H, M]
    block_mask = parallel_arm_block_mask(
        q_indices, offsets, block_size, T_kv=T_kv
    )  # block_indices, block_counts
    # torch.save(block_mask.cpu(), 'block_mask.pt')  # Store block_mask
    # block_mask = torch.load('block_mask.pt', map_location=q.device).to(torch.bool)  # Load block_mask
    # torch.save(delta.cpu(), 'delta.pt')  # Store delta
    # delta = torch.load('delta.pt', map_location=q.device).to(delta.dtype)  # Load delta

    #
    dk = torch.empty(
        NV, *k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device
    )
    dv = torch.empty(v.shape, dtype=v.dtype, device=q.device)

    grid = (NV, NS, B * H)
    parallel_arm_bwd_kernel_dkv[grid](
        q=q,
        k=k,
        v=v,
        lse=lse,
        delta=delta,
        do=do,
        dk=dk,
        dv=dv,
        q_indices=q_indices,
        k_indices=k_indices,
        block_mask=block_mask,
        offsets=offsets,
        chunk_indices=chunk_indices,
        scale=scale,
        T_q=T_q,
        T_kv=T_kv,
        B=B,
        H=H,
        HQ=HQ,
        G=G,
        K=K,
        V=V,
        M=block_mask.shape[-1],
        BS=BS,
        BK=BK,
        BV=BV,
        # BG=G
    )
    dk = dk.sum(0)
    return dq, dk, dv


@torch.compile
class ParallelarmFunction(torch.autograd.Function):

    @staticmethod
    @contiguous
    @autocast_custom_fwd  # block_indices  block_counts
    def forward(ctx, q, k, v, q_indices, k_indices, block_size, scale, offsets):
        ctx.dtype = q.dtype
        # 2-d sequence indices denoting the offsets of tokens in each sequence
        # for example, if the passed `offsets` is [0, 2, 6],
        # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
        # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
        o, lse = parallel_arm_fwd(
            q=q,
            k=k,
            v=v,
            q_indices=q_indices,  # block_indices,
            k_indices=k_indices,  # k_indices
            block_size=block_size,
            scale=scale,
            offsets=offsets,
        )
        ctx.save_for_backward(q, k, v, o, lse)
        ctx.q_indices = q_indices
        ctx.k_indices = k_indices
        ctx.offsets = offsets
        ctx.block_size = block_size
        ctx.scale = scale
        return o.to(q.dtype)

    @staticmethod
    @contiguous
    @autocast_custom_bwd
    def backward(ctx, do):
        q, k, v, o, lse = ctx.saved_tensors
        dq, dk, dv = parallel_arm_bwd(
            q=q,
            k=k,
            v=v,
            o=o,
            lse=lse,
            do=do,
            q_indices=ctx.q_indices,
            k_indices=ctx.k_indices,
            block_size=ctx.block_size,
            scale=ctx.scale,
            offsets=ctx.offsets,
        )
        return (
            dq.to(q),
            dk.to(k),
            dv.to(v),
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
        )


def parallel_arm(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    g_topk: torch.Tensor,
    bucket_keys: torch.Tensor,
    bucket_values: torch.Tensor,
    q_indices: torch.LongTensor,
    k_indices: torch.LongTensor,
    block_size: int = 64,
    window_size: int = 0,
    scale: Optional[float] = None,
    cu_seqlens: Optional[torch.LongTensor] = None,
    head_first: bool = False,
    use_causal_swa: bool = True,
) -> torch.Tensor:
    r"""
    Args:
        q (torch.Tensor):
            queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
        k (torch.Tensor):
            keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
            GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
        v (torch.Tensor):
            values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
        g_cmp (torch.Tensor):
            Gate score for compressed attention of shape `[B, T, HQ]` if  `head_first=False` else `[B, HQ, T]`.
        g_slc (torch.Tensor):
            Gate score for selected attention of shape `[B, T, HQ]` if  `head_first=False` else `[B, HQ, T]`.
        g_swa (torch.Tensor):
            Gate score for sliding attentionof shape `[B, T, HQ]` if  `head_first=False` else `[B, HQ, T]`.
        block_indices (torch.LongTensor):
            Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
            `S` is the number of selected blocks for each query token, which is set to 16 in the paper.
            If `g_cmp` is provided, the passed `block_indices` will be ignored.
        block_counts (Optional[Union[torch.LongTensor, int]]):
            Number of selected blocks for each query.
            If a tensor is provided, with shape `[B, T, H]` if `head_first=False` else `[B, H, T]`,
            each query can select the same number of blocks.
            If not provided, it will default to 16.
        block_size (int):
            Selected block size. Default: 64.
        window_size (int):
            Sliding window size. Default: 0.
        scale (Optional[int]):
            Scale factor for attention scores.
            If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
        head_first (Optional[bool]):
            Whether the inputs are in the head-first format. Default: `False`.
        cu_seqlens (torch.LongTensor):
            Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
            consistent with the FlashAttention API.

    Returns:
        o (torch.Tensor):
            Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
    """
    # assert block_counts is not None, "block counts must be provided for selection"
    if scale is None:
        scale = k.shape[-1] ** -0.5
    if cu_seqlens is not None:
        assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
        # Convert cu_seqlens to offsets format expected by ParallelarmFunction
        offsets = prepare_chunk_offsets(cu_seqlens)
    else:
        offsets = None
    if head_first:
        q, k, v = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v))
        g_topk = map(
            lambda x: rearrange(x, "b h t -> b t h") if x is not None else None, g_topk
        )
        # q: [2, 4096, 8, 64] bucket_kyes: [2, 4096, 4, 64]
        # q_indices: [16->8, 4096, 1] k_indices: [8, 256, 16]. BS=16
    # print("parallel_arm apply", flush=True)
    o_slc = ParallelarmFunction.apply(
        q, bucket_keys, bucket_values, q_indices, k_indices, block_size, scale, offsets
    )
    # print("parallel_arm flash_attn", flush=True)
    o = o_slc * g_topk[..., [0]]
    if window_size > 0:
        if cu_seqlens is not None:
            max_seqlen = q.shape[1]
            o_swa = flash_attn_varlen_func(
                q.squeeze(0),
                k.squeeze(0),
                v.squeeze(0),
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                causal=use_causal_swa,
                window_size=(window_size - 1, 0),
            ).unsqueeze(0)
        else:
            o_swa = flash_attn_func(
                q, k, v, causal=use_causal_swa, window_size=(window_size - 1, 0)
            )
        o = torch.addcmul(o, o_swa, g_topk[..., [1]])  # TODO

    # print("parallel_arm rearrange", flush=True)
    if head_first:
        o = rearrange(o, "b t h d -> b h t d")

    # print("parallel_arm return", flush=True)
    return o
