import jax
from jax import numpy as jnp, lax, vmap
from jax.numpy import einsum
from einshape import jax_einshape as einshape
from functools import partial
from jax.experimental import pallas as pl
from jax.experimental.pallas import triton as plgpu
from jax.scipy.special import logsumexp
import dataclasses
from time import time
import math
import numpy as np
from jax import Array as Array
from .single_level_attention import MultiLevelAttention, MultiLevelClustering
from .flash_attention import attn as pallas_flash
from .pallas_cluster import do_clustering as pallas_do_clustering
from flax import struct

def cudnn_flash_mha(q, k, v):
    lse = jnp.zeros(q.shape[:-1], dtype=jnp.float32)
    out = jax.nn.dot_product_attention(q, k, v, implementation="cudnn")
    return lse, out

DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.float32).max)
DIPOLE = True

@struct.dataclass
class ClusterData:
    cnt: Array # K
    lab: Array # N
    fwd: Array # Ku
    bwd: Array # N

def ref_initial_full(q, k, v, kmeta):
    _k = k[kmeta.fwd, :] # (K, u, D)
    _v = v[kmeta.fwd, :] # (K, u, V)
    sm_scale = 1.0 / math.sqrt(q.shape[-1])
    attn_scores = einsum("qd,kud->qku", q, _k) * sm_scale # (Q, K, u)
    attn_scores = jnp.where(kmeta.fwd[None, :, :] >= 0, attn_scores, DEFAULT_MASK_VALUE) # (Q, K, u)
    attn_weights = jax.nn.softmax(attn_scores, axis=-1) # (Q, K, u)
    v_out = einsum("qku,kuv->qkv", attn_weights, _v) # (Q, K, V)
    k_out = einsum("qku,kud->qkd", attn_weights, _k) # (Q, K, D)
    lse = logsumexp(attn_scores, axis=-1) # (Q, K)

    mean_0 = lambda arr: jnp.mean(arr, axis=0)
    v_bar = mean_0(v_out) # (K, V)
    k_bar = mean_0(k_out) # (K, D)
    vk_out = einsum("ku,kuv,kud->kvd", mean_0(attn_weights), _v, _k) - einsum("kv,kd->kvd", v_bar, k_bar) # (K, V, D)
    return lse, k_out, v_out, vk_out, v_bar, k_bar # (Q, K), (Q, K, D), (Q, K, V), (K, V, D), (K, V), (K, D)


def ref_initial(q, k, v, kmeta): # (Q, D), (S, D), (S, V), (S,)
    _k = k[kmeta.fwd, :] # (K, u, D)
    _v = v[kmeta.fwd, :] # (K, u, V)
    sm_scale = 1.0 / math.sqrt(q.shape[-1])
    attn_scores = einsum("qd,kud->qku", q, _k) * sm_scale # (Q, K, u)
    attn_scores = jnp.where(kmeta.fwd[None, :, :] >= 0, attn_scores, DEFAULT_MASK_VALUE) # (Q, K, u)
    attn_weights = jax.nn.softmax(attn_scores, axis=-1) # (Q, K, u)
    v_out = einsum("qku,kuv->qkv", attn_weights, _v) # (Q, K, V)
    k_out = einsum("qku,kud->qkd", attn_weights, _k) # (Q, K, D)

    mask = kmeta.fwd[:,:] >= 0
    v_bar = jnp.mean(_v, axis=1, where=mask[...,None]) # (K, V)
    k_bar = jnp.mean(_k, axis=1, where=mask[...,None]) # (K, D)
    vk_out = einsum("kuv,kud->kvd", jnp.where(mask[...,None], _v, 0.0), _k) / jnp.sum(mask, axis=1)[:,None,None] # (K, V, D)
    vk_out = vk_out - einsum("kv,kd->kvd", v_bar, k_bar) # (K, V, D)
    lse = logsumexp(attn_scores, axis=-1) # (Q, K)
    # Post processing
    #vk_weights = jax.nn.softmax(lse, axis=-1).mean(axis=0) # (K,)
    #vk_out = einsum("kvd,k->vd", vk_out, vk_weights)[None,:,:] # (Q, V, D)
    #vk_out = jnp.mean(vk_out, axis=0, keepdims=True) # (1, V, D)
    #vk_weights = jax.nn.softmax(lse, axis=-1) # (Q, K)
    #vk_out_final = einsum("qk,kvd->qvd", vk_weights, vk_out) # (Q, V, D)
    return lse, k_out, v_out, vk_out, v_bar, k_bar # (Q, K), (Q, K, D), (Q, K, V), (Q, V, D)

def ref_initial_bwd(q, k, v, kmeta, lse, k_out, v_out, v_bar, k_bar, dlse, dk_out, dv_out, dvk_out):
    # (Q, D), (S, D), (S, V), (S,), 
    # (Q, K), (Q, K, D), (Q, K, V), (Q, V, D),
    # (Q, K), (Q, K, D), (Q, K, V), (Q, V, D)
    # Pre processing
    #vk_weights = jax.nn.softmax(lse, axis=-1) # (Q, K)
    #dvk_interm = einsum("qvd,qk->kvd", dvk_out, vk_weights) # (K, V, D)
    #dvk_interm = jnp.mean(dvk_out, axis=0, keepdims=True)
    dvk_interm = dvk_out
    norm_dvk_interm = dvk_interm / kmeta.cnt[:, None, None] # (K, V, D)
    dlse_modified = dlse - einsum("qkd,qkd->qk", dk_out, k_out) - einsum("qkv,qkv->qk", dv_out, v_out) # (Q, K)
    # Kernel
    _k = k[kmeta.fwd, :] # (K, u, D)
    _v = v[kmeta.fwd, :] # (K, u, V)
    mask = kmeta.fwd[:,:] >= 0
    _k = jnp.where(mask[...,None], _k, 0.0) # (K, u, D)
    _v = jnp.where(mask[...,None], _v, 0.0) # (K, u, V)
    sm_scale = 1.0 / math.sqrt(q.shape[-1])
    attn_scores = einsum("qd,kud->qku", q, _k) * sm_scale # (Q, K, u)
    attn_scores = jnp.where(kmeta.fwd[None, :, :] >= 0, attn_scores, DEFAULT_MASK_VALUE) # (Q, K, u)
    #attn_weights = jax.nn.softmax(attn_scores, axis=-1) # (Q, K, u)
    attn_weights = jnp.exp(attn_scores - lse[:,:,None]) # (Q, K, u)

    dv0 = einsum("qku,qkv->kuv", attn_weights, dv_out) # (K, u, V)
    dk0 = einsum("qku,qkd->kud", attn_weights, dk_out) # (K, u, D)

    gv = einsum("qkv,kuv->qku", dv_out, _v) + einsum("qkd,kud->qku", dk_out, _k) # (Q, K, u)
    gv = gv + dlse_modified[:, :, None]
    gv = gv * attn_weights # (Q, K, u)
    dk2 = einsum("qku,qd->kud", gv, q) * sm_scale # (K, u, D)
    dq = einsum("qku,kud->qd", gv, _k) * sm_scale # (Q, D)

    # vk bwd
    mask = kmeta.fwd[:,:] >= 0
    kres = jnp.where(mask[...,None], _k - k_bar[:, None, :], 0.) # (K, u, D)
    dv1 = einsum("kvd,kud->kuv", norm_dvk_interm, kres) # (K, u, V)
    vres = jnp.where(mask[...,None], _v - v_bar[:, None, :], 0.) # (K, u, V)
    dk1 = einsum("kvd,kuv->kud", norm_dvk_interm, vres) # (K, u, D)

    dq = dq
    dv = dv0 + dv1
    dk = dk0 + dk2 + dk1
    dk = dk[kmeta.lab, kmeta.bwd, :] # (N, D)
    dv = dv[kmeta.lab, kmeta.bwd, :] # (N, V)
    return dq, dk, dv

