import torch

import triton
import triton.language as tl

import einops


@triton.jit
def _attn_fwd_inner(
    O_block,
    l_i,
    m_i,
    Q_block,
    K_block_ptr,
    V_block_ptr,
    block_index_q,
    stride_q_seq,
    stride_kv_seq,
    softmax_scale,
    apply_kv_padding_mask,
    kv_padding_mask_block_ptr,
    BLOCK_SIZE_Q: tl.constexpr,
    BLOCK_SIZE_KV: tl.constexpr,
    LOCAL_ATTENTION: tl.constexpr,
    offs_q: tl.constexpr,
    offs_kv: tl.constexpr,
    SEQ_LEN: tl.constexpr,
    CWL: tl.constexpr,
):
    if LOCAL_ATTENTION == 1:
        past_window = CWL // 2 - 1
        future_window = CWL // 2
 
        lo = tl.maximum(block_index_q * BLOCK_SIZE_Q - past_window, 0)
        hi = tl.minimum((block_index_q + 1) * (BLOCK_SIZE_Q) + future_window, SEQ_LEN)
    else:
        lo, hi = 0, SEQ_LEN

    start_kv_block = (lo // BLOCK_SIZE_KV) * BLOCK_SIZE_KV
    end_kv_block = tl.cdiv(hi, BLOCK_SIZE_KV) * BLOCK_SIZE_KV

    K_block_ptr += start_kv_block * stride_kv_seq
    V_block_ptr += start_kv_block * stride_kv_seq

    # loop over k, v and update accumulator
    for curr_kv in range(start_kv_block, end_kv_block, BLOCK_SIZE_KV):
        # Just let the compiler know that start_n is a multiple of BLOCK_N, so the compiler can do optimizations
        curr_kv = tl.multiple_of(curr_kv, BLOCK_SIZE_KV)

        # -- compute qk ----
        K_block = tl.load(K_block_ptr, mask=curr_kv + offs_kv[None, :] < SEQ_LEN, other=0)  # Shape: (BLOCK_SIZE_KV, HEAD_DIM)
        QK_block = tl.dot(Q_block, K_block) # Shape: (BLOCK_SIZE_Q, BLOCK_SIZE_KV)

        if LOCAL_ATTENTION == 1:
            mask = tl.maximum(offs_q[:, None] - past_window, 0) > (curr_kv + offs_kv[None, :])
            mask |= tl.minimum(offs_q[:, None] + future_window, SEQ_LEN-1) < (curr_kv + offs_kv[None, :])
            if apply_kv_padding_mask:
                kv_padding_mask_block = tl.load(
                    kv_padding_mask_block_ptr + curr_kv,
                    mask=curr_kv + offs_kv < SEQ_LEN,
                    other=1
                )  # Shape: (BLOCK_SIZE_KV,)
                mask |= kv_padding_mask_block[None, :]  # Shape: (1, BLOCK_SIZE_KV)
        else:
            mask = curr_kv + offs_kv[None, :] >= SEQ_LEN
            mask |= curr_kv + offs_kv[None, :] < 0

        QK_block = tl.where(mask, -float("inf"), QK_block)
        m_ij = tl.maximum(m_i, tl.max(QK_block, 1) * softmax_scale)
        QK_block = tl.where(mask, -float("inf"), QK_block * softmax_scale - m_ij[:, None])

        # Compute the exponential of each dot product, so now we are computing exp(qk_ij - m_ij)
        P_block = tl.math.exp(QK_block)
        # Compute the sum by rows of the attention scores
        l_ij = tl.sum(P_block, 1)

        # This is the correction factor for the previous l_i
        alpha = tl.math.exp(m_i - m_ij)
        alpha = tl.where(m_ij == -float("inf"), 0, alpha)
        # Apply the correction factor to the previous l_i and add the new l_ij
        l_i = l_i * alpha + l_ij

        V_block = tl.load(V_block_ptr, mask=curr_kv + offs_kv[:, None] < SEQ_LEN, other=0).to(tl.float32)  # Shape: (BLOCK_SIZE_KV, HEAD_DIM)
        P_block = P_block.to(tl.float32)
        # This computes the following: O_new = P x V + O_old * alpha
        O_block = O_block * alpha[:, None]
        O_block = tl.dot(P_block, V_block) + O_block

        m_i = m_ij

        # Move to the next block of K and V
        K_block_ptr += BLOCK_SIZE_KV * stride_kv_seq
        V_block_ptr += BLOCK_SIZE_KV * stride_kv_seq
    return O_block, l_i, m_i


@triton.autotune(
    [
        triton.Config(
            {"BLOCK_SIZE_Q": BLOCK_SIZE_Q, "BLOCK_SIZE_KV": BLOCK_SIZE_KV},
            num_stages=num_stages,
            num_warps=num_warps,
        )
        for BLOCK_SIZE_Q in [64, 128]
        for BLOCK_SIZE_KV in [32, 64]
        for num_stages in ([3, 4, 7])
        for num_warps in [2, 4]
    ],
    #key=["SEQ_LEN", "HEAD_DIM"],
    key=[]
)
@triton.jit
def _attn_fwd(
    Q,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM
    K,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM
    V,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM
    eK,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM
    eV,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM
    softmax_scale,
    M,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN
    O,  # BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM
    stride_Q_batch,
    stride_Q_head,
    stride_Q_seq,
    stride_Q_dim,
    stride_K_batch,
    stride_K_head,
    stride_K_seq,
    stride_K_dim,
    stride_V_batch,
    stride_V_head,
    stride_V_seq,
    stride_V_dim,
    stride_eK_batch,
    stride_eK_head,
    stride_eK_seq,
    stride_eK_dim,
    stride_eV_batch,
    stride_eV_head,
    stride_eV_seq,
    stride_eV_dim,
    stride_O_batch,
    stride_O_head,
    stride_O_seq,
    stride_O_dim,
    stride_M_batch,
    stride_M_head,
    stride_M_seq,
    apply_kv_padding_mask,
    kv_padding_mask,
    BATCH_SIZE,
    NUM_HEADS: tl.constexpr,
    SEQ_LEN: tl.constexpr,
    STATE_LEN: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    BLOCK_SIZE_Q: tl.constexpr,
    BLOCK_SIZE_KV: tl.constexpr,
    CWL: tl.constexpr,
):
    tl.static_assert(BLOCK_SIZE_KV <= HEAD_DIM)

    # This indicate which block in the sequence length to process
    block_index_q = tl.program_id(0)

    # This indicates which head and batch to process. Each program is associated with a single head of a single batch
    index_batch_head = tl.program_id(1)
    # This indicate which batch this program is associated with (each batch has NUM_HEADS heads)
    index_batch = index_batch_head // NUM_HEADS
    # This indicate the position of the head in the batch
    index_head = index_batch_head % NUM_HEADS

    # This allows to get the (N_CTX, HEAD_DIM) block in the Q, K, V by selecting indexing it by batch and head
    qvk_offset = (
        index_batch.to(tl.int64) * stride_Q_batch
        + index_head.to(tl.int64) * stride_Q_head
    )

    e_kv_offset = (
        index_batch.to(tl.int64) * stride_eK_batch
        + index_head.to(tl.int64) * stride_eK_head
    )

    # offs_q: the offsets for the tokens in the Q to process
    offs_q = block_index_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) # pointer to the global sequence indices in the Q block.
    # offs_kv: the offsets for the tokens in the K and V sequence to process
    offs_kv = tl.arange(0, BLOCK_SIZE_KV) # relative pointers of in the local kv sequence

    Q_block_ptr = Q + qvk_offset +\
        offs_q[:, None] * stride_Q_seq +\
        tl.arange(0, HEAD_DIM)[None, :] * stride_Q_dim

    V_block_ptr = V + qvk_offset +\
        offs_kv[:, None] * stride_V_seq +\
            tl.arange(0, HEAD_DIM)[None, :] * stride_V_dim

    K_block_ptr = K + qvk_offset +\
        tl.arange(0, HEAD_DIM)[:, None] * stride_K_dim +\
        tl.arange(0, BLOCK_SIZE_KV)[None, :] * stride_K_seq

    kv_padding_mask_block_ptr = kv_padding_mask + index_batch * SEQ_LEN + offs_kv

    eV_block_ptr = eV + e_kv_offset +\
        offs_kv[:, None] * stride_eV_seq +\
        tl.arange(0, HEAD_DIM)[None, :] * stride_eV_dim

    eK_block_ptr = eK + e_kv_offset +\
        tl.arange(0, HEAD_DIM)[:, None] * stride_eK_dim +\
        tl.arange(0, BLOCK_SIZE_KV)[None, :] * stride_eK_seq

    O_block_ptr = O + index_batch * stride_O_batch +\
        index_head * stride_O_head +\
        offs_q[:, None] * stride_O_seq +\
        tl.arange(0, HEAD_DIM)[None, :] * stride_O_dim

    # m_i: the running maximum. We have one for each query
    m_i = tl.zeros([BLOCK_SIZE_Q], dtype=tl.float32) - float("inf")
    # l_i: the running sum. We have one for each query (as we sum the attention scores by rows)
    l_i = tl.zeros([BLOCK_SIZE_Q], dtype=tl.float32) + 1.0
    # acc: the accumulator for the output, which is a group of rows of the O matrix
    O_block = tl.zeros([BLOCK_SIZE_Q, HEAD_DIM], dtype=tl.float32)

    # load the blocks of Q: it will stay in SRAM throughout
    Q_block = tl.load(Q_block_ptr, mask=offs_q[:, None] < SEQ_LEN, other=0)  # Shape: (BLOCK_SIZE_Q, HEAD_DIM)

    O_block, l_i, m_i = _attn_fwd_inner(
        O_block,
        l_i,
        m_i,
        Q_block,
        K_block_ptr,
        V_block_ptr,
        block_index_q,
        stride_Q_seq,
        stride_K_seq,
        softmax_scale,
        apply_kv_padding_mask,
        kv_padding_mask_block_ptr,
        BLOCK_SIZE_Q,
        BLOCK_SIZE_KV,
        LOCAL_ATTENTION=1,
        offs_q=offs_q,
        offs_kv=offs_kv,
        SEQ_LEN=SEQ_LEN,
        CWL=CWL,
    )

    # Execute global attention
    O_block, l_i, m_i = _attn_fwd_inner(
        O_block,
        l_i,
        m_i,
        Q_block,
        eK_block_ptr,
        eV_block_ptr,
        block_index_q,
        stride_Q_seq,
        stride_eK_seq,
        softmax_scale,
        apply_kv_padding_mask,
        kv_padding_mask_block_ptr,
        BLOCK_SIZE_Q,
        BLOCK_SIZE_KV,
        LOCAL_ATTENTION=0,
        offs_q= offs_q,
        offs_kv=offs_kv,
        SEQ_LEN=STATE_LEN,
        CWL=CWL,
    )

    # epilogue
    m_i += tl.math.log(
        l_i
    )  # This is needed to compute the logsumexp for the backwards pass
    O_block = O_block / l_i[:, None]
    m_ptrs = M + index_batch_head * SEQ_LEN + offs_q
    tl.store(m_ptrs, m_i, mask=offs_q < SEQ_LEN)  # Store the logsumexp

    O_mask = offs_q[:, None] < SEQ_LEN
    tl.store(O_block_ptr, O_block.to(O.type.element_ty), mask=O_mask)

