import nvtx
import jax
from jax import numpy as jnp, lax, vmap
from jax.numpy import einsum
from einshape import jax_einshape as einshape
from functools import partial
from jax.experimental import pallas as pl
from jax.experimental.pallas import triton as plgpu
from jax.scipy.special import logsumexp
import dataclasses
from time import time
import math
import numpy as np
from jax import Array as Array
from .single_level_attention import MultiLevelAttention, MultiLevelClustering
from .flash_attention import attn as pallas_flash
from .pallas_cluster import do_clustering as pallas_do_clustering, kmeans_centroids_cuda, custom_assign_indices, ClusterData
from flax import struct
from .moba_retrieval import spatial_retrieval as moba_spatial_retrieval

def cudnn_flash_mha(q, k, v):
    lse = jnp.zeros(q.shape[:-1], dtype=jnp.float32)
    out = jax.nn.dot_product_attention(q, k, v, implementation="cudnn", is_causal=True)
    return lse, out

DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.float32).max)
DIPOLE = True

@struct.dataclass
class ClusterData:
    cnt: Array # K
    lab: Array # N
    fwd: Array # Ku
    bwd: Array # N

def ref_initial(q, k, v, kmeta): # (Q, D), (S, D), (S, V), (S,)
    _k = k[kmeta.fwd, :] # (K, u, D)
    _v = v[kmeta.fwd, :] # (K, u, V)
    sm_scale = 1.0 / math.sqrt(q.shape[-1])
    attn_scores = einsum("qd,kud->qku", q, _k) * sm_scale # (Q, K, u)
    attn_scores = jnp.where(kmeta.fwd[None, :, :] >= 0, attn_scores, DEFAULT_MASK_VALUE) # (Q, K, u)
    attn_weights = jax.nn.softmax(attn_scores, axis=-1) # (Q, K, u)
    v_out = einsum("qku,kuv->qkv", attn_weights, _v) # (Q, K, V)
    k_out = einsum("qku,kud->qkd", attn_weights, _k) # (Q, K, D)

    #mask = kmeta.fwd[:,:] >= 0
    #v_bar = jnp.mean(_v, axis=1, where=mask[...,None]) # (K, V)
    #k_bar = jnp.mean(_k, axis=1, where=mask[...,None]) # (K, D)
    #vk_out = einsum("kuv,kud->kvd", jnp.where(mask[...,None], _v, 0.0), _k) / jnp.sum(mask, axis=1)[:,None,None] # (K, V, D)
    #vk_out = vk_out - einsum("kv,kd->kvd", v_bar, k_bar) # (K, V, D)
    lse = logsumexp(attn_scores, axis=-1) # (Q, K)
    #return lse, k_out, v_out, vk_out, v_bar, k_bar # (Q, K), (Q, K, D), (Q, K, V), (K, V, D)
    return lse, k_out, v_out # (Q, K), (Q, K, D), (Q, K, V)

def ref_initial_dipole(q, k, v, kmeta):
    _k = k[kmeta.fwd, :] # (K, u, D)
    _v = v[kmeta.fwd, :] # (K, u, V)
    sm_scale = 1.0 / math.sqrt(q.shape[-1])

    mask = kmeta.fwd[:,:] >= 0
    v_bar = jnp.sum(_v, axis=1, where=mask[...,None]) / jnp.maximum(jnp.sum(mask[...,None], axis=1), 1)# (K, V)
    k_bar = jnp.sum(_k, axis=1, where=mask[...,None]) / jnp.maximum(jnp.sum(mask[...,None], axis=1), 1)# (K, D)
    vk_out = einsum("kuv,kud->kvd", jnp.where(mask[...,None], _v, 0.0), _k) / jnp.maximum(jnp.sum(mask, axis=1)[:,None,None], 1) # (K, V, D)
    vk_out = vk_out - einsum("kv,kd->kvd", v_bar, k_bar) # (K, V, D)
    return vk_out, k_bar, v_bar

def ref_initial_dipole_full(q, k, v, kmeta):
    _k = k[kmeta.fwd, :] # (K, u, D)
    _v = v[kmeta.fwd, :] # (K, u, V)
    sm_scale = 1.0 / math.sqrt(q.shape[-1])
    attn_scores = einsum("qd,kud->qku", q, _k) * sm_scale # (Q, K, u)
    attn_scores = jnp.where(kmeta.fwd[None, :, :] >= 0, attn_scores, DEFAULT_MASK_VALUE) # (Q, K, u)
    attn_weights = jax.nn.softmax(attn_scores, axis=-1) # (Q, K, u)
    v_bar = einsum("qku,kuv->qkv", attn_weights, _v) # (Q, K, V)
    k_bar = einsum("qku,kud->qkd", attn_weights, _k) # (Q, K, D)
    kres = _k[None, :, :, :] - k_bar[:, :, None, :] # (Q, K, u, D)
    vres = _v[None, :, :, :] - v_bar[:, :, None, :] # (Q, K, u, V)
    vk_bar = einsum("qku,qkuv,qkud->qkvd", attn_weights, vres, kres) # (Q, K, V, D)
    return vk_bar, k_bar, v_bar

def ref_initial_bwd(q, k, v, kmeta, lse, k_out, v_out, dlse, dk_out, dv_out):
    # (Q, D), (S, D), (S, V), (S,), 
    # (Q, K), (Q, K, D), (Q, K, V), (Q, V, D),
    # (Q, K), (Q, K, D), (Q, K, V), (Q, V, D)
    # Pre processing
    #dvk_interm = dvk_out
    #norm_dvk_interm = dvk_interm / kmeta.cnt[:, None, None] # (K, V, D)
    dlse_modified = dlse - einsum("qkd,qkd->qk", dk_out, k_out) - einsum("qkv,qkv->qk", dv_out, v_out) # (Q, K)
    # Kernel
    _k = k[kmeta.fwd, :] # (K, u, D)
    _v = v[kmeta.fwd, :] # (K, u, V)
    mask = kmeta.fwd[:,:] >= 0
    _k = jnp.where(mask[...,None], _k, 0.0) # (K, u, D)
    _v = jnp.where(mask[...,None], _v, 0.0) # (K, u, V)
    sm_scale = 1.0 / math.sqrt(q.shape[-1])
    attn_scores = einsum("qd,kud->qku", q, _k) * sm_scale # (Q, K, u)
    attn_scores = jnp.where(kmeta.fwd[None, :, :] >= 0, attn_scores, DEFAULT_MASK_VALUE) # (Q, K, u)
    #attn_weights = jax.nn.softmax(attn_scores, axis=-1) # (Q, K, u)
    attn_weights = jnp.exp(attn_scores - lse[:,:,None]) # (Q, K, u)

    dv0 = einsum("qku,qkv->kuv", attn_weights, dv_out) # (K, u, V)
    dk0 = einsum("qku,qkd->kud", attn_weights, dk_out) # (K, u, D)

    gv = einsum("qkv,kuv->qku", dv_out, _v) + einsum("qkd,kud->qku", dk_out, _k) # (Q, K, u)
    gv = gv + dlse_modified[:, :, None]
    gv = gv * attn_weights # (Q, K, u)
    dk2 = einsum("qku,qd->kud", gv, q) * sm_scale # (K, u, D)
    dq = einsum("qku,kud->qd", gv, _k) * sm_scale # (Q, D)

    # vk bwd
    #mask = kmeta.fwd[:,:] >= 0
    #kres = jnp.where(mask[...,None], _k - k_bar[:, None, :], 0.) # (K, u, D)
    #dv1 = einsum("kvd,kud->kuv", norm_dvk_interm, kres) # (K, u, V)
    #vres = jnp.where(mask[...,None], _v - v_bar[:, None, :], 0.) # (K, u, V)
    #dk1 = einsum("kvd,kuv->kud", norm_dvk_interm, vres) # (K, u, D)

    dq = dq
    dv = dv0# + dv1
    dk = dk0 + dk2# + dk1
    dk = dk[kmeta.lab, kmeta.bwd, :] # (N, D)
    dv = dv[kmeta.lab, kmeta.bwd, :] # (N, V)
    return dq, dk, dv

#def initial_bwd_kernel(q_ref, k_ref, v_ref, kfwd_ref, lse_ref, k_bar_ref, v_bar_ref, dlse_ref, dok_ref, dov_ref, dovk_ref, dq_ref, dk_ref, dv_ref, *, block_u: int):
def initial_bwd_kernel(q_ref, k_ref, v_ref, kfwd_ref, kcnt_ref, lse_ref, dlse_ref, dok_ref, dov_ref, dq_ref, dk_ref, dv_ref, *, block_u: int):
    Q, D = q_ref.shape # (N, D)
    N, V = v_ref.shape # (N, V)
    U, = kfwd_ref.shape # (U,)
    sm_scale = 1.0 / math.sqrt(q_ref.shape[-1])

    dq = jnp.zeros((Q, D), dtype=jnp.float32) # (Q, D)

    kcnt = kcnt_ref[...] # (),
    q = q_ref[...] # (Q, D)
    lse = lse_ref[...] # (Q,)
    lse = math.log2(math.e) * lse # Convert to log2 scale
    dlse = dlse_ref[...] # (Q,)
    #v_bar = v_bar_ref[...] # (V,)
    #k_bar = k_bar_ref[...] # (D,)
    dov = dov_ref[...] # (Q, V)
    dok = dok_ref[...] # (Q, D)
    #dovk = dovk_ref[...] # (V, D)

    def body(start_u, dq_prev):
        curr_u_slice = pl.dslice(start_u*block_u, block_u)
        kfwd = kfwd_ref[curr_u_slice] # (u,)
        mask = kfwd >= 0 # (u,)
        k = pl.load(k_ref, (kfwd, slice(None)), mask=mask[:,None], other=0.0) # (u, D)
        v = pl.load(v_ref, (kfwd, slice(None)), mask=mask[:,None], other=0.0) # (u, V)
        qk_scale = math.log2(math.e) * sm_scale # (Q, u)
        qk = pl.dot(q, k, trans_b=True, allow_tf32=True) * qk_scale
        qk = jnp.where(mask[None, :], qk, DEFAULT_MASK_VALUE)
        s = jnp.exp2(qk - lse[:, None]) # (Q, u)

        dv0 = pl.dot(s.astype(dov.dtype), dov, trans_a=True, allow_tf32=True) # (u, V)
        dk0 = pl.dot(s.astype(dok.dtype), dok, trans_a=True, allow_tf32=True) # (u, D)

        gv = pl.dot(dov.astype(v.dtype), v, trans_b=True, allow_tf32=True) # (Q, u)
        gk = pl.dot(dok.astype(k.dtype), k, trans_b=True, allow_tf32=True) # (Q, u)

        ds = gv + gk + dlse[:, None]
        dqk = ds * s # (Q, u)

        dk2 = pl.dot(dqk.astype(q.dtype), q, trans_a=True, allow_tf32=True) * sm_scale # (u, D)
        dq_curr = pl.dot(dqk.astype(k.dtype), k, allow_tf32=True) * sm_scale # (Q, D)
        dq_next = dq_prev + dq_curr # (Q, D)

        #kres = jnp.where(mask[:, None], k - k_bar[None, :], 0.0) # (u, D)
        #dv1 = pl.dot(kres, dovk, trans_b=True, allow_tf32=True) # (u, V)
        #vres = jnp.where(mask[:, None], v - v_bar[None, :], 0.0) # (u, V)
        #dk1 = pl.dot(vres, dovk, allow_tf32=True) # (u, D)

        dv = dv0 #+ dv1 # (Q, V)
        pl.store(dv_ref, (kfwd, slice(None)), dv.astype(dv_ref.dtype), mask=mask[:,None])
        dk = dk0 + dk2# + dk1 # (Q, D)
        pl.store(dk_ref, (kfwd, slice(None)), dk.astype(dk_ref.dtype), mask=mask[:,None])
        return dq_next

    lower_bound = 0
    upper_bound = pl.cdiv(kcnt, block_u)
    assert U % block_u == 0, f"U ({U}) must be divisible by block_u ({block_u})"
    dq = lax.fori_loop(lower_bound, upper_bound, body, dq) # (Q, D)
    dq_ref[...] = dq.astype(dq_ref.dtype) # (Q, D)

def initial_bwd(q, k, v, kmeta, lse, k_out, v_out, dlse, dk_out, dv_out):
    dlse_modified = dlse - einsum("qkd,qkd->qk", dk_out, k_out) - einsum("qkv,qkv->qk", dv_out, v_out) # (Q, K)
    #dlse_modified = dlse - jnp.sum(dk_out * k_out, axis=-1) - jnp.sum(dv_out * v_out, axis=-1) # (Q, K)
    Q, D = q.shape # (Q, D)
    S, V = v.shape # (S, V)
    K, U = kmeta.fwd.shape # (K, u)
    block_Q = min(Q, 64)
    assert Q % block_Q == 0, f"Q ({Q}) must be divisible by block_Q ({block_Q})"
    num_blocks_Q = pl.cdiv(Q, block_Q)
    block_u = min(U, 64)
    dqs, dks, dvs = pl.pallas_call(
        partial(initial_bwd_kernel, block_u=block_u),
        out_shape=(
            jax.ShapeDtypeStruct((K,)+q.shape, q.dtype),
            jax.ShapeDtypeStruct((num_blocks_Q,)+k.shape, k.dtype),
            jax.ShapeDtypeStruct((num_blocks_Q,)+v.shape, v.dtype),
        ),
        grid=(K, num_blocks_Q),
        in_specs=[
            pl.BlockSpec((block_Q, D), lambda i, j: (j,0)),
            pl.BlockSpec(k.shape, lambda i, j: (0,0)),
            pl.BlockSpec(v.shape, lambda i, j: (0,0)),
            pl.BlockSpec((None, U), lambda i, j: (i, 0)), # kfwd_ref
            pl.BlockSpec((None,), lambda i, j: (i,)), # kcnt_ref
            pl.BlockSpec((block_Q, None), lambda i, j: (j, i)), # lse_ref
            pl.BlockSpec((block_Q, None), lambda i, j: (j, i)), # dlse_ref
            pl.BlockSpec((block_Q, None, D), lambda i, j: (j, i, 0)), # dok_ref
            pl.BlockSpec((block_Q, None, V), lambda i, j: (j, i, 0)), # dov_ref
        ],
        out_specs=[
            pl.BlockSpec((None, block_Q, D), lambda i, j: (i, j, 0)), # dq_ref
            pl.BlockSpec((None,) + k.shape, lambda i, j: (j, 0, 0)), # dk_ref
            pl.BlockSpec((None,) + v.shape, lambda i, j: (j, 0, 0)), # dv_ref
        ],
        compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=1),
        name="initial_backward",
    )(q, k, v, kmeta.fwd, kmeta.cnt, lse, dlse_modified, dk_out, dv_out)
    dq = jnp.sum(dqs, axis=0) # (Q, D)
    dk = jnp.sum(dks, axis=0) # (S, D)
    dv = jnp.sum(dvs, axis=0) # (S, V)
    return dq, dk, dv # (Q, D), (S, D), (S, V)



# Can maybe increase block size by changing to half precision intermediates, or separating out vk computation
def initial_kernel(q_ref, k_ref, v_ref, kfwd_ref, kcnt_ref, lse_ref, ok_ref, ov_ref, *, block_k: int):
    sm_scale = 1.0 / math.sqrt(q_ref.shape[-1])
    Q, D = q_ref.shape
    _, V = v_ref.shape
    (U,) = kfwd_ref.shape
    kcnt = kcnt_ref[...] # (),
    assert U % block_k == 0, f"U ({U}) must be divisible by block_k ({block_k})"
    q = q_ref[...] # (Q, D)
    #kfwd_all = kfwd_ref[...] # (u,)
    ov = jnp.zeros((Q, V), dtype=jnp.float32) # (Q, V)
    ok = jnp.zeros((Q, D), dtype=jnp.float32) # (Q, D)
    m = jnp.zeros(Q, dtype=jnp.float32) - jnp.inf # (Q,)
    l = jnp.zeros(Q, dtype=jnp.float32) # (Q,)
    def body(start_k, carry):
        ov_prev, ok_prev, m_prev, l_prev = carry
        curr_k_slice = pl.dslice(start_k*block_k, block_k)
        # NOTE THIS IS INEFFICIENT - SHOULD JUST LOAD ALL AT THE BEGINNING
        kfwd = kfwd_ref[curr_k_slice] # (u,)
        #kfwd = lax.dynamic_slice(kfwd_all, (start_k*block_k,), (block_k,))
        mask = kfwd >= 0
        k = pl.load(k_ref, (kfwd, slice(None)), mask=mask[:,None], other=0.0) # (u, D)
        v = pl.load(v_ref, (kfwd, slice(None)), mask=mask[:,None], other=0.0) # (u, V)
        qk_scale = math.log2(math.e) * sm_scale # (Q, u)
        qk = pl.dot(q, k, trans_b=True, allow_tf32=True) * qk_scale # (Q, u)
        qk = jnp.where(mask[None, :], qk, DEFAULT_MASK_VALUE)
        m_curr = jnp.max(qk, axis=-1) # (Q,)
        m_next = jnp.maximum(m_prev, m_curr) # (Q,)
        correction = jnp.exp2(m_prev - m_next) # (Q,)
        l_prev_corr = l_prev * correction # (Q,)
        s_curr = jnp.exp2(qk - m_next[:, None]) # (Q, u)
        l_curr = jnp.sum(s_curr, axis=-1) # (Q,)
        l_next = l_prev_corr + l_curr
        ov_prev_corr = ov_prev * correction[:, None] # (Q, V)
        ov_curr = pl.dot(s_curr.astype(v.dtype), v, allow_tf32=True) # (Q, V)
        ov_next = ov_prev_corr + ov_curr # (Q, V)
        ok_prev_corr = ok_prev * correction[:, None]
        ok_curr = pl.dot(s_curr.astype(k.dtype), k, allow_tf32=True)
        ok_next = ok_prev_corr + ok_curr
        return ov_next, ok_next, m_next, l_next

    lower_bound = 0
    #upper_bound = pl.cdiv(U, block_k)
    upper_bound = pl.cdiv(kcnt, block_k)
    ov, ok, m, l = lax.fori_loop(lower_bound, upper_bound, body, (ov, ok, m, l)) # (Q, V), (Q, D), (Q,), (Q,)
    eps = 1e-30
    ov = ov / jnp.maximum(eps, l[:, None]) # (Q, V)
    ok = ok / jnp.maximum(eps, l[:, None]) # (Q, D)
    ov_ref[...] = ov.astype(ov_ref.dtype) # (Q, V)
    ok_ref[...] = ok.astype(ok_ref.dtype) # (Q, D)
    lse_ref[...] = math.log(2)*m + jnp.log(l) # (Q,)

def initial(q, k, v, kmeta):
    K, U = kmeta.fwd.shape # (K, u)
    Q, D = q.shape # (Q, D)
    _, V = v.shape # (K, V)

    block_Q = min(Q, 64)
    assert Q % block_Q == 0, f"Q ({Q}) must be divisible by block_Q ({block_Q})"
    num_blocks_Q = pl.cdiv(Q, block_Q)
    block_u = min(U, 64)

    lse, k_bar, v_bar = pl.pallas_call( # (Q, K), (Q, K, D), (Q, K, V)
        partial(initial_kernel, block_k=block_u),
        out_shape=[
            jax.ShapeDtypeStruct((Q, K), jnp.float32), # lse
            jax.ShapeDtypeStruct((Q, K, D), k.dtype), # k_bar
            jax.ShapeDtypeStruct((Q, K, V), v.dtype), # v_bar
        ],
        grid=(K, num_blocks_Q),
        in_specs=[
            pl.BlockSpec((block_Q, D), lambda i, j: (j, 0)),
            pl.BlockSpec(k.shape, lambda i, j: (0, 0)),
            pl.BlockSpec(v.shape, lambda i, j: (0, 0)),
            pl.BlockSpec((None, U), lambda i, j: (i, 0)),
            pl.BlockSpec((None,), lambda i, j: (i,)),
        ],
        out_specs=[
            pl.BlockSpec((block_Q, None), lambda i, j: (j, i)),
            pl.BlockSpec((block_Q, None, D), lambda i, j: (j, i, 0)),
            pl.BlockSpec((block_Q, None, V), lambda i, j: (j, i, 0)),
        ],
        compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=1),
        name="initial_forward",
    )(q, k, v, kmeta.fwd, kmeta.cnt)
    return lse, k_bar, v_bar # (Q, K), (Q, K, D), (Q, K, V)



def ref_final(q, q_bar, k, v, m, qmeta): # (Q, n, D), (Q, K, D), (Q, K, V), (Q, K), (T,)
    _q = q[qmeta.fwd, :] # (Q, n, D)
    qres = _q - q_bar[:, None, :] # (Q, n, D)
    sm_scale = 1.0 / math.sqrt(qres.shape[-1])
    attn_scores = einsum("qnd,qkd->qnk", qres, k) * sm_scale + m[:,None,:] # (Q, n, K)
    attn_weights = jax.nn.softmax(attn_scores, axis=-1) # (Q, n, K)
    v_out = einsum("qnk,qkv->qnv", attn_weights, v) # (Q, n, V)
    #v_out_mon = einsum("qnk,qkv->qnv", attn_weights, v) # (Q, n, V)
    #if vk is not None:
    #    v_out = v_out_mon + einsum("qvd,qnd->qnv", vk, qres) * sm_scale # (Q, n, V)
    lse = logsumexp(attn_scores, axis=-1) # (Q, n)
    lse = lse.astype(jnp.float32) # Ensure lse is in float32 for consistency
    #return lse[qmeta.lab, qmeta.bwd], v_out[qmeta.lab, qmeta.bwd, :], v_out_mon[qmeta.lab, qmeta.bwd, :] # (N, n), (N, n, V), (N, n, V)
    return lse[qmeta.lab, qmeta.bwd], v_out[qmeta.lab, qmeta.bwd, :] # (N, n), (N, n, V)

def ref_final_dipole(q, q_bar, vk, qmeta):
    _q = q[qmeta.fwd, :] # (Q, n, D)
    qres = _q - q_bar[:, None, :] # (Q, n, D)
    sm_scale = 1.0 / math.sqrt(qres.shape[-1])
    v_out = einsum("qvd,qnd->qnv", vk, qres) * sm_scale # (Q, n, V)
    return v_out[qmeta.lab, qmeta.bwd, :] # (N, n, V)

def ref_final_dipole_full(q, q_bar, k, m, vk, qmeta):
    _q = q[qmeta.fwd, :] # (Q, n, D)
    qres = _q - q_bar[:, None, :] # (Q, n, D)
    sm_scale = 1.0 / math.sqrt(qres.shape[-1])
    attn_scores = einsum("qnd,qkd->qnk", qres, k) * sm_scale + m[:,None,:] # (Q, n, K)
    attn_weights = jax.nn.softmax(attn_scores, axis=-1) # (Q, n, K)
    extra_v_out = einsum("qnk,qkvd,qnd->qnv", attn_weights, vk, qres) * sm_scale # (Q, n, V)
    return extra_v_out[qmeta.lab, qmeta.bwd, :] # (N, n, V)


def ref_final_dipole_ret(q, q_bar, _vk_merged, merge_weights, vk_bar, qmeta, ret_idx):
    _q = q[qmeta.fwd, :] # (Q, n, D)
    Q, n, D = _q.shape
    _ret_idx = ret_idx[qmeta.fwd, :] # (Q, n, R)
    qres = _q - q_bar[:, None, :] # (Q, n, D)
    sm_scale = 1.0 / math.sqrt(qres.shape[-1])
    vk_merged = einsum("qk,kvd->qvd", merge_weights, vk_bar) # (Q, V, D)
    local_merge_weights = jnp.broadcast_to(merge_weights[:, None, :], (Q, n, merge_weights.shape[1])) # (Q, n, K)
    for r in range(ret_idx.shape[-1]):
        local_idx = _ret_idx[:,:,r] # (Q, n)
        local_merge_weights = local_merge_weights.at[jnp.arange(Q)[:, None], jnp.arange(n)[None, :], local_idx].set(0.0)
    local_vk = einsum("qnk,kvd->qnvd", local_merge_weights, vk_bar) # (Q, n, V, D)
    v_out = einsum("qnvd,qnd->qnv", local_vk, qres) * sm_scale # (Q, n, V)
    return v_out[qmeta.lab, qmeta.bwd, :] # (N, n, V)
    v_out = einsum("qvd,qnd->qnv", vk_merged, qres) * sm_scale # (Q, n, V)
    return v_out[qmeta.lab, qmeta.bwd, :] # (N, n, V)
    for r in range(ret_idx.shape[-1]):
        local_idx = _ret_idx[:,:,r] # (Q, n)
        local_weight = merge_weights[jnp.arange(Q)[:,None], local_idx] # (Q, n)
        local_vk = vk_bar[local_idx, :, :] # (Q, n, V, D)
        v_out = v_out - einsum("qnvd,qnd->qnv", local_vk, qres) * sm_scale * local_weight[:,:,None] # (Q, n, V)
    return v_out[qmeta.lab, qmeta.bwd, :] # (N, n, V)

def ref_final_dipole_full_ret(q, q_bar, k, m, vk, qmeta, ret_idx):
    _q = q[qmeta.fwd, :] # (Q, n, D)
    Q, n, D = _q.shape
    _ret_idx = ret_idx[qmeta.fwd, :] # (Q, n, R)
    qres = _q - q_bar[:, None, :] # (Q, n, D)
    sm_scale = 1.0 / math.sqrt(qres.shape[-1])
    attn_scores = einsum("qnd,qkd->qnk", qres, k) * sm_scale + m[:,None,:] # (Q, n, K)
    attn_weights = jax.nn.softmax(attn_scores, axis=-1) # (Q, n, K)
    attn_weights = attn_weights.at[jnp.arange(Q)[:, None, None], jnp.arange(n)[None, :, None], _ret_idx].set(0.0)
    extra_v_out = einsum("qnk,qkvd,qnd->qnv", attn_weights, vk, qres) * sm_scale # (Q, n, V)
    return extra_v_out[qmeta.lab, qmeta.bwd, :] # (N, n, V)


@jax.grad
def probe_func(attn_scores, v, dv_out):
    attn_weights = jax.nn.softmax(attn_scores, axis=-1) # (Q, n, K)
    print("shapes:", attn_weights.shape, v.shape, dv_out.shape)
    v_out = einsum("qnk,qkv->qnv", attn_weights, v) # (Q, n, V)
    return jnp.sum(v_out * dv_out)

def test_grad(attn_scores, v, dv_out):
    attn_weights = jax.nn.softmax(attn_scores, axis=-1) # (Q, n, K)
    v_out = einsum("qnk,qkv->qnv", attn_weights, v) # (Q, n, V)
    correction = - einsum("qnv,qnv->qn", dv_out, v_out) # (Q, n)
    gv = einsum("qnv,qkv->qnk", dv_out, v) # (Q, n, K)
    gv = gv + correction[:, :, None]
    gv = gv * attn_weights # (Q, n, K)
    return gv

def ref_final_bwd(q, q_bar, k, v, m, qmeta, lse, v_out, dlse, dv_out):
    # preprocessing
    mask = qmeta.fwd >= 0
    _q = q[qmeta.fwd, :] # (Q, n, D)
    v_out = v_out[qmeta.fwd, :] # (Q, n, V)
    dv_out = jnp.where(mask[:,:,None], dv_out[qmeta.fwd, :], 0.) # (Q, n, V)
    dlse = jnp.where(mask, dlse[qmeta.fwd], 0.) # (Q, n)
    lse = lse[qmeta.fwd] # (Q, n)
    dlse_modified = dlse - einsum("qnv,qnv->qn", dv_out, v_out) # (Q, n)
    # kernel
    sm_scale = 1.0 / math.sqrt(_q.shape[-1])
    qres = _q - q_bar[:, None, :] # (Q, n, D)
    attn_scores = einsum("qnd,qkd->qnk", qres, k) * sm_scale + m[:,None,:] # (Q, n, K)
    #attn_weights = jax.nn.softmax(attn_scores, axis=-1) # (Q, n, K)
    attn_weights = jnp.exp(attn_scores - lse[:,:,None]) # (Q, n, K)

    dv = einsum("qnk,qnv->qkv", attn_weights, dv_out) # (Q, K, V)
    gv = einsum("qnv,qkv->qnk", dv_out, v)
    gv = gv + dlse_modified[:,:,None]
    gv = gv * attn_weights # (Q, n, K)
    dk = einsum("qnk,qnd->qkd", gv, qres) * sm_scale # (Q, n, D)
    dqres0 = einsum("qnk,qkd->qnd", gv, k) * sm_scale # (Q, n, D)

    # vk bwd
    #dqres1 = einsum("qvd,qnv->qnd", vk, dv_out) * sm_scale # (Q, n, D)
    #dvk = einsum("qnd,qnv->qvd", qres, dv_out) * sm_scale # (Q, V, D)

    dqres = dqres0# + dqres1 # (Q, n, D)
    dqres = jnp.where(mask[:,:,None], dqres, 0.0) # (Q, n, D)
    dq = dqres# - jnp.mean(dqres, axis=1, where=mask[:,:,None], keepdims=True) # (Q, n, D)
    dq_bar = - jnp.sum(dqres, axis=1)
    #dq_bar = jnp.zeros_like(q_bar) # (Q, D)
    # THE BELOW MAY SIMPLIFY
    dm = jnp.sum(gv, axis=1) # (Q, K)
    
    dq = dq[qmeta.lab, qmeta.bwd, :] # (N, D)
    return dq, dq_bar, dk, dv, dm