def initial_bwd_kernel(q_ref, k_ref, v_ref, kfwd_ref, lse_ref, k_bar_ref, v_bar_ref, dlse_ref, dok_ref, dov_ref, dovk_ref, dq_ref, dk_ref, dv_ref, *, block_u: int):
    Q, D = q_ref.shape # (N, D)
    N, V = v_ref.shape # (N, V)
    U, = kfwd_ref.shape # (U,)
    sm_scale = 1.0 / math.sqrt(q_ref.shape[-1])

    dq = jnp.zeros((Q, D), dtype=jnp.float32) # (Q, D)

    q = q_ref[...] # (Q, D)
    lse = lse_ref[...] # (Q,)
    lse = math.log2(math.e) * lse # Convert to log2 scale
    dlse = dlse_ref[...] # (Q,)
    v_bar = v_bar_ref[...] # (V,)
    k_bar = k_bar_ref[...] # (D,)
    dov = dov_ref[...] # (Q, V)
    dok = dok_ref[...] # (Q, D)
    dovk = dovk_ref[...] # (V, D)

    def body(start_u, dq_prev):
        curr_u_slice = pl.dslice(start_u*block_u, block_u)
        kfwd = kfwd_ref[curr_u_slice] # (u,)
        mask = kfwd >= 0 # (u,)
        k = pl.load(k_ref, (kfwd, slice(None)), mask=mask[:,None], other=0.0) # (u, D)
        v = pl.load(v_ref, (kfwd, slice(None)), mask=mask[:,None], other=0.0) # (u, V)
        qk_scale = math.log2(math.e) * sm_scale # (Q, u)
        qk = pl.dot(q, k, trans_b=True, allow_tf32=True) * qk_scale
        qk = jnp.where(mask[None, :], qk, DEFAULT_MASK_VALUE)
        s = jnp.exp2(qk - lse[:, None]) # (Q, u)

        dv0 = pl.dot(s.astype(dov.dtype), dov, trans_a=True, allow_tf32=True) # (u, V)
        dk0 = pl.dot(s.astype(dok.dtype), dok, trans_a=True, allow_tf32=True) # (u, D)

        gv = pl.dot(dov.astype(v.dtype), v, trans_b=True, allow_tf32=True) # (Q, u)
        gk = pl.dot(dok.astype(k.dtype), k, trans_b=True, allow_tf32=True) # (Q, u)

        ds = gv + gk + dlse[:, None]
        dqk = ds * s # (Q, u)

        dk2 = pl.dot(dqk.astype(q.dtype), q, trans_a=True, allow_tf32=True) * sm_scale # (u, D)
        dq_curr = pl.dot(dqk.astype(k.dtype), k, allow_tf32=True) * sm_scale # (Q, D)
        dq_next = dq_prev + dq_curr # (Q, D)

        kres = jnp.where(mask[:, None], k - k_bar[None, :], 0.0) # (u, D)
        dv1 = pl.dot(kres, dovk, trans_b=True, allow_tf32=True) # (u, V)
        vres = jnp.where(mask[:, None], v - v_bar[None, :], 0.0) # (u, V)
        dk1 = pl.dot(vres, dovk, allow_tf32=True) # (u, D)

        dv = dv0 + dv1 # (Q, V)
        pl.store(dv_ref, (kfwd, slice(None)), dv.astype(dv_ref.dtype), mask=mask[:,None])
        dk = dk0 + dk2 + dk1 # (Q, D)
        pl.store(dk_ref, (kfwd, slice(None)), dk.astype(dk_ref.dtype), mask=mask[:,None])
        return dq_next

    lower_bound = 0
    upper_bound = pl.cdiv(U, block_u)
    assert U % block_u == 0, f"U ({U}) must be divisible by block_u ({block_u})"
    dq = lax.fori_loop(lower_bound, upper_bound, body, dq) # (Q, D)
    dq_ref[...] = dq.astype(dq_ref.dtype) # (Q, D)

def initial_bwd(q, k, v, kmeta, lse, k_out, v_out, v_bar, k_bar, dlse, dk_out, dv_out, dvk_out):
    norm_dvk_out = dvk_out / kmeta.cnt[:, None, None] # (K, V, D)
    dlse_modified = dlse - einsum("qkd,qkd->qk", dk_out, k_out) - einsum("qkv,qkv->qk", dv_out, v_out) # (Q, K)
    #dlse_modified = dlse - jnp.sum(dk_out * k_out, axis=-1) - jnp.sum(dv_out * v_out, axis=-1) # (Q, K)
    Q, D = q.shape # (Q, D)
    S, V = v.shape # (S, V)
    K, U = kmeta.fwd.shape # (K, u)
    dqs, dk, dv = pl.pallas_call(
        partial(initial_bwd_kernel, block_u=16),
        out_shape=(jax.ShapeDtypeStruct((K,)+q.shape, q.dtype), k, v),
        grid=(K,),
        in_specs=[
            pl.BlockSpec(q.shape, lambda i: (0,0)),
            pl.BlockSpec(k.shape, lambda i: (0,0)),
            pl.BlockSpec(v.shape, lambda i: (0,0)),
            pl.BlockSpec((None, U), lambda i: (i, 0)), # kfwd_ref
            pl.BlockSpec((Q, None), lambda i: (0, i)), # lse_ref
            pl.BlockSpec((None, D), lambda i: (i, 0)), # k_bar_ref
            pl.BlockSpec((None, V), lambda i: (i, 0)), # v_bar_ref
            pl.BlockSpec((Q, None), lambda i: (0, i)), # dlse_ref
            pl.BlockSpec((Q, None, D), lambda i: (0, i, 0)), # dok_ref
            pl.BlockSpec((Q, None, V), lambda i: (0, i, 0)), # dov_ref
            pl.BlockSpec((None, V, D), lambda i: (i, 0, 0)), # dovk_ref
        ],
        out_specs=[
            pl.BlockSpec((None,) + q.shape, lambda i: (i, 0, 0)), # dq_ref
            pl.BlockSpec(k.shape, lambda i: (0, 0)), # dk_ref
            pl.BlockSpec(v.shape, lambda i: (0, 0)), # dv_ref
        ],
        compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=1),
        name="initial_backward",
    )(q, k, v, kmeta.fwd, lse, k_bar, v_bar, dlse_modified, dk_out, dv_out, norm_dvk_out)
    dq = jnp.sum(dqs, axis=0) # (Q, D)
    return dq, dk, dv # (Q, D), (S, D), (S, V)

        




        



    
    


