import math
import torch
import triton
import triton.language as tl
from typing import Optional
from ..common.utils import is_nvidia_hopper, use_cuda_graph

NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
NUM_STAGES = [2, 3, 4]


@triton.heuristics(
    {
        "USE_INITIAL_STATE": lambda args: args["initial_state_ptr"] is not None,
        "STORE_FINAL_STATE": lambda args: args["last_state_ptr"] is not None,
        "USE_BETA": lambda args: args["beta_ptr"] is not None,
        "IS_VARLEN": lambda args: args["cu_seq_len"] is not None,
    }
)
@triton.autotune(
    configs=[
        triton.Config(
            {"BLOCK_SIZE_KD": BK, "BLOCK_SIZE_VD": BV}, num_warps=nw, num_stages=ns
        )
        for BK in [32, 64]
        for BV in [32, 64]
        for nw in NUM_WARPS
        for ns in NUM_STAGES
    ],
    key=["CHUNK_SIZE", "head_dim_qk", "head_dim_v"],
    use_cuda_graph=use_cuda_graph,
)
@triton.jit
def chunk_fwd_kernel_h(
    k_ptr,  # shape: 1, total_seqlen, num_heads, head_dim_qk
    v_ptr,  # shape: 1, total_seqlen, num_heads, head_dim_v
    alpha_ptr,  # shape: 1, total_seqlen, num_heads
    beta_ptr,  # shape: 1, total_seqlen, num_heads
    initial_state_ptr,  # shape: batch_size, num_heads, head_dim_qk, head_dim_v
    chunk_state_ptr,  # shape: total_num_chunks, num_heads, head_dim_qk, head_dim_v
    last_state_ptr,  # shape: batch_size, num_heads, head_dim_qk, head_dim_v
    # shapes
    cu_seq_len,
    cu_chunk_len,
    seq_len,
    num_chunks,
    num_heads,
    head_dim_qk,
    head_dim_v,
    # strides
    stride_k_b,
    stride_k_n,
    stride_k_h,
    stride_k_d,
    stride_v_b,
    stride_v_n,
    stride_v_h,
    stride_v_d,
    stride_a_b,
    stride_a_n,
    stride_a_h,
    stride_b_b,
    stride_b_n,
    stride_b_h,
    stride_is_b,
    stride_is_h,
    stride_is_kd,
    stride_is_vd,
    stride_cs_b,
    stride_cs_n,
    stride_cs_h,
    stride_cs_kd,
    stride_cs_vd,
    stride_ls_b,
    stride_ls_h,
    stride_ls_kd,
    stride_ls_vd,
    # block sizes
    CHUNK_SIZE: tl.constexpr,
    BLOCK_SIZE_KD: tl.constexpr,
    BLOCK_SIZE_VD: tl.constexpr,
    # optional
    USE_INITIAL_STATE: tl.constexpr,
    STORE_FINAL_STATE: tl.constexpr,
    USE_BETA: tl.constexpr,
    IS_VARLEN: tl.constexpr,
):
    pid_kd, pid_vd, pid_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    pid_b, pid_h = pid_bh // num_heads, pid_bh % num_heads
    # move ptr to the start of this batch
    if IS_VARLEN:
        seq_start = tl.load(cu_seq_len + pid_b).to(tl.int32)
        seq_end = tl.load(cu_seq_len + pid_b + 1).to(tl.int32)
        seq_len = seq_end - seq_start
        chunk_start = tl.load(cu_chunk_len + pid_b).to(tl.int32)
        chunk_end = tl.load(cu_chunk_len + pid_b + 1).to(tl.int32)
        num_chunks = chunk_end - chunk_start
        k_ptr = k_ptr + seq_start * stride_k_n + pid_h * stride_k_h
        v_ptr = v_ptr + seq_start * stride_v_n + pid_h * stride_v_h
        alpha_ptr = alpha_ptr + seq_start * stride_a_n + pid_h * stride_a_h
        beta_ptr = beta_ptr + seq_start * stride_b_n + pid_h * stride_b_h
        chunk_state_ptr = (
            chunk_state_ptr + chunk_start * stride_cs_n + pid_h * stride_cs_h
        )
    else:
        k_ptr = k_ptr + pid_b * stride_k_b + pid_h * stride_k_h
        v_ptr = v_ptr + pid_b * stride_v_b + pid_h * stride_v_h
        alpha_ptr = alpha_ptr + pid_b * stride_a_b + pid_h * stride_a_h
        beta_ptr = beta_ptr + pid_b * stride_b_b + pid_h * stride_b_h
        chunk_state_ptr = chunk_state_ptr + pid_b * stride_cs_b + pid_h * stride_cs_h
    if USE_INITIAL_STATE:
        initial_state_ptr = (
            initial_state_ptr + pid_b * stride_is_b + pid_h * stride_is_h
        )
    if STORE_FINAL_STATE:
        last_state_ptr = last_state_ptr + pid_b * stride_ls_b + pid_h * stride_ls_h
    # ptrs
    k_ptrs = tl.make_block_ptr(
        base=k_ptr,
        shape=(head_dim_qk, seq_len),
        strides=(stride_k_d, stride_k_n),
        offsets=(pid_kd * BLOCK_SIZE_KD, 0),
        block_shape=(BLOCK_SIZE_KD, CHUNK_SIZE),
        order=(0, 1),
    )
    v_ptrs = tl.make_block_ptr(
        base=v_ptr,
        shape=(seq_len, head_dim_v),
        strides=(stride_v_n, stride_v_d),
        offsets=(0, pid_vd * BLOCK_SIZE_VD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    alpha_ptrs = tl.make_block_ptr(
        base=alpha_ptr,
        shape=(seq_len,),
        strides=(stride_a_n,),
        offsets=(0,),
        block_shape=(CHUNK_SIZE,),
        order=(0,),
    )
    if USE_BETA:
        beta_ptrs = tl.make_block_ptr(
            base=beta_ptr,
            shape=(seq_len,),
            strides=(stride_b_n,),
            offsets=(0,),
            block_shape=(CHUNK_SIZE,),
            order=(0,),
        )
    chunk_state_ptrs = tl.make_block_ptr(
        base=chunk_state_ptr,
        shape=(num_chunks, head_dim_qk, head_dim_v),
        strides=(stride_cs_n, stride_cs_kd, stride_cs_vd),
        offsets=(0, pid_kd * BLOCK_SIZE_KD, pid_vd * BLOCK_SIZE_VD),
        block_shape=(1, BLOCK_SIZE_KD, BLOCK_SIZE_VD),
        order=(2, 1, 0),
    )
    # cumulative hidden state (k.T @ v), shape: [BLOCK_SIZE_KD, BLOCK_SIZE_VD]
    if USE_INITIAL_STATE:
        initial_state_ptrs = tl.make_block_ptr(
            base=initial_state_ptr,
            shape=(head_dim_qk, head_dim_v),
            strides=(stride_is_kd, stride_is_vd),
            offsets=(pid_kd * BLOCK_SIZE_KD, pid_vd * BLOCK_SIZE_VD),
            block_shape=(BLOCK_SIZE_KD, BLOCK_SIZE_VD),
            order=(1, 0),
        )
        state = tl.load(
            initial_state_ptrs, boundary_check=(0, 1), padding_option="zero"
        ).to(tl.float32)
    else:
        state = tl.zeros([BLOCK_SIZE_KD, BLOCK_SIZE_VD], dtype=tl.float32)
    # for loop over chunks
    for _ in range(num_chunks):
        # save cumsum result before this chunk to p_h
        tl.store(
            chunk_state_ptrs,
            state.to(chunk_state_ptrs.dtype.element_ty)[None, :, :],
            boundary_check=(0, 1, 2),
        )
        # load current chunk's k and v and gate
        k = tl.load(
            k_ptrs, boundary_check=(0, 1), padding_option="zero"
        )  # shape: [BLOCK_SIZE_KD, CHUNK_SIZE]
        v = tl.load(
            v_ptrs, boundary_check=(0, 1), padding_option="zero"
        )  # shape: [CHUNK_SIZE, BLOCK_SIZE_VD]
        alpha = tl.load(
            alpha_ptrs, boundary_check=(0,), padding_option="zero"
        )  # shape: [CHUNK_SIZE,]
        if USE_BETA:
            beta = tl.load(
                beta_ptrs, boundary_check=(0,), padding_option="zero"
            )  # shape: [CHUNK_SIZE,]
        # cumsum gate
        sum_alpha = tl.sum(alpha, axis=0)  # shape: [1,]
        cum_alpha = tl.cumsum(alpha, axis=0)  # shape: [CHUNK_SIZE,]
        # update state
        if USE_BETA:
            v = (v * tl.exp(sum_alpha - cum_alpha)[:, None] * beta[:, None]).to(v.dtype)
        else:
            v = (v * tl.exp(sum_alpha - cum_alpha)[:, None]).to(v.dtype)
        state = state * tl.exp(sum_alpha) + tl.dot(k, v)
        # update ptrs
        k_ptrs = tl.advance(k_ptrs, (0, CHUNK_SIZE))
        v_ptrs = tl.advance(v_ptrs, (CHUNK_SIZE, 0))
        alpha_ptrs = tl.advance(alpha_ptrs, (CHUNK_SIZE,))
        if USE_BETA:
            beta_ptrs = tl.advance(beta_ptrs, (CHUNK_SIZE,))
        chunk_state_ptrs = tl.advance(chunk_state_ptrs, (1, 0, 0))
    # save final kv state
    if STORE_FINAL_STATE:
        last_state_ptrs = tl.make_block_ptr(
            base=last_state_ptr,
            shape=(head_dim_qk, head_dim_v),
            strides=(stride_ls_kd, stride_ls_vd),
            offsets=(pid_kd * BLOCK_SIZE_KD, pid_vd * BLOCK_SIZE_VD),
            block_shape=(BLOCK_SIZE_KD, BLOCK_SIZE_VD),
            order=(1, 0),
        )
        tl.store(
            last_state_ptrs,
            state.to(last_state_ptrs.dtype.element_ty),
            boundary_check=(0, 1),
        )


def chunk_fwd_h(
    k: torch.Tensor,
    v: torch.Tensor,
    alpha: torch.Tensor,
    beta: Optional[torch.Tensor],
    initial_state: Optional[torch.Tensor],
    cu_seq_len: Optional[torch.Tensor],
    cu_chunk_len: Optional[torch.Tensor],
    chunk_size: int,
    output_final_state: bool,
    state_in_fp32: bool = True,
):
    if cu_seq_len is None:
        batch_size, seq_len, num_heads, head_dim_qk, head_dim_v = *k.shape, v.shape[-1]
        num_chunks = math.ceil(seq_len / chunk_size)
        chunk_state = torch.empty(
            batch_size,
            num_chunks,
            num_heads,
            head_dim_qk,
            head_dim_v,
            device=k.device,
            dtype=k.dtype if not state_in_fp32 else torch.float32,
        )
    else:
        _, _, num_heads, head_dim_qk, head_dim_v = *k.shape, v.shape[-1]
        batch_size = cu_seq_len.shape[0] - 1
        seq_len = cu_seq_len[-1]
        num_chunks = cu_chunk_len[-1]
        chunk_state = torch.empty(
            1,
            num_chunks,
            num_heads,
            head_dim_qk,
            head_dim_v,
            device=k.device,
            dtype=k.dtype if not state_in_fp32 else torch.float32,
        )

    if output_final_state:
        last_state = torch.empty(
            batch_size,
            num_heads,
            head_dim_qk,
            head_dim_v,
            device=k.device,
            dtype=k.dtype if not state_in_fp32 else torch.float32,
        )
    else:
        last_state = None

    def grid(meta):
        return (
            triton.cdiv(head_dim_qk, meta["BLOCK_SIZE_KD"]),
            triton.cdiv(head_dim_v, meta["BLOCK_SIZE_VD"]),
            num_heads * batch_size,
        )

    chunk_fwd_kernel_h[grid](
        k,
        v,
        alpha,
        beta,
        initial_state,
        chunk_state,
        last_state,
        cu_seq_len,
        cu_chunk_len,
        seq_len,
        num_chunks,
        num_heads,
        head_dim_qk,
        head_dim_v,
        k.stride(0),
        k.stride(1),
        k.stride(2),
        k.stride(3),
        v.stride(0),
        v.stride(1),
        v.stride(2),
        v.stride(3),
        alpha.stride(0),
        alpha.stride(1),
        alpha.stride(2),
        beta.stride(0) if beta is not None else 0,
        beta.stride(1) if beta is not None else 0,
        beta.stride(2) if beta is not None else 0,
        initial_state.stride(0) if initial_state is not None else 0,
        initial_state.stride(1) if initial_state is not None else 0,
        initial_state.stride(2) if initial_state is not None else 0,
        initial_state.stride(3) if initial_state is not None else 0,
        chunk_state.stride(0),
        chunk_state.stride(1),
        chunk_state.stride(2),
        chunk_state.stride(3),
        chunk_state.stride(4),
        last_state.stride(0) if last_state is not None else 0,
        last_state.stride(1) if last_state is not None else 0,
        last_state.stride(2) if last_state is not None else 0,
        last_state.stride(3) if last_state is not None else 0,
        CHUNK_SIZE=triton.next_power_of_2(chunk_size),
    )
    return chunk_state, last_state


@triton.heuristics(
    {
        "USE_LAST_STATE_GRADIENT": lambda args: args["dls_ptr"] is not None,
        "IS_VARLEN": lambda args: args["cu_seq_len"] is not None,
    }
)
@triton.autotune(
    configs=[
        triton.Config(
            {"BLOCK_SIZE_KD": BK, "BLOCK_SIZE_VD": BV}, num_warps=nw, num_stages=ns
        )
        for BK in [32, 64]
        for BV in [32, 64]
        for nw in NUM_WARPS
        for ns in NUM_STAGES
    ],
    key=["CHUNK_SIZE", "head_dim_qk", "head_dim_v"],
    use_cuda_graph=use_cuda_graph,
)
@triton.jit
def chunk_bwd_kernel_h(
    q_ptr,
    do_ptr,
    alpha_ptr,
    dls_ptr,
    dh_ptr,
    scale,
    # shapes
    cu_seq_len,
    cu_chunk_len,
    seq_len,
    num_chunks,
    num_heads,
    head_dim_qk,
    head_dim_v,
    # strides
    stride_q_b,
    stride_q_n,
    stride_q_h,
    stride_q_d,
    stride_do_b,
    stride_do_n,
    stride_do_h,
    stride_do_d,
    stride_a_b,
    stride_a_n,
    stride_a_h,
    stride_dls_b,
    stride_dls_h,
    stride_dls_kd,
    stride_dls_vd,
    stride_dh_b,
    stride_dh_n,
    stride_dh_h,
    stride_dh_kd,
    stride_dh_vd,
    # block sizes
    CHUNK_SIZE: tl.constexpr,
    BLOCK_SIZE_KD: tl.constexpr,
    BLOCK_SIZE_VD: tl.constexpr,
    USE_LAST_STATE_GRADIENT: tl.constexpr,
    IS_VARLEN: tl.constexpr,
):
    pid_bh, pid_kd, pid_vd = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    pid_b, pid_h = pid_bh // num_heads, pid_bh % num_heads
    # move ptr to the start of this batch
    if IS_VARLEN:
        seq_start = tl.load(cu_seq_len + pid_b).to(tl.int32)
        seq_end = tl.load(cu_seq_len + pid_b + 1).to(tl.int32)
        seq_len = seq_end - seq_start
        chunk_start = tl.load(cu_chunk_len + pid_b).to(tl.int32)
        chunk_end = tl.load(cu_chunk_len + pid_b + 1).to(tl.int32)
        num_chunks = chunk_end - chunk_start
        q_ptr = q_ptr + seq_start * stride_q_n + pid_h * stride_q_h
        do_ptr = do_ptr + seq_start * stride_do_n + pid_h * stride_do_h
        alpha_ptr = alpha_ptr + seq_start * stride_a_n + pid_h * stride_a_h
        dh_ptr = dh_ptr + chunk_start * stride_dh_n + pid_h * stride_dh_h
    else:
        q_ptr = q_ptr + pid_b * stride_q_b + pid_h * stride_q_h
        do_ptr = do_ptr + pid_b * stride_do_b + pid_h * stride_do_h
        alpha_ptr = alpha_ptr + pid_b * stride_a_b + pid_h * stride_a_h
        dh_ptr = dh_ptr + pid_b * stride_dh_b + pid_h * stride_dh_h
    if USE_LAST_STATE_GRADIENT:
        dls_ptr = dls_ptr + pid_b * stride_dls_b + pid_h * stride_dls_h
    # ptrs
    q_ptrs = tl.make_block_ptr(
        base=q_ptr,
        shape=(head_dim_qk, seq_len),
        strides=(stride_q_d, stride_q_n),
        offsets=(pid_kd * BLOCK_SIZE_KD, (num_chunks - 1) * CHUNK_SIZE),
        block_shape=(BLOCK_SIZE_KD, CHUNK_SIZE),
        order=(0, 1),
    )
    alpha_ptrs = tl.make_block_ptr(
        base=alpha_ptr,
        shape=(seq_len,),
        strides=(stride_a_n,),
        offsets=((num_chunks - 1) * CHUNK_SIZE,),
        block_shape=(CHUNK_SIZE,),
        order=(0,),
    )
    do_ptrs = tl.make_block_ptr(
        base=do_ptr,
        shape=(seq_len, head_dim_v),
        strides=(stride_do_n, stride_do_d),
        offsets=((num_chunks - 1) * CHUNK_SIZE, pid_vd * BLOCK_SIZE_VD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    dh_ptrs = tl.make_block_ptr(
        base=dh_ptr,
        shape=(num_chunks, head_dim_qk, head_dim_v),
        strides=(stride_dh_n, stride_dh_kd, stride_dh_vd),
        offsets=(num_chunks - 1, pid_kd * BLOCK_SIZE_KD, pid_vd * BLOCK_SIZE_VD),
        block_shape=(1, BLOCK_SIZE_KD, BLOCK_SIZE_VD),
        order=(2, 1, 0),
    )
    # init dh
    dh = tl.zeros([BLOCK_SIZE_KD, BLOCK_SIZE_VD], dtype=tl.float32)
    if USE_LAST_STATE_GRADIENT:
        dls_ptrs = tl.make_block_ptr(
            base=dls_ptr,
            shape=(head_dim_qk, head_dim_v),
            strides=(stride_dls_kd, stride_dls_vd),
            offsets=(pid_kd * BLOCK_SIZE_KD, pid_vd * BLOCK_SIZE_VD),
            block_shape=(BLOCK_SIZE_KD, BLOCK_SIZE_VD),
            order=(1, 0),
        )
        dls = tl.load(dls_ptrs, boundary_check=(0, 1), padding_option="zero")
        dh += dls
    # loop over reversed chunks
    for _ in range(num_chunks - 1, -1, -1):
        # store dh
        tl.store(
            dh_ptrs,
            dh.to(dh_ptrs.dtype.element_ty)[None, :, :],
            boundary_check=(0, 1, 2),
        )
        # load q and do, shape: [BLOCK_SIZE_KD, CHUNK_SIZE]
        q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
        # load do, shape: [CHUNK_SIZE, BLOCK_SIZE_VD]
        do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero")
        # load alpha, shape: [CHUNK_SIZE,]
        alpha = tl.load(alpha_ptrs, boundary_check=(0,), padding_option="zero")
        sum_alpha = tl.sum(alpha, axis=0)  # shape: [1,]
        gamma = tl.cumsum(alpha, axis=0)  # shape: [CHUNK_SIZE,]
        # update dh with gate
        q = (q * tl.exp(gamma)[None, :]).to(q.dtype)
        dh = dh * tl.exp(sum_alpha) + tl.dot(q, do) * scale
        # update ptrs
        q_ptrs = tl.advance(q_ptrs, (0, -CHUNK_SIZE))
        do_ptrs = tl.advance(do_ptrs, (-CHUNK_SIZE, 0))
        alpha_ptrs = tl.advance(alpha_ptrs, (-CHUNK_SIZE,))
        dh_ptrs = tl.advance(dh_ptrs, (-1, 0, 0))


def chunk_bwd_h(
    q: torch.Tensor,
    do: torch.Tensor,
    alpha: torch.Tensor,
    dls: Optional[torch.Tensor],
    scale: float,
    cu_seq_len: torch.Tensor,
    cu_chunk_len: torch.Tensor,
    chunk_size: int,
    state_in_fp32: bool = False,
) -> torch.Tensor:
    if cu_seq_len is None:
        batch_size, seq_len, num_heads, head_dim_qk, head_dim_v = *q.shape, do.shape[-1]
        num_chunks = math.ceil(seq_len / chunk_size)
        dh = torch.empty(
            batch_size,
            num_chunks,
            num_heads,
            head_dim_qk,
            head_dim_v,
            device=q.device,
            dtype=q.dtype if not state_in_fp32 else torch.float32,
        )
    else:
        _, _, num_heads, head_dim_qk, head_dim_v = *q.shape, do.shape[-1]
        batch_size = cu_seq_len.shape[0] - 1
        seq_len = cu_seq_len[-1]
        num_chunks = cu_chunk_len[-1]
        dh = torch.empty(
            1,
            num_chunks,
            num_heads,
            head_dim_qk,
            head_dim_v,
            device=q.device,
            dtype=q.dtype if not state_in_fp32 else torch.float32,
        )

    def grid(meta):
        return (
            batch_size * num_heads,
            triton.cdiv(head_dim_qk, meta["BLOCK_SIZE_KD"]),
            triton.cdiv(head_dim_v, meta["BLOCK_SIZE_VD"]),
        )

    chunk_bwd_kernel_h[grid](
        q,
        do,
        alpha,
        dls,
        dh,
        scale,
        cu_seq_len,
        cu_chunk_len,
        seq_len,
        num_chunks,
        num_heads,
        head_dim_qk,
        head_dim_v,
        q.stride(0),
        q.stride(1),
        q.stride(2),
        q.stride(3),
        do.stride(0),
        do.stride(1),
        do.stride(2),
        do.stride(3),
        alpha.stride(0),
        alpha.stride(1),
        alpha.stride(2),
        dls.stride(0) if dls is not None else 0,
        dls.stride(1) if dls is not None else 0,
        dls.stride(2) if dls is not None else 0,
        dls.stride(3) if dls is not None else 0,
        dh.stride(0),
        dh.stride(1),
        dh.stride(2),
        dh.stride(3),
        dh.stride(4),
        CHUNK_SIZE=triton.next_power_of_2(chunk_size),
    )
    return dh


@triton.heuristics(
    {
        "USE_LAST_STATE_GRADIENT": lambda args: args["dls_ptr"] is not None,
        "IS_VARLEN": lambda args: args["cu_seq_len"] is not None,
    }
)
@triton.autotune(
    configs=[
        triton.Config(
            {"BLOCK_SIZE_KD": BK, "BLOCK_SIZE_VD": BV}, num_warps=nw, num_stages=ns
        )
        for BK in [32, 64]
        for BV in [32, 64]
        for nw in NUM_WARPS
        for ns in NUM_STAGES
    ],
    key=["CHUNK_SIZE", "head_dim_qk", "head_dim_v"],
    use_cuda_graph=use_cuda_graph,
)
@triton.jit
def chunk_sq_sk_bwd_kernel_h(
    q_ptr,
    k_ptr,
    dsq_ptr,
    dsk_ptr,
    alpha_ptr,
    dls_ptr,
    dh_ptr,
    scale,
    # shapes
    cu_seq_len,
    cu_chunk_len,
    seq_len,
    num_chunks,
    num_heads,
    head_dim_qk,
    head_dim_v,
    # strides
    stride_q_b,
    stride_q_n,
    stride_q_h,
    stride_q_d,
    stride_k_b,
    stride_k_n,
    stride_k_h,
    stride_k_d,
    stride_dsq_b,
    stride_dsq_n,
    stride_dsq_h,
    stride_dsq_d,
    stride_dsk_b,
    stride_dsk_n,
    stride_dsk_h,
    stride_dsk_d,
    stride_a_b,
    stride_a_n,
    stride_a_h,
    stride_dls_b,
    stride_dls_h,
    stride_dls_kd,
    stride_dls_vd,
    stride_dh_b,
    stride_dh_n,
    stride_dh_h,
    stride_dh_kd,
    stride_dh_vd,
    # block sizes
    CHUNK_SIZE: tl.constexpr,
    BLOCK_SIZE_KD: tl.constexpr,
    BLOCK_SIZE_VD: tl.constexpr,
    USE_LAST_STATE_GRADIENT: tl.constexpr,
    IS_VARLEN: tl.constexpr,
):
    pid_bh, pid_kd, pid_vd = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    pid_b, pid_h = pid_bh // num_heads, pid_bh % num_heads
    # move ptr to the start of this batch
    if IS_VARLEN:
        seq_start = tl.load(cu_seq_len + pid_b).to(tl.int32)
        seq_end = tl.load(cu_seq_len + pid_b + 1).to(tl.int32)
        seq_len = seq_end - seq_start
        chunk_start = tl.load(cu_chunk_len + pid_b).to(tl.int32)
        chunk_end = tl.load(cu_chunk_len + pid_b + 1).to(tl.int32)
        num_chunks = chunk_end - chunk_start
        q_ptr = q_ptr + seq_start * stride_q_n + pid_h * stride_q_h
        k_ptr = k_ptr + seq_start * stride_k_n + pid_h * stride_k_h
        dsq_ptr = dsq_ptr + seq_start * stride_dsq_n + pid_h * stride_dsq_h
        dsk_ptr = dsk_ptr + seq_start * stride_dsk_n + pid_h * stride_dsk_h
        alpha_ptr = alpha_ptr + seq_start * stride_a_n + pid_h * stride_a_h
        dh_ptr = dh_ptr + chunk_start * stride_dh_n + pid_h * stride_dh_h
    else:
        q_ptr = q_ptr + pid_b * stride_q_b + pid_h * stride_q_h
        k_ptr = k_ptr + pid_b * stride_k_b + pid_h * stride_k_h
        dsq_ptr = dsq_ptr + pid_b * stride_dsq_b + pid_h * stride_dsq_h
        dsk_ptr = dsk_ptr + pid_b * stride_dsk_b + pid_h * stride_dsk_h
        alpha_ptr = alpha_ptr + pid_b * stride_a_b + pid_h * stride_a_h
        dh_ptr = dh_ptr + pid_b * stride_dh_b + pid_h * stride_dh_h
    if USE_LAST_STATE_GRADIENT:
        dls_ptr = dls_ptr + pid_b * stride_dls_b + pid_h * stride_dls_h
    # ptrs
    q_ptrs = tl.make_block_ptr(
        base=q_ptr,
        shape=(head_dim_qk, seq_len),
        strides=(stride_q_d, stride_q_n),
        offsets=(pid_kd * BLOCK_SIZE_KD, (num_chunks - 1) * CHUNK_SIZE + 1),
        block_shape=(BLOCK_SIZE_KD, CHUNK_SIZE),
        order=(0, 1),
    )
    k_ptrs = tl.make_block_ptr(
        base=k_ptr,
        shape=(head_dim_qk, seq_len),
        strides=(stride_k_d, stride_k_n),
        offsets=(pid_kd * BLOCK_SIZE_KD, (num_chunks - 1) * CHUNK_SIZE + 1),
        block_shape=(BLOCK_SIZE_KD, CHUNK_SIZE),
        order=(0, 1),
    )
    alpha_ptrs = tl.make_block_ptr(
        base=alpha_ptr,
        shape=(seq_len,),
        strides=(stride_a_n,),
        offsets=((num_chunks - 1) * CHUNK_SIZE,),
        block_shape=(CHUNK_SIZE,),
        order=(0,),
    )
    dsq_ptrs = tl.make_block_ptr(
        base=dsq_ptr,
        shape=(seq_len, head_dim_v),
        strides=(stride_dsq_n, stride_dsq_d),
        offsets=((num_chunks - 1) * CHUNK_SIZE + 1, pid_vd * BLOCK_SIZE_VD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    dsk_ptrs = tl.make_block_ptr(
        base=dsk_ptr,
        shape=(seq_len, head_dim_v),
        strides=(stride_dsk_n, stride_dsk_d),
        offsets=((num_chunks - 1) * CHUNK_SIZE + 1, pid_vd * BLOCK_SIZE_VD),
        block_shape=(CHUNK_SIZE, BLOCK_SIZE_VD),
        order=(1, 0),
    )
    dh_ptrs = tl.make_block_ptr(
        base=dh_ptr,
        shape=(num_chunks, head_dim_qk, head_dim_v),
        strides=(stride_dh_n, stride_dh_kd, stride_dh_vd),
        offsets=(num_chunks - 1, pid_kd * BLOCK_SIZE_KD, pid_vd * BLOCK_SIZE_VD),
        block_shape=(1, BLOCK_SIZE_KD, BLOCK_SIZE_VD),
        order=(2, 1, 0),
    )
    # init dh
    dh = tl.zeros([BLOCK_SIZE_KD, BLOCK_SIZE_VD], dtype=tl.float32)
    if USE_LAST_STATE_GRADIENT:
        dls_ptrs = tl.make_block_ptr(
            base=dls_ptr,
            shape=(head_dim_qk, head_dim_v),
            strides=(stride_dls_kd, stride_dls_vd),
            offsets=(pid_kd * BLOCK_SIZE_KD, pid_vd * BLOCK_SIZE_VD),
            block_shape=(BLOCK_SIZE_KD, BLOCK_SIZE_VD),
            order=(1, 0),
        )
        dls = tl.load(dls_ptrs, boundary_check=(0, 1), padding_option="zero")
        dh += dls
    # loop over reversed chunks
    for _ in range(num_chunks - 1, -1, -1):
        # store dh
        tl.store(
            dh_ptrs,
            dh.to(dh_ptrs.dtype.element_ty)[None, :, :],
            boundary_check=(0, 1, 2),
        )
        # load q and do, shape: [BLOCK_SIZE_KD, CHUNK_SIZE]
        q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero")
        k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero")
        # load dsq, shape: [CHUNK_SIZE, BLOCK_SIZE_VD]
        dsq = tl.load(dsq_ptrs, boundary_check=(0, 1), padding_option="zero")
        dsk = tl.load(dsk_ptrs, boundary_check=(0, 1), padding_option="zero")
        # load alpha, shape: [CHUNK_SIZE,]
        alpha = tl.load(alpha_ptrs, boundary_check=(0,), padding_option="zero")
        sum_alpha = tl.sum(alpha, axis=0)  # shape: [1,]
        gamma = tl.cumsum(alpha, axis=0)  # shape: [CHUNK_SIZE,]
        # update dh with gate
        q = (q * tl.exp(gamma)[None, :]).to(q.dtype)
        k = (k * tl.exp(gamma)[None, :]).to(k.dtype)
        dh = dh * tl.exp(sum_alpha) + tl.dot(q, dsq) * scale + tl.dot(k, dsk) * scale
        # update ptrs
        q_ptrs = tl.advance(q_ptrs, (0, -CHUNK_SIZE))
        k_ptrs = tl.advance(k_ptrs, (0, -CHUNK_SIZE))
        dsq_ptrs = tl.advance(dsq_ptrs, (-CHUNK_SIZE, 0))
        dsk_ptrs = tl.advance(dsk_ptrs, (-CHUNK_SIZE, 0))
        alpha_ptrs = tl.advance(alpha_ptrs, (-CHUNK_SIZE,))
        dh_ptrs = tl.advance(dh_ptrs, (-1, 0, 0))


def chunk_sq_sk_bwd_h(
    q: torch.Tensor,
    k: torch.Tensor,
    dsq: torch.Tensor,
    dsk: torch.Tensor,
    alpha: torch.Tensor,
    dls: Optional[torch.Tensor],
    scale: float,
    cu_seq_len: torch.Tensor,
    cu_chunk_len: torch.Tensor,
    chunk_size: int,
    state_in_fp32: bool = False,
) -> torch.Tensor:
    if cu_seq_len is None:
        batch_size, seq_len, num_heads, head_dim_qk, head_dim_v = (
            *q.shape,
            dsq.shape[-1],
        )
        num_chunks = math.ceil(seq_len / chunk_size)
        dh = torch.empty(
            batch_size,
            num_chunks,
            num_heads,
            head_dim_qk,
            head_dim_v,
            device=q.device,
            dtype=q.dtype if not state_in_fp32 else torch.float32,
        )
    else:
        _, _, num_heads, head_dim_qk, head_dim_v = *q.shape, dsq.shape[-1]
        batch_size = cu_seq_len.shape[0] - 1
        seq_len = cu_seq_len[-1]
        num_chunks = cu_chunk_len[-1]
        dh = torch.empty(
            1,
            num_chunks,
            num_heads,
            head_dim_qk,
            head_dim_v,
            device=q.device,
            dtype=q.dtype if not state_in_fp32 else torch.float32,
        )

    def grid(meta):
        return (
            batch_size * num_heads,
            triton.cdiv(head_dim_qk, meta["BLOCK_SIZE_KD"]),
            triton.cdiv(head_dim_v, meta["BLOCK_SIZE_VD"]),
        )

    chunk_sq_sk_bwd_kernel_h[grid](
        q,
        k,
        dsq,
        dsk,
        alpha,
        dls,
        dh,
        scale,
        cu_seq_len,
        cu_chunk_len,
        seq_len,
        num_chunks,
        num_heads,
        head_dim_qk,
        head_dim_v,
        q.stride(0),
        q.stride(1),
        q.stride(2),
        q.stride(3),
        k.stride(0),
        k.stride(1),
        k.stride(2),
        k.stride(3),
        dsq.stride(0),
        dsq.stride(1),
        dsq.stride(2),
        dsq.stride(3),
        dsk.stride(0),
        dsk.stride(1),
        dsk.stride(2),
        dsk.stride(3),
        alpha.stride(0),
        alpha.stride(1),
        alpha.stride(2),
        dls.stride(0) if dls is not None else 0,
        dls.stride(1) if dls is not None else 0,
        dls.stride(2) if dls is not None else 0,
        dls.stride(3) if dls is not None else 0,
        dh.stride(0),
        dh.stride(1),
        dh.stride(2),
        dh.stride(3),
        dh.stride(4),
        CHUNK_SIZE=triton.next_power_of_2(chunk_size),
    )
    return dh