def final_bwd_kernel(
        q_ref, q_bar_ref, k_ref, v_ref, mu_ref, qfwd_ref, qcnt_ref,
        lse_ref, dlse_ref, do_ref,
        dq_ref, dq_bar_ref, dk_ref, dv_ref, dmu_ref,
        *, block_n: int
    ):
    T, D = q_ref.shape # (T, D)
    K, V = v_ref.shape # (K, V)
    N, = qfwd_ref.shape # (N,)
    sm_scale = 1.0 / math.sqrt(q_bar_ref.shape[-1])

    q_bar = q_bar_ref[...].astype(q_ref.dtype) # (D,)
    qcnt = qcnt_ref[...] # (),
    k = k_ref[...] # (K, D)
    v = v_ref[...] # (K, V)
    mu = mu_ref[...] # (K,)

    dq_bar = jnp.zeros((D,), dtype=jnp.float32) # (D,)
    dk = jnp.zeros((K, D), dtype=jnp.float32) # (K, D)
    dv = jnp.zeros((K, V), dtype=jnp.float32) # (K, V)
    dmu = jnp.zeros((K,), dtype=jnp.float32) # (K,)

    def body(start_n, carry):
        dq_bar_prev, dk_prev, dv_prev, dmu_prev = carry # (D,), (K, D), (K, V), (V, D), (K,)

        curr_n_slice = pl.dslice(start_n*block_n, block_n)
        qfwd = qfwd_ref[curr_n_slice] # (n,)
        mask = qfwd >= 0
        q = pl.load(q_ref, (qfwd, slice(None)), mask=mask[:,None], other=0.0) # (n, D)
        lse = pl.load(lse_ref, (qfwd,), mask=mask, other=0.0) # (n,)
        do = pl.load(do_ref, (qfwd, slice(None)), mask=mask[:,None], other=0.0) # (n, V)
        dlse = pl.load(dlse_ref, (qfwd,), mask=mask, other=0.0) # (n,)
        qres = q - q_bar[None, :] # (n, D)
        qk_scale = math.log2(math.e) * sm_scale
        mu_scale = math.log2(math.e)
        qk = pl.dot(qres, k, trans_b=True, allow_tf32=True) * qk_scale + mu[None, :] * mu_scale # (n, K)
        s = jnp.exp2(qk - lse[:,None] * mu_scale) # (n, K)

        dv_curr = pl.dot(s.astype(do.dtype), do, trans_a=True, allow_tf32=True) # (K, V)
        dv_next = dv_prev + dv_curr # (K, V)
        gv = pl.dot(do.astype(v.dtype), v, trans_b=True, allow_tf32=True) # (n, K)
        ds = gv + dlse[:, None]
        dqk = ds * s # (n, K)

        dk_curr = pl.dot(dqk.astype(qres.dtype), qres, trans_a=True, allow_tf32=True) * sm_scale # (K, D)
        dk_next = dk_prev + dk_curr # (K, D)
        dqres0 = pl.dot(dqk.astype(k.dtype), k, allow_tf32=True) * sm_scale # (n, D)

        dqres = dqres0 #+ dqres1 # (n, D)
        dq = dqres
        dq_bar_curr = - jnp.sum(dqres, axis=0) # (D,)
        dq_bar_next = dq_bar_prev + dq_bar_curr # (D,)
        dmu_curr = jnp.sum(dqk, axis=0) # (K,)
        dmu_next = dmu_prev + dmu_curr # (K,)

        pl.store(dq_ref, (qfwd, slice(None)), dq.astype(dq_ref.dtype), mask=mask[:,None]) # (n, D)

        return dq_bar_next, dk_next, dv_next, dmu_next

    lower_bound = 0
    upper_bound = pl.cdiv(qcnt, block_n)
    assert N % block_n == 0, f"N ({N}) must be divisible by block_n ({block_n})"
    dq_bar, dk, dv, dmu = lax.fori_loop(lower_bound, upper_bound, body, (dq_bar, dk, dv, dmu)) # (D,), (K, D), (K, V), (V, D), (K,)

    dq_bar_ref[...] = dq_bar.astype(dq_bar_ref.dtype) # (D,)
    dk_ref[...] = dk.astype(dk_ref.dtype) # (K, D)
    dv_ref[...] = dv.astype(dv_ref.dtype) # (K, V)
    dmu_ref[...] = dmu.astype(dmu_ref.dtype) # (K,)

def final_bwd(q, q_bar, k, v, m, qmeta, lse, v_out, dlse, dv_out):
    #dlse_modified = dlse - einsum("tv,tv->t", dv_out, v_out).astype(dlse.dtype) # (T,)
    dlse_modified = dlse - jnp.sum(dv_out * v_out, axis=-1).astype(dlse.dtype) # (T,)
    T, D = q.shape # (T, D)
    assert dlse_modified.shape == q.shape[:-1], f"dlse_modified shape {dlse_modified.shape} does not match q shape {q.shape[:-1]}"
    Q, K, V = v.shape # (Q, K, V)
    _, N = qmeta.fwd.shape # (Q, N)
    block_n = min(N, 32)
    assert N % block_n == 0, f"N ({N}) must be divisible by block_n ({block_n})"
    block_K = min(K, 64)
    assert K % block_K == 0, f"K ({K}) must be divisible by block_k ({block_K})"
    num_blocks_K = pl.cdiv(K, block_K)

    dqs, dq_bars, dk, dv, dm = pl.pallas_call(
        kernel=partial(final_bwd_kernel, block_n=block_n),
        out_shape=(
            jnp.tile(q, (num_blocks_K,1,1)),
            jnp.tile(q_bar, (num_blocks_K,1,1)), 
            k, v, m
        ),
        grid=(Q, num_blocks_K),
        in_specs=[
            pl.BlockSpec(q.shape, lambda i, j: (0, 0)), # q_ref
            pl.BlockSpec((None, D), lambda i, j: (i, 0)), # q_bar_ref
            pl.BlockSpec((None, block_K, D), lambda i, j: (i, j, 0)), # k_ref
            pl.BlockSpec((None, block_K, V), lambda i, j: (i, j, 0)), # v_ref
            pl.BlockSpec((None, block_K), lambda i, j: (i, j)), # mu_ref
            pl.BlockSpec((None, N), lambda i, j: (i, 0)), # qfwd_ref
            pl.BlockSpec((None,), lambda i, j: (i,)), # qcnt_ref
            pl.BlockSpec(lse.shape, lambda i, j: (0,)), # lse_ref
            pl.BlockSpec(dlse.shape, lambda i, j: (0,)), # dlse_ref
            pl.BlockSpec(dv_out.shape, lambda i, j: (0, 0)), # do_ref
        ],
        out_specs=[
            pl.BlockSpec((None,) + q.shape, lambda i, j: (j, 0, 0)), # dq_ref
            pl.BlockSpec((None, None, D), lambda i, j: (j, i, 0)), # dq_bar_ref
            pl.BlockSpec((None, block_K, D), lambda i, j: (i, j, 0)), # dk_ref
            pl.BlockSpec((None, block_K, V), lambda i, j: (i, j, 0)), # dv_ref
            pl.BlockSpec((None, block_K), lambda i, j: (i, j)), # dmu_ref
        ],
        compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=1),
        name="final_backward",
    )(q, q_bar, k, v, m, qmeta.fwd, qmeta.cnt, lse, dlse_modified, dv_out)
    dq = jnp.sum(dqs, axis=0) # (T, D)
    dq_bar = jnp.sum(dq_bars, axis=0) # (Q, D)
    return dq, dq_bar, dk, dv, dm # (T, D), (Q, D), (Q, K, D), (Q, K, V), (Q, V, D), (Q, K,)










def old_final_kernel(q_ref, q_bar_ref, k_ref, v_ref, m_ref, qfwd_ref, lse_ref, v_out_ref):
    q_bar = q_bar_ref[...] # (D,)
    sm_scale = 1.0 / math.sqrt(q_bar.shape[-1])
    k = k_ref[...] # (K, D)
    v = v_ref[...] # (K, V)
    mu = m_ref[...] # (K,)
    qfwd = qfwd_ref[...] # (n,)
    mask = qfwd >= 0
    q = pl.load(q_ref, (qfwd, slice(None)), mask=mask[:,None], other=0.0) # (n, D)

    qres = q - q_bar[None, :] # (n, D)
    qk_scale = math.log2(math.e) * sm_scale
    mu_scale = math.log2(math.e)
    attn_scores = pl.dot(qres, k, trans_b=True, allow_tf32=True) * qk_scale + mu[None, :] * mu_scale # (n, K)
    m = jnp.max(attn_scores, axis=-1) # (n,)
    s = jnp.exp2(attn_scores - m[:,None]) # (n, K)
    l = jnp.sum(s, axis=-1) # (n,)
    ov = pl.dot(s.astype(v.dtype), v, allow_tf32=True) # (n, V)
    ov = ov / l[:, None] # (n, V)
    lse = math.log(2) * m + jnp.log(l) # (n,)
    pl.store(lse_ref, (qfwd,), lse, mask=mask)
    pl.store(v_out_ref, (qfwd, slice(None)), ov.astype(v_out_ref.dtype), mask=mask[:,None])

def final_kernel(q_ref, q_bar_ref, k_ref, v_ref, m_ref, qfwd_ref, qcnt_ref, lse_ref, v_out_ref, *, block_k: int, block_n: int):
    K, D = k_ref.shape # (K, D)
    N, = qfwd_ref.shape # (N,)
    q_bar = q_bar_ref[...] # (D,)
    qcnt = qcnt_ref[...] # ()
    sm_scale = 1.0 / math.sqrt(q_bar.shape[-1])
    qk_scale = math.log2(math.e) * sm_scale
    mu_scale = math.log2(math.e)
    def q_body(start_q):
        curr_q_slice = pl.dslice(start_q*block_n, block_n)
        qfwd = qfwd_ref[curr_q_slice] # (n,)
        mask = qfwd >= 0
        q = pl.load(q_ref, (qfwd, slice(None)), mask=mask[:,None], other=0.0) # (n, D)
        qres = q - q_bar[None, :] # (n, D)

        def body(start_k, carry):
            ov_prev, m_prev, l_prev = carry
            curr_k_slice = pl.dslice(start_k*block_k, block_k)

            k = k_ref[curr_k_slice] # (K, D)
            v = v_ref[curr_k_slice] # (K, V)
            mu = m_ref[curr_k_slice] # (K,)

            attn_scores = pl.dot(qres, k, trans_b=True, allow_tf32=True) * qk_scale + mu[None, :] * mu_scale # (n, K)
            m_curr = jnp.max(attn_scores, axis=-1) # (n,)
            m_next = jnp.maximum(m_prev, m_curr) # (n,)
            correction = jnp.exp2(m_prev - m_next) # (n,)
            l_prev_corr = l_prev * correction # (n,)
            s_curr = jnp.exp2(attn_scores - m_next[:, None]) # (n, K)
            l_curr = jnp.sum(s_curr, axis=-1) # (n,)
            l_next = l_prev_corr + l_curr
            ov_prev_corr = ov_prev * correction[:, None] # (n, V)
            ov_curr = pl.dot(s_curr.astype(v.dtype), v, allow_tf32=True) # (n, V)
            ov_next = ov_prev_corr + ov_curr # (n, V)
            return ov_next, m_next, l_next

        ov = jnp.zeros((qfwd.shape[0], v_ref.shape[1]), dtype=jnp.float32) # (n, V)
        m = jnp.full(qfwd.shape, DEFAULT_MASK_VALUE, dtype=jnp.float32) # (n,)
        l = jnp.zeros(qfwd.shape, dtype=jnp.float32) # (n,)

        lower_bound = 0
        upper_bound = pl.cdiv(K, block_k)
        ov, m, l = lax.fori_loop(lower_bound, upper_bound, body, (ov, m, l)) # (n, V), (n,), (n,)

        ov = ov / l[:, None] # (n, V)
        lse = math.log(2) * m + jnp.log(l) # (n,)
        pl.store(lse_ref, (qfwd,), lse, mask=mask)
        pl.store(v_out_ref, (qfwd, slice(None)), ov.astype(v_out_ref.dtype), mask=mask[:,None])

    q_lower_bound = 0
    q_upper_bound = pl.cdiv(qcnt, block_n)
    jax.lax.fori_loop(q_lower_bound, q_upper_bound, lambda i, _: q_body(i), None)


# Check whether usage of half precision is optimal - some stuff might be unnecessarily in float32
def final(q, q_bar, k, v, m, qmeta):
    Q, n = qmeta.fwd.shape
    N, D = q.shape # (N, D)
    _, K, V = v.shape # (K, V)
    block_n = min(n, 64)
    assert n % block_n == 0, f"n ({n}) must be divisible by block_q ({block_n})"
    block_k = min(K, 64)
    assert K % block_k == 0, f"K ({K}) must be divisible by block_k ({block_k})"


    lse, v_out = pl.pallas_call(
        partial(final_kernel, block_k=block_k, block_n=block_n),
        out_shape=[
            jax.ShapeDtypeStruct((N,), jnp.float32), # lse
            jax.ShapeDtypeStruct((N, V), v.dtype), # v_out
        ],
        grid=(Q,),
        in_specs=[
            pl.BlockSpec(q.shape, lambda i: (0, 0)),
            pl.BlockSpec((None, D,), lambda i: (i, 0)),
            pl.BlockSpec((None, K, D), lambda i: (i, 0, 0)),
            pl.BlockSpec((None, K, V), lambda i: (i, 0, 0)),
            pl.BlockSpec((None, K), lambda i: (i, 0)),
            pl.BlockSpec((None, n), lambda i: (i, 0)),
            pl.BlockSpec((None,), lambda i: (i,)),
        ],
        out_specs=[
            pl.BlockSpec((N,), lambda i: (0,)),
            pl.BlockSpec((N, V), lambda i: (0, 0)),
        ],
        compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=1),
        name="final_forward",
    )(q, q_bar, k, v, m, qmeta.fwd, qmeta.cnt)
    return lse, v_out

def ref_final_retrieval(num_retrievals: int, q, q_bar, k, v, m, qmeta): # (Q, n, D), (Q, K, D), (Q, K, V), (Q, K), (T,)
    _q = q[qmeta.fwd, :] # (Q, n, D)
    qres = _q - q_bar[:, None, :] # (Q, n, D)
    sm_scale = 1.0 / math.sqrt(qres.shape[-1])
    attn_scores = einsum("qnd,qkd->qnk", qres, k) * sm_scale + m[:,None,:] # (Q, n, K)
    Q, n, K = attn_scores.shape
    block_k = min(K, 64)
    assert K % block_k == 0, f"K ({K}) must be divisible by block_k ({block_k})"
    num_blocks_k = pl.cdiv(K, block_k)
    assert num_retrievals % num_blocks_k == 0, f"num_retrievals ({num_retrievals}) must be divisible by num_blocks_k ({num_blocks_k})"
    ret_idx_out = jnp.zeros((Q, n, num_retrievals), dtype=jnp.int32)
    for blk_idx in range(num_blocks_k):
        for r in range(num_retrievals//num_blocks_k):
            ret_idx = jnp.argmax(attn_scores[:,:,blk_idx*block_k:(blk_idx+1)*block_k], axis=-1) + (blk_idx*block_k) # (Q, n)
            ret_idx_out = ret_idx_out.at[:, :, blk_idx*(num_retrievals//num_blocks_k) + r].set(ret_idx)
            attn_to_kill = jnp.arange(attn_scores.shape[2])[None, None, :] == ret_idx[:, :, None]
            assert(attn_to_kill.shape == attn_scores.shape)
            attn_scores = jnp.where(attn_to_kill, DEFAULT_MASK_VALUE, attn_scores)
    m = jnp.max(attn_scores, axis=-1) # (Q, n)
    s = jnp.exp(attn_scores - m[:,:,None]) # (Q, n, K)
    l = jnp.sum(s, axis=-1) # (Q, n)
    ov = einsum("qnk,qkv->qnv", s.astype(v.dtype), v) # (Q, n, V)
    #v_out = ov / l[:, :, None] # (Q, n, V)
    m_out = math.log2(math.e) * m
    return m_out[qmeta.lab, qmeta.bwd], l[qmeta.lab, qmeta.bwd], ov[qmeta.lab, qmeta.bwd, :], ret_idx_out[qmeta.lab, qmeta.bwd, :] # (N, n), (N, n), (N, n, V), (N, n, num_retrievals)
    attn_weights = jax.nn.softmax(attn_scores, axis=-1) # (Q, n, K)
    v_out = einsum("qnk,qkv->qnv", attn_weights, v) # (Q, n, V)
    #v_out_mon = einsum("qnk,qkv->qnv", attn_weights, v) # (Q, n, V)
    #if vk is not None:
    #    v_out = v_out_mon + einsum("qvd,qnd->qnv", vk, qres) * sm_scale # (Q, n, V)
    lse = logsumexp(attn_scores, axis=-1) # (Q, n)
    lse = lse.astype(jnp.float32) # Ensure lse is in float32 for consistency
    #return lse[qmeta.lab, qmeta.bwd], v_out[qmeta.lab, qmeta.bwd, :], v_out_mon[qmeta.lab, qmeta.bwd, :] # (N, n), (N, n, V), (N, n, V)
    return lse[qmeta.lab, qmeta.bwd], v_out[qmeta.lab, qmeta.bwd, :] # (N, n), (N, n, V)

def final_retrieval_kernel(q_ref, q_bar_ref, k_ref, v_ref, m_ref, qfwd_ref, qcnt_ref, m_out_ref, l_out_ref, v_out_ref, ret_idx_out_ref, *, block_k: int, block_n: int):
    K, D = k_ref.shape # (K, D)
    N, = qfwd_ref.shape # (N,)
    _, R = ret_idx_out_ref.shape # (N, num_retrievals)
    q_bar = q_bar_ref[...] # (D,)
    qcnt = qcnt_ref[...] # ()
    sm_scale = 1.0 / math.sqrt(q_bar.shape[-1])
    qk_scale = math.log2(math.e) * sm_scale
    mu_scale = math.log2(math.e)
    def q_body(start_q):
        curr_q_slice = pl.dslice(start_q*block_n, block_n)
        qfwd = qfwd_ref[curr_q_slice] # (n,)
        mask = qfwd >= 0
        q = pl.load(q_ref, (qfwd, slice(None)), mask=mask[:,None], other=0.0) # (n, D)
        qres = q - q_bar[None, :] # (n, D)

        num_blocks_k = pl.cdiv(K, block_k)
        Rb = R // num_blocks_k
        def body(start_k, carry):
            ov_prev, m_prev, l_prev = carry
            curr_k_slice = pl.dslice(start_k*block_k, block_k)

            k = k_ref[curr_k_slice] # (K, D)
            v = v_ref[curr_k_slice] # (K, V)
            mu = m_ref[curr_k_slice] # (K,)

            attn_scores = pl.dot(qres, k, trans_b=True, allow_tf32=True) * qk_scale + mu[None, :] * mu_scale # (n, K)

            for r in range(Rb):
                ret_idx = jnp.argmax(attn_scores, axis=-1) # (n,)
                attn_to_kill = jnp.arange(attn_scores.shape[1])[None, :] == ret_idx[:, None]
                assert(attn_to_kill.shape == attn_scores.shape)
                attn_scores = jnp.where(attn_to_kill, DEFAULT_MASK_VALUE, attn_scores)
                #attn_scores = attn_scores.at[jnp.arange(attn_scores.shape[0]), ret_idx].set(DEFAULT_MASK_VALUE)
                base_idx = start_k * block_k
                #r_slice = pl.dslice(start_k*Rb + r, 1)
                #ret_idx_out_ref[qfwd, r] = ret_idx + base_idx
                pl.store(ret_idx_out_ref, (qfwd, start_k*Rb + r), ret_idx + base_idx, mask=mask)

            m_curr = jnp.max(attn_scores, axis=-1) # (n,)
            m_next = jnp.maximum(m_prev, m_curr) # (n,)
            correction = jnp.exp2(m_prev - m_next) # (n,)
            l_prev_corr = l_prev * correction # (n,)
            s_curr = jnp.exp2(attn_scores - m_next[:, None]) # (n, K)
            l_curr = jnp.sum(s_curr, axis=-1) # (n,)
            l_next = l_prev_corr + l_curr
            ov_prev_corr = ov_prev * correction[:, None] # (n, V)
            ov_curr = pl.dot(s_curr.astype(v.dtype), v, allow_tf32=True) # (n, V)
            ov_next = ov_prev_corr + ov_curr # (n, V)
            return ov_next, m_next, l_next

        ov = jnp.zeros((qfwd.shape[0], v_ref.shape[1]), dtype=jnp.float32) # (n, V)
        m = jnp.full(qfwd.shape, DEFAULT_MASK_VALUE, dtype=jnp.float32) # (n,)
        l = jnp.zeros(qfwd.shape, dtype=jnp.float32) # (n,)

        lower_bound = 0
        upper_bound = pl.cdiv(K, block_k)
        ov, m, l = lax.fori_loop(lower_bound, upper_bound, body, (ov, m, l)) # (n, V), (n,), (n,)


        #ov = ov / l[:, None] # (n, V)
        #lse = math.log(2) * m + jnp.log(l) # (n,)
        #pl.store(lse_ref, (qfwd,), lse, mask=mask)
        pl.store(m_out_ref, (qfwd,), m.astype(m_out_ref.dtype), mask=mask)
        pl.store(l_out_ref, (qfwd,), l.astype(l_out_ref.dtype), mask=mask)
        pl.store(v_out_ref, (qfwd, slice(None)), ov.astype(v_out_ref.dtype), mask=mask[:,None])

    q_lower_bound = 0
    q_upper_bound = pl.cdiv(qcnt, block_n)
    jax.lax.fori_loop(q_lower_bound, q_upper_bound, lambda i, _: q_body(i), None)

def final_retrieval(num_retrievals: int, q, q_bar, k, v, m, qmeta):
    Q, n = qmeta.fwd.shape
    N, D = q.shape # (N, D)
    _, K, V = v.shape # (K, V)
    block_n = min(n, 64)
    assert n % block_n == 0, f"n ({n}) must be divisible by block_q ({block_n})"
    block_k = min(K, 64)
    assert K % block_k == 0, f"K ({K}) must be divisible by block_k ({block_k})"
    num_blocks_k = pl.cdiv(K, block_k)

    #num_retrievals = max(1, K // 16)
    #num_retrievals = 0
    assert num_retrievals % num_blocks_k == 0, f"num_retrievals ({num_retrievals}) must be divisible by num_blocks_k ({num_blocks_k})"
    ret_idx_out_shape = jax.ShapeDtypeStruct((N, num_retrievals), jnp.int32)


    m_out, l_out, v_out, ret_idx = pl.pallas_call(
        partial(final_retrieval_kernel, block_k=block_k, block_n=block_n),
        out_shape=[
            #jax.ShapeDtypeStruct((N,), jnp.float32), # lse
            jax.ShapeDtypeStruct((N,), jnp.float32), # m
            jax.ShapeDtypeStruct((N,), jnp.float32), # l
            jax.ShapeDtypeStruct((N, V), v.dtype), # v_out
            ret_idx_out_shape,
        ],
        grid=(Q,),
        in_specs=[
            pl.BlockSpec(q.shape, lambda i: (0, 0)),
            pl.BlockSpec((None, D,), lambda i: (i, 0)),
            pl.BlockSpec((None, K, D), lambda i: (i, 0, 0)),
            pl.BlockSpec((None, K, V), lambda i: (i, 0, 0)),
            pl.BlockSpec((None, K), lambda i: (i, 0)),
            pl.BlockSpec((None, n), lambda i: (i, 0)),
            pl.BlockSpec((None,), lambda i: (i,)),
        ],
        out_specs=[
            #pl.BlockSpec((N,), lambda i: (0,)), # lse
            pl.BlockSpec((N,), lambda i: (0,)), # m
            pl.BlockSpec((N,), lambda i: (0,)), # l
            pl.BlockSpec((N, V), lambda i: (0, 0)), # v_out
            pl.BlockSpec((N, num_retrievals), lambda i: (0, 0)),
        ],
        compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=1),
        name="final_retrieval_forward",
    )(q, q_bar, k, v, m, qmeta.fwd, qmeta.cnt)
    return m_out, l_out, v_out, ret_idx

def final_retrieval_causal(num_retrievals: int, causal_block: int, q, q_bar, k, v, m, qmeta):
    C, Q, n = qmeta.fwd.shape
    N, D = q.shape # (N, D)
    _, _, K, V = v.shape # (K, V)
    block_n = min(n, 64)
    assert n % block_n == 0, f"n ({n}) must be divisible by block_q ({block_n})"
    block_k = min(K, 64)
    assert K % block_k == 0, f"K ({K}) must be divisible by block_k ({block_k})"
    num_blocks_k = pl.cdiv(K, block_k)

    #num_retrievals = max(1, K // 16)
    #num_retrievals = 0
    assert num_retrievals % num_blocks_k == 0, f"num_retrievals ({num_retrievals}) must be divisible by num_blocks_k ({num_blocks_k})"
    ret_idx_out_shape = jax.ShapeDtypeStruct((N, num_retrievals), jnp.int32)


    shifted_qfwd = qmeta.fwd + (causal_block * jnp.arange(C)[:, None, None]) # (C, Q, n)
    shifted_qfwd = jnp.where(qmeta.fwd < 0, -1, shifted_qfwd)
    m_out, l_out, v_out, ret_idx = pl.pallas_call(
        partial(final_retrieval_kernel, block_k=block_k, block_n=block_n),
        out_shape=[
            jax.ShapeDtypeStruct((N,), jnp.float32), # m
            jax.ShapeDtypeStruct((N,), jnp.float32), # l
            jax.ShapeDtypeStruct((N, V), v.dtype), # v_out
            ret_idx_out_shape,
        ],
        grid=(C, Q),
        in_specs=[
            pl.BlockSpec(q.shape, lambda i,j: (0, 0)), # q
            pl.BlockSpec((None, D,), lambda i,j: (j, 0)), # q_bar
            pl.BlockSpec((None, None, K, D), lambda i,j: (i, j, 0, 0)), # k
            pl.BlockSpec((None, None, K, V), lambda i,j: (i, j, 0, 0)), # v
            pl.BlockSpec((None, None, K), lambda i,j: (i, j, 0)), # m
            pl.BlockSpec((None, None, n), lambda i,j: (i, j, 0)), # qfwd
            pl.BlockSpec((None, None), lambda i,j: (i, j)), # qcnt
        ],
        out_specs=[
            #pl.BlockSpec((N,), lambda i: (0,)), # lse
            pl.BlockSpec((N,), lambda i,j: (0,)), # m
            pl.BlockSpec((N,), lambda i,j: (0,)), # l
            pl.BlockSpec((N, V), lambda i,j: (0, 0)), # v_out
            pl.BlockSpec((N, num_retrievals), lambda i,j: (0, 0)), # ret_idx
        ],
        compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=1),
        name="final_retrieval_forward",
    )(q, q_bar, k, v, m, shifted_qfwd, qmeta.cnt)
    return m_out, l_out, v_out, ret_idx

def final_retrieval_bwd_kernel(
        q_ref, q_bar_ref, k_ref, v_ref, mu_ref, qfwd_ref, qcnt_ref,
        lse_ref, dlse_ref, do_ref,
        dq_ref, dq_bar_ref, dk_ref, dv_ref, dmu_ref,
        *, block_n: int, num_retrievals: int,
    ):
    T, D = q_ref.shape # (T, D)
    K, V = v_ref.shape # (K, V)
    N, = qfwd_ref.shape # (N,)
    sm_scale = 1.0 / math.sqrt(q_bar_ref.shape[-1])

    q_bar = q_bar_ref[...].astype(q_ref.dtype) # (D,)
    qcnt = qcnt_ref[...] # (),
    k = k_ref[...] # (K, D)
    v = v_ref[...] # (K, V)
    mu = mu_ref[...] # (K,)

    dq_bar = jnp.zeros((D,), dtype=jnp.float32) # (D,)
    dk = jnp.zeros((K, D), dtype=jnp.float32) # (K, D)
    dv = jnp.zeros((K, V), dtype=jnp.float32) # (K, V)
    dmu = jnp.zeros((K,), dtype=jnp.float32) # (K,)

    def body(start_n, carry):
        dq_bar_prev, dk_prev, dv_prev, dmu_prev = carry # (D,), (K, D), (K, V), (V, D), (K,)

        curr_n_slice = pl.dslice(start_n*block_n, block_n)
        qfwd = qfwd_ref[curr_n_slice] # (n,)
        mask = qfwd >= 0
        q = pl.load(q_ref, (qfwd, slice(None)), mask=mask[:,None], other=0.0) # (n, D)
        lse = pl.load(lse_ref, (qfwd,), mask=mask, other=0.0) # (n,)
        do = pl.load(do_ref, (qfwd, slice(None)), mask=mask[:,None], other=0.0) # (n, V)
        dlse = pl.load(dlse_ref, (qfwd,), mask=mask, other=0.0) # (n,)
        qres = q - q_bar[None, :] # (n, D)
        qk_scale = math.log2(math.e) * sm_scale
        mu_scale = math.log2(math.e)
        qk = pl.dot(qres, k, trans_b=True, allow_tf32=True) * qk_scale + mu[None, :] * mu_scale # (n, K)
        qk = qk - lse[:,None] * mu_scale
        for r in range(num_retrievals):
            ret_idx = jnp.argmax(qk, axis=-1)
            attn_to_kill = jnp.arange(qk.shape[1])[None, :] == ret_idx[:, None]
            assert(attn_to_kill.shape == qk.shape)
            qk = jnp.where(attn_to_kill, DEFAULT_MASK_VALUE, qk)
        #assert False, "below is bug, lse must be subtracted before retrieval masking, wait no something else is wrong"
        s = jnp.exp2(qk) # (n, K)

        dv_curr = pl.dot(s.astype(do.dtype), do, trans_a=True, allow_tf32=True) # (K, V)
        dv_next = dv_prev + dv_curr # (K, V)
        gv = pl.dot(do.astype(v.dtype), v, trans_b=True, allow_tf32=True) # (n, K)
        ds = gv + dlse[:, None]
        dqk = ds * s # (n, K)

        dk_curr = pl.dot(dqk.astype(qres.dtype), qres, trans_a=True, allow_tf32=True) * sm_scale # (K, D)
        dk_next = dk_prev + dk_curr # (K, D)
        dqres0 = pl.dot(dqk.astype(k.dtype), k, allow_tf32=True) * sm_scale # (n, D)

        dqres = dqres0 #+ dqres1 # (n, D)
        dq = dqres
        dq_bar_curr = - jnp.sum(dqres, axis=0) # (D,)
        dq_bar_next = dq_bar_prev + dq_bar_curr # (D,)
        dmu_curr = jnp.sum(dqk, axis=0) # (K,)
        dmu_next = dmu_prev + dmu_curr # (K,)

        pl.store(dq_ref, (qfwd, slice(None)), dq.astype(dq_ref.dtype), mask=mask[:,None]) # (n, D)

        return dq_bar_next, dk_next, dv_next, dmu_next

    lower_bound = 0
    upper_bound = pl.cdiv(qcnt, block_n)
    assert N % block_n == 0, f"N ({N}) must be divisible by block_n ({block_n})"
    dq_bar, dk, dv, dmu = lax.fori_loop(lower_bound, upper_bound, body, (dq_bar, dk, dv, dmu)) # (D,), (K, D), (K, V), (V, D), (K,)

    dq_bar_ref[...] = dq_bar.astype(dq_bar_ref.dtype) # (D,)
    dk_ref[...] = dk.astype(dk_ref.dtype) # (K, D)
    dv_ref[...] = dv.astype(dv_ref.dtype) # (K, V)
    dmu_ref[...] = dmu.astype(dmu_ref.dtype) # (K,)

def final_retrieval_bwd(num_retrievals: int, q, q_bar, k, v, m, qmeta, lse, v_out, dlse, dv_out):
    #dlse_modified = dlse - einsum("tv,tv->t", dv_out, v_out).astype(dlse.dtype) # (T,)
    dlse_modified = dlse - jnp.sum(dv_out * v_out, axis=-1).astype(dlse.dtype) # (T,)
    T, D = q.shape # (T, D)
    assert dlse_modified.shape == q.shape[:-1], f"dlse_modified shape {dlse_modified.shape} does not match q shape {q.shape[:-1]}"
    Q, K, V = v.shape # (Q, K, V)
    _, N = qmeta.fwd.shape # (Q, N)
    block_n = min(N, 32)
    assert N % block_n == 0, f"N ({N}) must be divisible by block_n ({block_n})"
    block_K = min(K, 64)
    assert K % block_K == 0, f"K ({K}) must be divisible by block_k ({block_K})"
    num_blocks_K = pl.cdiv(K, block_K)

    #num_retrievals = 0
    assert num_retrievals % num_blocks_K == 0, f"num_retrievals ({num_retrievals}) must be divisible by num_blocks_K ({num_blocks_K})"
    retrievals_per_block_k = num_retrievals // num_blocks_K

    dqs, dq_bars, dk, dv, dm = pl.pallas_call(
        kernel=partial(final_retrieval_bwd_kernel, block_n=block_n, num_retrievals=retrievals_per_block_k),
        out_shape=(
            jnp.tile(q, (num_blocks_K,1,1)),
            jnp.tile(q_bar, (num_blocks_K,1,1)), 
            k, v, m
        ),
        grid=(Q, num_blocks_K),
        in_specs=[
            pl.BlockSpec(q.shape, lambda i, j: (0, 0)), # q_ref
            pl.BlockSpec((None, D), lambda i, j: (i, 0)), # q_bar_ref
            pl.BlockSpec((None, block_K, D), lambda i, j: (i, j, 0)), # k_ref
            pl.BlockSpec((None, block_K, V), lambda i, j: (i, j, 0)), # v_ref
            pl.BlockSpec((None, block_K), lambda i, j: (i, j)), # mu_ref
            pl.BlockSpec((None, N), lambda i, j: (i, 0)), # qfwd_ref
            pl.BlockSpec((None,), lambda i, j: (i,)), # qcnt_ref
            pl.BlockSpec(lse.shape, lambda i, j: (0,)), # lse_ref
            pl.BlockSpec(dlse.shape, lambda i, j: (0,)), # dlse_ref
            pl.BlockSpec(dv_out.shape, lambda i, j: (0, 0)), # do_ref
        ],
        out_specs=[
            pl.BlockSpec((None,) + q.shape, lambda i, j: (j, 0, 0)), # dq_ref
            pl.BlockSpec((None, None, D), lambda i, j: (j, i, 0)), # dq_bar_ref
            pl.BlockSpec((None, block_K, D), lambda i, j: (i, j, 0)), # dk_ref
            pl.BlockSpec((None, block_K, V), lambda i, j: (i, j, 0)), # dv_ref
            pl.BlockSpec((None, block_K), lambda i, j: (i, j)), # dmu_ref
        ],
        compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=1),
        name="final_retrieval_backward",
    )(q, q_bar, k, v, m, qmeta.fwd, qmeta.cnt, lse, dlse_modified, dv_out)
    dq = jnp.sum(dqs, axis=0) # (T, D)
    dq_bar = jnp.sum(dq_bars, axis=0) # (Q, D)
    #test_out = jnp.zeros((N,))
    #test_out = test_out.at[qmeta.fwd].set(jnp.where((jnp.arange(n)[None, :] < qmeta.cnt[:, None]) | (qmeta.fwd == -1), 1.0, -1.0), mode="drop")
    #jax.debug.print("uncovered: {} outside cnt: {}", jnp.sum(test_out == 0.0), jnp.sum(test_out < 0.0))
    return dq, dq_bar, dk, dv, dm # (T, D), (Q, D), (Q, K, D), (Q, K, V), (Q, V, D), (Q, K,)