# Can maybe increase block size by changing to half precision intermediates, or separating out vk computation
def initial_kernel(q_ref, k_ref, v_ref, kfwd_ref, lse_ref, ok_ref, ov_ref, ovk_ref, ok_bar_ref, ov_bar_ref, *, block_k: int):
    sm_scale = 1.0 / math.sqrt(q_ref.shape[-1])
    Q, D = q_ref.shape
    _, V = v_ref.shape
    (U,) = kfwd_ref.shape
    assert U % block_k == 0, f"U ({U}) must be divisible by block_k ({block_k})"
    q = q_ref[...] # (Q, D)
    #kfwd_all = kfwd_ref[...] # (u,)
    ov = jnp.zeros((Q, V), dtype=jnp.float32) # (Q, V)
    ok = jnp.zeros((Q, D), dtype=jnp.float32) # (Q, D)
    ovk = jnp.zeros((V, D), dtype=jnp.float32) # (V, D)
    v_vk = jnp.zeros((V,), dtype=jnp.float32) # (V,)
    k_vk = jnp.zeros((D,), dtype=jnp.float32) # (D,)
    m = jnp.zeros(Q, dtype=jnp.float32) - jnp.inf # (Q,)
    l = jnp.zeros(Q, dtype=jnp.float32) # (Q,)
    lvk = jnp.zeros((), dtype=jnp.float32) # ()
    def body(start_k, carry):
        ov_prev, ok_prev, ovk_prev, v_vk_prev, k_vk_prev, m_prev, l_prev, lvk_prev = carry
        curr_k_slice = pl.dslice(start_k*block_k, block_k)
        # NOTE THIS IS INEFFICIENT - SHOULD JUST LOAD ALL AT THE BEGINNING
        kfwd = kfwd_ref[curr_k_slice] # (u,)
        #kfwd = lax.dynamic_slice(kfwd_all, (start_k*block_k,), (block_k,))
        mask = kfwd >= 0
        k = pl.load(k_ref, (kfwd, slice(None)), mask=mask[:,None], other=0.0) # (u, D)
        v = pl.load(v_ref, (kfwd, slice(None)), mask=mask[:,None], other=0.0) # (u, V)
        qk_scale = math.log2(math.e) * sm_scale # (Q, u)
        qk = pl.dot(q, k, trans_b=True, allow_tf32=True) * qk_scale # (Q, u)
        qk = jnp.where(mask[None, :], qk, DEFAULT_MASK_VALUE)
        m_curr = jnp.max(qk, axis=-1) # (Q,)
        m_next = jnp.maximum(m_prev, m_curr) # (Q,)
        correction = jnp.exp2(m_prev - m_next) # (Q,)
        l_prev_corr = l_prev * correction # (Q,)
        s_curr = jnp.exp2(qk - m_next[:, None]) # (Q, u)
        l_curr = jnp.sum(s_curr, axis=-1) # (Q,)
        l_next = l_prev_corr + l_curr
        ov_prev_corr = ov_prev * correction[:, None] # (Q, V)
        ov_curr = pl.dot(s_curr.astype(v.dtype), v, allow_tf32=True) # (Q, V)
        ov_next = ov_prev_corr + ov_curr # (Q, V)
        ok_prev_corr = ok_prev * correction[:, None]
        ok_curr = pl.dot(s_curr.astype(k.dtype), k, allow_tf32=True)
        ok_next = ok_prev_corr + ok_curr
        if DIPOLE:
            ovk_curr = pl.dot(v, k, trans_a=True, allow_tf32=True)
            ovk_next = ovk_prev + ovk_curr # (V, D)
            v_vk_curr = jnp.sum(v, axis=0)
            v_vk_next = v_vk_prev + v_vk_curr # (V,)
            k_vk_curr = jnp.sum(k, axis=0)
            k_vk_next = k_vk_prev + k_vk_curr # (D,)
            lvk_next = lvk_prev + jnp.sum(jnp.where(mask, 1.0, 0.0))
        else:
            ovk_next, v_vk_next, k_vk_next, lvk_next = ovk_prev, v_vk_prev, k_vk_prev, lvk_prev # (V, D), (V,), (D,)
        return ov_next, ok_next, ovk_next, v_vk_next, k_vk_next, m_next, l_next, lvk_next

    lower_bound = 0
    upper_bound = pl.cdiv(U, block_k)
    ov, ok, ovk, v_vk, k_vk, m, l, lvk = lax.fori_loop(lower_bound, upper_bound, body, (ov, ok, ovk, v_vk, k_vk, m, l, lvk)) # (Q, V), (Q, D), (Q,), (Q,)
    ov = ov / l[:, None] # (Q, V)
    ok = ok / l[:, None] # (Q, D)
    ov_ref[...] = ov.astype(ov_ref.dtype) # (Q, V)
    ok_ref[...] = ok.astype(ok_ref.dtype) # (Q, D)
    lse_ref[...] = math.log(2)*m + jnp.log(l) # (Q,)
    if DIPOLE:
        ovk = ovk / lvk # (V, D)
        ov_bar = v_vk / lvk # (V,)
        ok_bar = k_vk / lvk # (D,)
        ovk = ovk - ov_bar[:,None] * ok_bar[None,:] # (V, D)
        ovk = ovk
    ovk_ref[...] = ovk.astype(ovk_ref.dtype) # (V, D)
    ov_bar_ref[...] = ov_bar.astype(ov_bar_ref.dtype) # (V,)
    ok_bar_ref[...] = ok_bar.astype(ok_bar_ref.dtype) # (D,)

def initial(q, k, v, kmeta):
    K, U = kmeta.fwd.shape # (K, u)
    Q, D = q.shape # (Q, D)
    _, V = v.shape # (K, V)

    lse, k_bar, v_bar, vk_bar, k_vk_bar, v_vk_bar = pl.pallas_call( # (Q, K), (Q, K, D), (Q, K, V), (K, V, D)
        partial(initial_kernel, block_k=32),
        out_shape=[
            jax.ShapeDtypeStruct((Q, K), jnp.float32), # lse
            jax.ShapeDtypeStruct((Q, K, D), k.dtype), # k_bar
            jax.ShapeDtypeStruct((Q, K, V), v.dtype), # v_bar
            jax.ShapeDtypeStruct((K, V, D), v.dtype), # vk_bar
            jax.ShapeDtypeStruct((K, D), k.dtype), # k_vk_bar
            jax.ShapeDtypeStruct((K, V), v.dtype), # v_vk_bar
        ],
        grid=(K,),
        in_specs=[
            pl.BlockSpec(q.shape, lambda i: (0, 0)),
            pl.BlockSpec(k.shape, lambda i: (0, 0)),
            pl.BlockSpec(v.shape, lambda i: (0, 0)),
            pl.BlockSpec((None, U), lambda i: (i, 0)),
        ],
        out_specs=[
            pl.BlockSpec((Q, None), lambda i: (0, i)),
            pl.BlockSpec((Q, None, D), lambda i: (0, i, 0)),
            pl.BlockSpec((Q, None, V), lambda i: (0, i, 0)),
            pl.BlockSpec((None, V, D), lambda i: (i, 0, 0)),
            pl.BlockSpec((None, D), lambda i: (i, 0)),
            pl.BlockSpec((None, V), lambda i: (i, 0)),
        ],
        compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=1),
        name="initial_forward",
    )(q, k, v, kmeta.fwd)
    # Post processing
    #vk_weights = jax.nn.softmax(lse, axis=-1).mean(axis=0) # (K,)
    #vk_bar = jnp.sum(vk_bar * vk_weights[:, None, None], axis=0)[None,:,:] # (1, V, D)
    #vk_weights = jax.nn.softmax(lse, axis=-1) # (Q, K)
    #vk_bar = einsum("qk,kvd->qvd", vk_weights, vk_bar) # (Q, V, D)
    return lse, k_bar, v_bar, vk_bar, v_vk_bar, k_vk_bar # (Q, K), (Q, K, D), (Q, K, V), (Q, V, D)



def ref_final(q, q_bar, k, v, vk, m, qmeta): # (Q, n, D), (Q, K, D), (Q, K, V), (Q, V, D), (Q, K), (T,)
    _q = q[qmeta.fwd, :] # (Q, n, D)
    qres = _q - q_bar[:, None, :] # (Q, n, D)
    sm_scale = 1.0 / math.sqrt(qres.shape[-1])
    attn_scores = einsum("qnd,qkd->qnk", qres, k) * sm_scale + m[:,None,:] # (Q, n, K)
    attn_weights = jax.nn.softmax(attn_scores, axis=-1) # (Q, n, K)
    v_out_mon = einsum("qnk,qkv->qnv", attn_weights, v) # (Q, n, V)
    if vk is not None:
        v_out = v_out_mon + einsum("qvd,qnd->qnv", vk, qres) * sm_scale # (Q, n, V)
    lse = logsumexp(attn_scores, axis=-1) # (Q, n)
    lse = lse.astype(jnp.float32) # Ensure lse is in float32 for consistency
    return lse[qmeta.lab, qmeta.bwd], v_out[qmeta.lab, qmeta.bwd, :], v_out_mon[qmeta.lab, qmeta.bwd, :] # (N, n), (N, n, V), (N, n, V)

@jax.grad
def probe_func(attn_scores, v, dv_out):
    attn_weights = jax.nn.softmax(attn_scores, axis=-1) # (Q, n, K)
    print("shapes:", attn_weights.shape, v.shape, dv_out.shape)
    v_out = einsum("qnk,qkv->qnv", attn_weights, v) # (Q, n, V)
    return jnp.sum(v_out * dv_out)

