import os
import warnings

import triton
import triton.language as tl

from hip_attn.v1_2.uvm_gpu_cache import load_tokens


@triton.jit
def apply_rope(
    seq,
    seq_rot,
    rope_t,
    COS,
    stride_cos_t,
    stride_cos_hid,
    SIN,
    stride_sin_t,
    stride_sin_hid,
    mask_t,
    idx_hid,
    HEAD_DIM: tl.constexpr,
):
    rope_cos = tl.load(
        COS
        + rope_t[:, None] * stride_cos_t
        + (idx_hid % (HEAD_DIM // 2))[None, :] * stride_cos_hid,
        mask=mask_t[:, None],
        other=0.0,
    ).to(seq.dtype)
    rope_sin = tl.load(
        SIN
        + rope_t[:, None] * stride_sin_t
        + (idx_hid % (HEAD_DIM // 2))[None, :] * stride_sin_hid,
        mask=mask_t[:, None],
        other=0.0,
    ).to(seq.dtype)

    seq = (seq * rope_cos + seq_rot * rope_sin).to(seq.dtype)

    return seq, rope_cos, rope_sin


@triton.jit
def load_rot_and_apply_rope(
    seq,
    rope_t,
    SEQ,
    stride_seq_t,
    stride_seq_hid,
    COS,
    stride_cos_t,
    stride_cos_hid,
    SIN,
    stride_sin_t,
    stride_sin_hid,
    idx_t,
    mask_t,
    idx_hid,
    HEAD_DIM: tl.constexpr,
):
    # rope_t = rope_t * 0 + 10

    idx_hid_rot = (idx_hid + HEAD_DIM // 2) % HEAD_DIM
    seq_rot = tl.load(
        SEQ + idx_t[:, None] * stride_seq_t + idx_hid_rot[None, :] * stride_seq_hid,
        mask=mask_t[:, None],
        other=0,
    )
    seq_rot = seq_rot * (((idx_hid + HEAD_DIM // 2)[None, :] < HEAD_DIM) * (-2) + 1).to(
        seq_rot.dtype
    )

    return apply_rope(
        seq,
        seq_rot,
        rope_t,
        COS,
        stride_cos_t,
        stride_cos_hid,
        SIN,
        stride_sin_t,
        stride_sin_hid,
        mask_t,
        idx_hid,
        HEAD_DIM,
    )


@triton.jit
def load_and_apply_rope(
    idx_t,
    rope_t,
    mask_t,
    SEQ,
    stride_seq_t,
    stride_seq_hid,
    USING_EXTEND,
    COS,
    stride_cos_t,
    stride_cos_hid,
    SIN,
    stride_sin_t,
    stride_sin_hid,
    DIM,
):
    idx_hid = tl.arange(0, DIM)

    seq = tl.load(
        SEQ + idx_t[:, None] * stride_seq_t + idx_hid[None, :] * stride_seq_hid,
        mask=mask_t[:, None],
        other=0,
    )

    if not USING_EXTEND:
        return seq, seq

    seq_roped, _, _ = load_rot_and_apply_rope(
        seq,
        rope_t,
        SEQ,
        stride_seq_t,
        stride_seq_hid,
        COS,
        stride_cos_t,
        stride_cos_hid,
        SIN,
        stride_sin_t,
        stride_sin_hid,
        idx_t,
        mask_t,
        idx_hid,
        DIM,
    )

    return seq, seq_roped


@triton.jit
def block_sparse_attention_backward_preprocess(
    CONTEXT,
    stride_context_bsz,
    stride_context_tdst,
    stride_context_head,
    stride_context_hid,
    GRAD_CONTEXT,
    stride_grad_context_bsz,
    stride_grad_context_tdst,
    stride_grad_context_head,
    stride_grad_context_hid,
    DELTA,
    stride_delta_bsz,
    stride_delta_tdst,
    stride_delta_head,
    BSZ,
    TDST,
    HEAD,
    BLOCK_TDST: tl.constexpr,
    HEAD_DIM: tl.constexpr,
):
    pid_batch_head = tl.program_id(1)
    idx_head = pid_batch_head % HEAD
    idx_bsz = pid_batch_head // HEAD

    idx_tdst = tl.program_id(0) * BLOCK_TDST + tl.arange(0, BLOCK_TDST)
    mask_tdst = idx_tdst < TDST

    idx_hid = tl.arange(0, HEAD_DIM)

    context = tl.load(
        CONTEXT
        + idx_bsz * stride_context_bsz
        + idx_tdst[:, None] * stride_context_tdst
        + idx_head * stride_context_head
        + idx_hid[None, :] * stride_context_hid,
        mask=mask_tdst[:, None],
    )
    grad_context = tl.load(
        GRAD_CONTEXT
        + idx_bsz * stride_grad_context_bsz
        + idx_tdst[:, None] * stride_grad_context_tdst
        + idx_head * stride_grad_context_head
        + idx_hid[None, :] * stride_grad_context_hid,
        mask=mask_tdst[:, None],
    ).to(tl.float32)

    delta = tl.sum(context * grad_context, axis=1)

    tl.store(
        DELTA
        + idx_bsz * stride_delta_bsz
        + idx_tdst * stride_delta_tdst
        + idx_head * stride_delta_head,
        value=delta,
    )


@triton.jit
def bsa_bwd_grad_kv_update(
    grad_k,
    grad_v,
    keys,
    keys_rot,
    values,
    Q,
    stride_q_tdst,
    stride_q_hid,
    GRAD_CONTEXT,
    stride_grad_context_tdst,
    stride_grad_context_hid,
    SCORE_MAXIMUM,
    stride_score_maximum_tdst,
    DELTA,
    stride_delta_tdst,
    USING_EXTEND: tl.constexpr,
    COS,
    stride_cos_t,
    stride_cos_hid,
    SIN,
    stride_sin_t,
    stride_sin_hid,
    EXCLUDE_SLIDING_WINDOW: tl.constexpr,
    SLIDING_WINDOW_SIZE,
    rope_tdst,
    idx_tdst,
    mask_tdst,
    pos_tdst,
    rope_tsrc,
    idx_tsrc,
    mask_tsrc,
    idx_hid,
    HEAD_DIM: tl.constexpr,
):
    # qT = tl.load(
    #     Q
    #     + idx_tdst[None, :] * stride_q_tdst
    #     + idx_hid[:, None] * stride_q_hid,
    #     mask=mask_tdst[None, :],
    #     other=0
    # )

    _, qTT = load_and_apply_rope(
        idx_tdst,
        rope_tdst,
        mask_tdst,
        Q,
        stride_q_tdst,
        stride_q_hid,
        USING_EXTEND,
        COS,
        stride_cos_t,
        stride_cos_hid,
        SIN,
        stride_sin_t,
        stride_sin_hid,
        HEAD_DIM,
    )
    qT = tl.trans(qTT)

    if USING_EXTEND:
        keys, rope_cos, rope_sin = apply_rope(
            keys,
            keys_rot,
            rope_tsrc,
            COS,
            stride_cos_t,
            stride_cos_hid,
            SIN,
            stride_sin_t,
            stride_sin_hid,
            mask_tsrc,
            idx_hid,
            HEAD_DIM,
        )

    # Load m before computing qk to reduce pipeline stall.
    m = tl.load(
        SCORE_MAXIMUM + idx_tdst * stride_score_maximum_tdst,
        mask=mask_tdst,
        other=0,
    )

    qkT = tl.dot(keys, qT)

    pT = tl.math.exp2(qkT - m[None, :])
    # Autoregressive masking.
    if EXCLUDE_SLIDING_WINDOW:
        # For sink and block sparse
        mask = ((pos_tdst - SLIDING_WINDOW_SIZE)[None, :] >= idx_tsrc[:, None]) & (
            mask_tdst[None, :] & mask_tsrc[:, None]
        )
    else:
        # For sliding window
        mask = (
            (pos_tdst[None, :] >= idx_tsrc[:, None])
            & ((pos_tdst - SLIDING_WINDOW_SIZE)[None, :] < idx_tsrc[:, None])
            & (mask_tdst[None, :] & mask_tsrc[:, None])
        )
    pT = tl.where(mask, pT, 0.0)
    do = tl.load(
        GRAD_CONTEXT
        + idx_tdst[:, None] * stride_grad_context_tdst
        + idx_hid[None, :] * stride_grad_context_hid,
        mask=mask_tdst[:, None],
        other=0,
    )
    # Compute dV.
    ppT = pT
    ppT = ppT.to(do.dtype)

    grad_v += tl.dot(ppT, do).to(grad_v.dtype) * mask_tsrc[:, None]

    # D (= delta) is pre-divided by ds_scale.
    Di = tl.load(
        DELTA + idx_tdst * stride_delta_tdst,
        mask=mask_tdst,
    )

    # Compute dP and dS.
    dpT = tl.dot(values, tl.trans(do)).to(
        tl.float32
    )  # NOTE(hj): this casting was exists from origianl
    dsT = pT * (dpT - Di[None, :])
    dsT = dsT.to(qT.dtype)

    if USING_EXTEND:
        grad_k_inner = tl.dot(dsT, tl.trans(qT)).to(grad_k.dtype)

        grad_k_cos = grad_k_inner * rope_cos

        grad_k_sin = grad_k_inner * rope_sin
        grad_k_sin = grad_k_sin * (
            ((idx_hid + HEAD_DIM // 2)[None, :] < HEAD_DIM) * (-2) + 1
        ).to(grad_k_sin.dtype)
        grad_k_sin_lo, grad_k_sin_hi = tl.split(
            tl.trans(
                tl.reshape(grad_k_sin, grad_k.shape[0], 2, HEAD_DIM // 2),
                0,
                2,
                1,
            )
        )
        grad_k_sin = tl.reshape(
            tl.trans(tl.join(grad_k_sin_hi, grad_k_sin_lo), 0, 2, 1),
            (grad_k.shape[0], HEAD_DIM),
        )

        grad_k_inner = grad_k_cos + grad_k_sin

        grad_k += grad_k_inner * mask_tsrc[:, None]
    else:
        grad_k += tl.dot(dsT, tl.trans(qT)).to(grad_k.dtype) * mask_tsrc[:, None]

    return grad_k, grad_v


@triton.jit
def bsa_bwd_grad_kv_sw(
    grad_k,
    grad_v,
    keys,
    keys_rot,
    values,
    sm_scale,
    Q,
    stride_q_tdst,
    stride_q_hid,
    GRAD_CONTEXT,
    stride_grad_context_tdst,
    stride_grad_context_hid,
    SCORE_MAXIMUM,
    stride_score_maximum_tdst,
    DELTA,
    stride_delta_tdst,
    POS_IDS,
    stride_pos_ids_tdst,
    USING_EXTEND: tl.constexpr,
    COS,
    stride_cos_t,
    stride_cos_hid,
    SIN,
    stride_sin_t,
    stride_sin_hid,
    HEAD,
    TDST,
    TSRC,
    SINK_TOKEN_SIZE,
    SLIDING_WINDOW_SIZE,
    NUM_K,
    BLOCK_SIZE_Q: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    start_tsrc,
    idx_tsrc,
    mask_tsrc,
):
    # return grad_k, grad_v
    idx_hid = tl.arange(0, HEAD_DIM)
    min_tsrc = tl.min(tl.where(mask_tsrc, idx_tsrc, TSRC))

    for start_tdst in range(0, TDST, BLOCK_SIZE_Q):
        # start_tdst = i_bdst * BLOCK_SIZE_Q + start_tsrc // BLOCK_SIZE_Q * BLOCK_SIZE_Q
        idx_tdst = start_tdst + tl.arange(0, BLOCK_SIZE_Q)
        mask_tdst = idx_tdst < TDST

        if (tl.sum(mask_tdst.to(tl.int32)) * tl.sum(mask_tsrc.to(tl.int32))) > 0:
            pos_tdst = tl.load(
                POS_IDS + idx_tdst * stride_pos_ids_tdst,
                mask=mask_tdst,
                other=0,
            )
            mask_tdst = mask_tdst & (pos_tdst >= 0)

            if (
                (tl.max(pos_tdst) - SLIDING_WINDOW_SIZE - BLOCK_SIZE_Q)
                <= (min_tsrc + BLOCK_SIZE_K)
            ) & (tl.max(pos_tdst) >= min_tsrc):
                # rope_tdst = tl.full(
                #     (BLOCK_SIZE_Q,),
                #     SINK_TOKEN_SIZE + SLIDING_WINDOW_SIZE + NUM_K * BLOCK_SIZE_K - 1,
                #     tl.int64,
                # )
                rope_tdst = (
                    tl.arange(0, BLOCK_SIZE_Q)
                    + SINK_TOKEN_SIZE
                    + SLIDING_WINDOW_SIZE
                    + NUM_K * BLOCK_SIZE_K
                    - BLOCK_SIZE_Q
                )

                # offset = (idx_tsrc - tl.max(pos_tdst)) * mask_tsrc
                # rope_tsrc = (
                #     SINK_TOKEN_SIZE
                #     + SLIDING_WINDOW_SIZE
                #     + NUM_K * BLOCK_SIZE_K
                #     - 1
                #     + offset
                # )
                rope_tsrc = (
                    idx_tsrc
                    - tl.max(pos_tdst)
                    + SINK_TOKEN_SIZE
                    + SLIDING_WINDOW_SIZE
                    + NUM_K * BLOCK_SIZE_K
                    - 1
                )
                rope_tsrc = tl.maximum(0, rope_tsrc)

                grad_k, grad_v = bsa_bwd_grad_kv_update(
                    grad_k,
                    grad_v,
                    keys,
                    keys_rot,
                    values,
                    Q,
                    stride_q_tdst,
                    stride_q_hid,
                    GRAD_CONTEXT,
                    stride_grad_context_tdst,
                    stride_grad_context_hid,
                    SCORE_MAXIMUM,
                    stride_score_maximum_tdst,
                    DELTA,
                    stride_delta_tdst,
                    USING_EXTEND,
                    COS,
                    stride_cos_t,
                    stride_cos_hid,
                    SIN,
                    stride_sin_t,
                    stride_sin_hid,
                    False,  # do not exclude sliding window == i am sliding window
                    SLIDING_WINDOW_SIZE,
                    rope_tdst,
                    idx_tdst,
                    mask_tdst,
                    pos_tdst,
                    rope_tsrc,
                    idx_tsrc,
                    mask_tsrc,
                    idx_hid,
                    HEAD_DIM,
                )

    return grad_k, grad_v


@triton.jit
def bsa_bwd_grad_kv_sink(
    grad_k,
    grad_v,
    keys,
    keys_rot,
    values,
    sm_scale,
    Q,
    stride_q_tdst,
    stride_q_hid,
    GRAD_CONTEXT,
    stride_grad_context_tdst,
    stride_grad_context_hid,
    SCORE_MAXIMUM,
    stride_score_maximum_tdst,
    DELTA,
    stride_delta_tdst,
    POS_IDS,
    stride_pos_ids_tdst,
    USING_EXTEND: tl.constexpr,
    COS,
    stride_cos_t,
    stride_cos_hid,
    SIN,
    stride_sin_t,
    stride_sin_hid,
    HEAD,
    TDST,
    TSRC,
    SINK_TOKEN_SIZE,
    SLIDING_WINDOW_SIZE,
    NUM_K,
    BLOCK_SIZE_Q: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    start_tsrc,
    idx_tsrc,
    mask_tsrc,
):
    idx_hid = tl.arange(0, HEAD_DIM)
    mask_tsrc = idx_tsrc < SINK_TOKEN_SIZE

    for start_bdst in range(tl.cdiv(TDST - start_tsrc, BLOCK_SIZE_Q)):
        start_tdst = (
            start_bdst * BLOCK_SIZE_Q + start_tsrc // BLOCK_SIZE_Q * BLOCK_SIZE_Q
        )
        idx_tdst = start_tdst + tl.arange(0, BLOCK_SIZE_Q)
        mask_tdst = idx_tdst < TDST

        pos_tdst = tl.load(
            POS_IDS + idx_tdst * stride_pos_ids_tdst,
            mask=mask_tdst,
            other=0,
        )
        mask_tdst = mask_tdst & (pos_tdst >= 0)

        # rope_tdst = tl.full(
        #     (BLOCK_SIZE_Q,),
        #     SINK_TOKEN_SIZE + SLIDING_WINDOW_SIZE + NUM_K * BLOCK_SIZE_K - 1,
        #     tl.int64,
        # )
        rope_tdst = (
            tl.arange(0, BLOCK_SIZE_Q)
            + SINK_TOKEN_SIZE
            + SLIDING_WINDOW_SIZE
            + NUM_K * BLOCK_SIZE_K
            - BLOCK_SIZE_Q
        )

        rope_tsrc = idx_tsrc

        grad_k, grad_v = bsa_bwd_grad_kv_update(
            grad_k,
            grad_v,
            keys,
            keys_rot,
            values,
            Q,
            stride_q_tdst,
            stride_q_hid,
            GRAD_CONTEXT,
            stride_grad_context_tdst,
            stride_grad_context_hid,
            SCORE_MAXIMUM,
            stride_score_maximum_tdst,
            DELTA,
            stride_delta_tdst,
            USING_EXTEND,
            COS,
            stride_cos_t,
            stride_cos_hid,
            SIN,
            stride_sin_t,
            stride_sin_hid,
            True,
            SLIDING_WINDOW_SIZE,
            rope_tdst,
            idx_tdst,
            mask_tdst,
            pos_tdst,
            rope_tsrc,
            idx_tsrc,
            mask_tsrc,
            idx_hid,
            HEAD_DIM,
        )

    return grad_k, grad_v


# The main inner-loop logic for computing dK and dV.
@triton.jit
def bsa_bwd_grad_kv(
    grad_k,
    grad_v,
    Q,
    stride_q_tdst,
    stride_q_hid,
    keys,
    keys_rot,
    values,
    sm_scale,
    GRAD_CONTEXT,
    stride_grad_context_tdst,
    stride_grad_context_hid,
    SCORE_MAXIMUM,
    stride_score_maximum_tdst,
    DELTA,
    stride_delta_tdst,
    POS_IDS,
    stride_pos_ids_tdst,
    HEAD,
    TDST,
    TSRC,
    SINK_TOKEN_SIZE,
    SLIDING_WINDOW_SIZE,
    NUM_K,
    BLOCK_SIZE_Q: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    BLOCK_COLWISE_BK: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    start_tsrc,
    idx_tsrc,
    mask_tsrc,
    INDICES_COLMAJOR_COLSTARTS,
    stride_indices_colmajor_colstarts_bsrc,
    INDICES_COLMAJOR_ROWIS,
    stride_indices_colmajor_rowis_z,
    INDICES_COLMAJOR_IXS,
    stride_indices_colmajor_ixs_z,
    INDICES_COLMAJOR_NNZ,
    USING_EXTEND: tl.constexpr,
    COS,
    stride_cos_t,
    stride_cos_hid,
    SIN,
    stride_sin_t,
    stride_sin_hid,
):
    idx_hid = tl.arange(0, HEAD_DIM)

    idx_bsrc = (start_tsrc // BLOCK_SIZE_K) + tl.arange(0, BLOCK_COLWISE_BK)
    mask_bsrc = (idx_bsrc < (TSRC // BLOCK_SIZE_K)) & (
        idx_bsrc >= (SINK_TOKEN_SIZE // BLOCK_SIZE_K)
    )
    mask_tsrc = idx_tsrc >= SINK_TOKEN_SIZE

    nnz_start = tl.load(
        INDICES_COLMAJOR_COLSTARTS + idx_bsrc * stride_indices_colmajor_colstarts_bsrc,
        mask=mask_bsrc,
        other=0,
    )
    nnz_end = tl.load(
        INDICES_COLMAJOR_COLSTARTS
        + (idx_bsrc + 1) * stride_indices_colmajor_colstarts_bsrc,
        mask=mask_bsrc,
        other=0,
    )
    nnz_start = tl.minimum(nnz_start, INDICES_COLMAJOR_NNZ)
    nnz_end = tl.minimum(tl.maximum(nnz_start, nnz_end), INDICES_COLMAJOR_NNZ)
    # NOTE: worse case loop length
    nnz_max_count = tl.sum(nnz_end - nnz_start)

    # tl.device_print('asdf', nnz_end - nnz_start)

    # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
    # tl.static_assert(BLOCK_SIZE_K % BLOCK_SIZE_Q == 0)
    # current_tdst = start_tdst
    # step_tdst = BLOCK_SIZE_Q

    idx_z = tl.zeros((BLOCK_COLWISE_BK,), dtype=tl.int32)

    for _ in range(nnz_max_count):
        mask_z = idx_z < (nnz_end - nnz_start)
        start_tdst = tl.load(
            INDICES_COLMAJOR_ROWIS
            + (idx_z + nnz_start) * stride_indices_colmajor_rowis_z,
            mask=mask_z,
            other=TDST,
        )
        min_start_tdst = tl.min(start_tdst)
        mask_z = mask_z & (start_tdst == min_start_tdst)
        num_active = tl.sum(mask_z.to(tl.int32))
        if num_active > 0:
            mask_z_src = tl.reshape(
                tl.reshape(mask_z, (BLOCK_COLWISE_BK, 1)).broadcast_to(
                    BLOCK_COLWISE_BK, BLOCK_SIZE_K
                ),
                BLOCK_COLWISE_BK * BLOCK_SIZE_K,
            )
            mask_tsrc_inner = mask_tsrc & mask_z_src

            idx_tdst = min_start_tdst + tl.arange(0, BLOCK_SIZE_Q)
            mask_tdst = idx_tdst < TDST
            # rope_tdst = tl.full(
            #     (BLOCK_SIZE_Q,),
            #     SINK_TOKEN_SIZE + SLIDING_WINDOW_SIZE + NUM_K * BLOCK_SIZE_K - 1,
            #     tl.int64,
            # )
            rope_tdst = (
                tl.arange(0, BLOCK_SIZE_Q)
                + SINK_TOKEN_SIZE
                + SLIDING_WINDOW_SIZE
                + NUM_K * BLOCK_SIZE_K
                - BLOCK_SIZE_Q
            )

            pos_tdst = tl.load(
                POS_IDS + idx_tdst * stride_pos_ids_tdst,
                mask=mask_tdst,
                other=0,
            )
            mask_tdst = mask_tdst & (pos_tdst >= 0)

            idx_bk = tl.load(
                INDICES_COLMAJOR_IXS
                + (idx_z + nnz_start) * stride_indices_colmajor_ixs_z,
                mask=mask_z,
            )
            rope_tsrc = (
                tl.reshape(
                    idx_bk[:, None] * BLOCK_SIZE_K
                    + tl.arange(0, BLOCK_SIZE_K)[None, :],
                    BLOCK_SIZE_K * BLOCK_COLWISE_BK,
                )
                + SINK_TOKEN_SIZE
            ) * mask_tsrc

            grad_k, grad_v = bsa_bwd_grad_kv_update(
                grad_k,
                grad_v,
                keys,
                keys_rot,
                values,
                Q,
                stride_q_tdst,
                stride_q_hid,
                GRAD_CONTEXT,
                stride_grad_context_tdst,
                stride_grad_context_hid,
                SCORE_MAXIMUM,
                stride_score_maximum_tdst,
                DELTA,
                stride_delta_tdst,
                USING_EXTEND,
                COS,
                stride_cos_t,
                stride_cos_hid,
                SIN,
                stride_sin_t,
                stride_sin_hid,
                True,
                SLIDING_WINDOW_SIZE,
                rope_tdst,
                idx_tdst,
                mask_tdst,
                pos_tdst,
                rope_tsrc,
                idx_tsrc,
                mask_tsrc_inner,
                idx_hid,
                HEAD_DIM,
            )

            idx_z = tl.where(mask_z, idx_z + 1, idx_z)

    return grad_k, grad_v


@triton.jit
def bsa_bwd_grad_q_update(
    grad_q,
    queries,
    grad_context,
    delta_i,
    score_maximum,
    K,
    stride_k_tsrc,
    stride_k_hid,
    V,
    stride_v_tsrc,
    stride_v_hid,
    USING_EXTEND,
    COS,
    stride_cos_t,
    stride_cos_hid,
    SIN,
    stride_sin_t,
    stride_sin_hid,
    EXCLUDE_SLIDING_WINDOW: tl.constexpr,
    SLIDING_WINDOW_SIZE,
    idx_tdst,
    pos_tdst,
    mask_tdst,
    rope_tsrc,
    idx_tsrc,
    mask_tsrc,
    idx_hid,
    HEAD_DIM: tl.constexpr,
):
    # kT = tl.load(
    #     K
    #     + idx_tsrc[None, :] * stride_k_tsrc
    #     + idx_hid[:, None] * stride_k_hid,
    #     mask=mask_tsrc[None, :],
    #     other=0
    # )

    _, kTT = load_and_apply_rope(
        idx_tsrc,
        rope_tsrc,
        mask_tsrc,
        K,
        stride_k_tsrc,
        stride_k_hid,
        USING_EXTEND,
        COS,
        stride_cos_t,
        stride_cos_hid,
        SIN,
        stride_sin_t,
        stride_sin_hid,
        HEAD_DIM,
    )
    kT = tl.trans(kTT)

    vT = tl.load(
        V + idx_tsrc[None, :] * stride_v_tsrc + idx_hid[:, None] * stride_v_hid,
        mask=mask_tsrc[None, :],
        other=0,
    )

    qk = tl.dot(queries, kT.to(queries.dtype))
    p = tl.math.exp2(qk - score_maximum)
    # Autoregressive masking.
    if EXCLUDE_SLIDING_WINDOW:
        mask = ((pos_tdst - SLIDING_WINDOW_SIZE)[:, None] >= idx_tsrc[None, :]) & (
            mask_tdst[:, None] & mask_tsrc[None, :]
        )
    else:
        mask = (
            (pos_tdst[:, None] >= idx_tsrc[None, :])
            & ((pos_tdst - SLIDING_WINDOW_SIZE)[:, None] < idx_tsrc[None, :])
            & (mask_tdst[:, None] & mask_tsrc[None, :])
        )
    p = tl.where(mask, p, 0.0)
    # Compute dP and dS.
    grad_p = tl.dot(grad_context, vT.to(grad_context.dtype))
    grad_s = p * (grad_p - delta_i[:, None])
    grad_s = grad_s.to(kT.dtype)
    # Compute dQ.
    # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
    grad_q += tl.dot(grad_s, tl.trans(kT))

    return grad_q


@triton.jit
def bsa_bwd_grad_q_sw(
    grad_q,
    queries,
    grad_context,
    score_maximum,
    delta_i,
    K,
    stride_k_tsrc,
    stride_k_hid,
    V,
    stride_v_tsrc,
    stride_v_hid,
    USING_EXTEND: tl.constexpr,
    COS,
    stride_cos_t,
    stride_cos_hid,
    SIN,
    stride_sin_t,
    stride_sin_hid,
    HEAD_DIM: tl.constexpr,
    start_tdst,
    idx_tdst,
    mask_tdst,
    pos_tdst,
    TSRC,
    NUM_K,
    SINK_TOKEN_SIZE,
    SLIDING_WINDOW_SIZE,
    BLOCK_SIZE_Q: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    BLOCK_KS: tl.constexpr,
):
    idx_hid = tl.arange(0, HEAD_DIM)
    CURR_TSRC = tl.max(tl.where(mask_tdst, pos_tdst, 0))
    i_tsrc_range_start = tl.maximum(0, CURR_TSRC - SLIDING_WINDOW_SIZE - BLOCK_SIZE_Q)

    BLOCK_STEP: tl.constexpr = BLOCK_SIZE_K * BLOCK_KS
    for i_tsrc in range(i_tsrc_range_start, CURR_TSRC, BLOCK_STEP):
        # idx_tsrc = (
        #     end_tsrc
        #     - SLIDING_WINDOW_SIZE
        #     + 1
        #     + i_bsrc * BLOCK_STEP
        #     + tl.arange(0, BLOCK_STEP)
        # )
        idx_tsrc = i_tsrc + tl.arange(0, BLOCK_STEP)
        mask_tsrc = idx_tsrc <= CURR_TSRC
        # rope_tsrc = (
        #     i_bsrc * BLOCK_STEP
        #     + tl.arange(0, BLOCK_STEP)
        #     - BLOCK_SIZE_Q
        #     # + 2048
        #     + SINK_TOKEN_SIZE
        #     + NUM_K * BLOCK_SIZE_K
        #     # + 1024 + 256
        # )# * 0 + 1000
        rope_tsrc = (
            idx_tsrc
            - tl.max(pos_tdst)
            + SINK_TOKEN_SIZE
            + SLIDING_WINDOW_SIZE
            + NUM_K * BLOCK_SIZE_K
            - 1
        )
        rope_tsrc = tl.maximum(0, rope_tsrc)
        # tl.device_print('test', SINK_TOKEN_SIZE + NUM_K * BLOCK_SIZE_K)

        grad_q = bsa_bwd_grad_q_update(
            grad_q,
            queries,
            grad_context,
            delta_i,
            score_maximum,
            K,
            stride_k_tsrc,
            stride_k_hid,
            V,
            stride_v_tsrc,
            stride_v_hid,
            USING_EXTEND,
            COS,
            stride_cos_t,
            stride_cos_hid,
            SIN,
            stride_sin_t,
            stride_sin_hid,
            False,
            SLIDING_WINDOW_SIZE,
            idx_tdst,
            pos_tdst,
            mask_tdst,
            rope_tsrc,
            idx_tsrc,
            mask_tsrc,
            idx_hid,
            HEAD_DIM,
        )
    return grad_q


@triton.jit
def bsa_bwd_grad_q_sink(
    grad_q,
    queries,
    grad_context,
    score_maximum,
    delta_i,
    K,
    stride_k_tsrc,
    stride_k_hid,
    V,
    stride_v_tsrc,
    stride_v_hid,
    USING_EXTEND,
    COS,
    stride_cos_t,
    stride_cos_hid,
    SIN,
    stride_sin_t,
    stride_sin_hid,
    HEAD_DIM: tl.constexpr,
    idx_tdst,
    mask_tdst,
    pos_tdst,
    TSRC,
    SINK_TOKEN_SIZE,
    SLIDING_WINDOW_SIZE,
    BLOCK_SIZE_K: tl.constexpr,
    BLOCK_KS: tl.constexpr,
):
    idx_hid = tl.arange(0, HEAD_DIM)

    BLOCK_STEP: tl.constexpr = BLOCK_SIZE_K * BLOCK_KS
    for i_bsrc in range(tl.cdiv(SINK_TOKEN_SIZE, BLOCK_STEP)):
        idx_tsrc = i_bsrc * BLOCK_STEP + tl.arange(0, BLOCK_STEP)
        mask_tsrc = (idx_tsrc < TSRC) & (idx_tsrc < SINK_TOKEN_SIZE)
        rope_tsrc = idx_tsrc

        grad_q = bsa_bwd_grad_q_update(
            grad_q,
            queries,
            grad_context,
            delta_i,
            score_maximum,
            K,
            stride_k_tsrc,
            stride_k_hid,
            V,
            stride_v_tsrc,
            stride_v_hid,
            USING_EXTEND,
            COS,
            stride_cos_t,
            stride_cos_hid,
            SIN,
            stride_sin_t,
            stride_sin_hid,
            True,
            SLIDING_WINDOW_SIZE,
            idx_tdst,
            pos_tdst,
            mask_tdst,
            rope_tsrc,
            idx_tsrc,
            mask_tsrc,
            idx_hid,
            HEAD_DIM,
        )
    return grad_q


# the main inner-loop logic for computing dQ
@triton.jit
def bsa_bwd_grad_q(
    grad_q,
    queries,
    grad_context,
    score_maximum,
    delta_i,
    K,
    stride_k_tsrc,
    stride_k_hid,
    V,
    stride_v_tsrc,
    stride_v_hid,
    USING_EXTEND,
    COS,
    stride_cos_t,
    stride_cos_hid,
    SIN,
    stride_sin_t,
    stride_sin_hid,
    HEAD_DIM: tl.constexpr,
    idx_tdst,
    mask_tdst,
    pos_tdst,
    TSRC,
    SINK_TOKEN_SIZE,
    SLIDING_WINDOW_SIZE,
    idx_bsz,
    idx_head,
    idx_bdst,
    HEAD: tl.constexpr,
    INDICES,
    stride_indices_bdst,
    stride_indices_k,
    NUM_KS,
    BLOCK_SIZE_K: tl.constexpr,
    BLOCK_KS: tl.constexpr,
):
    idx_hid = tl.arange(0, HEAD_DIM)

    for idx_bks in range(tl.cdiv(NUM_KS, BLOCK_KS)):
        idx_ks = tl.arange(0, BLOCK_KS) + idx_bks * BLOCK_KS
        mask_ks = idx_ks < NUM_KS
        idx_tsrc = tl.load(
            INDICES + idx_bdst * stride_indices_bdst + idx_ks * stride_indices_k,
            mask=mask_ks,
        )
        idx_tsrc = tl.where(mask_ks, idx_tsrc, TSRC)
        idx_tsrc = idx_tsrc[:, None] + tl.arange(0, BLOCK_SIZE_K)[None, :]
        idx_tsrc = tl.reshape(idx_tsrc, idx_tsrc.numel)
        mask_tsrc = (idx_tsrc < TSRC) & (idx_tsrc >= SINK_TOKEN_SIZE)
        rope_tsrc = SINK_TOKEN_SIZE + (
            idx_bks * BLOCK_KS * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K * BLOCK_KS)
        )

        grad_q = bsa_bwd_grad_q_update(
            grad_q,
            queries,
            grad_context,
            delta_i,
            score_maximum,
            K,
            stride_k_tsrc,
            stride_k_hid,
            V,
            stride_v_tsrc,
            stride_v_hid,
            USING_EXTEND,
            COS,
            stride_cos_t,
            stride_cos_hid,
            SIN,
            stride_sin_t,
            stride_sin_hid,
            True,
            SLIDING_WINDOW_SIZE,
            idx_tdst,
            pos_tdst,
            mask_tdst,
            rope_tsrc,
            idx_tsrc,
            mask_tsrc,
            idx_hid,
            HEAD_DIM,
        )
    return grad_q


@triton.jit
def block_sparse_attention_backward(
    Q,
    stride_q_bsz,
    stride_q_tdst,
    stride_q_head,
    stride_q_hid,
    K,
    stride_k_bsz,
    stride_k_tsrc,
    stride_k_head_kv,
    stride_k_hid,
    V,
    stride_v_bsz,
    stride_v_tsrc,
    stride_v_head_kv,
    stride_v_hid,
    COS,
    stride_cos_t,
    stride_cos_hid,
    SIN,
    stride_sin_t,
    stride_sin_hid,
    GRAD_CONTEXT,
    stride_grad_context_bsz,
    stride_grad_context_tdst,
    stride_grad_context_head,
    stride_grad_context_hid,
    GRAD_Q,
    stride_grad_q_bsz,
    stride_grad_q_tdst,
    stride_grad_q_head,
    stride_grad_q_hid,
    GRAD_K,
    stride_grad_k_bsz,
    stride_grad_k_tsrc,
    stride_grad_k_head_kv,
    stride_grad_k_hid,
    GRAD_V,
    stride_grad_v_bsz,
    stride_grad_v_tsrc,
    stride_grad_v_head_kv,
    stride_grad_v_hid,
    SCORE_MAXIMUM,
    stride_score_maximum_bsz,
    stride_score_maximum_tdst,
    stride_score_maximum_head,
    DELTA,
    stride_delta_bsz,
    stride_delta_tdst,
    stride_delta_head,
    POS_IDS,
    stride_pos_ids_bsz,
    stride_pos_ids_tdst,
    INDICES,
    stride_indices_n,
    stride_indices_bdst,
    stride_indices_k,
    NUM_K,
    BLOCK_KS: tl.constexpr,
    INDICES_COLMAJOR_COLSTARTS,
    stride_indices_colmajor_colstarts_bh,
    stride_indices_colmajor_colstarts_bsrc,
    INDICES_COLMAJOR_ROWIS,
    stride_indices_colmajor_rowis_bh,
    stride_indices_colmajor_rowis_z,
    INDICES_COLMAJOR_IXS,
    stride_indices_colmajor_ixs_bh,
    stride_indices_colmajor_ixs_z,
    INDICES_COLMAJOR_NNZ,
    sm_scale,
    HEAD,
    HEAD_GROUP,
    TDST,
    TSRC,
    SINK_TOKEN_SIZE,
    SLIDING_WINDOW_SIZE,
    HEAD_DIM: tl.constexpr,
    USING_EXTEND: tl.constexpr,
    BLOCK_SIZE_Q: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    BLOCK_COLWISE_BK: tl.constexpr,
):
    # M: minor (tscan)
    # N: major (t)

    LN2: tl.constexpr = 0.6931471824645996  # = ln(2)

    idx_bsz = tl.program_id(1) // HEAD
    idx_head = tl.program_id(1) % HEAD

    idx_bt = tl.program_id(0)

    # offset pointers for batch/head
    Q += idx_bsz * stride_q_bsz + idx_head * stride_q_head
    K += idx_bsz * stride_k_bsz + (idx_head // HEAD_GROUP) * stride_k_head_kv
    V += idx_bsz * stride_v_bsz + (idx_head // HEAD_GROUP) * stride_v_head_kv
    GRAD_CONTEXT += (
        idx_bsz * stride_grad_context_bsz + idx_head * stride_grad_context_head
    )
    GRAD_Q += idx_bsz * stride_grad_q_bsz + idx_head * stride_grad_q_head
    GRAD_K += (
        idx_bsz * stride_grad_k_bsz + (idx_head // HEAD_GROUP) * stride_grad_k_head_kv
    )
    GRAD_V += (
        idx_bsz * stride_grad_v_bsz + (idx_head // HEAD_GROUP) * stride_grad_v_head_kv
    )
    SCORE_MAXIMUM += (
        idx_bsz * stride_score_maximum_bsz + idx_head * stride_score_maximum_head
    )
    DELTA += idx_bsz * stride_delta_bsz + idx_head * stride_delta_head
    POS_IDS += idx_bsz * stride_pos_ids_bsz
    idx_bh = idx_bsz * HEAD + idx_head
    INDICES += idx_bh * stride_indices_n
    INDICES_COLMAJOR_COLSTARTS += idx_bh * stride_indices_colmajor_colstarts_bh
    INDICES_COLMAJOR_ROWIS += idx_bh * stride_indices_colmajor_rowis_bh
    INDICES_COLMAJOR_IXS += idx_bh * stride_indices_colmajor_ixs_bh

    # load scales
    idx_hid = tl.arange(0, HEAD_DIM)

    start_tsrc = idx_bt * (BLOCK_SIZE_K * BLOCK_COLWISE_BK)
    # start_tdst = start_tsrc

    # TODO: split this into two kernels (if needed)
    if start_tsrc < TSRC:
        idx_tsrc = start_tsrc + tl.arange(0, BLOCK_SIZE_K * BLOCK_COLWISE_BK)
        mask_tsrc = (
            idx_tsrc
            < TSRC
            # & (tl.arange(0, BLOCK_SIZE_K_PADDED) < BLOCK_SIZE_K)
        )

        grad_v = tl.zeros([BLOCK_SIZE_K * BLOCK_COLWISE_BK, HEAD_DIM], dtype=tl.float32)
        grad_k = tl.zeros([BLOCK_SIZE_K * BLOCK_COLWISE_BK, HEAD_DIM], dtype=tl.float32)

        # load K and V: they stay in SRAM throughout the inner loop.
        keys = tl.load(
            K + idx_tsrc[:, None] * stride_k_tsrc + idx_hid[None, :] * stride_k_hid,
            mask=mask_tsrc[:, None],
            other=0,
        )
        if USING_EXTEND:
            idx_hid_rot = (idx_hid + HEAD_DIM // 2) % HEAD_DIM
            keys_rot = tl.load(
                K
                + idx_tsrc[:, None] * stride_k_tsrc
                + idx_hid_rot[None, :] * stride_k_hid,
                mask=mask_tsrc[:, None],
                other=0,
            )
            keys_rot = keys_rot * (
                ((idx_hid + HEAD_DIM // 2)[None, :] < HEAD_DIM) * (-2) + 1
            ).to(keys_rot.dtype)
        else:
            keys_rot = None

        values = tl.load(
            V + idx_tsrc[:, None] * stride_v_tsrc + idx_hid[None, :] * stride_v_hid,
            mask=mask_tsrc[:, None],
            other=0,
        )

        # start_tdst = start_tsrc // BLOCK_SIZE_Q * BLOCK_SIZE_Q
        # num_steps = (TDST - start_tdst) // BLOCK_SIZE_Q

        # Compute dK and dV for non-masked blocks.
        grad_k, grad_v = bsa_bwd_grad_kv(
            grad_k,
            grad_v,
            Q,
            stride_q_tdst,
            stride_q_hid,
            keys,
            keys_rot,
            values,
            sm_scale,
            GRAD_CONTEXT,
            stride_grad_context_tdst,
            stride_grad_context_hid,
            SCORE_MAXIMUM,
            stride_score_maximum_tdst,
            DELTA,
            stride_delta_tdst,
            POS_IDS,
            stride_pos_ids_tdst,
            HEAD,
            TDST,
            TSRC,
            SINK_TOKEN_SIZE,
            SLIDING_WINDOW_SIZE,
            NUM_K,
            BLOCK_SIZE_Q,
            BLOCK_SIZE_K,
            BLOCK_COLWISE_BK,
            HEAD_DIM,
            start_tsrc,
            idx_tsrc,
            mask_tsrc,
            INDICES_COLMAJOR_COLSTARTS,
            stride_indices_colmajor_colstarts_bsrc,
            INDICES_COLMAJOR_ROWIS,
            stride_indices_colmajor_rowis_z,
            INDICES_COLMAJOR_IXS,
            stride_indices_colmajor_ixs_z,
            INDICES_COLMAJOR_NNZ,
            USING_EXTEND,
            COS,
            stride_cos_t,
            stride_cos_hid,
            SIN,
            stride_sin_t,
            stride_sin_hid,
        )

        # Compute dK and dV for sink tokens
        if (SINK_TOKEN_SIZE > 0) & (start_tsrc < SINK_TOKEN_SIZE):
            grad_k, grad_v = bsa_bwd_grad_kv_sink(
                grad_k,
                grad_v,
                keys,
                keys_rot,
                values,
                sm_scale,
                Q,
                stride_q_tdst,
                stride_q_hid,
                GRAD_CONTEXT,
                stride_grad_context_tdst,
                stride_grad_context_hid,
                SCORE_MAXIMUM,
                stride_score_maximum_tdst,
                DELTA,
                stride_delta_tdst,
                POS_IDS,
                stride_pos_ids_tdst,
                USING_EXTEND,
                COS,
                stride_cos_t,
                stride_cos_hid,
                SIN,
                stride_sin_t,
                stride_sin_hid,
                HEAD,
                TDST,
                TSRC,
                SINK_TOKEN_SIZE,
                SLIDING_WINDOW_SIZE,
                NUM_K,
                BLOCK_SIZE_Q,
                BLOCK_SIZE_K,
                HEAD_DIM,
                start_tsrc,
                idx_tsrc,
                mask_tsrc,
            )

        # Compute dK and dV for sliding window
        if SLIDING_WINDOW_SIZE > 0:
            grad_k, grad_v = bsa_bwd_grad_kv_sw(
                grad_k,
                grad_v,
                keys,
                keys_rot,
                values,
                sm_scale,
                Q,
                stride_q_tdst,
                stride_q_hid,
                GRAD_CONTEXT,
                stride_grad_context_tdst,
                stride_grad_context_hid,
                SCORE_MAXIMUM,
                stride_score_maximum_tdst,
                DELTA,
                stride_delta_tdst,
                POS_IDS,
                stride_pos_ids_tdst,
                USING_EXTEND,
                COS,
                stride_cos_t,
                stride_cos_hid,
                SIN,
                stride_sin_t,
                stride_sin_hid,
                HEAD,
                TDST,
                TSRC,
                SINK_TOKEN_SIZE,
                SLIDING_WINDOW_SIZE,
                NUM_K,
                BLOCK_SIZE_Q,
                BLOCK_SIZE_K,
                HEAD_DIM,
                start_tsrc,
                idx_tsrc,
                mask_tsrc,
            )

        # FIXME: need to remove atomic
        tl.atomic_add(
            GRAD_V
            + idx_tsrc[:, None] * stride_grad_v_tsrc
            + idx_hid[None, :] * stride_grad_v_hid,
            val=grad_v.to(GRAD_V.dtype.element_ty),
            mask=mask_tsrc[:, None],
        )

        # Write back dK.
        grad_k *= sm_scale
        tl.atomic_add(
            GRAD_K
            + idx_tsrc[:, None] * stride_grad_k_tsrc
            + idx_hid[None, :] * stride_grad_k_hid,
            val=grad_k.to(GRAD_K.dtype.element_ty),
            mask=mask_tsrc[:, None],
        )

    # THIS BLOCK DOES DQ:
    start_tdst = idx_bt * BLOCK_SIZE_Q

    if start_tdst < TDST:
        idx_tdst = start_tdst + tl.arange(0, BLOCK_SIZE_Q)
        mask_tdst = idx_tdst < TDST
        pos_tdst = tl.load(
            POS_IDS + idx_tdst * stride_pos_ids_tdst,
            mask=mask_tdst,
            other=0,
        )
        mask_tdst = mask_tdst & (pos_tdst >= 0)

        queries = tl.load(
            Q + idx_tdst[:, None] * stride_q_tdst + idx_hid[None, :] * stride_q_hid,
            mask=mask_tdst[:, None],
            other=0,
        )

        if USING_EXTEND:
            # rope_tdst = tl.full(
            #     (BLOCK_SIZE_Q,),
            #     SINK_TOKEN_SIZE + SLIDING_WINDOW_SIZE + NUM_K * BLOCK_SIZE_K - 1,
            #     dtype=tl.int64,
            # )  # * 0 + 1024 + 256 + 1024 - 1
            rope_tdst = (
                tl.arange(0, BLOCK_SIZE_Q)
                + SINK_TOKEN_SIZE
                + SLIDING_WINDOW_SIZE
                + NUM_K * BLOCK_SIZE_K
                - BLOCK_SIZE_Q
            )
            queries, rope_cos, rope_sin = load_rot_and_apply_rope(
                queries,
                rope_tdst,
                Q,
                stride_q_tdst,
                stride_q_hid,
                COS,
                stride_cos_t,
                stride_cos_hid,
                SIN,
                stride_sin_t,
                stride_sin_hid,
                idx_tdst,
                mask_tdst,
                idx_hid,
                HEAD_DIM,
            )

        grad_context = tl.load(
            GRAD_CONTEXT
            + idx_tdst[:, None] * stride_grad_context_tdst
            + idx_hid[None, :] * stride_grad_context_hid,
            mask=mask_tdst[:, None],
            other=0,
        )

        score_maximum = tl.load(
            SCORE_MAXIMUM + idx_tdst * stride_score_maximum_tdst,
            mask=mask_tdst,
        )
        score_maximum = score_maximum[:, None]

        # D (= delta) is pre-divided by ds_scale.
        delta_i = tl.load(
            DELTA + idx_tdst * stride_delta_tdst,
            mask=mask_tdst,
        )

        grad_q = tl.zeros([BLOCK_SIZE_Q, HEAD_DIM], dtype=tl.float32)

        # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
        tl.static_assert(BLOCK_SIZE_Q % BLOCK_SIZE_K == 0)
        grad_q = bsa_bwd_grad_q(
            grad_q,
            queries,
            grad_context,
            score_maximum,
            delta_i,
            K,
            stride_k_tsrc,
            stride_k_hid,
            V,
            stride_v_tsrc,
            stride_v_hid,
            USING_EXTEND,
            COS,
            stride_cos_t,
            stride_cos_hid,
            SIN,
            stride_sin_t,
            stride_sin_hid,
            HEAD_DIM,
            idx_tdst,
            mask_tdst,
            pos_tdst,
            TSRC,
            SINK_TOKEN_SIZE,
            SLIDING_WINDOW_SIZE,
            idx_bsz,
            idx_head,
            idx_bt,
            HEAD,
            INDICES,
            stride_indices_bdst,
            stride_indices_k,
            NUM_K,
            BLOCK_SIZE_K,
            BLOCK_KS,
        )

        if SINK_TOKEN_SIZE > 0:
            grad_q = bsa_bwd_grad_q_sink(
                grad_q,
                queries,
                grad_context,
                score_maximum,
                delta_i,
                K,
                stride_k_tsrc,
                stride_k_hid,
                V,
                stride_v_tsrc,
                stride_v_hid,
                USING_EXTEND,
                COS,
                stride_cos_t,
                stride_cos_hid,
                SIN,
                stride_sin_t,
                stride_sin_hid,
                HEAD_DIM,
                idx_tdst,
                mask_tdst,
                pos_tdst,
                TSRC,
                SINK_TOKEN_SIZE,
                SLIDING_WINDOW_SIZE,
                BLOCK_SIZE_K,
                BLOCK_KS,
            )

        if SLIDING_WINDOW_SIZE > 0:
            grad_q = bsa_bwd_grad_q_sw(
                grad_q,
                queries,
                grad_context,
                score_maximum,
                delta_i,
                K,
                stride_k_tsrc,
                stride_k_hid,
                V,
                stride_v_tsrc,
                stride_v_hid,
                USING_EXTEND,
                COS,
                stride_cos_t,
                stride_cos_hid,
                SIN,
                stride_sin_t,
                stride_sin_hid,
                HEAD_DIM,
                start_tdst,
                idx_tdst,
                mask_tdst,
                pos_tdst,
                TSRC,
                NUM_K,
                SINK_TOKEN_SIZE,
                SLIDING_WINDOW_SIZE,
                BLOCK_SIZE_Q,
                BLOCK_SIZE_K,
                BLOCK_KS,
            )

        # q' = q * cos + q_rot * sin

        if USING_EXTEND:
            grad_q_cos = grad_q * rope_cos

            grad_q_sin = grad_q * rope_sin
            grad_q_sin = grad_q_sin * (
                ((idx_hid + HEAD_DIM // 2)[None, :] < HEAD_DIM) * (-2) + 1
            ).to(grad_q_sin.dtype)
            grad_q_sin_lo, grad_q_sin_hi = tl.split(
                tl.trans(
                    tl.reshape(grad_q_sin, BLOCK_SIZE_Q, 2, HEAD_DIM // 2),
                    0,
                    2,
                    1,
                )
            )
            grad_q_sin = tl.reshape(
                tl.trans(tl.join(grad_q_sin_hi, grad_q_sin_lo), 0, 2, 1),
                (BLOCK_SIZE_Q, HEAD_DIM),
            )

            grad_q = grad_q_cos + grad_q_sin

        # Write back dQ.
        grad_q *= LN2
        tl.store(
            GRAD_Q
            + idx_tdst[:, None] * stride_grad_q_tdst
            + idx_hid[None, :] * stride_grad_q_hid,
            value=grad_q,
        )


__all__ = [
    block_sparse_attention_backward,
    block_sparse_attention_backward_preprocess,
]