def final_retrieval_causal_bwd(num_retrievals: int, causal_block: int, q, q_bar, k, v, m, qmeta, lse, v_out, dlse, dv_out):
    dlse_modified = dlse - jnp.sum(dv_out * v_out, axis=-1).astype(dlse.dtype) # (T,)
    T, D = q.shape # (T, D)
    assert dlse_modified.shape == q.shape[:-1], f"dlse_modified shape {dlse_modified.shape} does not match q shape {q.shape[:-1]}"
    C, Q, K, V = v.shape # (C, Q, K, V)
    _, _, N = qmeta.fwd.shape # (C, Q, N)
    block_n = min(N, 32)
    assert N % block_n == 0, f"N ({N}) must be divisible by block_n ({block_n})"
    block_K = min(K, 64)
    assert K % block_K == 0, f"K ({K}) must be divisible by block_k ({block_K})"
    num_blocks_K = pl.cdiv(K, block_K)

    assert num_retrievals % num_blocks_K == 0, f"num_retrievals ({num_retrievals}) must be divisible by num_blocks_K ({num_blocks_K})"
    retrievals_per_block_k = num_retrievals // num_blocks_K

    shifted_qfwd = qmeta.fwd + (causal_block * jnp.arange(C)[:, None, None]) # (C, Q, n)
    shifted_qfwd = jnp.where(qmeta.fwd < 0, -1, shifted_qfwd)
    dqs, dq_bars, dk, dv, dm = pl.pallas_call(
        kernel=partial(final_retrieval_bwd_kernel, block_n=block_n, num_retrievals=retrievals_per_block_k),
        out_shape=(
            jnp.tile(q, (num_blocks_K,1,1)),
            jnp.tile(q_bar, (C, num_blocks_K,1,1)), 
            k, v, m
        ),
        grid=(C, Q, num_blocks_K),
        in_specs=[
            pl.BlockSpec(q.shape, lambda i, j, k: (0, 0)), # q_ref
            pl.BlockSpec((None, D), lambda i, j, k: (j, 0)), # q_bar_ref
            pl.BlockSpec((None, None, block_K, D), lambda i, j, k: (i, j, k, 0)), # k_ref
            pl.BlockSpec((None, None, block_K, V), lambda i, j, k: (i, j, k, 0)), # v_ref
            pl.BlockSpec((None, None, block_K), lambda i, j, k: (i, j, k)), # mu_ref
            pl.BlockSpec((None, None, N), lambda i, j, k: (i, j, 0)), # qfwd_ref
            pl.BlockSpec((None, None), lambda i, j, k: (i, j)), # qcnt_ref
            pl.BlockSpec(lse.shape, lambda i, j, k: (0,)), # lse_ref
            pl.BlockSpec(dlse.shape, lambda i, j, k: (0,)), # dlse_ref
            pl.BlockSpec(dv_out.shape, lambda i, j, k: (0, 0)), # do_ref
        ],
        out_specs=[
            pl.BlockSpec((None,) + q.shape, lambda i, j, k: (k, 0, 0)), # dq_ref
            pl.BlockSpec((None, None, None, D), lambda i, j, k: (i, k, j, 0)), # dq_bar_ref
            pl.BlockSpec((None, None, block_K, D), lambda i, j, k: (i, j, k, 0)), # dk_ref
            pl.BlockSpec((None, None, block_K, V), lambda i, j, k: (i, j, k, 0)), # dv_ref
            pl.BlockSpec((None, None, block_K), lambda i, j, k: (i, j, k)), # dmu_ref
        ],
        compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=1),
        name="final_retrieval_backward",
    )(q, q_bar, k, v, m, shifted_qfwd, qmeta.cnt, lse, dlse_modified, dv_out)
    dq = jnp.sum(dqs, axis=0) # (T, D)
    dq_bar = jnp.sum(jnp.sum(dq_bars, axis=0), axis=0) # (Q, D)
    #dk = jnp.zeros_like(k)
    #dv = jnp.zeros_like(v)
    #dm = jnp.zeros_like(m)
    #dq_bar = jnp.zeros_like(q_bar)
    #dk, dv, dm = jnp.zeros_like(k), jnp.zeros_like(v), jnp.zeros_like(m)
    return dq, dq_bar, dk, dv, dm # (T, D), (Q, D), (Q, K, D), (Q, K, V), (Q, V, D), (Q, K,)





def counting_argsegment(K: int, labels):
    cumsums = jnp.cumsum(labels[:, None] == jnp.arange(K)[None, :], axis=0) # (N, K)
    sizes = cumsums[-1, :] # (K,)
    offsets = jnp.cumsum(sizes) - sizes # (K,)
    bwd_interm = cumsums - 1 + offsets[None, :] # (N, K)
    bwd = jnp.take_along_axis(bwd_interm, labels[:, None], axis=1).squeeze(-1) # (N,)
    fwd = jnp.zeros_like(bwd).at[bwd].set(jnp.arange(labels.shape[0])) # (N,)
    return sizes, offsets, fwd, bwd


#def leaf_preprocess_ret(qmeta, kmeta, ret_idx):
def leaf_preprocess_ret(K: int, ret_idx):
    N, R = ret_idx.shape
    #K, U = kmeta.fwd.shape
    flat_ret = ret_idx.flatten() # (N * R,)

    #sizes = jnp.bincount(flat_ret, length=K) # (K,)
    #sizes = jnp.sum(flat_ret[:, None] == jnp.arange(K)[None, :], axis=0) # (K,)
    #offsets = jnp.concatenate([jnp.array([0]), jnp.cumsum(sizes[:-1])]) # (K,)

    #sort_idxs = jnp.argsort(flat_ret) # (N * R,)
    #sort_idxs = counting_argsort(flat_ret, offsets) # (N * R,)
    #sizes, offsets, sort_idxs, _ = counting_argsegment(K, flat_ret)
    sizes = jnp.bincount(flat_ret, length=K) # (K,)
    offsets = jnp.concatenate([jnp.array([0]), jnp.cumsum(sizes[:-1])]) # (K,)
    sort_idxs = jnp.argsort(flat_ret)
    #sort_idxs = jnp.arange(N * R)
    #sort_idxs = jnp.clip(sort_idxs, 0, N * R - 1)
    sorted_qfwd = sort_idxs // R
    sorted_ridx = sort_idxs % R
    return sorted_qfwd, sorted_ridx, sizes, offsets
def spatial_preprocess_ret(K: int, Q: int, ret_idx, qlabels, valid_idx):
    T, R = ret_idx.shape
    total_idx = qlabels[:, None] * K + ret_idx # (T, R)
    total_idx = jnp.where(valid_idx, total_idx, Q*K)
    flat_ret = total_idx.flatten() # (T * R,)
    sizes = jnp.bincount(flat_ret, length=Q * K) # (Q * K,)
    offsets = jnp.concatenate([jnp.array([0]), jnp.cumsum(sizes[:-1])]) # (Q * K,)
    sort_idxs = jnp.argsort(flat_ret) # (T * R,)
    sorted_qfwd = sort_idxs // R
    sorted_ridx = sort_idxs % R
    return sorted_qfwd, sorted_ridx, sizes, offsets

def ref_leaf_retrieval(q, q_bar, k, v, qmeta, kmeta, ret_idx, m_in, l_in, v_in):
    T, R = ret_idx.shape
    S, V = v.shape
    _, D = q.shape
    sm_scale = 1.0 / math.sqrt(D)
    qk_scale = math.log2(math.e) * sm_scale
    attn_scores = einsum("td,sd->ts", q, k) * qk_scale # (T, S)
    attn_mask = jnp.zeros((T, S), dtype=bool) # (T, S)
    q_mask = qmeta.fwd >= 0 # (Q, N)
    k_mask = kmeta.fwd >= 0 # (K, U)
    for r in range(R):
        ret_idx_r = ret_idx[:, r] # (T,)
        attn_mask = attn_mask.at[jnp.arange(T)[:,None], kmeta.fwd[ret_idx_r, :]].max(k_mask[ret_idx_r, :])
    attn_scores = jnp.where(attn_mask, attn_scores, DEFAULT_MASK_VALUE)
    m_curr = jnp.max(attn_scores, axis=-1) # (T,)
    m_next = jnp.maximum(m_in, m_curr) # (T,)
    correction = jnp.exp2(m_in - m_next) # (T,)
    l_in_corr = l_in * correction # (T,)
    s = jnp.exp2(attn_scores - m_next[:, None]) # (T, S)
    #attn_weights = jax.nn.softmax(attn_scores, axis=-1) # (T, S)
    v_next = correction[:, None] * v_in + einsum("ts,sv->tv", s.astype(v.dtype), v) # (T, V)
    l_next = l_in_corr + jnp.sum(s, axis=-1) # (T,)
    #v_out = einsum("ts,sv->tv", attn_weights, v) # (T, V)
    #lse_out = logsumexp(attn_scores, axis=-1) # (T,)
    lse_out = math.log(2) * m_next + jnp.log(l_next) # (T,)
    v_out = v_next / l_next[:, None] # (T, V)
    return lse_out, v_out

def leaf_retrieval_kernel(
    q_ref, k_ref, v_ref, m_in_ref,
    kmeta_fwd_ref, kmeta_cnt_ref,
    sorted_qfwd_ref, sorted_ridx_ref, sizes_ref, offsets_ref, kidx_ref,
    m_out_ref, l_out_ref, v_out_ref,
    *, block_t: int, block_u: int
):
    #return
    S, V = v_ref.shape # (S, V)
    _, D = q_ref.shape # (T, D)
    K, U = kmeta_fwd_ref.shape # (U,)
    qk_scale = math.log2(math.e) * (1.0 / math.sqrt(D))

    kidx = kidx_ref[...] # (U,)
    kcnt = kmeta_cnt_ref[kidx] # ()
    t_size = sizes_ref[...] # ()
    t_offset = offsets_ref[...] # ()

    STORE_EVICTION_POLICY = None
    QLOAD_CACHE_MODIFIER = None
    KLOAD_CACHE_MODIFIER = None
    # Override defaults to increase L1 hit rate
    #STORE_EVICTION_POLICY = "evict_first"
    #QLOAD_CACHE_MODIFIER = "cg"
    #KLOAD_CACHE_MODIFIER = "ca"

    def t_body(start_t, _):
        base_t = start_t * block_t + t_offset
        curr_t_slice = pl.dslice(base_t, block_t)
        mask_t = jnp.arange(block_t) + base_t < t_offset + t_size
        #mask_t = jnp.arange(block_t) == 0
        qfwd = pl.load(sorted_qfwd_ref, (curr_t_slice,), mask=mask_t, other=0, cache_modifier=QLOAD_CACHE_MODIFIER)
        ridx = pl.load(sorted_ridx_ref, (curr_t_slice,), mask=mask_t, other=0, cache_modifier=QLOAD_CACHE_MODIFIER)
        m = pl.load(m_in_ref, (qfwd,), mask=mask_t, other=DEFAULT_MASK_VALUE, cache_modifier=QLOAD_CACHE_MODIFIER)
        q = pl.load(q_ref, (qfwd, slice(None)), mask=mask_t[:,None], other=0.0, cache_modifier=QLOAD_CACHE_MODIFIER) # (t, D)

        def u_body(start_u, carry):
            ov_prev, m_prev, l_prev = carry
            curr_k_slice = pl.dslice(start_u*block_u, block_u)
            #curr_k_indices = jnp.arange(block_u) + start_u * block_u
            #mask_k = curr_k_indices < kcnt
            #kfwd = pl.load(kmeta_fwd_ref, (kidx, curr_k_slice), mask=mask_k, other=-1)
            kfwd = kmeta_fwd_ref[kidx, curr_k_slice] # (u,)
            mask_k = kfwd >= 0
            #mask_k = jnp.zeros_like(mask_k)

            k = pl.load(k_ref, (kfwd, slice(None)), mask=mask_k[:,None], other=0.0, cache_modifier=KLOAD_CACHE_MODIFIER) # (u, D)
            v = pl.load(v_ref, (kfwd, slice(None)), mask=mask_k[:,None], other=0.0, cache_modifier=KLOAD_CACHE_MODIFIER) # (u, V

            attn_scores = pl.dot(q, k, trans_b=True, allow_tf32=True) * qk_scale # (t, u)
            attn_scores = jnp.where(mask_k[None, :], attn_scores, DEFAULT_MASK_VALUE)
            m_curr = jnp.max(attn_scores, axis=-1)
            m_next = jnp.maximum(m_prev, m_curr)
            correction = jnp.exp2(m_prev - m_next)
            #correction = jnp.maximum(1.0, m_prev - m_next)
            l_prev_corr = l_prev * correction
            s_curr = jnp.exp2(attn_scores - m_next[:, None])
            #s_curr = jnp.maximum(1.0, attn_scores - m_next[:, None])
            l_curr = jnp.sum(s_curr, axis=-1)
            l_next = l_prev_corr + l_curr
            ov_prev_corr = ov_prev * correction[:, None]
            ov_curr = pl.dot(s_curr.astype(v.dtype), v, allow_tf32=True)
            ov_next = ov_prev_corr + ov_curr
            return ov_next, m_next, l_next

        ov = jnp.zeros((block_t, V), dtype=jnp.float32) # (t, V)
        l = jnp.zeros((block_t,), dtype=jnp.float32) # (t,)

        u_lower_bound = 0
        u_upper_bound = pl.cdiv(kcnt, block_u)
        #u_upper_bound = pl.cdiv(U, block_u)
        ov, m, l = lax.fori_loop(u_lower_bound, u_upper_bound, u_body, (ov, m, l)) # (t, V), (t,), (t,)

        pl.store(m_out_ref, (qfwd, ridx), m.astype(m_out_ref.dtype), mask=mask_t, eviction_policy=STORE_EVICTION_POLICY)
        pl.store(l_out_ref, (qfwd, ridx), l.astype(l_out_ref.dtype), mask=mask_t, eviction_policy=STORE_EVICTION_POLICY)
        pl.store(v_out_ref, (qfwd, ridx, slice(None)), ov.astype(v_out_ref.dtype), mask=mask_t[:,None], eviction_policy=STORE_EVICTION_POLICY)
    
    t_lower_bound = 0
    t_upper_bound = pl.cdiv(t_size, block_t)
    #t_upper_bound = 1
    jax.lax.fori_loop(t_lower_bound, t_upper_bound, t_body, None)


def as_subsegments(N: int, B: int, sizes, offsets):
    K, = sizes.shape
    full_subsegs = sizes // B
    max_total_subsegs = N // B + K
    cumulative_subsegs = jnp.cumsum(full_subsegs + 1) # (K,)
    subseg_to_seg = jnp.repeat(jnp.arange(K), full_subsegs + 1, total_repeat_length=max_total_subsegs)
    subseg_local_idx = jnp.arange(max_total_subsegs) - jnp.repeat(cumulative_subsegs - full_subsegs - 1, full_subsegs + 1, total_repeat_length=max_total_subsegs)
    is_last = subseg_local_idx == full_subsegs[subseg_to_seg]
    subseg_sizes = jnp.where(is_last, sizes[subseg_to_seg] % B, B)
    subseg_sizes = jnp.where(subseg_local_idx > full_subsegs[subseg_to_seg], 0, subseg_sizes)
    subseg_offsets = offsets[subseg_to_seg] + subseg_local_idx * B
    subseg_offsets = jnp.minimum(subseg_offsets, N-1)
    return subseg_sizes, subseg_offsets, subseg_to_seg
    


def leaf_retrieval_nop(q, q_bar, k, v, qmeta, kmeta, ret_idx, m_in, l_in, v_in):
    lse = math.log(2) * m_in + jnp.log(l_in) # (N,)
    v_out = v_in / l_in[:, None] # (N, V)
    return lse, v_out
    

def ref_leaf_retrieval_postprocess(m_in, l_in, v_in, m_ret, l_ret, v_ret):
    m_out = jnp.max(m_ret, axis=1) # (N,)
    correction_in = jnp.exp2(m_in - m_out) # (N,)
    v_in_corr = v_in * correction_in[:, None] # (N, V)
    l_in_corr = l_in * correction_in # (N,)
    correction_ret = jnp.exp2(m_ret - m_out[:, None]) # (N, R)
    v_ret_corr = v_ret.astype(correction_ret.dtype) * correction_ret[:, :, None] # (N, R, V)
    l_ret_corr = l_ret * correction_ret # (N, R)
    v_total = v_in_corr + jnp.sum(v_ret_corr, axis=1) # (N, V)
    l_total = l_in_corr + jnp.sum(l_ret_corr, axis=1) # (N,)
    eps = 1e-30
    v_out = v_total / jnp.maximum(eps, l_total[:, None]) # (N, V)
    lse_out = math.log(2) * m_out + jnp.log(l_total) # (N,)
    return lse_out, v_out
def ref_spatial_retrieval_postprocess(m_in, l_in, v_in, m_ret, l_ret, v_ret):
    #print(f"all input shapes: m_in {m_in.shape}, l_in {l_in.shape}, v_in {v_in.shape}, m_ret {m_ret.shape}, l_ret {l_ret.shape}, v_ret {v_ret.shape}")
    m_total = jnp.maximum(jnp.max(m_ret, axis=1), DEFAULT_MASK_VALUE) # (N,)
    correction_in = jnp.exp2(m_in - m_total) # (N,)
    v_in_corr = v_in * correction_in[:, None] # (N, V)
    l_in_corr = l_in * correction_in # (N,)
    correction_ret = jnp.exp2(m_ret - m_total[:, None]) # (N, R)
    #print(f"shapes: v_ret {v_ret.shape}, correction_ret {correction_ret.shape}")
    v_ret_corr = v_ret.astype(correction_ret.dtype) * correction_ret[:, :, None] # (N, R, V)
    l_ret_corr = l_ret * correction_ret # (N, R)
    v_total = v_in_corr + jnp.sum(v_ret_corr, axis=1) # (N, V)
    l_total = l_in_corr + jnp.sum(l_ret_corr, axis=1) # (N,)
    return m_total, l_total, v_total

def leaf_retrieval(q, q_bar, k, v, qmeta, kmeta, ret_idx, m_in, l_in, v_in):
    T, R = ret_idx.shape
    Q, D = q_bar.shape
    K, U = kmeta.fwd.shape
    S, V = v.shape

    block_u = min(U, 64)
    assert U % block_u == 0, f"U ({U}) must be divisible by block_u ({block_u})"
    block_t = 128
    assert (T * R) % block_t == 0, f"N * R ({T*R}) must be divisible by block_n ({block_t})"

    sorted_qfwd, sorted_ridx, sizes, offsets = leaf_preprocess_ret(K, ret_idx)
    subseg_sizes, subseg_offsets, subseg_labels = as_subsegments(T * R, block_t*4, sizes, offsets)
    #subseg_sizes, subseg_offsets, subseg_labels = sizes, offsets, jnp.arange(K, dtype=jnp.int32)
    m_ret, l_ret, v_ret = pl.pallas_call(
        partial(leaf_retrieval_kernel, block_t=block_t, block_u=block_u),
        out_shape=[
            jax.ShapeDtypeStruct((T, R), jnp.float32), # m_out
            jax.ShapeDtypeStruct((T, R), jnp.float32), # l_out
            jax.ShapeDtypeStruct((T, R, V), v.dtype), # v_out
        ],
        grid=(len(subseg_labels),),
        in_specs=[
            pl.BlockSpec(q.shape, lambda i: (0, 0)),
            pl.BlockSpec(k.shape, lambda i: (0, 0)),
            pl.BlockSpec(v.shape, lambda i: (0, 0)),
            pl.BlockSpec(m_in.shape, lambda i: (0,)),
            pl.BlockSpec((K, U), lambda i: (0, 0)),
            pl.BlockSpec((K,), lambda i: (0,)),
            pl.BlockSpec(sorted_qfwd.shape, lambda i: (0,)),
            pl.BlockSpec(sorted_ridx.shape, lambda i: (0,)),
            pl.BlockSpec((None,), lambda i: (i,)),
            pl.BlockSpec((None,), lambda i: (i,)),
            pl.BlockSpec((None,), lambda i: (i,)),
        ],
        out_specs=[
            pl.BlockSpec((T, R), lambda i: (0, 0)),
            pl.BlockSpec((T, R), lambda i: (0, 0)),
            pl.BlockSpec((T, R, V), lambda i: (0, 0, 0)),
        ],
        compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=1),
        name="leaf_retrieval_forward",
    #)(q, k, v, m_in, kmeta.fwd, kmeta.cnt, sorted_qfwd, sorted_ridx, sizes, offsets, labels)
    )(q, k, v, m_in, kmeta.fwd, kmeta.cnt, sorted_qfwd, sorted_ridx, subseg_sizes, subseg_offsets, subseg_labels)

    lse_out, v_out = ref_leaf_retrieval_postprocess(m_in, l_in, v_in, m_ret, l_ret, v_ret)
    return lse_out, v_out.astype(v.dtype), (sorted_qfwd, sorted_ridx, sizes, offsets)

def leaf_retrieval_causal(causal_block: int, q, q_bar, k, v, qmeta, kmeta, ret_idx, blk_idx, m_in, l_in, v_in):
    T, _R, P = blk_idx.shape
    blk_idx = blk_idx.reshape(T, _R*P)
    #ret_idx = jnp.tile(ret_idx, (1, P))
    ret_idx = jnp.repeat(ret_idx, P, axis=1)
    assert ret_idx.shape == blk_idx.shape

    T, R = ret_idx.shape
    Q, D = q_bar.shape
    C, K, U = kmeta.fwd.shape
    S, V = v.shape

    block_u = min(U, 64)
    assert U % block_u == 0, f"U ({U}) must be divisible by block_u ({block_u})"
    block_t = 128
    assert (T * R) % block_t == 0, f"N * R ({T*R}) must be divisible by block_n ({block_t})"

    blkret_idx = ret_idx + blk_idx * K
    valid_idx = (ret_idx >= 0) & (ret_idx < K) & (blk_idx >= 0) & (blk_idx < C)
    blkret_idx = jnp.where(valid_idx, blkret_idx, K*C+1)
    #sorted_qfwd, sorted_ridx, sizes, offsets = leaf_preprocess_ret(K, ret_idx)
    sorted_qfwd, sorted_ridx, sizes, offsets = leaf_preprocess_ret(K*C, blkret_idx)
    subseg_sizes, subseg_offsets, subseg_labels = as_subsegments(T * R, block_t*4, sizes, offsets)
    #subseg_sizes, subseg_offsets, subseg_labels = sizes, offsets, jnp.arange(K, dtype=jnp.int32)
    shifted_kfwd = kmeta.fwd + causal_block*jnp.arange(C)[:, None, None] # (C, K, n)
    shifted_kfwd = jnp.where(kmeta.fwd < 0, -1, shifted_kfwd)
    flat_kfwd = einshape("cku->(ck)u", shifted_kfwd, c=C, k=K, u=U)
    flat_kcnt = einshape("ck->(ck)", kmeta.cnt, c=C, k=K)
    m_ret, l_ret, v_ret = pl.pallas_call(
        partial(leaf_retrieval_kernel, block_t=block_t, block_u=block_u),
        out_shape=[
            jax.ShapeDtypeStruct((T, R), jnp.float32), # m_out
            jax.ShapeDtypeStruct((T, R), jnp.float32), # l_out
            jax.ShapeDtypeStruct((T, R, V), v.dtype), # v_out
        ],
        grid=(len(subseg_labels),),
        in_specs=[
            pl.BlockSpec(q.shape, lambda i: (0, 0)), # q
            pl.BlockSpec(k.shape, lambda i: (0, 0)), # k
            pl.BlockSpec(v.shape, lambda i: (0, 0)), # v
            pl.BlockSpec(m_in.shape, lambda i: (0,)), # m_in
            pl.BlockSpec((C*K, U), lambda i: (0, 0)), # kmeta.fwd
            pl.BlockSpec((C*K,), lambda i: (0,)), # kmeta.cnt
            pl.BlockSpec(sorted_qfwd.shape, lambda i: (0,)), # sorted_qfwd
            pl.BlockSpec(sorted_ridx.shape, lambda i: (0,)), # sorted_ridx
            pl.BlockSpec((None,), lambda i: (i,)), # subseg_sizes
            pl.BlockSpec((None,), lambda i: (i,)), # subseg_offsets
            pl.BlockSpec((None,), lambda i: (i,)), # subseg_labels
        ],
        out_specs=[
            pl.BlockSpec((T, R), lambda i: (0, 0)), # m_out
            pl.BlockSpec((T, R), lambda i: (0, 0)), # l_out
            pl.BlockSpec((T, R, V), lambda i: (0, 0, 0)), # v_out
        ],
        compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=1),
        name="leaf_retrieval_forward",
    #)(q, k, v, m_in, kmeta.fwd, kmeta.cnt, sorted_qfwd, sorted_ridx, sizes, offsets, labels)
    )(q, k, v, m_in, flat_kfwd, flat_kcnt, sorted_qfwd, sorted_ridx, subseg_sizes, subseg_offsets, subseg_labels)
    #total_size = jnp.sum(sizes)
    #covered = jnp.zeros_like(ret_idx, dtype=bool).at[sorted_qfwd, sorted_ridx].set(jnp.arange(T * R) < total_size)
    #valid_idx = (ret_idx >= 0) & (ret_idx < K) & (blk_idx >= 0) & (blk_idx < C)
    #covered_with_patch = jnp.where(valid_idx, covered, True)
    #jax.debug.print("total_size: {} covered: {}, covered_with_patch: {}", total_size, jnp.sum(covered), jnp.sum(covered_with_patch))
    #first_uncovered = jnp.argmin(covered_with_patch.reshape(-1))
    #jax.debug.print("first_uncovered idx: {}", first_uncovered)
    #jax.debug.print("sizes: {} offsets: {} labels: {} invalid_count: {}", sizes, offsets, subseg_labels, jnp.sum(~valid_idx))
    #valid_idx = (ret_idx >= 0) & (ret_idx < K) & (blk_idx >= 0) & (blk_idx < C)
    m_ret = jnp.where(valid_idx, m_ret, m_in[:, None])
    l_ret = jnp.where(valid_idx, l_ret, 0.0)
    v_ret = jnp.where(valid_idx[:, :, None], v_ret, 0.0)

    lse_out, v_out = ref_leaf_retrieval_postprocess(m_in, l_in, v_in, m_ret, l_ret, v_ret)
    return lse_out, v_out.astype(v.dtype), (sorted_qfwd, sorted_ridx, sizes, offsets, valid_idx)