def test_grad(attn_scores, v, dv_out):
    attn_weights = jax.nn.softmax(attn_scores, axis=-1) # (Q, n, K)
    v_out = einsum("qnk,qkv->qnv", attn_weights, v) # (Q, n, V)
    correction = - einsum("qnv,qnv->qn", dv_out, v_out) # (Q, n)
    gv = einsum("qnv,qkv->qnk", dv_out, v) # (Q, n, K)
    gv = gv + correction[:, :, None]
    gv = gv * attn_weights # (Q, n, K)
    return gv

def ref_final_bwd(q, q_bar, k, v, vk, m, qmeta, lse, v_out, dlse, dv_out):
    # preprocessing
    mask = qmeta.fwd >= 0
    _q = q[qmeta.fwd, :] # (Q, n, D)
    v_out = v_out[qmeta.fwd, :] # (Q, n, V)
    dv_out = jnp.where(mask[:,:,None], dv_out[qmeta.fwd, :], 0.) # (Q, n, V)
    dlse = jnp.where(mask, dlse[qmeta.fwd], 0.) # (Q, n)
    lse = lse[qmeta.fwd] # (Q, n)
    dlse_modified = dlse - einsum("qnv,qnv->qn", dv_out, v_out) # (Q, n)
    # kernel
    sm_scale = 1.0 / math.sqrt(_q.shape[-1])
    qres = _q - q_bar[:, None, :] # (Q, n, D)
    attn_scores = einsum("qnd,qkd->qnk", qres, k) * sm_scale + m[:,None,:] # (Q, n, K)
    #attn_weights = jax.nn.softmax(attn_scores, axis=-1) # (Q, n, K)
    attn_weights = jnp.exp(attn_scores - lse[:,:,None]) # (Q, n, K)

    dv = einsum("qnk,qnv->qkv", attn_weights, dv_out) # (Q, K, V)
    gv = einsum("qnv,qkv->qnk", dv_out, v)
    gv = gv + dlse_modified[:,:,None]
    gv = gv * attn_weights # (Q, n, K)
    dk = einsum("qnk,qnd->qkd", gv, qres) * sm_scale # (Q, n, D)
    dqres0 = einsum("qnk,qkd->qnd", gv, k) * sm_scale # (Q, n, D)

    # vk bwd
    dqres1 = einsum("qvd,qnv->qnd", vk, dv_out) * sm_scale # (Q, n, D)
    dvk = einsum("qnd,qnv->qvd", qres, dv_out) * sm_scale # (Q, V, D)

    dqres = dqres0 + dqres1 # (Q, n, D)
    dqres = jnp.where(mask[:,:,None], dqres, 0.0) # (Q, n, D)
    dq = dqres# - jnp.mean(dqres, axis=1, where=mask[:,:,None], keepdims=True) # (Q, n, D)
    dq_bar = - jnp.sum(dqres, axis=1)
    #dq_bar = jnp.zeros_like(q_bar) # (Q, D)
    # THE BELOW MAY SIMPLIFY
    dm = jnp.sum(gv, axis=1) # (Q, K)
    
    dq = dq[qmeta.lab, qmeta.bwd, :] # (N, D)
    return dq, dq_bar, dk, dv, dvk, dm

def final_bwd_kernel(
        q_ref, q_bar_ref, k_ref, v_ref, vk_ref, mu_ref, qfwd_ref,
        lse_ref, dlse_ref, do_ref,
        dq_ref, dq_bar_ref, dk_ref, dv_ref, dvk_ref, dmu_ref,
        *, block_n: int
    ):
    T, D = q_ref.shape # (T, D)
    K, V = v_ref.shape # (K, V)
    N, = qfwd_ref.shape # (N,)
    sm_scale = 1.0 / math.sqrt(q_bar_ref.shape[-1])

    q_bar = q_bar_ref[...].astype(q_ref.dtype) # (D,)
    k = k_ref[...] # (K, D)
    v = v_ref[...] # (K, V)
    vk = vk_ref[...] # (V, D)
    mu = mu_ref[...] # (K,)

    dq_bar = jnp.zeros((D,), dtype=jnp.float32) # (D,)
    dk = jnp.zeros((K, D), dtype=jnp.float32) # (K, D)
    dv = jnp.zeros((K, V), dtype=jnp.float32) # (K, V)
    dvk = jnp.zeros((V, D), dtype=jnp.float32) # (V, D)
    dmu = jnp.zeros((K,), dtype=jnp.float32) # (K,)

    def body(start_n, carry):
        dq_bar_prev, dk_prev, dv_prev, dvk_prev, dmu_prev = carry # (D,), (K, D), (K, V), (V, D), (K,)

        curr_n_slice = pl.dslice(start_n*block_n, block_n)
        qfwd = qfwd_ref[curr_n_slice] # (n,)
        mask = qfwd >= 0
        q = pl.load(q_ref, (qfwd, slice(None)), mask=mask[:,None], other=0.0) # (n, D)
        lse = pl.load(lse_ref, (qfwd,), mask=mask, other=0.0) # (n,)
        do = pl.load(do_ref, (qfwd, slice(None)), mask=mask[:,None], other=0.0) # (n, V)
        dlse = pl.load(dlse_ref, (qfwd,), mask=mask, other=0.0) # (n,)
        #dlse = dlse_ref[...] # (n,)
        qres = q - q_bar[None, :] # (n, D)
        qk_scale = math.log2(math.e) * sm_scale
        mu_scale = math.log2(math.e)
        qk = pl.dot(qres, k, trans_b=True, allow_tf32=True) * qk_scale + mu[None, :] * mu_scale # (n, K)
        s = jnp.exp2(qk - lse[:,None] * mu_scale) # (n, K)

        dv_curr = pl.dot(s.astype(do.dtype), do, trans_a=True, allow_tf32=True) # (K, V)
        dv_next = dv_prev + dv_curr # (K, V)
        gv = pl.dot(do.astype(v.dtype), v, trans_b=True, allow_tf32=True) # (n, K)
        ds = gv + dlse[:, None]
        dqk = ds * s # (n, K)

        dk_curr = pl.dot(dqk.astype(qres.dtype), qres, trans_a=True, allow_tf32=True) * sm_scale # (K, D)
        dk_next = dk_prev + dk_curr # (K, D)
        dqres0 = pl.dot(dqk.astype(k.dtype), k, allow_tf32=True) * sm_scale # (n, D)

        dqres1 = pl.dot(do, vk, allow_tf32=True) * sm_scale # (n, D)
        dvk_curr = pl.dot(do, qres, trans_a=True, allow_tf32=True) * sm_scale # (V, D)
        dvk_next = dvk_prev + dvk_curr # (V, D)

        dqres = dqres0 + dqres1 # (n, D)
        dq = dqres
        dq_bar_curr = - jnp.sum(dqres, axis=0) # (D,)
        dq_bar_next = dq_bar_prev + dq_bar_curr # (D,)
        dmu_curr = jnp.sum(dqk, axis=0) # (K,)
        dmu_next = dmu_prev + dmu_curr # (K,)

        pl.store(dq_ref, (qfwd, slice(None)), dq.astype(dq_ref.dtype), mask=mask[:,None]) # (n, D)

        return dq_bar_next, dk_next, dv_next, dvk_next, dmu_next

    lower_bound = 0
    upper_bound = pl.cdiv(N, block_n)
    assert N % block_n == 0, f"N ({N}) must be divisible by block_n ({block_n})"
    dq_bar, dk, dv, dvk, dmu = lax.fori_loop(lower_bound, upper_bound, body, (dq_bar, dk, dv, dvk, dmu)) # (D,), (K, D), (K, V), (V, D), (K,)

    dq_bar_ref[...] = dq_bar.astype(dq_bar_ref.dtype) # (D,)
    dk_ref[...] = dk.astype(dk_ref.dtype) # (K, D)
    dv_ref[...] = dv.astype(dv_ref.dtype) # (K, V)
    dvk_ref[...] = dvk.astype(dvk_ref.dtype) # (V, D)
    dmu_ref[...] = dmu.astype(dmu_ref.dtype) # (K,)

