# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Triton FlashAttention implementation with learnable bias softmax support.

This is a modified version of the Flash Attention v2 implementation that supports
learnable bias softmax as used in the torchtitan codebase.
"""

import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor


@triton.jit
def _attn_fwd_learnable_bias(
    Q,
    K,
    V,
    BiasParams,
    sm_scale,
    M,
    Out,
    QKRowMax,
    Entropy,
    Start_q,
    Z,
    H,
    N_Q_TRUE,
    N_Q_CTX,
    N_KV_CTX,
    HEAD_DIM: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    USE_BIAS: tl.constexpr,
    USE_METRICS: tl.constexpr,
):
    start_q = tl.load(Start_q).to(tl.int32)
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    off_z = off_hz // H
    off_h = off_hz % H

    # Load learnable bias parameter for this head
    if USE_BIAS:
        bias_param = tl.load(BiasParams + off_h).to(tl.float32)
    else:
        bias_param = 0.0

    # Initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)

    # Initialize pointer to m and l (avoid -inf to prevent inf deltas on first tile)
    m_i = tl.full([BLOCK_M], -1e20, dtype=tl.float32)
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    if USE_METRICS:
        t_i = tl.zeros([BLOCK_M], dtype=tl.float32)  # sum e^(..) * (qk - m)
    acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)

    # Valid rows within this tile
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    valid_q = offs_m < N_Q_TRUE

    # Load scales
    qk_scale = sm_scale
    q = (
        Q.load([off_z, off_h, start_m * BLOCK_M, 0])
        .reshape([BLOCK_M, HEAD_DIM])
        .to(tl.float32)
    )
    q = q * valid_q[:, None].to(tl.float32)

    # Early exit if this entire tile is padding on the query side
    if not ((start_m + 1) * BLOCK_M <= N_Q_TRUE):
        if start_m * BLOCK_M >= N_Q_TRUE:
            return

    # Compute attention range (causal masking)
    lo = 0  # not start_q
    hi = tl.minimum(start_q + (start_m + 1) * BLOCK_M, N_KV_CTX)

    for start_n in range(lo, hi, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)

        # Causal mask
        mask = (start_n + offs_n)[None, :] > (start_q + offs_m)[:, None]

        # Mask out padded KV tail to avoid contributing to denominator
        kv_mask = (start_n + offs_n) < N_KV_CTX

        k = (
            K.load([off_z, off_h, start_n, 0])
            .reshape([BLOCK_N, HEAD_DIM])
            .to(tl.float32)
            .T
        )
        qk = tl.dot(q, k, allow_tf32=True)

        qk = qk * qk_scale
        # Apply causal mask without tl.where to avoid slowdowns
        qk = qk + mask.to(tl.float32) * (-1e20)
        # Mask out padded KV tail without tl.where
        kv_mask_f32 = kv_mask.to(tl.float32)[None, :]
        qk = qk * kv_mask_f32 + (1.0 - kv_mask_f32) * (-1e20)

        # new running max for each row
        m_ij = tl.maximum(m_i, tl.max(qk, 1))
        delta = m_ij - m_i
        alpha = tl.math.exp(-delta)  # = exp(m_i - m_ij)

        # shift to new max and exponentiate
        qk = qk - m_ij[:, None]
        p = tl.math.exp(qk)
        l_ij = tl.sum(p, 1)

        # ---- entropy running update ----
        if USE_METRICS:
            t_i = alpha * (t_i - delta * l_i) + tl.sum(p * qk, 1)

        # ---- value acc / denominator updates ----
        acc = acc * alpha[:, None]

        v = V.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM])
        v = v.to(tl.float32)
        acc = tl.dot(p, v, acc, allow_tf32=True)

        l_i = l_i * alpha + l_ij
        m_i = m_ij

    # Apply learnable bias to the normalization
    if USE_BIAS:
        # acc = sum_j exp(qk_j - m_i) * v_j
        # l_i = sum_j exp(qk_j - m_i)
        # O = acc / (l_i + b * exp(-m_i))       <-- stable
        bias_exp_neg_m = bias_param * tl.math.exp(-m_i)
        full_denominator = l_i + bias_exp_neg_m
        full_denominator = tl.where(valid_q, full_denominator, 1.0)
        acc = acc / full_denominator[:, None]
    else:
        # Standard softmax normalization
        full_denominator = l_i
        full_denominator = tl.where(valid_q, full_denominator, 1.0)
        acc = acc / l_i[:, None]

    # Zero out invalid rows to avoid writing garbage
    acc = tl.where(valid_q[:, None], acc, 0.0)

    # Store results: save log-normalizer m = log(sum_exp)
    if USE_BIAS:
        # Exclude bias for m_store so backward uses log(sum_exp) consistently
        m_store = m_i + tl.math.log(l_i)
    else:
        m_store = m_i + tl.math.log(l_i)
    m_ptrs = M + off_hz * N_Q_CTX + offs_m
    tl.store(m_ptrs, m_store, mask=valid_q)

    # Metrics: optionally store row-wise max and entropy
    if USE_METRICS:
        tl.store(QKRowMax + off_hz * N_Q_CTX + offs_m, m_i, mask=valid_q)
        # Match reference metric: denom = sum_exp + bias (no exp(-m) scaling)
        eps = 1e-20
        denom_entropy = l_i + (bias_param if USE_BIAS else 0.0)
        denom_entropy = tl.maximum(denom_entropy, eps)
        H = tl.math.log(denom_entropy) - t_i / denom_entropy
        tl.store(Entropy + off_hz * N_Q_CTX + offs_m, H, mask=valid_q)

    acc = acc.to(Out.dtype)[None, None, :, :]
    Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)


@triton.jit
def _attn_bwd_preprocess(
    O, DO, Delta, Z, H, N_CTX, BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr
):
    """Preprocess step for backward pass - compute delta = sum(O * dO)"""
    off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
    off_hz = tl.program_id(1)
    off_n = tl.arange(0, HEAD_DIM)

    # Load O and dO
    o = tl.load(
        O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]
    )
    do = tl.load(
        DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]
    ).to(tl.float32)

    # Compute delta = sum(O * dO, dim=-1)
    delta = tl.sum(o * do, axis=1)

    # Store delta
    tl.store(Delta + off_hz * N_CTX + off_m, delta)


@triton.jit
def _attn_bwd_dkdv_learnable_bias(
    dk,
    dv,
    Q,
    k,
    v,
    sm_scale,
    DO,
    M,
    D,
    BiasParams,
    stride_tok,
    stride_d,
    H,
    N_CTX,
    BLOCK_M1: tl.constexpr,
    BLOCK_N1: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    start_n,
    start_m,
    num_steps,
    off_h,
    USE_BIAS: tl.constexpr,
    MASK: tl.constexpr,
):
    """Compute gradients for K and V with learnable bias support"""
    offs_m = start_m + tl.arange(0, BLOCK_M1)
    offs_n = start_n + tl.arange(0, BLOCK_N1)
    offs_k = tl.arange(0, HEAD_DIM)

    # Load bias parameter if needed
    if USE_BIAS:
        bias_param = tl.load(BiasParams + off_h).to(tl.float32)
    else:
        bias_param = 0.0

    qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
    do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d

    tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
    curr_m = start_m
    step_m = BLOCK_M1

    for blk_idx in range(num_steps):
        qT = tl.load(qT_ptrs).to(tl.float32)
        offs_m = curr_m + tl.arange(0, BLOCK_M1)
        m = tl.load(M + offs_m)

        qkT = tl.dot(k, qT, allow_tf32=False)
        qkT = qkT * sm_scale

        # std numerators
        pT = tl.math.exp(qkT - m[None, :])

        if USE_BIAS:
            # stable alpha = 1 / (1 + b * exp(-m))
            alpha = 1.0 / (1.0 + bias_param * tl.math.exp(-m)[None, :])
            pT = pT * alpha  # sT

        # Apply causal mask
        if MASK:
            mask = offs_m[None, :] >= offs_n[:, None]
            pT = tl.where(mask, pT, 0.0)

        do = tl.load(do_ptrs).to(tl.float32)

        # dV
        pT_f32 = pT.to(tl.float32)
        dv += tl.dot(pT_f32, do, allow_tf32=False)

        # dpT = V @ dO^T
        Di = tl.load(D + offs_m).to(tl.float32)
        v_f32 = v.to(tl.float32)
        dpT = tl.dot(v_f32, tl.trans(do), allow_tf32=True)

        # dsT = sT * (dpT - Δ)  (Δ loaded as Di)
        dsT_f32 = pT_f32 * (dpT - Di[None, :])

        # dK accum (sm_scale is applied later once)
        qT_f32 = qT.to(tl.float32)
        dk += tl.dot(dsT_f32, tl.trans(qT_f32), allow_tf32=True)

        # Increment pointers
        curr_m += step_m
        qT_ptrs += step_m * stride_tok
        do_ptrs += step_m * stride_tok

    return dk, dv


@triton.jit
def _attn_bwd_dq_learnable_bias(
    dq,
    Q,
    K,
    V,
    sm_scale,
    DO,
    M,
    D,
    BiasParams,
    stride_tok,
    stride_d,
    H,
    N_CTX,
    BLOCK_M2: tl.constexpr,
    BLOCK_N2: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    start_m,
    start_n,
    num_steps,
    off_h,
    USE_BIAS: tl.constexpr,
    MASK: tl.constexpr,
):
    """Compute gradients for Q with learnable bias support (loads tiles internally)."""
    offs_m = start_m + tl.arange(0, BLOCK_M2)
    offs_n = start_n + tl.arange(0, BLOCK_N2)
    offs_k = tl.arange(0, HEAD_DIM)

    # Load bias parameter if needed
    if USE_BIAS:
        bias_param = tl.load(BiasParams + off_h).to(tl.float32)
    else:
        bias_param = 0.0

    # Load q and grad_out tiles
    q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d).to(
        tl.float32
    )
    grad_out = tl.load(
        DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
    ).to(tl.float32)

    # Load per-row log-sum-exp m
    m = tl.load(M + offs_m)
    m = m[:, None]

    # Build tiles with shape [BLOCK_N2, HEAD_DIM] so they can be used as N,K matrices
    # Use [HEAD_DIM, BLOCK_N2] tile layout for K and V blocks
    kT_ptrs = K + offs_k[:, None] * stride_d + offs_n[None, :] * stride_tok
    vT_ptrs = V + offs_k[:, None] * stride_d + offs_n[None, :] * stride_tok

    Di = tl.load(D + offs_m).to(tl.float32)

    tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
    curr_n = start_n
    step_n = BLOCK_N2

    for blk_idx in range(num_steps):
        kT = tl.load(kT_ptrs).to(tl.float32)  # [HEAD_DIM, BLOCK_N2]
        vT = tl.load(vT_ptrs).to(tl.float32)  # [HEAD_DIM, BLOCK_N2]

        # qk: [BLOCK_M2, HEAD_DIM] @ [HEAD_DIM, BLOCK_N2]
        qk = tl.dot(q, kT, allow_tf32=True)

        # Scale by attention scale factor
        qk = qk * sm_scale

        p = tl.math.exp(qk - m)  # std numerators

        if USE_BIAS:
            alpha = 1.0 / (1.0 + bias_param * tl.math.exp(-m))  # m is [BLOCK_M2,1]
            p = p * alpha  # s

        # Apply causal mask
        if MASK:
            offs_n = curr_n + tl.arange(0, BLOCK_N2)
            mask = offs_m[:, None] >= offs_n[None, :]
            p = tl.where(mask, p, 0.0)

        # dp = dO @ V^T
        dp = tl.dot(grad_out, vT, allow_tf32=True)

        # ds = s * (dp - Δ)
        p_f32 = p.to(tl.float32)
        ds_f32 = p_f32 * (dp - Di[:, None])

        dq += tl.dot(ds_f32, tl.trans(kT), allow_tf32=True)

        # Increment pointers
        curr_n += step_n
        kT_ptrs += step_n * stride_tok
        vT_ptrs += step_n * stride_tok

    return dq


@triton.jit
def _attn_bwd_dbias_learnable_bias(
    DBias,
    O,
    DO,
    M,
    BiasParams,
    stride_z,
    stride_h,
    stride_tok,
    stride_d,
    Z,
    H,
    N_CTX,
    BLOCK_M: tl.constexpr,
    HEAD_DIM: tl.constexpr,
):
    """Compute gradient for learnable bias parameters"""
    bhid = tl.program_id(0)
    off_z = bhid // H
    off_h = bhid % H

    # Offset pointers for batch/head
    adj = (stride_h * off_h + stride_z * off_z).to(tl.int64)
    O += adj
    DO += adj
    M += (bhid * N_CTX).to(tl.int64)

    # Load bias parameter for this head
    bias_param = tl.load(BiasParams + off_h).to(tl.float32)

    # Initialize gradient accumulator
    dbias = tl.zeros([], dtype=tl.float32)

    offs_k = tl.arange(0, HEAD_DIM)

    # Loop over all sequence positions for this batch/head
    for start_m in range(0, N_CTX, BLOCK_M):
        offs_m = start_m + tl.arange(0, BLOCK_M)

        # Bounds check
        mask_m = offs_m < N_CTX

        # Load O, dO, and M for this block
        o = tl.load(
            O + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d,
            mask=mask_m[:, None],
        )
        do = tl.load(
            DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d,
            mask=mask_m[:, None],
        ).to(tl.float32)
        m = tl.load(M + offs_m, mask=mask_m)

        # Compute delta = sum(O * dO, dim=-1) for each query position
        delta = tl.sum(o * do, axis=1)

        # Correct gradient: dL/db = -sum_i delta_i / (S_i + b)
        # where S_i = exp(m_i). This is numerically stable using stored m.
        inv_norm = 1.0 / (tl.math.exp(m) + bias_param)
        dbias_contrib = tl.sum(tl.where(mask_m, -delta * inv_norm, 0.0))
        dbias += dbias_contrib

    # Store the gradient for this batch/head combination
    tl.store(DBias + bhid, dbias)


@triton.jit
def _attn_bwd_learnable_bias(
    Q,
    K,
    V,
    sm_scale,
    DO,
    DQ,
    DK,
    DV,
    DBias,
    M,
    D,
    BiasParams,
    stride_z,
    stride_h,
    stride_tok,
    stride_d,
    H,
    N_CTX,
    BLOCK_M1: tl.constexpr,
    BLOCK_N1: tl.constexpr,
    BLOCK_M2: tl.constexpr,
    BLOCK_N2: tl.constexpr,
    BLK_SLICE_FACTOR: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    USE_BIAS: tl.constexpr,
):
    """Main backward kernel for learnable bias attention"""

    bhid = tl.program_id(2)
    off_chz = (bhid * N_CTX).to(tl.int64)
    adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
    pid = tl.program_id(0)
    off_h = bhid % H

    # Offset pointers for batch/head
    Q += adj
    K += adj
    V += adj
    DO += adj
    DQ += adj
    DK += adj
    DV += adj
    M += off_chz
    D += off_chz

    offs_k = tl.arange(0, HEAD_DIM)
    start_n = pid * BLOCK_N1
    start_m = start_n

    MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
    offs_n = start_n + tl.arange(0, BLOCK_N1)

    dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
    dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)

    # Load K and V
    k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d).to(
        tl.float32
    )
    v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d).to(
        tl.float32
    )

    num_steps = BLOCK_N1 // MASK_BLOCK_M1

    # Compute dK and dV for masked blocks
    dk, dv = _attn_bwd_dkdv_learnable_bias(
        dk,
        dv,
        Q,
        k,
        v,
        sm_scale,
        DO,
        M,
        D,
        BiasParams,
        stride_tok,
        stride_d,
        H,
        N_CTX,
        MASK_BLOCK_M1,
        BLOCK_N1,
        HEAD_DIM,
        start_n,
        start_m,
        num_steps,
        off_h,
        USE_BIAS,
        MASK=True,
    )

    start_m += num_steps * MASK_BLOCK_M1
    num_steps = (N_CTX - start_m) // BLOCK_M1

    # Compute dK and dV for non-masked blocks
    dk, dv = _attn_bwd_dkdv_learnable_bias(
        dk,
        dv,
        Q,
        k,
        v,
        sm_scale,
        DO,
        M,
        D,
        BiasParams,
        stride_tok,
        stride_d,
        H,
        N_CTX,
        BLOCK_M1,
        BLOCK_N1,
        HEAD_DIM,
        start_n,
        start_m,
        num_steps,
        off_h,
        USE_BIAS,
        MASK=False,
    )

    # Store dV
    dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
    tl.store(dv_ptrs, dv)

    # Store dK
    dk *= sm_scale
    dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
    tl.store(dk_ptrs, dk)

    # Compute dQ
    start_m = pid * BLOCK_M2
    end_n = start_m + BLOCK_M2

    MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
    offs_m = start_m + tl.arange(0, BLOCK_M2)

    q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d).to(
        tl.float32
    )
    dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
    do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d).to(
        tl.float32
    )

    m = tl.load(M + offs_m)
    m = m[:, None]

    # Compute dQ for masked blocks
    num_steps = BLOCK_M2 // MASK_BLOCK_N2
    dq = _attn_bwd_dq_learnable_bias(
        dq,
        Q,
        K,
        V,
        sm_scale,
        DO,
        M,
        D,
        BiasParams,
        stride_tok,
        stride_d,
        H,
        N_CTX,
        BLOCK_M2=BLOCK_M2,
        BLOCK_N2=MASK_BLOCK_N2,
        HEAD_DIM=HEAD_DIM,
        start_m=start_m,
        start_n=end_n - num_steps * MASK_BLOCK_N2,
        num_steps=num_steps,
        off_h=off_h,
        USE_BIAS=USE_BIAS,
        MASK=True,
    )

    end_n -= num_steps * MASK_BLOCK_N2
    num_steps = end_n // BLOCK_N2

    # Compute dQ for non-masked blocks
    dq = _attn_bwd_dq_learnable_bias(
        dq,
        Q,
        K,
        V,
        sm_scale,
        DO,
        M,
        D,
        BiasParams,
        stride_tok,
        stride_d,
        H,
        N_CTX,
        BLOCK_M2=BLOCK_M2,
        BLOCK_N2=BLOCK_N2,
        HEAD_DIM=HEAD_DIM,
        start_m=start_m,
        start_n=end_n - num_steps * BLOCK_N2,
        num_steps=num_steps,
        off_h=off_h,
        USE_BIAS=USE_BIAS,
        MASK=False,
    )

    # Apply scaling once after all dQ computations
    dq = dq * sm_scale

    # Store dQ (no LN2 scaling needed for base-e)
    dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
    tl.store(dq_ptrs, dq)


class _learnable_bias_attention(torch.autograd.Function):  # noqa: N801
    @staticmethod
    def forward(
        ctx, q, k, v, bias_params, sm_scale, start_q, use_metrics: bool = False
    ):
        assert len(start_q) == 1
        bs, n_ctx, n_heads, HEAD_DIM_Q = q.shape
        bs, n_kv_ctx, n_kv_heads, HEAD_DIM_K = k.shape
        bs, n_kv_ctx, n_kv_heads, HEAD_DIM_V = v.shape

        assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
        assert HEAD_DIM_K in {16, 32, 64, 128, 256}
        assert n_heads == n_kv_heads, "GQA not yet supported in this implementation"

        q = q.transpose(1, 2).contiguous()
        k = k.transpose(1, 2).contiguous()
        v = v.transpose(1, 2).contiguous()

        # Use fixed tile sizes for forward
        BLOCK_M = 64
        BLOCK_N = 128
        # Backward kernels require N_CTX to be divisible by 128; pad Q/K/V to that.
        target_ctx = ((n_ctx + 127) // 128) * 128
        m_pad_size = target_ctx - n_ctx
        # Pad q to target_ctx in the n_ctx dimension (-2)
        if m_pad_size > 0:
            q = torch.nn.functional.pad(q, (0, 0, 0, m_pad_size))
        # Pad k and v to the same target_ctx to satisfy backward tile shapes
        kv_pad_size = target_ctx - n_kv_ctx
        if kv_pad_size > 0:
            k = torch.nn.functional.pad(k, (0, 0, 0, kv_pad_size))
            v = torch.nn.functional.pad(v, (0, 0, 0, kv_pad_size))

        o = torch.empty_like(q)
        M = torch.empty(
            (bs, n_heads, n_ctx + m_pad_size), device=q.device, dtype=torch.float32
        )

        # Launch enough blocks to cover the padded query length
        grid = (triton.cdiv(n_ctx + m_pad_size, BLOCK_M), bs * n_heads, 1)

        use_bias = bias_params is not None

        # Prepare metric buffers only if requested; otherwise reuse M as dummy to avoid allocation
        if use_metrics:
            QKRowMax = torch.empty_like(M)
            Entropy = torch.empty_like(M)
        else:
            QKRowMax = M
            Entropy = M

        _attn_fwd_learnable_bias[grid](
            TensorDescriptor.from_tensor(q, [1, 1, BLOCK_M, HEAD_DIM_K]),
            TensorDescriptor.from_tensor(k, [1, 1, BLOCK_N, HEAD_DIM_K]),
            TensorDescriptor.from_tensor(v, [1, 1, BLOCK_N, HEAD_DIM_K]),
            bias_params if use_bias else None,
            sm_scale,
            M,
            TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, HEAD_DIM_K]),
            QKRowMax,  # NEW
            Entropy,  # NEW
            start_q,
            q.shape[0],  # Z
            q.shape[1],  # H
            n_ctx,  # N_Q_TRUE
            N_Q_CTX=n_ctx + m_pad_size,
            # Only compute over the true KV context; padded tail is masked out
            N_KV_CTX=n_kv_ctx,
            HEAD_DIM=HEAD_DIM_K,
            BLOCK_M=BLOCK_M,
            BLOCK_N=BLOCK_N,
            USE_BIAS=use_bias,
            USE_METRICS=use_metrics,
        )

        ctx.save_for_backward(q, k, v, bias_params, o, M, start_q)
        ctx.sm_scale = sm_scale
        ctx.use_bias = use_bias
        ctx.n_ctx = n_ctx
        ctx.bs = bs
        ctx.n_heads = n_heads
        ctx.HEAD_DIM_V = HEAD_DIM_V

        # slice to original length
        o_unpadded = o[:, :, :n_ctx, :]
        qk_rowmax = QKRowMax[:, :, :n_ctx]  # (bs, n_heads, n_ctx)
        entropy = Entropy[:, :, :n_ctx]  # (bs, n_heads, n_ctx)

        # these are metrics only – not part of autograd
        if use_metrics:
            ctx.mark_non_differentiable(qk_rowmax, entropy)

        return o_unpadded, qk_rowmax, entropy

    @staticmethod
    def backward(ctx, do, *_ignored_metric_grads):
        q, k, v, bias_params, o, M, start_q = ctx.saved_tensors
        sm_scale = ctx.sm_scale
        use_bias = ctx.use_bias

        # Triton-only path

        # Ensure contiguous gradient and pad to the padded sequence length used in forward
        do = do.contiguous()

        BATCH, N_HEAD, N_CTX_PAD = q.shape[:3]
        HEAD_DIM = q.shape[-1]
        n_ctx_orig = ctx.n_ctx
        if do.shape[2] < N_CTX_PAD:
            pad_len = N_CTX_PAD - do.shape[2]
            do_padded = F.pad(do, (0, 0, 0, pad_len))
        else:
            do_padded = do

        dq = torch.empty_like(q)
        dk = torch.empty_like(k)
        dv = torch.empty_like(v)
        dbias = torch.empty_like(bias_params) if use_bias else None

        N_CTX = N_CTX_PAD
        PRE_BLOCK = 128
        NUM_WARPS, NUM_STAGES = 4, 5
        BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
        BLK_SLICE_FACTOR = 2

        # Use base-e consistently - no RCP_LN2 scaling needed
        arg_k = k

        # Preprocess: compute delta = sum(O * dO, dim=-1)
        assert N_CTX % PRE_BLOCK == 0
        pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
        delta = torch.empty_like(M)

        _attn_bwd_preprocess[pre_grid](
            o,
            do_padded,
            delta,
            BATCH,
            N_HEAD,
            N_CTX,
            BLOCK_M=PRE_BLOCK,
            HEAD_DIM=HEAD_DIM,
        )

        # Main backward pass
        grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
        _attn_bwd_learnable_bias[grid](
            q,
            arg_k,
            v,
            sm_scale,
            do_padded,
            dq,
            dk,
            dv,
            dbias,
            M,
            delta,
            bias_params if use_bias else None,
            q.stride(0),
            q.stride(1),
            q.stride(2),
            q.stride(3),
            N_HEAD,
            N_CTX,
            BLOCK_M1=BLOCK_M1,
            BLOCK_N1=BLOCK_N1,
            BLOCK_M2=BLOCK_M2,
            BLOCK_N2=BLOCK_N2,
            BLK_SLICE_FACTOR=BLK_SLICE_FACTOR,
            HEAD_DIM=HEAD_DIM,
            USE_BIAS=use_bias,
            num_warps=NUM_WARPS,
            num_stages=NUM_STAGES,
        )

        # Compute bias gradient separately if needed
        if use_bias:
            # Initialize bias gradient tensor - sum across all batch*head combinations
            dbias.zero_()
            bias_grad_temp = torch.zeros(
                BATCH * N_HEAD, device=q.device, dtype=torch.float32
            )

            bias_grid = (BATCH * N_HEAD,)
            _attn_bwd_dbias_learnable_bias[bias_grid](
                bias_grad_temp,
                o,
                do_padded,
                M,
                bias_params,
                q.stride(0),
                q.stride(1),
                q.stride(2),
                q.stride(3),
                BATCH,
                N_HEAD,
                N_CTX,
                BLOCK_M=PRE_BLOCK,
                HEAD_DIM=HEAD_DIM,
            )

            # Sum across batch dimension to get per-head gradients
            bias_grad_temp = bias_grad_temp.view(BATCH, N_HEAD)
            dbias.copy_(bias_grad_temp.sum(dim=0))

        # Unpad and transpose gradients back to match original input shapes
        dq_out = dq[:, :, :n_ctx_orig, :].transpose(1, 2).contiguous()
        dk_out = dk[:, :, :n_ctx_orig, :].transpose(1, 2).contiguous()
        dv_out = dv[:, :, :n_ctx_orig, :].transpose(1, 2).contiguous()

        return dq_out, dk_out, dv_out, dbias, None, None, None


learnable_bias_attention = _learnable_bias_attention.apply


def triton_attention_with_bias(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    bias_params: torch.Tensor = None,
    sm_scale: float = None,
    start_q: int = 0,
    return_metrics: bool = False,  # NEW
):
    """
    Triton-based attention with optional learnable bias softmax.

    Args:
        query: Query tensor (bs, seqlen, n_heads, head_dim)
        key: Key tensor (bs, seqlen, n_heads, head_dim)
        value: Value tensor (bs, seqlen, n_heads, head_dim)
        bias_params: Optional learnable bias parameters (n_heads,)
        sm_scale: Scaling factor for attention scores
        start_q: Starting position for queries (for sliding window)

    Returns:
        Attention output tensor (bs, seqlen, n_heads * head_dim)
    """
    if sm_scale is None:
        sm_scale = 1.0 / (query.shape[-1] ** 0.5)

    # Make sure bias_params is on the same device/dtype
    if bias_params is not None:
        bias_params = bias_params.to(query.device, dtype=torch.float32).contiguous()

    start_q_tensor = torch.tensor([start_q], dtype=torch.int32, device=query.device)

    # forward now returns (o, qk_rowmax, entropy); compute metrics only if requested
    o, qk_rowmax, entropy = _learnable_bias_attention.apply(
        query, key, value, bias_params, sm_scale, start_q_tensor, return_metrics
    )

    # Convert to expected output format (bs, seqlen, n_heads * head_dim)
    bs, n_heads, seqlen, head_dim = o.shape
    o = o.transpose(1, 2).contiguous()  # (bs, seqlen, n_heads, head_dim)
    o = o.view(bs, seqlen, n_heads * head_dim)

    if return_metrics:
        return o, {"qk_row_max": qk_rowmax, "entropy": entropy}
    return o


# Reference implementation for testing
def attention_ref_with_bias(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    bias_params: torch.Tensor = None,
    sm_scale: float = None,
    start_q: int = 0,
    return_metrics: bool = False,
):
    """
    Reference implementation of attention with learnable bias softmax.

    If return_metrics is True, also returns metrics dict with:
    - "qk_row_max": tensor of shape (bs, n_heads, seqlen)
    - "entropy": tensor of shape (bs, n_heads, seqlen)
    """
    batch_size, num_queries, num_heads, head_dim = query.shape
    batch_size, num_keys, num_heads, head_dim = key.shape

    if sm_scale is None:
        sm_scale = 1.0 / (head_dim**0.5)

    pos_keys = torch.arange(num_keys, device=query.device)
    pos_queries = torch.arange(num_queries, device=query.device) + start_q
    mask = pos_keys[None, :] > pos_queries[:, None]
    mask = mask.float().masked_fill(mask, float("-inf"))

    logits = torch.einsum("bqhd,bkhd->bhqk", query.float(), key.float()) * sm_scale
    logits = logits + mask[None, None, :, :]

    # Apply learnable bias softmax
    if bias_params is not None:
        exp_logits = torch.exp(logits)
        sum_exp = torch.sum(exp_logits, dim=-1, keepdim=True)
        # bias_params shape: (n_heads,) -> (1, n_heads, 1, 1)
        bias_expanded = bias_params.view(1, -1, 1, 1)
        denominator = bias_expanded + sum_exp
        scores = exp_logits / denominator
    else:
        scores = torch.softmax(logits, dim=-1)

    output = torch.einsum("bhqk,bkhd->bqhd", scores, value.float())
    output = output.reshape(batch_size, num_queries, num_heads * head_dim)

    # --- Metrics for parity with Triton ---
    # qk_row_max: max over keys per (b, h, q); ignore masked -inf
    qk_row_max = logits.max(dim=-1).values  # (bs, n_heads, seqlen)

    # Stable numerators for entropy: p = exp(logits - m)
    m = qk_row_max  # (bs, n_heads, seqlen)
    p = torch.exp(logits - m[..., None])  # (bs, n_heads, seqlen, num_keys)
    sum_exp = p.sum(dim=-1)  # (bs, n_heads, seqlen)
    # zero out masked contributions to avoid 0 * -inf -> NaN
    centered = logits - m[..., None]
    centered = torch.where(
        torch.isfinite(centered), centered, torch.zeros_like(centered)
    )
    t_i = (p * centered).sum(dim=-1)  # (bs, n_heads, seqlen)

    if bias_params is not None:
        # bias per head, broadcast to (bs, n_heads, seqlen)
        bias_broadcast = bias_params.view(1, -1, 1)
        full_denominator = sum_exp + bias_broadcast
    else:
        full_denominator = sum_exp

    eps = 1e-20
    entropy = torch.log(
        full_denominator.clamp_min(eps)
    ) - t_i / full_denominator.clamp_min(eps)

    if not return_metrics:
        return output.to(query.dtype)

    return output.to(query.dtype), {
        "qk_row_max": qk_row_max.detach(),
        "entropy": entropy.detach(),
    }
