import math
import torch
import triton
import triton.language as tl
from typing import Optional

from ..common.utils import is_nvidia_hopper, use_cuda_graph

NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
NUM_STAGES = [2, 3, 4]


@triton.jit
def exp_negative_only(x):
    # avoid overflow when compute gate as exp(gamma[:, None] - gamma[None, :])
    # if use exp(x) directly, it will get inf at upper triangular part
    # so we use exp(x) only when x <= 0 and set the other part to exp(-inf)
    return tl.exp(tl.where(x <= 0, x, float("-inf")))


@triton.heuristics(
    {
        "USE_BETA": lambda args: args["beta_ptr"] is not None,
        "IS_VARLEN": lambda args: args["cu_seq_len"] is not None,
    }
)
@triton.autotune(
    configs=[
        triton.Config(
            {"BLOCK_SIZE_KD": 128, "BLOCK_SIZE_VD": 64}, num_warps=4, num_stages=3
        ),
        triton.Config(
            {"BLOCK_SIZE_KD": 64, "BLOCK_SIZE_VD": 64}, num_warps=4, num_stages=3
        ),
        triton.Config(
            {"BLOCK_SIZE_KD": 32, "BLOCK_SIZE_VD": 32}, num_warps=2, num_stages=3
        ),
    ],
    key=["CHUNK_SIZE", "head_dim_qk", "head_dim_v", "USE_BETA"],
    use_cuda_graph=use_cuda_graph,
)
@triton.jit
def chunk_fwd_kernel_sq_sk(
    q_ptr,
    k_ptr,
    v_ptr,
    sq_ptr,
    sk_ptr,
    alpha_ptr,
    beta_ptr,
    chunk_state_ptr,
    scale,
    # shapes
    cu_seq_len,
    cu_chunk_len,
    seq_len,
    num_chunks,
    num_heads,
    head_dim_qk,
    head_dim_v,
    # strides
    stride_q_b,
    stride_q_n,
    stride_q_h,
    stride_q_d,
    stride_k_b,
    stride_k_n,
    stride_k_h,
    stride_k_d,
    stride_v_b,
    stride_v_n,
    stride_v_h,
    stride_v_d,
    stride_sq_b,
    stride_sq_n,
    stride_sq_h,
    stride_sq_d,
    stride_sk_b,
    stride_sk_n,
    stride_sk_h,
    stride_sk_d,
    stride_a_b,
    stride_a_n,
    stride_a_h,
    stride_b_b,
    stride_b_n,
    stride_b_h,
    stride_cs_b,
    stride_cs_n,
    stride_cs_h,
    stride_cs_kd,
    stride_cs_vd,
    # block sizes
    CHUNK_SIZE: tl.constexpr,
    BLOCK_SIZE_KD: tl.constexpr,
    BLOCK_SIZE_VD: tl.constexpr,
    # option
    USE_BETA: tl.constexpr,
    IS_VARLEN: tl.constexpr,
):
    pid_vd, pid_c, pid_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    pid_b, pid_h = pid_bh // num_heads, pid_bh % num_heads
    # move ptr to the start of this batch
    if IS_VARLEN:
        seq_start = tl.load(cu_seq_len + pid_b).to(tl.int32)
        seq_end = tl.load(cu_seq_len + pid_b + 1).to(tl.int32)
        seq_len = seq_end - seq_start
        chunk_start = tl.load(cu_chunk_len + pid_b).to(tl.int32)
        q_ptr = q_ptr + seq_start * stride_q_n + pid_h * stride_q_h
        k_ptr = k_ptr + seq_start * stride_k_n + pid_h * stride_k_h
        v_ptr = v_ptr + seq_start * stride_v_n + pid_h * stride_v_h
        sq_ptr = sq_ptr + seq_start * stride_sq_n + pid_h * stride_sq_h
        sk_ptr = sk_ptr + seq_start * stride_sk_n + pid_h * stride_sk_h
        alpha_ptr = alpha_ptr + seq_start * stride_a_n + pid_h * stride_a_h
        beta_ptr = beta_ptr + seq_start * stride_b_n + pid_h * stride_b_h
        chunk_state_ptr = (
            chunk_state_ptr + chunk_start * stride_cs_n + pid_h * stride_cs_h
        )
    else:
        q_ptr = q_ptr + pid_b * stride_q_b + pid_h * stride_q_h
        k_ptr = k_ptr + pid_b * stride_k_b + pid_h * stride_k_h
        v_ptr = v_ptr + pid_b * stride_v_b + pid_h * stride_v_h
        sq_ptr = sq_ptr + pid_b * stride_sq_b + pid_h * stride_sq_h
        sk_ptr = sk_ptr + pid_b * stride_sk_b + pid_h * stride_sk_h
        alpha_ptr = alpha_ptr + pid_b * stride_a_b + pid_h * stride_a_h
        beta_ptr = beta_ptr + pid_b * stride_b_b + pid_h * stride_b_h
        chunk_state_ptr = chunk_state_ptr + pid_b * stride_cs_b + pid_h * stride_cs_h
    if pid_c * CHUNK_SIZE + 1 >= seq_len:
        return
    # ptrs
    q1_ptrs = tl.make_block_ptr(
        base=q_ptr,
        shape=(seq_len, head_dim_qk),
        strides=(stride_q_n, stride_q_d),
        offsets=(pid_c * CHUNK_SIZE + 1, 0),  # add 1 to compute S_{t}q_{t+1}
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_KD),
        order=(1, 0),
    )
    q2_ptrs = tl.make_block_ptr(
        base=k_ptr,
        shape=(seq_len, head_dim_qk),
        strides=(stride_k_n, stride_k_d),
        offsets=(pid_c * CHUNK_SIZE + 1, 0),  # add 1 to compute S_{t}q_{t+1}
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_KD),
        order=(1, 0),
    )
    k_ptrs = tl.make_block_ptr(
        base=k_ptr,
        shape=(head_dim_qk, seq_len),
        strides=(stride_k_d, stride_k_n),
        offsets=(0, pid_c * CHUNK_SIZE),
        block_shape=(BLOCK_SIZE_KD, CHUNK_SIZE),
        order=(0, 1),
    )
    chunk_state_ptrs = tl.make_block_ptr(
        base=chunk_state_ptr + pid_c * stride_cs_n,
        shape=(head_dim_qk, head_dim_v),
        strides=(stride_cs_kd, stride_cs_vd),
        offsets=(0, pid_vd * BLOCK_SIZE_VD),
        block_shape=(BLOCK_SIZE_KD, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    alpha_ptrs = tl.make_block_ptr(
        base=alpha_ptr,
        shape=(seq_len,),
        strides=(stride_a_n,),
        offsets=(pid_c * CHUNK_SIZE,),
        block_shape=(CHUNK_SIZE,),
        order=(0,),
    )
    if USE_BETA:
        beta_ptrs = tl.make_block_ptr(
            base=beta_ptr,
            shape=(seq_len,),
            strides=(stride_b_n,),
            offsets=(pid_c * CHUNK_SIZE,),
            block_shape=(CHUNK_SIZE,),
            order=(0,),
        )
    # history output and qk buffer
    acc_sq1 = tl.zeros([CHUNK_SIZE, BLOCK_SIZE_VD], dtype=tl.float32)
    acc_sq2 = tl.zeros([CHUNK_SIZE, BLOCK_SIZE_VD], dtype=tl.float32)
    acc_q1k = tl.zeros([CHUNK_SIZE, CHUNK_SIZE], dtype=tl.float32)
    acc_q2k = tl.zeros([CHUNK_SIZE, CHUNK_SIZE], dtype=tl.float32)
    # for loop at key head dim, compute inter chunk output and qk
    for _ in range(0, head_dim_qk, BLOCK_SIZE_KD):
        q1 = tl.load(q1_ptrs, boundary_check=(0, 1), padding_option="zero")
        q2 = tl.load(q2_ptrs, boundary_check=(0, 1), padding_option="zero")
        k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
        state = tl.load(
            chunk_state_ptrs, boundary_check=(0, 1), padding_option="zero"
        ).to(q1.dtype)
        # inter chunk output, [CHUNK_SIZE, BLOCK_SIZE_KD] @ [BLOCK_SIZE_KD, BLOCK_SIZE_VD] -> [CHUNK_SIZE, BLOCK_SIZE_VD]
        acc_sq1 += tl.dot(q1, state)
        acc_sq2 += tl.dot(q2, state)
        # qk result, [CHUNK_SIZE, BLOCK_SIZE_KD] @ [BLOCK_SIZE_KD, CHUNK_SIZE] -> [CHUNK_SIZE, CHUNK_SIZE]
        acc_q1k += tl.dot(q1, k)
        acc_q2k += tl.dot(q2, k)
        # update ptrs
        q1_ptrs = tl.advance(q1_ptrs, (0, BLOCK_SIZE_KD))
        q2_ptrs = tl.advance(q2_ptrs, (0, BLOCK_SIZE_KD))
        k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_KD, 0))
        chunk_state_ptrs = tl.advance(chunk_state_ptrs, (BLOCK_SIZE_KD, 0))
    # load alpha and beta, compute gate
    alpha = tl.load(alpha_ptrs, boundary_check=(0,), padding_option="zero")
    alpha = tl.cumsum(alpha, axis=0)
    gate = exp_negative_only(alpha[:, None] - alpha[None, :])
    if USE_BETA:
        beta = tl.load(beta_ptrs, boundary_check=(0,), padding_option="zero")
        gate = gate * beta[None, :]
    acc_sq1 = acc_sq1 * tl.exp(alpha)[:, None]
    acc_sq2 = acc_sq2 * tl.exp(alpha)[:, None]
    acc_q1k = acc_q1k * gate
    acc_q2k = acc_q2k * gate
    # causal mask
    off_c = tl.arange(0, CHUNK_SIZE)
    acc_q1k = tl.where(off_c[:, None] >= off_c[None, :], acc_q1k, 0)
    acc_q2k = tl.where(off_c[:, None] >= off_c[None, :], acc_q2k, 0)
    # intra chunk output, [CHUNK_SIZE, CHUNK_SIZE] @ [CHUNK_SIZE, BLOCK_SIZE_VD] -> [CHUNK_SIZE, BLOCK_SIZE_VD]
    v_ptrs = tl.make_block_ptr(
        base=v_ptr,
        shape=(seq_len, head_dim_v),
        strides=(stride_v_n, stride_v_d),
        offsets=(pid_c * CHUNK_SIZE, pid_vd * BLOCK_SIZE_VD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
    acc_sq1 += tl.dot(acc_q1k.to(v.dtype), v)
    acc_sq2 += tl.dot(acc_q2k.to(v.dtype), v)
    acc_sq1 *= scale
    acc_sq2 *= scale
    # save output
    sq_ptrs = tl.make_block_ptr(
        base=sq_ptr,
        shape=(seq_len, head_dim_v),
        strides=(stride_sq_n, stride_sq_d),
        offsets=(pid_c * CHUNK_SIZE + 1, pid_vd * BLOCK_SIZE_VD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    tl.store(sq_ptrs, acc_sq1.to(sq_ptr.dtype.element_ty), boundary_check=(0, 1))
    sk_ptrs = tl.make_block_ptr(
        base=sk_ptr,
        shape=(seq_len, head_dim_v),
        strides=(stride_sk_n, stride_sk_d),
        offsets=(pid_c * CHUNK_SIZE + 1, pid_vd * BLOCK_SIZE_VD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    tl.store(sk_ptrs, acc_sq2.to(sk_ptr.dtype.element_ty), boundary_check=(0, 1))


def chunk_fwd_sq_sk(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    alpha: torch.Tensor,
    beta: Optional[torch.Tensor],
    chunk_state: torch.Tensor,
    scale: float,
    cu_seq_len: Optional[torch.Tensor],
    cu_chunk_len: Optional[torch.Tensor],
    chunk_size: int,
) -> torch.Tensor:
    if cu_seq_len is None:
        batch_size, seq_len, num_heads, head_dim_qk, head_dim_v = *k.shape, v.shape[-1]
        num_chunks = math.ceil(seq_len / chunk_size)
    else:
        _, _, num_heads, head_dim_qk, head_dim_v = *k.shape, v.shape[-1]
        batch_size = cu_seq_len.shape[0] - 1
        seq_len = cu_seq_len[-1]
        num_chunks = (cu_chunk_len[1:] - cu_chunk_len[:-1]).max()
    sq = torch.zeros_like(v)
    sk = torch.zeros_like(v)

    def grid(meta):
        return (
            triton.cdiv(head_dim_v, meta["BLOCK_SIZE_VD"]),
            num_chunks,
            batch_size * num_heads,
        )

    chunk_fwd_kernel_sq_sk[grid](
        q,
        k,
        v,
        sq,
        sk,
        alpha,
        beta,
        chunk_state,
        scale,
        cu_seq_len,
        cu_chunk_len,
        seq_len,
        num_chunks,
        num_heads,
        head_dim_qk,
        head_dim_v,
        q.stride(0),
        q.stride(1),
        q.stride(2),
        q.stride(3),
        k.stride(0),
        k.stride(1),
        k.stride(2),
        k.stride(3),
        v.stride(0),
        v.stride(1),
        v.stride(2),
        v.stride(3),
        sq.stride(0),
        sq.stride(1),
        sq.stride(2),
        sq.stride(3),
        sk.stride(0),
        sk.stride(1),
        sk.stride(2),
        sk.stride(3),
        alpha.stride(0),
        alpha.stride(1),
        alpha.stride(2),
        beta.stride(0) if beta is not None else 0,
        beta.stride(1) if beta is not None else 0,
        beta.stride(2) if beta is not None else 0,
        chunk_state.stride(0),
        chunk_state.stride(1),
        chunk_state.stride(2),
        chunk_state.stride(3),
        chunk_state.stride(4),
        CHUNK_SIZE=triton.next_power_of_2(chunk_size),
    )
    return sq, sk


@triton.heuristics(
    {
        "IS_VARLEN": lambda args: args["cu_seq_len"] is not None,
    }
)
@triton.autotune(
    configs=[
        triton.Config(
            {"BLOCK_SIZE_KD": 64, "BLOCK_SIZE_VD": 128}, num_warps=4, num_stages=3
        ),
        triton.Config(
            {"BLOCK_SIZE_KD": 64, "BLOCK_SIZE_VD": 64}, num_warps=4, num_stages=3
        ),
        triton.Config(
            {"BLOCK_SIZE_KD": 32, "BLOCK_SIZE_VD": 32}, num_warps=2, num_stages=3
        ),
    ],
    reset_to_zero=["da_ptr"],
    key=["CHUNK_SIZE", "head_dim_qk", "head_dim_v"],
    use_cuda_graph=use_cuda_graph,
)
@triton.jit
def chunk_sq_sk_bwd_dqka_kernel(
    q_ptr,
    k_ptr,
    v_ptr,
    alpha_ptr,
    beta_ptr,
    chunk_state_ptr,
    dsq_ptr,
    dsk_ptr,
    d_chunk_state_ptr,
    dq1_ptr,
    dq2_ptr,
    dk_ptr,
    da_ptr,
    scale,
    # shapes
    cu_seq_len,
    cu_chunk_len,
    seq_len,
    num_chunks,
    num_heads,
    head_dim_qk,
    head_dim_v,
    # strides
    stride_q_b,
    stride_q_n,
    stride_q_h,
    stride_q_d,
    stride_k_b,
    stride_k_n,
    stride_k_h,
    stride_k_d,
    stride_v_b,
    stride_v_n,
    stride_v_h,
    stride_v_d,
    stride_a_b,
    stride_a_n,
    stride_a_h,
    stride_b_b,
    stride_b_n,
    stride_b_h,
    stride_cs_b,
    stride_cs_n,
    stride_cs_h,
    stride_cs_kd,
    stride_cs_vd,
    stride_dsq_b,
    stride_dsq_n,
    stride_dsq_h,
    stride_dsq_d,
    stride_dsk_b,
    stride_dsk_n,
    stride_dsk_h,
    stride_dsk_d,
    stride_dcs_b,
    stride_dcs_n,
    stride_dcs_h,
    stride_dcs_kd,
    stride_dcs_vd,
    stride_dq1_b,
    stride_dq1_n,
    stride_dq1_h,
    stride_dq1_d,
    stride_dq2_b,
    stride_dq2_n,
    stride_dq2_h,
    stride_dq2_d,
    stride_dk_b,
    stride_dk_n,
    stride_dk_h,
    stride_dk_d,
    stride_da_b,
    stride_da_n,
    stride_da_h,
    # block sizes
    CHUNK_SIZE: tl.constexpr,
    BLOCK_SIZE_KD: tl.constexpr,
    BLOCK_SIZE_VD: tl.constexpr,
    # option
    IS_VARLEN: tl.constexpr,
):
    pid_bh, pid_c, pid_kd = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    pid_b, pid_h = pid_bh // num_heads, pid_bh % num_heads
    # move ptr to the start of this batch
    if IS_VARLEN:
        seq_start = tl.load(cu_seq_len + pid_b).to(tl.int32)
        seq_end = tl.load(cu_seq_len + pid_b + 1).to(tl.int32)
        seq_len = seq_end - seq_start
        chunk_start = tl.load(cu_chunk_len + pid_b).to(tl.int32)
        q_ptr = q_ptr + seq_start * stride_q_n + pid_h * stride_q_h
        k_ptr = k_ptr + seq_start * stride_k_n + pid_h * stride_k_h
        v_ptr = v_ptr + seq_start * stride_v_n + pid_h * stride_v_h
        alpha_ptr = alpha_ptr + seq_start * stride_a_n + pid_h * stride_a_h
        beta_ptr = beta_ptr + seq_start * stride_b_n + pid_h * stride_b_h
        chunk_state_ptr = (
            chunk_state_ptr + chunk_start * stride_cs_n + pid_h * stride_cs_h
        )
        dsq_ptr = dsq_ptr + seq_start * stride_dsq_n + pid_h * stride_dsq_h
        dsk_ptr = dsk_ptr + seq_start * stride_dsk_n + pid_h * stride_dsk_h
        d_chunk_state_ptr = (
            d_chunk_state_ptr + chunk_start * stride_dcs_n + pid_h * stride_dcs_h
        )
        dq1_ptr = dq1_ptr + seq_start * stride_dq1_n + pid_h * stride_dq1_h
        dq2_ptr = dq2_ptr + seq_start * stride_dq2_n + pid_h * stride_dq2_h
        dk_ptr = dk_ptr + seq_start * stride_dk_n + pid_h * stride_dk_h
        da_ptr = da_ptr + seq_start * stride_da_n + pid_h * stride_da_h
    else:
        q_ptr = q_ptr + pid_b * stride_q_b + pid_h * stride_q_h
        k_ptr = k_ptr + pid_b * stride_k_b + pid_h * stride_k_h
        v_ptr = v_ptr + pid_b * stride_v_b + pid_h * stride_v_h
        alpha_ptr = alpha_ptr + pid_b * stride_a_b + pid_h * stride_a_h
        beta_ptr = beta_ptr + pid_b * stride_b_b + pid_h * stride_b_h
        chunk_state_ptr = chunk_state_ptr + pid_b * stride_cs_b + pid_h * stride_cs_h
        dsq_ptr = dsq_ptr + pid_b * stride_dsq_b + pid_h * stride_dsq_h
        dsk_ptr = dsk_ptr + pid_b * stride_dsk_b + pid_h * stride_dsk_h
        d_chunk_state_ptr = (
            d_chunk_state_ptr + pid_b * stride_dcs_b + pid_h * stride_dcs_h
        )
        dq1_ptr = dq1_ptr + pid_b * stride_dq1_b + pid_h * stride_dq1_h
        dq2_ptr = dq2_ptr + pid_b * stride_dq2_b + pid_h * stride_dq2_h
        dk_ptr = dk_ptr + pid_b * stride_dk_b + pid_h * stride_dk_h
        da_ptr = da_ptr + pid_b * stride_da_b + pid_h * stride_da_h
    if pid_c * CHUNK_SIZE + 1 >= seq_len:
        return
    # ptrs
    q1_ptrs = tl.make_block_ptr(
        base=q_ptr,
        shape=(seq_len, head_dim_qk),
        strides=(stride_q_n, stride_q_d),
        offsets=(pid_c * CHUNK_SIZE + 1, pid_kd * BLOCK_SIZE_KD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_KD),
        order=(1, 0),
    )
    q2_ptrs = tl.make_block_ptr(
        base=k_ptr,
        shape=(seq_len, head_dim_qk),
        strides=(stride_k_n, stride_k_d),
        offsets=(pid_c * CHUNK_SIZE + 1, pid_kd * BLOCK_SIZE_KD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_KD),
        order=(1, 0),
    )
    k_ptrs = tl.make_block_ptr(
        base=k_ptr,
        shape=(seq_len, head_dim_qk),
        strides=(stride_k_n, stride_k_d),
        offsets=(pid_c * CHUNK_SIZE, pid_kd * BLOCK_SIZE_KD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_KD),
        order=(1, 0),
    )
    v_ptrs = tl.make_block_ptr(
        base=v_ptr,
        shape=(seq_len, head_dim_v),
        strides=(stride_v_n, stride_v_d),
        offsets=(pid_c * CHUNK_SIZE, 0),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    alpha_ptrs = tl.make_block_ptr(
        base=alpha_ptr,
        shape=(seq_len,),
        strides=(stride_a_n,),
        offsets=(pid_c * CHUNK_SIZE,),
        block_shape=(CHUNK_SIZE,),
        order=(0,),
    )
    beta_ptrs = tl.make_block_ptr(
        base=beta_ptr,
        shape=(seq_len,),
        strides=(stride_b_n,),
        offsets=(pid_c * CHUNK_SIZE,),
        block_shape=(CHUNK_SIZE,),
        order=(0,),
    )
    do1_ptrs = tl.make_block_ptr(
        base=dsq_ptr,
        shape=(seq_len, head_dim_v),
        strides=(stride_dsq_n, stride_dsq_d),
        offsets=(pid_c * CHUNK_SIZE + 1, 0),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    do2_ptrs = tl.make_block_ptr(
        base=dsk_ptr,
        shape=(seq_len, head_dim_v),
        strides=(stride_dsk_n, stride_dsk_d),
        offsets=(pid_c * CHUNK_SIZE + 1, 0),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    d_chunk_state_ptrs = tl.make_block_ptr(
        base=d_chunk_state_ptr + pid_c * stride_dcs_n,
        shape=(head_dim_v, head_dim_qk),
        strides=(stride_dcs_vd, stride_dcs_kd),
        offsets=(0, pid_kd * BLOCK_SIZE_KD),
        block_shape=(BLOCK_SIZE_VD, BLOCK_SIZE_KD),
        order=(0, 1),
    )
    chunk_state_ptrs = tl.make_block_ptr(
        base=chunk_state_ptr + pid_c * stride_cs_n,
        shape=(head_dim_v, head_dim_qk),
        strides=(stride_cs_vd, stride_cs_kd),
        offsets=(0, pid_kd * BLOCK_SIZE_KD),
        block_shape=(BLOCK_SIZE_VD, BLOCK_SIZE_KD),
        order=(0, 1),
    )
    # init dq, dk
    dq1 = tl.zeros([CHUNK_SIZE, BLOCK_SIZE_KD], dtype=tl.float32)
    dq2 = tl.zeros([CHUNK_SIZE, BLOCK_SIZE_KD], dtype=tl.float32)
    dk = tl.zeros([CHUNK_SIZE, BLOCK_SIZE_KD], dtype=tl.float32)
    do1v = tl.zeros([CHUNK_SIZE, CHUNK_SIZE], dtype=tl.float32)
    do2v = tl.zeros([CHUNK_SIZE, CHUNK_SIZE], dtype=tl.float32)
    da = tl.zeros(
        [
            CHUNK_SIZE,
        ],
        dtype=tl.float32,
    )
    d_sum_alpha = tl.zeros(
        [
            1,
        ],
        dtype=tl.float32,
    )

    # load beta, shape: [CHUNK_SIZE,]
    beta = tl.load(beta_ptrs, boundary_check=(0,), padding_option="zero")

    # compute inter block gradient for qk
    for _ in range(0, head_dim_v, BLOCK_SIZE_VD):
        # load v, do, shape: [CHUNK_SIZE, BLOCK_SIZE_VD]
        v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
        v = (v * beta[:, None]).to(v.dtype)
        do1 = tl.load(do1_ptrs, boundary_check=(0, 1), padding_option="zero")
        do2 = tl.load(do2_ptrs, boundary_check=(0, 1), padding_option="zero")
        # chunk_state and d_chunk_stte, [BLOCK_SIZE_VD, BLOCK_SIZE_KD]
        d_chunk_state = tl.load(
            d_chunk_state_ptrs, boundary_check=(0, 1), padding_option="zero"
        ).to(v.dtype)
        chunk_state = tl.load(
            chunk_state_ptrs, boundary_check=(0, 1), padding_option="zero"
        ).to(do1.dtype)
        # compute do @ v.T, [CHUNK_SIZE, BLOCK_SIZE_VD] @ [BLOCK_SIZE_VD, CHUNK_SIZE] -> [CHUNK_SIZE, CHUNK_SIZE]
        do1v += tl.dot(do1, tl.trans(v)) * scale
        do2v += tl.dot(do2, tl.trans(v)) * scale
        # compute inter block dq = do @ state.T, [CHUNK_SIZE, BLOCK_SIZE_VD] @ [BLOCK_SIZE_VD, BLOCK_SIZE_KD] -> [CHUNK_SIZE, BLOCK_SIZE_KD]
        dq1 += tl.dot(do1, chunk_state) * scale
        dq2 += tl.dot(do2, chunk_state) * scale
        # compute inter block dk = v @ d_chunk_state.T, [CHUNK_SIZE, BLOCK_SIZE_VD] @ [BLOCK_SIZE_VD, BLOCK_SIZE_KD] -> [CHUNK_SIZE, BLOCK_SIZE_KD]
        dk += tl.dot(v, d_chunk_state)
        # compute state * d_chunk_state
        d_sum_alpha += tl.sum(d_chunk_state * chunk_state)
        # update ptrs
        v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_VD))
        do1_ptrs = tl.advance(do1_ptrs, (0, BLOCK_SIZE_VD))
        do2_ptrs = tl.advance(do2_ptrs, (0, BLOCK_SIZE_VD))
        d_chunk_state_ptrs = tl.advance(d_chunk_state_ptrs, (BLOCK_SIZE_VD, 0))
        chunk_state_ptrs = tl.advance(chunk_state_ptrs, (BLOCK_SIZE_VD, 0))

    # load alpha, shape: [CHUNK_SIZE,]
    alpha = tl.load(alpha_ptrs, boundary_check=(0,), padding_option="zero")
    sum_alpha = tl.sum(alpha)
    cum_alpha = tl.cumsum(alpha, axis=0)
    # load q, k, shape: [CHUNK_SIZE, BLOCK_SIZE_KD]
    q1 = tl.load(q1_ptrs, boundary_check=(0, 1), padding_option="zero")
    q2 = tl.load(q2_ptrs, boundary_check=(0, 1), padding_option="zero")
    k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")

    # gate for inter chunk dq and dk, [CHUNK_SIZE, BLOCK_SIZE_KD] * [CHUNK_SIZE, 1] -> [CHUNK_SIZE, BLOCK_SIZE_KD]
    dq1 = dq1 * tl.exp(cum_alpha)[:, None]
    dq2 = dq2 * tl.exp(cum_alpha)[:, None]
    dk = dk * tl.exp(sum_alpha - cum_alpha)[:, None]

    # gate for dov
    off_c = tl.arange(0, CHUNK_SIZE)
    do1v = tl.where(
        off_c[:, None] >= off_c[None, :],
        do1v * exp_negative_only(cum_alpha[:, None] - cum_alpha[None, :]),
        0,
    )
    do2v = tl.where(
        off_c[:, None] >= off_c[None, :],
        do2v * exp_negative_only(cum_alpha[:, None] - cum_alpha[None, :]),
        0,
    )

    # compute gradient of alpha
    d_sum_alpha *= tl.exp(sum_alpha)
    d_sum_alpha += tl.sum(k * dk)
    da += tl.sum(dq1 * q1, axis=1)
    da += tl.sum(dq2 * q2, axis=1)
    da -= tl.sum(k * dk, axis=1)
    do1vq1k = do1v * tl.dot(q1, tl.trans(k))
    do2vq2k = do2v * tl.dot(q2, tl.trans(k))
    da += tl.sum(do1vq1k, axis=1) + tl.sum(do2vq2k, axis=1)
    da -= tl.sum(do1vq1k, axis=0) + tl.sum(do2vq2k, axis=0)
    da = tl.cumsum(da, axis=0, reverse=True) + d_sum_alpha

    # intra chunk dq and dk, [CHUNK_SIZE, CHUNK_SIZE] @ [CHUNK_SIZE, BLOCK_SIZE_KD] -> [CHUNK_SIZE, BLOCK_SIZE_KD]
    do1v = do1v.to(k.dtype)
    dq1 += tl.dot(do1v, k)
    dk += tl.dot(tl.trans(do1v), q1)
    do2v = do2v.to(k.dtype)
    dq2 += tl.dot(do2v, k)
    dk += tl.dot(tl.trans(do2v), q2)

    # store dq, dk, da
    dq1_ptrs = tl.make_block_ptr(
        base=dq1_ptr,
        shape=(seq_len, head_dim_qk),
        strides=(stride_dq1_n, stride_dq1_d),
        offsets=(pid_c * CHUNK_SIZE + 1, pid_kd * BLOCK_SIZE_KD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_KD),
        order=(1, 0),
    )
    dq2_ptrs = tl.make_block_ptr(
        base=dq2_ptr,
        shape=(seq_len, head_dim_qk),
        strides=(stride_dq2_n, stride_dq2_d),
        offsets=(pid_c * CHUNK_SIZE + 1, pid_kd * BLOCK_SIZE_KD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_KD),
        order=(1, 0),
    )
    dk_ptrs = tl.make_block_ptr(
        base=dk_ptr,
        shape=(seq_len, head_dim_qk),
        strides=(stride_dk_n, stride_dk_d),
        offsets=(pid_c * CHUNK_SIZE, pid_kd * BLOCK_SIZE_KD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_KD),
        order=(1, 0),
    )
    da_ptrs = da_ptr + (pid_c * CHUNK_SIZE + off_c) * stride_da_n
    tl.store(dq1_ptrs, dq1.to(dq1_ptrs.dtype.element_ty), boundary_check=(0, 1))
    tl.store(dq2_ptrs, dq2.to(dq2_ptrs.dtype.element_ty), boundary_check=(0, 1))
    tl.store(dk_ptrs, dk.to(dk_ptrs.dtype.element_ty), boundary_check=(0, 1))
    tl.atomic_add(
        da_ptrs,
        da.to(da_ptrs.dtype.element_ty),
        mask=off_c < seq_len - pid_c * CHUNK_SIZE,
        sem="relaxed",
    )


def chunk_sq_sk_bwd_dqka(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    alpha: torch.Tensor,
    beta: torch.Tensor,
    chunk_state: torch.Tensor,
    dsq: torch.Tensor,
    dsk: torch.Tensor,
    d_chunk_state: torch.Tensor,
    scale: float,
    cu_seq_len: Optional[torch.Tensor],
    cu_chunk_len: Optional[torch.Tensor],
    chunk_size: int,
):
    if cu_seq_len is None:
        batch_size, seq_len, num_heads, head_dim_qk, head_dim_v = *k.shape, v.shape[-1]
        num_chunks = math.ceil(seq_len / chunk_size)
    else:
        _, _, num_heads, head_dim_qk, head_dim_v = *k.shape, v.shape[-1]
        batch_size = cu_seq_len.shape[0] - 1
        seq_len = cu_seq_len[-1]
        num_chunks = (cu_chunk_len[1:] - cu_chunk_len[:-1]).max()

    dq1 = torch.zeros_like(q)
    dq2 = torch.zeros_like(q)
    dk = torch.empty_like(k)
    da = torch.zeros_like(alpha, dtype=torch.float32)

    # launch kernel
    def grid(meta):
        return (
            batch_size * num_heads,
            num_chunks,
            triton.cdiv(head_dim_qk, meta["BLOCK_SIZE_KD"]),
        )

    chunk_sq_sk_bwd_dqka_kernel[grid](
        q,
        k,
        v,
        alpha,
        beta,
        chunk_state,
        dsq,
        dsk,
        d_chunk_state,
        dq1,
        dq2,
        dk,
        da,
        scale,
        cu_seq_len,
        cu_chunk_len,
        seq_len,
        num_chunks,
        num_heads,
        head_dim_qk,
        head_dim_v,
        q.stride(0),
        q.stride(1),
        q.stride(2),
        q.stride(3),
        k.stride(0),
        k.stride(1),
        k.stride(2),
        k.stride(3),
        v.stride(0),
        v.stride(1),
        v.stride(2),
        v.stride(3),
        alpha.stride(0),
        alpha.stride(1),
        alpha.stride(2),
        beta.stride(0),
        beta.stride(1),
        beta.stride(2),
        chunk_state.stride(0),
        chunk_state.stride(1),
        chunk_state.stride(2),
        chunk_state.stride(3),
        chunk_state.stride(4),
        dsq.stride(0),
        dsq.stride(1),
        dsq.stride(2),
        dsq.stride(3),
        dsk.stride(0),
        dsk.stride(1),
        dsk.stride(2),
        dsk.stride(3),
        d_chunk_state.stride(0),
        d_chunk_state.stride(1),
        d_chunk_state.stride(2),
        d_chunk_state.stride(3),
        d_chunk_state.stride(4),
        dq1.stride(0),
        dq1.stride(1),
        dq1.stride(2),
        dq1.stride(3),
        dq2.stride(0),
        dq2.stride(1),
        dq2.stride(2),
        dq2.stride(3),
        dk.stride(0),
        dk.stride(1),
        dk.stride(2),
        dk.stride(3),
        da.stride(0),
        da.stride(1),
        da.stride(2),
        CHUNK_SIZE=triton.next_power_of_2(chunk_size),
    )
    dq = dq1
    dk2 = dq2
    return dq, dk, dk2, da


@triton.heuristics(
    {
        "IS_VARLEN": lambda args: args["cu_seq_len"] is not None,
    }
)
@triton.autotune(
    configs=[
        triton.Config(
            {"BLOCK_SIZE_KD": 128, "BLOCK_SIZE_VD": 64}, num_warps=4, num_stages=3
        ),
        triton.Config(
            {"BLOCK_SIZE_KD": 64, "BLOCK_SIZE_VD": 64}, num_warps=4, num_stages=3
        ),
        triton.Config(
            {"BLOCK_SIZE_KD": 32, "BLOCK_SIZE_VD": 32}, num_warps=2, num_stages=3
        ),
    ],
    reset_to_zero=["db_ptr"],
    key=["CHUNK_SIZE", "head_dim_qk", "head_dim_v"],
    use_cuda_graph=use_cuda_graph,
)
@triton.jit
def chunk_sq_sk_bwd_dvb_kernel(
    q_ptr,
    k_ptr,
    v_ptr,
    alpha_ptr,
    beta_ptr,
    dsq_ptr,
    dsk_ptr,
    d_chunk_state_ptr,
    dv_ptr,
    db_ptr,
    scale,
    # shapes
    cu_seq_len,
    cu_chunk_len,
    seq_len,
    num_chunks,
    num_heads,
    head_dim_qk,
    head_dim_v,
    # strides
    stride_q_b,
    stride_q_n,
    stride_q_h,
    stride_q_d,
    stride_k_b,
    stride_k_n,
    stride_k_h,
    stride_k_d,
    stride_v_b,
    stride_v_n,
    stride_v_h,
    stride_v_d,
    stride_a_b,
    stride_a_n,
    stride_a_h,
    stride_b_b,
    stride_b_n,
    stride_b_h,
    stride_dsq_b,
    stride_dsq_n,
    stride_dsq_h,
    stride_dsq_d,
    stride_dsk_b,
    stride_dsk_n,
    stride_dsk_h,
    stride_dsk_d,
    stride_dcs_b,
    stride_dcs_n,
    stride_dcs_h,
    stride_dcs_kd,
    stride_dcs_vd,
    stride_dv_b,
    stride_dv_n,
    stride_dv_h,
    stride_dv_d,
    stride_db_b,
    stride_db_n,
    stride_db_h,
    # block sizes
    CHUNK_SIZE: tl.constexpr,
    BLOCK_SIZE_KD: tl.constexpr,
    BLOCK_SIZE_VD: tl.constexpr,
    # option
    IS_VARLEN: tl.constexpr,
):
    pid_bh, pid_c, pid_vd = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    pid_b, pid_h = pid_bh // num_heads, pid_bh % num_heads
    # move ptr to the start of this batch
    if IS_VARLEN:
        seq_start = tl.load(cu_seq_len + pid_b).to(tl.int32)
        seq_end = tl.load(cu_seq_len + pid_b + 1).to(tl.int32)
        seq_len = seq_end - seq_start
        chunk_start = tl.load(cu_chunk_len + pid_b).to(tl.int32)
        q_ptr = q_ptr + seq_start * stride_q_n + pid_h * stride_q_h
        k_ptr = k_ptr + seq_start * stride_k_n + pid_h * stride_k_h
        v_ptr = v_ptr + seq_start * stride_v_n + pid_h * stride_v_h
        alpha_ptr = alpha_ptr + seq_start * stride_a_n + pid_h * stride_a_h
        beta_ptr = beta_ptr + seq_start * stride_b_n + pid_h * stride_b_h
        dsq_ptr = dsq_ptr + seq_start * stride_dsq_n + pid_h * stride_dsq_h
        dsk_ptr = dsk_ptr + seq_start * stride_dsk_n + pid_h * stride_dsk_h
        d_chunk_state_ptr = (
            d_chunk_state_ptr + chunk_start * stride_dcs_n + pid_h * stride_dcs_h
        )
        dv_ptr = dv_ptr + seq_start * stride_dv_n + pid_h * stride_dv_h
        db_ptr = db_ptr + seq_start * stride_db_n + pid_h * stride_db_h
    else:
        q_ptr = q_ptr + pid_b * stride_q_b + pid_h * stride_q_h
        k_ptr = k_ptr + pid_b * stride_k_b + pid_h * stride_k_h
        v_ptr = v_ptr + pid_b * stride_v_b + pid_h * stride_v_h
        alpha_ptr = alpha_ptr + pid_b * stride_a_b + pid_h * stride_a_h
        beta_ptr = beta_ptr + pid_b * stride_b_b + pid_h * stride_b_h
        dsq_ptr = dsq_ptr + pid_b * stride_dsq_b + pid_h * stride_dsq_h
        dsk_ptr = dsk_ptr + pid_b * stride_dsk_b + pid_h * stride_dsk_h
        d_chunk_state_ptr = (
            d_chunk_state_ptr + pid_b * stride_dcs_b + pid_h * stride_dcs_h
        )
        dv_ptr = dv_ptr + pid_b * stride_dv_b + pid_h * stride_dv_h
        db_ptr = db_ptr + pid_b * stride_db_b + pid_h * stride_db_h
    if pid_c * CHUNK_SIZE + 1 >= seq_len:
        return
    # ptrs
    q1_ptrs = tl.make_block_ptr(
        base=q_ptr,
        shape=(head_dim_qk, seq_len),
        strides=(stride_q_d, stride_q_n),
        offsets=(0, pid_c * CHUNK_SIZE + 1),
        block_shape=(BLOCK_SIZE_KD, CHUNK_SIZE),
        order=(0, 1),
    )
    q2_ptrs = tl.make_block_ptr(
        base=k_ptr,
        shape=(head_dim_qk, seq_len),
        strides=(stride_k_d, stride_k_n),
        offsets=(0, pid_c * CHUNK_SIZE + 1),
        block_shape=(BLOCK_SIZE_KD, CHUNK_SIZE),
        order=(0, 1),
    )
    k_ptrs = tl.make_block_ptr(
        base=k_ptr,
        shape=(seq_len, head_dim_qk),
        strides=(stride_k_n, stride_k_d),
        offsets=(pid_c * CHUNK_SIZE, 0),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_KD),
        order=(1, 0),
    )
    v_ptrs = tl.make_block_ptr(
        base=v_ptr,
        shape=(seq_len, head_dim_v),
        strides=(stride_v_n, stride_v_d),
        offsets=(pid_c * CHUNK_SIZE, pid_vd * BLOCK_SIZE_VD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    dsq_ptrs = tl.make_block_ptr(
        base=dsq_ptr,
        shape=(seq_len, head_dim_v),
        strides=(stride_dsq_n, stride_dsq_d),
        offsets=(pid_c * CHUNK_SIZE + 1, pid_vd * BLOCK_SIZE_VD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    dsk_ptrs = tl.make_block_ptr(
        base=dsk_ptr,
        shape=(seq_len, head_dim_v),
        strides=(stride_dsk_n, stride_dsk_d),
        offsets=(pid_c * CHUNK_SIZE + 1, pid_vd * BLOCK_SIZE_VD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    d_chunk_state_ptrs = tl.make_block_ptr(
        base=d_chunk_state_ptr + pid_c * stride_dcs_n,
        shape=(head_dim_qk, head_dim_v),
        strides=(stride_dcs_kd, stride_dcs_vd),
        offsets=(0, pid_vd * BLOCK_SIZE_VD),
        block_shape=(BLOCK_SIZE_KD, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    alpha_ptrs = tl.make_block_ptr(
        base=alpha_ptr,
        shape=(seq_len,),
        strides=(stride_a_n,),
        offsets=(pid_c * CHUNK_SIZE,),
        block_shape=(CHUNK_SIZE,),
        order=(0,),
    )
    beta_ptrs = tl.make_block_ptr(
        base=beta_ptr,
        shape=(seq_len,),
        strides=(stride_b_n,),
        offsets=(pid_c * CHUNK_SIZE,),
        block_shape=(CHUNK_SIZE,),
        order=(0,),
    )
    # init buffer
    dbv = tl.zeros([CHUNK_SIZE, BLOCK_SIZE_VD], dtype=tl.float32)
    kq1 = tl.zeros([CHUNK_SIZE, CHUNK_SIZE], dtype=tl.float32)
    kq2 = tl.zeros([CHUNK_SIZE, CHUNK_SIZE], dtype=tl.float32)
    # loop over k head dim, compute k @ q.T and inter block part of dv
    for _ in range(0, head_dim_qk, BLOCK_SIZE_KD):
        # load q, k, shape: [BLOCK_SIZE_KD, CHUNK_SIZE]
        q1 = tl.load(q1_ptrs, boundary_check=(0, 1), padding_option="zero")
        q2 = tl.load(q2_ptrs, boundary_check=(0, 1), padding_option="zero")
        k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
        # load d_chunk_state, shape: [BLOCK_SIZE_VD, BLOCK_SIZE_KD]
        d_chunk_state = tl.load(
            d_chunk_state_ptrs, boundary_check=(0, 1), padding_option="zero"
        ).to(k.dtype)
        # compute k @ q.T
        kq1 += tl.dot(k, q1) * scale
        kq2 += tl.dot(k, q2) * scale
        # compute inter block part of dv
        dbv += tl.dot(k, d_chunk_state)
        # update ptrs
        q1_ptrs = tl.advance(q1_ptrs, (BLOCK_SIZE_KD, 0))
        q2_ptrs = tl.advance(q2_ptrs, (BLOCK_SIZE_KD, 0))
        k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_KD))
        d_chunk_state_ptrs = tl.advance(d_chunk_state_ptrs, (BLOCK_SIZE_KD, 0))

    # load alpha, shape: [CHUNK_SIZE,]
    alpha = tl.load(alpha_ptrs, boundary_check=(0,), padding_option="zero")
    sum_alpha = tl.sum(alpha, axis=0)
    cum_alpha = tl.cumsum(alpha, axis=0)

    # gate for inter block part of dv
    dbv *= tl.exp(sum_alpha - cum_alpha)[:, None]

    # gate for kq
    off_c = tl.arange(0, CHUNK_SIZE)
    kq1 = tl.where(
        off_c[:, None] <= off_c[None, :],
        kq1 * exp_negative_only(cum_alpha[None, :] - cum_alpha[:, None]),
        0,
    )
    kq2 = tl.where(
        off_c[:, None] <= off_c[None, :],
        kq2 * exp_negative_only(cum_alpha[None, :] - cum_alpha[:, None]),
        0,
    )

    # intra block part of dv, [CHUNK_SIZE, CHUNK_SIZE] @ [CHUNK_SIZE, BLOCK_SIZE_VD] -> [CHUNK_SIZE, BLOCK_SIZE_VD]
    # load do, shape: [CHUNK_SIZE, BLOCK_SIZE_VD]
    do1 = tl.load(dsq_ptrs, boundary_check=(0, 1), padding_option="zero")
    do2 = tl.load(dsk_ptrs, boundary_check=(0, 1), padding_option="zero")
    dbv += tl.dot(kq1.to(do1.dtype), do1)
    dbv += tl.dot(kq2.to(do2.dtype), do2)

    # get actual dv and db
    # load v, shape: [CHUNK_SIZE, BLOCK_SIZE_VD]
    v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
    # load beta, shape: [CHUNK_SIZE,]
    beta = tl.load(beta_ptrs, boundary_check=(0,), padding_option="zero")
    dv = dbv * beta[:, None]
    db = tl.sum(dbv * v, axis=1)

    # store dv db
    dv_ptrs = tl.make_block_ptr(
        base=dv_ptr,
        shape=(seq_len, head_dim_v),
        strides=(stride_dv_n, stride_dv_d),
        offsets=(pid_c * CHUNK_SIZE, pid_vd * BLOCK_SIZE_VD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    db_ptrs = db_ptr + (pid_c * CHUNK_SIZE + off_c) * stride_db_n
    tl.store(dv_ptrs, dv.to(dv_ptrs.dtype.element_ty), boundary_check=(0, 1))
    tl.atomic_add(
        db_ptrs,
        db.to(db_ptrs.dtype.element_ty),
        mask=pid_c * CHUNK_SIZE + off_c < seq_len,
        sem="relaxed",
    )


def chunk_sq_sk_bwd_dvb(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    alpha: torch.Tensor,
    beta: torch.Tensor,
    dsq: torch.Tensor,
    dsk: torch.Tensor,
    d_chunk_state: torch.Tensor,
    scale: float,
    cu_seq_len: Optional[torch.Tensor],
    cu_chunk_len: Optional[torch.Tensor],
    chunk_size: int,
):
    if cu_seq_len is None:
        batch_size, seq_len, num_heads, head_dim_qk, head_dim_v = *k.shape, v.shape[-1]
        num_chunks = math.ceil(seq_len / chunk_size)
    else:
        _, _, num_heads, head_dim_qk, head_dim_v = *k.shape, v.shape[-1]
        batch_size = cu_seq_len.shape[0] - 1
        seq_len = cu_seq_len[-1]
        num_chunks = (cu_chunk_len[1:] - cu_chunk_len[:-1]).max()

    dv = torch.empty_like(v)
    db = torch.zeros_like(beta)

    # launch kernel
    def grid(meta):
        return (
            batch_size * num_heads,
            num_chunks,
            triton.cdiv(head_dim_v, meta["BLOCK_SIZE_VD"]),
        )

    chunk_sq_sk_bwd_dvb_kernel[grid](
        q,
        k,
        v,
        alpha,
        beta,
        dsq,
        dsk,
        d_chunk_state,
        dv,
        db,
        scale,
        cu_seq_len,
        cu_chunk_len,
        seq_len,
        num_chunks,
        num_heads,
        head_dim_qk,
        head_dim_v,
        q.stride(0),
        q.stride(1),
        q.stride(2),
        q.stride(3),
        k.stride(0),
        k.stride(1),
        k.stride(2),
        k.stride(3),
        v.stride(0),
        v.stride(1),
        v.stride(2),
        v.stride(3),
        alpha.stride(0),
        alpha.stride(1),
        alpha.stride(2),
        beta.stride(0),
        beta.stride(1),
        beta.stride(2),
        dsq.stride(0),
        dsq.stride(1),
        dsq.stride(2),
        dsq.stride(3),
        dsk.stride(0),
        dsk.stride(1),
        dsk.stride(2),
        dsk.stride(3),
        d_chunk_state.stride(0),
        d_chunk_state.stride(1),
        d_chunk_state.stride(2),
        d_chunk_state.stride(3),
        d_chunk_state.stride(4),
        dv.stride(0),
        dv.stride(1),
        dv.stride(2),
        dv.stride(3),
        db.stride(0),
        db.stride(1),
        db.stride(2),
        CHUNK_SIZE=triton.next_power_of_2(chunk_size),
    )
    return dv, db