def final_bwd(q, q_bar, k, v, vk, m, qmeta, lse, v_out, dlse, dv_out):
    #dlse_modified = dlse - einsum("tv,tv->t", dv_out, v_out).astype(dlse.dtype) # (T,)
    dlse_modified = dlse - jnp.sum(dv_out * v_out, axis=-1).astype(dlse.dtype) # (T,)
    T, D = q.shape # (T, D)
    assert dlse_modified.shape == q.shape[:-1], f"dlse_modified shape {dlse_modified.shape} does not match q shape {q.shape[:-1]}"
    Q, K, V = v.shape # (Q, K, V)
    _, N = qmeta.fwd.shape # (Q, N)
    block_n = 32
    assert N % block_n == 0, f"N ({N}) must be divisible by block_n ({block_n})"

    dq, dq_bar, dk, dv, dvk, dm = pl.pallas_call(
        kernel=partial(final_bwd_kernel, block_n=block_n),
        out_shape=(q, q_bar, k, v, vk, m),
        grid=(Q,),
        in_specs=[
            pl.BlockSpec(q.shape, lambda i: (0, 0)), # q_ref
            pl.BlockSpec((None, D), lambda i: (i, 0)), # q_bar_ref
            pl.BlockSpec((None, K, D), lambda i: (i, 0, 0)), # k_ref
            pl.BlockSpec((None, K, V), lambda i: (i, 0, 0)), # v_ref
            pl.BlockSpec((None, V, D), lambda i: (i, 0, 0)), # vk_ref
            pl.BlockSpec((None, K), lambda i: (i, 0)), # mu_ref
            pl.BlockSpec((None, N), lambda i: (i, 0)), # qfwd_ref
            pl.BlockSpec(lse.shape, lambda i: (0,)), # lse_ref
            pl.BlockSpec(dlse.shape, lambda i: (0,)), # dlse_ref
            pl.BlockSpec(dv_out.shape, lambda i: (0, 0)), # do_ref
        ],
        out_specs=[
            pl.BlockSpec(q.shape, lambda i: (0, 0)), # dq_ref
            pl.BlockSpec((None, D), lambda i: (i, 0)), # dq_bar_ref
            pl.BlockSpec((None, K, D), lambda i: (i, 0, 0)), # dk_ref
            pl.BlockSpec((None, K, V), lambda i: (i, 0, 0)), # dv_ref
            pl.BlockSpec((None, V, D), lambda i: (i, 0, 0)), # dvk_ref
            pl.BlockSpec((None, K), lambda i: (i, 0)), # dmu_ref
        ],
        compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=1),
        name="final_backward",
    )(q, q_bar, k, v, vk, m, qmeta.fwd, lse, dlse_modified, dv_out)
    return dq, dq_bar, dk, dv, dvk, dm # (T, D), (Q, D), (Q, K, D), (Q, K, V), (Q, V, D), (Q, K,)










def final_kernel(q_ref, q_bar_ref, k_ref, v_ref, vk_ref, m_ref, qfwd_ref, lse_ref, v_out_ref, v_out_mon_ref):
    q_bar = q_bar_ref[...] # (D,)
    sm_scale = 1.0 / math.sqrt(q_bar.shape[-1])
    k = k_ref[...] # (K, D)
    v = v_ref[...] # (K, V)
    vk = vk_ref[...] # (V, D)
    vk = (vk * sm_scale).astype(v.dtype)
    mu = m_ref[...] # (K,)
    qfwd = qfwd_ref[...] # (n,)
    mask = qfwd >= 0
    q = pl.load(q_ref, (qfwd, slice(None)), mask=mask[:,None], other=0.0) # (n, D)

    qres = q - q_bar[None, :] # (n, D)
    qk_scale = math.log2(math.e) * sm_scale
    mu_scale = math.log2(math.e)
    attn_scores = pl.dot(qres, k, trans_b=True, allow_tf32=True) * qk_scale + mu[None, :] * mu_scale # (n, K)
    m = jnp.max(attn_scores, axis=-1) # (n,)
    s = jnp.exp2(attn_scores - m[:,None]) # (n, K)
    l = jnp.sum(s, axis=-1) # (n,)
    ov = pl.dot(s.astype(v.dtype), v, allow_tf32=True) # (n, V)
    ov = ov / l[:, None] # (n, V)
    pl.store(v_out_mon_ref, (qfwd, slice(None)), ov.astype(v_out_mon_ref.dtype), mask=mask[:,None]) # (n, V)
    extra_ov = pl.dot(qres.astype(v.dtype), vk, trans_b=True, allow_tf32=True) # (n, V)
    ov = ov + extra_ov # (n, V)
    lse = math.log(2) * m + jnp.log(l) # (n,)
    pl.store(lse_ref, (qfwd,), lse, mask=mask)
    pl.store(v_out_ref, (qfwd, slice(None)), ov.astype(v_out_ref.dtype), mask=mask[:,None])

# Check whether usage of half precision is optimal - some stuff might be unnecessarily in float32
def final(q, q_bar, k, v, vk, m, qmeta):
    Q, n = qmeta.fwd.shape
    N, D = q.shape # (N, D)
    _, K, V = v.shape # (K, V)
    block_q = 32
    assert n % block_q == 0, f"n ({n}) must be divisible by block_q ({block_q})"


    lse, v_out, v_out_mon = pl.pallas_call(
        final_kernel,
        out_shape=[
            jax.ShapeDtypeStruct((N,), jnp.float32), # lse
            jax.ShapeDtypeStruct((N, V), v.dtype), # v_out
            jax.ShapeDtypeStruct((N, V), v.dtype), # v_out
        ],
        grid=(Q, pl.cdiv(n, block_q)),
        in_specs=[
            pl.BlockSpec(q.shape, lambda i,j: (0, 0)),
            pl.BlockSpec((None, D,), lambda i,j: (i, 0)),
            pl.BlockSpec((None, K, D), lambda i,j: (i, 0, 0)),
            pl.BlockSpec((None, K, V), lambda i,j: (i, 0, 0)),
            pl.BlockSpec((None, V, D), lambda i,j: (i, 0, 0)),
            pl.BlockSpec((None, K), lambda i,j: (i, 0)),
            pl.BlockSpec((None, block_q), lambda i,j: (i, j)),
        ],
        out_specs=[
            pl.BlockSpec((N,), lambda i,j: (0,)),
            pl.BlockSpec((N, V), lambda i,j: (0, 0)),
            pl.BlockSpec((N, V), lambda i,j: (0, 0)),
        ],
        compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=1),
        name="final_forward",
    )(q, q_bar, k, v, vk, m, qmeta.fwd)
    return lse, v_out, v_out_mon

def ref_preprocess(q, qmeta):
    _q = q[qmeta.fwd, :] # (Q, n, D)
    q_bar = jnp.mean(_q, axis=1, where=qmeta.fwd[:,:,None]>=0) # (Q, D)
    return q_bar

def ref_preprocess_bwd(qmeta, dq_bar): # _, (Q, D)
    scale = 1.0 / qmeta.cnt
    return (dq_bar * scale[:, None])[qmeta.lab, :] # (N, D)

def ref_vk_merge(lse_bar, vk_bar):
    weights = jax.nn.softmax(lse_bar, axis=-1) # (Q, K)
    if False:
        weights = weights * 0.
        print("WARNING: DIPOLE DISABLED IN REF_VK_MERGE")
    return einsum("qk,kvd->qvd", weights, vk_bar).astype(vk_bar.dtype), weights # (Q, V, D)

def ref_vk_merge_bwd(weights, vk_bar, dvk_merged):
    vk_merged = einsum("qk,kvd->qvd", weights, vk_bar) # (Q, V, D)
    dvk_bar = einsum("qvd,qk->kvd", dvk_merged, weights) # (K, V, D)
    dweights = einsum("qvd,kvd->qk", dvk_merged, vk_bar) # (Q, K)
    correction = - einsum("qvd,qvd->q", dvk_merged, vk_merged) # (Q,)
    dlse_bar = (dweights + correction[:, None]) * weights
    return dlse_bar, dvk_bar.astype(dvk_merged.dtype) # (Q, K), (K, V, D)