@triton.jit
def _attn_bwd_preprocess(
    O,
    dO,
    D,
    SEQ_LEN,
    BLOCK_SIZE_Q: tl.constexpr,
    HEAD_DIM: tl.constexpr,
):
    block_index_q = tl.program_id(0)
    offs_q = block_index_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)
    index_batch_head = tl.program_id(1)
    offs_dim = tl.arange(0, HEAD_DIM)
    # Load a single block of BLOCK_SIZE_Q rows of O
    O_block = tl.load(
        O
        + index_batch_head * HEAD_DIM * SEQ_LEN
        + offs_q[:, None] * HEAD_DIM
            + offs_dim[None, :], mask=offs_q[:, None] < SEQ_LEN, other=0
    )
    # Load a single block of BLOCK_SIZE_Q rows of dO
    dO_block = tl.load(
        dO
        + index_batch_head * HEAD_DIM * SEQ_LEN
        + offs_q[:, None] * HEAD_DIM
            + offs_dim[None, :], mask=offs_q[:, None] < SEQ_LEN, other=0
    ).to(tl.float32)
    # Compute the D block
    D_block = tl.sum(dO_block * O_block, axis=1)  # Shape: (BLOCK_SIZE_Q,)
    # Store the D block
    D_block_ptrs = D + index_batch_head * SEQ_LEN + offs_q
    mask = offs_q < SEQ_LEN
    tl.store(D_block_ptrs, D_block, mask=mask)


