import os
from typing import Optional

import torch
import triton
import triton.language as tl
import triton.tools.experimental_descriptor
from triton_bwd import autotune, triton_bwd

# ENABLE_LHS_TO_TMEM is an experimental environment variable for Blackwell.
# If it is set to 1 it can improve performance of Blackwell attention. However,
# it defaults to 0 as it is known to cause correctness issues outside of the
# _attn_fwd_tma kernel below.


TRITON_DEBUG = os.getenv("TRITON_DEBUG", "0") == "1"


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cuda():
    return triton.runtime.driver.active.get_current_target().backend == "cuda"


@triton.jit
def dot(a, b, acc: tl.constexpr = None):
    use_sum: tl.constexpr = a.dtype == tl.float64
    if use_sum:
        tl.static_print("WARNING: Tensorcores disabled.")
        prod = tl.sum(a[:, :, None] * b[None, :, :], 1)
        if acc is not None:
            return acc + prod
        else:
            return prod
    else:
        return tl.dot(a, b)


@triton.jit
def _attn_fwd_inner(
    acc,
    l_i,
    m_i,
    q,
    K_block_ptr,
    V_block_ptr,
    Score_block_ptr,
    start_m,
    qk_scale,
    BLOCK_M: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    BLOCK_N: tl.constexpr,
    STAGE: tl.constexpr,
    offs_m,
    offs_n,
    seq_lens,
    M_CTX: tl.constexpr,
    N_CTX: tl.constexpr,
    Q_FACTOR: tl.constexpr,
    KV_FACTOR: tl.constexpr,
    fp8_v: tl.constexpr,
    EXCLUDE_LAST_WINDOW: tl.constexpr,
):
    # range of values handled by this stage
    
    max_seq_len = tl.max(seq_lens).to(tl.int32)
    diag_begin = max_seq_len // BLOCK_N * BLOCK_N - BLOCK_N
    if STAGE == 1:
        lo, hi = 0, diag_begin
    elif STAGE == 2:
        lo, hi = diag_begin, tl.cdiv(max_seq_len, BLOCK_N) * BLOCK_N
    else:
        lo, hi = 0, N_CTX
    
    K_block_ptr = tl.advance(K_block_ptr, (0, lo))
    V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
    if Score_block_ptr is not None:
        Score_block_ptr = tl.advance(Score_block_ptr, (0, lo))
    # loop over k, v and update accumulator
    for start_n in range(lo, hi, BLOCK_N, max_iters=tl.cdiv(N_CTX, BLOCK_N)):
        # start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
        k = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero")
        qk = dot(q, k)
        if STAGE == 2:
            idx_tsrc = offs_n + start_n
            if EXCLUDE_LAST_WINDOW:
                mask = seq_lens[:, None] > (idx_tsrc[None, :] + 1)
            else:
                mask = seq_lens[:, None] >= (idx_tsrc[None, :] + 1)
            # qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
            qk = tl.where(mask, qk * qk_scale, float('-inf'))
            if Score_block_ptr is not None:
                tl.store(
                    Score_block_ptr, 
                    qk, 
                    boundary_check=(0, 1)
                )
            m_ij = tl.maximum(m_i, tl.max(qk, 1))
            qk -= m_ij[:, None]
            qk = tl.where(mask, qk, float('-inf'))
        else:
            qk = qk * qk_scale
            if Score_block_ptr is not None:
                tl.store(
                    Score_block_ptr, 
                    qk,
                    boundary_check=(0, 1),
                )
            m_ij = tl.maximum(m_i, tl.max(qk, 1))
            qk -= m_ij[:, None]
        p = tl.math.exp2(qk)
        l_ij = tl.sum(p, 1)
        # -- update m_i and l_i
        alpha = tl.math.exp2(m_i - m_ij)
        invalid = m_ij < -1.0e5
        l_i = tl.where(invalid, l_i, l_i * alpha + l_ij)
        # -- update output accumulator --
        # update acc
        v = tl.load(
            V_block_ptr, 
            boundary_check=(0,), 
            padding_option="zero"
        )
        if fp8_v:
            p = p.to(tl.float8e5)
        else:
            p = p.to(v.dtype)
        acc = tl.where(invalid[:, None], acc, tl.dot(p, v, acc * alpha[:, None]))
        # update m_i and l_i
        m_i = m_ij
        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
        if Score_block_ptr is not None:
            Score_block_ptr = tl.advance(Score_block_ptr, (0, BLOCK_N))
    return acc, l_i, m_i