def ref_attn(q, k, v, qmeta, kmeta):
    q_bar = ref_preprocess(q, qmeta) # (Q, D)
    lse_bar, k_bar, v_bar, vk_bar, _, _ = ref_initial(q_bar, k, v, kmeta)
    vk_merged, _ = ref_vk_merge(lse_bar, vk_bar) # (Q, V, D)
    lse, v_out, _ = ref_final(q, q_bar, k_bar, v_bar, vk_merged, lse_bar, qmeta)
    return lse, v_out

def ref_attn_fwd(q, k, v, qmeta, kmeta):
    q_bar = ref_preprocess(q, qmeta) # (Q, D)
    lse_bar, k_bar, v_bar, vk_bar, v_bar_vk, k_bar_vk = ref_initial(q_bar, k, v, kmeta) # (Q, K), (Q, K, D), (Q, K, V), (K, V, D)
    vk_merged, merge_weights = ref_vk_merge(lse_bar, vk_bar) # (Q, V, D)
    lse, v_out, v_out_mon = ref_final(q, q_bar, k_bar, v_bar, vk_merged, lse_bar, qmeta) # (Q, n), (Q, n, V)
    return (lse, v_out), (lse, v_out_mon, vk_merged, q_bar, lse_bar, k_bar, v_bar, vk_bar, v_bar_vk, k_bar_vk, merge_weights)

def ref_attn_bwd(q, k, v, qmeta, kmeta, res, dlse, dv_out):
    lse, v_out_mon, vk_merged, q_bar, lse_bar, k_bar, v_bar, vk_bar, v_bar_vk, k_bar_vk, merge_weights = res
    dq0, dq_bar0, dk_bar, dv_bar, dvk_merged, dlse_bar0 = ref_final_bwd(q, q_bar, k_bar, v_bar, vk_merged, lse_bar, qmeta, lse, v_out_mon, dlse, dv_out)
    dlse_bar1, dvk_bar = ref_vk_merge_bwd(merge_weights, vk_bar, dvk_merged) # (Q, K), (K, V, D)
    dlse_bar = dlse_bar0 + dlse_bar1 # (Q, K)
    dq_bar1, dk, dv = ref_initial_bwd(q_bar, k, v, kmeta, lse_bar, k_bar, v_bar, v_bar_vk, k_bar_vk, dlse_bar, dk_bar, dv_bar, dvk_bar)
    dq1 = ref_preprocess_bwd(qmeta, dq_bar0 + dq_bar1) # (N, D)
    dq = dq0 + dq1 # (Q, D)
    return dq, dk, dv


@jax.custom_vjp
def attn(q, k, v, qmeta, kmeta):
    q_bar = ref_preprocess(q, qmeta) # (Q, D)
    lse_bar, k_bar, v_bar, vk_bar, _, _ = initial(q_bar, k, v, kmeta)
    vk_merged, _ = ref_vk_merge(lse_bar, vk_bar) # (Q, V, D)
    lse, v_out, _ = final(q, q_bar, k_bar, v_bar, vk_merged, lse_bar, qmeta)
    return lse, v_out

def attn_fwd(q, k, v, qmeta, kmeta):
    q_bar = ref_preprocess(q, qmeta) # (Q, D)
    lse_bar, k_bar, v_bar, vk_bar, v_bar_vk, k_bar_vk = initial(q_bar, k, v, kmeta) # (Q, K), (Q, K, D), (Q, K, V), (K, V, D)
    vk_merged, merge_weights = ref_vk_merge(lse_bar, vk_bar) # (Q, V, D)
    lse, v_out, v_out_mon = final(q, q_bar, k_bar, v_bar, vk_merged, lse_bar, qmeta) # (Q, n), (Q, n, V)
    return (lse, v_out), (lse, v_out_mon, vk_merged, q_bar, lse_bar, k_bar, v_bar, vk_bar, v_bar_vk, k_bar_vk, merge_weights)

def _attn_fwd(q, k, v, qmeta, kmeta):
    out, res = attn_fwd(q, k, v, qmeta, kmeta)
    return out, (q, k, v, qmeta, kmeta) + res

def attn_bwd(q, k, v, qmeta, kmeta, res, dlse, dv_out):
    lse, v_out_mon, vk_merged, q_bar, lse_bar, k_bar, v_bar, vk_bar, v_bar_vk, k_bar_vk, merge_weights = res
    #correct_grads = ref_final_bwd(q, q_bar, k_bar, v_bar, vk_merged, lse_bar, qmeta, lse, v_out_mon, dlse, dv_out)
    dq0, dq_bar0, dk_bar, dv_bar, dvk_merged, dlse_bar0 = final_bwd(q, q_bar, k_bar, v_bar, vk_merged, lse_bar, qmeta, lse, v_out_mon, dlse, dv_out)
    #errors = jax.tree.map(lambda ref, x: jnp.sum(jnp.square(ref - x)) / jnp.sum(jnp.square(ref)), correct_grads, (dq0, dq_bar0, dk_bar, dv_bar, dvk_merged, dlse_bar0))
    #print("Errors in backward pass:", errors)
    dlse_bar1, dvk_bar = ref_vk_merge_bwd(merge_weights, vk_bar, dvk_merged) # (Q, K), (K, V, D)
    dlse_bar = dlse_bar0 + dlse_bar1 # (Q, K)
    dq_bar1, dk, dv = initial_bwd(q_bar, k, v, kmeta, lse_bar, k_bar, v_bar, v_bar_vk, k_bar_vk, dlse_bar, dk_bar, dv_bar, dvk_bar)
    dq1 = ref_preprocess_bwd(qmeta, dq_bar0 + dq_bar1) # (N, D)
    dq = dq0 + dq1 # (Q, D)
    return dq, dk, dv

def _attn_bwd(res, grads):
    q, k, v, qmeta, kmeta = res[:5]
    dlse, dv_out = grads
    dq, dk, dv = attn_bwd(q, k, v, qmeta, kmeta, res[5:], dlse, dv_out)
    return dq, dk, dv, None, None

attn.defvjp(_attn_fwd, _attn_bwd)






##############################################################################################
#Evaluation code:
##############################################################################################

def sq_err(ref, x):
    ref_mean = jnp.mean(ref)
    #ref = ref - ref_mean
    #x = x - ref_mean
    ref_msq = jnp.mean(jnp.square(ref))
    err_msq = jnp.mean(jnp.square(ref - x))
    return jnp.where(ref_msq > 0, err_msq / ref_msq, jnp.inf).astype(jnp.float32)

def axis_corr(ref, x, axis=-3):
    correct_axis = np.argmax(ref.shape)
    axis=correct_axis
    assert ref.shape[axis] >= max(ref.shape)
    ref_mean = jnp.mean(ref, axis=axis, keepdims=True)
    x_mean = jnp.mean(x, axis=axis, keepdims=True)
    ref_var = jnp.mean(jnp.square(ref - ref_mean))
    x_var = jnp.mean(jnp.square(x - x_mean))
    cov = jnp.mean((ref - ref_mean) * (x - x_mean))
    corr = cov / (jnp.sqrt(ref_var) * jnp.sqrt(x_var) + 1e-6)
    return corr.astype(jnp.float32)

mha_sq_err = vmap(vmap(sq_err, in_axes=1))
tree_sq_err = partial(jax.tree.map, lambda ref, x: float(sq_err(ref, x)))
tree_axis_corr = partial(jax.tree.map, lambda ref, x: float(axis_corr(ref, x)))

