import math
import torch
import triton
import triton.language as tl
from typing import Optional
from ..common.utils import is_nvidia_hopper, use_cuda_graph

NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
NUM_STAGES = [2, 3, 4]


@triton.jit
def exp_negative_only(x):
    # avoid overflow when compute gate as exp(gamma[:, None] - gamma[None, :])
    # if use exp(x) directly, it will get inf at upper triangular part
    # so we use exp(x) only when x <= 0 and set the other part to exp(-inf)
    return tl.exp(tl.where(x <= 0, x, float("-inf")))


@triton.heuristics(
    {
        "USE_BETA": lambda args: args["beta_ptr"] is not None,
        "IS_VARLEN": lambda args: args["cu_seq_len"] is not None,
    }
)
@triton.autotune(
    configs=[
        triton.Config(
            {"BLOCK_SIZE_KD": 128, "BLOCK_SIZE_VD": 64}, num_warps=4, num_stages=3
        ),
        triton.Config(
            {"BLOCK_SIZE_KD": 64, "BLOCK_SIZE_VD": 64}, num_warps=4, num_stages=3
        ),
        triton.Config(
            {"BLOCK_SIZE_KD": 32, "BLOCK_SIZE_VD": 32}, num_warps=2, num_stages=3
        ),
    ],
    key=["CHUNK_SIZE", "head_dim_qk", "head_dim_v", "USE_BETA"],
    use_cuda_graph=use_cuda_graph,
)
@triton.jit
def chunk_fwd_kernel_o(
    q_ptr,
    k_ptr,
    v_ptr,
    o_ptr,
    alpha_ptr,
    beta_ptr,
    chunk_state_ptr,
    scale,
    # shapes
    cu_seq_len,
    cu_chunk_len,
    seq_len,
    num_chunks,
    num_heads,
    head_dim_qk,
    head_dim_v,
    # strides
    stride_q_b,
    stride_q_n,
    stride_q_h,
    stride_q_d,
    stride_k_b,
    stride_k_n,
    stride_k_h,
    stride_k_d,
    stride_v_b,
    stride_v_n,
    stride_v_h,
    stride_v_d,
    stride_o_b,
    stride_o_n,
    stride_o_h,
    stride_o_d,
    stride_a_b,
    stride_a_n,
    stride_a_h,
    stride_b_b,
    stride_b_n,
    stride_b_h,
    stride_cs_b,
    stride_cs_n,
    stride_cs_h,
    stride_cs_kd,
    stride_cs_vd,
    # block sizes
    CHUNK_SIZE: tl.constexpr,
    BLOCK_SIZE_KD: tl.constexpr,
    BLOCK_SIZE_VD: tl.constexpr,
    # option
    USE_BETA: tl.constexpr,
    IS_VARLEN: tl.constexpr,
):
    pid_vd, pid_c, pid_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    pid_b, pid_h = pid_bh // num_heads, pid_bh % num_heads
    # move ptr to the start of this batch
    if IS_VARLEN:
        seq_start = tl.load(cu_seq_len + pid_b).to(tl.int32)
        seq_end = tl.load(cu_seq_len + pid_b + 1).to(tl.int32)
        seq_len = seq_end - seq_start
        chunk_start = tl.load(cu_chunk_len + pid_b).to(tl.int32)
        chunk_end = tl.load(cu_chunk_len + pid_b + 1).to(tl.int32)
        num_chunks = chunk_end - chunk_start
        q_ptr = q_ptr + seq_start * stride_q_n + pid_h * stride_q_h
        k_ptr = k_ptr + seq_start * stride_k_n + pid_h * stride_k_h
        v_ptr = v_ptr + seq_start * stride_v_n + pid_h * stride_v_h
        o_ptr = o_ptr + seq_start * stride_o_n + pid_h * stride_o_h
        alpha_ptr = alpha_ptr + seq_start * stride_a_n + pid_h * stride_a_h
        beta_ptr = beta_ptr + seq_start * stride_b_n + pid_h * stride_b_h
        chunk_state_ptr = (
            chunk_state_ptr + chunk_start * stride_cs_n + pid_h * stride_cs_h
        )
    else:
        q_ptr = q_ptr + pid_b * stride_q_b + pid_h * stride_q_h
        k_ptr = k_ptr + pid_b * stride_k_b + pid_h * stride_k_h
        v_ptr = v_ptr + pid_b * stride_v_b + pid_h * stride_v_h
        o_ptr = o_ptr + pid_b * stride_o_b + pid_h * stride_o_h
        alpha_ptr = alpha_ptr + pid_b * stride_a_b + pid_h * stride_a_h
        beta_ptr = beta_ptr + pid_b * stride_b_b + pid_h * stride_b_h
        chunk_state_ptr = chunk_state_ptr + pid_b * stride_cs_b + pid_h * stride_cs_h
    if pid_c >= num_chunks:
        return
    # ptrs
    q_ptrs = tl.make_block_ptr(
        base=q_ptr,
        shape=(seq_len, head_dim_qk),
        strides=(stride_q_n, stride_q_d),
        offsets=(pid_c * CHUNK_SIZE, 0),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_KD),
        order=(1, 0),
    )
    k_ptrs = tl.make_block_ptr(
        base=k_ptr,
        shape=(head_dim_qk, seq_len),
        strides=(stride_k_d, stride_k_n),
        offsets=(0, pid_c * CHUNK_SIZE),
        block_shape=(BLOCK_SIZE_KD, CHUNK_SIZE),
        order=(0, 1),
    )
    chunk_state_ptrs = tl.make_block_ptr(
        base=chunk_state_ptr + pid_c * stride_cs_n,
        shape=(head_dim_qk, head_dim_v),
        strides=(stride_cs_kd, stride_cs_vd),
        offsets=(0, pid_vd * BLOCK_SIZE_VD),
        block_shape=(BLOCK_SIZE_KD, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    alpha_ptrs = tl.make_block_ptr(
        base=alpha_ptr,
        shape=(seq_len,),
        strides=(stride_a_n,),
        offsets=(pid_c * CHUNK_SIZE,),
        block_shape=(CHUNK_SIZE,),
        order=(0,),
    )
    if USE_BETA:
        beta_ptrs = tl.make_block_ptr(
            base=beta_ptr,
            shape=(seq_len,),
            strides=(stride_b_n,),
            offsets=(pid_c * CHUNK_SIZE,),
            block_shape=(CHUNK_SIZE,),
            order=(0,),
        )
    # history output and qk buffer
    acc_o = tl.zeros([CHUNK_SIZE, BLOCK_SIZE_VD], dtype=tl.float32)
    acc_qk = tl.zeros([CHUNK_SIZE, CHUNK_SIZE], dtype=tl.float32)
    # for loop at key head dim, compute inter chunk output and qk
    for _ in range(0, head_dim_qk, BLOCK_SIZE_KD):
        q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
        k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
        state = tl.load(
            chunk_state_ptrs, boundary_check=(0, 1), padding_option="zero"
        ).to(q.dtype)
        # inter chunk output, [CHUNK_SIZE, BLOCK_SIZE_KD] @ [BLOCK_SIZE_KD, BLOCK_SIZE_VD] -> [CHUNK_SIZE, BLOCK_SIZE_VD]
        acc_o += tl.dot(q, state)
        # qk result, [CHUNK_SIZE, BLOCK_SIZE_KD] @ [BLOCK_SIZE_KD, CHUNK_SIZE] -> [CHUNK_SIZE, CHUNK_SIZE]
        acc_qk += tl.dot(q, k)
        # update ptrs
        q_ptrs = tl.advance(q_ptrs, (0, BLOCK_SIZE_KD))
        k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_KD, 0))
        chunk_state_ptrs = tl.advance(chunk_state_ptrs, (BLOCK_SIZE_KD, 0))
    # load alpha and beta, compute gate
    alpha = tl.load(alpha_ptrs, boundary_check=(0,), padding_option="zero")
    alpha = tl.cumsum(alpha, axis=0)
    gate = exp_negative_only(alpha[:, None] - alpha[None, :])
    if USE_BETA:
        beta = tl.load(beta_ptrs, boundary_check=(0,), padding_option="zero")
        gate = gate * beta[None, :]
    acc_o = acc_o * tl.exp(alpha)[:, None]
    acc_qk = acc_qk * gate
    # causal mask
    off_c = tl.arange(0, CHUNK_SIZE)
    acc_qk = tl.where(off_c[:, None] >= off_c[None, :], acc_qk, 0)
    # intra chunk output, [CHUNK_SIZE, CHUNK_SIZE] @ [CHUNK_SIZE, BLOCK_SIZE_VD] -> [CHUNK_SIZE, BLOCK_SIZE_VD]
    v_ptrs = tl.make_block_ptr(
        base=v_ptr,
        shape=(seq_len, head_dim_v),
        strides=(stride_v_n, stride_v_d),
        offsets=(pid_c * CHUNK_SIZE, pid_vd * BLOCK_SIZE_VD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
    acc_o += tl.dot(acc_qk.to(v.dtype), v)
    acc_o *= scale
    # save output
    o_ptrs = tl.make_block_ptr(
        base=o_ptr,
        shape=(seq_len, head_dim_v),
        strides=(stride_o_n, stride_o_d),
        offsets=(pid_c * CHUNK_SIZE, pid_vd * BLOCK_SIZE_VD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    tl.store(o_ptrs, acc_o.to(o_ptrs.dtype.element_ty), boundary_check=(0, 1))


def chunk_fwd_o(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    alpha: torch.Tensor,
    beta: Optional[torch.Tensor],
    chunk_state: torch.Tensor,
    scale: float,
    cu_seq_len: Optional[torch.Tensor],
    cu_chunk_len: Optional[torch.Tensor],
    chunk_size: int,
) -> torch.Tensor:
    if cu_seq_len is None:
        batch_size, seq_len, num_heads, head_dim_qk, head_dim_v = *k.shape, v.shape[-1]
        num_chunks = math.ceil(seq_len / chunk_size)
        o = torch.empty(
            batch_size, seq_len, num_heads, head_dim_v, device=v.device, dtype=v.dtype
        )
    else:
        _, _, num_heads, head_dim_qk, head_dim_v = *k.shape, v.shape[-1]
        batch_size = cu_seq_len.shape[0] - 1
        seq_len = cu_seq_len[-1]
        num_chunks = (cu_chunk_len[1:] - cu_chunk_len[:-1]).max()
        o = torch.empty(
            1, seq_len, num_heads, head_dim_v, device=v.device, dtype=v.dtype
        )

    def grid(meta):
        return (
            triton.cdiv(head_dim_v, meta["BLOCK_SIZE_VD"]),
            num_chunks,
            batch_size * num_heads,
        )

    chunk_fwd_kernel_o[grid](
        q,
        k,
        v,
        o,
        alpha,
        beta,
        chunk_state,
        scale,
        cu_seq_len,
        cu_chunk_len,
        seq_len,
        num_chunks,
        num_heads,
        head_dim_qk,
        head_dim_v,
        q.stride(0),
        q.stride(1),
        q.stride(2),
        q.stride(3),
        k.stride(0),
        k.stride(1),
        k.stride(2),
        k.stride(3),
        v.stride(0),
        v.stride(1),
        v.stride(2),
        v.stride(3),
        o.stride(0),
        o.stride(1),
        o.stride(2),
        o.stride(3),
        alpha.stride(0),
        alpha.stride(1),
        alpha.stride(2),
        beta.stride(0) if beta is not None else 0,
        beta.stride(1) if beta is not None else 0,
        beta.stride(2) if beta is not None else 0,
        chunk_state.stride(0),
        chunk_state.stride(1),
        chunk_state.stride(2),
        chunk_state.stride(3),
        chunk_state.stride(4),
        CHUNK_SIZE=triton.next_power_of_2(chunk_size),
    )
    return o


@triton.heuristics(
    {
        "IS_VARLEN": lambda args: args["cu_seq_len"] is not None,
    }
)
@triton.autotune(
    configs=[
        triton.Config(
            {"BLOCK_SIZE_KD": 64, "BLOCK_SIZE_VD": 128}, num_warps=4, num_stages=3
        ),
        triton.Config(
            {"BLOCK_SIZE_KD": 64, "BLOCK_SIZE_VD": 64}, num_warps=4, num_stages=3
        ),
        triton.Config(
            {"BLOCK_SIZE_KD": 32, "BLOCK_SIZE_VD": 32}, num_warps=2, num_stages=3
        ),
    ],
    reset_to_zero=["da_ptr"],
    key=["CHUNK_SIZE", "head_dim_qk", "head_dim_v"],
    use_cuda_graph=use_cuda_graph,
)
@triton.jit
def chunk_bwd_dqka_kernel(
    q_ptr,
    k_ptr,
    v_ptr,
    alpha_ptr,
    beta_ptr,
    chunk_state_ptr,
    do_ptr,
    d_chunk_state_ptr,
    dq_ptr,
    dk_ptr,
    da_ptr,
    scale,
    # shapes
    cu_seq_len,
    cu_chunk_len,
    seq_len,
    num_chunks,
    num_heads,
    head_dim_qk,
    head_dim_v,
    # strides
    stride_q_b,
    stride_q_n,
    stride_q_h,
    stride_q_d,
    stride_k_b,
    stride_k_n,
    stride_k_h,
    stride_k_d,
    stride_v_b,
    stride_v_n,
    stride_v_h,
    stride_v_d,
    stride_a_b,
    stride_a_n,
    stride_a_h,
    stride_b_b,
    stride_b_n,
    stride_b_h,
    stride_cs_b,
    stride_cs_n,
    stride_cs_h,
    stride_cs_kd,
    stride_cs_vd,
    stride_do_b,
    stride_do_n,
    stride_do_h,
    stride_do_d,
    stride_dcs_b,
    stride_dcs_n,
    stride_dcs_h,
    stride_dcs_kd,
    stride_dcs_vd,
    stride_dq_b,
    stride_dq_n,
    stride_dq_h,
    stride_dq_d,
    stride_dk_b,
    stride_dk_n,
    stride_dk_h,
    stride_dk_d,
    stride_da_b,
    stride_da_n,
    stride_da_h,
    # block sizes
    CHUNK_SIZE: tl.constexpr,
    BLOCK_SIZE_KD: tl.constexpr,
    BLOCK_SIZE_VD: tl.constexpr,
    # option
    IS_VARLEN: tl.constexpr,
):
    pid_bh, pid_c, pid_kd = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    pid_b, pid_h = pid_bh // num_heads, pid_bh % num_heads
    # move ptr to the start of this batch
    if IS_VARLEN:
        seq_start = tl.load(cu_seq_len + pid_b).to(tl.int32)
        seq_end = tl.load(cu_seq_len + pid_b + 1).to(tl.int32)
        seq_len = seq_end - seq_start
        chunk_start = tl.load(cu_chunk_len + pid_b).to(tl.int32)
        chunk_end = tl.load(cu_chunk_len + pid_b + 1).to(tl.int32)
        num_chunks = chunk_end - chunk_start
        q_ptr = q_ptr + seq_start * stride_q_n + pid_h * stride_q_h
        k_ptr = k_ptr + seq_start * stride_k_n + pid_h * stride_k_h
        v_ptr = v_ptr + seq_start * stride_v_n + pid_h * stride_v_h
        alpha_ptr = alpha_ptr + seq_start * stride_a_n + pid_h * stride_a_h
        beta_ptr = beta_ptr + seq_start * stride_b_n + pid_h * stride_b_h
        chunk_state_ptr = (
            chunk_state_ptr + chunk_start * stride_cs_n + pid_h * stride_cs_h
        )
        do_ptr = do_ptr + seq_start * stride_do_n + pid_h * stride_do_h
        d_chunk_state_ptr = (
            d_chunk_state_ptr + chunk_start * stride_dcs_n + pid_h * stride_dcs_h
        )
        dq_ptr = dq_ptr + seq_start * stride_dq_n + pid_h * stride_dq_h
        dk_ptr = dk_ptr + seq_start * stride_dk_n + pid_h * stride_dk_h
        da_ptr = da_ptr + seq_start * stride_da_n + pid_h * stride_da_h
    else:
        q_ptr = q_ptr + pid_b * stride_q_b + pid_h * stride_q_h
        k_ptr = k_ptr + pid_b * stride_k_b + pid_h * stride_k_h
        v_ptr = v_ptr + pid_b * stride_v_b + pid_h * stride_v_h
        alpha_ptr = alpha_ptr + pid_b * stride_a_b + pid_h * stride_a_h
        beta_ptr = beta_ptr + pid_b * stride_b_b + pid_h * stride_b_h
        chunk_state_ptr = chunk_state_ptr + pid_b * stride_cs_b + pid_h * stride_cs_h
        do_ptr = do_ptr + pid_b * stride_do_b + pid_h * stride_do_h
        d_chunk_state_ptr = (
            d_chunk_state_ptr + pid_b * stride_dcs_b + pid_h * stride_dcs_h
        )
        dq_ptr = dq_ptr + pid_b * stride_dq_b + pid_h * stride_dq_h
        dk_ptr = dk_ptr + pid_b * stride_dk_b + pid_h * stride_dk_h
        da_ptr = da_ptr + pid_b * stride_da_b + pid_h * stride_da_h
    if pid_c >= num_chunks:
        return
    # ptrs
    q_ptrs = tl.make_block_ptr(
        base=q_ptr,
        shape=(seq_len, head_dim_qk),
        strides=(stride_q_n, stride_q_d),
        offsets=(pid_c * CHUNK_SIZE, pid_kd * BLOCK_SIZE_KD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_KD),
        order=(1, 0),
    )
    k_ptrs = tl.make_block_ptr(
        base=k_ptr,
        shape=(seq_len, head_dim_qk),
        strides=(stride_k_n, stride_k_d),
        offsets=(pid_c * CHUNK_SIZE, pid_kd * BLOCK_SIZE_KD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_KD),
        order=(1, 0),
    )
    v_ptrs = tl.make_block_ptr(
        base=v_ptr,
        shape=(seq_len, head_dim_v),
        strides=(stride_v_n, stride_v_d),
        offsets=(pid_c * CHUNK_SIZE, 0),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    alpha_ptrs = tl.make_block_ptr(
        base=alpha_ptr,
        shape=(seq_len,),
        strides=(stride_a_n,),
        offsets=(pid_c * CHUNK_SIZE,),
        block_shape=(CHUNK_SIZE,),
        order=(0,),
    )
    beta_ptrs = tl.make_block_ptr(
        base=beta_ptr,
        shape=(seq_len,),
        strides=(stride_b_n,),
        offsets=(pid_c * CHUNK_SIZE,),
        block_shape=(CHUNK_SIZE,),
        order=(0,),
    )
    do_ptrs = tl.make_block_ptr(
        base=do_ptr,
        shape=(seq_len, head_dim_v),
        strides=(stride_do_n, stride_do_d),
        offsets=(pid_c * CHUNK_SIZE, 0),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    d_chunk_state_ptrs = tl.make_block_ptr(
        base=d_chunk_state_ptr + pid_c * stride_dcs_n,
        shape=(head_dim_v, head_dim_qk),
        strides=(stride_dcs_vd, stride_dcs_kd),
        offsets=(0, pid_kd * BLOCK_SIZE_KD),
        block_shape=(BLOCK_SIZE_VD, BLOCK_SIZE_KD),
        order=(0, 1),
    )
    chunk_state_ptrs = tl.make_block_ptr(
        base=chunk_state_ptr + pid_c * stride_cs_n,
        shape=(head_dim_v, head_dim_qk),
        strides=(stride_cs_vd, stride_cs_kd),
        offsets=(0, pid_kd * BLOCK_SIZE_KD),
        block_shape=(BLOCK_SIZE_VD, BLOCK_SIZE_KD),
        order=(0, 1),
    )
    # init dq, dk
    dq = tl.zeros([CHUNK_SIZE, BLOCK_SIZE_KD], dtype=tl.float32)
    dk = tl.zeros([CHUNK_SIZE, BLOCK_SIZE_KD], dtype=tl.float32)
    dov = tl.zeros([CHUNK_SIZE, CHUNK_SIZE], dtype=tl.float32)
    da = tl.zeros(
        [
            CHUNK_SIZE,
        ],
        dtype=tl.float32,
    )
    d_sum_alpha = tl.zeros(
        [
            1,
        ],
        dtype=tl.float32,
    )

    # load beta, shape: [CHUNK_SIZE,]
    beta = tl.load(beta_ptrs, boundary_check=(0,), padding_option="zero")

    # compute inter block gradient for qk
    for _ in range(0, head_dim_v, BLOCK_SIZE_VD):
        # load v, do, shape: [CHUNK_SIZE, BLOCK_SIZE_VD]
        v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
        v = (v * beta[:, None]).to(v.dtype)
        do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
        # chunk_state and d_chunk_stte, [BLOCK_SIZE_VD, BLOCK_SIZE_KD]
        d_chunk_state = tl.load(
            d_chunk_state_ptrs, boundary_check=(0, 1), padding_option="zero"
        ).to(v.dtype)
        chunk_state = tl.load(
            chunk_state_ptrs, boundary_check=(0, 1), padding_option="zero"
        ).to(do.dtype)
        # compute do @ v.T, [CHUNK_SIZE, BLOCK_SIZE_VD] @ [BLOCK_SIZE_VD, CHUNK_SIZE] -> [CHUNK_SIZE, CHUNK_SIZE]
        dov += tl.dot(do, tl.trans(v)) * scale
        # compute inter block dq = do @ state.T, [CHUNK_SIZE, BLOCK_SIZE_VD] @ [BLOCK_SIZE_VD, BLOCK_SIZE_KD] -> [CHUNK_SIZE, BLOCK_SIZE_KD]
        dq += tl.dot(do, chunk_state) * scale
        # compute inter block dk = v @ d_chunk_state.T, [CHUNK_SIZE, BLOCK_SIZE_VD] @ [BLOCK_SIZE_VD, BLOCK_SIZE_KD] -> [CHUNK_SIZE, BLOCK_SIZE_KD]
        dk += tl.dot(v, d_chunk_state)
        # compute state * d_chunk_state
        d_sum_alpha += tl.sum(d_chunk_state * chunk_state)
        # update ptrs
        v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_VD))
        do_ptrs = tl.advance(do_ptrs, (0, BLOCK_SIZE_VD))
        d_chunk_state_ptrs = tl.advance(d_chunk_state_ptrs, (BLOCK_SIZE_VD, 0))
        chunk_state_ptrs = tl.advance(chunk_state_ptrs, (BLOCK_SIZE_VD, 0))

    # load alpha, shape: [CHUNK_SIZE,]
    alpha = tl.load(alpha_ptrs, boundary_check=(0,), padding_option="zero")
    sum_alpha = tl.sum(alpha)
    cum_alpha = tl.cumsum(alpha, axis=0)
    # load q, k, shape: [CHUNK_SIZE, BLOCK_SIZE_KD]
    q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
    k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")

    # gate for inter chunk dq and dk, [CHUNK_SIZE, BLOCK_SIZE_KD] * [CHUNK_SIZE, 1] -> [CHUNK_SIZE, BLOCK_SIZE_KD]
    dq = dq * tl.exp(cum_alpha)[:, None]
    dk = dk * tl.exp(sum_alpha - cum_alpha)[:, None]

    # gate for dov
    off_c = tl.arange(0, CHUNK_SIZE)
    dov = tl.where(
        off_c[:, None] >= off_c[None, :],
        dov * exp_negative_only(cum_alpha[:, None] - cum_alpha[None, :]),
        0,
    )

    # compute gradient of alpha
    d_sum_alpha *= tl.exp(sum_alpha)
    d_sum_alpha += tl.sum(k * dk)
    da += tl.sum(dq * q, axis=1)
    da -= tl.sum(k * dk, axis=1)
    dovqk = dov * tl.dot(q, tl.trans(k))
    da += tl.sum(dovqk, axis=1)
    da -= tl.sum(dovqk, axis=0)
    da = tl.cumsum(da, axis=0, reverse=True) + d_sum_alpha

    # intra chunk dq and dk, [CHUNK_SIZE, CHUNK_SIZE] @ [CHUNK_SIZE, BLOCK_SIZE_KD] -> [CHUNK_SIZE, BLOCK_SIZE_KD]
    dov = dov.to(k.dtype)
    dq += tl.dot(dov, k)
    dk += tl.dot(tl.trans(dov), q)

    # store dq, dk, da
    dq_ptrs = tl.make_block_ptr(
        base=dq_ptr,
        shape=(seq_len, head_dim_qk),
        strides=(stride_dq_n, stride_dq_d),
        offsets=(pid_c * CHUNK_SIZE, pid_kd * BLOCK_SIZE_KD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_KD),
        order=(1, 0),
    )
    dk_ptrs = tl.make_block_ptr(
        base=dk_ptr,
        shape=(seq_len, head_dim_qk),
        strides=(stride_dk_n, stride_dk_d),
        offsets=(pid_c * CHUNK_SIZE, pid_kd * BLOCK_SIZE_KD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_KD),
        order=(1, 0),
    )
    da_ptrs = da_ptr + (pid_c * CHUNK_SIZE + off_c) * stride_da_n
    tl.store(dq_ptrs, dq.to(dq_ptrs.dtype.element_ty), boundary_check=(0, 1))
    tl.store(dk_ptrs, dk.to(dk_ptrs.dtype.element_ty), boundary_check=(0, 1))
    tl.atomic_add(
        da_ptrs,
        da.to(da_ptrs.dtype.element_ty),
        mask=off_c < seq_len - pid_c * CHUNK_SIZE,
        sem="relaxed",
    )


def chunk_bwd_dqka(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    alpha: torch.Tensor,
    beta: torch.Tensor,
    chunk_state: torch.Tensor,
    do: torch.Tensor,
    d_chunk_state: torch.Tensor,
    scale: float,
    cu_seq_len: Optional[torch.Tensor],
    cu_chunk_len: Optional[torch.Tensor],
    chunk_size: int,
):
    if cu_seq_len is None:
        batch_size, seq_len, num_heads, head_dim_qk, head_dim_v = *k.shape, v.shape[-1]
        num_chunks = math.ceil(seq_len / chunk_size)
    else:
        _, _, num_heads, head_dim_qk, head_dim_v = *k.shape, v.shape[-1]
        batch_size = cu_seq_len.shape[0] - 1
        seq_len = cu_seq_len[-1]
        num_chunks = (cu_chunk_len[1:] - cu_chunk_len[:-1]).max()

    dq = torch.empty_like(q)
    dk = torch.empty_like(k)
    da = torch.zeros_like(alpha, dtype=torch.float32)

    # launch kernel
    def grid(meta):
        return (
            batch_size * num_heads,
            num_chunks,
            triton.cdiv(head_dim_qk, meta["BLOCK_SIZE_KD"]),
        )

    chunk_bwd_dqka_kernel[grid](
        q,
        k,
        v,
        alpha,
        beta,
        chunk_state,
        do,
        d_chunk_state,
        dq,
        dk,
        da,
        scale,
        cu_seq_len,
        cu_chunk_len,
        seq_len,
        num_chunks,
        num_heads,
        head_dim_qk,
        head_dim_v,
        q.stride(0),
        q.stride(1),
        q.stride(2),
        q.stride(3),
        k.stride(0),
        k.stride(1),
        k.stride(2),
        k.stride(3),
        v.stride(0),
        v.stride(1),
        v.stride(2),
        v.stride(3),
        alpha.stride(0),
        alpha.stride(1),
        alpha.stride(2),
        beta.stride(0),
        beta.stride(1),
        beta.stride(2),
        chunk_state.stride(0),
        chunk_state.stride(1),
        chunk_state.stride(2),
        chunk_state.stride(3),
        chunk_state.stride(4),
        do.stride(0),
        do.stride(1),
        do.stride(2),
        do.stride(3),
        d_chunk_state.stride(0),
        d_chunk_state.stride(1),
        d_chunk_state.stride(2),
        d_chunk_state.stride(3),
        d_chunk_state.stride(4),
        dq.stride(0),
        dq.stride(1),
        dq.stride(2),
        dq.stride(3),
        dk.stride(0),
        dk.stride(1),
        dk.stride(2),
        dk.stride(3),
        da.stride(0),
        da.stride(1),
        da.stride(2),
        CHUNK_SIZE=triton.next_power_of_2(chunk_size),
    )
    return dq, dk, da


@triton.heuristics(
    {
        "IS_VARLEN": lambda args: args["cu_seq_len"] is not None,
    }
)
@triton.autotune(
    configs=[
        triton.Config(
            {"BLOCK_SIZE_KD": 128, "BLOCK_SIZE_VD": 64}, num_warps=4, num_stages=3
        ),
        triton.Config(
            {"BLOCK_SIZE_KD": 64, "BLOCK_SIZE_VD": 64}, num_warps=4, num_stages=3
        ),
        triton.Config(
            {"BLOCK_SIZE_KD": 32, "BLOCK_SIZE_VD": 32}, num_warps=2, num_stages=3
        ),
    ],
    reset_to_zero=["db_ptr"],
    key=["CHUNK_SIZE", "head_dim_qk", "head_dim_v"],
    use_cuda_graph=use_cuda_graph,
)
@triton.jit
def chunk_bwd_dvb_kernel(
    q_ptr,
    k_ptr,
    v_ptr,
    alpha_ptr,
    beta_ptr,
    do_ptr,
    d_chunk_state_ptr,
    dv_ptr,
    db_ptr,
    scale,
    # shapes
    cu_seq_len,
    cu_chunk_len,
    seq_len,
    num_chunks,
    num_heads,
    head_dim_qk,
    head_dim_v,
    # strides
    stride_q_b,
    stride_q_n,
    stride_q_h,
    stride_q_d,
    stride_k_b,
    stride_k_n,
    stride_k_h,
    stride_k_d,
    stride_v_b,
    stride_v_n,
    stride_v_h,
    stride_v_d,
    stride_a_b,
    stride_a_n,
    stride_a_h,
    stride_b_b,
    stride_b_n,
    stride_b_h,
    stride_do_b,
    stride_do_n,
    stride_do_h,
    stride_do_d,
    stride_dcs_b,
    stride_dcs_n,
    stride_dcs_h,
    stride_dcs_kd,
    stride_dcs_vd,
    stride_dv_b,
    stride_dv_n,
    stride_dv_h,
    stride_dv_d,
    stride_db_b,
    stride_db_n,
    stride_db_h,
    # block sizes
    CHUNK_SIZE: tl.constexpr,
    BLOCK_SIZE_KD: tl.constexpr,
    BLOCK_SIZE_VD: tl.constexpr,
    # option
    IS_VARLEN: tl.constexpr,
):
    pid_bh, pid_c, pid_vd = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    pid_b, pid_h = pid_bh // num_heads, pid_bh % num_heads
    # move ptr to the start of this batch
    if IS_VARLEN:
        seq_start = tl.load(cu_seq_len + pid_b).to(tl.int32)
        seq_end = tl.load(cu_seq_len + pid_b + 1).to(tl.int32)
        seq_len = seq_end - seq_start
        chunk_start = tl.load(cu_chunk_len + pid_b).to(tl.int32)
        chunk_end = tl.load(cu_chunk_len + pid_b + 1).to(tl.int32)
        num_chunks = chunk_end - chunk_start
        q_ptr = q_ptr + seq_start * stride_q_n + pid_h * stride_q_h
        k_ptr = k_ptr + seq_start * stride_k_n + pid_h * stride_k_h
        v_ptr = v_ptr + seq_start * stride_v_n + pid_h * stride_v_h
        alpha_ptr = alpha_ptr + seq_start * stride_a_n + pid_h * stride_a_h
        beta_ptr = beta_ptr + seq_start * stride_b_n + pid_h * stride_b_h
        do_ptr = do_ptr + seq_start * stride_do_n + pid_h * stride_do_h
        d_chunk_state_ptr = (
            d_chunk_state_ptr + chunk_start * stride_dcs_n + pid_h * stride_dcs_h
        )
        dv_ptr = dv_ptr + seq_start * stride_dv_n + pid_h * stride_dv_h
        db_ptr = db_ptr + seq_start * stride_db_n + pid_h * stride_db_h
    else:
        q_ptr = q_ptr + pid_b * stride_q_b + pid_h * stride_q_h
        k_ptr = k_ptr + pid_b * stride_k_b + pid_h * stride_k_h
        v_ptr = v_ptr + pid_b * stride_v_b + pid_h * stride_v_h
        alpha_ptr = alpha_ptr + pid_b * stride_a_b + pid_h * stride_a_h
        beta_ptr = beta_ptr + pid_b * stride_b_b + pid_h * stride_b_h
        do_ptr = do_ptr + pid_b * stride_do_b + pid_h * stride_do_h
        d_chunk_state_ptr = (
            d_chunk_state_ptr + pid_b * stride_dcs_b + pid_h * stride_dcs_h
        )
        dv_ptr = dv_ptr + pid_b * stride_dv_b + pid_h * stride_dv_h
        db_ptr = db_ptr + pid_b * stride_db_b + pid_h * stride_db_h
    if pid_c >= num_chunks:
        return
    # ptrs
    q_ptrs = tl.make_block_ptr(
        base=q_ptr,
        shape=(head_dim_qk, seq_len),
        strides=(stride_q_d, stride_q_n),
        offsets=(0, pid_c * CHUNK_SIZE),
        block_shape=(BLOCK_SIZE_KD, CHUNK_SIZE),
        order=(0, 1),
    )
    k_ptrs = tl.make_block_ptr(
        base=k_ptr,
        shape=(seq_len, head_dim_qk),
        strides=(stride_k_n, stride_k_d),
        offsets=(pid_c * CHUNK_SIZE, 0),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_KD),
        order=(1, 0),
    )
    v_ptrs = tl.make_block_ptr(
        base=v_ptr,
        shape=(seq_len, head_dim_v),
        strides=(stride_v_n, stride_v_d),
        offsets=(pid_c * CHUNK_SIZE, pid_vd * BLOCK_SIZE_VD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    do_ptrs = tl.make_block_ptr(
        base=do_ptr,
        shape=(seq_len, head_dim_v),
        strides=(stride_do_n, stride_do_d),
        offsets=(pid_c * CHUNK_SIZE, pid_vd * BLOCK_SIZE_VD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    d_chunk_state_ptrs = tl.make_block_ptr(
        base=d_chunk_state_ptr + pid_c * stride_dcs_n,
        shape=(head_dim_qk, head_dim_v),
        strides=(stride_dcs_kd, stride_dcs_vd),
        offsets=(0, pid_vd * BLOCK_SIZE_VD),
        block_shape=(BLOCK_SIZE_KD, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    alpha_ptrs = tl.make_block_ptr(
        base=alpha_ptr,
        shape=(seq_len,),
        strides=(stride_a_n,),
        offsets=(pid_c * CHUNK_SIZE,),
        block_shape=(CHUNK_SIZE,),
        order=(0,),
    )
    beta_ptrs = tl.make_block_ptr(
        base=beta_ptr,
        shape=(seq_len,),
        strides=(stride_b_n,),
        offsets=(pid_c * CHUNK_SIZE,),
        block_shape=(CHUNK_SIZE,),
        order=(0,),
    )
    # init buffer
    dbv = tl.zeros([CHUNK_SIZE, BLOCK_SIZE_VD], dtype=tl.float32)
    kq = tl.zeros([CHUNK_SIZE, CHUNK_SIZE], dtype=tl.float32)
    # loop over k head dim, compute k @ q.T and inter block part of dv
    for _ in range(0, head_dim_qk, BLOCK_SIZE_KD):
        # load q, k, shape: [BLOCK_SIZE_KD, CHUNK_SIZE]
        q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
        k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
        # load d_chunk_state, shape: [BLOCK_SIZE_VD, BLOCK_SIZE_KD]
        d_chunk_state = tl.load(
            d_chunk_state_ptrs, boundary_check=(0, 1), padding_option="zero"
        ).to(k.dtype)
        # compute k @ q.T
        kq += tl.dot(k, q) * scale
        # compute inter block part of dv
        dbv += tl.dot(k, d_chunk_state)
        # update ptrs
        q_ptrs = tl.advance(q_ptrs, (BLOCK_SIZE_KD, 0))
        k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_KD))
        d_chunk_state_ptrs = tl.advance(d_chunk_state_ptrs, (BLOCK_SIZE_KD, 0))

    # load alpha, shape: [CHUNK_SIZE,]
    alpha = tl.load(alpha_ptrs, boundary_check=(0,), padding_option="zero")
    sum_alpha = tl.sum(alpha, axis=0)
    cum_alpha = tl.cumsum(alpha, axis=0)

    # gate for inter block part of dv
    dbv *= tl.exp(sum_alpha - cum_alpha)[:, None]

    # gate for kq
    off_c = tl.arange(0, CHUNK_SIZE)
    kq = tl.where(
        off_c[:, None] <= off_c[None, :],
        kq * exp_negative_only(cum_alpha[None, :] - cum_alpha[:, None]),
        0,
    )

    # intra block part of dv, [CHUNK_SIZE, CHUNK_SIZE] @ [CHUNK_SIZE, BLOCK_SIZE_VD] -> [CHUNK_SIZE, BLOCK_SIZE_VD]
    # load do, shape: [CHUNK_SIZE, BLOCK_SIZE_VD]
    do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
    dbv += tl.dot(kq.to(do.dtype), do)

    # get actual dv and db
    # load v, shape: [CHUNK_SIZE, BLOCK_SIZE_VD]
    v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero")
    # load beta, shape: [CHUNK_SIZE,]
    beta = tl.load(beta_ptrs, boundary_check=(0,), padding_option="zero")
    dv = dbv * beta[:, None]
    db = tl.sum(dbv * v, axis=1)

    # store dv db
    dv_ptrs = tl.make_block_ptr(
        base=dv_ptr,
        shape=(seq_len, head_dim_v),
        strides=(stride_dv_n, stride_dv_d),
        offsets=(pid_c * CHUNK_SIZE, pid_vd * BLOCK_SIZE_VD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    db_ptrs = db_ptr + (pid_c * CHUNK_SIZE + off_c) * stride_db_n
    tl.store(dv_ptrs, dv.to(dv_ptrs.dtype.element_ty), boundary_check=(0, 1))
    tl.atomic_add(
        db_ptrs,
        db.to(db_ptrs.dtype.element_ty),
        mask=pid_c * CHUNK_SIZE + off_c < seq_len,
        sem="relaxed",
    )


def chunk_bwd_dvb(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    alpha: torch.Tensor,
    beta: torch.Tensor,
    do: torch.Tensor,
    d_chunk_state: torch.Tensor,
    scale: float,
    cu_seq_len: Optional[torch.Tensor],
    cu_chunk_len: Optional[torch.Tensor],
    chunk_size: int,
):
    if cu_seq_len is None:
        batch_size, seq_len, num_heads, head_dim_qk, head_dim_v = *k.shape, v.shape[-1]
        num_chunks = math.ceil(seq_len / chunk_size)
    else:
        _, _, num_heads, head_dim_qk, head_dim_v = *k.shape, v.shape[-1]
        batch_size = cu_seq_len.shape[0] - 1
        seq_len = cu_seq_len[-1]
        num_chunks = (cu_chunk_len[1:] - cu_chunk_len[:-1]).max()

    dv = torch.empty_like(v)
    db = torch.zeros_like(beta)

    # launch kernel
    def grid(meta):
        return (
            batch_size * num_heads,
            num_chunks,
            triton.cdiv(head_dim_v, meta["BLOCK_SIZE_VD"]),
        )

    chunk_bwd_dvb_kernel[grid](
        q,
        k,
        v,
        alpha,
        beta,
        do,
        d_chunk_state,
        dv,
        db,
        scale,
        cu_seq_len,
        cu_chunk_len,
        seq_len,
        num_chunks,
        num_heads,
        head_dim_qk,
        head_dim_v,
        q.stride(0),
        q.stride(1),
        q.stride(2),
        q.stride(3),
        k.stride(0),
        k.stride(1),
        k.stride(2),
        k.stride(3),
        v.stride(0),
        v.stride(1),
        v.stride(2),
        v.stride(3),
        alpha.stride(0),
        alpha.stride(1),
        alpha.stride(2),
        beta.stride(0),
        beta.stride(1),
        beta.stride(2),
        do.stride(0),
        do.stride(1),
        do.stride(2),
        do.stride(3),
        d_chunk_state.stride(0),
        d_chunk_state.stride(1),
        d_chunk_state.stride(2),
        d_chunk_state.stride(3),
        d_chunk_state.stride(4),
        dv.stride(0),
        dv.stride(1),
        dv.stride(2),
        dv.stride(3),
        db.stride(0),
        db.stride(1),
        db.stride(2),
        CHUNK_SIZE=triton.next_power_of_2(chunk_size),
    )
    return dv, db