# We don't run auto-tuning every time to keep the tutorial fast. Keeping
# the code below and commenting out the equivalent parameters is convenient for
# re-tuning.
configs = [
    triton.Config({"BLOCK_M": BM, "BLOCK_N": BN}, num_stages=s, num_warps=w)
    # FIXME: why autotune results failure in decoding?
    # for BM in [64, 128]
    # for BN in [32, 64]
    # for s in ([1] if is_hip() else [3, 4, 7])
    # for w in [4, 8]
    
    for BM in [64,]
    for BN in [32,]
    for s in ([3,])
    for w in [4,]
]


def keep(conf):
    BLOCK_M = conf.kwargs["BLOCK_M"]
    BLOCK_N = conf.kwargs["BLOCK_N"]
    if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8:
        return False
    return True


@autotune(list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM"])
@triton_bwd(["Q", "K", "V"], ["Out", "M", "Scores"])
def _attn_fwd(
    Q,
    K,
    V,
    sm_scale,
    M,
    Out,  #
    Scores,
    stride_qz,
    stride_qh,
    stride_qm,
    stride_qk,  #
    stride_kz,
    stride_kh,
    stride_kn,
    stride_kk,  #
    stride_vz,
    stride_vh,
    stride_vn,
    stride_vk,  #
    stride_oz,
    stride_oh,
    stride_om,
    stride_ok,  #
    stride_sz,
    stride_sh,
    stride_sm,
    stride_sn,  #
    SEQ_LEN,
    stride_seq_len_bsz,
    stride_seq_len_tdst,
    Z,
    H,
    M_CTX,
    N_CTX,  #
    Q_FACTOR: tl.constexpr,  #
    KV_FACTOR: tl.constexpr,  #
    HEAD_GROUP: tl.constexpr,
    HEAD_DIM: tl.constexpr,  #
    BLOCK_M: tl.constexpr,  #
    BLOCK_N: tl.constexpr,  #
    STAGE: tl.constexpr,  #
    EXCLUDE_LAST_WINDOW: tl.constexpr,
):
    tl.static_assert(BLOCK_N <= HEAD_DIM)
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    off_z = off_hz // H
    off_h = off_hz % H
    q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
    o_offset = off_z.to(tl.int64) * stride_oz + off_h.to(tl.int64) * stride_oh
    k_offset = (
        off_z.to(tl.int64) * stride_kz + (off_h // HEAD_GROUP).to(tl.int64) * stride_kh
    )
    v_offset = (
        off_z.to(tl.int64) * stride_vz + (off_h // HEAD_GROUP).to(tl.int64) * stride_vh
    )
    if Scores is not None:
        score_offset = off_z.to(tl.int64) * stride_sz + off_h.to(tl.int64) * stride_sh

    # block pointers
    Q_block_ptr = tl.make_block_ptr(
        base=Q + q_offset,
        shape=(M_CTX, HEAD_DIM),
        strides=(stride_qm, stride_qk),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, HEAD_DIM),
        order=(1, 0),
    )
    v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
    V_block_ptr = tl.make_block_ptr(
        base=V + v_offset,
        shape=(N_CTX, HEAD_DIM),
        strides=(stride_vn, stride_vk),
        offsets=(0, 0),
        block_shape=(BLOCK_N, HEAD_DIM),
        order=v_order,
    )
    K_block_ptr = tl.make_block_ptr(
        base=K + k_offset,
        shape=(HEAD_DIM, N_CTX),
        strides=(stride_kk, stride_kn),
        offsets=(0, 0),
        block_shape=(HEAD_DIM, BLOCK_N),
        order=(0, 1),
    )
    O_block_ptr = tl.make_block_ptr(
        base=Out + o_offset,
        shape=(M_CTX, HEAD_DIM),
        strides=(stride_om, stride_ok),
        offsets=(start_m * BLOCK_M, 0),
        block_shape=(BLOCK_M, HEAD_DIM),
        order=(1, 0),
    )
    if Scores is not None:
        Score_block_ptr = tl.make_block_ptr(
            base=Scores + score_offset,
            shape=(M_CTX, N_CTX),
            strides=(stride_sm, stride_sn),
            offsets=(start_m * BLOCK_M, 0),
            block_shape=(BLOCK_M, BLOCK_N),
            order=(0, 1),
        )
    else:
        Score_block_ptr = None
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    mask_m = offs_m < N_CTX
    offs_n = tl.arange(0, BLOCK_N)
    # initialize pointer to m and l
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
    acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
    # load scales
    qk_scale = sm_scale
    qk_scale *= 1.44269504  # 1/log(2)
    # load q: it will stay in SRAM throughout
    q = tl.load(
        Q_block_ptr, 
        boundary_check=(0,), 
        padding_option="zero"
    )
    idx_tdst = (tl.arange(0, BLOCK_M) + start_m * BLOCK_M)
    seq_lens = tl.load(
        SEQ_LEN
        + off_z.to(tl.int64) * stride_seq_len_bsz
        + idx_tdst * stride_seq_len_tdst,
        mask=idx_tdst < M_CTX,
        other=0,
    )
    
    # stage 1: off-band
    # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
    # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
    if STAGE & 1:
        acc, l_i, m_i = _attn_fwd_inner(
            acc,
            l_i,
            m_i,
            q,
            K_block_ptr,
            V_block_ptr,  #
            Score_block_ptr,
            start_m,
            qk_scale,  #
            BLOCK_M,
            HEAD_DIM,
            BLOCK_N,  #
            4 - STAGE,
            offs_m,
            offs_n,
            seq_lens,
            M_CTX,
            N_CTX,
            Q_FACTOR,
            KV_FACTOR,
            V.dtype.element_ty == tl.float8e5,
            EXCLUDE_LAST_WINDOW,
        )
    # stage 2: on-band
    if STAGE & 2:
        # barrier makes it easier for compielr to schedule the
        # two loops independently
        acc, l_i, m_i = _attn_fwd_inner(
            acc,
            l_i,
            m_i,
            q,
            K_block_ptr,
            V_block_ptr,
            Score_block_ptr,
            start_m,
            qk_scale,
            BLOCK_M,
            HEAD_DIM,
            BLOCK_N,
            2,
            offs_m,
            offs_n,
            seq_lens,
            M_CTX,
            N_CTX,
            Q_FACTOR,
            KV_FACTOR,
            V.dtype.element_ty == tl.float8e5,
            EXCLUDE_LAST_WINDOW,
        )
    # epilogue
    m_i += tl.math.log2(l_i)
    acc = acc / l_i[:, None]
    m_ptrs = M + off_hz * M_CTX + offs_m
    tl.store(
        m_ptrs, 
        m_i, 
        mask=mask_m
    )
    tl.store(
        O_block_ptr, 
        acc.to(Out.type.element_ty), 
        boundary_check=(0,)
    )


@triton.jit
def _attn_bwd_preprocess(
    O, DO, Delta, Z, H, M_CTX, BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr  #  #  #  #
):
    off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
    off_hz = tl.program_id(1)
    off_n = tl.arange(0, HEAD_DIM)
    # load
    o = tl.load(
        O + off_hz * HEAD_DIM * M_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]
    )
    do = tl.load(
        DO + off_hz * HEAD_DIM * M_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]
    ).to(tl.float32)
    delta = tl.sum(o * do, axis=1)
    # write-back
    tl.store(Delta + off_hz * M_CTX + off_m, delta)