def benchmark_fn(fn, args, ref_out, *, warmup_iters=100, iters=100):
    start_time = time()
    jit_fn = jax.jit(fn)
    fn_out = jax.block_until_ready(jit_fn(*args))
    end_time = time()
    sq_errs = tree_sq_err(ref_out, fn_out)
    print(f"First call took {end_time - start_time:.4f} seconds.")
    print(f"Output errors: {sq_errs}")
    corrs = tree_axis_corr(ref_out, fn_out)
    print(f"Output correlations: {corrs}")
    # Warmup iterations
    for _ in range(warmup_iters):
        jax.block_until_ready(jit_fn(*args))
    # Benchmark iterations
    start_time = time()
    for _ in range(iters):
        jax.block_until_ready(jit_fn(*args))
    end_time = time()
    print(f"Average time over {iters} iterations: {1e3 * (end_time - start_time) / iters:.4f} ms.")

_multi_level_clustering = MultiLevelClustering(
    Qs=(64,),
    Ks=(64,),
    coupled_clustering=False,
    inner_iters=3,
    outer_iters=3,
    max_cluster_scale=1.5,
    #max_cluster_scale=2,
    compute_indices=True,
)

def _do_multi_level_clustering(qs, ks, vs):
    (qcnt, qlab, qfwd, qbwd), (kcnt, klab, kfwd, kbwd) = _multi_level_clustering.do_clustering(qs, ks, vs)
    qmeta = ClusterData(qcnt, qlab, qfwd, qbwd)
    kmeta = ClusterData(kcnt, klab, kfwd, kbwd)
    return qmeta, kmeta

_do_multi_level_clustering_mha = vmap(vmap(_do_multi_level_clustering, in_axes=1, out_axes=1))

def cluster_attn(qs, ks, vs):
    #chosen_do_clustering = _do_multi_level_clustering
    EXPAND_RATIO = 1.5
    N, D = qs.shape
    K = 64
    max_cluster_size = math.ceil(EXPAND_RATIO * N / K)
    chosen_do_clustering = partial(pallas_do_clustering, K, max_cluster_size, 1)
    qmeta, kmeta = chosen_do_clustering(qs, ks, vs)
    return attn(qs, ks, vs, qmeta, kmeta)

cluster_attn_mha = vmap(vmap(cluster_attn, in_axes=1, out_axes=1))

def get_data(dtype, n, only_first_batch=False):
    #data = jnp.load("examples/minigpt/qkv_data_64_T1/qkv_step_1600.npz")
    #data = jnp.load("examples/minigpt/qkv_data_64_T1/qkv_step_11000.npz")
    data = jnp.load("examples/minigpt/qkv_data_64_T1/qkv_step_15200.npz")
    ks, vs, qs = data["k"], data["v"], data["q"] # (B, N, H, D)
    if only_first_batch:
        print(f"WARNING: Only using first batch of data, original shape: {ks.shape}, {vs.shape}, {qs.shape}")
        ks = ks[:1, ...]
        vs = vs[:1, ...]
        qs = qs[:1, ...]
    #ks = ks[:,:,:,:32]
    #vs = vs[:,:,:,:32]
    #qs = qs[:,:,:,:32]
    B, N, H, D = qs.shape
    B, N, H, V = vs.shape
    assert N % n == 0, f"N ({N}) must be divisible by n ({n})"
    M = N // n
    qs = einshape("b(mn)hd->(bm)nhd", qs, m=M, n=n)
    ks = einshape("b(mn)hd->(bm)nhd", ks, m=M, n=n)
    vs = einshape("b(mn)hv->(bm)nhv", vs, m=M, n=n)

    #ks = ks[:, :n, :, :] # (B, N, H, D)
    #vs = vs[:, :n, :, :] # (B, N, H, V)
    #qs = qs[:, -n:, :, :] # (B, N, H, D)
    clustering = MultiLevelClustering(
        Qs=(64,),
        Ks=(64,),
        coupled_clustering=False,
        inner_iters=3,
        outer_iters=3,
        max_cluster_scale=1.5,
        #max_cluster_scale=2,
        compute_indices=True,
    )
    multi_attention = MultiLevelAttention(
        clustering=clustering,
    )
    def baseline_approx(qs, ks, vs, qmeta, kmeta):
        scale = math.sqrt(math.sqrt(qs.shape[-1]))
        lse, out = multi_attention.attend(
            queries=qs / scale, qlabels=qmeta.lab, keys=ks / scale, klabels=kmeta.lab, values=vs,
            qfwd_indices=qmeta.fwd, qbwd_indices=qmeta.bwd,
            kfwd_indices=kmeta.fwd, kbwd_indices=kmeta.bwd,
        )
        return lse, out

    baseline_approx_mha = vmap(vmap(baseline_approx, in_axes=1, out_axes=1))

    pallas_flash_mha = vmap(vmap(partial(pallas_flash, sm_scale=1.0 / math.sqrt(qs.shape[-1])), in_axes=1, out_axes=1))

    clustering_mha = vmap(vmap(clustering.do_clustering, in_axes=1, out_axes=1))
    (qcnt, qlab, qfwd, qbwd), (kcnt, klab, kfwd, kbwd) = clustering_mha(qs, ks, vs)
    qmeta = ClusterData(qcnt, qlab, qfwd, qbwd)
    kmeta = ClusterData(kcnt, klab, kfwd, kbwd)
    return pallas_flash_mha, baseline_approx, baseline_approx_mha, qs.astype(dtype), ks.astype(dtype), vs.astype(dtype), qmeta, kmeta

