import math
from typing import Any

import torch
import triton
import triton.language as tl

from .full_backward import attn_flash_bwd_triton
from .mask_gen import mask_gen_triton
from ..common import calc_dims
from .sparse_attn import bmm, pbmm, sparse_attn_triton


@triton.jit
def sparse_attn_dv_kernel(
        q_buffer, q_stride_n, q_stride_h, q_stride_i, q_stride_d,  # (bsz, num_heads, q_len, head_dim)
        k_buffer, k_stride_n, k_stride_h, k_stride_i, k_stride_d,  # (bsz, num_heads, k_len, head_dim)
        m_buffer, m_stride_n, m_stride_h, m_stride_i, m_stride_t,  # (bsz, num_heads, q_blocks, top_k)
        v_buffer, v_stride_n, v_stride_h, v_stride_i, v_stride_v,  # (bsz, num_heads, k_len, value_dim)
        do_buffer, do_stride_n, do_stride_h, do_stride_i, do_stride_v,  # (bsz, num_heads, q_len, value_dim)
        l_buffer, l_stride_n, l_stride_h, l_stride_i,  # (bsz, num_heads, q_len)
        dv_buffer, dv_stride_n, dv_stride_h, dv_stride_i, dv_stride_v,  # (bsz, num_heads, k_len, value_dim)
        mult_factor: float,
        bsz, num_heads, q_len, k_len, head_dim: tl.constexpr, value_dim: tl.constexpr, top_k: tl.constexpr,
        block_size_q: tl.constexpr, block_size_k: tl.constexpr, query_offset, start_sink_tokens, end_sink_tokens,
        B_r: tl.constexpr, B_c: tl.constexpr) -> Any:

    n = tl.program_id(1)
    h = tl.program_id(2)
    ij = tl.program_id(0)

    q_block_offset = query_offset // block_size_q
    q_start_padding = query_offset - q_block_offset * block_size_q
    q_blocks = tl.cdiv(q_len + query_offset, block_size_q) - q_block_offset
    q_end_padding = q_blocks * block_size_q - (q_start_padding + q_len)
    k_blocks = tl.cdiv(k_len, block_size_k)
    k_blocks = max(k_blocks, tl.cdiv((q_block_offset + q_blocks) * block_size_q, block_size_k))

    T_r = tl.cdiv(q_blocks * block_size_q, B_r)
    T_c = tl.cdiv(top_k * block_size_k, B_c)
    T_start_sink = tl.cdiv(start_sink_tokens, B_c)
    T_end_sink = tl.cdiv(end_sink_tokens, B_c)
    i = ij // (T_c + T_start_sink + T_end_sink)
    j = ij % (T_c + T_start_sink + T_end_sink)

    d = tl.arange(0, head_dim)
    v = tl.arange(0, value_dim)

    r_block_begin = i * B_r // block_size_q
    r_block_end = min(q_blocks, r_block_begin + B_r // block_size_q)
    r_indices = r_block_begin + tl.arange(0, B_r // block_size_q)

    q = (
        tl.arange(0, B_r // block_size_q)[None, :] * block_size_q
        + tl.arange(0, block_size_q)[:, None]
    )  # token-level indices
    q_positions = (q_block_offset + r_block_begin) * block_size_q + q  # (block_size_q, B_r // block_size_q)
    q_indices = r_block_begin * block_size_q + q

    Q_i = tl.load(
        q_buffer
        + n * q_stride_n
        + h * q_stride_h
        + (q_indices[:, None, :] - q_start_padding) * q_stride_i
        + d[None, :, None] * q_stride_d,
        mask=(
            (query_offset <= q_positions[:, None, :])
            & (q_positions[:, None, :] < query_offset + q_len)
        ),
        other=0.0
    )  # (block_size_q, head_dim, B_r // block_size_q)
    dtype = Q_i.dtype
    Q_i *= mult_factor.to(dtype)

    dO_i = tl.load(
        do_buffer
        + n * do_stride_n
        + h * do_stride_h
        + (q_indices[:, None, :] - q_start_padding) * do_stride_i
        + v[None, :, None] * do_stride_v,
        mask=(
            (query_offset <= q_positions[:, None, :])
            & (q_positions[:, None, :] < query_offset + q_len)
        ),
        other=0.0
    )  # (block_size_q, value_dim, B_r // block_size_q)

    L_i = tl.load(
        l_buffer
        + n * l_stride_n
        + h * l_stride_h
        + (q_indices - q_start_padding) * l_stride_i,
        mask=(
            (query_offset <= q_positions)
            & (q_positions < query_offset + q_len)
        ),
        other=0.0
    )  # (block_size_q, B_r // block_size_q)

    sink_end_start_indices = tl.cdiv(
        (q_block_offset + r_block_begin + 1 + tl.arange(0, B_r // block_size_q)) * block_size_q,
        block_size_k
    ) * block_size_k - end_sink_tokens

    NEG_INF = -10000.0

    # Handle top-k
    if j < T_c:
        c_block_begin = j * B_c // block_size_k
        c_block_end = min(top_k, c_block_begin + B_c // block_size_k)
        c_indices = c_block_begin + tl.arange(0, B_c // block_size_k)

        sparse_indices = tl.load(
            m_buffer
            + n * m_stride_n
            + h * m_stride_h
            + r_indices[None, :] * m_stride_i
            + c_indices[:, None] * m_stride_t,
            mask=(
                (r_indices[None, :] < r_block_end)
                & (c_indices[:, None] < c_block_end)
            ),
            other=0
        )  # sparse-block-level indices, (B_c // block_size_k, B_r // block_size_q)

        k_indices = (
            start_sink_tokens + sparse_indices[:, None, :] * block_size_k
            + tl.arange(0, block_size_k)[None, :, None],
        )  # (B_c // block_size_k, block_size_k, B_r // block_size_q)

        K_j = tl.load(
            k_buffer
            + n * k_stride_n
            + h * k_stride_h
            + k_indices[None, :, :, :] * k_stride_i
            + d[:, None, None, None] * k_stride_d,
            mask=(
                (r_indices[None, None, None, :] < r_block_end)
                & (c_indices[None, :, None, None] < c_block_end)
                & (sparse_indices[None, :, None, :] >= 0)
            ),
            other=0
        )  # (head_dim, B_c // block_size_k, block_size_k, B_r // block_size_q)
        K_j = K_j.reshape(head_dim, B_c, B_r // block_size_q)

        S_ij = bmm(Q_i, K_j)  # (block_size_q, B_c, B_r // block_size_q)

        # Apply causal mask
        valid_mask = (k_indices[None, :, :, :] <= q_positions[:, None, None, :])

        # Skip length 0 blocks
        valid_mask &= (sparse_indices >= 0)[None, :, None, :]

        valid_mask = valid_mask.reshape(block_size_q, B_c, B_r // block_size_q)
        S_ij = tl.where(valid_mask, S_ij, tl.full((), NEG_INF, dtype=dtype))

        P_ij = tl.exp((S_ij - L_i[:, None, :]).to(tl.float32)).to(dtype)  # (block_size_q, B_c, B_r // block_size_q)
        dV_j = bmm(P_ij, dO_i, ta=True)  # (B_c, value_dim, B_r // block_size_q)
        dV_j = tl.reshape(dV_j, (B_c // block_size_k, block_size_k, value_dim, B_r // block_size_q))

        tl.atomic_add(
            dv_buffer
            + n * dv_stride_n
            + h * dv_stride_h
            + k_indices[:, :, None, :] * dv_stride_i
            + v[None, None, :, None] * dv_stride_v,
            dV_j,
            mask=(
                (r_indices[None, None, None, :] < r_block_end)
                & (c_indices[:, None, None, None] < c_block_end)
                & (sparse_indices[:, None, None, :] >= 0)
            ),
        )

    # Handle sink tokens
    elif j < T_c + T_start_sink:
        j -= T_c
        c_block_begin = j * B_c // block_size_k
        c_block_end = min(start_sink_tokens // block_size_k, c_block_begin + B_c // block_size_k)
        c_indices = c_block_begin + tl.arange(0, B_c // block_size_k)

        k_indices = c_indices[:, None] * block_size_k + tl.arange(0, block_size_k)[None, :]
        # (B_c // block_size_k, block_size_k)
        K_j = tl.load(
            k_buffer
            + n * k_stride_n
            + h * k_stride_h
            + k_indices[None, :, :] * k_stride_i
            + d[:, None, None] * k_stride_d,
            mask=(
                (c_indices[None, :, None] < c_block_end)
                & (k_indices[None, :, :] < k_len)
            ),
            other=0
        )  # (head_dim, B_c // block_size_k, block_size_k)
        K_j = K_j.reshape(head_dim, B_c)

        V_j = tl.load(
            v_buffer
            + n * v_stride_n
            + h * v_stride_h
            + k_indices[:, :, None] * v_stride_i
            + v[None, None, :] * v_stride_v,
            mask=(
                (c_indices[:, None, None] < c_block_end)
                & (k_indices[:, :, None] < k_len)
            ),
            other=0
        )  # (B_c // block_size_k, block_size_k, value_dim)

        S_ij = pbmm(Q_i, K_j)  # (block_size_q, B_c, B_r // block_size_q)

        # Apply causal mask
        k_positions = k_indices[None, :, :, None]  # (1, B_c // block_size_k, block_size_k, 1)
        valid_mask = (
            (k_positions <= q_positions[:, None, None, :])
            & (k_positions < sink_end_start_indices[None, None, None, :])
            & (c_indices[None, :, None, None] < c_block_end)
        )
        valid_mask = valid_mask.reshape(block_size_q, B_c, B_r // block_size_q)

        S_ij = tl.where(valid_mask, S_ij, tl.full((), NEG_INF, dtype=dtype))

        P_ij = tl.exp((S_ij - L_i[:, None, :]).to(tl.float32)).to(dtype)  # (block_size_q, B_c, B_r // block_size_q)
        dV_j = tl.dot(
            P_ij.permute(1, 0, 2).reshape(B_c, block_size_q * (B_r // block_size_q)),
            dO_i.permute(0, 2, 1).reshape(block_size_q * (B_r // block_size_q), value_dim),
        ).to(dtype)  # (B_c, value_dim)

        tl.atomic_add(
            dv_buffer
            + n * dv_stride_n
            + h * dv_stride_h
            + k_indices[:, :, None] * dv_stride_i
            + v[None, None, :] * dv_stride_v,
            dV_j.reshape(B_c // block_size_k, block_size_k, value_dim),
            mask=(
                (c_indices[:, None, None] < c_block_end)
                & (k_indices[:, :, None] < k_len)
            ),
        )

    elif j < T_c + T_start_sink + T_end_sink:
        j -= T_c + T_start_sink
        c_block_begin = j * B_c // block_size_k
        c_block_end = min(end_sink_tokens // block_size_k, c_block_begin + B_c // block_size_k)
        c_indices = c_block_begin + tl.arange(0, B_c // block_size_k)

        k_block_indices = sink_end_start_indices[None, :] // block_size_k + c_indices[:, None]
        k_block_indices_valid = (k_block_indices >= 0)  # (B_c // block_size_k, B_r // block_size_q)
        k_indices = (
            k_block_indices[:, None, :] * block_size_k + tl.arange(0, block_size_k)[None, :, None]
        )  # (B_c // block_size_k, block_size_k, B_r // block_size_q)

        K_j = tl.load(
            k_buffer
            + n * k_stride_n
            + h * k_stride_h
            + k_indices[None, :, :, :] * k_stride_i
            + d[:, None, None, None] * k_stride_d,
            mask=(
                (k_indices[None, :, :, :] < k_len)
                & (r_indices[None, None, None, :] < r_block_end)
                & (c_indices[None, :, None, None] < c_block_end)
                & k_block_indices_valid[None, :, None, :]
            ),
            other=0
        )  # (head_dim, B_c // block_size_k, block_size_k, B_r // block_size_q)
        K_j = K_j.reshape(head_dim, B_c, B_r // block_size_q)

        S_ij = bmm(Q_i, K_j)  # (block_size_q, B_c, B_r // block_size_q)

        # Apply causal mask
        k_positions = k_indices[None, :, :, :]  # (1, B_c // block_size_k, block_size_k, B_r // block_size_q)
        valid_mask = (
            (k_positions <= q_positions[:, None, None, :])
            & (c_indices[None, :, None, None] < c_block_end)
            & k_block_indices_valid[None, :, None, :]
        )
        valid_mask = valid_mask.reshape(block_size_q, B_c, B_r // block_size_q)
        S_ij = tl.where(valid_mask, S_ij, tl.full((), NEG_INF, dtype=dtype))

        P_ij = tl.exp((S_ij - L_i[:, None, :]).to(tl.float32)).to(dtype)  # (block_size_q, B_c, B_r // block_size_q)
        dV_j = bmm(P_ij, dO_i, ta=True)  # (B_c, value_dim, B_r // block_size_q)
        dV_j = tl.reshape(dV_j, (B_c // block_size_k, block_size_k, value_dim, B_r // block_size_q))

        tl.atomic_add(
            dv_buffer
            + n * dv_stride_n
            + h * dv_stride_h
            + k_indices[:, :, None, :] * dv_stride_i
            + v[None, None, :, None] * dv_stride_v,
            dV_j,
            mask=(
                (r_indices[None, None, None, :] < r_block_end)
                & (c_indices[:, None, None, None] < c_block_end)
                & (k_indices[:, :, None, :] < k_len)
                & k_block_indices_valid[:, None, None, :]
            ),
        )


@triton.jit
def flash_attn_kernel(
        q_buffer, q_stride_n, q_stride_h, q_stride_i, q_stride_d,  # (bsz, num_heads, q_len, head_dim)
        k_buffer, k_stride_n, k_stride_h, k_stride_i, k_stride_d,  # (bsz, num_heads, k_len, head_dim)
        v_buffer, v_stride_n, v_stride_h, v_stride_i, v_stride_v,  # (bsz, num_heads, k_len, value_dim)
        r_buffer, r_stride_n, r_stride_h, r_stride_i, r_stride_v,  # (bsz, num_heads, q_len, value_dim)
        l_buffer, l_stride_n, l_stride_h, l_stride_i,  # (bsz, num_heads, q_len)
        mult_factor: float,
        q_len, k_len, head_dim: tl.constexpr, value_dim: tl.constexpr,
        query_offset, B_r: tl.constexpr, B_c: tl.constexpr, return_l: tl.constexpr) -> Any:

    n = tl.program_id(1)
    h = tl.program_id(2)
    i = tl.program_id(0)

    T_r = tl.cdiv(q_len, B_r)
    T_c = tl.cdiv(k_len, B_c)

    d = tl.arange(0, head_dim)
    v = tl.arange(0, value_dim)

    r_block_begin = i * B_r
    r_block_end = min(q_len, r_block_begin + B_r)
    r_indices = r_block_begin + tl.arange(0, B_r)

    dtype = tl.bfloat16

    Q_i = tl.load(
        q_buffer
        + n * q_stride_n
        + h * q_stride_h
        + r_indices[:, None] * q_stride_i
        + d[None, :] * q_stride_d,
        mask=(r_indices[:, None] < r_block_end),
        other=0.0
    ).to(dtype)
    Q_i *= mult_factor.to(dtype)  # (B_r, head_dim)

    O_ij = tl.zeros((B_r, value_dim), dtype=dtype)
    l_ij = tl.full((B_r,), 1.0, dtype=dtype)
    m_ij = tl.full((B_r,), float('-inf'), dtype=dtype)

    for j in range(T_c):
        c_block_begin = j * B_c
        c_block_end = min(k_len, c_block_begin + B_c)
        c_indices = c_block_begin + tl.arange(0, B_c)

        K_j = tl.load(
            k_buffer
            + n * k_stride_n
            + h * k_stride_h
            + c_indices[:, None] * k_stride_i
            + d[None, :] * k_stride_d,
            mask=(c_indices[:, None] < c_block_end),
            other=0.0,
        ).to(dtype)  # (B_c, head_dim)

        V_j = tl.load(
            v_buffer
            + n * v_stride_n
            + h * v_stride_h
            + c_indices[:, None] * v_stride_i
            + v[None, :] * v_stride_v,
            mask=(c_indices[:, None] < c_block_end),
            other=0.0,
        ).to(dtype)  # (B_c, value_dim)

        S_ij = tl.dot(Q_i, tl.trans(K_j)).to(dtype)  # (B_r, B_c)

        # Apply causal mask
        k_positions = c_indices
        q_positions = query_offset + r_indices
        causal_mask = k_positions[None, :] <= q_positions[:, None]
        S_ij = tl.where(causal_mask, S_ij, tl.full((), float('-inf'), dtype=dtype))

        m_ijm1 = m_ij
        m_ij = tl.maximum(m_ij, tl.max(S_ij, axis=1)).to(dtype)  # r
        P_tilde_ij = tl.exp((S_ij - m_ij[:, None]).to(tl.float32)).to(dtype)  # r x c
        m_diff_exp = tl.exp((m_ijm1 - m_ij).to(tl.float32)).to(dtype)  # r
        l_ij = l_ij * m_diff_exp + tl.sum(P_tilde_ij, axis=1).to(dtype)  # r
        O_ij = O_ij * m_diff_exp[:, None] + tl.dot(P_tilde_ij, V_j).to(dtype)

    O_ij *= (1 / l_ij)[:, None]
    L_i = m_ij + tl.log(l_ij.to(tl.float32)).to(dtype)

    tl.store(
        r_buffer
        + n * r_stride_n
        + h * r_stride_h
        + r_indices[:, None] * r_stride_i
        + v[None, :] * r_stride_v,
        O_ij,
        mask=(r_indices[:, None] < r_block_end),
    )

    if return_l:
        tl.store(
            l_buffer
            + n * l_stride_n
            + h * l_stride_h
            + r_indices * l_stride_i,
            L_i,
            mask=(r_indices < r_block_end),
        )


def ste_sparse_attn_bwd_triton_impl(
        query_states, key_states, sparse_indices, value_states, output, grad_output, L,
        block_size_q: int, block_size_k: int, query_offset: int,
        start_sink_tokens: int, end_sink_tokens: int,
        B_r: int = None, B_c: int = None):
    """
    Sparse flashattention backward pass with STE
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states:   (bsz, num_heads, k_len, head_dim)
    :param sparse_indices: (bsz, num_heads, q_blocks, top_k)
    :param value_states: (bsz, num_heads, k_len, value_dim)
    :param output: (bsz, num_heads, q_len, value_dim)
    :param grad_output: (bsz, num_heads, q_len, value_dim)
    :param L: (bsz, num_heads, q_len)
    :param block_size_q: query block size
    :param block_size_k: key block size
    :param query_offset: offset of the query
    :param start_sink_tokens: number of sink tokens at the start of the key
    :param end_sink_tokens: number of sink tokens at the end of the key
    :param B_r: number of rows in the flash block
    :param B_c: number of columns in the flash block
    :return: (dQ, dK, dV)
    """

    device = query_states.device
    dtype = query_states.dtype

    bsz, num_heads, q_len, head_dim = query_states.size()
    _, _, k_len, value_dim = value_states.size()
    _, _, _, top_k = sparse_indices.size()

    if B_r is None:
        B_r = max(32, block_size_q)
    if B_c is None:
        B_c = max(64, block_size_k)

    assert B_r % block_size_q == 0
    assert B_c % block_size_k == 0

    (q_block_offset, q_blocks, q_start_padding, q_end_padding, k_blocks) = calc_dims(
        q_len, k_len, block_size_q, block_size_k, query_offset
    )

    T_r = triton.cdiv(q_blocks * block_size_q, B_r)
    T_c = triton.cdiv(top_k * block_size_k, B_c)

    q, k, m, v, o, do = query_states, key_states, sparse_indices, value_states, output, grad_output

    acc_dtype = dtype
    if dtype == torch.bfloat16:  # workaround for bf16 atomic_add not supported
        acc_dtype = torch.float32
    dv = torch.zeros_like(v, dtype=acc_dtype)

    grid = ((T_c + triton.cdiv(start_sink_tokens, B_c) + triton.cdiv(end_sink_tokens, B_c)) * T_r, bsz, num_heads)
    sparse_attn_dv_kernel[grid](
        q_buffer=q, q_stride_n=q.stride(0), q_stride_h=q.stride(1), q_stride_i=q.stride(2), q_stride_d=q.stride(3),
        k_buffer=k, k_stride_n=k.stride(0), k_stride_h=k.stride(1), k_stride_i=k.stride(2), k_stride_d=k.stride(3),
        m_buffer=m, m_stride_n=m.stride(0), m_stride_h=m.stride(1), m_stride_i=m.stride(2), m_stride_t=m.stride(3),
        v_buffer=v, v_stride_n=v.stride(0), v_stride_h=v.stride(1), v_stride_i=v.stride(2), v_stride_v=v.stride(3),
        do_buffer=do, do_stride_n=do.stride(0), do_stride_h=do.stride(1), do_stride_i=do.stride(2), do_stride_v=do.stride(3),
        l_buffer=L, l_stride_n=L.stride(0), l_stride_h=L.stride(1), l_stride_i=L.stride(2),
        dv_buffer=dv, dv_stride_n=dv.stride(0), dv_stride_h=dv.stride(1), dv_stride_i=dv.stride(2), dv_stride_v=dv.stride(3),
        mult_factor=1.0 / math.sqrt(head_dim),
        bsz=bsz, num_heads=num_heads, q_len=q_len, k_len=k_len, head_dim=head_dim, value_dim=value_dim, top_k=top_k,
        block_size_q=block_size_q, block_size_k=block_size_k, query_offset=query_offset,
        start_sink_tokens=start_sink_tokens, end_sink_tokens=end_sink_tokens, B_r=B_r, B_c=B_c,
        num_stages=1,
        num_warps=16,
    )

    o = torch.zeros_like(output, dtype=acc_dtype)
    L = torch.zeros_like(L, dtype=acc_dtype)

    if B_r is None:
        B_r = max(64, block_size_q)
    if B_c is None:
        B_c = max(128, block_size_k)

    grid = (T_r, bsz, num_heads)
    flash_attn_kernel[grid](
        q_buffer=q, q_stride_n=q.stride(0), q_stride_h=q.stride(1), q_stride_i=q.stride(2), q_stride_d=q.stride(3),
        k_buffer=k, k_stride_n=k.stride(0), k_stride_h=k.stride(1), k_stride_i=k.stride(2), k_stride_d=k.stride(3),
        v_buffer=v, v_stride_n=v.stride(0), v_stride_h=v.stride(1), v_stride_i=v.stride(2), v_stride_v=v.stride(3),
        r_buffer=o, r_stride_n=o.stride(0), r_stride_h=o.stride(1), r_stride_i=o.stride(2), r_stride_v=o.stride(3),
        l_buffer=L, l_stride_n=L.stride(0), l_stride_h=L.stride(1), l_stride_i=L.stride(2),
        mult_factor=1.0 / math.sqrt(head_dim),
        q_len=q_len, k_len=k_len, head_dim=head_dim, value_dim=value_dim,
        query_offset=query_offset, B_r=B_r, B_c=B_c, return_l=True,
        num_stages=1,
        num_warps=8,
    )

    (dq, dk, _, _, _), _ = attn_flash_bwd_triton(
        query_states.unsqueeze(3), key_states, value_states,
        o.unsqueeze(3), grad_output.unsqueeze(3), L.unsqueeze(3),
        query_offset=query_offset, begin_offset=0
    )
    dq = dq.squeeze(3)

    return dq, dk, dv


def ste_sparse_attn_bwd_triton(
        query_states, key_states, sparse_indices, value_states, output, grad_output, L,
        block_size_q: int, block_size_k: int, query_offset: int,
        start_sink_tokens: int, end_sink_tokens: int,
        B_r: int = None, B_c: int = None):
    """
    Sparse flashattention backward pass with STE
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states:   (bsz, num_heads, k_len, head_dim)
    :param sparse_indices: (bsz, num_heads, q_blocks, top_k)
    :param value_states: (bsz, num_heads, k_len, value_dim)
    :param output: (bsz, num_heads, q_len, value_dim)
    :param grad_output: (bsz, num_heads, q_len, value_dim)
    :param L: (bsz, num_heads, q_len)
    :param block_size_q: query block size
    :param block_size_k: key block size
    :param query_offset: offset of the query
    :param start_sink_tokens: number of sink tokens at the start of the key
    :param end_sink_tokens: number of sink tokens at the end of the key
    :param B_r: number of rows in the flash block
    :param B_c: number of columns in the flash block
    :return: (dQ, dK, dV)
    """
    bsz, num_heads, q_len, head_dim = query_states.size()
    _, _, k_len, value_dim = value_states.size()
    _, _, _, top_k = sparse_indices.size()

    assert end_sink_tokens >= block_size_q
    (q_block_offset, q_blocks, q_start_padding, q_end_padding, k_blocks) = calc_dims(
        q_len, k_len, block_size_q, block_size_k, query_offset
    )
    total_attended_tokens = start_sink_tokens + top_k * block_size_k + end_sink_tokens
    sparse_q_block_begin = min(q_blocks, triton.cdiv(total_attended_tokens, block_size_q) - 1 - q_block_offset)

    cutoff = max(0, sparse_q_block_begin * block_size_q - q_start_padding)

    # Perform dense backward for the initial part
    (dQ_initial, dK_initial, dV_initial, _, _), _ = attn_flash_bwd_triton(
        query_states[:, :, :cutoff].unsqueeze(3), key_states, value_states,
        output[:, :, :cutoff].unsqueeze(3),
        grad_output[:, :, :cutoff].unsqueeze(3),
        L[:, :, :cutoff].unsqueeze(3),
        B_r=B_r, B_c=B_c, query_offset=query_offset,
    )
    dQ_initial = dQ_initial.squeeze(3)

    dQ, dK, dV = ste_sparse_attn_bwd_triton_impl(
        query_states[:, :, cutoff:], key_states, sparse_indices, value_states,
        output[:, :, cutoff:], grad_output[:, :, cutoff:], L[:, :, cutoff:],
        block_size_q, block_size_k, cutoff + query_offset,
        start_sink_tokens, end_sink_tokens, B_r, B_c,
    )

    dQ = torch.cat([dQ_initial, dQ], dim=2)
    dK = dK + dK_initial
    dV = dV + dV_initial

    return dQ, dK, dV


class SteHipAttention(torch.autograd.Function):
    @staticmethod
    def forward(  # noqa
            ctx, query_states, key_states, value_states,
            hip_block_size_q: int, hip_block_size_k: int, hip_top_k_elems: int, query_offset: int,
            start_sink_tokens: int, end_sink_tokens: int, sparse_indices=None):
        assert hip_top_k_elems % hip_block_size_k == 0
        top_k_blocks = hip_top_k_elems // hip_block_size_k
        if sparse_indices is None:
            sparse_indices, _ = mask_gen_triton(
                query_states, key_states, top_k_blocks,
                hip_block_size_q, hip_block_size_k, query_offset,
                start_sink_tokens, end_sink_tokens
            )
        output, L = sparse_attn_triton(
            query_states, key_states, sparse_indices, value_states,
            hip_block_size_q, hip_block_size_k, query_offset,
            start_sink_tokens, end_sink_tokens
        )
        ctx.save_for_backward(query_states, key_states, sparse_indices, value_states, output, L)
        ctx.query_offset = query_offset
        ctx.block_size_q = hip_block_size_q
        ctx.block_size_k = hip_block_size_k
        ctx.start_sink_tokens = start_sink_tokens
        ctx.end_sink_tokens = end_sink_tokens
        return output

    @staticmethod
    def backward(ctx, grad_output):  # noqa
        query_states, key_states, sparse_indices, value_states, output, L = ctx.saved_tensors
        query_offset = ctx.query_offset

        dq, dk, dv = ste_sparse_attn_bwd_triton(
            query_states, key_states, sparse_indices, value_states, output, grad_output, L,
            ctx.block_size_q, ctx.block_size_k, query_offset,
            ctx.start_sink_tokens, ctx.end_sink_tokens
        )

        return dq, dk, dv, None, None, None, None, None, None, None, None


def ste_hip_attn(query_states, key_states, value_states,
                 hip_block_size_q: int, hip_block_size_k: int, hip_top_k_elems: int, query_offset: int,
                 start_sink_tokens: int, end_sink_tokens: int, sparse_indices=None):
    """
    Perform HiP attention in the forward pass.
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states: (bsz, num_heads, k_len, head_dim)
    :param value_states: (bsz, num_heads, k_len, value_dim)
    :param hip_block_size_q: block size for the query
    :param hip_block_size_k: block size for the key
    :param hip_top_k_elems: number of top-k elements. Must be a multiple of `hip_block_size_k`.
    :param query_offset: offset of the query
    :param start_sink_tokens: number of sink tokens at the start of the key
    :param end_sink_tokens: number of sink tokens at the end of the key
    :param sparse_indices: (bsz, num_heads, q_blocks, top_k)
    :return: output (bsz, num_heads, q_len, value_dim)
    """
    return SteHipAttention.apply(
        query_states, key_states, value_states,
        hip_block_size_q, hip_block_size_k, hip_top_k_elems, query_offset,
        start_sink_tokens, end_sink_tokens, sparse_indices,
    )