# The main inner-loop logic for computing dK and dV.
@triton.jit
def _attn_bwd_dkdv(
    dk,
    dv,  #
    Q,
    k,
    v,
    sm_scale,  #
    DO,  #
    M,
    D,  #
    # shared by Q/K/V/DO.
    stride_tok,
    stride_d,  #
    H,
    M_CTX,
    N_CTX,
    Q_FACTOR: tl.constexpr,
    KV_FACTOR: tl.constexpr,
    BLOCK_M1: tl.constexpr,  #
    BLOCK_N1: tl.constexpr,  #
    HEAD_DIM: tl.constexpr,  #
    # Filled in by the wrapper.
    start_n,
    start_m,
    num_steps,  #
    MASK: tl.constexpr,
    EXCLUDE_LAST_WINDOW: tl.constexpr,
):
    offs_m = start_m + tl.arange(0, BLOCK_M1)
    offs_n = start_n + tl.arange(0, BLOCK_N1)
    offs_k = tl.arange(0, HEAD_DIM)
    qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
    do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
    # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
    tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
    curr_m = start_m
    step_m = BLOCK_M1
    for blk_idx in range(num_steps):
        qT = tl.load(qT_ptrs)
        # Load m before computing qk to reduce pipeline stall.
        offs_m = curr_m + tl.arange(0, BLOCK_M1)
        m = tl.load(M + offs_m)
        qkT = dot(k, qT)
        pT = tl.math.exp2(qkT - m[None, :])
        # Autoregressive masking.
        if MASK:
            if EXCLUDE_LAST_WINDOW:
                mask = (offs_m[None, :] + 1) * Q_FACTOR > (
                    offs_n[:, None] + 1
                ) * KV_FACTOR
            else:
                mask = offs_m[None, :] * Q_FACTOR >= offs_n[:, None] * KV_FACTOR
            pT = tl.where(mask, pT, 0.0)
        do = tl.load(do_ptrs)
        # Compute dV.
        ppT = pT
        ppT = ppT.to(do.dtype)
        dv += dot(ppT, do)
        # D (= delta) is pre-divided by ds_scale.
        Di = tl.load(D + offs_m)
        # Compute dP and dS.
        dpT = dot(v, tl.trans(do)).to(tl.float32)
        dsT = pT * (dpT - Di[None, :])
        dsT = dsT.to(qT.dtype)
        dk += dot(dsT, tl.trans(qT))
        # Increment pointers.
        curr_m += step_m
        qT_ptrs += step_m * stride_tok
        do_ptrs += step_m * stride_tok
    return dk, dv