def main():
    dtype = jnp.bfloat16
    requested_n = 2**13
    ONLY_FIRST_BATCH = False
    pallas_flash_mha, baseline_approx, baseline_approx_mha, qs, ks, vs, qmeta, kmeta = get_data(dtype=dtype, n=requested_n, only_first_batch=ONLY_FIRST_BATCH)
    ref_attn_mha = vmap(vmap(ref_attn, in_axes=1, out_axes=1))
    attn_mha = vmap(vmap(attn, in_axes=1, out_axes=1))
    def one_head(array):
        return array[0,:, 0, ...]
    oh_qs, oh_ks, oh_vs = one_head(qs), one_head(ks), one_head(vs)
    oh_qmeta = jax.tree.map(one_head, qmeta)
    oh_kmeta = jax.tree.map(one_head, kmeta)
    print(f"oh shapes: {oh_qs.shape}, {oh_ks.shape}, {oh_vs.shape}")
    print(f"oh qmeta: {oh_qmeta.fwd.shape}, oh kmeta: {oh_kmeta.fwd.shape}")
    
    B, N, H, D = qs.shape
    _, _, _, V = vs.shape
    sm_scale = 1.0 / math.sqrt(D)
    print(f"B: {B}, N: {N}, H: {H}, D: {D}, V: {V}")
    dlse_shape = (B, N, H)
    dv_out_shape = (B, N, H, V)
    dlse = jax.random.normal(jax.random.PRNGKey(0), shape=dlse_shape, dtype=jnp.float32)
    dv_out = jax.random.normal(jax.random.PRNGKey(1), shape=dv_out_shape, dtype=dtype)
    oh_dlse = one_head(dlse)
    oh_dv_out = one_head(dv_out)

    @partial(jax.grad, argnums=(0, 1, 2))
    def attn_grad_mha(qs, ks, vs, qmeta, kmeta):
        lse, v_out = attn_mha(qs, ks, vs, qmeta, kmeta)
        return jnp.sum(v_out * dv_out) + jnp.sum(lse * dlse)

    @partial(jax.grad, argnums=(0, 1, 2))
    def cluster_attn_grad_mha(qs, ks, vs):
        lse, v_out = cluster_attn_mha(qs, ks, vs)
        return jnp.sum(v_out * dv_out) + jnp.sum(lse * dlse)

    @partial(jax.grad, argnums=(0, 1, 2))
    def pallas_flash_mha_grad(qs, ks, vs):
        lse, v_out = pallas_flash_mha(qs, ks, vs)
        return jnp.sum(v_out * dv_out) + jnp.sum(lse * dlse)

    @partial(jax.grad, argnums=(0, 1, 2))
    def cudnn_flash_mha_grad(qs, ks, vs):
        lse, v_out = cudnn_flash_mha(qs, ks, vs)
        return jnp.sum(v_out * dv_out) + jnp.sum(lse * dlse)

    #attn_grad_mha_out = attn_grad_mha(qs, ks, vs, qmeta, kmeta)

    exact_out = pallas_flash_mha(qs.astype(jnp.float32), ks.astype(jnp.float32), vs.astype(jnp.float32))
    assert (dlse.shape, dv_out.shape) == (exact_out[0].shape, exact_out[1].shape), f"Expected shapes {exact_out[0].shape}, {exact_out[1].shape} but got {dlse.shape}, {dv_out.shape}"
    exact_oh_out = pallas_flash(oh_qs.astype(jnp.float32), oh_ks.astype(jnp.float32), oh_vs.astype(jnp.float32), sm_scale=sm_scale)
    baseline_out = baseline_approx_mha(qs, ks, vs, qmeta, kmeta)
    #baseline_oh_out = baseline_approx(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
    ref_oh_out = ref_attn(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
    attn_oh_out = attn(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)

    print(tree_sq_err(exact_out, baseline_out))
    print(tree_sq_err(jax.tree.map(one_head, exact_out), jax.tree.map(one_head, baseline_out)))
    print(tree_sq_err(exact_oh_out, ref_oh_out))


    ref_attn_out, ref_attn_res = ref_attn_fwd(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
    attn_out, attn_res = attn_fwd(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
    #print(tree_sq_err(ref_attn_res, attn_res))
    ref_attn_bwd_out = ref_attn_bwd(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta, ref_attn_res, oh_dlse, oh_dv_out)
    attn_bwd_out = attn_bwd(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta, ref_attn_res, oh_dlse, oh_dv_out)
    #exit()
    ref_attn_fwd_mha = vmap(vmap(ref_attn_fwd, in_axes=1, out_axes=1))
    _, ref_attn_res_mha = ref_attn_fwd_mha(qs, ks, vs, qmeta, kmeta)
    ref_attn_bwd_mha = vmap(vmap(ref_attn_bwd, in_axes=1, out_axes=1))
    ref_attn_bwd_mha_out = ref_attn_bwd_mha(qs, ks, vs, qmeta, kmeta, ref_attn_res_mha, dlse, dv_out)
    attn_bwd_mha = vmap(vmap(attn_bwd, in_axes=1, out_axes=1))
    attn_bwd_mha_out = attn_bwd_mha(qs, ks, vs, qmeta, kmeta, ref_attn_res_mha, dlse, dv_out)
    #print(f"attn_grad error: {tree_sq_err(attn_grad_mha_out, attn_bwd_mha_out)}")

    def curried_ref_attn_oh(qs, ks, vs):
        return ref_attn(qs, ks, vs, oh_qmeta, oh_kmeta)
    _, ref_attn_oh_vjp = jax.vjp(curried_ref_attn_oh, oh_qs, oh_ks, oh_vs)
    ref_attn_oh_vjp_out = ref_attn_oh_vjp((oh_dlse, oh_dv_out))
    _, ref_attn_oh_vjp_fp32 = jax.vjp(curried_ref_attn_oh, oh_qs.astype(jnp.float32), oh_ks.astype(jnp.float32), oh_vs.astype(jnp.float32))
    ref_attn_oh_vjp_out_fp32 = ref_attn_oh_vjp_fp32((oh_dlse.astype(jnp.float32), oh_dv_out.astype(jnp.float32)))

    _, pallas_flash_mha_vjp = jax.vjp(pallas_flash_mha, qs, ks, vs)
    pallas_flash_mha_vjp_out = pallas_flash_mha_vjp((dlse, dv_out))
    _, cudnn_flash_mha_vjp = jax.vjp(cudnn_flash_mha, qs, ks, vs)
    cudnn_flash_mha_vjp_out = cudnn_flash_mha_vjp((dlse, dv_out))

    print("Benchmarking mha Cluster Attention...")
    benchmark_fn(cluster_attn_mha, (qs, ks, vs), exact_out)
    exit()
    print("Benchmarking mha Cluster Attention grad...")
    benchmark_fn(cluster_attn_grad_mha, (qs, ks, vs), ref_attn_bwd_mha_out)
    #exit()
    print("Benchmarking mha Pallas Flash Attention grad...")
    benchmark_fn(pallas_flash_mha_grad, (qs, ks, vs), ref_attn_bwd_mha_out)
    print("Benchmarking mha CUDNN Flash Attention grad...")
    benchmark_fn(cudnn_flash_mha_grad, (qs, ks, vs), ref_attn_bwd_mha_out)
    exit()

    print("Benchmarking mha baseline clustering...")
    benchmark_fn(_do_multi_level_clustering_mha, (qs, ks, vs), (qmeta, kmeta))

    print("Benchmarking mha Kernel Attention backward...")
    benchmark_fn(attn_bwd_mha, (qs, ks, vs, qmeta, kmeta, ref_attn_res_mha, dlse, dv_out), ref_attn_bwd_mha_out)
    print("Benchmarking mha Kernel Attention...")
    benchmark_fn(attn_mha, (qs, ks, vs, qmeta, kmeta), exact_out)
    print("Benchmarking mha Kernel Attention grad...")
    benchmark_fn(attn_grad_mha, (qs, ks, vs, qmeta, kmeta), ref_attn_bwd_mha_out)

    print("Benchmarking mha Cluster Attention grad...")
    benchmark_fn(cluster_attn_grad_mha, (qs, ks, vs), ref_attn_bwd_mha_out)

    print("Benchmarking mha Pallas Flash Attention grad...")
    benchmark_fn(pallas_flash_mha_grad, (qs, ks, vs), ref_attn_bwd_mha_out)
    print("Benchmarking mha CUDNN Flash Attention grad...")
    benchmark_fn(cudnn_flash_mha_grad, (qs, ks, vs), ref_attn_bwd_mha_out)

    exit()

    print("Benchmarking mha CUDNN Flash Attention backward...")
    benchmark_fn(cudnn_flash_mha_vjp, ((dlse, dv_out),), cudnn_flash_mha_vjp_out)
    print("Benchmarking mha Pallas Flash Attention backward...")
    benchmark_fn(pallas_flash_mha_vjp, ((dlse, dv_out),), pallas_flash_mha_vjp_out)
    exit()
    print("Benchmarking oh Reference Attention auto VJP...")
    benchmark_fn(ref_attn_oh_vjp, ((oh_dlse, oh_dv_out),), ref_attn_oh_vjp_out_fp32)
    print("Benchmarking oh Reference Attention backward...")
    benchmark_fn(ref_attn_bwd, (oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta, ref_attn_res, oh_dlse, oh_dv_out), ref_attn_oh_vjp_out_fp32)
    print("Benchmarking mha Reference Attention backward...")
    benchmark_fn(ref_attn_bwd_mha, (qs, ks, vs, qmeta, kmeta, ref_attn_res_mha, dlse, dv_out), ref_attn_bwd_mha_out)

    exit()

    print("Benchmarking mha CUDNN Flash Attention...")
    benchmark_fn(cudnn_flash_mha, (qs, ks, vs), exact_out)
    #print("Benchmarking Pallas Flash Attention...")
    #benchmark_fn(partial(pallas_flash, sm_scale=sm_scale), (oh_qs, oh_ks, oh_vs), exact_oh_out)
    print("Benchmarking mha Pallas Flash Attention...")
    benchmark_fn(pallas_flash_mha, (qs, ks, vs), exact_out)
    print("Benchmarking Baseline Approximation...")
    benchmark_fn(baseline_approx_mha, (qs, ks, vs, qmeta, kmeta), exact_out)
    #print("Benchmarking Reference Attention...")
    #benchmark_fn(ref_attn, (oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta), exact_oh_out)
    #print("Benchmarking Kernel Attention...")
    #benchmark_fn(attn, (oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta), exact_oh_out)

    print("Benchmarking mha Reference Attention...")
    benchmark_fn(ref_attn_mha, (qs, ks, vs, qmeta, kmeta), exact_out)
    print("Benchmarking mha Kernel Attention...")
    benchmark_fn(attn_mha, (qs, ks, vs, qmeta, kmeta), exact_out)



if __name__ == "__main__":
    main()

    