def leaf_retrieval_bwd_kernel(
    q_ref, k_ref, v_ref,
    kmeta_fwd_ref, kmeta_cnt_ref,
    sorted_qfwd_ref, sorted_ridx_ref, sizes_ref, offsets_ref, kidx_ref,
    lse_ref, dlse_ref, do_ref,
    dq_ref, dk_ref, dv_ref,
    *, block_t: int, block_u: int
):
    sm_scale = 1.0 / math.sqrt(q_ref.shape[-1])
    qk_scale = math.log2(math.e) * sm_scale
    lse_scale = math.log2(math.e)

    kidx = kidx_ref[...] # (U,)
    kcnt = kmeta_cnt_ref[kidx] # ()
    t_size = sizes_ref[...] # ()
    t_offset = offsets_ref[...] # ()

    t_lower_bound = 0
    # I don't think this line below does anything - presumably it was intended to ensure dq was written to, but the masked writes mean this can't do anything to dq.
    #t_upper_bound = jnp.maximum(1,pl.cdiv(t_size, block_t))
    t_upper_bound = pl.cdiv(t_size, block_t)
    u_lower_bound = 0
    u_upper_bound = jnp.maximum(1, pl.cdiv(kcnt, block_u))
    def body_u(start_u, _):
        curr_u_slice = pl.dslice(start_u*block_u, block_u)
        kfwd = kmeta_fwd_ref[kidx, curr_u_slice] # (u,)
        mask_k = kfwd >= 0

        k = pl.load(k_ref, (kfwd, slice(None)), mask=mask_k[:,None], other=0.0) # (u, D)
        v = pl.load(v_ref, (kfwd, slice(None)), mask=mask_k[:,None], other=0.0) # (u, V)

        def body_t(start_t, carry):
            dk_prev, dv_prev = carry

            base_t = start_t * block_t + t_offset
            curr_t_slice = pl.dslice(base_t, block_t)
            mask_t = jnp.arange(block_t) + base_t < t_offset + t_size

            qfwd = pl.load(sorted_qfwd_ref, (curr_t_slice,), mask=mask_t, other=-1)
            ridx = pl.load(sorted_ridx_ref, (curr_t_slice,), mask=mask_t, other=0)
            mask_t = qfwd >= 0

            q = pl.load(q_ref, (qfwd, slice(None)), mask=mask_t[:,None], other=0.0) # (t, D)
            lse = pl.load(lse_ref, (qfwd), mask=mask_t, other=0.0) # (t,)
            do = pl.load(do_ref, (qfwd, slice(None)), mask=mask_t[:,None], other=0.0)
            dlse = pl.load(dlse_ref, (qfwd,), mask=mask_t, other=0.0)

            qk = pl.dot(q, k, trans_b=True, allow_tf32=True) * qk_scale # (t, u)
            qk = qk - lse[:,None] * lse_scale
            qk = jnp.where(mask_k[None, :], qk, DEFAULT_MASK_VALUE)
            qk = jnp.where(mask_t[:, None], qk, DEFAULT_MASK_VALUE)
            s = jnp.exp2(qk)
            s = jnp.where(mask_k[None, :], s, 0.0)
            s = jnp.where(mask_t[:, None], s, 0.0)

            dv = pl.dot(s.astype(do.dtype), do, trans_a=True, allow_tf32=True) # (u, V)
            dv_next = dv_prev + dv # (u, V)
            gv = pl.dot(do.astype(v.dtype), v, trans_b=True, allow_tf32=True) # (t, u)
            ds = gv + dlse[:, None]
            dqk = ds * s # (t, u)

            dk_curr = pl.dot(dqk.astype(q.dtype), q, trans_a=True, allow_tf32=True) * sm_scale # (u, D)
            dk_next = dk_prev + dk_curr # (u, D)
            dq_curr = pl.dot(dqk.astype(k.dtype), k, allow_tf32=True) * sm_scale # (t, D)
            dq_prev = pl.load(dq_ref, (qfwd, ridx, slice(None)), mask=mask_t[:,None] & (start_u > 0), other=0.0)
            dq_next = dq_prev.astype(dq_curr.dtype) + dq_curr
            pl.store(dq_ref, (qfwd, ridx, slice(None)), dq_next.astype(dq_ref.dtype), mask=mask_t[:,None])
            return dk_next, dv_next

        dk = jnp.zeros((block_u, k_ref.shape[1]), dtype=jnp.float32) # (u, D)
        dv = jnp.zeros((block_u, v_ref.shape[1]), dtype=jnp.float32) # (u, V)
        dk, dv = lax.fori_loop(t_lower_bound, t_upper_bound, body_t, (dk, dv))
        pl.store(dk_ref, (kfwd, slice(None)), dk.astype(dk_ref.dtype), mask=mask_k[:,None])
        pl.store(dv_ref, (kfwd, slice(None)), dv.astype(dv_ref.dtype), mask=mask_k[:,None])

    jax.lax.fori_loop(u_lower_bound, u_upper_bound, body_u, None)

def leaf_retrieval_bwd(
    q, q_bar, k, v, qmeta, kmeta, ret_idx,
    lse, v_out,
    dlse, dv_out,
    retrieval_meta,
):
    K, U = kmeta.fwd.shape
    S, V = v.shape
    T, R = ret_idx.shape
    _, D = q.shape

    dlse_modified = dlse - jnp.sum(dv_out * v_out, axis=-1).astype(dlse.dtype) # (T,)

    #sorted_qfwd, sorted_ridx, sizes, offsets = leaf_preprocess_ret(qmeta, kmeta, ret_idx)
    sorted_qfwd, sorted_ridx, sizes, offsets = retrieval_meta
    subseg_sizes, subseg_offsets, subseg_labels = sizes, offsets, jnp.arange(K, dtype=jnp.int32)
    
    block_t = min(T * R, 64)
    assert (T * R) % block_t == 0, f"T * R ({T*R}) must be divisible by block_t ({block_t})"
    block_u = min(U, 64)
    assert U % block_u == 0, f"U ({U}) must be divisible by block_u ({block_u})"

    dqs, dk, dv = pl.pallas_call(
        kernel=partial(leaf_retrieval_bwd_kernel, block_t=block_t, block_u=block_u),
        out_shape=[
            jax.ShapeDtypeStruct((T, R, D), q.dtype), # dqs
            jax.ShapeDtypeStruct((S, D), k.dtype), # dk
            jax.ShapeDtypeStruct((S, V), v.dtype), # dv
        ],
        grid=(len(subseg_labels),),
        in_specs=[
            pl.BlockSpec(q.shape, lambda i: (0, 0)), # q_ref
            pl.BlockSpec(k.shape, lambda i: (0, 0)), # k_ref
            pl.BlockSpec(v.shape, lambda i: (0, 0)), # v_ref
            pl.BlockSpec((K, U), lambda i: (0, 0)), # kmeta_fwd_ref
            pl.BlockSpec((K,), lambda i: (0,)), # kmeta_cnt_ref
            pl.BlockSpec(sorted_qfwd.shape, lambda i: (0,)), # sorted_qfwd_ref
            pl.BlockSpec(sorted_ridx.shape, lambda i: (0,)), # sorted_ridx_ref
            pl.BlockSpec((None,), lambda i: (i,)), # sizes_ref
            pl.BlockSpec((None,), lambda i: (i,)), # offsets_ref
            pl.BlockSpec((None,), lambda i: (i,)), # kidx_ref
            pl.BlockSpec(lse.shape, lambda i: (0,)), # lse_ref
            pl.BlockSpec(dlse.shape, lambda i: (0,)), # dlse_ref
            pl.BlockSpec(dv_out.shape, lambda i: (0, 0)), # do_ref
        ],
        out_specs=[
            pl.BlockSpec((T, R, D), lambda i: (0, 0, 0)), # dq_ref
            pl.BlockSpec((S, D), lambda i: (0, 0)), # dk_ref
            pl.BlockSpec((S, V), lambda i: (0, 0)), # dv_ref
        ],
        compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=1),
        name="leaf_retrieval_backward",
    )(q, k, v, kmeta.fwd, kmeta.cnt, sorted_qfwd, sorted_ridx, subseg_sizes, subseg_offsets, subseg_labels, lse, dlse_modified, dv_out)
    dq = jnp.sum(dqs, axis=-2) # (T, D)
    return dq, dk, dv

def leaf_retrieval_causal_bwd(
    causal_block: int,
    q, q_bar, k, v, qmeta, kmeta, ret_idx, blk_idx,
    lse, v_out,
    dlse, dv_out,
    retrieval_meta,
):
    T, _R, P = blk_idx.shape
    blk_idx = blk_idx.reshape(T, _R*P)
    #ret_idx = jnp.tile(ret_idx, (1, P))
    ret_idx = jnp.repeat(ret_idx, P, axis=1)
    assert ret_idx.shape == blk_idx.shape



    C, K, U = kmeta.fwd.shape
    S, V = v.shape
    T, R = ret_idx.shape
    _, D = q.shape


    dlse_modified = dlse - jnp.sum(dv_out * v_out, axis=-1).astype(dlse.dtype) # (T,)

    #sorted_qfwd, sorted_ridx, sizes, offsets = leaf_preprocess_ret(qmeta, kmeta, ret_idx)
    sorted_qfwd, sorted_ridx, sizes, offsets, valid_idx = retrieval_meta
    subseg_sizes, subseg_offsets, subseg_labels = sizes, offsets, jnp.arange(K*C, dtype=jnp.int32)
    
    block_t = min(T * R, 64)
    assert (T * R) % block_t == 0, f"T * R ({T*R}) must be divisible by block_t ({block_t})"
    block_u = min(U, 64)
    assert U % block_u == 0, f"U ({U}) must be divisible by block_u ({block_u})"

    shifted_kfwd = kmeta.fwd + causal_block*jnp.arange(C)[:, None, None] # (C, K, n)
    shifted_kfwd = jnp.where(kmeta.fwd < 0, -1, shifted_kfwd)
    flat_kfwd = einshape("cku->(ck)u", shifted_kfwd, c=C, k=K, u=U)
    flat_kcnt = einshape("ck->(ck)", kmeta.cnt, c=C, k=K)

    dqs, dk, dv = pl.pallas_call(
        kernel=partial(leaf_retrieval_bwd_kernel, block_t=block_t, block_u=block_u),
        out_shape=[
            jax.ShapeDtypeStruct((T, R, D), q.dtype), # dqs
            jax.ShapeDtypeStruct((S, D), k.dtype), # dk
            jax.ShapeDtypeStruct((S, V), v.dtype), # dv
        ],
        grid=(len(subseg_labels),),
        in_specs=[
            pl.BlockSpec(q.shape, lambda i: (0, 0)), # q_ref
            pl.BlockSpec(k.shape, lambda i: (0, 0)), # k_ref
            pl.BlockSpec(v.shape, lambda i: (0, 0)), # v_ref
            pl.BlockSpec((C*K, U), lambda i: (0, 0)), # kmeta_fwd_ref
            pl.BlockSpec((C*K,), lambda i: (0,)), # kmeta_cnt_ref
            pl.BlockSpec(sorted_qfwd.shape, lambda i: (0,)), # sorted_qfwd_ref
            pl.BlockSpec(sorted_ridx.shape, lambda i: (0,)), # sorted_ridx_ref
            pl.BlockSpec((None,), lambda i: (i,)), # sizes_ref
            pl.BlockSpec((None,), lambda i: (i,)), # offsets_ref
            pl.BlockSpec((None,), lambda i: (i,)), # kidx_ref
            pl.BlockSpec(lse.shape, lambda i: (0,)), # lse_ref
            pl.BlockSpec(dlse.shape, lambda i: (0,)), # dlse_ref
            pl.BlockSpec(dv_out.shape, lambda i: (0, 0)), # do_ref
        ],
        out_specs=[
            pl.BlockSpec((T, R, D), lambda i: (0, 0, 0)), # dq_ref
            pl.BlockSpec((S, D), lambda i: (0, 0)), # dk_ref
            pl.BlockSpec((S, V), lambda i: (0, 0)), # dv_ref
        ],
        compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=1),
        name="leaf_retrieval_backward",
    )(q, k, v, flat_kfwd, flat_kcnt, sorted_qfwd, sorted_ridx, subseg_sizes, subseg_offsets, subseg_labels, lse, dlse_modified, dv_out)
    dqs = jnp.where(valid_idx[:, :, None], dqs, 0.0)
    dq = jnp.sum(dqs, axis=-2) # (T, D)
    return dq, dk, dv


def leaf_retrieval_postprocess_kernel(m1_in_ref, l1_in_ref, v1_in_ref, m2_in_ref, l2_in_ref, v2_in_ref, lse_out_ref, v_out_ref):
    m1 = m1_in_ref[...] # (T,)
    l1 = l1_in_ref[...] # (T,)
    v1 = v1_in_ref[...] # (T, V)
    m2 = m2_in_ref[...] # (T, R)
    l2 = l2_in_ref[...] # (T, R)
    v2 = v2_in_ref[...] # (T, R, V)

    m_out = jnp.maximum(m1, jnp.max(m2, axis=1)) # (T,)
    correction1 = jnp.exp2(m1 - m_out) # (T,)
    v1_corr = v1.astype(correction1.dtype) * correction1[:, None] # (T, V)
    l1_corr = l1 * correction1 # (T,)
    correction2 = jnp.exp2(m2 - m_out[:, None]) # (T, R)
    v2_corr = v2.astype(correction2.dtype) * correction2[:, :, None] # (T, R, V)
    l2_corr = l2 * correction2 # (T, R)
    v_total = v1_corr + jnp.sum(v2_corr, axis=1) # (T, V)
    l_total = l1_corr + jnp.sum(l2_corr, axis=1) # (T,)
    v_out = v_total / l_total[:, None] # (T, V)
    lse_out = math.log(2) * m_out + jnp.log(l_total) # (T,)

    lse_out_ref[...] = lse_out.astype(lse_out_ref.dtype) # (T,)
    v_out_ref[...] = v_out.astype(v_out_ref.dtype) # (T, V)

def leaf_retrieval_postprocess(m1_in, l1_in, v1_in, m2_in, l2_in, v2_in, target_dtype):
    T, R = m2_in.shape
    _, V = v1_in.shape

    block_t_post = 64
    assert T % block_t_post == 0, f"T ({T}) must be divisible by block_t_post ({block_t_post})"
    num_blocks_post = pl.cdiv(T, block_t_post)
    lse_out, v_out = pl.pallas_call(
        leaf_retrieval_postprocess_kernel,
        out_shape=[
            jax.ShapeDtypeStruct((T,), jnp.float32), # lse_out
            jax.ShapeDtypeStruct((T, V), target_dtype), # v_out
        ],
        grid=(num_blocks_post,),
        in_specs=[
            pl.BlockSpec((block_t_post,), lambda i: (i,)), # m1_in_ref
            pl.BlockSpec((block_t_post,), lambda i: (i,)), # l1_in_ref
            pl.BlockSpec((block_t_post, V), lambda i: (i, 0)), # v1_in_ref
            pl.BlockSpec((block_t_post, R), lambda i: (i, 0)), # m2_in_ref
            pl.BlockSpec((block_t_post, R), lambda i: (i, 0)), # l2_in_ref
            pl.BlockSpec((block_t_post, R, V), lambda i: (i, 0, 0)), # v2_in_ref
        ],
        out_specs=[
            pl.BlockSpec((block_t_post,), lambda i: (i,)), # lse_out_ref
            pl.BlockSpec((block_t_post, V), lambda i: (i, 0)), # v_out_ref
        ],
        name="leaf_retrieval_postprocess_forward",
    )(m1_in, l1_in, v1_in, m2_in, l2_in, v2_in)
    return lse_out, v_out

def ref_lse_value_merge(lse_0, lse_1, value_0, value_1):
    max_lse = jnp.maximum(lse_0, lse_1)
    lse_0_corr = jnp.exp(lse_0 - max_lse)
    lse_1_corr = jnp.exp(lse_1 - max_lse)
    lse_merged = jnp.log(lse_0_corr + lse_1_corr) + max_lse
    weights_0 = lse_0_corr / (lse_0_corr + lse_1_corr)
    weights_1 = lse_1_corr / (lse_0_corr + lse_1_corr)
    value_merged = value_0 * weights_0[..., None] + value_1 * weights_1[..., None]
    return lse_merged.astype(lse_0.dtype), value_merged.astype(value_0.dtype)

def spatial_retrieval_kernel(
    q_ref, q_bar_ref, k_ref, v_ref, mu_ref,
    sorted_qfwd_ref, sorted_ridx_ref, sizes_ref, offsets_ref, kidx_ref, qidx_ref,
    m_in_ref,
    m_out_ref, l_out_ref, v_out_ref, blk_idx_out_ref,
    *,
    causal_stride: int, block_t: int, max_c: int, bidiagonal: bool,
):
    sm_scale = 1.0 / math.sqrt(q_ref.shape[-1])
    qk_scale = math.log2(math.e) * sm_scale
    mu_scale = math.log2(math.e)

    qidx = qidx_ref[...] # ()
    q_bar = q_bar_ref[qidx, :] # (D,)

    size_t = sizes_ref[...] # ()
    offset_t = offsets_ref[...] # ()
    kidx = kidx_ref[...] # ()

    # CHECK INDEXING HERE
    C, Q, K, D = k_ref.shape
    T, R, P = blk_idx_out_ref.shape
    kblk = jnp.arange(C)
    mask_k = kblk < max_c
    k = pl.load(k_ref, (slice(None), qidx, kidx, slice(None)), mask=mask_k[:,None], other=0.0) # (C, D)
    v = pl.load(v_ref, (slice(None), qidx, kidx, slice(None)), mask=mask_k[:,None], other=0.0) # (C, V)
    mu = pl.load(mu_ref, (slice(None), qidx, kidx), mask=mask_k, other=DEFAULT_MASK_VALUE) # (C,)
    kblk = jnp.arange(k.shape[0])

    
    #raise NotImplementedError("spatial retrieval kernel not implemented yet")
    def t_body(start_t, _):
        base_t = start_t * block_t + offset_t
        curr_t_slice = pl.dslice(base_t, block_t)
        mask_t = jnp.arange(block_t) + base_t < offset_t + size_t
        qfwd = pl.load(sorted_qfwd_ref, (curr_t_slice,), mask=mask_t, other=0)
        ridx = pl.load(sorted_ridx_ref, (curr_t_slice,), mask=mask_t, other=0)
        qblk = qfwd // causal_stride
        m = pl.load(m_in_ref, (qfwd,), mask=mask_t, other=DEFAULT_MASK_VALUE)
        q = pl.load(q_ref, (qfwd, slice(None)), mask=mask_t[:,None], other=0.0) # (t, D)
        qres = q - q_bar[None, :]

        causal_mask = kblk < qblk[:, None] if not bidiagonal else kblk < (qblk[:, None] - 1)

        # NEED TO ADD CAUSAL MASKING
        attn_scores = pl.dot(qres, k, trans_b=True, allow_tf32=True) * qk_scale + mu[None, :] * mu_scale
        attn_scores = jnp.where(causal_mask, attn_scores, DEFAULT_MASK_VALUE)
        for pidx in range(P):
            blk_idx = jnp.argmax(attn_scores, axis=-1)
            #blk_idx = jnp.zeros_like(blk_idx)
            #blk_idx = qblk - 1
            attn_to_kill = jnp.arange(attn_scores.shape[1])[None, :] == blk_idx[:, None]
            attn_scores = jnp.where(attn_to_kill, DEFAULT_MASK_VALUE, attn_scores)
            if bidiagonal:
                blk_idx = jnp.where(blk_idx < (qblk - 1), blk_idx, -1)
            else:
                blk_idx = jnp.where(blk_idx < qblk, blk_idx, -1)
            pl.store(blk_idx_out_ref, (qfwd, ridx, pidx), blk_idx.astype(blk_idx_out_ref.dtype), mask=mask_t)

        m_curr = jnp.max(attn_scores, axis=-1) # (t,)
        m_next = jnp.maximum(m, m_curr)
        s = jnp.exp2(attn_scores - m_next[:, None]) # (t, C)
        s = jnp.where(causal_mask, s, 0.0)

        l_curr = jnp.sum(s, axis=-1) # (t,)
        ov_curr = pl.dot(s.astype(v.dtype), v, allow_tf32=True) # (t, V)

        pl.store(m_out_ref, (qfwd, ridx), m_next.astype(m_out_ref.dtype), mask=mask_t)
        pl.store(l_out_ref, (qfwd, ridx), l_curr.astype(l_out_ref.dtype), mask=mask_t)
        pl.store(v_out_ref, (qfwd, ridx, slice(None)), ov_curr.astype(v_out_ref.dtype), mask=mask_t[:, None])

    t_lower_bound = 0
    t_upper_bound = pl.cdiv(size_t, block_t)
    jax.lax.fori_loop(t_lower_bound, t_upper_bound, t_body, None)

def spatial_retrieval(num_spatial: int, causal_block: int, bidiagonal: bool, q, q_bar, k, v, m, qmeta, ret_idx, m_in, l_in, v_in):
    """figure out which spatial index to retrieve from, attend to summaries of others"""
    assert ret_idx.dtype == jnp.int32, "ret_idx must be int32"
    assert qmeta.lab.dtype == jnp.int32, "qmeta.lab must be int32"
    T, R = ret_idx.shape
    Q, D = q_bar.shape
    C, Q, K, V = v.shape
    P = num_spatial

    assert C <= 256, "C too large for spatial retrieval with no blocking over C"
    if C >= 128:
        print(f"Warning: C={C} is large (>=128) for spatial retrieval with no blocking over C, performance may be suboptimal")
    block_C = max(16, C)

    block_t = 16
    qblk = jnp.arange(T) // causal_block
    first_valid_qblk = 3 if bidiagonal else 2
    valid_idx = jnp.broadcast_to((qblk >= first_valid_qblk)[:,None], (T, R))
    flat_labels = qmeta.lab.flatten()
    sorted_qfwd, sorted_ridx, sizes, offsets = spatial_preprocess_ret(K, Q, ret_idx, flat_labels, valid_idx)
    subseg_labels_q = jnp.broadcast_to(jnp.arange(Q)[:, None], (Q, K)).flatten()
    subseg_labels_k = jnp.broadcast_to(jnp.arange(K)[None, :], (Q, K)).flatten()

    causal_stride = causal_block * R
    #print("Note: should set max_c to C-1 once spatial_retrieval works, and check it is faster")
    padded_k_shape = (block_C,) + k.shape[1:]
    padded_v_shape = (block_C,) + v.shape[1:]
    padded_m_shape = (block_C,) + m.shape[1:]
    m_ret, l_ret, v_ret, blk_idx = pl.pallas_call(
        partial(spatial_retrieval_kernel, causal_stride=causal_block, block_t=block_t, max_c=C, bidiagonal=bidiagonal),
        out_shape = [
            jnp.tile(m_in[:, None], (1, R)), # m_out
            jnp.tile(l_in[:, None], (1, R)), # l_out
            jnp.tile(v_in[:, None, :], (1, R, 1)), # v_out
            jnp.tile(ret_idx[:, :, None], (1, 1, P)), # blk_idx_out
            #ret_idx,
        ],
        grid = (len(sizes),),
        in_specs = [
            pl.BlockSpec(q.shape, lambda i: (0, 0)), # q_ref
            pl.BlockSpec(q_bar.shape, lambda i: (0, 0)), # q_bar_ref
            pl.BlockSpec(padded_k_shape, lambda i: (0, 0, 0, 0)), # k_ref
            pl.BlockSpec(padded_v_shape, lambda i: (0, 0, 0, 0)), # v_ref
            pl.BlockSpec(padded_m_shape, lambda i: (0, 0, 0)), # mu_ref
            pl.BlockSpec(sorted_qfwd.shape, lambda i: (0,)), # sorted_qfwd_ref
            pl.BlockSpec(sorted_ridx.shape, lambda i: (0,)), # sorted_ridx_ref
            pl.BlockSpec((None,), lambda i: (i,)), # sizes_ref
            pl.BlockSpec((None,), lambda i: (i,)), # offsets_ref
            pl.BlockSpec((None,), lambda i: (i,)), # kidx_ref
            pl.BlockSpec((None,), lambda i: (i,)), # qidx_ref
            pl.BlockSpec(m_in.shape, lambda i: (0,)), # m_in_ref
        ],
        out_specs = [
            pl.BlockSpec((T, R), lambda i: (0, 0)), # m_out_ref
            pl.BlockSpec((T, R), lambda i: (0, 0)), # l_out_ref
            pl.BlockSpec((T, R, V), lambda i: (0, 0, 0)), # v_out_ref
            pl.BlockSpec((T, R, P), lambda i: (0, 0, 0)), # blk_idx_out_ref
        ],
        compiler_params=plgpu.CompilerParams(num_warps=1, num_stages=2),
        name="spatial_retrieval_forward",
    )(q, q_bar, k, v, m, sorted_qfwd, sorted_ridx, sizes, offsets, subseg_labels_k, subseg_labels_q, m_in)
    m_ret = jnp.where(valid_idx, m_ret, m_in[:, None])
    l_ret = jnp.where(valid_idx, l_ret, 0.0)
    v_ret = jnp.where(valid_idx[:, :, None], v_ret, 0.0)
    #blk_idx = blk_idx[:,:,0]
    if bidiagonal:
        assert P == 1, "bidiagonal spatial retrieval only supports P=1 for now"
        blk_idx = jnp.where((qblk == 0)[:, None, None], -1, blk_idx)
        blk_idx = jnp.where((qblk == 1)[:, None, None], -1, blk_idx)
        blk_idx = jnp.where((qblk == 2)[:, None, None], 0, blk_idx)
    else:
        blk_idx = jnp.where((qblk == 0)[:, None, None], -1, blk_idx)
        qblk_1_blk_idx = jnp.full((P,), -1, dtype=blk_idx.dtype).at[0].set(0)
        blk_idx = jnp.where((qblk == 1)[:, None, None], qblk_1_blk_idx[None,None,:], blk_idx)


    blk_idx_negs_first = jnp.sum(blk_idx[:causal_block] < 0)
    blk_idx_cnts_first = jnp.bincount(blk_idx[:causal_block].flatten(), length=C).at[0].add(-blk_idx_negs_first)
    blk_idx_negs_rest = jnp.sum(blk_idx[causal_block:] < 0)
    blk_idx_cnts_rest = jnp.bincount(blk_idx[causal_block:].flatten(), length=C).at[0].add(-blk_idx_negs_rest)
    #jax.debug.print("spatial retrieval blk_idx counts: first block {}, rest {}", blk_idx_cnts_first, blk_idx_cnts_rest)

    spatial_retrieval_meta = (sorted_qfwd, sorted_ridx, sizes, offsets)

    m_out, l_out, v_out = ref_spatial_retrieval_postprocess(m_in, l_in, v_in, m_ret, l_ret, v_ret)
    return m_out, l_out, v_out, blk_idx, spatial_retrieval_meta