# the main inner-loop logic for computing dQ
@triton.jit
def _attn_bwd_dq(
    dq,
    q,
    K,
    V,  #
    do,
    m,
    D,
    # shared by Q/K/V/DO.
    stride_tok,
    stride_d,  #
    H,
    M_CTX,
    N_CTX,  #
    Q_FACTOR,
    KV_FACTOR,
    BLOCK_M2: tl.constexpr,  #
    BLOCK_N2: tl.constexpr,  #
    HEAD_DIM: tl.constexpr,
    # Filled in by the wrapper.
    start_m,
    start_n,
    num_steps,  #
    MASK: tl.constexpr,
    EXCLUDE_LAST_WINDOW: tl.constexpr,
):
    offs_m = start_m + tl.arange(0, BLOCK_M2)
    offs_n = start_n + tl.arange(0, BLOCK_N2)
    tl.device_assert(offs_n >= 0, "offs_n < 0: message anno_6 for assistance")
    offs_k = tl.arange(0, HEAD_DIM)
    kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
    vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
    # D (= delta) is pre-divided by ds_scale.
    Di = tl.load(D + offs_m)
    # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
    tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
    curr_n = start_n
    step_n = BLOCK_N2
    for blk_idx in range(num_steps):
        kT = tl.load(kT_ptrs)
        vT = tl.load(vT_ptrs)
        qk = dot(q, kT)
        p = tl.math.exp2(qk - m)
        # Autoregressive masking.
        if MASK:
            offs_n = curr_n + tl.arange(0, BLOCK_N2)
            if EXCLUDE_LAST_WINDOW:
                mask = (offs_m[:, None] + 1) * Q_FACTOR > (
                    offs_n[None, :] + 1
                ) * KV_FACTOR
            else:
                mask = offs_m[:, None] * Q_FACTOR >= offs_n[None, :] * KV_FACTOR
            p = tl.where(mask, p, 0.0)
        # Compute dP and dS.
        dp = dot(do, vT).to(tl.float32)
        ds = p * (dp - Di[:, None])
        if MASK:
            ds = tl.where(mask, ds, 0.0)
        ds = ds.to(kT.dtype)
        # Compute dQ.
        # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
        dq += dot(ds, tl.trans(kT))
        # Increment pointers.
        curr_n += step_n
        kT_ptrs += step_n * stride_tok
        vT_ptrs += step_n * stride_tok
    return dq