@triton.jit
def _attn_bwd_dq(
    Q,
    K,
    V,
    eK,
    eV,
    softmax_scale,
    dO,
    dQ,
    M,
    D,
    stride_batch,
    stride_head,
    stride_seq,
    stride_dim,
    e_stride_batch,
    e_stride_head,
    e_stride_seq,
    e_stride_dim,
    NUM_HEADS,
    SEQ_LEN,
    STATE_LEN,
    apply_kv_padding_mask,
    kv_padding_mask_block_ptr,
    BLOCK_Q: tl.constexpr,
    BLOCK_KV: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    CWL: tl.constexpr,
):
    index_batch_head = tl.program_id(2)
    index_batch = index_batch_head // NUM_HEADS
    index_head = index_batch_head % NUM_HEADS
    offset_batch_head = (stride_batch * index_batch + stride_head * index_head).to(
        tl.int64
    )
    e_offset_batch_head = (e_stride_batch * index_batch + e_stride_head * index_head).to(
        tl.int64
    )
    # This is the offset that allows us to select the right sequence given the batch and head.
    offset_batch_head_seq = (index_batch_head * SEQ_LEN).to(tl.int64)

    # Make sure the pointers are in the right place w.r.t batch and head
    # The reason we don't access the blocks through make_block_ptr is because we need to use the range of offsets to apply the masking
    Q += offset_batch_head
    K += offset_batch_head
    V += offset_batch_head
    eK += e_offset_batch_head
    eV += e_offset_batch_head
    dO += offset_batch_head
    dQ += offset_batch_head
    if apply_kv_padding_mask:
        kv_padding_mask_block_ptr += index_batch * SEQ_LEN

    # Make sure the pointers are in the right place w.r.t batch, head and sequence
    M += offset_batch_head_seq
    D += offset_batch_head_seq

    # load scales
    offs_dim = tl.arange(0, HEAD_DIM)

    index_block_q = tl.program_id(0)

    start_q = index_block_q * BLOCK_Q
    offs_q = start_q + tl.arange(0, BLOCK_Q)

    Q_block = tl.load(Q + offs_q[:, None] * stride_seq + offs_dim[None, :] * stride_dim, mask=offs_q[:, None] < SEQ_LEN, other=0)  # Shape: (BLOCK_Q, HEAD_DIM)
    dQ_block = tl.zeros([BLOCK_Q, HEAD_DIM], dtype=tl.float32)
    dO_block = tl.load(
        dO + offs_q[:, None] * stride_seq + offs_dim[None, :] * stride_dim, mask=offs_q[:, None] < SEQ_LEN, other=0
    )

    M_block = tl.load(M + offs_q, mask=offs_q < SEQ_LEN, other=0)  # Shape: (BLOCK_Q,)
    M_block = M_block[:, None]

    offs_kv = tl.arange(0, BLOCK_KV)

    past_window = CWL // 2 - 1
    future_window = CWL // 2

    lo = tl.maximum(index_block_q * BLOCK_Q - past_window, 0)
    hi = tl.minimum((index_block_q + 1) * (BLOCK_Q) + future_window, SEQ_LEN)

    start_kv_block = (lo // BLOCK_KV) * BLOCK_KV
    end_kv_block = tl.cdiv(hi, BLOCK_KV) * BLOCK_KV

    # We access the K and V as transposed blocks
    kT_ptrs = K + (start_kv_block + offs_kv[None, :]) * stride_seq + offs_dim[:, None] * stride_dim
    vT_ptrs = V + (start_kv_block + offs_kv[None, :]) * stride_seq + offs_dim[:, None] * stride_dim

    Di = tl.load(D + offs_q, mask=offs_q < SEQ_LEN, other=0)  # Shape: (BLOCK_Q,)

    for curr_kv in range(start_kv_block, end_kv_block, BLOCK_KV):
        K_T_block = tl.load(kT_ptrs, mask=curr_kv + offs_kv[None, :] < SEQ_LEN, other=0)
        V_T_block = tl.load(vT_ptrs, mask=curr_kv + offs_kv[None, :] < SEQ_LEN, other=0)

        # calculate mask
        mask = tl.maximum(offs_q[:, None] - past_window, 0) > (curr_kv + offs_kv[None, :])
        mask |= tl.minimum(offs_q[:, None] + future_window, SEQ_LEN-1) < (curr_kv + offs_kv[None, :])
        if apply_kv_padding_mask:
            kv_padding_mask_block = tl.load(
                kv_padding_mask_block_ptr + curr_kv + offs_kv,
                mask=curr_kv + offs_kv < SEQ_LEN,
                other=1
            )
            mask |= kv_padding_mask_block[None, :]  # Shape: (1, BLOCK_KV)
        QK_block = tl.where(mask, -float("inf"), softmax_scale * tl.dot(Q_block, K_T_block))
        P_block = tl.math.exp(QK_block - M_block)

        # Compute dP and dS.
        dP_block = tl.dot(dO_block, V_T_block).to(tl.float32)
        dS_block = P_block * (dP_block - Di[:, None])
        dS_block = dS_block.to(tl.float32)
        # Compute dQ.
        # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
        dQ_block += softmax_scale * tl.dot(dS_block, tl.trans(K_T_block))
        # Increment pointers.
        kT_ptrs += BLOCK_KV * stride_seq
        vT_ptrs += BLOCK_KV * stride_seq

    lo, hi = 0, STATE_LEN
    start_kv_block = (lo // BLOCK_KV) * BLOCK_KV
    end_kv_block = tl.cdiv(hi, BLOCK_KV) * BLOCK_KV

    # We access the K and V as transposed blocks
    kT_ptrs = eK + (start_kv_block + offs_kv[None, :]) * e_stride_seq + offs_dim[:, None] * e_stride_dim
    vT_ptrs = eV + (start_kv_block + offs_kv[None, :]) * e_stride_seq + offs_dim[:, None] * e_stride_dim

    for curr_kv in range(start_kv_block, end_kv_block, BLOCK_KV):
        K_T_block = tl.load(kT_ptrs, mask=curr_kv + offs_kv[None, :] < STATE_LEN, other=0)
        V_T_block = tl.load(vT_ptrs, mask=curr_kv + offs_kv[None, :] < STATE_LEN, other=0)
        mask = curr_kv + offs_kv[None, :] >= STATE_LEN
        mask |= curr_kv + offs_kv[None, :] < 0
        QK_block = tl.where(mask, -float("inf"), softmax_scale * tl.dot(Q_block, K_T_block))
        P_block = tl.math.exp(QK_block - M_block)

        # Compute dP and dS.
        dP_block = tl.dot(dO_block, V_T_block).to(tl.float32)
        dS_block = P_block * (dP_block - Di[:, None])
        dS_block = dS_block.to(tl.float32)
        # Compute dQ.
        # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
        dQ_block += softmax_scale * tl.dot(dS_block, tl.trans(K_T_block))
        # Increment pointers.
        kT_ptrs += BLOCK_KV * e_stride_seq
        vT_ptrs += BLOCK_KV * e_stride_seq

    # Store back to HBM
    dQ_block_ptrs = dQ + offs_q[:, None] * stride_seq + offs_dim[None, :] * stride_dim
    tl.store(dQ_block_ptrs, dQ_block, mask= offs_q[:, None] < SEQ_LEN)


@triton.jit
def _attn_bwd_dk_dv(
    Q,
    K,
    V,
    softmax_scale,
    dO,
    dK,
    dV,
    M,
    D,
    stride_q_batch,
    stride_q_head,
    stride_q_seq,
    stride_q_dim,
    stride_kv_batch,
    stride_kv_head,
    stride_kv_seq,
    stride_kv_dim,
    NUM_HEADS,
    SEQ_LEN,
    STATE_LEN,
    apply_kv_padding_mask,
    kv_padding_mask_block_ptr,
    BLOCK_Q: tl.constexpr,
    BLOCK_KV: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    LOCAL_ATTENTION: tl.constexpr,
    CWL: tl.constexpr,
):
    index_batch_head = tl.program_id(2)
    index_batch = index_batch_head // NUM_HEADS
    index_head = index_batch_head % NUM_HEADS
    offset_batch_head_q = (stride_q_batch * index_batch + stride_q_head * index_head).to(
        tl.int64
    )
    offset_batch_head_kv = (stride_kv_batch * index_batch + stride_kv_head * index_head).to(
        tl.int64
    )
    # This is the offset that allows us to select the right sequence given the batch and head.
    offset_batch_head_seq = (index_batch_head * SEQ_LEN).to(tl.int64)

    # Make sure the pointers are in the right place w.r.t batch and head
    # The reason we don't access the blocks through make_block_ptr is because we need to use the range of offsets to apply the masking
    Q += offset_batch_head_q
    K += offset_batch_head_kv
    V += offset_batch_head_kv
    dO += offset_batch_head_q
    dK += offset_batch_head_kv
    dV += offset_batch_head_kv
    if apply_kv_padding_mask:
        kv_padding_mask_block_ptr += index_batch * SEQ_LEN

    # Make sure the pointers are in the right place w.r.t batch, head and sequence
    M += offset_batch_head_seq
    D += offset_batch_head_seq

    # load scales
    offs_dim = tl.arange(0, HEAD_DIM)

    index_block_kv = tl.program_id(0)
    start_kv = index_block_kv * BLOCK_KV

    offs_kv = start_kv + tl.arange(0, BLOCK_KV)

    if LOCAL_ATTENTION == 1:
        past_window = CWL // 2
        future_window = CWL // 2 - 1
 
        lo = tl.maximum(index_block_kv * BLOCK_KV - past_window, 0)
        hi = tl.minimum((index_block_kv + 1) * (BLOCK_KV) + future_window, SEQ_LEN)
    else:
        lo, hi = 0, SEQ_LEN

    start_q_block = (lo // BLOCK_Q) * BLOCK_Q
    end_q_block = tl.cdiv(hi, BLOCK_Q) * BLOCK_Q

    dV_block = tl.zeros([BLOCK_KV, HEAD_DIM], dtype=tl.float32)
    dK_block = tl.zeros([BLOCK_KV, HEAD_DIM], dtype=tl.float32)

    # load K and V: they stay in SRAM throughout the inner loop.
    mask_kv = offs_kv < SEQ_LEN if LOCAL_ATTENTION == 1 else offs_kv < STATE_LEN
    K_block = tl.load(
        K + offs_kv[:, None] * stride_kv_seq + offs_dim[None, :] * stride_kv_dim, mask=mask_kv[:, None], other=0
    )  # Shape: (BLOCK_KV1, HEAD_DIM)
    V_block = tl.load(
        V + offs_kv[:, None] * stride_kv_seq + offs_dim[None, :] * stride_kv_dim, mask=mask_kv[:, None], other=0
    )  # Shape: (BLOCK_KV1, HEAD_DIM)

    offs_q = tl.arange(0, BLOCK_Q)

    # We access the Q as a transposed array, so that's why we treat offs_q as a column vector and offs_dim as a row vector
    # This is equivalent to doing:
    # We point to the first BLOCK_Q rows of Q for both the qT and dO pointers, inside the for loop we will move forward by BLOCK_Q rows at each iteration.
    qT_ptrs = Q + (start_q_block + offs_q[None, :]) * stride_q_seq + offs_dim[:, None] * stride_q_dim
    dO_ptrs = dO + (start_q_block + offs_q[:, None]) * stride_q_seq + offs_dim[None, :] * stride_q_dim

    # Iterates over the sequence dimension of the query
    for curr_q in range(start_q_block, end_q_block, BLOCK_Q):
        # Load a block of Q
        offs_q = curr_q + tl.arange(0, BLOCK_Q)
        qT_block = tl.load(qT_ptrs, mask=offs_q[None, :] < SEQ_LEN, other=0)  # Shape: (BLOCK_Q, HEAD_DIM)
        # Load the logsumexp values for the queries in the current block
        mask = offs_q < SEQ_LEN
        m = tl.load(M + offs_q, mask=mask, other=0)  # Shape: (BLOCK_Q,)

        # Calculate the mask
        if LOCAL_ATTENTION == 1:
            mask = tl.maximum(offs_kv[:, None] - past_window, 0) > offs_q[None, :] # Shape [BLOCK_KV, BLOCK_Q]
            mask |= tl.minimum(offs_kv[:, None] + future_window, SEQ_LEN-1) < offs_q[None, :] # Shape [BLOCK_KV, BLOCK_Q]
            if apply_kv_padding_mask:
                kv_padding_mask_block = tl.load(
                    kv_padding_mask_block_ptr + offs_kv, mask=offs_kv < SEQ_LEN, other=1
                )
                mask |= kv_padding_mask_block[:, None]  # Shape: (BLOCK_KV, 1)
        else:
            mask = offs_q[None, :] >= SEQ_LEN # Shape [1, BLOCK_Q]
            mask |= offs_q[None, :] < 0 # Shape [1, BLOCK_Q]

        # This gives us (QK^T)^T = (K^T)^T(Q^T) = K(Q^T) = P^T
        QK_T_block = tl.where(mask, -float("inf"), softmax_scale * tl.dot(K_block, qT_block))
        # We apply the softmax by using the logsumexp trick
        P_T_block = tl.math.exp(QK_T_block - m[None, :]) # Shape: (BLOCK_KV, BLOCK_Q)

        dO_block = tl.load(dO_ptrs, mask=offs_q[:, None]<SEQ_LEN, other=0) # Shape: (BLOCK_Q, HEAD_DIM)
        # According to the formula: dV_new = dV_old + P^T x dO, where x is the matrix multiplication
        dV_block += tl.dot(P_T_block.to(tl.float32), dO_block) # The summation

        # Delta = rowsum(O * dO) where * is the element-wise product
        Di = tl.load(D + offs_q, mask=offs_q<SEQ_LEN, other=0) # Shape: (BLOCK_Q,)

        # dP = dO x V^T, so dP^T = V x dO^T
        # Where x is the matrix multiplication
        dpT_block = tl.dot(V_block, tl.trans(dO_block)).to(tl.float32)

        # We know that dS = P * (dP - Delta), so dS^T = P^T * (dP^T - Delta^T)

        dS_T_block = P_T_block * (dpT_block - Di[None, :])
        dS_T_block = dS_T_block.to(tl.float32)

        # According to the formula on the paper: dK_new = dK_old + dS^T x Q
        dK_block += softmax_scale * tl.dot(dS_T_block, tl.trans(qT_block))
        # Increment pointers.
        qT_ptrs += BLOCK_Q * stride_q_seq
        dO_ptrs += BLOCK_Q * stride_q_seq

    dV_block_ptrs = dV + offs_kv[:, None] * stride_kv_seq + offs_dim[None, :] * stride_kv_dim
    tl.store(dV_block_ptrs, dV_block, mask=mask_kv[:, None])

    # Write back dK.
    dK_block_ptrs = dK + offs_kv[:, None] * stride_kv_seq + offs_dim[None, :] * stride_kv_dim
    tl.store(dK_block_ptrs, dK_block, mask=mask_kv[:, None])


class Contextualization(torch.autograd.Function):

    @staticmethod
    def forward(ctx, Q, K, V, eK, eV, cwl, apply_kv_padding_mask, kv_padding_mask):
        HEAD_DIM_Q, HEAD_DIM_K = Q.shape[-1], K.shape[-1]
        HEAD_DIM_V = V.shape[-1]

        BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM = Q.shape
        STATE_LEN = eK.shape[2]

        assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V

        # Make tensor is_contiguous
        if not Q.is_contiguous():
            Q = Q.contiguous()

        if not K.is_contiguous():
            K = K.contiguous()

        if not V.is_contiguous():
            V = V.contiguous()

        if not eK.is_contiguous():
            eK = eK.contiguous()

        if not eV.is_contiguous():
            eV = eV.contiguous()

        if apply_kv_padding_mask and not kv_padding_mask.is_contiguous():
            kv_padding_mask = kv_padding_mask.contiguous()

        softmax_scale = HEAD_DIM_K**-0.5

        #O = torch.empty_like(Q)
        O = torch.empty((BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM), device=Q.device, dtype=torch.float32)
        stage = 3

        grid = lambda args: (
            triton.cdiv(SEQ_LEN, args["BLOCK_SIZE_Q"]),
            BATCH_SIZE * NUM_HEADS,
            1,
        )

        # M is the logsumexp for the backward pass, one for each query
        M = torch.empty(
            (BATCH_SIZE, NUM_HEADS, SEQ_LEN), device=Q.device, dtype=torch.float32
        )

        _attn_fwd[grid](
            Q=Q,
            K=K,
            V=V,
            eK=eK,
            eV=eV,
            softmax_scale=softmax_scale,
            M=M,
            O=O,
            stride_Q_batch=Q.stride(0),
            stride_Q_head=Q.stride(1),
            stride_Q_seq=Q.stride(2),
            stride_Q_dim=Q.stride(3),
            stride_K_batch=K.stride(0),
            stride_K_head=K.stride(1),
            stride_K_seq=K.stride(2),
            stride_K_dim=K.stride(3),
            stride_V_batch=V.stride(0),
            stride_V_head=V.stride(1),
            stride_V_seq=V.stride(2),
            stride_V_dim=V.stride(3),
            stride_eK_batch=eK.stride(0),
            stride_eK_head=eK.stride(1),
            stride_eK_seq=eK.stride(2),
            stride_eK_dim=eK.stride(3),
            stride_eV_batch=eV.stride(0),
            stride_eV_head=eV.stride(1),
            stride_eV_seq=eV.stride(2),
            stride_eV_dim=eV.stride(3),
            stride_O_batch=O.stride(0),
            stride_O_head=O.stride(1),
            stride_O_seq=O.stride(2),
            stride_O_dim=O.stride(3),
            stride_M_batch=M.stride(0),
            stride_M_head=M.stride(1),
            stride_M_seq=M.stride(2),
            apply_kv_padding_mask=apply_kv_padding_mask,
            kv_padding_mask=kv_padding_mask,
            BATCH_SIZE=Q.shape[0],
            NUM_HEADS=Q.shape[1],
            SEQ_LEN=Q.shape[2],
            STATE_LEN=eK.shape[2],
            HEAD_DIM=HEAD_DIM_K,
            CWL=cwl,
        )


        ctx.save_for_backward(Q, K, V, eK, eV, O, M, kv_padding_mask)
        ctx.grid = grid
        ctx.softmax_scale = softmax_scale
        ctx.HEAD_DIM = HEAD_DIM_K
        ctx.CWL = cwl
        ctx.apply_kv_padding_mask = apply_kv_padding_mask
        return O


    @staticmethod
    def backward(ctx, dO):
        if not dO.is_contiguous():
            dO = dO.contiguous()

        Q, K, V, eK, eV, O, M, kv_padding_mask = ctx.saved_tensors

        apply_kv_padding_mask = ctx.apply_kv_padding_mask

        assert dO.is_contiguous()
        assert Q.stride() == K.stride() == V.stride() == O.stride() == dO.stride()
        assert eK.stride() == eV.stride()
        dQ = torch.empty_like(Q)
        dK = torch.empty_like(K)
        dV = torch.empty_like(V)
        deK = torch.empty_like(eK)
        deV = torch.empty_like(eV)

        BATCH_SIZE, NUM_HEADS, SEQ_LEN = Q.shape[:3]
        NUM_WARPS, NUM_STAGES = 4, 3
        BLOCK_SIZE_MICRO, BLOCK_SIZE_MACRO = 32, 128
        STATE_LEN = eK.shape[2]

        preprocess_grid = (triton.cdiv(SEQ_LEN, BLOCK_SIZE_MACRO), BATCH_SIZE * NUM_HEADS)
        D = torch.empty_like(M)  # Shape: (BATCH_SIZE, NUM_HEADS, SEQ_LEN)

        # Compute all the elements Di
        _attn_bwd_preprocess[preprocess_grid](
            O=O,
            dO=dO,
            D=D,
            SEQ_LEN=SEQ_LEN,
            BLOCK_SIZE_Q=BLOCK_SIZE_MACRO,
            HEAD_DIM=ctx.HEAD_DIM,
        )

        grid = (triton.cdiv(SEQ_LEN, BLOCK_SIZE_MACRO), 1, BATCH_SIZE * NUM_HEADS)

        # Fix KV and iterate through all the Q blocks
        _attn_bwd_dk_dv[grid](
            Q=Q,
            K=K,
            V=V,
            softmax_scale=ctx.softmax_scale,
            dO=dO,
            dK=dK,
            dV=dV,
            M=M,
            D=D,
            stride_q_batch=Q.stride(0),
            stride_q_head=Q.stride(1),
            stride_q_seq=Q.stride(2),
            stride_q_dim=Q.stride(3),
            stride_kv_batch=K.stride(0),
            stride_kv_head=K.stride(1),
            stride_kv_seq=K.stride(2),
            stride_kv_dim=K.stride(3),
            NUM_HEADS=NUM_HEADS,
            SEQ_LEN=SEQ_LEN,
            STATE_LEN=STATE_LEN,
            apply_kv_padding_mask=apply_kv_padding_mask,
            kv_padding_mask_block_ptr=kv_padding_mask,
            BLOCK_Q=BLOCK_SIZE_MICRO,
            BLOCK_KV=BLOCK_SIZE_MACRO,
            HEAD_DIM=ctx.HEAD_DIM,
            LOCAL_ATTENTION=1,  # Local attention is always 1 for the backward pass
            CWL=ctx.CWL,
            num_warps=NUM_WARPS,
            num_stages=NUM_STAGES,
        )

        grid = (triton.cdiv(STATE_LEN, BLOCK_SIZE_MACRO), 1, BATCH_SIZE * NUM_HEADS)
        # Fix eKV and iterate through all the Q blocks
        _attn_bwd_dk_dv[grid](
            Q=Q,
            K=eK,
            V=eV,
            softmax_scale=ctx.softmax_scale,
            dO=dO,
            dK=deK,
            dV=deV,
            M=M,
            D=D,
            stride_q_batch=Q.stride(0),
            stride_q_head=Q.stride(1),
            stride_q_seq=Q.stride(2),
            stride_q_dim=Q.stride(3),
            stride_kv_batch=eK.stride(0),
            stride_kv_head=eK.stride(1),
            stride_kv_seq=eK.stride(2),
            stride_kv_dim=eK.stride(3),
            NUM_HEADS=NUM_HEADS,
            SEQ_LEN=SEQ_LEN,
            STATE_LEN=STATE_LEN,
            apply_kv_padding_mask=apply_kv_padding_mask,
            kv_padding_mask_block_ptr=kv_padding_mask,
            BLOCK_Q=BLOCK_SIZE_MICRO,
            BLOCK_KV=BLOCK_SIZE_MACRO,
            HEAD_DIM=ctx.HEAD_DIM,
            LOCAL_ATTENTION=0,  # Global attention is always 0 for the backward pass
            CWL=ctx.CWL,
            num_warps=NUM_WARPS,
            num_stages=NUM_STAGES,
        )

        # Fix Q and iterate through all the KV block
        grid = (triton.cdiv(SEQ_LEN, BLOCK_SIZE_MACRO), 1, BATCH_SIZE * NUM_HEADS)
        _attn_bwd_dq[grid](
            Q=Q,
            K=K,
            V=V,
            eK=eK,
            eV=eV,
            softmax_scale=ctx.softmax_scale,
            dO=dO,
            dQ=dQ,
            M=M,
            D=D,
            stride_batch=Q.stride(0),
            stride_head=Q.stride(1),
            stride_seq=Q.stride(2),
            stride_dim=Q.stride(3),
            e_stride_batch=eK.stride(0),
            e_stride_head=eK.stride(1),
            e_stride_seq=eK.stride(2),
            e_stride_dim=eK.stride(3),
            NUM_HEADS=NUM_HEADS,
            SEQ_LEN=SEQ_LEN,
            STATE_LEN=STATE_LEN,
            apply_kv_padding_mask=apply_kv_padding_mask,
            kv_padding_mask_block_ptr=kv_padding_mask,
            BLOCK_Q=BLOCK_SIZE_MACRO,
            BLOCK_KV=BLOCK_SIZE_MICRO,
            HEAD_DIM=ctx.HEAD_DIM,
            CWL=ctx.CWL,
            num_warps=NUM_WARPS,
            num_stages=NUM_STAGES,
        )

        return dQ, dK, dV, deK, deV, None, None, None # CWL and kv_padding_mask have no gradients