def spatial_retrieval_bwd_kernel(
    q_ref, q_bar_ref, k_ref, v_ref, mu_ref,
    sorted_qfwd_ref, sorted_ridx_ref, sizes_ref, offsets_ref, kidx_ref, qidx_ref,
    lse_ref, dlse_ref, do_ref,
    dq_out_ref, dk_out_ref, dv_out_ref, dmu_out_ref, dq_bar_out_ref,
    *,
    causal_stride: int, block_t: int, max_c: int, bidiagonal: bool, num_spatial: int,
):
    sm_scale = 1.0 / math.sqrt(q_ref.shape[-1])
    qk_scale = math.log2(math.e) * sm_scale
    mu_scale = math.log2(math.e)
    lse_scale = math.log2(math.e)

    qidx = qidx_ref[...] # ()
    q_bar = q_bar_ref[qidx, :] # (D,)

    size_t = sizes_ref[...] # ()
    offset_t = offsets_ref[...] # ()
    kidx = kidx_ref[...] # ()

    # CHECK INDEXING HERE
    C, Q, K, D = k_ref.shape
    kblk = jnp.arange(C)
    mask_k = kblk < max_c
    k = pl.load(k_ref, (slice(None), qidx, kidx, slice(None)), mask=mask_k[:,None], other=0.0) # (C, D)
    v = pl.load(v_ref, (slice(None), qidx, kidx, slice(None)), mask=mask_k[:,None], other=0.0) # (C, V)
    mu = pl.load(mu_ref, (slice(None), qidx, kidx), mask=mask_k, other=DEFAULT_MASK_VALUE) # (C,)
    kblk = jnp.arange(k.shape[0])

    
    #raise NotImplementedError("spatial retrieval kernel not implemented yet")
    def t_body(start_t, carry):
        dk_prev, dv_prev, dmu_prev, dq_bar_prev = carry

        base_t = start_t * block_t + offset_t
        curr_t_slice = pl.dslice(base_t, block_t)
        mask_t = jnp.arange(block_t) + base_t < offset_t + size_t
        qfwd = pl.load(sorted_qfwd_ref, (curr_t_slice,), mask=mask_t, other=0)
        ridx = pl.load(sorted_ridx_ref, (curr_t_slice,), mask=mask_t, other=0)
        qblk = qfwd // causal_stride
        q = pl.load(q_ref, (qfwd, slice(None)), mask=mask_t[:,None], other=0.0) # (t, D)
        lse = pl.load(lse_ref, (qfwd,), mask=mask_t, other=0.0)
        do = pl.load(do_ref, (qfwd, slice(None)), mask=mask_t[:,None], other=0.0)
        dlse = pl.load(dlse_ref, (qfwd,), mask=mask_t, other=0.0)
        qres = q - q_bar[None, :]

        causal_mask = kblk < qblk[:, None] if not bidiagonal else kblk < (qblk[:, None] - 1)

        attn_scores = pl.dot(qres, k, trans_b=True, allow_tf32=True) * qk_scale + mu[None, :] * mu_scale
        attn_scores = attn_scores - lse[:, None] * lse_scale
        attn_scores = jnp.where(causal_mask, attn_scores, DEFAULT_MASK_VALUE)
        attn_to_kill = jnp.zeros_like(attn_scores, dtype=bool)
        for pidx in range(num_spatial):
            blk_idx = jnp.argmax(attn_scores, axis=-1)
            #blk_idx = jnp.zeros_like(blk_idx)
            #blk_idx = qblk - 1
            attn_to_kill = attn_to_kill | jnp.arange(attn_scores.shape[1])[None, :] == blk_idx[:, None]
            attn_scores = jnp.where(attn_to_kill, DEFAULT_MASK_VALUE, attn_scores)

        s = jnp.exp2(attn_scores)
        s = jnp.where(causal_mask & (~attn_to_kill), s, 0.0)

        dv = pl.dot(s.astype(do.dtype), do, trans_a=True, allow_tf32=True) # (C, V)
        dv_next = dv_prev + dv # (C, V)
        gv = pl.dot(do.astype(v.dtype), v, trans_b=True, allow_tf32=True)
        ds = gv + dlse[:, None]
        dqk = ds * s # (t, C)

        dmu = jnp.sum(dqk, axis=0) * mu_scale # (C,)
        dmu_next = dmu_prev + dmu # (C,)
        dk = pl.dot(dqk.astype(q.dtype), qres, trans_a=True, allow_tf32=True) * sm_scale # (C, D)
        dk_next = dk_prev + dk # (C, D)
        dq = pl.dot(dqk.astype(k.dtype), k, allow_tf32=True) * sm_scale # (t, D)
        pl.store(dq_out_ref, (ridx, qfwd, slice(None)), dq.astype(dq_out_ref.dtype), mask=mask_t[:,None])
        dq_bar_next = dq_bar_prev - jnp.sum(dq, axis=0)

        return dk_next, dv_next, dmu_next, dq_bar_next

    dk = jnp.zeros(k.shape, dtype=jnp.float32) # (C, D)
    dv = jnp.zeros(v.shape, dtype=jnp.float32) # (C, V)
    dmu = jnp.zeros(mu.shape, dtype=jnp.float32) # (C,)
    dq_bar = jnp.zeros(q_bar.shape, dtype=jnp.float32) # (D,)

    t_lower_bound = 0
    t_upper_bound = pl.cdiv(size_t, block_t)
    dk, dv, dmu, dq_bar = jax.lax.fori_loop(t_lower_bound, t_upper_bound, t_body, (dk, dv, dmu, dq_bar))

    pl.store(dk_out_ref, (slice(None), qidx, kidx, slice(None)), dk.astype(dk_out_ref.dtype), mask=mask_k[:,None])
    pl.store(dv_out_ref, (slice(None), qidx, kidx, slice(None)), dv.astype(dv_out_ref.dtype), mask=mask_k[:,None])
    pl.store(dmu_out_ref, (slice(None), qidx, kidx), dmu.astype(dmu_out_ref.dtype), mask=mask_k)
    pl.store(dq_bar_out_ref, (kidx, qidx, slice(None),), dq_bar.astype(dq_bar_out_ref.dtype))

def spatial_retrieval_bwd(num_spatial: int, causal_block: int, bidiagonal: bool, q, q_bar, k, v, m, qmeta, ret_idx, lse, v_out, dlse, dv_out, spatial_retrieval_meta):
    """figure out which spatial index to retrieve from, attend to summaries of others"""
    T, R = ret_idx.shape
    Q, D = q_bar.shape
    C, Q, K, V = v.shape

    dlse_modified = dlse - jnp.sum(dv_out * v_out, axis=-1).astype(dlse.dtype) # (T,)

    assert C <= 64, "C too large for spatial retrieval with no blocking over C"
    block_C = max(16, C)

    block_t = 16
    
    qblk = jnp.arange(T) // causal_block
    first_valid_qblk = 3 if bidiagonal else 2
    valid_qblk = qblk >= first_valid_qblk

    sorted_qfwd, sorted_ridx, sizes, offsets = spatial_retrieval_meta
    subseg_labels_q = jnp.broadcast_to(jnp.arange(Q)[:, None], (Q, K)).flatten()
    subseg_labels_k = jnp.broadcast_to(jnp.arange(K)[None, :], (Q, K)).flatten()

    causal_stride = causal_block * R
    #print("Note: should set max_c to C-1 once spatial_retrieval works, and check it is faster")
    padded_k_shape = (block_C,) + k.shape[1:]
    padded_v_shape = (block_C,) + v.shape[1:]
    padded_m_shape = (block_C,) + m.shape[1:]
    dqs, dk, dv, dm, dq_bars = pl.pallas_call(
        partial(spatial_retrieval_bwd_kernel, causal_stride=causal_block, block_t=block_t, max_c=C, bidiagonal=bidiagonal, num_spatial=num_spatial),
        out_shape = [
            jax.ShapeDtypeStruct((R, T, D), q.dtype), # dq_out
            k, v, m,
            jax.ShapeDtypeStruct((K, Q, D), q_bar.dtype), # dq_bar_out
        ],
        grid = (len(sizes),),
        in_specs = [
            pl.BlockSpec(q.shape, lambda i: (0, 0)), # q_ref
            pl.BlockSpec(q_bar.shape, lambda i: (0, 0)), # q_bar_ref
            pl.BlockSpec(padded_k_shape, lambda i: (0, 0, 0, 0)), # k_ref
            pl.BlockSpec(padded_v_shape, lambda i: (0, 0, 0, 0)), # v_ref
            pl.BlockSpec(padded_m_shape, lambda i: (0, 0, 0)), # mu_ref
            pl.BlockSpec(sorted_qfwd.shape, lambda i: (0,)), # sorted_qfwd_ref
            pl.BlockSpec(sorted_ridx.shape, lambda i: (0,)), # sorted_ridx_ref
            pl.BlockSpec((None,), lambda i: (i,)), # sizes_ref
            pl.BlockSpec((None,), lambda i: (i,)), # offsets_ref
            pl.BlockSpec((None,), lambda i: (i,)), # kidx_ref
            pl.BlockSpec((None,), lambda i: (i,)), # qidx_ref
            pl.BlockSpec(lse.shape, lambda i: (0,)), # lse_ref
            pl.BlockSpec(dlse_modified.shape, lambda i: (0,)), # dlse_ref
            pl.BlockSpec(dv_out.shape, lambda i: (0, 0)), # do_ref
        ],
        out_specs = [
            pl.BlockSpec((R, T, D), lambda i: (0, 0, 0)), # dq_out_ref
            pl.BlockSpec(padded_k_shape, lambda i: (0, 0, 0, 0)), # dk_out_ref
            pl.BlockSpec(padded_v_shape, lambda i: (0, 0, 0, 0)), # dv_out_ref
            pl.BlockSpec(padded_m_shape, lambda i: (0, 0, 0)), # dmu_out_ref
            pl.BlockSpec((K, Q, D), lambda i: (0, 0, 0)), # dq_bar_out_ref
        ],
        compiler_params=plgpu.CompilerParams(num_warps=1, num_stages=2),
        name="spatial_retrieval_backward",
    )(q, q_bar, k, v, m, sorted_qfwd, sorted_ridx, sizes, offsets, subseg_labels_k, subseg_labels_q, lse, dlse_modified, dv_out)
    dqs = jnp.where(valid_qblk[None, :, None], dqs, 0.0)
    dq = jnp.sum(dqs, axis=0) # (T, D)
    dq_bar = jnp.sum(dq_bars, axis=0) # (Q, D)
    return dq, dk, dv, dm, dq_bar

def ref_preprocess(q, qmeta):
    _q = q[qmeta.fwd, :] # (Q, n, D)
    q_bar = jnp.mean(_q, axis=1, where=qmeta.fwd[:,:,None]>=0) # (Q, D)
    return jnp.nan_to_num(q_bar, nan=0.0) # (Q, D)

def prepocess_kernel(q_ref, qfwd_ref, qcnt_ref, qbar_ref, *, block_n: int):
    N, D = q_ref.shape # (N, D)
    qcnt = qcnt_ref[...] # ()

    def body(start_n, carry):
        qsum_prev, cnt_prev = carry
        curr_n_slice = pl.dslice(start_n*block_n, block_n)
        qfwd = qfwd_ref[curr_n_slice]
        mask = qfwd >= 0
        q = pl.load(q_ref, (qfwd, slice(None)), mask=mask[:,None], other=0.0) # (n, D)
        qsum_curr = jnp.sum(q.astype(qsum_prev.dtype), axis=0) # (D,)
        cnt_curr = jnp.sum(mask.astype(cnt_prev.dtype)) # ()
        qsum_next = qsum_prev + qsum_curr
        cnt_next = cnt_prev + cnt_curr
        return qsum_next, cnt_next

    qsum = jnp.zeros((D,), dtype=jnp.float32) # (D,)
    cnt = jnp.zeros((), dtype=jnp.int32) # ()

    lower_bound = 0
    upper_bound = pl.cdiv(qcnt, block_n)
    qsum, cnt = lax.fori_loop(lower_bound, upper_bound, body, (qsum, cnt)) # (D,), ()
    qbar = qsum / jnp.maximum(cnt, 1) # (D,)
    qbar_ref[...] = qbar.astype(qbar_ref.dtype) # (D,)

def preprocess(q, qmeta):
    Q, n = qmeta.fwd.shape
    N, D = q.shape # (N, D)
    block_n = min(n, 64)
    assert n % block_n == 0, f"n ({n}) must be divisible by block_n ({block_n})"

    (q_bar,) = pl.pallas_call(
        partial(prepocess_kernel, block_n=block_n),
        out_shape=(jax.ShapeDtypeStruct((Q, D), q.dtype),), # (D,)
        grid=(Q,),
        in_specs=[
            pl.BlockSpec(q.shape, lambda i: (0, 0)),
            pl.BlockSpec((None, n), lambda i: (i, 0)),
            pl.BlockSpec((None,), lambda i: (i,)),
        ],
        out_specs=[
            pl.BlockSpec((None, D,), lambda i: (i, 0)),
        ],
        #compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=1),
        name="preprocess_forward",
    )(q, qmeta.fwd, qmeta.cnt)
    return q_bar

def ref_preprocess_bwd(qmeta, dq_bar): # _, (Q, D)
    scale = 1.0 / jnp.maximum(qmeta.cnt, 1)
    return (dq_bar * scale[:, None])[qmeta.lab, :] # (N, D)

def ref_vk_merge(lse_bar, vk_bar):
    weights = jax.nn.softmax(lse_bar, axis=-1) # (Q, K)
    if False:
        weights = weights * 0.
        print("WARNING: DIPOLE DISABLED IN REF_VK_MERGE")
    return einsum("qk,kvd->qvd", weights, vk_bar).astype(vk_bar.dtype), weights # (Q, V, D)

def ref_vk_merge_bwd(weights, vk_bar, dvk_merged):
    vk_merged = einsum("qk,kvd->qvd", weights, vk_bar) # (Q, V, D)
    dvk_bar = einsum("qvd,qk->kvd", dvk_merged, weights) # (K, V, D)
    dweights = einsum("qvd,kvd->qk", dvk_merged, vk_bar) # (Q, K)
    correction = - einsum("qvd,qvd->q", dvk_merged, vk_merged) # (Q,)
    dlse_bar = (dweights + correction[:, None]) * weights
    return dlse_bar, dvk_bar.astype(dvk_merged.dtype) # (Q, K), (K, V, D)

def lse_cumsum_exclusive_kernel(lse_ref, v_ref, k_ref, lse_out_ref, v_out_ref, k_out_ref):
    C, = lse_ref.shape
    _, V = v_ref.shape
    _, D = k_ref.shape

    def body_c(cidx, carry):
        next_cidx = cidx + 1
        pipe_not_tail = next_cidx < C
        m_prev, l_prev, v_prev, k_prev, lse, v, k = carry
        inv_l_prev = 1.0 / l_prev
        lse_out_prev = math.log(2) * m_prev + jnp.log(l_prev)
        pl.store(lse_out_ref, (cidx,), lse_out_prev.astype(lse_out_ref.dtype))
        v_out_prev = v_prev * inv_l_prev[None]
        pl.store(v_out_ref, (cidx, slice(None)), v_out_prev.astype(v_out_ref.dtype))
        k_out_prev = k_prev * inv_l_prev[None]
        pl.store(k_out_ref, (cidx, slice(None)), k_out_prev.astype(k_out_ref.dtype))

        #lse = lse_ref[cidx] # ()
        lse_prefetch = pl.load(lse_ref, (next_cidx,), mask=pipe_not_tail, other=0.0)
        lse_base2 = lse * math.log2(math.e)
        m_next = jnp.maximum(m_prev, lse_base2)
        correction = jnp.exp2(m_prev - m_next)
        weight_curr = jnp.exp2(lse_base2 - m_next)
        l_next = correction*l_prev + weight_curr
        #v = v_ref[cidx, :] # (V,)
        v_prefetch = pl.load(v_ref, (next_cidx, slice(None)), mask=pipe_not_tail, other=0.0)
        v_next = correction*v_prev + weight_curr*v
        #k = k_ref[cidx, :] # (D,)
        k_prefetch = pl.load(k_ref, (next_cidx, slice(None)), mask=pipe_not_tail, other=0.0)
        k_next = correction*k_prev + weight_curr*k
        return m_next, l_next, v_next, k_next, lse_prefetch, v_prefetch, k_prefetch
    m = jnp.full((), DEFAULT_MASK_VALUE, dtype=jnp.float32)
    l = jnp.full((), 1e-10, dtype=jnp.float32)
    v = jnp.zeros((V,), dtype=jnp.float32)
    k = jnp.zeros((D,), dtype=jnp.float32)
    lower_bound = 0
    upper_bound = C
    lse_initial = lse_ref[0]
    v_initial = v_ref[0, :]
    k_initial = k_ref[0, :]
    m, l, v, k, _, _, _ = lax.fori_loop(lower_bound, upper_bound, body_c, (m, l, v, k, lse_initial, v_initial, k_initial))

def lse_cumsum_exclusive(lse, v, k):
    C, Q, K, V = v.shape
    assert lse.shape == (C, Q, K), f"lse shape {lse.shape} does not match v shape {v.shape}"
    C_, Q_, K_, D = k.shape
    assert C == C_ and Q == Q_ and K == K_, f"k shape {k.shape} does not match v shape {v.shape}"

    (lse_out, v_out, k_out) = pl.pallas_call(
        lse_cumsum_exclusive_kernel,
        out_shape = [lse, v, k],
        grid = (Q, K),
        in_specs = [
            pl.BlockSpec((C, None, None), lambda i, j: (0, i, j)), # lse_ref
            pl.BlockSpec((C, None, None, V), lambda i, j: (0, i, j, 0)), # v_ref
            pl.BlockSpec((C, None, None, D), lambda i, j: (0, i, j, 0)), # k_ref
        ],
        out_specs = [
            pl.BlockSpec((C, None, None), lambda i, j: (0, i, j)), # lse_out_ref
            pl.BlockSpec((C, None, None, V), lambda i, j: (0, i, j, 0)), # v_out_ref
            pl.BlockSpec((C, None, None, D), lambda i, j: (0, i, j, 0)), # k_out_ref
        ],
        compiler_params=plgpu.CompilerParams(num_warps=1, num_stages=1),
        name="lse_cumsum_exclusive_forward",
    )(lse, v, k)
    return lse_out, v_out, k_out

def lse_cumsum_exclusive_bwd_kernel(
    lse_ref, v_ref, k_ref,
    clse_ref,
    dlse_in_ref, dv_in_ref, dk_in_ref,
    dlse_out_ref, dv_out_ref, dk_out_ref,
    ):
    C, = lse_ref.shape
    _, V = v_ref.shape
    _, D = k_ref.shape

    def body_c(start_c, carry):
        cidx = C - 1 - start_c
        c_next_idx = jnp.maximum(cidx - 1, 0)
        pipe_not_tail = cidx > 0
        clse_prev, dlse_prev, dv_prev, dk_prev, lse_curr, v_curr, k_curr = carry
        #lse_curr = lse_ref[cidx] # ()
        lse_prefetch = pl.load(lse_ref, (c_next_idx,), mask=pipe_not_tail, other=0.0)
        weight_curr = jnp.exp(lse_curr - clse_prev)
        pl.store(dv_out_ref, (cidx, slice(None)), (weight_curr*dv_prev).astype(dv_out_ref.dtype))
        pl.store(dk_out_ref, (cidx, slice(None)), (weight_curr*dk_prev).astype(dk_out_ref.dtype))
        #v_curr = v_ref[cidx, :] # (V,)
        #k_curr = k_ref[cidx, :] # (D,)
        dlse_out = weight_curr * (dlse_prev + jnp.sum(dv_prev * v_curr.astype(dv_prev.dtype)) + jnp.sum(dk_prev * k_curr.astype(dk_prev.dtype)))
        pl.store(dlse_out_ref, (cidx,), dlse_out.astype(dlse_out_ref.dtype))
        v_prefetch = pl.load(v_ref, (c_next_idx, slice(None)), mask=pipe_not_tail, other=0.0)
        k_prefetch = pl.load(k_ref, (c_next_idx, slice(None)), mask=pipe_not_tail, other=0.0)

        clse_curr = clse_ref[cidx] # ()
        correction = jnp.exp(clse_curr - clse_prev)

        dlse_curr = dlse_in_ref[cidx]
        dlse_next = dlse_prev * correction  + dlse_curr
        dv_curr = dv_in_ref[cidx, :] # (V,)
        dv_next = dv_prev * correction + dv_curr
        dk_curr = dk_in_ref[cidx, :] # (D,)
        dk_next = dk_prev * correction + dk_curr
        return clse_curr, dlse_next, dv_next, dk_next, lse_prefetch, v_prefetch, k_prefetch
    clse = jnp.full((), -DEFAULT_MASK_VALUE, dtype=jnp.float32)
    dlse = jnp.zeros((), dtype=jnp.float32)
    dv = jnp.zeros((V,), dtype=jnp.float32)
    dk = jnp.zeros((D,), dtype=jnp.float32)
    lower_bound = 0
    upper_bound = C
    # The first prefetch values are not actually used because the scan is exclusive, so init to zero instead.
    lse_initial = jnp.zeros_like(lse_ref[C-1])
    v_initial = jnp.zeros_like(v_ref[C-1, :])
    k_initial = jnp.zeros_like(k_ref[C-1, :])
    init_state = (clse, dlse, dv, dk, lse_initial, v_initial, k_initial)
    clse, dlse, dv, dk, _, _, _ = lax.fori_loop(lower_bound, upper_bound, body_c, init_state)

def lse_cumsum_exclusive_bwd(lse_in, v_in, k_in, lse_out, v_out, k_out, dlse_out, dv_out, dk_out):
    C, Q, K, V = v_in.shape
    _, _, _, D = k_in.shape
    dlse_modified = dlse_out - jnp.sum(dv_out * v_out, axis=-1).astype(dlse_out.dtype) - jnp.sum(dk_out * k_out, axis=-1).astype(dlse_out.dtype) # (C, Q, K)
    (dlse_in, dv_in, dk_in) = pl.pallas_call(
        lse_cumsum_exclusive_bwd_kernel,
        out_shape = [lse_in, v_in, k_in],
        grid = (Q, K),
        in_specs = [
            pl.BlockSpec((C, None, None), lambda i, j: (0, i, j)), # lse_ref
            pl.BlockSpec((C, None, None, V), lambda i, j: (0, i, j, 0)), # v_ref
            pl.BlockSpec((C, None, None, D), lambda i, j: (0, i, j, 0)), # k_ref
            pl.BlockSpec((C, None, None), lambda i, j: (0, i, j)), # clse_ref
            pl.BlockSpec((C, None, None), lambda i, j: (0, i, j)), # dlse_in_ref
            pl.BlockSpec((C, None, None, V), lambda i, j: (0, i, j, 0)), # dv_in_ref
            pl.BlockSpec((C, None, None, D), lambda i, j: (0, i, j, 0)), # dk_in_ref
        ],
        out_specs = [
            pl.BlockSpec((C, None, None), lambda i, j: (0, i, j)), # dlse_out_ref
            pl.BlockSpec((C, None, None, V), lambda i, j: (0, i, j, 0)), # dv_out_ref
            pl.BlockSpec((C, None, None, D), lambda i, j: (0, i, j, 0)), # dk_out_ref
        ],
        compiler_params=plgpu.CompilerParams(num_warps=1, num_stages=1),
        name="lse_cumsum_exclusive_backward",
    )(lse_in, v_in, k_in, lse_out, dlse_modified, dv_out, dk_out)
    return dlse_in, dv_in, dk_in


NUM_RETRIEVALS = 8

def ref_attn(q, k, v, qmeta, kmeta):
    REF_ATTN_DIPOLE = False
    q_bar = ref_preprocess(q, qmeta) # (Q, D)
    lse_bar, k_bar, v_bar = ref_initial(q_bar, k, v, kmeta)
    if REF_ATTN_DIPOLE:
        vk_bar, k_bar_vk, v_bar_vk = ref_initial_dipole(q_bar, k, v, kmeta)
        vk_merged, merge_weights = ref_vk_merge(lse_bar, vk_bar)
        vk_bar_full, _, _ = ref_initial_dipole_full(q_bar, k, v, kmeta)
    if NUM_RETRIEVALS is None:
        lse, v_out = ref_final(q, q_bar, k_bar, v_bar, lse_bar, qmeta)
        if REF_ATTN_DIPOLE:
            extra_v_out = ref_final_dipole(q, q_bar, vk_merged, qmeta)
            extra_v_out = ref_final_dipole_full(q, q_bar, k_bar, lse_bar, vk_bar_full, qmeta)
            v_out = v_out + extra_v_out
        return lse, v_out
    else:
        #assert not REF_ATTN_DIPOLE, "REF_ATTN_DIPOLE not supported with non-None NUM_RETRIEVALS"
        num_retrievals = NUM_RETRIEVALS
        #lse, v_out = ref_final_retrieval(num_retrievals, q, q_bar, k_bar, v_bar, lse_bar, qmeta)
        m_mon, l_mon, v_out_mon, ret_idx = ref_final_retrieval(num_retrievals, q, q_bar, k_bar, v_bar, lse_bar, qmeta)
        #lse, v_out = leaf_retrieval_nop(q, q_bar, k, v, qmeta, kmeta, ret_idx, m_mon, l_mon, v_out_mon)
        lse, v_out = ref_leaf_retrieval(q, q_bar, k, v, qmeta, kmeta, ret_idx, m_mon, l_mon, v_out_mon)
        if REF_ATTN_DIPOLE:
            lse_noret, _ = ref_final(q, q_bar, k_bar, v_bar, lse_bar, qmeta)
            extra_v_out = ref_final_dipole(q, q_bar, vk_merged, qmeta)
            extra_v_out = ref_final_dipole_ret(q, q_bar, vk_merged, merge_weights, vk_bar, qmeta, ret_idx)
            extra_v_out = ref_final_dipole_full(q, q_bar, k_bar, lse_bar, vk_bar_full, qmeta)
            extra_v_out = ref_final_dipole_full_ret(q, q_bar, k_bar, lse_bar, vk_bar_full, qmeta, ret_idx)
            v_out = v_out + jnp.exp(lse_noret - lse)[:, None] * extra_v_out
            print(f"ret_idx shape: {ret_idx.shape}, merge_weights shape: {merge_weights.shape}")
        return lse, v_out

def ref_attn_fwd(q, k, v, qmeta, kmeta):
    q_bar = ref_preprocess(q, qmeta) # (Q, D)
    lse_bar, k_bar, v_bar = ref_initial(q_bar, k, v, kmeta) # (Q, K), (Q, K, D), (Q, K, V)
    lse, v_out = ref_final(q, q_bar, k_bar, v_bar, lse_bar, qmeta) # (Q, n), (Q, n, V)
    return (lse, v_out), (lse, v_out, q_bar, lse_bar, k_bar, v_bar)

def ref_attn_bwd(q, k, v, qmeta, kmeta, res, dlse, dv_out):
    lse, v_out, q_bar, lse_bar, k_bar, v_bar = res
    dq0, dq_bar0, dk_bar, dv_bar, dlse_bar0 = ref_final_bwd(q, q_bar, k_bar, v_bar, lse_bar, qmeta, lse, v_out, dlse, dv_out)
    dlse_bar = dlse_bar0 # (Q, K)
    dq_bar1, dk, dv = ref_initial_bwd(q_bar, k, v, kmeta, lse_bar, k_bar, v_bar, dlse_bar, dk_bar, dv_bar)
    dq1 = ref_preprocess_bwd(qmeta, dq_bar0 + dq_bar1) # (N, D)
    dq = dq0 + dq1 # (Q, D)
    return dq, dk, dv

def lse_v_out_cumsum(exclusive: bool, lse_b, k_b, v_b):
    #C, Q, K, V = v_b.shape
    lse_b = jnp.maximum(lse_b, DEFAULT_MASK_VALUE)
    total_lse = logsumexp(lse_b, axis=0) # (Q, K)
    total_lse = jnp.maximum(total_lse, DEFAULT_MASK_VALUE)
    clse = lax.cumlogsumexp(lse_b, axis=0) # (C,Q,K)
    if exclusive:
        clse = jnp.concatenate((jnp.full_like(total_lse[None,:,:], DEFAULT_MASK_VALUE/10.), clse[:-1,:,:]), axis=0)
    weights = jnp.exp(lse_b - total_lse) # (C,Q,K)
    csum_v_b = lax.cumsum(weights[:,:,:,None]*v_b, axis=0) # (C,Q,K,V)
    csum_k_b = lax.cumsum(weights[:,:,:,None]*k_b, axis=0) # (C,Q,K,D)
    csum_weights = lax.cumsum(weights, axis=0) # (C,Q,K)
    #if exclusive:
    #    csum_v_b = csum_v_b - weights[:,:,:,None]*v_b
    #    csum_k_b = csum_k_b - weights[:,:,:,None]*k_b
    #    csum_weights = csum_weights - weights
    csum_weights = jnp.maximum(csum_weights, 1e-10)
    c_v_b = (csum_v_b / csum_weights[:,:,:,None]).astype(v_b.dtype) # (C,Q,K,V)
    c_k_b = (csum_k_b / csum_weights[:,:,:,None]).astype(k_b.dtype) # (C,Q,K,D)
    #c_v_b = jnp.where(csum_weights[:,:,:,None] > 1e-30, csum_v_b/csum_weights[:,:,:,None], 0.0).astype(v_b.dtype) # (C,Q,K,V)
    #c_k_b = jnp.where(csum_weights[:,:,:,None] > 1e-30, csum_k_b/csum_weights[:,:,:,None], 0.0).astype(k_b.dtype) # (C,Q,K,D)
    if exclusive:
        c_v_b = jnp.concatenate((jnp.zeros_like(c_v_b[0:1,:,:,:]), c_v_b[:-1,:,:,:]), axis=0)
        c_k_b = jnp.concatenate((jnp.zeros_like(c_k_b[0:1,:,:,:]), c_k_b[:-1,:,:,:]), axis=0)
    return clse, c_k_b, c_v_b