@triton.jit
def _attn_bwd(
    Q,
    K,
    V,
    sm_scale,  #
    DO,  #
    DQ,
    DK,
    DV,  #
    M,
    D,
    # shared by Q/K/V/DO.
    stride_z,
    stride_h,
    stride_tok,
    stride_d,  #
    stride_kz,
    stride_kh,
    stride_ktok,
    stride_kd,  #
    H,
    HEAD_GROUP,
    M_CTX,
    N_CTX,  #
    Q_FACTOR: tl.constexpr,
    KV_FACTOR: tl.constexpr,
    BLOCK_M1: tl.constexpr,  #
    BLOCK_N1: tl.constexpr,  #
    BLOCK_M2: tl.constexpr,  #
    BLOCK_N2: tl.constexpr,  #
    BLK_SLICE_FACTOR: tl.constexpr,  #
    HEAD_DIM: tl.constexpr,
    EXCLUDE_LAST_WINDOW: tl.constexpr,
):
    LN2: tl.constexpr = 0.6931471824645996  # = ln(2)

    bhid = tl.program_id(2)
    off_chz = (bhid * M_CTX).to(tl.int64)
    adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
    adj_kv = (stride_kh * ((bhid % H) // HEAD_GROUP) + stride_kz * (bhid // H)).to(
        tl.int64
    )
    pid = tl.program_id(0)

    # offset pointers for batch/head
    Q += adj
    K += adj_kv
    V += adj_kv
    DO += adj
    DQ += adj
    DK += adj_kv
    DV += adj_kv
    M += off_chz
    D += off_chz

    # load scales
    offs_k = tl.arange(0, HEAD_DIM)

    start_n = pid * BLOCK_N1
    start_m = (start_n * KV_FACTOR) // Q_FACTOR
    tl.device_assert(start_n + BLOCK_N1 <= N_CTX, "start_n + BLOCK_N1 >= N_CTX")
    tl.device_assert(
        tl.num_programs(0) * BLOCK_N1 == N_CTX, "num_programs(0) * BLOCK_N1 != N_CTX"
    )
    tl.device_assert(
        (start_n * KV_FACTOR) % Q_FACTOR == 0,
        "start_n * KV_FACTOR % Q_FACTOR != 0: message anno_6 for assistance",
    )

    MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
    offs_n = start_n + tl.arange(0, BLOCK_N1)

    dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
    dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)

    # load K and V: they stay in SRAM throughout the inner loop.
    k = tl.load(K + offs_n[:, None] * stride_ktok + offs_k[None, :] * stride_kd)
    v = tl.load(V + offs_n[:, None] * stride_ktok + offs_k[None, :] * stride_kd)

    num_steps = (BLOCK_N1 * KV_FACTOR // Q_FACTOR) // MASK_BLOCK_M1
    tl.device_assert(
        (BLOCK_N1 * KV_FACTOR) % Q_FACTOR == 0,
        "BLOCK_N1 * KV_FACTOR % Q_FACTOR != 0: message anno_6 for assistance",
    )

    # Iterate in query direction
    dk, dv = _attn_bwd_dkdv(
        dk,
        dv,  #
        Q,
        k,
        v,
        sm_scale,  #
        DO,  #
        M,
        D,  #
        stride_tok,
        stride_d,  #
        H,
        M_CTX,
        N_CTX,  #
        Q_FACTOR,
        KV_FACTOR,
        MASK_BLOCK_M1,
        BLOCK_N1,
        HEAD_DIM,  #
        start_n,
        start_m,
        num_steps,  #
        MASK=True,
        EXCLUDE_LAST_WINDOW=EXCLUDE_LAST_WINDOW,
    )

    start_m += num_steps * MASK_BLOCK_M1
    num_steps = (M_CTX - start_m) // BLOCK_M1

    # Compute dK and dV for non-masked blocks.
    dk, dv = _attn_bwd_dkdv(  #
        dk,
        dv,  #
        Q,
        k,
        v,
        sm_scale,  #
        DO,  #
        M,
        D,  #
        stride_tok,
        stride_d,  #
        H,
        M_CTX,
        N_CTX,  #
        Q_FACTOR,
        KV_FACTOR,
        BLOCK_M1,
        BLOCK_N1,
        HEAD_DIM,  #
        start_n,
        start_m,
        num_steps,  #
        MASK=False,  #
        EXCLUDE_LAST_WINDOW=EXCLUDE_LAST_WINDOW,
    )

    dv_ptrs = DV + offs_n[:, None] * stride_ktok + offs_k[None, :] * stride_kd
    tl.atomic_add(dv_ptrs, dv.to(DV.dtype.element_ty))

    # Write back dK.
    dk *= sm_scale
    dk_ptrs = DK + offs_n[:, None] * stride_ktok + offs_k[None, :] * stride_kd
    tl.atomic_add(dk_ptrs, dk.to(DK.dtype.element_ty))

    # THIS BLOCK DOES DQ:
    start_m = pid * BLOCK_M2
    end_n = (start_m + BLOCK_M2) * Q_FACTOR // KV_FACTOR
    tl.device_assert(start_m + BLOCK_M2 <= M_CTX, "start_m + BLOCK_M2 >= M_CTX")
    tl.device_assert(
        tl.num_programs(0) * BLOCK_M2 == M_CTX, "num_programs(0) * BLOCK_M2 != M_CTX"
    )
    tl.device_assert(
        (start_m * Q_FACTOR) % KV_FACTOR == 0,
        "start_m * Q_FACTOR % KV_FACTOR != 0: message anno_6 for assistance",
    )

    MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
    offs_m = start_m + tl.arange(0, BLOCK_M2)

    q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
    dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
    do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)

    m = tl.load(M + offs_m)
    m = m[:, None]

    # Compute dQ for masked (diagonal) blocks.
    # Iterate in Key direction.
    # NOTE: This code scans each row of QK^T backward (from right to left,
    # but inside each call to _attn_bwd_dq, from left to right), but that's
    # not due to anything important.  I just wanted to reuse the loop
    # structure for dK & dV above as much as possible.
    num_steps = ((BLOCK_M2 * Q_FACTOR) // KV_FACTOR) // MASK_BLOCK_N2
    tl.device_assert(
        (BLOCK_M2 * Q_FACTOR) % KV_FACTOR == 0,
        "BLOCK_M2 * Q_FACTOR % KV_FACTOR != 0: message anno_6 for assistance",
    )
    dq = _attn_bwd_dq(
        dq,
        q,
        K,
        V,  #
        do,
        m,
        D,  #
        stride_ktok,
        stride_kd,  #
        H,
        M_CTX,
        N_CTX,  #
        Q_FACTOR,
        KV_FACTOR,
        BLOCK_M2,
        MASK_BLOCK_N2,
        HEAD_DIM,  #
        start_m,
        end_n - num_steps * MASK_BLOCK_N2,
        num_steps,  #
        MASK=True,
        EXCLUDE_LAST_WINDOW=EXCLUDE_LAST_WINDOW,
    )
    end_n -= num_steps * MASK_BLOCK_N2
    # stage 2
    num_steps = end_n // BLOCK_N2
    dq = _attn_bwd_dq(
        dq,
        q,
        K,
        V,  #
        do,
        m,
        D,  #
        stride_ktok,
        stride_kd,  #
        H,
        M_CTX,
        N_CTX,  #
        Q_FACTOR,
        KV_FACTOR,
        BLOCK_M2,
        BLOCK_N2,
        HEAD_DIM,  #
        start_m,
        end_n - num_steps * BLOCK_N2,
        num_steps,  #
        MASK=False,
        EXCLUDE_LAST_WINDOW=EXCLUDE_LAST_WINDOW,
    )
    # Write back dQ.
    dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
    dq *= LN2
    tl.store(dq_ptrs, dq)


def forward(
    q,
    k,
    v,
    seq_lens,
    causal,
    sm_scale,
    RETURN_SCORES=False,
    EXCLUDE_LAST_WINDOW=False,
    q_factor=1,
    kv_factor=1,
    USE_TMA=False,
    use_torch_fwd=False,
):
    # shape constraints
    HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
    # when v is in float8_e5m2 it is transposed.
    HEAD_DIM_V = v.shape[-1]
    assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
    assert HEAD_DIM_K in {16, 32, 64, 128, 256}
    o = torch.zeros_like(q)
    stage = 3 if causal else 1
    extra_kern_args = {}
    # Tuning for AMD target
    if is_hip():
        waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2
        extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True}

    scores = None
    if RETURN_SCORES:
        # (B, H, q_seq, k_seq)
        scores = o.new_full(
            (q.shape[0], q.shape[1], q.shape[2], k.shape[2]),
            float("-inf"),
            dtype=torch.float32,
        )

    M = torch.empty(
        (q.shape[0], q.shape[1], q.shape[2]), 
        device=q.device,
        dtype=torch.float32,
    )
    
    assert seq_lens.ndim == 2

    grid = lambda args: (
        triton.cdiv(q.shape[2], args["BLOCK_M"]),
        q.shape[0] * q.shape[1],
        1,
    )
    o, M, scores = _attn_fwd.forward(
        grid,
        q,
        k,
        v,
        sm_scale,
        M,
        o,  #
        scores,
        q.stride(0),
        q.stride(1),
        q.stride(2),
        q.stride(3),  #
        k.stride(0),
        k.stride(1),
        k.stride(2),
        k.stride(3),  #
        v.stride(0),
        v.stride(1),
        v.stride(2),
        v.stride(3),  #
        o.stride(0),
        o.stride(1),
        o.stride(2),
        o.stride(3),  #
        scores.stride(0) if scores is not None else None,
        scores.stride(1) if scores is not None else None,
        scores.stride(2) if scores is not None else None,
        scores.stride(3) if scores is not None else None,  #
        seq_lens,
        seq_lens.stride(0), 
        seq_lens.stride(1),
        q.shape[0],
        q.shape[1],  #
        M_CTX=q.shape[2],
        N_CTX=k.shape[2],  #
        Q_FACTOR=q_factor,
        KV_FACTOR=kv_factor,
        HEAD_GROUP=q.shape[1] // k.shape[1],
        HEAD_DIM=HEAD_DIM_K,  #
        STAGE=stage,  #
        EXCLUDE_LAST_WINDOW=EXCLUDE_LAST_WINDOW,
        use_torch_fwd=use_torch_fwd,
        **extra_kern_args,
    )

    return o, M, scores


