import math
import torch
import triton
import triton.language as tl
import time
from .quant_utils import triton_quantize_and_pack_along_last_dim
from .quant_utils import unpack_and_dequant_vcache
from .utils import repeat_kv
# from triton_quant import triton_quantize_and_pack_along_last_dim
# from quant_utils import unpack_and_dequant_vcache
# from utils import repeat_kv




# _stage1a_configs = [
#     triton.Config({'KV_CHUNK': 32},  num_warps=2, num_stages=2),
#     triton.Config({'KV_CHUNK': 32},  num_warps=4, num_stages=2),
#     triton.Config({'KV_CHUNK': 32},  num_warps=8, num_stages=2),
#     triton.Config({'KV_CHUNK': 32},  num_warps=4, num_stages=3),
#     triton.Config({'KV_CHUNK': 32},  num_warps=4, num_stages=4),

#     triton.Config({'KV_CHUNK': 64},  num_warps=2, num_stages=2),
#     triton.Config({'KV_CHUNK': 64},  num_warps=4, num_stages=2),
#     triton.Config({'KV_CHUNK': 64},  num_warps=8, num_stages=2),
#     triton.Config({'KV_CHUNK': 64},  num_warps=4, num_stages=3),
#     triton.Config({'KV_CHUNK': 64},  num_warps=4, num_stages=4),
#     triton.Config({'KV_CHUNK': 64},  num_warps=8, num_stages=3),

#     triton.Config({'KV_CHUNK': 128}, num_warps=2, num_stages=2),
#     triton.Config({'KV_CHUNK': 128}, num_warps=4, num_stages=2),
#     triton.Config({'KV_CHUNK': 128}, num_warps=8, num_stages=2),
#     triton.Config({'KV_CHUNK': 128}, num_warps=4, num_stages=3),
#     triton.Config({'KV_CHUNK': 128}, num_warps=8, num_stages=3),
# ]
# @triton.autotune(configs=_stage1a_configs, key=['D'])
@triton.jit
def _kvlinc_stage_1a(
    Q_ptr, K_ptr, V_ptr,
    M_ptr, L_ptr, OSUM_ptr,
    BH: tl.int32, S: tl.int32, D: tl.int32, Dv: tl.int32, NB: tl.int32,
    Hq: tl.int32, Hkv: tl.int32, R: tl.int32,
    q_stride_bh: tl.int32, q_stride_d: tl.int32,
    k_stride_bh: tl.int32, k_stride_s: tl.int32, k_stride_d: tl.int32,
    v_stride_bh: tl.int32, v_stride_s: tl.int32, v_stride_d: tl.int32,
    m_stride_bh: tl.int32, m_stride_nb: tl.int32,
    l_stride_bh: tl.int32, l_stride_nb: tl.int32,
    os_stride_bh: tl.int32, os_stride_nb: tl.int32, os_stride_dv: tl.int32,
    softmax_scale,
    BLOCK_KV: tl.constexpr, KV_CHUNK: tl.constexpr,
    BLOCK_D: tl.constexpr, BLOCK_DV: tl.constexpr,
):
    pid_blk = tl.program_id(0)
    pid_bh  = tl.program_id(1)
    
    # for gqa
    b = pid_bh // Hq
    h_q = pid_bh % Hq
    h_kv = h_q // R
    pid_bh_kv = b * Hkv + h_kv

    LOG2E = 1.4426950408889634

    # load Q (bf16)
    offs_d = tl.arange(0, BLOCK_D)
    dmask  = offs_d < D
    # q = tl.load(Q_ptr + pid_bh * q_stride_bh + offs_d * q_stride_d, mask=dmask, other=0)
    q = tl.load(Q_ptr + pid_bh * q_stride_bh + offs_d * q_stride_d).to(tl.float32)

    # running state (bf16)
    m    = tl.full((), -float("inf"), dtype=tl.float32)
    l    = tl.full((), 0.0,            dtype=tl.float32)
    offs_dv = tl.arange(0, BLOCK_DV)
    dv_mask = offs_dv < Dv
    osum = tl.zeros([BLOCK_DV], dtype=tl.float32)

    blk_start   = pid_blk * BLOCK_KV
    neg_inf_vec = tl.full([KV_CHUNK], -float("inf"), dtype=tl.float32)

    for k0 in range(0, BLOCK_KV, KV_CHUNK):
        rel     = k0 + tl.arange(0, KV_CHUNK)
        kv_idx  = blk_start + rel
        kv_mask = (rel < BLOCK_KV) & (kv_idx < S)

        # K tile
        k_ptrs = (K_ptr + pid_bh_kv * k_stride_bh
                         + kv_idx[:, None] * k_stride_s
                         + offs_d[None, :] * k_stride_d)
        # k_tile = tl.load(k_ptrs, mask=kv_mask[:, None] & dmask[None, :], other=0)
        k_tile = tl.load(k_ptrs, mask=kv_mask[:, None], other=0).to(tl.float32)

        # scores
        s_tile = (tl.sum(k_tile * q[None, :], axis=1) * softmax_scale)
        s_tile = tl.where(kv_mask, s_tile, neg_inf_vec)

        # tile softmax in local gauge (keep bf16 after each op)
        m_t  = tl.max(s_tile, axis=0)
        e    = tl.exp2(((s_tile - m_t) * LOG2E))
        l_t  = tl.sum(e, axis=0)

        # V tile
        v_ptrs = (V_ptr + pid_bh_kv * v_stride_bh
                         + kv_idx[:, None] * v_stride_s
                         + offs_dv[None, :] * v_stride_d)
        # v_tile = tl.load(v_ptrs, mask=kv_mask[:, None] & dv_mask[None, :], other=0)
        v_tile = tl.load(v_ptrs, mask=kv_mask[:, None], other=0).to(tl.float32)

        # o_t in the same gauge (bf16 after reduction)
        o_t = tl.sum((v_tile * e[:, None]), axis=0)

        # online combine (cast exp2 outputs to bf16)
        m_new = tl.maximum(m, m_t)
        alpha = tl.exp2(((m   - m_new) * LOG2E))
        beta  = tl.exp2(((m_t - m_new) * LOG2E))

        osum = (osum * alpha + o_t * beta)
        l    = (l    * alpha + l_t * beta)
        m    = m_new

    # store partials
    tl.store(M_ptr + pid_bh * m_stride_bh + pid_blk * m_stride_nb, m)
    tl.store(L_ptr + pid_bh * l_stride_bh + pid_blk * l_stride_nb, l)
    os_ptrs = (OSUM_ptr + pid_bh * os_stride_bh
                          + pid_blk * os_stride_nb
                          + offs_dv * os_stride_dv)
    # tl.store(os_ptrs, osum.to(tl.bfloat16), mask=dv_mask)
    tl.store(os_ptrs, osum)