def lse_v_out_ass_scan_exclusive(lse_b, k_b, v_b):
    clse = lax.cumlogsumexp(lse_b, axis=0) # (C,Q,K)
    clse = jnp.concatenate((jnp.full_like(clse[0:1,:,:], DEFAULT_MASK_VALUE), clse[:-1,:,:]), axis=0)
    def combine(a, b):
        ma, la, ka, va = a
        mb, lb, kb, vb = b
        m = jnp.maximum(ma, mb)
        corr_a, corr_b = jnp.exp(ma - m), jnp.exp(mb - m)
        l = corr_a * la + corr_b * lb
        k = corr_a[:,None] * ka + corr_b[:,None] * kb
        v = corr_a[:,None] * va + corr_b[:,None] * vb
        return m, l, k, v
    lse_b_id = jnp.concatenate((jnp.full_like(lse_b[0:1,:,:], DEFAULT_MASK_VALUE), lse_b[:-1,:,:]), axis=0)
    l_b_id = jnp.concatenate((jnp.zeros_like(lse_b[0:1,:,:]), jnp.ones_like(lse_b[:-1,:,:])), axis=0)
    k_b_id = jnp.concatenate((jnp.zeros_like(k_b[0:1,:,:,:]), k_b[:-1,:,:,:]), axis=0)
    v_b_id = jnp.concatenate((jnp.zeros_like(v_b[0:1,:,:,:]), v_b[:-1,:,:,:]), axis=0)

    m, l, k, v = lax.associative_scan(combine, (lse_b_id, l_b_id, k_b_id, v_b_id))
    c_v_b = (v / jnp.maximum(l[:,:,None], 1e-30)).astype(v_b.dtype) # (C,Q,K,V)
    c_k_b = (k / jnp.maximum(l[:,:,None], 1e-30)).astype(k_b.dtype) # (C,Q,K,D)
    return clse, c_k_b, c_v_b

def lse_v_out_scan_exclusive(lse_b, k_b, v_b):
    def scan_fn(carry, x):
        ma, la, ka, va = carry
        eps = 1e-30
        lse_out = ma + jnp.log(jnp.maximum(la, eps))
        v_out = (va / jnp.maximum(la[:,None], eps)).astype(v_b.dtype)
        k_out = (ka / jnp.maximum(la[:,None], eps)).astype(k_b.dtype)
        mb, lb, kb, vb = x
        m = jnp.maximum(ma, mb)
        corr_a, corr_b = jnp.exp(ma - m), jnp.exp(mb - m)
        l = corr_a * la + corr_b * lb
        k = corr_a[:,None] * ka + corr_b[:,None] * kb
        v = corr_a[:,None] * va + corr_b[:,None] * vb
        return (m, l, k, v), (lse_out, v_out, k_out)
    init_carry = (jnp.full_like(lse_b[0,:,:], -10.), jnp.ones_like(lse_b[0,:,:]), jnp.zeros_like(k_b[0,:,:,:]), jnp.zeros_like(v_b[0,:,:,:]))
    ones_b = jnp.ones_like(lse_b)
    _, (lse_out_b, v_out_b, k_out_b) = lax.scan(scan_fn, init_carry, (lse_b, ones_b, k_b, v_b))
    return lse_out_b, v_out_b, k_out_b

def lse_v_out_matmul_cumsum(lse_b, k_b, v_b): # (C,Q,K), (C,Q,K,D), (C,Q,K,V)
    C, Q, K = lse_b.shape
    clse = lax.cumlogsumexp(lse_b, axis=0) # (C,Q,K)
    qk = lse_b[None, :, ...] - jnp.maximum(clse[:, None, ...], DEFAULT_MASK_VALUE) # (C,C,Q,K)
    mask = jnp.arange(C)[:, None, None, None] > jnp.arange(C)[None, :, None, None] # (C,C,1,1)
    qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)
    s = jnp.exp(qk) # (C,C,Q,K)
    v_out_b = einsum("tsqk,sqkv->tqkv", s.astype(v_b.dtype), v_b) # (C,Q,K,V)
    k_out_b = einsum("tsqk,sqkd->tqkd", s.astype(k_b.dtype), k_b) # (C,Q,K,D)
    return clse, k_out_b, v_out_b

    


def do_causal_cluster(q, k, v, Q=64, K=64, B=2**12):
    if Q != K:
        return do_irregular_causal_cluster(q, k, v, Q=Q, K=K, B=B)
    T, D = q.shape
    S, V = v.shape
    #Q = 64
    #K = 64
    cluster_num_iters = 1
    qk_stack = jnp.stack([q, k], axis=0) # (2, T, D)
    #qcentroids = kmeans_centroids_cuda(Q, cluster_num_iters, q)
    #kcentroids = kmeans_centroids_cuda(K, cluster_num_iters, k)
    assert Q == K
    qkcentroids = vmap(kmeans_centroids_cuda, in_axes=(None, None, 0))(Q, cluster_num_iters, qk_stack)

    #QcapGlobal = math.ceil(4*T / Q)
    #qmeta = custom_assign_indices(QcapGlobal, jax.lax.stop_gradient(lax.stop_gradient(q)), qcentroids)

    #B = 2**12
    assert S % B == 0, f"S ({S}) must be divisible by B ({B})"
    assert S == T, f"S ({S}) must be equal to T ({T})"
    C = S // B

    qb = einshape("(cb)d->cbd", q, a=C, b=B) # (C, B, D)
    kb = einshape("(cb)d->cbd", k, a=C, b=B) # (C, B, D)
    #vb = einshape("(cb)v->cbv", v, a=C, b=B) # (C, B, V)

    qk_stack_b = einshape("a(cb)d->acbd", qk_stack, a=2, b=B, c=C) # (2, C, B, D)

    Qcap = math.ceil(4*B / Q)
    Kcap = math.ceil(4*B / K)
    # Note this will unnecessarily replicate centroids for the cuda call
    qmeta_b = vmap(partial(custom_assign_indices, Qcap), in_axes=(0,None))(jax.lax.stop_gradient(qb), qkcentroids[0])
    kmeta_b = vmap(partial(custom_assign_indices, Kcap), in_axes=(0,None))(jax.lax.stop_gradient(kb), qkcentroids[1])
    return qmeta_b, kmeta_b

def do_irregular_causal_cluster(q, k, v, Q=64, K=64, B=2**12):
    print("WARNING: USING EXPERIMENTAL IRREGULAR CAUSAL CLUSTERING")
    T, D = q.shape
    S, V = v.shape
    #Q = 64
    #K = 64
    cluster_num_iters = 1
    qk_stack = jnp.stack([q, k], axis=0) # (2, T, D)
    qcentroids = kmeans_centroids_cuda(Q, cluster_num_iters, q)
    kcentroids = kmeans_centroids_cuda(K, cluster_num_iters, k)
    #qkcentroids = vmap(kmeans_centroids_cuda, in_axes=(None, None, 0))(Q, cluster_num_iters, qk_stack)

    #QcapGlobal = math.ceil(4*T / Q)
    #qmeta = custom_assign_indices(QcapGlobal, jax.lax.stop_gradient(lax.stop_gradient(q)), qcentroids)

    #B = 2**12
    assert S % B == 0, f"S ({S}) must be divisible by B ({B})"
    assert S == T, f"S ({S}) must be equal to T ({T})"
    C = S // B

    qb = einshape("(cb)d->cbd", q, a=C, b=B) # (C, B, D)
    kb = einshape("(cb)d->cbd", k, a=C, b=B) # (C, B, D)
    #vb = einshape("(cb)v->cbv", v, a=C, b=B) # (C, B, V)

    #qk_stack_b = einshape("a(cb)d->acbd", qk_stack, a=2, b=B, c=C) # (2, C, B, D)

    Qcap = math.ceil(4*B / Q)
    Kcap = math.ceil(4*B / K)
    # Note this will unnecessarily replicate centroids for the cuda call
    qmeta_b = vmap(partial(custom_assign_indices, Qcap), in_axes=(0,None))(jax.lax.stop_gradient(qb), qcentroids)
    kmeta_b = vmap(partial(custom_assign_indices, Kcap), in_axes=(0,None))(jax.lax.stop_gradient(kb), kcentroids)
    return qmeta_b, kmeta_b

def ref_qbar_b_merge(q_bar_b, qmeta_b):
    C, Q, D = q_bar_b.shape
    cnt = qmeta_b.cnt # (C,Q)
    sum_q = jnp.sum(q_bar_b * cnt[:,:,None], axis=0) # (Q,D)
    sum_cnt = jnp.sum(cnt, axis=0) # (Q,)
    return sum_q / jnp.maximum(sum_cnt[:,None], 1) # (Q,D)

@jax.custom_vjp
def inspect_gradients(x):
    return x
def inspect_gradients_fwd(x):
    return x, ()
def inspect_gradients_bwd(res, dx):
    has_nan = jnp.isnan(dx).any()
    has_inf = jnp.isinf(dx).any()
    max_abs = jnp.max(jnp.abs(dx))
    jax.debug.print("Gradient stats - has_nan: {}, has_inf: {}, max_abs: {}", has_nan, has_inf, max_abs)
    return (dx,)
inspect_gradients.defvjp(inspect_gradients_fwd, inspect_gradients_bwd)

def weighted_cummean_exclusive(x, weights, eps=1.0):
    weighted_cumsum = lax.cumsum(x * weights, axis=0)
    weights_cumsum = lax.cumsum(weights, axis=0)
    cummean = weighted_cumsum / jnp.maximum(weights_cumsum, eps)
    cummean = jnp.concatenate([jnp.zeros_like(cummean[0:1,...]), cummean[:-1,...]], axis=0)
    return cummean

def custom_vk_merge_and_lse(q_bar, k_bar, lse_bar, vk_bar, cnt): # (Q, D), (K, D), (Q, K), (K,)
    sm_scale = 1.0 / math.sqrt(q_bar.shape[-1])
    bias = cnt # (K,)
    logweights = einsum("qd,kd->qk", q_bar, k_bar) * sm_scale + jnp.log(bias[None, :]+1e-1) # (Q, K)
    logweights = jnp.minimum(logweights, lse_bar)  # (Q, K)
    #logweights = lse_bar
    weights = jax.nn.softmax(logweights, axis=-1) # (Q, K)
    lse = logsumexp(logweights, axis=-1) # (Q,)
    vk_merged = einsum("qk,kvd->qvd", weights, vk_bar).astype(vk_bar.dtype) # (Q, V, D)
    return lse, vk_merged, weights, logweights

def block_diagonal_causal_attn(q, k, v, B=2**13):
    T, D = q.shape
    S, V = v.shape
    assert S % B == 0, f"S ({S}) must be divisible by B ({B})"
    assert S == T, f"S ({S}) must be equal to T ({T})"
    C = S // B

    qb = einshape("(cb)d->cbd", q, a=C, b=B) # (C, B, D)
    kb = einshape("(cb)d->cbd", k, a=C, b=B) # (C, B, D)
    vb = einshape("(cb)v->cbv", v, a=C, b=B) # (C, B, V)

    def attn_block(qb, kb, vb):
        lse_block, v_out_block = pallas_flash(qb, kb, vb, window=(B,0), sm_scale=1.0 / math.sqrt(D))
        return lse_block, v_out_block

    lse_b, v_out_b = vmap(attn_block)(qb, kb, vb) # (C, B), (C, B, V)
    lse = einshape("cb->(cb)", lse_b, a=C, b=B) # (T,)
    v_out = einshape("cbv->(cb)v", v_out_b, a=C, b=B) # (T, V)
    return lse, v_out

def moba_causal_attn(q, k, v, B=2**9, num_spatial=1, MB=2**13):
    lse_base, v_out_base = block_diagonal_causal_attn(q, k, v, B=MB)
    T, D = q.shape
    S, V = v.shape
    assert S % B == 0, f"S ({S}) must be divisible by B ({B})"
    assert S == T, f"S ({S}) must be equal to T ({T})"
    C = S // B

    qb = einshape("(cb)d->cbd", q, a=C, b=B) # (C, B, D)
    kb = einshape("(cb)d->cbd", k, a=C, b=B) # (C, B, D)
    vb = einshape("(cb)v->cbv", v, a=C, b=B) # (C, B, V)

    kbar_b = jnp.mean(kb, axis=1) # (C, D)
    kbar_b = kbar_b[:, None, None, :] # (C, 1, 1, D)
    vbar_b = jnp.mean(vb, axis=1) # (C, V)
    vbar_b = vbar_b[:, None, None, :] # (C, 1, 1, V)

    @struct.dataclass
    class DummyMeta:
        lab: jnp.ndarray

    @struct.dataclass
    class DummyKMeta:
        lab: jnp.ndarray
        fwd: jnp.ndarray
        cnt: jnp.ndarray
    dummy_qmetab = DummyMeta(lab=jnp.zeros((C, B), dtype=jnp.int32))
    dummy_klab = jnp.zeros((C, B), dtype=jnp.int32)
    dummy_kcnt = jnp.full((C, 1), B, dtype=jnp.int32)
    dummy_kfwd = jnp.tile(jnp.arange(B)[None, None, :], (C, 1, 1)).astype(jnp.int32) # (C, 1, B)
    dummy_kmetab = DummyKMeta(lab=dummy_klab, fwd=dummy_kfwd, cnt=dummy_kcnt)
    dummy_m = jnp.ones((C, 1, 1)) * jnp.log(B)
    dummy_ret_idx = jnp.zeros((T, 1), dtype=jnp.int32)
    assert dummy_ret_idx.dtype == jnp.int32
    dummy_qbar = jnp.zeros((1, D), dtype=q.dtype)

    #m_in = lse_base
    #l_in = jnp.ones_like(m_in)
    #v_in = v_out_base
    m_in = jnp.full((T,), DEFAULT_MASK_VALUE/10., dtype=jnp.float32)
    l_in = jnp.zeros((T,), dtype=jnp.float32)
    v_in = jnp.zeros((T, V), dtype=v.dtype)
    #m_spat, l_spat, v_spat, blk_idx, spatial_retrieval_meta = spatial_retrieval(num_spatial, B, False, q, dummy_qbar, kbar_b, vbar_b, dummy_m, dummy_qmetab, dummy_ret_idx, m_in, l_in, v_in)
    #lse, v_out, ret_meta = leaf_retrieval_causal(B, q, dummy_qbar, k, v, dummy_qmetab, dummy_kmetab, dummy_ret_idx, blk_idx, m_in, l_in, v_in)
    #lse, v_out, ret_meta = leaf_retrieval_causal(B, q, dummy_qbar, k, v, dummy_qmetab, dummy_kmetab, dummy_ret_idx, blk_idx, m_spat, l_spat, v_spat)
    #lse, v_out = lse_base, v_out_base
    #lse, v_out = ref_lse_value_merge(lse, lse_base, v_out, v_out_base)
    lse_moba, v_out_moba = moba_retrieval_causal_auto(num_spatial, B, MB, q, k, v)
    lse, v_out = ref_lse_value_merge(lse_moba, lse_base, v_out_moba, v_out_base)
    return lse, v_out
#def spatial_retrieval(num_spatial: int, causal_block: int, bidiagonal: bool, q, q_bar, k, v, m, qmeta, ret_idx, m_in, l_in, v_in):
#    lse, v_out, ret_meta = leaf_retrieval_causal(causal_block, q, q_bar, k, v, qmeta, kmeta, ret_idx, blk_idx, m_mon, l_mon, v_out_mon)

@partial(jax.custom_vjp, nondiff_argnums=(0,1,2))
def moba_retrieval_causal_auto(num_spatial: int, causal_block: int, meta_block: int, q, k, v):
    (lse, v_out), res = moba_retrieval_causal_auto_fwd(num_spatial, causal_block, meta_block, q, k, v)
    return lse, v_out