class _attention(torch.autograd.Function):

    @staticmethod
    def forward(
        ctx,
        q,
        k,
        v,
        seq_lens,
        causal,
        sm_scale,
        RETURN_SCORES,
        EXCLUDE_LAST_WINDOW,
        q_factor,
        kv_factor,
        USE_TMA,
    ):
        HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
        o, M, scores = forward(
            q,
            k,
            v,
            seq_lens,
            causal,
            sm_scale,
            RETURN_SCORES,
            EXCLUDE_LAST_WINDOW,
            q_factor,
            kv_factor,
            USE_TMA,
        )

        if TRITON_DEBUG:
            assert not torch.any(torch.isinf(o))
            if torch.any(torch.isnan(o)):
                print(torch.isnan(o).nonzero())
                assert False, "o has nan"
            if scores is not None:
                # assert not torch.any(torch.isinf(scores))
                assert not torch.any(torch.isnan(scores))

        ctx.save_for_backward(q, k, v, o, M)
        ctx.sm_scale = sm_scale
        ctx.HEAD_DIM = HEAD_DIM_K
        ctx.causal = causal
        ctx.EXCLUDE_LAST_WINDOW = EXCLUDE_LAST_WINDOW
        ctx.Q_FACTOR = q_factor
        ctx.KV_FACTOR = kv_factor
        return o, scores

    @staticmethod
    def backward(ctx, do, dscores):
        do = do.contiguous()
        q, k, v, o, M = ctx.saved_tensors
        assert do.is_contiguous()
        assert q.stride() == o.stride() == do.stride()
        assert k.stride() == v.stride()
        dq = torch.zeros_like(q, dtype=torch.float32)
        dk = torch.zeros_like(k, dtype=torch.float32)
        dv = torch.zeros_like(v, dtype=torch.float32)
        BATCH, N_HEAD, M_CTX = q.shape[:3]
        N_CTX = k.shape[2]
        NUM_WARPS, NUM_STAGES = 4, 2
        if q.dtype == torch.float32:
            BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32
        else:
            BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
        (BLOCK_N1, BLOCK_M2), (BLOCK_N1_alt, BLOCK_M2_alt) = (
            (BLOCK_N1, M_CTX * BLOCK_N1 // N_CTX),
            (N_CTX * BLOCK_M2 // M_CTX, BLOCK_M2),
        )
        if BLOCK_N1_alt * BLOCK_M2_alt < BLOCK_N1 * BLOCK_M2:
            BLOCK_N1, BLOCK_M2 = BLOCK_N1_alt, BLOCK_M2_alt
        BLOCK_M1 = min(BLOCK_N1, BLOCK_M1)
        BLOCK_N2 = min(BLOCK_N2, BLOCK_M2)
        assert BLOCK_N1 % BLOCK_M1 == 0
        assert BLOCK_M2 % BLOCK_N2 == 0
        assert M_CTX // BLOCK_M2 == N_CTX // BLOCK_N1
        BLK_SLICE_FACTOR = 2
        RCP_LN2 = 1.4426950408889634  # = 1.0 / ln(2)
        arg_k = k
        arg_k = arg_k * (ctx.sm_scale * RCP_LN2)

        # TODO(anno_0): is this min ok? sometimes in testing M_CTX was 64 for some reason
        # TODO(hj): I think this should be a constant. If we need to pass the non 128 divisible sequence length, then we need to pad
        # PRE_BLOCK = min(128, M_CTX)
        PRE_BLOCK = 128
        assert M_CTX % PRE_BLOCK == 0, f"{M_CTX} % {PRE_BLOCK}"
        pre_grid = (M_CTX // PRE_BLOCK, BATCH * N_HEAD)
        delta = torch.empty_like(M)
        _attn_bwd_preprocess[pre_grid](
            o,
            do,  #
            delta,  #
            BATCH,
            N_HEAD,
            M_CTX,  #
            BLOCK_M=PRE_BLOCK,
            HEAD_DIM=ctx.HEAD_DIM,  #
        )
        grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
        # if TRITON_DEBUG: print(f"attn bwd start {q.shape=} {k.shape=} {v.shape=}")
        _attn_bwd[grid](
            q,
            arg_k,
            v,
            ctx.sm_scale,
            do,
            dq,
            dk,
            dv,  #
            M,
            delta,  #
            q.stride(0),
            q.stride(1),
            q.stride(2),
            q.stride(3),  #
            k.stride(0),
            k.stride(1),
            k.stride(2),
            k.stride(3),
            N_HEAD,
            q.shape[1] // k.shape[1],
            M_CTX,
            N_CTX,  #
            ctx.Q_FACTOR,
            ctx.KV_FACTOR,
            BLOCK_M1=BLOCK_M1,
            BLOCK_N1=BLOCK_N1,  #
            BLOCK_M2=BLOCK_M2,
            BLOCK_N2=BLOCK_N2,  #
            BLK_SLICE_FACTOR=BLK_SLICE_FACTOR,  #
            HEAD_DIM=ctx.HEAD_DIM,  #
            EXCLUDE_LAST_WINDOW=ctx.EXCLUDE_LAST_WINDOW,
            num_warps=NUM_WARPS,  #
            num_stages=NUM_STAGES,
        )

        # dq.zero_()
        # dk.zero_()
        # dv.zero_()

        if TRITON_DEBUG:
            assert not torch.any(torch.isnan(dq))
            assert not torch.any(torch.isnan(dk))
            assert not torch.any(torch.isnan(dv))
            assert not torch.any(torch.isinf(dq))
            assert not torch.any(torch.isinf(dk))
            assert not torch.any(torch.isinf(dv))

        return dq, dk, dv, None, None, None, None, None, None, None


def flash_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    seq_lens: torch.Tensor,
    causal: bool,
    sm_scale: Optional[float] = None,
    USE_TMA: bool = False,
    EXCLUDE_LAST_WINDOW: bool = False,
    RETURN_SCORES: bool = False,
    q_factor=1,
    kv_factor=1,
):
    if sm_scale is None:
        sm_scale = q.shape[-1] ** 0.5

    output, scores = _attention.apply(
        q,
        k,
        v,
        seq_lens,
        causal,
        sm_scale,
        RETURN_SCORES,
        EXCLUDE_LAST_WINDOW,
        q_factor,
        kv_factor,
        USE_TMA,
    )

    if scores is not None:
        return output, scores
    return output