# _stage1b_configs = [
#     triton.Config({}, num_warps=2, num_stages=3),
#     triton.Config({}, num_warps=4, num_stages=3),
#     triton.Config({}, num_warps=4, num_stages=4),
#     triton.Config({}, num_warps=8, num_stages=3),
#     triton.Config({}, num_warps=2, num_stages=2),
#     triton.Config({}, num_warps=4, num_stages=2),
#     triton.Config({}, num_warps=8, num_stages=2),
# ]

# @triton.autotune(configs=_stage1b_configs, key=['D', 'QBITS'])
@triton.jit
def _kvlinc_stage_1b(
    Q_ptr, K_ptr, Ksc_ptr, Kmn_ptr,
    V_ptr, Vsc_ptr, Vmn_ptr,
    M_ptr, L_ptr, OSUM_ptr,
    BH: tl.int32, S: tl.int32, D: tl.int32, NB: tl.int32,
    Hq: tl.int32, Hkv: tl.int32, R: tl.int32,
    PACK: tl.int32, QBITS: tl.int32, GROUP_SIZE: tl.int32,
    q_stride_bh: tl.int32, q_stride_d: tl.int32,
    k_stride_bh: tl.int32, k_stride_d: tl.int32, k_stride_s: tl.int32,
    k_sc_stride_bh: tl.int32, k_sc_stride_d:tl.int32, k_sc_stride_g: tl.int32,
    k_mn_stride_bh: tl.int32, k_mn_stride_d:tl.int32, k_mn_stride_g: tl.int32,
    v_stride_bh: tl.int32, v_stride_s: tl.int32, v_stride_d: tl.int32,
    v_sc_stride_bh: tl.int32, v_sc_stride_s: tl.int32, v_sc_stride_g: tl.int32,
    v_mn_stride_bh: tl.int32, v_mn_stride_s: tl.int32, v_mn_stride_g: tl.int32,
    m_stride_bh: tl.int32, m_stride_nb: tl.int32,
    l_stride_bh: tl.int32, l_stride_nb: tl.int32,
    os_stride_bh: tl.int32, os_stride_nb: tl.int32, os_stride_dv: tl.int32,
    softmax_scale,
    BLOCK_KV: tl.constexpr,
    BLOCK_D: tl.constexpr, 
):  
    # ---- Static / alignment hints (let compiler vectorize) ----
    tl.static_assert(BLOCK_D == 128)
    # tl.static_assert(BLOCK_KV == 128)
    # tl.static_assert(KV_CHUNK == 128)
    tl.multiple_of(k_stride_s, 16)
    tl.multiple_of(v_stride_d, 16)
    tl.multiple_of(q_stride_d, 16)
    pid_blk = tl.program_id(0)
    pid_bh  = tl.program_id(1)

    # for gqa
    b = pid_bh // Hq
    h_q = pid_bh % Hq
    h_kv = h_q // R
    pid_bh_kv = b * Hkv + h_kv

    LOG2E = 1.4426950408889634

    # load Q 
    offs_d = tl.arange(0, BLOCK_D)
    dmask  = offs_d < D
    # q = tl.load(Q_ptr + pid_bh * q_stride_bh + offs_d * q_stride_d, mask=dmask, other=0)
    q = tl.load(Q_ptr + pid_bh * q_stride_bh + offs_d * q_stride_d).to(tl.float32)
    
    # running state
    m    = tl.full((), -float("inf"), dtype=tl.float32)
    l    = tl.full((), 0.0,            dtype=tl.float32)
    offs_dv = tl.arange(0, BLOCK_D)
    dv_mask = offs_dv < D
    osum = tl.zeros([BLOCK_D], dtype=tl.float32)

    blk_start   = pid_blk * BLOCK_KV
    mask_u32 = tl.full((), (1 << QBITS) - 1, tl.uint32) 
    kv_idx  = blk_start + tl.arange(0, BLOCK_KV)
    kv_mask = (kv_idx < S)

    # Packed-word index & lane for 2-bit along S
    sp      = kv_idx // PACK                              # [KV_CHUNK]
    shifts  = ((kv_idx % PACK) * QBITS)                   # [KV_CHUNK]
    sp_safe = tl.where(kv_mask, sp, 0)

    # === Load packed words for K: [KV_CHUNK, BLOCK_D] ===
    # first load scale and zero points for this keys.
    # assumes that chunk start is group aligned
    gk = (blk_start) // GROUP_SIZE
    k_sc_ptrs = (Ksc_ptr + pid_bh_kv * k_sc_stride_bh
                            + offs_d * k_sc_stride_d 
                            + gk * k_sc_stride_g)
    k_mn_ptrs = (Kmn_ptr + pid_bh_kv * k_mn_stride_bh 
                            + offs_d * k_mn_stride_d
                            + gk * k_mn_stride_g)

    # k_sc_vec = tl.load(k_sc_ptrs, cache_modifier=".ca") # [BLOCK_D]
    k_sc_vec = tl.load(k_sc_ptrs) # [BLOCK_D]
    # k_mn_vec = tl.load(k_mn_ptrs, cache_modifier=".ca")  # [BLOCK_D]
    k_mn_vec = tl.load(k_mn_ptrs)  # [BLOCK_D]

    k_ptrs = (K_ptr + pid_bh_kv * k_stride_bh
                        + offs_d[None, :] * k_stride_d
                        + sp_safe[:, None] * k_stride_s)
    # k_pack = tl.load(k_ptrs, cache_modifier=".cg")
    k_pack = tl.load(k_ptrs)

    # extract 2-bit integers: (word >> shift) & 0x3
    k_unpack = ((k_pack >> shifts[:, None]) & mask_u32).to(tl.float32)        # [KV_CHUNK, BLOCK_D]

    q_sc = (k_sc_vec * q).to(tl.float32)
    q_mn  = (k_mn_vec * q).to(tl.float32)   
    s_tile = tl.sum(tl.fma(k_unpack, q_sc[None, :], q_mn[None, :]), axis=1)  # [KV_CHUNK]
    s_tile = s_tile * softmax_scale
    
    # Dequantize (broadcast sc/zp over rows)
    # k_tile = tl.fma(k_unpack, k_sc_vec[None, :], k_mn_vec[None, :])  # [KV_CHUNK, BLOCK_D]
    
    # scores
    # s_tile = (tl.sum(k_tile * q[None, :], axis=1) * softmax_scale)

    # tile softmax in local gauge (keep bf16 after each op)
    m_t  = tl.max(s_tile, axis=0)
    e    = tl.exp2(((s_tile - m_t) * LOG2E))
    l_t  = tl.sum(e, axis=0)
    
    # === Load packed words for V: [KV_CHUNK, Dv_Packed] ===
    # first load scale and zero points for this keys.
    # assumes that chunk start is group aligned
    gcol = offs_dv // GROUP_SIZE

    v_sc_ptrs = (Vsc_ptr + pid_bh_kv * v_sc_stride_bh
                    + kv_idx[:, None] * v_sc_stride_s
                    + gcol[None, :] * v_sc_stride_g)
    v_mn_ptrs = (Vmn_ptr + pid_bh_kv * v_mn_stride_bh
                    + kv_idx[:, None] * v_mn_stride_s
                    + gcol[None, :] * v_mn_stride_g)
    
    # v_sc_vec = tl.load(v_sc_ptrs, mask=kv_mask[:, None] & dv_mask[None, :], other=1.0)
    v_sc_vec = tl.load(v_sc_ptrs).to(tl.float32)
    # v_mn_vec = tl.load(v_mn_ptrs, mask=kv_mask[:, None] & dv_mask[None, :], other=0.0)
    v_mn_vec = tl.load(v_mn_ptrs).to(tl.float32)

    
    
    gv = (offs_dv // PACK)
    shifts_v = ((offs_dv % PACK) * QBITS)

    v_ptrs = (V_ptr + pid_bh_kv * v_stride_bh
                + kv_idx[:, None] * v_stride_s
                + gv[None, :] * v_stride_d)
    
    # v_pack = tl.load(v_ptrs, mask=kv_mask[:, None] & dv_mask[None, :], other=0)
    v_pack = tl.load(v_ptrs)

    v_unpack = ((v_pack >> shifts_v[None, :]) & mask_u32).to(tl.float32)   

    v_tile = tl.fma(v_unpack, v_sc_vec, v_mn_vec)  # [KV_CHUNK, BLOCK_D]

    o_t = tl.sum((v_tile * e[:, None]), axis=0)

    # o_t in the same gauge (bf16 after reduction)
    # e_f = e[:, None]
    # e_sc = e_f * v_sc_vec
    # e_mn = e_f * v_mn_vec
    # o_t = tl.sum(tl.fma(v_unpack, e_sc, e_mn), axis=0)

    
    # online combine (cast exp2 outputs to bf16)
    # m_new = tl.maximum(m, m_t)
    # alpha = tl.exp2(((m   - m_new) * LOG2E))
    # beta  = tl.exp2(((m_t - m_new) * LOG2E))

    # osum = tl.fma(osum, alpha, o_t * beta)
    # l    = tl.fma(l,    alpha, l_t * beta)
    # m    = m_new
    # m = m_t
    # l = l_t
    # osum = o_t
        
    # store partials
    tl.store(M_ptr + pid_bh * m_stride_bh + pid_blk * m_stride_nb, m_t)
    tl.store(L_ptr + pid_bh * l_stride_bh + pid_blk * l_stride_nb, l_t)
    os_ptrs = (OSUM_ptr + pid_bh * os_stride_bh
                          + pid_blk * os_stride_nb
                          + offs_dv * os_stride_dv)
    # tl.store(os_ptrs, osum, mask=dv_mask)
    tl.store(os_ptrs, o_t)

# _stage2_configs = [
#     triton.Config({}, num_warps=2, num_stages=2),
#     triton.Config({}, num_warps=4, num_stages=2),
#     triton.Config({}, num_warps=4, num_stages=3),
#     triton.Config({}, num_warps=8, num_stages=2),
# ]

# @triton.autotune(configs=_stage2_configs, key=['Dv'])
@triton.jit
def _flashdec_stage2_online_bf16(
    M_ptr, L_ptr, OSUM_ptr, Out_ptr, Aln_ptr, Sum_ptr,
    BH: tl.int32, NB: tl.int32, Dv: tl.int32,
    m_stride_bh: tl.int32, m_stride_nb: tl.int32,
    l_stride_bh: tl.int32, l_stride_nb: tl.int32,
    os_stride_bh: tl.int32, os_stride_nb: tl.int32, os_stride_dv: tl.int32,
    out_stride_bh: tl.int32, out_stride_dv: tl.int32,
    aln_stride_bh: tl.int32, aln_stride_dv: tl.int32,
    sumln_stride_bh: tl.int32, BLOCK_DV: tl.constexpr,
):
    pid_bh = tl.program_id(1)
    dv_blk = tl.program_id(0)
    offs_dv = dv_blk * BLOCK_DV + tl.arange(0, BLOCK_DV)
    dv_mask = offs_dv < Dv

    m    = tl.full((), -float("inf"), dtype=tl.float32)
    l    = tl.full((), 0.0,            dtype=tl.float32)
    osum = tl.zeros([BLOCK_DV], dtype=tl.float32)

    LOG2E = 1.4426950408889634

    for b in range(0, NB):
        m_i = tl.load(M_ptr + pid_bh * m_stride_bh + b * m_stride_nb)
        l_i = tl.load(L_ptr + pid_bh * l_stride_bh + b * l_stride_nb)

        os_ptrs = (OSUM_ptr + pid_bh * os_stride_bh
                              + b * os_stride_nb
                              + offs_dv * os_stride_dv)
        os_i = tl.load(os_ptrs, mask=dv_mask, other=0)

        m_new = tl.maximum(m, m_i)
        alpha = tl.exp2(((m   - m_new) * LOG2E))
        beta  = tl.exp2(((m_i - m_new) * LOG2E))

        osum = (osum * alpha + os_i * beta)
        l    = (l    * alpha + l_i * beta)
        m    = m_new

    aln_ptrs = Aln_ptr + pid_bh * aln_stride_bh + offs_dv * aln_stride_dv
    aln = tl.load(aln_ptrs, mask=dv_mask, other=0)
    sum_ln = tl.load(Sum_ptr + pid_bh * sumln_stride_bh)
    O = (osum + aln) / (l + sum_ln)
    # O = osum / l
    out_ptrs = Out_ptr + pid_bh * out_stride_bh + offs_dv * out_stride_dv
    tl.store(out_ptrs, O.to(tl.bfloat16), mask=dv_mask)
# =========================
# Python wrapper
# =========================
def flashdecoding_two_stage_online_bf16(
    Q, Kfp, Kq, Kq_scale, Kq_mn, Vfp,  
    Vq, Vq_scale, Vq_mn, 
    ALN, SUMLN,
    softmax_scale: float = None,
    Q_group_size=128, BITS=2,
    BLOCK_KV=128,
    BLOCK_D=128, 
):
    """
    End-only normalization via online-softmax (no logs; bf16 throughout).
    Q: [B,H,D]   (bf16)
    Kq: [B,H,D,S//(G*B//16)] (int32)
    Kq_mn, Kq_scale: [B,H,D,S//G] (bf16)
    Vq: [B,H,S,Dv//(G*B//16)](int32)
    Vq_mn, Vq_scale: [B,H,S,Dv//G] (bf16)
    Returns: O [B,H,Dv] (bf16)
    """
    import torch, math
    assert Q.dtype == torch.bfloat16 and Kq.dtype == torch.int32 and Vq.dtype == torch.int32
    B, H, D  = Q.shape
    _, HKV, Dk, S_comp = Kq.shape
    _, _, S, Dv_comp = Vq.shape
    _, _, Sfp, _ = Kfp.shape
    _,_,_, Num_k_groups = Kq_mn.shape
    _,_,_, Num_v_groups = Vq_mn.shape
    feat_per_int = 32 // BITS

    assert S_comp * feat_per_int == S
    assert Dv_comp * feat_per_int == D
    assert Kq_mn.shape == Kq_scale.shape
    assert Dk == D
    assert H % HKV == 0
    if softmax_scale is None:
        softmax_scale = 1.0 / math.sqrt(D)
    assert D  <= BLOCK_D

    BH = B * H
    BHKV = B * HKV
    Qc = Q.contiguous().view(BH, D)

    Kc = Kq.contiguous().view(BHKV, D, S_comp)
    Kc_mn = Kq_mn.contiguous().view(BHKV, D, Num_k_groups)
    Kc_sc = Kq_scale.contiguous().view(BHKV, D, Num_k_groups)
    
    Vc = Vq.contiguous().view(BHKV, S, Dv_comp)
    Vc_mn = Vq_mn.contiguous().view(BHKV, S, Num_v_groups)
    Vc_sc = Vq_scale.contiguous().view(BHKV, S, Num_v_groups)
    
    Kcfp = Kfp.contiguous().view(BHKV, Sfp, D)
    Vcfp = Vfp.contiguous().view(BHKV, Sfp, D)

    ALNc = ALN.contiguous().view(BH, D)
    # ALNc = torch.empty((BH, D),     dtype=torch.bfloat16, device=Q.device)
    SUMLNc = SUMLN.contiguous().view(BH)
    # SUMLNc = torch.empty((BH),     dtype=torch.bfloat16, device=Q.device)

    NB = (S + BLOCK_KV - 1) // BLOCK_KV
    import torch
    M    = torch.empty((BH, NB+1),     dtype=torch.float32, device=Q.device)
    Mfp = M[:, :1]
    Mq = M[:, 1:]
    L    = torch.empty((BH, NB+1),     dtype=torch.float32, device=Q.device)
    Lfp = L[:, :1]
    Lq = L[:, 1:]
    OSUM = torch.empty((BH, NB+1, D), dtype=torch.float32, device=Q.device)
    OSUMfp = OSUM[:,:1,:]
    OSUMq = OSUM[:,1:,:] 
    Out  = torch.empty((BH, D),     dtype=torch.bfloat16, device=Q.device)

    # Stage 1a : attention on fp keys and values
    grid1a = (1, BH)
    _kvlinc_stage_1a[grid1a](
        Qc, Kcfp, Vcfp, Mfp, Lfp, OSUMfp,
        BH, Sfp, D, D, 1,
        H, HKV, H // HKV,
        Qc.stride(0), Qc.stride(1),
        Kcfp.stride(0), Kcfp.stride(1), Kcfp.stride(2),
        Vcfp.stride(0), Vcfp.stride(1), Vcfp.stride(2),
        Mfp.stride(0), Mfp.stride(1),
        Lfp.stride(0), Lfp.stride(1),
        OSUMfp.stride(0), OSUMfp.stride(1), OSUMfp.stride(2),
        softmax_scale,
        BLOCK_KV=128,
        BLOCK_D=BLOCK_D, BLOCK_DV=BLOCK_D,
        KV_CHUNK=128,
        num_warps=8, num_stages=3 
    )
    # print("Stage 1a:",_kvlinc_stage_1a.best_config)
    # Stage 1b : attention on quantized keys and values
    grid1b = (NB, BH)
    _kvlinc_stage_1b[grid1b](
        Qc, Kc, Kc_sc, Kc_mn, Vc, Vc_sc, Vc_mn, Mq, Lq, OSUMq,
        BH, S, D, NB,
        H, HKV, H // HKV,
        feat_per_int, BITS, Q_group_size,
        Qc.stride(0), Qc.stride(1),
        Kc.stride(0), Kc.stride(1), Kc.stride(2),
        Kc_sc.stride(0), Kc_sc.stride(1), Kc_sc.stride(2),
        Kc_mn.stride(0), Kc_mn.stride(1), Kc_mn.stride(2),
        Vc.stride(0), Vc.stride(1), Vc.stride(2),
        Vc_sc.stride(0), Vc_sc.stride(1), Vc_sc.stride(2),
        Vc_mn.stride(0), Vc_mn.stride(1), Vc_mn.stride(2),
        Mq.stride(0), Mq.stride(1),
        Lq.stride(0), Lq.stride(1),
        OSUMq.stride(0), OSUMq.stride(1), OSUMq.stride(2),
        softmax_scale,
        BLOCK_KV=BLOCK_KV, 
        BLOCK_D=BLOCK_D, 
        # KV_CHUNK=BLOCK_KV,
        num_warps=4, num_stages=3
    )
    # print("Stage 1b:",_kvlinc_stage_1b.best_config)


    # m = M.max(dim=1, keepdim=True).values
    # w = torch.exp(M-m)

    # denom = (L*w).sum(dim=1) + SUMLNc
    # num = (OSUM * w[..., None]).sum(dim=1) + ALNc
    # Out = num/denom.unsqueeze(-1)


    # # Stage 2
    grid2 = ((D + BLOCK_D - 1) // BLOCK_D, BH)
    _flashdec_stage2_online_bf16[grid2](
        M, L, OSUM, Out, ALNc, SUMLNc,
        BH, NB+1, D,
        M.stride(0), M.stride(1),
        L.stride(0), L.stride(1),
        OSUM.stride(0), OSUM.stride(1), OSUM.stride(2),
        Out.stride(0), Out.stride(1),
        ALNc.stride(0), ALNc.stride(1),
        SUMLNc.stride(0),BLOCK_DV=BLOCK_D,
        num_warps=4, num_stages=2
    )
    # print("Stage 2:",_flashdec_stage2_online_bf16.best_config)


    return Out.view(B, H,1,D)

def kvlinc_attention_forward(q, k_fp, v_fp, k_q, k_sc, k_mn, v_q, v_sc, v_mn, a_ln, sum_ln, q_group_size, bits, softmax_scale):
    B,H,_,D = q.shape
    dtype = q.dtype
    q = q.reshape(B,H,D)
    O_kernel = flashdecoding_two_stage_online_bf16(
        q,
        k_fp, k_q, k_sc, k_mn,
        v_fp, v_q, v_sc, v_mn,
        a_ln, sum_ln,
        softmax_scale, q_group_size, bits,
        BLOCK_KV=128, BLOCK_D=128
    )
    return O_kernel.to(dtype)
# -----------------------------
# Validation & quick benchmark
# -----------------------------


@torch.no_grad()
def reference_decode(Q, Kfp, Kq, Kq_sc, Kq_mn, Vfp, Vq, Vq_sc, Vq_mn, Aln, Sumln, group_size, bits, scale=None):
    """
    Reference attention for decode (Sq=1). Uses fp32 math for stability.
    Q: [B,H,D] (bf16), K: [B,HKV,S,D] (bf16), V: [B,HKV,S,Dv] (bf16)
    Return: [B,H,Dv] (fp32)
    """
    B, H, D = Q.shape
    if scale is None: scale = 1.0 / math.sqrt(D)

    Kfp = Kfp.transpose(2,3)
    # unpack and dequantize
    K_dq = unpack_and_dequant_vcache(Kq, Kq_sc.unsqueeze(-1), Kq_mn.unsqueeze(-1), group_size, bits).to(torch.bfloat16)
    K_dq = torch.cat([Kfp, K_dq], dim=-1)
    V_dq = unpack_and_dequant_vcache(Vq, Vq_sc.unsqueeze(-1), Vq_mn.unsqueeze(-1), group_size, bits).to(torch.bfloat16)
    V_dq = torch.cat([Vfp, V_dq], dim=-2)

    S = K_dq.size(3); Dv = V_dq.size(3)
    HKV = K_dq.size(1)

    # repeat for gqa
    num_kv_groups = H//HKV
    K_dq = repeat_kv(K_dq, num_kv_groups)
    V_dq = repeat_kv(V_dq, num_kv_groups)

    BH = B * H
    Qf = Q.reshape(BH, D)
    Kf = K_dq.reshape(BH, D, S)
    Vf = V_dq.reshape(BH, S, Dv)
    Alnf = Aln.reshape(BH, Dv)
    Sumlnf = Sumln.reshape(BH,1)

    scores  = torch.einsum('bds,bd->bs', Kf.to(torch.float32), Qf.to(torch.float32)) * scale      # [BH, S]
    scores_max = torch.amax(scores, dim=-1, keepdim=True)     # [BH, 1]
    scores = torch.exp(scores - scores_max)                   # [BH, S]
    scores_sum = scores.sum(dim=-1, keepdim=True)             # [BH, 1]
    out     = torch.bmm(scores.unsqueeze(1), Vf.to(torch.float32)).squeeze(1)   # [BH, Dv]
    out = (out + Alnf.to(torch.float32)) / (scores_sum + Sumlnf.to(torch.float32))                # [BH, Dv]
    return out.view(B, H,1,Dv).to(torch.bfloat16)


def error_report(y, yref, tau=1e-3, eps=1e-12):
    assert torch.allclose(y, yref, rtol=1e-2, atol=1e-2)
    y  = y.float().reshape(-1)
    r  = yref.float().reshape(-1)
    diff = (y - r).abs()
    max_abs = diff.max().item()
    mean_abs = diff.mean().item()
    mask = r.abs() >= tau
    if mask.any():
        max_rel_masked = (diff[mask] / r[mask].abs().clamp_min(eps)).max().item()
    else:
        max_rel_masked = float('nan')
    rel_l2 = (diff.norm() / r.norm().clamp_min(eps)).item()
    cos = torch.nn.functional.cosine_similarity(y.unsqueeze(0), r.unsqueeze(0)).item()
    return max_abs, max_rel_masked, rel_l2, cos, mean_abs


@torch.no_grad()
def run_case_lse_kernel(
    B=2, H=8, HKV=8, S=1024, Sfp=100, D=64, Dv=128, 
    Q_GRP_SIZE=128, BITS=2,
    seed=0,
    BLOCK_KV=256, KV_CHUNK=128, BLOCK_D=128, BLOCK_DV=128,
    atol=2e-2, rtol=2e-2, do_bench=True
):
    assert torch.cuda.is_available(), "CUDA required"
    device = "cuda"
    torch.manual_seed(seed)
    assert Sfp <= 128
    # bf16 inputs
    Q = torch.randn(B, H, D,  device=device, dtype=torch.bfloat16)
    KQ = torch.randn(B, HKV, D, S,  device=device, dtype=torch.bfloat16)
    KFP = torch.randn(B, HKV, Sfp, D,  device=device, dtype=torch.bfloat16)
    VQ = torch.randn(B, HKV, S, D, device=device, dtype=torch.bfloat16)
    VFP = torch.randn(B, HKV, Sfp, D, device=device, dtype=torch.bfloat16)
    ALN = torch.randn(B, H, D, device=device, dtype=torch.bfloat16)
    SUMLN = torch.zeros(B, H, 1, 1, device=device, dtype=torch.bfloat16)
    scale = 1.0 / math.sqrt(D)

    # quantize K,V 
    """
    Quantize and pack along last dimension.
    Kq : [B, H, D, S//(grp_size*bits//16)]
    Kq_mn, Kq_scale : [B, H, D, S//grp_size]

    Vq : [B, H, S, Dv//(grp_size*bits//16)]
    Vq_mn, Vq_scale ; [B, H, S, Dv//grp_size]     
    """
    Kq, Kq_scale, Kq_mn, _, _ = triton_quantize_and_pack_along_last_dim(KQ.clone(), Q_GRP_SIZE, BITS) 
    Vq, Vq_scale, Vq_mn, _, _ = triton_quantize_and_pack_along_last_dim(VQ.clone(), Q_GRP_SIZE, BITS) 
    # Triton kernel (bf16 out -> compare in fp32)
    O_kernel = flashdecoding_two_stage_online_bf16(
        Q,
        KFP, 
        Kq, Kq_scale, Kq_mn, 
        VFP,
        Vq, Vq_scale, Vq_mn, 
        ALN, SUMLN, 
        Q_group_size = Q_GRP_SIZE, BITS = BITS,
        softmax_scale=scale,
        BLOCK_KV=BLOCK_KV,
        BLOCK_D=BLOCK_D,
    )
    O_kernel = O_kernel.to(torch.bfloat16)
    # Reference (fp32 math)
    O_ref = reference_decode(Q, KFP, Kq.clone(), Kq_scale.clone(), Kq_mn.clone(), VFP, Vq.clone(), Vq_scale.clone(), Vq_mn.clone(),
                              ALN, SUMLN, Q_GRP_SIZE, BITS, scale=scale)
    # breakpoint()
    # Errors
    max_abs, max_rel_masked, rel_l2, cos, mean_abs = error_report(O_kernel, O_ref)
    print(f"[OK] [B={B},H={H},HKV={HKV},S={S},D={D},Dv={Dv}] "
          f"max_abs={max_abs:.3e}  max_rel_masked={max_rel_masked:.3e} "
          f"rel_l2={rel_l2:.3e}  cos={cos:.3e} mean_abs={mean_abs:.3e}")

    

    # Optional quick bench
    if do_bench:
        # Warmup
        torch.cuda.synchronize()
        for _ in range(10):
            flashdecoding_two_stage_online_bf16(
                Q,
                KFP, 
                Kq, Kq_scale, Kq_mn, 
                VFP,
                Vq, Vq_scale, Vq_mn, 
                ALN, SUMLN, 
                Q_group_size = Q_GRP_SIZE, BITS = BITS,
                softmax_scale=scale,
                BLOCK_KV=BLOCK_KV,
                BLOCK_D=BLOCK_D,
            )
        torch.cuda.synchronize()

        iters = 50
        t0 = time.time()
        for _ in range(iters):
            flashdecoding_two_stage_online_bf16(
                Q,
                KFP, 
                Kq, Kq_scale, Kq_mn, 
                VFP,
                Vq, Vq_scale, Vq_mn, 
                ALN, SUMLN, 
                Q_group_size = Q_GRP_SIZE, BITS = BITS,
                softmax_scale=scale,
                BLOCK_KV=BLOCK_KV,
                BLOCK_D=BLOCK_D,
            )
        torch.cuda.synchronize()
        ms_kernel = (time.time() - t0) * 1000 / iters

        # Reference timing
        torch.cuda.synchronize()
        for _ in range(10):
            reference_decode(Q, KFP, Kq.clone(), Kq_scale.clone(), Kq_mn.clone(), VFP, Vq.clone(), Vq_scale.clone(), Vq_mn.clone(),
                              ALN, SUMLN, Q_GRP_SIZE, BITS, scale=scale)
        torch.cuda.synchronize()
        t0 = time.time()
        for _ in range(iters):
            reference_decode(Q, KFP, Kq.clone(), Kq_scale.clone(), Kq_mn.clone(), VFP, Vq.clone(), Vq_scale.clone(), Vq_mn.clone(),
                              ALN, SUMLN, Q_GRP_SIZE, BITS, scale=scale)
        torch.cuda.synchronize()
        ms_ref = (time.time() - t0) * 1000 / iters

        print(f"Avg time: Kernel={ms_kernel:.3f} ms   Reference={ms_ref:.3f} ms\n")

if __name__ == "__main__":
    # A few sanity runs (tweak BLOCK_D/ BLOCK_DV if you change D/Dv)
    run_case_lse_kernel(B=1, H=32, HKV=8, S=2048, Sfp=126, D=128, Dv=128, 
                        BLOCK_KV=128, KV_CHUNK=128, BLOCK_D=128, BLOCK_DV=128, seed=1)
    run_case_lse_kernel(B=32, H=32, HKV=8, S=2048, Sfp=126, D=128, Dv=128, 
                        BLOCK_KV=128, KV_CHUNK=128, BLOCK_D=128, BLOCK_DV=128, seed=1)
    run_case_lse_kernel(B=64, H=32, HKV=8, S=2048, Sfp=126, D=128, Dv=128, 
                        BLOCK_KV=128, KV_CHUNK=128, BLOCK_D=128, BLOCK_DV=128, seed=1)
    run_case_lse_kernel(B=512, H=32, HKV=8, S=2048, Sfp=2, D=128, Dv=128, 
                        BLOCK_KV=128, KV_CHUNK=128, BLOCK_D=128, BLOCK_DV=128, seed=1)
