import math
import os
from typing import Any

import torch
import triton
import triton.language as tl

from ..utils.set_device import SetDevice
from .full_backward import attn_flash_bwd_triton
from ..common import calc_dims
from .sparse_attn import bmm, pbmm


@triton.jit
def sparse_attn_bwd_kernel(
        q_buffer, q_stride_n, q_stride_h, q_stride_i, q_stride_d,  # (bsz, num_heads, q_len, head_dim)
        k_buffer, k_stride_n, k_stride_h, k_stride_i, k_stride_d,  # (bsz, num_heads, k_len, head_dim)
        m_buffer, m_stride_n, m_stride_h, m_stride_i, m_stride_t,  # (bsz, num_heads, q_blocks, top_k)
        v_buffer, v_stride_n, v_stride_h, v_stride_i, v_stride_v,  # (bsz, num_heads, k_len, value_dim)
        do_buffer, do_stride_n, do_stride_h, do_stride_i, do_stride_v,  # (bsz, num_heads, q_len, value_dim)
        l_buffer, l_stride_n, l_stride_h, l_stride_i,  # (bsz, num_heads, q_len)
        d_buffer, d_stride_n, d_stride_h, d_stride_i,  # (bsz, num_heads, q_len)
        dq_buffer, dq_stride_n, dq_stride_h, dq_stride_i, dq_stride_d,  # (bsz, num_heads, q_len, head_dim)
        dk_buffer, dk_stride_n, dk_stride_h, dk_stride_i, dk_stride_d,  # (bsz, num_heads, k_len, head_dim)
        dv_buffer, dv_stride_n, dv_stride_h, dv_stride_i, dv_stride_v,  # (bsz, num_heads, k_len, value_dim)
        mult_factor: float,
        bsz, num_heads, q_len, k_len, head_dim: tl.constexpr, value_dim: tl.constexpr, top_k: tl.constexpr,
        block_size_q: tl.constexpr, block_size_k: tl.constexpr, query_offset, start_sink_tokens, end_sink_tokens,
        B_r: tl.constexpr, B_c: tl.constexpr) -> Any:

    n = tl.program_id(1)
    h = tl.program_id(2)
    ij = tl.program_id(0)

    q_block_offset = query_offset // block_size_q
    q_start_padding = query_offset - q_block_offset * block_size_q
    q_blocks = tl.cdiv(q_len + query_offset, block_size_q) - q_block_offset
    q_end_padding = q_blocks * block_size_q - (q_start_padding + q_len)
    k_blocks = tl.cdiv(k_len, block_size_k)
    k_blocks = max(k_blocks, tl.cdiv((q_block_offset + q_blocks) * block_size_q, block_size_k))

    T_r = tl.cdiv(q_blocks * block_size_q, B_r)
    T_c = tl.cdiv(top_k * block_size_k, B_c)
    T_start_sink = tl.cdiv(start_sink_tokens, B_c)
    T_end_sink = tl.cdiv(end_sink_tokens, B_c)
    i = ij // (T_c + T_start_sink + T_end_sink)
    j = ij % (T_c + T_start_sink + T_end_sink)

    d = tl.arange(0, head_dim)
    v = tl.arange(0, value_dim)

    r_block_begin = i * B_r // block_size_q
    r_block_end = min(q_blocks, r_block_begin + B_r // block_size_q)
    r_indices = r_block_begin + tl.arange(0, B_r // block_size_q)

    q = (
        tl.arange(0, B_r // block_size_q)[None, :] * block_size_q
        + tl.arange(0, block_size_q)[:, None]
    )  # token-level indices
    q_positions = (q_block_offset + r_block_begin) * block_size_q + q  # (block_size_q, B_r // block_size_q)
    q_indices = r_block_begin * block_size_q + q

    Q_i = tl.load(
        q_buffer
        + n * q_stride_n
        + h * q_stride_h
        + (q_indices[:, None, :] - q_start_padding) * q_stride_i
        + d[None, :, None] * q_stride_d,
        mask=(
            (query_offset <= q_positions[:, None, :])
            & (q_positions[:, None, :] < query_offset + q_len)
        ),
        other=0.0
    )  # (block_size_q, head_dim, B_r // block_size_q)
    dtype = Q_i.dtype
    Q_i *= mult_factor.to(dtype)

    dO_i = tl.load(
        do_buffer
        + n * do_stride_n
        + h * do_stride_h
        + (q_indices[:, None, :] - q_start_padding) * do_stride_i
        + v[None, :, None] * do_stride_v,
        mask=(
            (query_offset <= q_positions[:, None, :])
            & (q_positions[:, None, :] < query_offset + q_len)
        ),
        other=0.0
    )  # (block_size_q, value_dim, B_r // block_size_q)

    L_i = tl.load(
        l_buffer
        + n * l_stride_n
        + h * l_stride_h
        + (q_indices - q_start_padding) * l_stride_i,
        mask=(
            (query_offset <= q_positions)
            & (q_positions < query_offset + q_len)
        ),
        other=0.0
    )  # (block_size_q, B_r // block_size_q)

    D_i = tl.load(
        d_buffer
        + n * d_stride_n
        + h * d_stride_h
        + (q_indices - q_start_padding) * d_stride_i,
        mask=(
            (query_offset <= q_positions)
            & (q_positions < query_offset + q_len)
        ),
        other=0.0
    )  # (block_size_q, B_r // block_size_q)

    sink_end_start_indices = tl.cdiv(
        (q_block_offset + r_block_begin + 1 + tl.arange(0, B_r // block_size_q)) * block_size_q,
        block_size_k
    ) * block_size_k - end_sink_tokens

    NEG_INF = -10000.0

    # Handle top-k
    if j < T_c:
        c_block_begin = j * B_c // block_size_k
        c_block_end = min(top_k, c_block_begin + B_c // block_size_k)
        c_indices = c_block_begin + tl.arange(0, B_c // block_size_k)

        sparse_indices = tl.load(
            m_buffer
            + n * m_stride_n
            + h * m_stride_h
            + r_indices[None, :] * m_stride_i
            + c_indices[:, None] * m_stride_t,
            mask=(
                (r_indices[None, :] < r_block_end)
                & (c_indices[:, None] < c_block_end)
            ),
            other=0
        )  # sparse-block-level indices, (B_c // block_size_k, B_r // block_size_q)

        k_indices = (
            start_sink_tokens + sparse_indices[:, None, :] * block_size_k
            + tl.arange(0, block_size_k)[None, :, None],
        )  # (B_c // block_size_k, block_size_k, B_r // block_size_q)

        K_j = tl.load(
            k_buffer
            + n * k_stride_n
            + h * k_stride_h
            + k_indices[None, :, :, :] * k_stride_i
            + d[:, None, None, None] * k_stride_d,
            mask=(
                (r_indices[None, None, None, :] < r_block_end)
                & (c_indices[None, :, None, None] < c_block_end)
                & (sparse_indices[None, :, None, :] >= 0)
            ),
            other=0
        )  # (head_dim, B_c // block_size_k, block_size_k, B_r // block_size_q)
        K_j = K_j.reshape(head_dim, B_c, B_r // block_size_q)

        V_j = tl.load(
            v_buffer
            + n * v_stride_n
            + h * v_stride_h
            + k_indices[:, :, None, :] * v_stride_i
            + v[None, None, :, None] * v_stride_v,
            mask=(
                (r_indices[None, None, None, :] < r_block_end)
                & (c_indices[:, None, None, None] < c_block_end)
                & (sparse_indices[:, None, None, :] >= 0)
            ),
            other=0
        )  # (B_c // block_size_k, block_size_k, value_dim, B_r // block_size_q)
        V_j = V_j.reshape(B_c, value_dim, B_r // block_size_q)

        S_ij = bmm(Q_i, K_j)  # (block_size_q, B_c, B_r // block_size_q)

        # Apply causal mask
        valid_mask = (k_indices[None, :, :, :] <= q_positions[:, None, None, :])

        # Skip length 0 blocks
        valid_mask &= (sparse_indices >= 0)[None, :, None, :]

        valid_mask = valid_mask.reshape(block_size_q, B_c, B_r // block_size_q)
        S_ij = tl.where(valid_mask, S_ij, tl.full((), NEG_INF, dtype=dtype))

        P_ij = tl.exp((S_ij - L_i[:, None, :]).to(tl.float32)).to(dtype)  # (block_size_q, B_c, B_r // block_size_q)
        dV_j = bmm(P_ij, dO_i, ta=True)  # (B_c, value_dim, B_r // block_size_q)
        dP_ij = bmm(dO_i, V_j, tb=True)  # (block_size_q, B_c, B_r // block_size_q)
        dS_ij = P_ij * (dP_ij - D_i[:, None, :])  # (block_size_q, B_c, B_r // block_size_q)
        dQ_i_update = bmm(dS_ij, K_j, tb=True)  # (block_size_q, head_dim, B_r // block_size_q)
        dK_j = bmm(dS_ij, Q_i, ta=True)  # (B_c, head_dim, B_r // block_size_q)

        dK_j = tl.reshape(dK_j, (B_c // block_size_k, block_size_k, head_dim, B_r // block_size_q))
        dV_j = tl.reshape(dV_j, (B_c // block_size_k, block_size_k, value_dim, B_r // block_size_q))

        tl.atomic_add(
            dq_buffer
            + n * dq_stride_n
            + h * dq_stride_h
            + (q_indices[:, None, :] - q_start_padding) * dq_stride_i
            + d[None, :, None] * dq_stride_d,
            dQ_i_update * mult_factor.to(dtype),
            mask=(
                (query_offset <= q_positions[:, None, :])
                & (q_positions[:, None, :] < query_offset + q_len)
            ),
        )
        tl.atomic_add(
            dk_buffer
            + n * dk_stride_n
            + h * dk_stride_h
            + k_indices[:, :, None, :] * dk_stride_i
            + d[None, None, :, None] * dk_stride_d,
            dK_j,
            mask=(
                (r_indices[None, None, None, :] < r_block_end)
                & (c_indices[:, None, None, None] < c_block_end)
                & (sparse_indices[:, None, None, :] >= 0)
            ),
        )
        tl.atomic_add(
            dv_buffer
            + n * dv_stride_n
            + h * dv_stride_h
            + k_indices[:, :, None, :] * dv_stride_i
            + v[None, None, :, None] * dv_stride_v,
            dV_j,
            mask=(
                (r_indices[None, None, None, :] < r_block_end)
                & (c_indices[:, None, None, None] < c_block_end)
                & (sparse_indices[:, None, None, :] >= 0)
            ),
        )

    # Handle sink tokens
    elif j < T_c + T_start_sink:
        j -= T_c
        c_block_begin = j * B_c // block_size_k
        c_block_end = min(start_sink_tokens // block_size_k, c_block_begin + B_c // block_size_k)
        c_indices = c_block_begin + tl.arange(0, B_c // block_size_k)

        k_indices = c_indices[:, None] * block_size_k + tl.arange(0, block_size_k)[None, :]
        # (B_c // block_size_k, block_size_k)
        K_j = tl.load(
            k_buffer
            + n * k_stride_n
            + h * k_stride_h
            + k_indices[None, :, :] * k_stride_i
            + d[:, None, None] * k_stride_d,
            mask=(
                (c_indices[None, :, None] < c_block_end)
                & (k_indices[None, :, :] < k_len)
            ),
            other=0
        )  # (head_dim, B_c // block_size_k, block_size_k)
        K_j = K_j.reshape(head_dim, B_c)

        V_j = tl.load(
            v_buffer
            + n * v_stride_n
            + h * v_stride_h
            + k_indices[:, :, None] * v_stride_i
            + v[None, None, :] * v_stride_v,
            mask=(
                (c_indices[:, None, None] < c_block_end)
                & (k_indices[:, :, None] < k_len)
            ),
            other=0
        )  # (B_c // block_size_k, block_size_k, value_dim)
        V_j = V_j.reshape(B_c, value_dim)

        S_ij = pbmm(Q_i, K_j)  # (block_size_q, B_c, B_r // block_size_q)

        # Apply causal mask
        k_positions = k_indices[None, :, :, None]  # (1, B_c // block_size_k, block_size_k, 1)
        valid_mask = (
            (k_positions <= q_positions[:, None, None, :])
            & (k_positions < sink_end_start_indices[None, None, None, :])
            & (c_indices[None, :, None, None] < c_block_end)
        )
        valid_mask = valid_mask.reshape(block_size_q, B_c, B_r // block_size_q)

        S_ij = tl.where(valid_mask, S_ij, tl.full((), NEG_INF, dtype=dtype))

        P_ij = tl.exp((S_ij - L_i[:, None, :]).to(tl.float32)).to(dtype)  # (block_size_q, B_c, B_r // block_size_q)
        dV_j = tl.dot(
            P_ij.permute(1, 0, 2).reshape(B_c, block_size_q * (B_r // block_size_q)),
            dO_i.permute(0, 2, 1).reshape(block_size_q * (B_r // block_size_q), value_dim),
        ).to(dtype)  # (B_c, value_dim)
        dP_ij = pbmm(dO_i, tl.trans(V_j))  # (block_size_q, B_c, B_r // block_size_q)
        dS_ij = P_ij * (dP_ij - D_i[:, None, :])  # (block_size_q, B_c, B_r // block_size_q)
        dQ_i_update = pbmm(dS_ij, tl.trans(K_j))  # (block_size_q, head_dim, B_r // block_size_q)
        dK_j = tl.dot(
            dS_ij.permute(1, 0, 2).reshape(B_c, block_size_q * (B_r // block_size_q)),
            Q_i.permute(0, 2, 1).reshape(block_size_q * (B_r // block_size_q), head_dim),
        ).to(dtype)  # (B_c, head_dim)

        tl.atomic_add(
            dq_buffer
            + n * dq_stride_n
            + h * dq_stride_h
            + (q_indices[:, None, :] - q_start_padding) * dq_stride_i
            + d[None, :, None] * dq_stride_d,
            dQ_i_update * mult_factor.to(dtype),
            mask=(
                (query_offset <= q_positions[:, None, :])
                & (q_positions[:, None, :] < query_offset + q_len)
            ),
        )
        tl.atomic_add(
            dk_buffer
            + n * dk_stride_n
            + h * dk_stride_h
            + k_indices[:, :, None] * dk_stride_i
            + d[None, None, :] * dk_stride_d,
            dK_j.reshape(B_c // block_size_k, block_size_k, head_dim),
            mask=(
                (c_indices[:, None, None] < c_block_end)
                & (k_indices[:, :, None] < k_len)
            ),
        )
        tl.atomic_add(
            dv_buffer
            + n * dv_stride_n
            + h * dv_stride_h
            + k_indices[:, :, None] * dv_stride_i
            + v[None, None, :] * dv_stride_v,
            dV_j.reshape(B_c // block_size_k, block_size_k, value_dim),
            mask=(
                (c_indices[:, None, None] < c_block_end)
                & (k_indices[:, :, None] < k_len)
            ),
        )

    elif j < T_c + T_start_sink + T_end_sink:
        j -= T_c + T_start_sink
        c_block_begin = j * B_c // block_size_k
        c_block_end = min(end_sink_tokens // block_size_k, c_block_begin + B_c // block_size_k)
        c_indices = c_block_begin + tl.arange(0, B_c // block_size_k)

        k_block_indices = sink_end_start_indices[None, :] // block_size_k + c_indices[:, None]
        k_block_indices_valid = (k_block_indices >= 0)  # (B_c // block_size_k, B_r // block_size_q)
        k_indices = (
            k_block_indices[:, None, :] * block_size_k + tl.arange(0, block_size_k)[None, :, None]
        )  # (B_c // block_size_k, block_size_k, B_r // block_size_q)

        K_j = tl.load(
            k_buffer
            + n * k_stride_n
            + h * k_stride_h
            + k_indices[None, :, :, :] * k_stride_i
            + d[:, None, None, None] * k_stride_d,
            mask=(
                (k_indices[None, :, :, :] < k_len)
                & (r_indices[None, None, None, :] < r_block_end)
                & (c_indices[None, :, None, None] < c_block_end)
                & k_block_indices_valid[None, :, None, :]
            ),
            other=0
        )  # (head_dim, B_c // block_size_k, block_size_k, B_r // block_size_q)
        K_j = K_j.reshape(head_dim, B_c, B_r // block_size_q)

        V_j = tl.load(
            v_buffer
            + n * v_stride_n
            + h * v_stride_h
            + k_indices[:, :, None, :] * v_stride_i
            + v[None, None, :, None] * v_stride_v,
            mask=(
                (k_indices[:, :, None, :] < k_len)
                & (r_indices[None, None, None, :] < r_block_end)
                & (c_indices[:, None, None, None] < c_block_end)
                & k_block_indices_valid[:, None, None, :]
            ),
            other=0
        )  # (B_c // block_size_k, block_size_k, value_dim, B_r // block_size_q)
        V_j = V_j.reshape(B_c, value_dim, B_r // block_size_q)

        S_ij = bmm(Q_i, K_j)  # (block_size_q, B_c, B_r // block_size_q)

        # Apply causal mask
        k_positions = k_indices[None, :, :, :]  # (1, B_c // block_size_k, block_size_k, B_r // block_size_q)
        valid_mask = (
            (k_positions <= q_positions[:, None, None, :])
            & (c_indices[None, :, None, None] < c_block_end)
            & k_block_indices_valid[None, :, None, :]
        )
        valid_mask = valid_mask.reshape(block_size_q, B_c, B_r // block_size_q)
        S_ij = tl.where(valid_mask, S_ij, tl.full((), NEG_INF, dtype=dtype))

        P_ij = tl.exp((S_ij - L_i[:, None, :]).to(tl.float32)).to(dtype)  # (block_size_q, B_c, B_r // block_size_q)
        dV_j = bmm(P_ij, dO_i, ta=True)  # (B_c, value_dim, B_r // block_size_q)
        dP_ij = bmm(dO_i, V_j, tb=True)  # (block_size_q, B_c, B_r // block_size_q)
        dS_ij = P_ij * (dP_ij - D_i[:, None, :])  # (block_size_q, B_c, B_r // block_size_q)
        dQ_i_update = bmm(dS_ij, K_j, tb=True)  # (block_size_q, head_dim, B_r // block_size_q)
        dK_j = bmm(dS_ij, Q_i, ta=True)  # (B_c, head_dim, B_r // block_size_q)

        dK_j = tl.reshape(dK_j, (B_c // block_size_k, block_size_k, head_dim, B_r // block_size_q))
        dV_j = tl.reshape(dV_j, (B_c // block_size_k, block_size_k, value_dim, B_r // block_size_q))

        tl.atomic_add(
            dq_buffer
            + n * dq_stride_n
            + h * dq_stride_h
            + (q_indices[:, None, :] - q_start_padding) * dq_stride_i
            + d[None, :, None] * dq_stride_d,
            dQ_i_update * mult_factor.to(dtype),
            mask=(
                (query_offset <= q_positions[:, None, :])
                & (q_positions[:, None, :] < query_offset + q_len)
            ),
        )
        tl.atomic_add(
            dk_buffer
            + n * dk_stride_n
            + h * dk_stride_h
            + k_indices[:, :, None, :] * dk_stride_i
            + d[None, None, :, None] * dk_stride_d,
            dK_j,
            mask=(
                (r_indices[None, None, None, :] < r_block_end)
                & (c_indices[:, None, None, None] < c_block_end)
                & (k_indices[:, :, None, :] < k_len)
                & k_block_indices_valid[:, None, None, :]
            ),
        )
        tl.atomic_add(
            dv_buffer
            + n * dv_stride_n
            + h * dv_stride_h
            + k_indices[:, :, None, :] * dv_stride_i
            + v[None, None, :, None] * dv_stride_v,
            dV_j,
            mask=(
                (r_indices[None, None, None, :] < r_block_end)
                & (c_indices[:, None, None, None] < c_block_end)
                & (k_indices[:, :, None, :] < k_len)
                & k_block_indices_valid[:, None, None, :]
            ),
        )


def sparse_attn_bwd_triton_impl(
        query_states, key_states, sparse_indices, value_states, output, grad_output, L,
        block_size_q: int, block_size_k: int, query_offset: int,
        start_sink_tokens: int, end_sink_tokens: int,
        B_r: int = None, B_c: int = None):
    """
    Sparse flashattention backward pass
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states:   (bsz, num_heads, k_len, head_dim)
    :param sparse_indices: (bsz, num_heads, q_blocks, top_k)
    :param value_states: (bsz, num_heads, k_len, value_dim)
    :param output: (bsz, num_heads, q_len, value_dim)
    :param grad_output: (bsz, num_heads, q_len, value_dim)
    :param L: (bsz, num_heads, q_len)
    :param block_size_q: query block size
    :param block_size_k: key block size
    :param query_offset: offset of the query
    :param start_sink_tokens: number of sink tokens at the start of the key
    :param end_sink_tokens: number of sink tokens at the end of the key
    :param B_r: number of rows in the flash block
    :param B_c: number of columns in the flash block
    :return: (dQ, dK, dV)
    """
    device = query_states.device
    dtype = query_states.dtype

    bsz, num_heads, q_len, head_dim = query_states.size()
    _, _, k_len, value_dim = value_states.size()
    _, _, _, top_k = sparse_indices.size()

    if B_r is None:
        B_r = max(int(os.environ.get('QE_BWD_SA_BLOCK_BQ', 32)), block_size_q)
    if B_c is None:
        B_c = max(int(os.environ.get('QE_BWD_SA_BLOCK_BK', 64)), block_size_k)

    assert B_r % block_size_q == 0
    assert B_c % block_size_k == 0

    (q_block_offset, q_blocks, q_start_padding, q_end_padding, k_blocks) = calc_dims(
        q_len, k_len, block_size_q, block_size_k, query_offset
    )

    T_r = triton.cdiv(q_blocks * block_size_q, B_r)
    T_c = triton.cdiv(top_k * block_size_k, B_c)

    D = (grad_output * output).sum(dim=3)  # (bsz, num_heads, q_len)

    q, k, m, v, o, do = query_states, key_states, sparse_indices, value_states, output, grad_output

    acc_dtype = dtype
    if dtype == torch.bfloat16:  # workaround for bf16 atomic_add not supported
        acc_dtype = torch.float32
    dq = torch.zeros_like(q, dtype=acc_dtype)
    dk = torch.zeros_like(k, dtype=acc_dtype)
    dv = torch.zeros_like(v, dtype=acc_dtype)

    grid = ((T_c + triton.cdiv(start_sink_tokens, B_c) + triton.cdiv(end_sink_tokens, B_c)) * T_r, bsz, num_heads)
    with SetDevice(query_states.device):
        compile_info = sparse_attn_bwd_kernel[grid](
            q_buffer=q, q_stride_n=q.stride(0), q_stride_h=q.stride(1), q_stride_i=q.stride(2), q_stride_d=q.stride(3),
            k_buffer=k, k_stride_n=k.stride(0), k_stride_h=k.stride(1), k_stride_i=k.stride(2), k_stride_d=k.stride(3),
            m_buffer=m, m_stride_n=m.stride(0), m_stride_h=m.stride(1), m_stride_i=m.stride(2), m_stride_t=m.stride(3),
            v_buffer=v, v_stride_n=v.stride(0), v_stride_h=v.stride(1), v_stride_i=v.stride(2), v_stride_v=v.stride(3),
            do_buffer=do, do_stride_n=do.stride(0), do_stride_h=do.stride(1), do_stride_i=do.stride(2), do_stride_v=do.stride(3),
            l_buffer=L, l_stride_n=L.stride(0), l_stride_h=L.stride(1), l_stride_i=L.stride(2),
            d_buffer=D, d_stride_n=D.stride(0), d_stride_h=D.stride(1), d_stride_i=D.stride(2),
            dq_buffer=dq, dq_stride_n=dq.stride(0), dq_stride_h=dq.stride(1), dq_stride_i=dq.stride(2), dq_stride_d=dq.stride(3),
            dk_buffer=dk, dk_stride_n=dk.stride(0), dk_stride_h=dk.stride(1), dk_stride_i=dk.stride(2), dk_stride_d=dk.stride(3),
            dv_buffer=dv, dv_stride_n=dv.stride(0), dv_stride_h=dv.stride(1), dv_stride_i=dv.stride(2), dv_stride_v=dv.stride(3),
            mult_factor=1.0 / math.sqrt(head_dim),
            bsz=bsz, num_heads=num_heads, q_len=q_len, k_len=k_len, head_dim=head_dim, value_dim=value_dim, top_k=top_k,
            block_size_q=block_size_q, block_size_k=block_size_k, query_offset=query_offset,
            start_sink_tokens=start_sink_tokens, end_sink_tokens=end_sink_tokens, B_r=B_r, B_c=B_c,
            num_stages=1,
            num_warps=16,
        )

    dq = dq.to(dtype)

    return (dq, dk, dv), compile_info


def sparse_attn_bwd_triton(
        query_states, key_states, sparse_indices, value_states, output, grad_output, L,
        block_size_q: int, block_size_k: int, query_offset: int,
        start_sink_tokens: int, end_sink_tokens: int,
        B_r: int = None, B_c: int = None):
    """
    Sparse flashattention backward pass
    :param query_states: (bsz, num_heads, q_len, head_dim)
    :param key_states:   (bsz, num_heads, k_len, head_dim)
    :param sparse_indices: (bsz, num_heads, q_blocks, top_k)
    :param value_states: (bsz, num_heads, k_len, value_dim)
    :param output: (bsz, num_heads, q_len, value_dim)
    :param grad_output: (bsz, num_heads, q_len, value_dim)
    :param L: (bsz, num_heads, q_len)
    :param block_size_q: query block size
    :param block_size_k: key block size
    :param query_offset: offset of the query
    :param start_sink_tokens: number of sink tokens at the start of the key
    :param end_sink_tokens: number of sink tokens at the end of the key
    :param B_r: number of rows in the flash block
    :param B_c: number of columns in the flash block
    :return: (dQ, dK, dV)
    """
    bsz, num_heads, q_len, head_dim = query_states.size()
    _, _, k_len, value_dim = value_states.size()
    _, _, _, top_k = sparse_indices.size()

    assert end_sink_tokens >= block_size_q
    (q_block_offset, q_blocks, q_start_padding, q_end_padding, k_blocks) = calc_dims(
        q_len, k_len, block_size_q, block_size_k, query_offset
    )
    total_attended_tokens = start_sink_tokens + top_k * block_size_k + end_sink_tokens
    sparse_q_block_begin = min(q_blocks, triton.cdiv(total_attended_tokens, block_size_q) - 1 - q_block_offset)

    cutoff = max(0, sparse_q_block_begin * block_size_q - q_start_padding)

    # Perform dense backward for the initial part
    (dQ_initial, dK_initial, dV_initial, _, _), _ = attn_flash_bwd_triton(
        query_states[:, :, :cutoff].unsqueeze(3), key_states, value_states,
        output[:, :, :cutoff].unsqueeze(3),
        grad_output[:, :, :cutoff].unsqueeze(3),
        L[:, :, :cutoff].unsqueeze(3),
        B_r=B_r, B_c=B_c, query_offset=query_offset,
    )
    dQ_initial = dQ_initial.squeeze(3)

    (dQ, dK, dV), _ = sparse_attn_bwd_triton_impl(
        query_states[:, :, cutoff:], key_states, sparse_indices, value_states,
        output[:, :, cutoff:], grad_output[:, :, cutoff:], L[:, :, cutoff:],
        block_size_q, block_size_k, cutoff + query_offset,
        start_sink_tokens, end_sink_tokens, B_r, B_c,
    )

    dQ = torch.cat([dQ_initial, dQ], dim=2)
    dK = dK + dK_initial
    dV = dV + dV_initial

    return dQ, dK, dV