def moba_retrieval_causal_auto_fwd(num_spatial: int, causal_block: int, meta_block: int, q, k, v):
    assert meta_block >= causal_block
    assert meta_block % causal_block == 0
    meta_multiplier = int(meta_block // causal_block)
    B = causal_block
    T, D = q.shape
    S, V = v.shape
    assert S % B == 0, f"S ({S}) must be divisible by B ({B})"
    assert S == T, f"S ({S}) must be equal to T ({T})"
    C = S // B

    qb = einshape("(cb)d->cbd", q, a=C, b=B) # (C, B, D)
    kb = einshape("(cb)d->cbd", k, a=C, b=B) # (C, B, D)
    vb = einshape("(cb)v->cbv", v, a=C, b=B) # (C, B, V)

    kbar_b = jnp.mean(kb, axis=1) # (C, D)
    kbar_b = kbar_b[:, None, None, :] # (C, 1, 1, D)
    vbar_b = jnp.mean(vb, axis=1) # (C, V)
    vbar_b = vbar_b[:, None, None, :] # (C, 1, 1, V)

    @struct.dataclass
    class DummyMeta:
        lab: jnp.ndarray

    @struct.dataclass
    class DummyKMeta:
        lab: jnp.ndarray
        fwd: jnp.ndarray
        cnt: jnp.ndarray
    dummy_qmetab = DummyMeta(lab=jnp.zeros((C, B), dtype=jnp.int32))
    dummy_klab = jnp.zeros((C, B), dtype=jnp.int32)
    dummy_kcnt = jnp.full((C, 1), B, dtype=jnp.int32)
    dummy_kfwd = jnp.tile(jnp.arange(B)[None, None, :], (C, 1, 1)).astype(jnp.int32) # (C, 1, B)
    dummy_kmetab = DummyKMeta(lab=dummy_klab, fwd=dummy_kfwd, cnt=dummy_kcnt)
    dummy_m = jnp.ones((C, 1, 1)) * jnp.log(B)
    dummy_ret_idx = jnp.zeros((T, 1), dtype=jnp.int32)
    assert dummy_ret_idx.dtype == jnp.int32
    dummy_qbar = jnp.zeros((1, D), dtype=q.dtype)

    #m_in = lse_base
    #l_in = jnp.ones_like(m_in)
    #v_in = v_out_base
    m_in = jnp.full((T,), DEFAULT_MASK_VALUE/10., dtype=jnp.float32)
    l_in = jnp.zeros((T,), dtype=jnp.float32)
    v_in = jnp.zeros((T, V), dtype=v.dtype)
    m_spat, l_spat, v_spat, blk_idx, spatial_retrieval_meta = moba_spatial_retrieval(num_spatial, B, False, meta_multiplier, q, dummy_qbar, kbar_b, vbar_b, dummy_m, dummy_qmetab, dummy_ret_idx, m_in, l_in, v_in)
    lse, v_out, ret_meta = leaf_retrieval_causal(B, q, dummy_qbar, k, v, dummy_qmetab, dummy_kmetab, dummy_ret_idx, blk_idx, m_in, l_in, v_in)
    res = (q, k, v, dummy_qbar, dummy_qmetab, dummy_kmetab, dummy_ret_idx, blk_idx, lse, v_out, ret_meta)
    return (lse, v_out), res
def moba_retrieval_causal_auto_bwd(num_spatial: int, causal_block: int, meta_block: int, res, grads):
    dlse, dv_out = grads
    q, k, v, q_bar, qmeta, kmeta, ret_idx, blk_idx, lse, v_out, ret_meta = res

    dq_r, dk_r, dv_r = leaf_retrieval_causal_bwd(causal_block, q, q_bar, k, v, qmeta, kmeta, ret_idx, blk_idx, lse, v_out, dlse, dv_out, ret_meta)
    return dq_r, dk_r, dv_r
moba_retrieval_causal_auto.defvjp(moba_retrieval_causal_auto_fwd, moba_retrieval_causal_auto_bwd)




def causal_attn(Q, K, B, num_retrievals, bidiagonal, dipole, q, k, v, num_spatial=1):
    #print(f"WARNING: using default num_spatial of {num_spatial}")
    return ref_causal_attn(q, k, v, Q=Q, K=K, B=B, num_retrievals=num_retrievals, bidiagonal=bidiagonal, dipole=dipole, num_spatial=num_spatial)

#def ref_causal_attn(q, k, v, Q=128, K=128, B=2**13, num_retrievals=NUM_RETRIEVALS, bidiagonal=False, dipole=False, approx_dipole=True):
def ref_causal_attn(q, k, v, Q=128, K=128, B=2**13, num_retrievals=8, bidiagonal=False, dipole=False, approx_dipole=True, num_spatial=1):
    print("Compiling fma.pallas_retrieval.ref_causal_attn...")
    if False:
        #print("WARNING: Hard-Coding K to 16, Q=16...")
        print("WARNING: Hard-Coding K to 128")
        #prev_Q = int(Q)
        #Q = 16
        K = 128
    qmeta_b, kmeta_b = do_causal_cluster(q, k, v, Q=Q, K=K, B=B)
    T, D = q.shape
    S, V = v.shape
    C, Q, N = qmeta_b.fwd.shape
    _, K, U = kmeta_b.fwd.shape
    B = S // C

    def expand_T(arr):
        return einshape("(cb)...->cb...", arr, a=C, b=B)
    qb = expand_T(q) # (C, B, D)
    kb = expand_T(k) # (C, B, D)
    vb = expand_T(v) # (C, B, V)
    #qb = lax.reshape(qb, (C, B, D))
    #kb = lax.reshape(kb, (C, B, D))
    #vb = lax.reshape(vb, (C, B, V))

    #q_bar_b = vmap(ref_preprocess)(qb, qmeta_b) # (C, Q, D)
    q_bar_b = vmap(preprocess_auto)(qb, qmeta_b) # (C, Q, D)
    q_bar = ref_qbar_b_merge(q_bar_b, qmeta_b).astype(q.dtype) # (Q, D)
    if True:
        print("WARNING: Stopped Gradients on q_bar")
        q_bar = jax.lax.stop_gradient(q_bar)
    if False:
        reduce = 8
        print(f"WARNING: Artifically reducing q-cluster count by factor of {reduce} through averaging")
        q_bar_prime = einshape("(ab)d->abd", q_bar, a=Q//reduce, b=reduce)
        q_bar_prime = jnp.ones_like(q_bar_prime)*jnp.mean(q_bar_prime, keepdims=True, axis=1)
        q_bar = einshape("abd->(ab)d", q_bar_prime, a=Q//reduce, b=reduce)
        q_bar = jnp.ones_like(q_bar)*jnp.mean(q_bar, keepdims=True, axis=0)
    q_bar_b = lax.broadcast(q_bar, (C,)) # (C, Q, D)
    if False:
        q_bar = jnp.zeros_like(q_bar)
        q_bar_b = jnp.zeros_like(q_bar_b)
        print("WARNING: USING ZERO q_bar_b TO DISABLE QUERY CLUSTERS")
    if False:
        scale_factor = float(prev_Q / 128)
        print(f"WARNING: SCALING DOWN q_bar_B by {scale_factor:.2f}")
        q_bar = scale_factor * q_bar
        q_bar_b = scale_factor * q_bar_b
    #q_bar_b = jnp.zeros((C, Q, D), dtype=q.dtype) # (C, Q, D)
    #jax.debug.print("q_bar_b has NaN: {}", jnp.isnan(q_bar_b).any())
    drop_last = partial(jax.tree.map, lambda x: x[:-1])
    drop_first = partial(jax.tree.map, lambda x: x[1:])

    lse_flash, v_out_flash = vmap(partial(pallas_flash, window=(B,0), sm_scale=1.0 / math.sqrt(D)))(qb, kb, vb) # (C, B), (C, B, V)
    if bidiagonal:
        extra_lse_flash, extra_v_out_flash = vmap(partial(pallas_flash, window=(B,B), sm_scale=1.0 / math.sqrt(D)))(drop_first(qb), drop_last(kb), drop_last(vb)) # (Cm1, B), (Cm1, B, V)
        extra_lse_flash = jnp.concatenate([jnp.full_like(extra_lse_flash[0:1], DEFAULT_MASK_VALUE), extra_lse_flash], axis=0)
        extra_v_out_flash = jnp.concatenate([jnp.zeros_like(extra_v_out_flash[0:1]), extra_v_out_flash], axis=0)
        lse_flash, v_out_flash = ref_lse_value_merge(lse_flash, extra_lse_flash, v_out_flash, extra_v_out_flash)
    #lse_bar, k_bar, v_bar = vmap(initial_auto)(q_bar_b[:-1], kb[:-1], vb[:-1], drop_last(kmeta_b)) # CQK, CQKD, CQKV
    lse_bar, k_bar, v_bar = vmap(initial_auto)(q_bar_b, kb, vb, kmeta_b) # CQK, CQKD, CQKV
    if dipole and approx_dipole:
        vk_bar, k_bar_vk, v_bar_vk = vmap(ref_initial_dipole)(q_bar_b, kb, vb, kmeta_b) # CKVD, CKD, CKV
        vk_bar_weights = kmeta_b.cnt # CK
        #vk_bar_cumsum = lax.cumsum(vk_bar * vk_bar_weights[:,:,None, None], axis=0) # CKVD
        #vk_bar_weights_cumsum = lax.cumsum(vk_bar_weights, axis=0) # CK
        #vk_bar_cummean = (vk_bar_cumsum / jnp.maximum(vk_bar_weights_cumsum[:,:,None, None], 1)).astype(vk_bar.dtype) # CKVD
        #vk_bar = jnp.concatenate([jnp.zeros_like(vk_bar_cummean[0:1,:,:,:]), vk_bar_cummean[:-1,:,:,:]], axis=0)
        vk_bar = weighted_cummean_exclusive(vk_bar, vk_bar_weights[:,:,None, None])
        k_bar_vk = weighted_cummean_exclusive(k_bar_vk, vk_bar_weights[:,:,None])
        v_bar_vk = weighted_cummean_exclusive(v_bar_vk, vk_bar_weights[:,:,None])
    elif dipole and not approx_dipole:
        vk_bar, _, _ = vmap(ref_initial_dipole_full)(q_bar_b, kb, vb, kmeta_b) # CQKVD, CQKD, CQKV
        _, vk_bar_merged, _ = vmap(lse_cumsum_exclusive_auto, in_axes=(None, -2, -2), out_axes=-2)(lse_bar, vk_bar, vk_bar)
    lse_bar_old, k_bar_old, v_bar_old = lse_bar, k_bar, v_bar
    #jax.debug.print("Init v_bar has NaN: {}", jnp.isnan(v_bar).any())
    #lse_bar, k_bar, v_bar = lse_v_out_cumsum(False, lse_bar, k_bar, v_bar)
    #lse_bar, k_bar, v_bar = lse_v_out_cumsum(True, lse_bar, k_bar, v_bar)
    #lse_bar, k_bar, v_bar = lse_v_out_matmul_cumsum(lse_bar, k_bar, v_bar)
    lse_bar, k_bar, v_bar = lse_cumsum_exclusive_auto(lse_bar, k_bar, v_bar)
    if bidiagonal:
        lse_bar = jnp.concatenate([jnp.full_like(lse_bar[0:1], 0.5*DEFAULT_MASK_VALUE), lse_bar[:-1]], axis=0)
        k_bar = jnp.concatenate([jnp.zeros_like(k_bar[0:1]), k_bar[:-1]], axis=0)
        v_bar = jnp.concatenate([jnp.zeros_like(v_bar[0:1]), v_bar[:-1]], axis=0)
    #jax.debug.print("Cumsum v_bar has NaN: {}", jnp.isnan(v_bar).any())
    k_bar, v_bar = k_bar.astype(k.dtype), v_bar.astype(v.dtype)
    #lse_b, v_out_b = vmap(final_auto)(drop_first(qb), drop_first(q_bar_b), k_bar, v_bar, lse_bar, drop_first(qmeta_b))
    lse_b, v_out_b = vmap(final_auto)(qb, q_bar_b, k_bar, v_bar, lse_bar, qmeta_b)
    if dipole:
        if approx_dipole:
            vk_bar_merged, merge_weights = vmap(ref_vk_merge)(lse_bar, vk_bar)
            extra_v_out_b = vmap(ref_final_dipole)(qb, q_bar_b, vk_bar_merged, qmeta_b)
        else:
            extra_v_out_b = vmap(ref_final_dipole_full)(qb, q_bar_b, k_bar, lse_bar, vk_bar_merged, qmeta_b)
        v_out_b = v_out_b + extra_v_out_b
    lse_b_flat = einshape("cb->(cb)", lse_b, c=C, b=B) # (T, )
    v_out_b_flat = einshape("cbv->(cb)v", v_out_b, c=C, b=B) # (T, V)
    #lse_b_flat, v_out_b_flat = partial(final_causal_auto, B)(q, q_bar, k_bar, v_bar, lse_bar, qmeta_b)
    #lse_b, v_out_b = vmap(final_ret_auto)(drop_first(qb), drop_last(kb), drop_last(vb), drop_first(q_bar_b), k_bar, v_bar, lse_bar, drop_first(qmeta_b), drop_last(kmeta_b))
    #lse_b, v_out_b = final_ret_causal_auto(drop_first(qb), drop_last(kb), drop_last(vb), drop_first(q_bar_b), k_bar, v_bar, lse_bar, drop_first(qmeta_b), drop_last(kmeta_b))
    #lse_b, v_out_b = final_ret_causal_auto(qb, kb, vb, q_bar_b, k_bar, v_bar, lse_bar, qmeta_b, kmeta_b)
    #num_retrievals = NUM_RETRIEVALS if NUM_RETRIEVALS is not None else 2
    if num_retrievals is not None and num_retrievals > 0:
        lse_b_flat, v_out_b_flat, ret_idx = partial(final_ret_causal_auto, num_retrievals, B, bidiagonal, num_spatial)(q, k, v, q_bar, k_bar, v_bar, lse_bar, qmeta_b, kmeta_b, k_bar_old, v_bar_old, lse_bar_old)
        if dipole:
            lse_b = einshape("(cb)->cb", lse_b_flat, c=C, b=B)
            ret_idx_unflat = einshape("(cb)r->cbr", ret_idx, c=C, b=B)
            #lse_noret_b, _ = vmap(final_auto)(qb*0., q_bar_b, k_bar, v_bar, lse_bar, qmeta_b)
            if approx_dipole:
                #lse_noret_b, _ = vmap(final_auto)(qb, q_bar_b, k_bar, v_bar, lse_bar, qmeta_b)
                cumcnt = lax.cumsum(kmeta_b.cnt, axis=0) - kmeta_b.cnt
                lse_qcluster, vk_bar_merged, merge_weights, merge_logweights = vmap(custom_vk_merge_and_lse)(q_bar_b, k_bar_vk, lse_bar, vk_bar, cumcnt)
                lse_noret_b, _ = vmap(final_auto)(qb, q_bar_b, jnp.broadcast_to(k_bar_vk[:,None], k_bar.shape), v_bar, merge_logweights, qmeta_b)
                #lse_noret_b, _ = vmap(final_auto)(qb, jnp.zeros_like(q_bar_b), jnp.broadcast_to(k_bar_vk[:,None], k_bar.shape), v_bar, jnp.minimum(lse_bar, jnp.broadcast_to(jnp.log(cumcnt[:,None] + 1e-1), lse_bar.shape)), qmeta_b)
                #lse_noret_b = lse_qcluster[jnp.arange(C)[:,None], qmeta_b.lab]
                extra_v_out_b = vmap(ref_final_dipole)(qb, q_bar_b, vk_bar_merged, qmeta_b)#*3e-1
                #extra_v_out_b = vmap(ref_final_dipole_ret)(qb, q_bar_b, vk_bar_merged, merge_weights, vk_bar, qmeta_b, ret_idx_unflat)
            else:
                lse_noret_b, _ = vmap(final_auto)(qb, q_bar_b, k_bar, v_bar, lse_bar, qmeta_b)
                extra_v_out_b = vmap(ref_final_dipole_full_ret)(qb, q_bar_b, k_bar, lse_bar, vk_bar_merged, qmeta_b, ret_idx_unflat)
            extra_v_out_b_flat = einshape("cbv->(cb)v", extra_v_out_b*jnp.exp(lse_noret_b-lse_b)[:,:,None], c=C, b=B)
            v_out_b_flat = v_out_b_flat + extra_v_out_b_flat
    #jax.debug.print("Final v_out_b has NaN: {}", jnp.isnan(v_out_b).any())
    #lse_b = jnp.nan_to_num(lse_b)
    #lse_b = jnp.concatenate([jnp.full_like(lse_b[0:1], DEFAULT_MASK_VALUE), lse_b])
    #v_out_b = jnp.concatenate([jnp.zeros_like(v_out_b[0:1]), v_out_b])

    #lse, v_out = lse_flash, v_out_flash
    #lse, v_out = lse_b, v_out_b
    #lse_merged, v_out_merged = vmap(ref_lse_value_merge)(lse_flash, lse_b, v_out_flash, v_out_b) # (Cm1, B), (Cm1, B, V)
    #lse, v_out = lse_merged, v_out_merged
    #lse_flat = einshape("cb->(cb)", lse, c=C, b=B) # (T, )
    #v_out_flat = einshape("cbv->(cb)v", v_out, c=C, b=B) # (T, V)

    lse_flash_flat = einshape("cb->(cb)", lse_flash, c=C, b=B) # (T, )
    v_out_flash_flat = einshape("cbv->(cb)v", v_out_flash, c=C, b=B) # (T, V)
    lse_merged_flat, v_out_merged_flat = ref_lse_value_merge(lse_flash_flat, lse_b_flat, v_out_flash_flat, v_out_b_flat)
    lse_flat, v_out_flat = lse_merged_flat, v_out_merged_flat

    return lse_flat, v_out_flat

    

@jax.custom_vjp
def lse_cumsum_exclusive_auto(lse, v, k):
    return lse_cumsum_exclusive(lse, v, k)
def lse_cumsum_exclusive_auto_fwd(lse, v, k):
    lse_out, v_out, k_out = lse_cumsum_exclusive(lse, v, k)
    return (lse_out, v_out, k_out), (lse, v, k, lse_out, v_out, k_out)
def lse_cumsum_exclusive_auto_bwd(res, grads):
    lse_in, v_in, k_in, lse_out, v_out, k_out = res
    dlse_out, dv_out, dk_out = grads
    dlse_in, dv_in, dk_in = lse_cumsum_exclusive_bwd(lse_in, v_in, k_in, lse_out, v_out, k_out, dlse_out, dv_out, dk_out)
    return dlse_in, dv_in, dk_in
lse_cumsum_exclusive_auto.defvjp(lse_cumsum_exclusive_auto_fwd, lse_cumsum_exclusive_auto_bwd)

@jax.custom_vjp
def preprocess_auto(q, qmeta):
    return preprocess(q, qmeta)
def preprocess_auto_fwd(q, qmeta):
    q_bar = preprocess(q, qmeta)
    return q_bar, (qmeta,)
def preprocess_auto_bwd(res, dq_bar):
    (qmeta,) = res
    dq = ref_preprocess_bwd(qmeta, dq_bar)
    return (dq, None)
preprocess_auto.defvjp(preprocess_auto_fwd, preprocess_auto_bwd)

@jax.custom_vjp
def initial_auto(q_bar, k, v, kmeta):
    return initial(q_bar, k, v, kmeta)
def initial_auto_fwd(q_bar, k, v, kmeta):
    lse_bar, k_bar, v_bar = initial(q_bar, k, v, kmeta)
    return (lse_bar, k_bar, v_bar), (q_bar, k, v, kmeta, lse_bar, k_bar, v_bar)
def initial_auto_bwd(res, grads):
    q_bar, k, v, kmeta, lse_bar, k_bar, v_bar = res
    dlse_bar, dk_bar, dv_bar = grads
    dq_bar, dk, dv = initial_bwd(q_bar, k, v, kmeta, lse_bar, k_bar, v_bar, dlse_bar, dk_bar, dv_bar)
    #dq_bar, dk, dv = map(jnp.nan_to_num, (dq_bar, dk, dv))
    return dq_bar, dk, dv, None
initial_auto.defvjp(initial_auto_fwd, initial_auto_bwd)

@jax.custom_vjp
def final_auto(q, q_bar, k_bar, v_bar, lse_bar, qmeta):
    return final(q, q_bar, k_bar, v_bar, lse_bar, qmeta)
def final_auto_fwd(q, q_bar, k_bar, v_bar, lse_bar, qmeta):
    lse, v_out = final(q, q_bar, k_bar, v_bar, lse_bar, qmeta)
    return (lse, v_out), (q, q_bar, k_bar, v_bar, lse_bar, qmeta, lse, v_out)
def final_auto_bwd(res, grads):
    q, q_bar, k_bar, v_bar, lse_bar, qmeta, lse, v_out = res
    dlse, dv_out = grads
    dq0, dq_bar0, dk_bar, dv_bar, dlse_bar0 = final_bwd(q, q_bar, k_bar, v_bar, lse_bar, qmeta, lse, v_out, dlse, dv_out)
    return dq0, dq_bar0, dk_bar, dv_bar, dlse_bar0, None
final_auto.defvjp(final_auto_fwd, final_auto_bwd)

@jax.custom_vjp
def final_ret_auto(q, k, v, q_bar, k_bar, v_bar, lse_bar, qmeta, kmeta):
    num_retrievals = NUM_RETRIEVALS if NUM_RETRIEVALS is not None else 2
    m_mon, l_mon, v_out_mon, ret_idx = final_retrieval(num_retrievals, q, q_bar, k_bar, v_bar, lse_bar, qmeta)
    lse, v_out, ret_meta = leaf_retrieval(q, q_bar, k, v, qmeta, kmeta, ret_idx, m_mon, l_mon, v_out_mon)
    return lse, v_out
def final_ret_auto_fwd(q, k, v, q_bar, k_bar, v_bar, lse_bar, qmeta, kmeta):
    num_retrievals = NUM_RETRIEVALS if NUM_RETRIEVALS is not None else 2
    m_mon, l_mon, v_out_mon, ret_idx = final_retrieval(num_retrievals, q, q_bar, k_bar, v_bar, lse_bar, qmeta)
    lse, v_out, ret_meta = leaf_retrieval(q, q_bar, k, v, qmeta, kmeta, ret_idx, m_mon, l_mon, v_out_mon)
    return (lse, v_out), (q, k, v, q_bar, k_bar, v_bar, lse_bar, qmeta, kmeta, lse, v_out, ret_idx, ret_meta)
def final_ret_auto_bwd(res, grads):
    num_retrievals = NUM_RETRIEVALS if NUM_RETRIEVALS is not None else 2
    q, k, v, q_bar, k_bar, v_bar, lse_bar, qmeta, kmeta, lse, v_out, ret_idx, ret_meta = res
    dlse, dv_out = grads
    dq, dq_bar, dk_bar, dv_bar, dlse_bar = final_retrieval_bwd(num_retrievals, q, q_bar, k_bar, v_bar, lse_bar, qmeta, lse, v_out, dlse, dv_out)
    dq_r, dk_r, dv_r = leaf_retrieval_bwd(q, q_bar, k, v, qmeta, kmeta, ret_idx, lse, v_out, dlse, dv_out, ret_meta)
    dq = dq + dq_r
    return dq, dk_r, dv_r, dq_bar, dk_bar, dv_bar, dlse_bar, None, None
final_ret_auto.defvjp(final_ret_auto_fwd, final_ret_auto_bwd)

@partial(jax.custom_vjp, nondiff_argnums=(0,))
def final_causal_auto(causal_block: int, q, q_bar, k_bar, v_bar, lse_bar, qmeta): # (T,D) (Q,D) (C,Q,K,D) (C,Q,K,V) (C,Q,K) (C,Q)
    m_mon, l_mon, v_out_mon, _ = final_retrieval_causal(0, causal_block, q, q_bar, k_bar, v_bar, lse_bar, qmeta)
    lse = math.log(2)*m_mon + jnp.log(jnp.maximum(l_mon, 1e-30))
    v_out = v_out_mon / jnp.maximum(l_mon[...,None], 1e-30)
    return lse, v_out
def final_causal_auto_fwd(causal_block:int, q, q_bar, k_bar, v_bar, lse_bar, qmeta):
    m_mon, l_mon, v_out_mon, _ = final_retrieval_causal(0, causal_block, q, q_bar, k_bar, v_bar, lse_bar, qmeta)
    lse = math.log(2)*m_mon + jnp.log(jnp.maximum(l_mon, 1e-30))
    v_out = v_out_mon / jnp.maximum(l_mon[...,None], 1e-30)
    return (lse, v_out), (q, q_bar, k_bar, v_bar, lse_bar, qmeta, lse, v_out)
def final_causal_auto_bwd(causal_block: int, res, grads):
    q, q_bar, k_bar, v_bar, lse_bar, qmeta, lse, v_out = res
    dlse, dv_out = grads
    dq0, dq_bar0, dk_bar, dv_bar, dlse_bar0 = final_retrieval_causal_bwd(0, causal_block, q, q_bar, k_bar, v_bar, lse_bar, qmeta, lse, v_out, dlse, dv_out)
    return dq0, dq_bar0, dk_bar, dv_bar, dlse_bar0, None
final_causal_auto.defvjp(final_causal_auto_fwd, final_causal_auto_bwd)

def retrieve_all_blocks(causal_block: int, ret_idx):
    T, R = ret_idx.shape
    C = T // causal_block
    q_blk = jnp.arange(T) // causal_block
    causal_mask = q_blk[:, None] > jnp.arange(C)[None, :] # (T, C)
    ret_idx_out = jnp.tile(ret_idx[:, None, :], (1, C, 1)) # (T, C, R)
    blk_idx = jnp.where(causal_mask, jnp.arange(C)[None,:], -1) # (T, C)
    blk_idx_out = jnp.tile(blk_idx[:, :, None], (1, 1, R)) # (T, C, R)
    
    ret_idx_flat = einshape("tcr->t(cr)", ret_idx_out, c=C, r=R) # (T, C*R)
    blk_idx_flat = einshape("tcr->t(cr)", blk_idx_out, c=C, r=R) # (T, C*R)
    return ret_idx_flat, blk_idx_flat
def also_retrieve_previous(causal_block: int, ret_idx, blk_idx):
    T, R = ret_idx.shape
    C = T // causal_block
    q_blk = jnp.arange(T) // causal_block
    one_before = jnp.tile(q_blk[:, None] - 1, (1, R)) # (T, R)
    new_old_blk_idx = jnp.where(one_before == blk_idx, -1, blk_idx)
    blk_idx_out = jnp.concatenate((one_before, new_old_blk_idx), axis=1) # (T, 2R)
    ret_idx_out = jnp.concatenate((ret_idx, ret_idx), axis=1) # (T, 2R)
    return ret_idx_out, blk_idx_out

@partial(jax.custom_vjp, nondiff_argnums=(0,1,2,3))
def final_ret_causal_auto(num_retrievals: int, causal_block: int, bidiagonal: bool, num_spatial: int, q, k, v, q_bar, k_bar, v_bar, lse_bar, qmeta, kmeta, k_bar_a, v_bar_a, lse_bar_a):
    (lse, v_out, ret_idx), res = final_ret_causal_auto_fwd(
        num_retrievals, causal_block, bidiagonal, num_spatial, q, k, v, q_bar, k_bar, v_bar, lse_bar, qmeta, kmeta, k_bar_a, v_bar_a, lse_bar_a
    )
    return lse, v_out, ret_idx
def final_ret_causal_auto_fwd(num_retrievals: int, causal_block: int, bidiagonal: bool, num_spatial: int, q, k, v, q_bar, k_bar, v_bar, lse_bar, qmeta, kmeta, k_bar_a, v_bar_a, lse_bar_a):
    m_mon, l_mon, v_out_mon, ret_idx = final_retrieval_causal(num_retrievals, causal_block, q, q_bar, k_bar, v_bar, lse_bar, qmeta)
    #blk_idx = jnp.cumsum(jnp.zeros_like(ret_idx).at[::causal_block, :].set(1), axis=0) - 2
    m_mon, l_mon, v_out_mon, blk_idx, spatial_retrieval_meta = spatial_retrieval(num_spatial, causal_block, bidiagonal, q, q_bar, k_bar_a, v_bar_a, lse_bar_a, qmeta, ret_idx, m_mon, l_mon, v_out_mon)
    #ret_idx, blk_idx = retrieve_all_blocks(causal_block, ret_idx)
    #spatial_retrieval_meta = None
    DISABLE_MONOPOLES_IN_FORWARD = False
    if DISABLE_MONOPOLES_IN_FORWARD:
        print("WARNING: Monopoles disabled in forward")
        m_mon = jnp.ones_like(m_mon)*DEFAULT_MASK_VALUE/10
        l_mon = jnp.zeros_like(l_mon)
        v_out_mon = jnp.zeros_like(v_out_mon)
    lse, v_out, ret_meta = leaf_retrieval_causal(causal_block, q, q_bar, k, v, qmeta, kmeta, ret_idx, blk_idx, m_mon, l_mon, v_out_mon)
    return (lse, v_out, ret_idx), (q, k, v, q_bar, k_bar, v_bar, lse_bar, qmeta, kmeta, k_bar_a, v_bar_a, lse_bar_a, lse, v_out, ret_idx, blk_idx, ret_meta, spatial_retrieval_meta)
def final_ret_causal_auto_bwd(num_retrievals: int, causal_block: int, bidiagonal: bool, num_spatial: int, res, grads):
    q, k, v, q_bar, k_bar, v_bar, lse_bar, qmeta, kmeta, k_bar_a, v_bar_a, lse_bar_a, lse, v_out, ret_idx, blk_idx, ret_meta, spatial_retrieval_meta = res
    dlse, dv_out, _ = grads
    #dq, dq_bar, dk_bar, dv_bar, dlse_bar = vmap(partial(final_retrieval_bwd, num_retrievals))(q, q_bar, k_bar, v_bar, lse_bar, qmeta, lse, v_out, dlse, dv_out)
    dq, dq_bar, dk_bar, dv_bar, dlse_bar = final_retrieval_causal_bwd(num_retrievals, causal_block, q, q_bar, k_bar, v_bar, lse_bar, qmeta, lse, v_out, dlse, dv_out)

    dq_sr, dk_bar_a, dv_bar_a, dlse_bar_a, dq_bar_sr = spatial_retrieval_bwd(num_spatial, causal_block, bidiagonal, q, q_bar, k_bar_a, v_bar_a, lse_bar_a, qmeta, ret_idx, lse, v_out, dlse, dv_out, spatial_retrieval_meta)
    #dq_sr, dk_bar_a, dv_bar_a, dlse_bar_a, dq_bar_sr = jax.tree.map(lambda x: jnp.zeros_like(x), (dq_sr, dk_bar_a, dv_bar_a, dlse_bar_a, dq_bar_sr))
    #dq_r, dk_r, dv_r = vmap(leaf_retrieval_bwd)(q, q_bar, k, v, qmeta, kmeta, ret_idx, lse, v_out, dlse, dv_out, ret_meta)
    dq_r, dk_r, dv_r = leaf_retrieval_causal_bwd(causal_block, q, q_bar, k, v, qmeta, kmeta, ret_idx, blk_idx, lse, v_out, dlse, dv_out, ret_meta)
    #dq_r, dk_r, dv_r = jax.tree.map(lambda x: jnp.zeros_like(x), (dq_r, dk_r, dv_r))
    dq = dq + dq_r + dq_sr
    dq_bar = dq_bar + dq_bar_sr
    return dq, dk_r, dv_r, dq_bar, dk_bar, dv_bar, dlse_bar, None, None, dk_bar_a, dv_bar_a, dlse_bar_a
    #return dq, dk_r, dv_r, dq_bar, dk_bar, dv_bar, dlse_bar, None, None, None, None, None
final_ret_causal_auto.defvjp(final_ret_causal_auto_fwd, final_ret_causal_auto_bwd)


@jax.custom_vjp
def attn(q, k, v, qmeta, kmeta):
    q_bar = preprocess(q, qmeta) # (Q, D)
    lse_bar, k_bar, v_bar = initial(q_bar, k, v, kmeta)
    if NUM_RETRIEVALS is None:
        lse, v_out = final(q, q_bar, k_bar, v_bar, lse_bar, qmeta)
        return lse, v_out
    else:
        num_retrievals = NUM_RETRIEVALS
        m_mon, l_mon, v_out_mon, ret_idx = final_retrieval(num_retrievals, q, q_bar, k_bar, v_bar, lse_bar, qmeta)
        lse, v_out, _ = leaf_retrieval(q, q_bar, k, v, qmeta, kmeta, ret_idx, m_mon, l_mon, v_out_mon)
        return lse, v_out

def attn_fwd(q, k, v, qmeta, kmeta):
    q_bar = preprocess(q, qmeta) # (Q, D)
    lse_bar, k_bar, v_bar = initial(q_bar, k, v, kmeta) # (Q, K), (Q, K, D), (Q, K, V), (K, V, D)
    if NUM_RETRIEVALS is None:
        lse, v_out = final(q, q_bar, k_bar, v_bar, lse_bar, qmeta) # (Q, n), (Q, n, V)
        return (lse, v_out), (lse, v_out, q_bar, lse_bar, k_bar, v_bar)
    else:
        num_retrievals = NUM_RETRIEVALS
        m_mon, l_mon, v_out_mon, ret_idx = final_retrieval(num_retrievals, q, q_bar, k_bar, v_bar, lse_bar, qmeta)
        lse, v_out, ret_meta = leaf_retrieval(q, q_bar, k, v, qmeta, kmeta, ret_idx, m_mon, l_mon, v_out_mon)
        return (lse, v_out), (lse, v_out, q_bar, lse_bar, k_bar, v_bar, ret_idx, ret_meta)

def _attn_fwd(q, k, v, qmeta, kmeta):
    out, res = attn_fwd(q, k, v, qmeta, kmeta)
    return out, (q, k, v, qmeta, kmeta) + res

def attn_bwd(q, k, v, qmeta, kmeta, res, dlse, dv_out):
    if NUM_RETRIEVALS is None:
        lse, v_out, q_bar, lse_bar, k_bar, v_bar = res
        dq0, dq_bar0, dk_bar, dv_bar, dlse_bar0 = final_bwd(q, q_bar, k_bar, v_bar, lse_bar, qmeta, lse, v_out, dlse, dv_out)
    else:
        lse, v_out, q_bar, lse_bar, k_bar, v_bar, ret_idx, ret_meta = res
        num_retrievals = ret_idx.shape[-1]
        dq0, dq_bar0, dk_bar, dv_bar, dlse_bar0 = final_retrieval_bwd(num_retrievals, q, q_bar, k_bar, v_bar, lse_bar, qmeta, lse, v_out, dlse, dv_out)
        dq_r, dk_r, dv_r = leaf_retrieval_bwd(q, q_bar, k, v, qmeta, kmeta, ret_idx, lse, v_out, dlse, dv_out, ret_meta)
    dlse_bar = dlse_bar0
    dq_bar1, dk, dv = initial_bwd(q_bar, k, v, kmeta, lse_bar, k_bar, v_bar, dlse_bar, dk_bar, dv_bar)
    dq1 = ref_preprocess_bwd(qmeta, dq_bar0 + dq_bar1) # (N, D)
    dq = dq0 + dq1 # (Q, D)
    dk = dk
    dv = dv
    if NUM_RETRIEVALS is not None:
        dq = dq + dq_r
        dk = dk + dk_r
        dv = dv + dv_r
    return dq, dk, dv

def _attn_bwd(res, grads):
    q, k, v, qmeta, kmeta = res[:5]
    dlse, dv_out = grads
    dq, dk, dv = attn_bwd(q, k, v, qmeta, kmeta, res[5:], dlse, dv_out)
    return dq, dk, dv, None, None

attn.defvjp(_attn_fwd, _attn_bwd)






##############################################################################################
#Evaluation code:
##############################################################################################

def sq_err(ref, x):
    ref_mean = jnp.mean(ref)
    #ref = ref - ref_mean
    #x = x - ref_mean
    ref_msq = jnp.mean(jnp.square(ref))
    err_msq = jnp.mean(jnp.square(ref - x))
    return jnp.where(ref_msq > 0, err_msq / ref_msq, jnp.inf).astype(jnp.float32)

def axis_corr(ref, x, axis=-3):
    correct_axis = np.argmax(ref.shape)
    axis=correct_axis
    assert ref.shape[axis] >= max(ref.shape)
    ref_mean = jnp.mean(ref, axis=axis, keepdims=True)
    x_mean = jnp.mean(x, axis=axis, keepdims=True)
    ref_var = jnp.mean(jnp.square(ref - ref_mean))
    x_var = jnp.mean(jnp.square(x - x_mean))
    cov = jnp.mean((ref - ref_mean) * (x - x_mean))
    corr = cov / (jnp.sqrt(ref_var) * jnp.sqrt(x_var) + 1e-6)
    return corr.astype(jnp.float32)

def compute_error_metrics(v_out_ref, v_out):
    B, T, H, D = v_out_ref.shape
    token_axis = np.argmax(v_out_ref.shape)
    head_sqerrs = vmap(sq_err, in_axes=-2, out_axes=-1)(v_out_ref, v_out)
    head_corrs = vmap(axis_corr, in_axes=-2, out_axes=-1)(v_out_ref, v_out)
    return head_sqerrs, head_corrs

def compute_error_metrics_list(v_out_ref_list, v_out_list):
    metrics = [compute_error_metrics(ref, v) for ref, v in zip(v_out_ref_list, v_out_list)]
    sq_errs, corrs = zip(*metrics)
    sq_errs = jnp.concatenate(sq_errs)
    corrs = jnp.concatenate(corrs)
    return sq_errs, corrs

def print_list_metrics(fn, qs_list, ks_list, vs_list, exact_list):
    jit_fn = jax.jit(fn)
    exact_lses, exact_v_outs = zip(*exact_list)
    approx_results = [jax.block_until_ready(jit_fn(q, k, v)) for q, k, v in zip(qs_list, ks_list, vs_list)]
    approx_lses, approx_v_outs = zip(*approx_results)
    #print(tree_sq_err(exact_list[0], approx_results[0]))
    sq_errs, corrs = compute_error_metrics_list(exact_v_outs, approx_v_outs)
    mean_err = jnp.mean(sq_errs)
    std_err = jnp.std(sq_errs)
    max_err = jnp.max(sq_errs)
    mean_corr = jnp.mean(corrs)
    std_corr = jnp.std(corrs)
    min_corr = jnp.min(corrs)
    print(f"OutRelsqerr: mean {mean_err:.5f} std {std_err:.5f} worst {max_err:.5f}")
    print(f"OutCorr: mean {mean_corr:.4f} std {std_corr:.4f} worst {min_corr:.4f}")
    
    

mha_sq_err = vmap(vmap(sq_err, in_axes=1))
tree_sq_err = partial(jax.tree.map, lambda ref, x: float(sq_err(ref, x)))
tree_axis_corr = partial(jax.tree.map, lambda ref, x: float(axis_corr(ref, x)))
_tree_sq_err = partial(jax.tree.map, lambda ref, x: sq_err(ref, x))
_tree_axis_corr = partial(jax.tree.map, lambda ref, x: axis_corr(ref, x))

def profile_fn(fn, args, *other):
    jit_fn = jax.jit(fn)
    with nvtx.annotate("target"):
        out = jax.block_until_ready(jit_fn(*args))
    return out

def benchmark_fn(fn, args, ref_out, *, warmup_iters=100, iters=100):
    start_time = time()
    jit_fn = jax.jit(fn)
    fn_out = jax.block_until_ready(jit_fn(*args))
    #print(f"shapes: {jax.tree.map(jnp.shape, ref_out)} {jax.tree.map(jnp.shape, fn_out)}")
    end_time = time()
    sq_errs = tree_sq_err(ref_out, fn_out)
    max_sq_errs = vmap(_tree_sq_err, in_axes=2)(ref_out, fn_out)
    max_sq_errs = jax.tree.map(float, jax.tree.map(jnp.max, max_sq_errs))
    print(f"First call took {end_time - start_time:.4f} seconds.")
    print(f"Output errors: {sq_errs}, max: {max_sq_errs}")
    corrs = tree_axis_corr(ref_out, fn_out)
    min_corrs = vmap(_tree_axis_corr, in_axes=2)(ref_out, fn_out)
    min_corrs = jax.tree.map(float, jax.tree.map(jnp.min, min_corrs))
    print(f"Output correlations: {corrs}, min: {min_corrs}")
    # Warmup iterations
    for _ in range(warmup_iters):
        jax.block_until_ready(jit_fn(*args))
    # Benchmark iterations
    start_time = time()
    for _ in range(iters):
        jax.block_until_ready(jit_fn(*args))
    end_time = time()
    print(f"Average time over {iters} iterations: {1e3 * (end_time - start_time) / iters:.4f} ms.")

_multi_level_clustering = MultiLevelClustering(
    Qs=(64,),
    Ks=(64,),
    coupled_clustering=False,
    inner_iters=3,
    outer_iters=3,
    max_cluster_scale=1.5,
    #max_cluster_scale=2,
    compute_indices=True,
)

def _do_multi_level_clustering(qs, ks, vs):
    (qcnt, qlab, qfwd, qbwd), (kcnt, klab, kfwd, kbwd) = _multi_level_clustering.do_clustering(qs, ks, vs)
    qmeta = ClusterData(qcnt, qlab, qfwd, qbwd)
    kmeta = ClusterData(kcnt, klab, kfwd, kbwd)
    return qmeta, kmeta

_do_multi_level_clustering_mha = vmap(vmap(_do_multi_level_clustering, in_axes=1, out_axes=1))

def do_clustering(qs, ks, vs):
    EXPAND_RATIO = 4.0
    N, D = qs.shape
    K = 64
    max_cluster_size = math.ceil(EXPAND_RATIO * N / K)
    chosen_do_clustering = partial(pallas_do_clustering, K, max_cluster_size, 1)
    return chosen_do_clustering(qs, ks, vs)

def cluster_attn(qs, ks, vs):
    #chosen_do_clustering = _do_multi_level_clustering
    #EXPAND_RATIO = 4.0
    #N, D = qs.shape
    #K = 128
    #max_cluster_size = math.ceil(EXPAND_RATIO * N / K)
    #chosen_do_clustering = partial(pallas_do_clustering, K, max_cluster_size, 1)
    #qmeta, kmeta = chosen_do_clustering(qs, ks, vs)
    qmeta, kmeta = do_clustering(qs, ks, vs)
    return attn(qs, ks, vs, qmeta, kmeta)

cluster_attn_mha = vmap(vmap(cluster_attn, in_axes=1, out_axes=1))

def get_data(dtype, n, only_first_batch=False, subsample_heads=None):
    #data = jnp.load("examples/minigpt/qkv_data_64_T1/qkv_step_1600.npz")
    #data = jnp.load("examples/minigpt/qkv_data_64_T1/qkv_step_11000.npz")
    #data = jnp.load("examples/minigpt/qkv_data_64_T1/qkv_step_15200.npz")
    #data = jnp.load("examples/minigpt/qkv_data_stack_bucket6_64k/qkv_step_1800.npz")
    #data = jnp.load("examples/minigpt/qkv_data_1B_science_64k_cudnn/qkv_step_180000.npz")
    data = jnp.load("examples/minigpt/qkv_data_1B_stack_64k_cudnn/qkv_step_180000.npz")
    #data = jnp.load("examples/minigpt/qkv_data_1B_science_64k_hard_cudnn/qkv_step_54.npz")
    #data = jnp.load("examples/minigpt/qkv_data_96M2B_stack5678_64k_cudnn/qkv_step_15000.npz")
    #data = jnp.load("examples/minigpt/qkv_data_96M2B_stack5678_64k_cudnn_head128/qkv_step_7000.npz")
    ks, vs, qs = data["k"], data["v"], data["q"] # (B, N, H, D)
    if only_first_batch:
        print(f"WARNING: Only using first batch of data, original shape: {ks.shape}, {vs.shape}, {qs.shape}")
        ks = ks[:1, ...]
        vs = vs[:1, ...]
        qs = qs[:1, ...]
    if subsample_heads is not None:
        #OFFSET = 0
        OFFSET = 3
        print(f"WARNING: Only using one in {subsample_heads} heads, original shape: {ks.shape}, {vs.shape}, {qs.shape}")
        ks = ks[:, :, OFFSET::subsample_heads, :]
        vs = vs[:, :, OFFSET::subsample_heads, :]
        qs = qs[:, :, OFFSET::subsample_heads, :]
    make_8_double_width_heads = False
    if make_8_double_width_heads:
        print(f"WARNING: Making 8 double-width heads, original shape: {ks.shape}, {vs.shape}, {qs.shape}")
        ks = einshape("bn(ht)d->bnh(td)", ks, t=2)[:,:,:8,:]
        vs = einshape("bn(ht)v->bnh(tv)", vs, t=2)[:,:,:8,:]
        qs = einshape("bn(ht)d->bnh(td)", qs, t=2)[:,:,:8,:]
    #ks = ks[:,:,:,:32]
    #vs = vs[:,:,:,:32]
    #qs = qs[:,:,:,:32]
    qs = einshape("bnhd->1(bn)hd", qs)
    ks = einshape("bnhd->1(bn)hd", ks)
    vs = einshape("bnhd->1(bn)hd", vs)
    B, N, H, D = qs.shape
    B, N, H, V = vs.shape
    assert N % n == 0, f"N ({N}) must be divisible by n ({n})"
    M = N // n
    qs = einshape("b(mn)hd->(bm)nhd", qs, m=M, n=n) #/ 2.0 #/1.4
    ks = einshape("b(mn)hd->(bm)nhd", ks, m=M, n=n)
    vs = einshape("b(mn)hv->(bm)nhv", vs, m=M, n=n)
    #qs = qs[:,:,3:4,:]
    #ks = ks[:,:,3:4,:]
    #vs = vs[:,:,3:4,:]

    #ks = ks[:, :n, :, :] # (B, N, H, D)
    #vs = vs[:, :n, :, :] # (B, N, H, V)
    #qs = qs[:, -n:, :, :] # (B, N, H, D)
    clustering = MultiLevelClustering(
        Qs=(64,),
        Ks=(64,),
        coupled_clustering=False,
        inner_iters=5,
        outer_iters=5,
        max_cluster_scale=4.0,
        #max_cluster_scale=2,
        compute_indices=True,
    )
    multi_attention = MultiLevelAttention(
        clustering=clustering,
    )
    def baseline_approx(qs, ks, vs, qmeta, kmeta):
        scale = math.sqrt(math.sqrt(qs.shape[-1]))
        lse, out = multi_attention.attend(
            queries=qs / scale, qlabels=qmeta.lab, keys=ks / scale, klabels=kmeta.lab, values=vs,
            qfwd_indices=qmeta.fwd, qbwd_indices=qmeta.bwd,
            kfwd_indices=kmeta.fwd, kbwd_indices=kmeta.bwd,
        )
        return lse, out

    baseline_approx_mha = vmap(vmap(baseline_approx, in_axes=1, out_axes=1))

    pallas_flash_mha = vmap(vmap(partial(pallas_flash, sm_scale=1.0 / math.sqrt(qs.shape[-1])), in_axes=1, out_axes=1))
    pallas_flash_causal = partial(pallas_flash, sm_scale=1.0 / math.sqrt(qs.shape[-1]), window=(N, 0))
    pallas_flash_causal_mha = vmap(vmap(pallas_flash_causal, in_axes=1, out_axes=1))

    #clustering_mha = vmap(vmap(clustering.do_clustering, in_axes=1, out_axes=1))
    #(qcnt, qlab, qfwd, qbwd), (kcnt, klab, kfwd, kbwd) = clustering_mha(qs, ks, vs)
    #qmeta = ClusterData(qcnt, qlab, qfwd, qbwd)
    #kmeta = ClusterData(kcnt, klab, kfwd, kbwd)
    do_clustering_mha = vmap(vmap(do_clustering, in_axes=1, out_axes=1))
    #qmeta, kmeta = do_clustering_mha(qs, ks, vs)
    qmeta, kmeta = None, None
    return pallas_flash_mha, pallas_flash_causal, pallas_flash_causal_mha, baseline_approx, baseline_approx_mha, qs.astype(dtype), ks.astype(dtype), vs.astype(dtype), qmeta, kmeta

def main():
    from jax.sharding import PartitionSpec as P
    MULTI_DEVICE = True
    if MULTI_DEVICE:
        mesh = jax.make_mesh((1,8,), ('batch', 'heads',), axis_types=(jax.sharding.AxisType.Auto,)*2)
        jax.set_mesh(mesh)
        shard_mha = partial(jax.shard_map, in_specs=P('batch',None,'heads',None), out_specs=(P('batch',None,'heads',),P('batch',None,'heads',None)), check_vma=False)
    dtype = jnp.bfloat16
    requested_n = 2**16
    ONLY_FIRST_BATCH = False
    SUBSAMPLE_HEADS = None
    pallas_flash_mha, pallas_flash_causal, pallas_flash_causal_mha, baseline_approx, baseline_approx_mha, qs, ks, vs, qmeta, kmeta = get_data(dtype=dtype, n=requested_n, only_first_batch=ONLY_FIRST_BATCH, subsample_heads=SUBSAMPLE_HEADS)
    if True:
        split_fac = 5
        qs_list = jnp.split(qs, split_fac, axis=-2)
        ks_list = jnp.split(ks, split_fac, axis=-2)
        vs_list = jnp.split(vs, split_fac, axis=-2)
        if False:
            print(f"WARNING: Discarding most of qkv for benchmarking speed only!")
            qs_list = qs_list[:1]
            ks_list = ks_list[:1]
            vs_list = vs_list[:1]
        qs, ks, vs = qs_list[0], ks_list[0], vs_list[0]
    if MULTI_DEVICE:
        pallas_flash_mha = shard_mha(pallas_flash_mha)
        pallas_flash_causal_mha = shard_mha(pallas_flash_causal_mha)
    #k_num_empty = jnp.sum(kmeta.cnt == 0)
    #q_num_empty = jnp.sum(qmeta.cnt == 0)
    #print(f"k clusters empty: {k_num_empty}, q clusters empty: {q_num_empty}")
    ref_attn_mha = vmap(vmap(ref_attn, in_axes=1, out_axes=1))
    attn_mha = vmap(vmap(attn, in_axes=1, out_axes=1))
    #profile_fn(attn_mha, (qs, ks, vs, qmeta, kmeta))
    #exit()
    exact_out = pallas_flash_mha(qs.astype(jnp.float32), ks.astype(jnp.float32), vs.astype(jnp.float32))
    exact_causal_list = [pallas_flash_causal_mha(q.astype(jnp.float32), k.astype(jnp.float32), v.astype(jnp.float32)) for q,k,v in zip(qs_list, ks_list, vs_list)]
    B, N, H, D = qs.shape
    _, _, _, V = vs.shape
    print(f"B: {B}, N: {N}, H: {H}, D: {D}, V: {V}")
    dlse_shape = (B, N, H)
    dv_out_shape = (B, N, H, V)
    dlse = jax.random.normal(jax.random.PRNGKey(0), shape=dlse_shape, dtype=jnp.float32) #* 0.0
    dv_out = jax.random.normal(jax.random.PRNGKey(1), shape=dv_out_shape, dtype=dtype) #* 0.0
    @partial(jax.grad, argnums=(0, 1, 2))
    def pallas_flash_mha_grad(qs, ks, vs):
        lse, v_out = pallas_flash_mha(qs, ks, vs)
        return jnp.sum(v_out * dv_out) + jnp.sum(lse * dlse)
    exact_grad = pallas_flash_mha_grad(qs.astype(jnp.float32), ks.astype(jnp.float32), vs.astype(jnp.float32))
    @partial(jax.grad, argnums=(0, 1, 2))
    def attn_grad_mha(qs, ks, vs, qmeta, kmeta):
        lse, v_out = attn_mha(qs, ks, vs, qmeta, kmeta)
        return jnp.sum(v_out * dv_out) + jnp.sum(lse * dlse)

    #profile_fn(attn_grad_mha, (qs, ks, vs, qmeta, kmeta))
    #print("Benchmarking mha Attn Attention...")
    #benchmark_fn(attn_mha, (qs, ks, vs, qmeta, kmeta), exact_out)
    #print("Benchmarking mha Cluster Attention...")
    #benchmark_fn(cluster_attn_mha, (qs, ks, vs), exact_out)
    #print("Benchmarking mha exact Pallas Flash Attention...")
    #benchmark_fn(pallas_flash_mha, (qs, ks, vs), exact_out)
    #print("Benchmarking mha Attn grad...")
    #benchmark_fn(attn_grad_mha, (qs, ks, vs, qmeta, kmeta), exact_grad)
    #exit()

    BATCH_ELEM = 0
    HEAD_ELEM = 23
    def one_head(array):
        return array[BATCH_ELEM,:, HEAD_ELEM, ...]
    #def move_head(x):
    #    return jax.tree.map(lambda arr: einshape("bnh...->bhn...", arr), x)
    oh_qs, oh_ks, oh_vs = one_head(qs), one_head(ks), one_head(vs)
    #oh_qmeta = jax.tree.map(one_head, qmeta)
    #oh_kmeta = jax.tree.map(one_head, kmeta)
    print(f"oh shapes: {oh_qs.shape}, {oh_ks.shape}, {oh_vs.shape}")
    #print(f"oh qmeta: {oh_qmeta.fwd.shape}, oh kmeta: {oh_kmeta.fwd.shape}")
    
    B, N, H, D = qs.shape
    _, _, _, V = vs.shape
    sm_scale = 1.0 / math.sqrt(D)
    gradient_scale = 1e0
    print(f"B: {B}, N: {N}, H: {H}, D: {D}, V: {V}")
    dlse_shape = (B, N, H)
    dv_out_shape = (B, N, H, V)
    dlse = jax.random.normal(jax.random.PRNGKey(0), shape=dlse_shape, dtype=jnp.float32) * gradient_scale #* 0.0
    dv_out = jax.random.normal(jax.random.PRNGKey(1), shape=dv_out_shape, dtype=dtype) * gradient_scale #* 0.0
    oh_dlse = one_head(dlse)
    oh_dv_out = one_head(dv_out)

    @partial(jax.grad, argnums=(0, 1, 2))
    def attn_grad_mha(qs, ks, vs, qmeta, kmeta):
        lse, v_out = attn_mha(qs, ks, vs, qmeta, kmeta)
        return jnp.sum(v_out * dv_out) + jnp.sum(lse * dlse)

    @partial(jax.grad, argnums=(0, 1, 2))
    def cluster_attn_grad_mha(qs, ks, vs):
        lse, v_out = cluster_attn_mha(qs, ks, vs)
        return jnp.sum(v_out * dv_out) + jnp.sum(lse * dlse)

    @partial(jax.grad, argnums=(0, 1, 2))
    def pallas_flash_mha_grad(qs, ks, vs):
        lse, v_out = pallas_flash_mha(qs, ks, vs)
        return jnp.sum(v_out * dv_out) + jnp.sum(lse * dlse)

    if MULTI_DEVICE:
        sharded_cudnn_flash_mha = shard_mha(cudnn_flash_mha)
    else:
        sharded_cudnn_flash_mha = cudnn_flash_mha
    @partial(jax.grad, argnums=(0, 1, 2))
    def cudnn_flash_mha_grad(qs, ks, vs):
        lse, v_out = sharded_cudnn_flash_mha(qs, ks, vs)
        return jnp.sum(v_out * dv_out) + jnp.sum(lse * dlse)

    ref_causal_attn_mha = vmap(vmap(ref_causal_attn, in_axes=1, out_axes=1))
    if MULTI_DEVICE:
        ref_causal_attn_mha = shard_mha(ref_causal_attn_mha)
    @partial(jax.grad, argnums=(0, 1, 2))
    def ref_causal_attn_mha_grad(qs, ks, vs):
        lse, v_out = ref_causal_attn_mha(qs, ks, vs)
        return jnp.sum(v_out * dv_out) + jnp.sum(lse * dlse)
    #profile_fn(ref_causal_attn_mha_grad, (qs, ks, vs))
    #exit()

    print(f"Evaluating ref_causal_attn_mha...")
    print_list_metrics(ref_causal_attn_mha, qs_list, ks_list, vs_list, exact_causal_list)
    #approx_0 = ref_causal_attn_mha(qs_list[0], ks_list[0], vs_list[0])
    #print(tree_sq_err(exact_causal_list[0], approx_0))
    #exit()

    @partial(jax.grad, argnums=(0, 1, 2))
    def pallas_flash_causal_mha_grad(qs, ks, vs):
        lse, v_out = pallas_flash_causal_mha(qs, ks, vs)
        return jnp.sum(v_out * dv_out) + jnp.sum(lse * dlse)

    block_diagonal_causal_mha = vmap(vmap(block_diagonal_causal_attn, in_axes=1, out_axes=1))
    if MULTI_DEVICE:
        block_diagonal_causal_mha = shard_mha(block_diagonal_causal_mha)
    @partial(jax.grad, argnums=(0, 1, 2))
    def block_diagonal_causal_mha_grad(qs, ks, vs):
        lse, v_out = block_diagonal_causal_mha(qs, ks, vs)
        return jnp.sum(v_out * dv_out) + jnp.sum(lse * dlse)

    print(f"Evaluating block_diagonal_causal_mha ...")
    print_list_metrics(block_diagonal_causal_mha, qs_list, ks_list, vs_list, exact_causal_list)

    moba_causal_mha = vmap(vmap(moba_causal_attn, in_axes=1, out_axes=1))
    if MULTI_DEVICE:
        moba_causal_mha = shard_mha(moba_causal_mha)
    @partial(jax.grad, argnums=(0, 1, 2))
    def moba_causal_mha_grad(qs, ks, vs):
        lse, v_out = moba_causal_mha(qs, ks, vs)
        return jnp.sum(v_out * dv_out) + jnp.sum(lse * dlse)

    #print(f"Evaluating moba_causal_mha...")
    #print_list_metrics(moba_causal_mha, qs_list, ks_list, vs_list, exact_causal_list)

    ######### One Head Gradients #########

    @partial(jax.grad, argnums=(0, 1, 2))
    def pallas_flash_grad(oh_qs, oh_ks, oh_vs):
        lse, v_out = pallas_flash(oh_qs.astype(jnp.float32), oh_ks.astype(jnp.float32), oh_vs.astype(jnp.float32), sm_scale=sm_scale)
        return jnp.sum(v_out * oh_dv_out.astype(jnp.float32)) + jnp.sum(lse * oh_dlse.astype(jnp.float32))

    @partial(jax.grad, argnums=(0, 1, 2))
    def ref_attn_grad(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta):
        lse, v_out = ref_attn(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        return jnp.sum(v_out * oh_dv_out) + jnp.sum(lse * oh_dlse)

    @partial(jax.grad, argnums=(0, 1, 2))
    def attn_grad(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta):
        lse, v_out = attn(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        return jnp.sum(v_out * oh_dv_out) + jnp.sum(lse * oh_dlse)

    ######### One Head Causal Gradients #########

    @partial(jax.grad, argnums=(0, 1, 2))
    def pallas_flash_causal_grad(oh_qs, oh_ks, oh_vs):
        lse, v_out = pallas_flash_causal(oh_qs.astype(jnp.float32), oh_ks.astype(jnp.float32), oh_vs.astype(jnp.float32))
        return jnp.sum(v_out * oh_dv_out.astype(jnp.float32)) + jnp.sum(lse * oh_dlse.astype(jnp.float32))

    @partial(jax.grad, argnums=(0, 1, 2))
    def ref_causal_attn_grad(oh_qs, oh_ks, oh_vs):
        lse, v_out = ref_causal_attn(oh_qs, oh_ks, oh_vs)
        return jnp.sum(v_out * oh_dv_out) + jnp.sum(lse * oh_dlse)


    #attn_grad_mha_out = attn_grad_mha(qs, ks, vs, qmeta, kmeta)

    exact_oh_causal_out = pallas_flash_causal(oh_qs.astype(jnp.float32), oh_ks.astype(jnp.float32), oh_vs.astype(jnp.float32))
    ref_oh_causal_out = ref_causal_attn(oh_qs, oh_ks, oh_vs)
    print("Ref causal vs exact causal one head output correlations:")
    print(tree_axis_corr(exact_oh_causal_out, ref_oh_causal_out))
    print("... output errors:")
    print(tree_sq_err(exact_oh_causal_out, ref_oh_causal_out))
    exact_oh_causal_grad = pallas_flash_causal_grad(oh_qs.astype(jnp.float32), oh_ks.astype(jnp.float32), oh_vs.astype(jnp.float32))
    ref_oh_causal_grad = ref_causal_attn_grad(oh_qs, oh_ks, oh_vs)
    print("Ref causal vs exact causal one head grad correlations:")
    print(tree_axis_corr(exact_oh_causal_grad, ref_oh_causal_grad))
    #exit()

    exact_causal_out = pallas_flash_causal_mha(qs.astype(jnp.float32), ks.astype(jnp.float32), vs.astype(jnp.float32))
    exact_causal_grad = pallas_flash_causal_mha_grad(qs.astype(jnp.float32), ks.astype(jnp.float32), vs.astype(jnp.float32))
    causal_cluster_mha = vmap(vmap(do_causal_cluster, in_axes=1, out_axes=1))
    #causal_cluster_out = causal_cluster_mha(qs, ks, vs)


    with nvtx.annotate("approximate_grad"):
        #print("Benchmarking mha causal cluster...")
        #benchmark_fn(causal_cluster_mha, (qs, ks, vs), causal_cluster_out, warmup_iters=100, iters=100)
        print("Benchmarking mha ref causal Attn Attention...")
        benchmark_fn(ref_causal_attn_mha, (qs, ks, vs), exact_causal_out, warmup_iters=10, iters=100)
        #exit()
        print("Benchmarking mha ref causal attn grad...")
        benchmark_fn(ref_causal_attn_mha_grad, (qs, ks, vs), exact_causal_grad, warmup_iters=10, iters=100)
        #exit()
        #print("Benchmarking mha moba causal Attention...")
        #benchmark_fn(moba_causal_mha, (qs, ks, vs), exact_causal_out, warmup_iters=10, iters=100)
        print("Benchmarking mha moba causal attn grad...")
        benchmark_fn(moba_causal_mha_grad, (qs, ks, vs), exact_causal_grad, warmup_iters=10, iters=100)
        #exit()
        print("Benchmarking mha causal block diagonal Attention...")
        benchmark_fn(block_diagonal_causal_mha, (qs, ks, vs), exact_causal_out, warmup_iters=10, iters=100)
        print("Benchmarking mha causal block diagonal grad...")
        benchmark_fn(block_diagonal_causal_mha_grad, (qs, ks, vs), exact_causal_grad, warmup_iters=10, iters=100)

        #exit()
        print("Benchmarking mha causal Pallas Flash Attention grad...")
        benchmark_fn(pallas_flash_causal_mha_grad, (qs, ks, vs), exact_causal_grad, warmup_iters=10, iters=100)
        print("Benchmarking mha causal CuDNN Flash Attention grad...")
        benchmark_fn(cudnn_flash_mha_grad, (qs, ks, vs), exact_causal_grad, warmup_iters=10, iters=100)
        exit()
        print("Benchmarking mha moba causal Attention...")
        benchmark_fn(moba_causal_mha, (qs, ks, vs), exact_causal_out, warmup_iters=10, iters=10)
        print("Benchmarking mha moba causal attn grad...")
        benchmark_fn(moba_causal_mha_grad, (qs, ks, vs), exact_causal_grad, warmup_iters=10, iters=10)
    exit()




    exact_out = pallas_flash_mha(qs.astype(jnp.float32), ks.astype(jnp.float32), vs.astype(jnp.float32))
    #assert (dlse.shape, dv_out.shape) == (exact_out[0].shape, exact_out[1].shape), f"Expected shapes {exact_out[0].shape}, {exact_out[1].shape} but got {dlse.shape}, {dv_out.shape}"
    exact_oh_out = pallas_flash(oh_qs.astype(jnp.float32), oh_ks.astype(jnp.float32), oh_vs.astype(jnp.float32), sm_scale=sm_scale)
    #baseline_out = baseline_approx_mha(qs, ks, vs, qmeta, kmeta)
    #baseline_oh_out = baseline_approx(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
    ref_oh_out = ref_attn(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
    attn_oh_out = attn(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)

    #print(tree_sq_err(exact_out, baseline_out))
    #print(tree_sq_err(jax.tree.map(one_head, exact_out), jax.tree.map(one_head, baseline_out)))
    print("Ref vs exact one head output correlations:")
    print(tree_axis_corr(exact_oh_out, ref_oh_out))
    exit()
    print("Ref vs exact one head sq error:")
    print(tree_sq_err(exact_oh_out, ref_oh_out))
    print("Attn vs ref one head sq error:")
    print(tree_sq_err(ref_oh_out, attn_oh_out))
    print("Attn vs exact one head sq error:")
    print(tree_sq_err(exact_oh_out, attn_oh_out))
    print("Attn vs exact one head output correlations:")
    print(tree_axis_corr(exact_oh_out, attn_oh_out))


    #exit()
    print("Benchmarking mha Cluster Attention...")
    benchmark_fn(cluster_attn_mha, (qs, ks, vs), exact_out)
    print("Benchmarking mha Attn Attention...")
    benchmark_fn(attn_mha, (qs, ks, vs, qmeta, kmeta), exact_out)
    exit()
    print("Benchmarking mha exact Pallas Flash Attention...")
    benchmark_fn(pallas_flash_mha, (qs, ks, vs), exact_out)
    #exit()

    exact_oh_grad = pallas_flash_grad(oh_qs.astype(jnp.float32), oh_ks.astype(jnp.float32), oh_vs.astype(jnp.float32))
    ref_oh_grad = ref_attn_grad(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
    attn_oh_grad = attn_grad(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)

    print("Ref vs exact one head grad correlations:")
    print(tree_axis_corr(exact_oh_grad, ref_oh_grad))
    print("Ref vs exact one head grad sq error:")
    print(tree_sq_err(exact_oh_grad, ref_oh_grad))
    print("Attn vs ref one head grad sq error:")
    print(tree_sq_err(ref_oh_grad, attn_oh_grad))
    #exit()


    ref_attn_out, ref_attn_res = ref_attn_fwd(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
    attn_out, attn_res = attn_fwd(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
    #print(tree_sq_err(ref_attn_res, attn_res))
    ref_attn_bwd_out = ref_attn_bwd(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta, ref_attn_res, oh_dlse, oh_dv_out)
    attn_bwd_out = attn_bwd(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta, attn_res, oh_dlse, oh_dv_out)
    #exit()
    ref_attn_fwd_mha = vmap(vmap(ref_attn_fwd, in_axes=1, out_axes=1))
    _, ref_attn_res_mha = ref_attn_fwd_mha(qs, ks, vs, qmeta, kmeta)
    attn_fwd_mha = vmap(vmap(attn_fwd, in_axes=1, out_axes=1))
    _, attn_res_mha = attn_fwd_mha(qs, ks, vs, qmeta, kmeta)
    ref_attn_bwd_mha = vmap(vmap(ref_attn_bwd, in_axes=1, out_axes=1))
    ref_attn_bwd_mha_out = ref_attn_bwd_mha(qs, ks, vs, qmeta, kmeta, ref_attn_res_mha, dlse, dv_out)
    attn_bwd_mha = vmap(vmap(attn_bwd, in_axes=1, out_axes=1))
    attn_bwd_mha_out = attn_bwd_mha(qs, ks, vs, qmeta, kmeta, attn_res_mha, dlse, dv_out)
    #print(f"attn_grad error: {tree_sq_err(attn_grad_mha_out, attn_bwd_mha_out)}")

    exact_mha_grad = pallas_flash_mha_grad(qs, ks, vs)

    print("Benchmarking mha Cluster Attention grad...")
    benchmark_fn(cluster_attn_grad_mha, (qs, ks, vs), exact_mha_grad)
    print("Benchmarking mha Attn grad...")
    benchmark_fn(attn_grad_mha, (qs, ks, vs, qmeta, kmeta), exact_mha_grad)
    print("Benchmarking mha attn_bwd...")
    benchmark_fn(attn_bwd_mha, (qs, ks, vs, qmeta, kmeta, attn_res_mha, dlse, dv_out), exact_mha_grad)
    print("Benchmarking mha Pallas Flash Attention grad...")
    benchmark_fn(pallas_flash_mha_grad, (qs, ks, vs), exact_mha_grad)
    exit()

    def curried_ref_attn_oh(qs, ks, vs):
        return ref_attn(qs, ks, vs, oh_qmeta, oh_kmeta)
    _, ref_attn_oh_vjp = jax.vjp(curried_ref_attn_oh, oh_qs, oh_ks, oh_vs)
    ref_attn_oh_vjp_out = ref_attn_oh_vjp((oh_dlse, oh_dv_out))
    _, ref_attn_oh_vjp_fp32 = jax.vjp(curried_ref_attn_oh, oh_qs.astype(jnp.float32), oh_ks.astype(jnp.float32), oh_vs.astype(jnp.float32))
    ref_attn_oh_vjp_out_fp32 = ref_attn_oh_vjp_fp32((oh_dlse.astype(jnp.float32), oh_dv_out.astype(jnp.float32)))

    _, pallas_flash_mha_vjp = jax.vjp(pallas_flash_mha, qs, ks, vs)
    pallas_flash_mha_vjp_out = pallas_flash_mha_vjp((dlse, dv_out))
    _, cudnn_flash_mha_vjp = jax.vjp(cudnn_flash_mha, qs, ks, vs)
    cudnn_flash_mha_vjp_out = cudnn_flash_mha_vjp((dlse, dv_out))

    print("Benchmarking mha Cluster Attention...")
    benchmark_fn(cluster_attn_mha, (qs, ks, vs), exact_out)
    #exit()
    print("Benchmarking mha Cluster Attention grad...")
    benchmark_fn(cluster_attn_grad_mha, (qs, ks, vs), ref_attn_bwd_mha_out)
    #exit()
    print("Benchmarking mha Pallas Flash Attention grad...")
    benchmark_fn(pallas_flash_mha_grad, (qs, ks, vs), ref_attn_bwd_mha_out)
    print("Benchmarking mha CUDNN Flash Attention grad...")
    benchmark_fn(cudnn_flash_mha_grad, (qs, ks, vs), ref_attn_bwd_mha_out)
    exit()

    print("Benchmarking mha baseline clustering...")
    benchmark_fn(_do_multi_level_clustering_mha, (qs, ks, vs), (qmeta, kmeta))

    print("Benchmarking mha Kernel Attention backward...")
    benchmark_fn(attn_bwd_mha, (qs, ks, vs, qmeta, kmeta, ref_attn_res_mha, dlse, dv_out), ref_attn_bwd_mha_out)
    print("Benchmarking mha Kernel Attention...")
    benchmark_fn(attn_mha, (qs, ks, vs, qmeta, kmeta), exact_out)
    print("Benchmarking mha Kernel Attention grad...")
    benchmark_fn(attn_grad_mha, (qs, ks, vs, qmeta, kmeta), ref_attn_bwd_mha_out)

    print("Benchmarking mha Cluster Attention grad...")
    benchmark_fn(cluster_attn_grad_mha, (qs, ks, vs), ref_attn_bwd_mha_out)

    print("Benchmarking mha Pallas Flash Attention grad...")
    benchmark_fn(pallas_flash_mha_grad, (qs, ks, vs), ref_attn_bwd_mha_out)
    print("Benchmarking mha CUDNN Flash Attention grad...")
    benchmark_fn(cudnn_flash_mha_grad, (qs, ks, vs), ref_attn_bwd_mha_out)

    exit()

    print("Benchmarking mha CUDNN Flash Attention backward...")
    benchmark_fn(cudnn_flash_mha_vjp, ((dlse, dv_out),), cudnn_flash_mha_vjp_out)
    print("Benchmarking mha Pallas Flash Attention backward...")
    benchmark_fn(pallas_flash_mha_vjp, ((dlse, dv_out),), pallas_flash_mha_vjp_out)
    exit()
    print("Benchmarking oh Reference Attention auto VJP...")
    benchmark_fn(ref_attn_oh_vjp, ((oh_dlse, oh_dv_out),), ref_attn_oh_vjp_out_fp32)
    print("Benchmarking oh Reference Attention backward...")
    benchmark_fn(ref_attn_bwd, (oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta, ref_attn_res, oh_dlse, oh_dv_out), ref_attn_oh_vjp_out_fp32)
    print("Benchmarking mha Reference Attention backward...")
    benchmark_fn(ref_attn_bwd_mha, (qs, ks, vs, qmeta, kmeta, ref_attn_res_mha, dlse, dv_out), ref_attn_bwd_mha_out)

    exit()

    print("Benchmarking mha CUDNN Flash Attention...")
    benchmark_fn(cudnn_flash_mha, (qs, ks, vs), exact_out)
    #print("Benchmarking Pallas Flash Attention...")
    #benchmark_fn(partial(pallas_flash, sm_scale=sm_scale), (oh_qs, oh_ks, oh_vs), exact_oh_out)
    print("Benchmarking mha Pallas Flash Attention...")
    benchmark_fn(pallas_flash_mha, (qs, ks, vs), exact_out)
    print("Benchmarking Baseline Approximation...")
    benchmark_fn(baseline_approx_mha, (qs, ks, vs, qmeta, kmeta), exact_out)
    #print("Benchmarking Reference Attention...")
    #benchmark_fn(ref_attn, (oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta), exact_oh_out)
    #print("Benchmarking Kernel Attention...")
    #benchmark_fn(attn, (oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta), exact_oh_out)

    print("Benchmarking mha Reference Attention...")
    benchmark_fn(ref_attn_mha, (qs, ks, vs, qmeta, kmeta), exact_out)
    print("Benchmarking mha Kernel Attention...")
    benchmark_fn(attn_mha, (qs, ks, vs, qmeta, kmeta), exact_out)



if __name__ == "__main__":
    main()

    

