import jax
from jax import numpy as jnp, lax, vmap
from jax.numpy import einsum
from einshape import jax_einshape as einshape
from functools import partial
from jax.experimental import pallas as pl
from jax.experimental.pallas import triton as plgpu
from jax.scipy.special import logsumexp
import dataclasses
from time import time
import math
import numpy as np
from jax import Array
from flax import struct
from matplotlib import pyplot as plt
import scipy


from .pallas_monopole import attn as pallas_fma
from .pallas_cluster import do_clustering as pallas_do_clustering
from .flash_attention import attn as pallas_flash_noscale
from . import pallas_monopole as pmlib

def get_joint_basis(q, k):
    N, D = q.shape
    q_bar = jnp.mean(q, axis=0)
    k_bar = jnp.mean(k, axis=0)
    q_res = q - q_bar[None, :]
    k_res = k - k_bar[None, :]
    qcov = einsum("nd,ne->de", q_res, q_res) / D
    kcov = einsum("nd,ne->de", k_res, k_res) / D
    # now eigendecompose QK
    S, V = jnp.linalg.eig(jnp.matmul(qcov.astype(jnp.float32), kcov.astype(jnp.float32)))
    idx = jnp.argsort(jnp.real(S), descending=True)
    V = V[:, idx]
    V, _ = jnp.linalg.qr(V)
    #print(jnp.diag(V.T @ V))
    #exit()
    return jnp.real(V).astype(q.dtype)

def ref_preprocess(q, qmeta):
    return pmlib.ref_preprocess(q, qmeta)

def ref_global_dipole(q, k, v, qmeta, kmeta):
    N, D = q.shape
    sm_scale = 1.0 / math.sqrt(D)
    v_bar = jnp.mean(v, axis=0)
    vk = einsum("nv,nd->vd", v, k) / N
    vk_out = einsum("vd,nd->nv", vk, q) * sm_scale
    v_out = v_bar[None, :] + vk_out
    lse_out = jnp.zeros((N,), dtype=q.dtype) + jnp.log(N)
    return lse_out, v_out

def ref_global_quadrupole(q, k, v, qmeta, kmeta):
    N, D = q.shape
    M = 16
    sm_scale = 1.0 / math.sqrt(D)
    v_bar = jnp.mean(v, axis=0)
    k_bar = jnp.mean(k, axis=0)
    q_bar = jnp.mean(q, axis=0)
    qres = q - q_bar[None, :]
    vres = v - v_bar[None, :]
    kres = k - k_bar[None, :]
    basis = get_joint_basis(qres, kres)  # D, M
    basis, _ = jnp.linalg.qr(basis)
    basis, rem_basis = basis[:, :M], basis[:, M:]
    kres_b = einsum("nd,db->nb", kres, basis)
    qres_b = einsum("nd,db->nb", qres, basis)
    vk = einsum("nv,nd->vd", vres, kres) / N
    kk = einsum("np,nd->pd", kres, kres) / N
    qq = einsum("np,nd->pd", qres, qres) / N
    kk_b = einsum("np,nd->pd", kres_b, kres_b) / N
    vk_b = einsum("nv,nb->vb", vres, kres_b) / N
    qq_b = einsum("np,nb->pb", qres_b, qres_b) / N
    vkk = einsum("nv,np,nd->vpd", vres, kres, kres) / N
    vkk_b = einsum("nv,nc,nb->vcb", vres, kres_b, kres_b) / N
    kkk_b = einsum("na,nb,nc->abc", kres_b, kres_b, kres_b) / N
    k2e_b = kres_b**2 * jnp.diag(qq_b) * (sm_scale**2) # NM
    k2e_b_bar = jnp.mean(k2e_b, axis=0) # M
    vk_k2e_b = einsum("nvp,nd->vpd", einsum("nv,nd->nvd", vres,kres) - vk, k2e_b - k2e_b_bar) / N - einsum("vd,m->vdm", vk, k2e_b_bar)

    qres_r = einsum("nd,dr->nr", qres, rem_basis)
    kres_r = einsum("nd,dr->nr", kres, rem_basis)
    qqe_r = jnp.mean(qres_r**2, axis=0)
    kke_r = jnp.mean(kres_r**2, axis=0)
    k2e_r = kres_r**2 * qqe_r * (sm_scale**2)
    v_k2e_r = einsum("nv,nr->vr", vres, k2e_r - jnp.mean(k2e_r,axis=0)) / N

    v_out_monopole = v_bar[None, :]
    v_out_dipole = einsum("vd,nd->nv", vk, q) * sm_scale
    q_b = einsum("nd,db->nb", q, basis)
    v_out_dipole_cheap = einsum("vm,nm->nv", vk_b, q_b) * sm_scale
    v_out_quadrupole = 0.5 * einsum("vcb,nc,nb->nv", vkk_b, q_b, q_b) * (sm_scale **2)
    q2e_r = qres_r**2 * kke_r * (sm_scale**2) 
    v_out_quadrupole_r = 0.5 * einsum("vc,nc->nv", v_k2e_r, q2e_r)
    v_out_quadrupole = 0.5 * einsum("vcb,nc,nb->nv", vkk, q, q) * (sm_scale **2)
    q2e_b = qres_b**2 * jnp.diag(kk_b) * (sm_scale**2)
    v_out_trace_e_oct = (1.0/6.0) * einsum("vpd,np,nd->nv", vk_k2e_b, q, q2e_b) * (sm_scale**1)
    #v_out = v_out_monopole + v_out_dipole + v_out_quadrupole #+ v_out_quadrupole_r + v_out_trace_e_oct
    lse_out_monopole = jnp.zeros((N,), dtype=q.dtype) + jnp.log(N)
    lse_out_dipole = einsum("d,nd->n", k_bar, q) * sm_scale
    lse_out_quadrupole = 0.5 * einsum("pd,np,nd->n", kk, q, q) * (sm_scale **2)
    lse_out_octopole = (1.0/6.0) * einsum("abc,na,nb,nc->n", kkk_b, q_b, q_b, q_b) * (sm_scale**3)
    lse_out = lse_out_monopole + lse_out_dipole + lse_out_quadrupole + lse_out_octopole
    #v_out = v_out + v_bar * lse_out_quadrupole[:,None]
    #v_out = v_out /(1+lse_out_quadrupole)[:,None]
    v_out = v_out_monopole + jnp.reciprocal(1.0 + lse_out_quadrupole)[:,None] * (v_out_dipole + v_out_quadrupole)
    #v_out = v_out * jnp.exp(-(lse_out_quadrupole))[:,None]
    #loq = lse_out_quadrupole
    #v_out = v_out / ((1+0.405*loq)*(1+0.045*loq)*(1+0.016*loq))[:,None]
    #v_out = v_out * (((1+0.101*loq)*(1+0.0253*loq))**2/((1+0.405*loq)*(1+0.045*loq))**2)[:,None]
    #v_out = v_out * (jnp.exp(-lse_out_quadrupole) + lse_out_quadrupole)[:,None]
    #v_out = v_out * D / (D + (loq)**1)[:,None]
    return lse_out, v_out

def ref_gperf(q, k, v, qmeta, kmeta):
    M = 1024
    rng = jax.random.PRNGKey(0)
    N, D = q.shape
    sm_scale = 1.0 / math.sqrt(D)
    probes = jax.random.normal(rng, (M, D), dtype=q.dtype)
    #probes = jnp.identity(D, dtype=q.dtype) * jnp.sqrt(D)
    #probes = jnp.array(scipy.linalg.hadamard(D)[:M]).astype(q.dtype)
    probes = jnp.concatenate([probes, -probes], axis=0) # 2M, D
    mk_scores = einsum("md,nd->mn", probes, k) * sm_scale**0.5
    k_lognormalizer = logsumexp(mk_scores.astype(jnp.float32), axis=-2, keepdims=True)
    #k_lognormalizer = jnp.sum(jnp.square(k), axis=-1)[None, :] * (sm_scale**1) / 2.0 + jnp.log(M)
    k_reconstruct = einsum("mn,md->nd", jax.nn.softmax(mk_scores, axis=-2), probes)
    k_res = k*sm_scale**0.5 - k_reconstruct
    mk_scores = mk_scores - k_lognormalizer
    mk_weights = jax.nn.softmax(mk_scores, axis=-1)
    mk_lse = logsumexp(mk_scores.astype(jnp.float32), axis=-1)
    mk_v = einsum("mn,nv->mv", mk_weights, v)
    mk_k = einsum("mn,nd->md", mk_weights, k)
    mk_kres = einsum("mn,nd->md", mk_weights, k_res)

    raw_qm_scores = einsum("nd,md->nm", q, probes) * sm_scale**0.5
    raw_qm_weights = jax.nn.softmax(raw_qm_scores, axis=-1)
    raw_qm_reconstruct = einsum("nm,md->nd", raw_qm_weights, probes)
    q_res = q*sm_scale**0.5 - raw_qm_reconstruct
    extra_qm_scores = einsum("nd,md->nm", q, mk_kres) * sm_scale**0.5 + einsum("nd,md->nm", q_res, mk_k) * sm_scale**0.5
    #extra_qm_scores = einsum("nd,md->nm", q, mk_kres) * sm_scale

    qm_scores = raw_qm_scores + mk_lse[None, :] # NM
    #qm_scores = raw_qm_scores + extra_qm_scores + mk_lse[None, :] # NM
    qm_weights = jax.nn.softmax(qm_scores, axis=-1) # NM
    qm_lse = logsumexp(qm_scores.astype(jnp.float32), axis=-1)  - logsumexp(raw_qm_scores.astype(jnp.float32), axis=-1) + jnp.log(2*M)
    v_out = einsum("nm,mv->nv", qm_weights, mk_v)
    return qm_lse, v_out

def ref_vk_merge(lse_bar, vk_bar, v_bar_vk, k_bar_vk, v_bar, k_bar, q_bar, test=False):
    sm_scale = 1.0 / math.sqrt(q_bar.shape[-1])
    weights = jax.nn.softmax(lse_bar, axis=-1)
    if len(vk_bar.shape) == 3:
        expected_new_v = einsum("qd,kvd->qkv", q_bar, vk_bar) *sm_scale # QKV
        new_v_diff = (v_bar - v_bar_vk) #- expected_new_v # QKV
        extra_vk = einsum("qkv,qkd->qkvd", new_v_diff, k_bar - k_bar_vk)
        extra_vk = 2.* einsum("qk,qkvd->qvd", weights, extra_vk)
        avg_expected_new_v = einsum("qd,qk,kvd->qv", q_bar, weights, vk_bar) *sm_scale # QV
        v_bar_bar = einsum("qk,qkv->qv", weights, v_bar - v_bar_vk)
        k_bar_bar = einsum("qk,qkd->qd", weights, k_bar - k_bar_vk)
        avg_extra_vk = einsum("qv,qd->qvd", v_bar_bar - avg_expected_new_v, k_bar_bar)
        new_vk = einsum("qk,kvd->qvd", weights, vk_bar).astype(vk_bar.dtype)
        if test:
            new_vk = 0.0*new_vk + 1.0* extra_vk + 0.0* avg_extra_vk
        else:
            new_vk = new_vk + 0.* extra_vk + 0.* avg_extra_vk
    elif len(vk_bar.shape) == 4:
        new_vk = einsum("qk,qkvd->qvd", weights, vk_bar).astype(vk_bar.dtype)
    v_bar_bar = einsum("qk,qkv->qv", weights, v_bar-v_bar_vk)
    k_bar_bar = einsum("qk,qkd->qd", weights, k_bar-k_bar_vk)
    bar_bar_vk = einsum("qv,qd->qvd", v_bar_bar, k_bar_bar)
    extra_vk = einsum("qk,qkv,qkd->qvd", weights, v_bar-v_bar_vk , k_bar - k_bar_vk)
    total_vk = new_vk - 0.* extra_vk + 0.* bar_bar_vk
    if test:
        avg_v_diff = v_bar_bar - avg_expected_new_v
        return total_vk, weights, (avg_v_diff, k_bar_bar)
    return total_vk, weights

def ref_initial_full(q, k, v, kmeta):
    _k = k[kmeta.fwd, :] # KuD
    _v = v[kmeta.fwd, :] # KuV
    sm_scale = 1.0 / math.sqrt(q.shape[-1])
    attn_scores = einsum("qd,kud->qku", q, _k) * sm_scale # QKu
    attn_scores = jnp.where(kmeta.fwd[None, :, :] >= 0, attn_scores, -jnp.inf) # QKu
    attn_weights = jax.nn.softmax(attn_scores, axis=-1) # QKu
    v_out = einsum("qku,kuv->qkv", attn_weights, _v) # QKV
    k_out = einsum("qku,kud->qkd", attn_weights, _k) # QKD
    lse = logsumexp(attn_scores.astype(jnp.float32), axis=-1) # QK
    vres = _v - v_out[:, :, None, :] # QKuV
    kres = _k - k_out[:, :, None, :] # QKuD
    vk_out = einsum("qku,qkuv,qkud->qkvd", attn_weights, vres, kres) # QKVD
    return lse, k_out, v_out, vk_out, v_out, k_out

def ref_initial_full_kcov(q, k, v, kmeta):
    _k = k[kmeta.fwd, :] # KuD
    _v = v[kmeta.fwd, :] # KuV
    sm_scale = 1.0 / math.sqrt(q.shape[-1])
    attn_scores = einsum("qd,kud->qku", q, _k) * sm_scale # QKu
    attn_scores = jnp.where(kmeta.fwd[None, :, :] >= 0, attn_scores, -jnp.inf) # QKu
    attn_weights = jax.nn.softmax(attn_scores, axis=-1) # QKu
    v_out = einsum("qku,kuv->qkv", attn_weights, _v) # QKV
    k_out = einsum("qku,kud->qkd", attn_weights, _k) # QKD

    kres = _k - k_out[:, :, None, :] # QKuD
    kk_out = einsum("qku,qkup,qkud->qkpd", attn_weights, kres, kres) # QKDD
    return kk_out

def ref_final(q, q_bar, k, v, vk, m, qmeta, vkk=None):
    _q = q[qmeta.fwd, :] # QnD
    qres = _q - q_bar[:, None, :] # QnD
    sm_scale = 1.0 / math.sqrt(q.shape[-1])
    attn_scores = einsum("qnd,qkd->qnk", qres, k) * sm_scale + m[:, None, :] # QnK
    lse = logsumexp(attn_scores.astype(jnp.float32), axis=-1) # Qn
    attn_weights = jax.nn.softmax(attn_scores, axis=-1) # QnK
    v_out_monopole = einsum("qnk,qkv->qnv", attn_weights, v) # QnV
    dipole_rescale = 1.0
    v_out_dipole = einsum("qvd,qnd->qnv", vk, qres) * sm_scale * dipole_rescale # QnV
    if vkk is not None:
        (avg_v_diff, k_bar_bar) = vkk
        qres_dot_kbarbar = einsum("qnd,qd->qn", qres, k_bar_bar)
        #vkk_correction = (qres_dot_kbarbar**2)[:, :, None] * avg_v_diff[:, None, :] # QnV
        #vkk_correction = 1.*avg_v_diff[:, None, :] #* sm_scale # QnV
        vkk_correction = 3.*(jnp.cosh(qres_dot_kbarbar/sm_scale)-1.)[:, :, None] * avg_v_diff[:, None, :] # QnV
        correction_weights = (attn_scores - lse[:,:,None])
        vkk_correction = 1e-2*einsum("qnk,qkv->qnv", correction_weights, v)**2
        #v_out_dipole = v_out_dipole + vkk_correction

    v_out = v_out_monopole + v_out_dipole
    return lse[qmeta.lab, qmeta.bwd], v_out[qmeta.lab, qmeta.bwd], v_out_monopole[qmeta.lab, qmeta.bwd]

def ref_final_retrieval(q, k, v, q_bar, k_bar, v_bar, m, qmeta, kmeta, vk=None, kk=None, vk_to_merge=None):
    _q = q[qmeta.fwd, :] # QnD
    T, D = q.shape
    S, V = v.shape
    Q,N,D = _q.shape
    K,U = kmeta.fwd.shape
    qres = _q - q_bar[:, None, :] # QnD
    sm_scale = 1.0 / math.sqrt(q.shape[-1])
    naive_scores = m[:, None, :] + jnp.zeros(N, dtype=m.dtype)[None,:,None]# QnK
    naive_priorities = jnp.argsort(naive_scores, axis=-1)[:, :, ::-1] # QnK
    attn_scores = einsum("qnd,qkd->qnk", qres, k_bar) * sm_scale + m[:, None, :] # QnK
    query_priority_scores = attn_scores
    if kk is not None:
        extra_score_kk = 0.5 * einsum("qnd,qnp,qkpd->qnk", qres, qres, kk) * (sm_scale **2) # QnK
        query_priority_scores = query_priority_scores + extra_score_kk
    query_priorities = jnp.argsort(query_priority_scores, axis=-1)[:, :, ::-1] # QnK
    retrieval_scores = einsum("qd,kd->qk", q, k) * sm_scale # TS
    valid_q = jnp.arange(N)[None, :] < qmeta.cnt[:, None] # Qn
    valid_k = jnp.arange(U)[None, :] < kmeta.cnt[:, None] # Ku
    oracle_scores = logsumexp(retrieval_scores[qmeta.fwd[:,:,None,None], kmeta.fwd[None,None,:,:]].astype(jnp.float32), axis=-1, where=valid_k[None,None,:,:]) # QnK
    oracle_priorities = jnp.argsort(oracle_scores, axis=-1)[:, :, ::-1] # QnK
    total_queries = max(K // 16, 1)
    #total_queries = 4
    #ret_idx = oracle_priorities[:, :, :total_queries] # Qnr
    ret_idx = query_priorities[:, :, :total_queries] # Qnr
    #ret_idx = naive_priorities[:, :, :total_queries] # Qnr
    retrieval_mask = jnp.zeros((T, S), dtype=bool)
    retrieval_mask = retrieval_mask.at[qmeta.fwd[:,:,None,None], kmeta.fwd[ret_idx,:]].set(valid_q[:, :, None, None] & valid_k[ret_idx, :])
    retrieval_scores = jnp.where(retrieval_mask, retrieval_scores, -jnp.inf)
    lse_ret = logsumexp(retrieval_scores.astype(jnp.float32), axis=-1) # T
    attn_scores_rest = attn_scores.at[jnp.arange(Q)[:, None, None], jnp.arange(N)[None, :, None], ret_idx].set(-jnp.inf) # QnK
    #attn_scores_rest = attn_scores
    #attn_scores_not_rest = jnp.full_like(attn_scores, -jnp.inf).at[jnp.arange(Q)[:, None, None], jnp.arange(N)[None, :, None], ret_idx].set(attn_scores[jnp.arange(Q)[:, None, None], jnp.arange(N)[None, :, None], ret_idx]) # QnK
    lse_rest = logsumexp(attn_scores_rest.astype(jnp.float32), axis=-1) # Qn
    lse_total = jnp.logaddexp(lse_rest, lse_ret[qmeta.fwd]) # Qn
    #lse_total = lse_rest
    lse_original = logsumexp(attn_scores.astype(jnp.float32), axis=-1) # Qn
    lse_total_ret = lse_total[qmeta.lab, qmeta.bwd]
    attn_weights_ret = jnp.exp(retrieval_scores.astype(jnp.float32) - lse_total_ret[:, None]) # TS
    v_out_ret = einsum("qk,kv->qv", attn_weights_ret, v) # TV
    attn_weights_rest = jnp.exp(attn_scores_rest.astype(jnp.float32) - lse_total[:,:, None]) # QnK
    v_out_rest = einsum("qnk,qkv->qnv", attn_weights_rest, v_bar) # QnV
    if vk_to_merge is not None:
        m_ = m[:,None,:] + jnp.zeros(N, dtype=m.dtype)[None, :, None] # QnK
        non_ret_m = m_.at[jnp.arange(Q)[:, None,None], jnp.arange(N)[None,:,None], ret_idx].set(-jnp.inf) # QnK
        mass_fraction = jnp.exp(lse_original - lse_total) # Qn
        merge_weights = jax.nn.softmax(m, axis=-1) #/ kmeta.cnt[None,:]# QK
        vk_dipole = einsum("qk,kvd->qvd", merge_weights, vk_to_merge) # QVD
        v_out_dipole_rest = einsum("qvd, qnd->qnv", vk_dipole, qres) * sm_scale
        v_out_rest = v_out_rest + v_out_dipole_rest * mass_fraction[:, :, None]
        
        _v = v[kmeta.fwd, :] # KuV
        vres = _v - jnp.mean(_v, axis=-2, keepdims=True, where=(kmeta.fwd[:, :, None] >=0))
        _k = k[kmeta.fwd, :] # KuD
        kres = _k - jnp.mean(_k, axis=-2, keepdims=True, where=(kmeta.fwd[:, :, None] >=0))
        qres_dot_k = einsum("td,sd->ts", qres[qmeta.lab, qmeta.bwd], kres[kmeta.lab,kmeta.bwd]) * sm_scale # TS
        scaled_qres_dot_k = qres_dot_k.at[qmeta.fwd[:,:,None,None], kmeta.fwd[None,None,:,:]].mul((merge_weights/kmeta.cnt)[:, None, :, None]) # TS
        scaled_qres_dot_k = jnp.where(retrieval_mask, scaled_qres_dot_k, 0.0)
        dipole_recon = einsum("ts,sv->tv", scaled_qres_dot_k, vres[kmeta.lab, kmeta.bwd])
        v_out_rest = v_out_rest - dipole_recon[qmeta.fwd, :] * mass_fraction[:, :, None]
    elif vk is not None:
        #attn_weights_not_rest = jnp.exp(attn_scores_not_rest.astype(jnp.float32) - lse_total[:,:, None]) # QnK
        if len(vk.shape) == 3:
            v_out_dipole_rest = einsum("kvd,qnd,qnk->qnv", vk, qres, attn_weights_rest) * sm_scale
        if len(vk.shape) == 4:
            v_out_dipole_rest = einsum("qkvd,qnd,qnk->qnv", vk, qres, attn_weights_rest) * sm_scale
        v_out_rest = v_out_rest + v_out_dipole_rest
    v_out = v_out_rest + v_out_ret[qmeta.fwd, :]
    lse_out = lse_total
    return lse_out[qmeta.lab, qmeta.bwd], v_out[qmeta.lab, qmeta.bwd]

def ref_final_full(q, q_bar, k, v, vk, m, qmeta, kk=None):
    _q = q[qmeta.fwd, :] # QnD
    qres = _q - q_bar[:, None, :] # QnD
    sm_scale = 1.0 / math.sqrt(q.shape[-1])
    attn_scores = einsum("qnd,qkd->qnk", qres, k) * sm_scale + m[:, None, :] # QnK
    if kk is not None:
        extra_score_kk = 0.5 * einsum("qnd,qnp,qkpd->qnk", qres, qres, kk) * (sm_scale **2) # QnK
        attn_scores = attn_scores + extra_score_kk
    lse = logsumexp(attn_scores.astype(jnp.float32), axis=-1) # Qn
    lse_base = logsumexp(m.astype(jnp.float32), axis=-1) # Q
    lse_excess = lse - lse_base[:, None] # Qn
    attn_weights = jax.nn.softmax(attn_scores, axis=-1) # QnK
    v_out_monopole = einsum("qnk,qkv->qnv", attn_weights, v) # QnV
    dipole_rescale = 1.0
    dipole_rescale = jnp.exp(0.0*lse_excess)[:, :, None]
    #dip_attn_weights = jnp.exp(attn_scores.astype(jnp.float32) - lse_base[:, None, None]) # QnK
    dip_attn_weights = attn_weights
    if len(vk.shape) == 3:
        v_out_dipole = einsum("kvd,qnd,qnk->qnv", vk, qres, dip_attn_weights) * sm_scale * dipole_rescale # QnV
    if len(vk.shape) == 4:
        v_out_dipole = einsum("qkvd,qnd,qnk->qnv", vk, qres, dip_attn_weights) * sm_scale * dipole_rescale # QnV
    v_out = v_out_monopole + v_out_dipole
    return lse[qmeta.lab, qmeta.bwd], v_out[qmeta.lab, qmeta.bwd], v_out_monopole[qmeta.lab, qmeta.bwd]

def ref_dipole_attn(q, k, v, qmeta, kmeta):
    q_bar = ref_preprocess(q, qmeta)
    lse_bar, k_bar, v_bar, vk_bar, v_bar_vk, k_bar_vk = pmlib.ref_initial(q_bar, k, v, kmeta) # QK, QKD, QKV, KVD, KV, KD
    vk_merged, merge_weights = ref_vk_merge(lse_bar, vk_bar, v_bar_vk, k_bar_vk, v_bar, k_bar, q_bar) # QVD
    lse, v_out, v_out_monopole = ref_final(q, q_bar, k_bar, v_bar, vk_merged, lse_bar, qmeta)
    #lse, v_out, v_out_monopole = ref_final_full(q, q_bar, k_bar, v_bar, vk_bar, lse_bar, qmeta)
    return lse, v_out

def ref_dipole_init_full_attn(q, k, v, qmeta, kmeta):
    q_bar = ref_preprocess(q, qmeta)
    lse_bar, k_bar, v_bar, vk_bar, v_bar_vk, k_bar_vk = ref_initial_full(q_bar, k, v, kmeta) # QK, QKD, QKV, QKVD, KV, KD
    vk_merged, merge_weights = ref_vk_merge(lse_bar, vk_bar, v_bar_vk, k_bar_vk, v_bar, k_bar, q_bar) # QVD
    lse, v_out, v_out_monopole = ref_final(q, q_bar, k_bar, v_bar, vk_merged, lse_bar, qmeta)
    return lse, v_out

def ref_dipole_full_attn(q, k, v, qmeta, kmeta):
    q_bar = ref_preprocess(q, qmeta)
    #lse_bar, k_bar, v_bar, vk_bar, v_bar_vk, k_bar_vk = pmlib.ref_initial(q_bar, k, v, kmeta) # QK, QKD, QKV, KVD, KV, KD
    lse_bar, k_bar, v_bar, vk_bar, v_bar_vk, k_bar_vk = ref_initial_full(q_bar, k, v, kmeta) # QK, QKD, QKV, QKVD, KV, KD
    #kk_bar = ref_initial_full_kcov(q_bar, k, v, kmeta) # QKDD
    #vk_merged, merge_weights = ref_vk_merge(lse_bar, vk_bar, v_bar_vk, k_bar_vk, v_bar, k_bar) # QVD
    #lse, v_out, v_out_monopole = ref_final(q, q_bar, k_bar, v_bar, vk_merged, lse_bar, qmeta)
    lse, v_out, v_out_monopole = ref_final_full(q, q_bar, k_bar, v_bar, vk_bar, lse_bar, qmeta)
    return lse, v_out

def ref_dipole_fkcov_attn(q, k, v, qmeta, kmeta):
    q_bar = ref_preprocess(q, qmeta)
    #lse_bar, k_bar, v_bar, vk_bar, v_bar_vk, k_bar_vk = pmlib.ref_initial(q_bar, k, v, kmeta) # QK, QKD, QKV, KVD, KV, KD
    lse_bar, k_bar, v_bar, vk_bar, v_bar_vk, k_bar_vk = ref_initial_full(q_bar, k, v, kmeta) # QK, QKD, QKV, QKVD, KV, KD
    kk_bar = ref_initial_full_kcov(q_bar, k, v, kmeta) # QKDD
    #vk_merged, merge_weights = ref_vk_merge(lse_bar, vk_bar, v_bar_vk, k_bar_vk, v_bar, k_bar) # QVD
    #lse, v_out, v_out_monopole = ref_final(q, q_bar, k_bar, v_bar, vk_merged, lse_bar, qmeta)
    lse, v_out, v_out_monopole = ref_final_full(q, q_bar, k_bar, v_bar, vk_bar, lse_bar, qmeta, kk=kk_bar)
    return lse, v_out

def ref_monopole_attn(q, k, v, qmeta, kmeta):
    q_bar = ref_preprocess(q, qmeta)
    lse_bar, k_bar, v_bar, vk_bar, v_bar_vk, k_bar_vk = pmlib.ref_initial(q_bar, k, v, kmeta) # QK, QKD, QKV, KVD, KV, KD
    vk_merged, merge_weights, vkk = ref_vk_merge(lse_bar, vk_bar, v_bar_vk, k_bar_vk, v_bar, k_bar, q_bar, test=True) # QVD
    lse, v_out, v_out_monopole = ref_final(q, q_bar, k_bar, v_bar, vk_merged, lse_bar, qmeta, vkk=vkk)
    return lse, v_out_monopole

def ref_monopole_retrieval_attn(q, k, v, qmeta, kmeta):
    q_bar = ref_preprocess(q, qmeta)
    lse_bar, k_bar, v_bar, vk_bar, v_bar_vk, k_bar_vk = pmlib.ref_initial(q_bar, k, v, kmeta) # QK, QKD, QKV, KVD, KV, KD
    lse, v_out = ref_final_retrieval(q, k, v, q_bar, k_bar, v_bar, lse_bar, qmeta, kmeta)
    return lse, v_out

def ref_dipole_retrieval_attn(q, k, v, qmeta, kmeta):
    q_bar = ref_preprocess(q, qmeta)
    lse_bar, k_bar, v_bar, vk_bar, v_bar_vk, k_bar_vk = pmlib.ref_initial(q_bar, k, v, kmeta)
    vk_merged, merge_weights = ref_vk_merge(lse_bar, vk_bar, v_bar_vk, k_bar_vk, v_bar, k_bar, q_bar)
    lse, v_out = ref_final_retrieval(q, k, v, q_bar, k_bar, v_bar, lse_bar, qmeta, kmeta, vk_to_merge=vk_bar)
    #lse, v_out = ref_final_retrieval(q, k, v, q_bar, k_bar, v_bar, lse_bar, qmeta, kmeta, vk=vk_bar)
    return lse, v_out

def ref_full_dipole_retrieval_attn(q, k, v, qmeta, kmeta):
    q_bar = ref_preprocess(q, qmeta)
    lse_bar, k_bar, v_bar, vk_bar, v_bar_vk, k_bar_vk = ref_initial_full(q_bar, k, v, kmeta) # QK, QKD, QKV, QKVD, KV, KD
    #kk_bar = ref_initial_full_kcov(q_bar, k, v, kmeta) # QKDD
    #lse, v_out = ref_final_retrieval(q, k, v, q_bar, k_bar, v_bar, lse_bar, qmeta, kmeta, vk=vk_bar, kk=kk_bar)
    lse, v_out = ref_final_retrieval(q, k, v, q_bar, k_bar, v_bar, lse_bar, qmeta, kmeta, vk=vk_bar)
    return lse, v_out

def ref_coarseq_attn(q, k, v, qmeta, kmeta):
    q_bar = ref_preprocess(q, qmeta)
    lse_bar, k_bar, v_bar, vk_bar, v_bar_vk, k_bar_vk = pmlib.ref_initial(q_bar, k, v, kmeta) # QK, QKD, QKV, KVD, KV, KD
    lse = logsumexp(lse_bar.astype(jnp.float32), axis=-1)
    weights = jax.nn.softmax(lse_bar, axis=-1)
    v_out = einsum("nm,nmv->nv", weights, v_bar)
    return lse[qmeta.lab], v_out[qmeta.lab]

def ref_coarseq_per_kcluster(q, k, v, qmeta, kmeta):
    q_bar = ref_preprocess(q, qmeta)
    lse_bar, k_bar, v_bar, vk_bar, v_bar_vk, k_bar_vk = pmlib.ref_initial(q_bar, k, v, kmeta) # QK, QKD, QKV, KVD, KV, KD
    return lse_bar, v_bar

def pallas_flash(qs, ks, vs):
    N, D = qs.shape
    sm_scale = 1.0 / math.sqrt(D)
    return pallas_flash_noscale(qs, ks, vs, sm_scale=sm_scale)

def attn_scores(qs, ks):
    N, D = qs.shape
    sm_scale = 1.0 / math.sqrt(D)
    scores = einsum("nd,md->nm", qs, ks) * sm_scale
    return scores

def coarse_q_k_mean(qmean, ks, kcnt):
    scores = einsum("d,md->m", qmean, ks)
    weights = jax.nn.softmax(scores, where=jnp.arange(ks.shape[0]) < kcnt, axis=0)
    kmean = einsum("m,md->d", weights, ks)
    return kmean

def show_scores(qs, ks, qmeta, kmeta):
    N, D = ks.shape
    sm_scale = 1.0 / math.sqrt(qs.shape[-1])
    qmap = jnp.argsort(qmeta.lab)
    kmap = jnp.argsort(kmeta.lab)
    bucketed_qs = qs[qmeta.fwd]
    qmeans = jnp.mean(bucketed_qs, axis=1, where=jnp.arange(bucketed_qs.shape[1])[None,:,None] < qmeta.cnt[:, None,None])
    bucketed_ks = ks[kmeta.fwd]
    qkmeans = jax.vmap(jax.vmap(coarse_q_k_mean, in_axes=(None, 0,0)), in_axes=(0, None,None))(qmeans, bucketed_ks,kmeta.cnt)
    #kmean_for_q = qkmeans[qmeta.lab, :, :][:, kmeta.lab, :]
    qres = qs - qmeans[qmeta.lab]
    #base_scores = einsum("nd,nmd->nm", qres, kmean_for_q) + einsum("nd,md->nm", qmeans[qmeta.lab], ks)
    kmeans = jnp.mean(bucketed_ks, axis=1, where=jnp.arange(bucketed_ks.shape[1])[None,:,None] < kmeta.cnt[:, None,None])
    kres = ks - kmeans[kmeta.lab]
    base_scores = einsum("nd,md->nm", qs, kmeans[kmeta.lab]) + einsum("nd,md->nm", qmeans[qmeta.lab], kres)
    print(f"sqmag kmean: {jnp.sum(jnp.square(ks.mean(axis=0)))}, qmean: {jnp.sum(jnp.square(qs.mean(axis=0)))}")
    print(f"mean sqmag kmeans: {jnp.mean(jnp.sum(jnp.square(kmeans), axis=-1))}, qmeans: {jnp.mean(jnp.sum(jnp.square(qmeans), axis=-1))}")
    print(f"mean sqmag ks: {jnp.mean(jnp.sum(jnp.square(ks), axis=-1))}, qs: {jnp.mean(jnp.sum(jnp.square(qs), axis=-1))}")
    #local_kmean = jnp.mean(local_ks, axis=0)
    #local_qmean = jnp.mean(local_qs, axis=0)
    #local_kmean = coarse_q_k_mean(local_qmean, local_ks)
    #print(f"sqmag qmean: {jnp.sum(jnp.square(local_qmean))}, kmean: {jnp.sum(jnp.square(local_kmean))}")
    #print(f"mean sqmag qs: {jnp.mean(jnp.sum(jnp.square(local_qs), axis=-1))}, ks: {jnp.mean(jnp.sum(jnp.square(local_ks), axis=-1))}")
    #scores = attn_scores(qres[qmap], kres[kmap])
    def make_figure(array2d, name):
        array2d = array2d[:2048, :2048]
        plt.figure(figsize=(30, 18))
        #plt.imshow(np.array(array2d.astype(jnp.float32)), vmin=-3, vmax=jnp.percentile(array2d, 90))
        plt.imshow(np.array(array2d.astype(jnp.float32)), vmin=-3, vmax=3)
        #plt.imshow(np.array(array2d.astype(jnp.float32)))
        plt.colorbar()
        plt.title(f"Attention scores: {name}")
        plt.savefig(f"visualizations/attention_matrix_{name}.png")
    scores = attn_scores(qs, ks)
    #scores = scores - logsumexp(scores, axis=-1, keepdims=True) + jnp.log(N)
    make_figure(scores, "scores")
    res_scores = attn_scores(qres, kres)
    #res_scores = res_scores - logsumexp(res_scores, axis=-1, keepdims=True) + jnp.log(N)
    make_figure(res_scores, "residual_scores")
    res_scores_nolin = jnp.log(jnp.expm1(res_scores) - res_scores)
    make_figure(res_scores_nolin, "residual_scores_nolin")
    scores_diff = scores - base_scores * sm_scale
    mapped_scores_diff = scores_diff[qmap][:, kmap]
    scores = scores_diff #- logsumexp(scores, axis=-1, keepdims=True)
    scores = scores - logsumexp(scores, axis=-1, keepdims=True) + jnp.log(N)
    errors = jnp.log(jnp.expm1(scores) - scores) #- logsumexp(scores, axis=-1, keepdims=True)
    errors = jnp.maximum(errors, -10.0)


################# Benchmarking code #################
mha_pallas_fma = vmap(vmap(pallas_fma, in_axes=1, out_axes=1))

def approx_meansq_attn_weight(qs, ks):
    N, D = qs.shape
    sm_scale = 1.0 / math.sqrt(D)
    pallas_flash = partial(pallas_flash_noscale, sm_scale=sm_scale)
    def estimate(rng):
        vs = jax.random.normal(rng, (N, D), dtype=qs.dtype)
        init_sq = jnp.mean(jnp.square(vs))
        _, out = pallas_flash(qs, ks, vs)
        out_sq = jnp.mean(jnp.square(out))
        return out_sq / init_sq
    rngs = jax.random.split(jax.random.PRNGKey(0), 10)
    estimates = jax.vmap(estimate)(rngs)
    return jnp.mean(estimates)

def analyze_retrieval_thresholds(qs, ks):
    N, D = qs.shape
    scores = attn_scores(qs, ks)
    weights = jax.nn.softmax(scores, axis=-1)
    sorted_weights = jnp.sort(weights, axis=-1, descending=True)
    cum_weights = jnp.cumsum(sorted_weights, axis=-1)
    thresholds = jnp.array([0.8, 0.9, 0.95], dtype=jnp.float32)
    def find_thresholds(cum_weights):
        def find_for_one(threshold):
            indices = jnp.argmax(cum_weights >= threshold, axis=-1)
            return indices
        return jax.vmap(find_for_one)(thresholds)
    thresh_counts = find_thresholds(cum_weights)
    sorted_sq_weights = jnp.square(sorted_weights)
    cum_sq_weights = jnp.cumsum(sorted_sq_weights, axis=-1) / jnp.sum(sorted_sq_weights, axis=-1, keepdims=True)
    thresh_sq_counts = find_thresholds(cum_sq_weights)
    for i, t in enumerate(thresholds):
        mean_count = jnp.mean(thresh_counts[i])
        mean_sq_count = jnp.mean(thresh_sq_counts[i])
        #print(f"Threshold {t}: mean count {mean_count:.1f}, mean sq count {mean_sq_count:.1f}")
        #print(f"Threshold {t}: median count {jnp.median(thresh_counts[i])}, median sq count {jnp.median(thresh_sq_counts[i])}")
        print(f"Threshold {t:.2f}: count: mean {mean_count:.1f}, median {jnp.median(thresh_counts[i])}, sq mean {mean_sq_count:.1f}, sq median {jnp.median(thresh_sq_counts[i])}")
    return thresholds, thresh_counts, thresh_sq_counts

def random_backward_probe(attn_fn, qs, ks, vs, *args, return_index=False):
    N, D = qs.shape

    probe_key = jax.random.PRNGKey(1)
    probe_index = jax.random.choice(probe_key, qs.shape[0])
    #probe_grad_elem = jax.nn.one_hot(probe_index, N, dtype=qs.dtype)
    #probe_grad = jnp.tile(probe_grad_elem[:, None], (1, D))
    #assert probe_grad.shape == (N, D)

    lse_all, out_all = attn_fn(qs, ks, vs, *args)
    lse = lse_all[probe_index]

    @jax.grad
    def get_grad(vs):
        _, out = attn_fn(qs, ks, vs, *args)
        return out[probe_index].sum()

    vs_grad = get_grad(vs)
    weight_vec = vs_grad[:,0]
    if return_index:
        return lse, weight_vec, probe_index
    else:
        return lse, weight_vec


def sq_err(ref, x):
    ref_mean = jnp.mean(ref)
    ref_msq = jnp.mean(jnp.square(ref))
    err_msq = jnp.mean(jnp.square(ref - x))
    return jnp.where(ref_msq > 0, err_msq / ref_msq, jnp.inf).astype(jnp.float32)

mha_sq_err = vmap(vmap(sq_err, in_axes=1))

def normalize_outputs(vs):
    mags = jnp.linalg.norm(vs, axis=-1, keepdims=True)
    return vs / (mags + 1e-6)

def axis_corr(ref, x):
    ref = normalize_outputs(ref)
    x = normalize_outputs(x)
    correct_axis = np.argmax(ref.shape)
    axis = correct_axis
    assert ref.shape[axis] >= max(ref.shape)
    ref_mean = jnp.mean(ref, axis=axis, keepdims=True)
    x_mean = jnp.mean(x, axis=axis, keepdims=True)
    ref_var = jnp.mean(jnp.square(ref - ref_mean), axis=axis)
    x_var = jnp.mean(jnp.square(x - x_mean), axis=axis)
    cov = jnp.mean((ref - ref_mean) * (x - x_mean), axis=axis)
    corr = cov / (jnp.sqrt(ref_var * x_var) + 1e-6)
    return corr.astype(jnp.float32).mean(axis=-1) # average over the V dimension

mha_axis_corr = vmap(vmap(axis_corr, in_axes=1))

def do_clustering(qs, ks, vs):
    EXPAND_RATIO = 4.0
    N, D = qs.shape
    K = 64
    ITERS = 5
    max_cluster_size = math.ceil(EXPAND_RATIO * N / K)
    qmeta, kmeta = pallas_do_clustering(K, max_cluster_size, ITERS, qs, ks, vs)
    return qmeta, kmeta

mha_do_clustering = vmap(vmap(do_clustering, in_axes=1, out_axes=1))

def get_data(dtype, n, only_first_batch=False):
    data = jnp.load("examples/minigpt/qkv_data_64_T1/qkv_step_15200.npz")
    #data = jnp.load("examples/minigpt/qkv_data_64_T1/qkv_step_1600.npz")
    ks, vs, qs = data["k"], data["v"], data["q"] # (B, N, H, D)
    if only_first_batch:
        print(f"WARNING: only using first batch of data, orirginal shape {ks.shape}")
        ks, vs, qs = ks[:1], vs[:1], qs[:1]

    B,N,H,D = ks.shape
    B,N,H,V = vs.shape
    CULL = 0
    ks = jnp.array(ks).at[...,:CULL].set(0.) #* 2**0.25
    qs = jnp.array(qs).at[...,:CULL].set(0.) #* 2**0.25
    if False:
        ks = ks[...,:N//2, :, :]
        vs = vs[...,:N//2, :, :]
        qs = qs[...,-N//2:, :, :]
        N = N // 2
    assert N % n == 0, f"N={N} must be divisible by n={n}"
    M = N // n
    qs = einshape("b(mn)hd->(bm)nhd", qs, m=M, n=n)
    ks = einshape("b(mn)hd->(bm)nhd", ks, m=M, n=n)
    vs = einshape("b(mn)hv->(bm)nhv", vs, m=M, n=n)
    return qs.astype(dtype), ks.astype(dtype), vs.astype(dtype)

def main():
    dtype = jnp.float32
    requested_n = 2**13
    ONLY_FIRST_BATCH = False

    qs, ks, vs = get_data(dtype, requested_n, only_first_batch=ONLY_FIRST_BATCH)
    qs = qs - qs.mean(axis=-3, keepdims=True)
    ks = ks - ks.mean(axis=-3, keepdims=True)
    TEMP = 1.0
    qs = qs / jnp.sqrt(TEMP).astype(dtype)
    ks = ks / jnp.sqrt(TEMP).astype(dtype)
    qmeta, kmeta = mha_do_clustering(qs, ks, vs)
    B,N,H,D = qs.shape
    _,_,_,V = vs.shape
    sm_scale = 1.0 / math.sqrt(D)
    mha_pallas_flash = vmap(vmap(partial(pallas_flash_noscale, sm_scale=sm_scale), in_axes=1, out_axes=1))
    BATCH_ELEM = 2
    HEAD = 7
    def one_head(array):
        return array[BATCH_ELEM, :, HEAD, :]
     
    oh_qs, oh_ks, oh_vs = one_head(qs), one_head(ks), one_head(vs)
    oh_qmeta, oh_kmeta = do_clustering(oh_qs, oh_ks, oh_vs)

    oh_exact_lse, oh_exact_out = pallas_flash_noscale(oh_qs, oh_ks, oh_vs, sm_scale=sm_scale)
    oh_approx_lse, oh_approx_out = pallas_fma(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
    print(f"One head exact vs approx lse sq err: {sq_err(oh_exact_lse, oh_approx_lse)}")
    print(f"One head exact vs approx out sq err: {sq_err(oh_exact_out, oh_approx_out)}")
    print(f"One head exact vs approx out axis corr: {axis_corr(oh_exact_out, oh_approx_out)}")

    if True:
        ref_gdip_lse, ref_gdip_out = ref_global_dipole(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        print(f"One head flash vs ref gdipole lse sq err: {sq_err(oh_exact_lse, ref_gdip_lse)}")
        print(f"One head flash vs ref gdipole out sq err: {sq_err(oh_exact_out, ref_gdip_out)}")
        print(f"One head flash vs ref gdipole out axis corr: {axis_corr(oh_exact_out, ref_gdip_out)}")
        ref_gqua_lse, ref_gqua_out = ref_global_quadrupole(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        print(f"One head flash vs ref gquad lse sq err: {sq_err(oh_exact_lse, ref_gqua_lse)}")
        print(f"One head flash vs ref gquad out sq err: {sq_err(oh_exact_out, ref_gqua_out)}")
        print(f"One head flash vs ref gquad out axis corr: {axis_corr(oh_exact_out, ref_gqua_out)}")
        ref_gperf_lse, ref_gperf_out = ref_gperf(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        print(f"One head flash vs ref gperf lse sq err: {sq_err(oh_exact_lse, ref_gperf_lse)}")
        print(f"One head flash vs ref gperf out sq err: {sq_err(oh_exact_out, ref_gperf_out)}")
        print(f"One head flash vs ref gperf out axis corr: {axis_corr(oh_exact_out, ref_gperf_out)}")
        ref_coq_lse, ref_coq_out = ref_coarseq_attn(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        print(f"One head flash vs ref coarseq lse sq err: {sq_err(oh_exact_lse, ref_coq_lse)}")
        print(f"One head flash vs ref coarseq out sq err: {sq_err(oh_exact_out, ref_coq_out)}")
        print(f"One head flash vs ref coarseq out axis corr: {axis_corr(oh_exact_out, ref_coq_out)}")
        ref_mon_lse, ref_mon_out = ref_monopole_attn(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        print(f"One head flash vs ref monopole lse sq err: {sq_err(oh_exact_lse, ref_mon_lse)}")
        print(f"One head flash vs ref monopole out sq err: {sq_err(oh_exact_out, ref_mon_out)}")
        print(f"One head flash vs ref monopole out axis corr: {axis_corr(oh_exact_out, ref_mon_out)}")
        ref_mre_lse, ref_mre_out = ref_monopole_retrieval_attn(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        print(f"One head flash vs ref monopole retrieval lse sq err: {sq_err(oh_exact_lse, ref_mre_lse)}")
        print(f"One head flash vs ref monopole retrieval out sq err: {sq_err(oh_exact_out, ref_mre_out)}")
        print(f"One head flash vs ref monopole retrieval out axis corr: {axis_corr(oh_exact_out, ref_mre_out)}")
        ref_dip_lse, ref_dip_out = ref_dipole_attn(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        print(f"One head flash vs ref dipole lse sq err: {sq_err(oh_exact_lse, ref_dip_lse)}")
        print(f"One head flash vs ref dipole out sq err: {sq_err(oh_exact_out, ref_dip_out)}")
        print(f"One head flash vs ref dipole out axis corr: {axis_corr(oh_exact_out, ref_dip_out)}")
        ref_dre_lse, ref_dre_out = ref_dipole_retrieval_attn(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        print(f"One head flash vs ref dipole retrieval lse sq err: {sq_err(oh_exact_lse, ref_dre_lse)}")
        print(f"One head flash vs ref dipole retrieval out sq err: {sq_err(oh_exact_out, ref_dre_out)}")
        print(f"One head flash vs ref dipole retrieval out axis corr: {axis_corr(oh_exact_out, ref_dre_out)}")
        ref_difu_lse, ref_difu_out = ref_dipole_init_full_attn(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        print(f"One head flash vs ref dipole init full lse sq err: {sq_err(oh_exact_lse, ref_difu_lse)}")
        print(f"One head flash vs ref dipole init full out sq err: {sq_err(oh_exact_out, ref_difu_out)}")
        print(f"One head flash vs ref dipole init full out axis corr: {axis_corr(oh_exact_out, ref_difu_out)}")
        ref_dfu_lse, ref_dfu_out = ref_dipole_full_attn(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        print(f"One head flash vs ref dipole full lse sq err: {sq_err(oh_exact_lse, ref_dfu_lse)}")
        print(f"One head flash vs ref dipole full out sq err: {sq_err(oh_exact_out, ref_dfu_out)}")
        print(f"One head flash vs ref dipole full out axis corr: {axis_corr(oh_exact_out, ref_dfu_out)}")
        ref_fdr_lse, ref_fdr_out = ref_full_dipole_retrieval_attn(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        print(f"One head flash vs ref full dipole retrieval lse sq err: {sq_err(oh_exact_lse, ref_fdr_lse)}")
        print(f"One head flash vs ref full dipole retrieval out sq err: {sq_err(oh_exact_out, ref_fdr_out)}")
        print(f"One head flash vs ref full dipole retrieval out axis corr: {axis_corr(oh_exact_out, ref_fdr_out)}")
        ref_fkc_lse, ref_fkc_out = ref_dipole_fkcov_attn(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        print(f"One head flash vs ref dipole fkcov lse sq err: {sq_err(oh_exact_lse, ref_fkc_lse)}")
        print(f"One head flash vs ref dipole fkcov out sq err: {sq_err(oh_exact_out, ref_fkc_out)}")
        print(f"One head flash vs ref dipole fkcov out axis corr: {axis_corr(oh_exact_out, ref_fkc_out)}")

    if True:
        msqaw = float(approx_meansq_attn_weight(oh_qs, oh_ks))
        effective_n = 1.0 / msqaw
        selectivity = N*msqaw
        print(f"One head msqaw: {msqaw:.4f}, effective_n: {effective_n:.0f}, selectivity: {selectivity:.0f}")
    if True:
        analyze_retrieval_thresholds(oh_qs, oh_ks)

    if True:
        rn_lse, rn_weights, qidx = random_backward_probe(pallas_flash, oh_qs, oh_ks, oh_vs, return_index=True)
        rn_msqaw = jnp.sum(jnp.square(rn_weights))
        rn_effective_n = 1.0 / rn_msqaw
        rn_selectivity = N * rn_msqaw
        print(f"Random probe msqaw: {rn_msqaw:.4f}, effective_n: {rn_effective_n:.0f}, selectivity: {rn_selectivity:.0f}, at qidx {qidx}")
        #oh_q_bar = ref_preprocess(oh_qs, oh_qmeta)
        #oh_qs = oh_qs.at[qidx, :].set(oh_q_bar[oh_qmeta.lab[qidx], :])
        #rn_lse, rn_weights, qidx = random_backward_probe(pallas_flash, oh_qs, oh_ks, oh_vs, return_index=True)
        qlab = oh_qmeta.lab[qidx]
        rn_weight_per_k = jax.ops.segment_sum(rn_weights, oh_kmeta.lab, num_segments=oh_kmeta.cnt.shape[0])
        sq_weights = jnp.square(rn_weights)
        cum_sq_weights = jnp.cumsum(jnp.sort(sq_weights, axis=0)[::-1], axis=0) / jnp.sum(sq_weights)
        plt.plot(np.array(cum_sq_weights.astype(jnp.float32)))
        plt.title(f"Head {HEAD} batch {BATCH_ELEM} cumulative squared attention weights for random probe")
        plt.xscale("log")
        plt.savefig("visualizations/rn_cum_sq_weights.png")
        cum_weights = jnp.cumsum(jnp.sort(rn_weights, axis=0)[::-1], axis=0)
        plt.plot(np.array(cum_weights.astype(jnp.float32)))
        plt.title(f"Head {HEAD} batch {BATCH_ELEM} cumulative attention weights for random probe")
        plt.xscale("log")
        plt.savefig("visualizations/rn_cum_weights.png")

        # clear figure
        plt.clf()
        # make large figure now
        plt.figure(figsize=(30, 18))
        #rn_fma_lse, rn_fma_weights = random_backward_probe(pallas_fma, oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        #rn_fma_relsqerr = jnp.sum(jnp.square(rn_weights - rn_fma_weights)) / jnp.sum(jnp.square(rn_weights))
        #print(f"Random probe flash vs fma weights rel sq err: {rn_fma_relsqerr:.6f}")
        #print(f"Random probe flash vs fma lse: {rn_lse:.3f} vs {rn_fma_lse:.3f}")
        #rn_fma_weights_adjusted = jnp.exp(rn_fma_lse - rn_lse) * rn_fma_weights

        #rn_gdip_lse, rn_gdip_weights = random_backward_probe(ref_global_dipole, oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        #rn_gdip_relsqerr = jnp.sum(jnp.square(rn_weights - rn_gdip_weights)) / jnp.sum(jnp.square(rn_weights))
        #print(f"Random probe flash vs gdip weights rel sq err: {rn_gdip_relsqerr:.6f}")
        #print(f"Random probe flash vs gdip lse: {rn_lse:.3f} vs {rn_gdip_lse:.3f}")
        #rn_gdip_weights_adjusted = jnp.exp(rn_gdip_lse - rn_lse) * rn_gdip_weights

        #rn_gqua_lse, rn_gqua_weights = random_backward_probe(ref_global_quadrupole, oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        #rn_gqua_relsqerr = jnp.sum(jnp.square(rn_weights - rn_gqua_weights)) / jnp.sum(jnp.square(rn_weights))
        #print(f"Random probe flash vs gquad weights rel sq err: {rn_gqua_relsqerr:.6f}")
        #print(f"Random probe flash vs gquad lse: {rn_lse:.3f} vs {rn_gqua_lse:.3f}")
        #rn_gqua_weights_adjusted = jnp.exp(rn_gqua_lse - rn_lse) * rn_gqua_weights

        #rn_gperf_lse, rn_gperf_weights = random_backward_probe(ref_gperf, oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        #rn_gperf_relsqerr = jnp.sum(jnp.square(rn_weights - rn_gperf_weights)) / jnp.sum(jnp.square(rn_weights))
        #print(f"Random probe flash vs gperf weights rel sq err: {rn_gperf_relsqerr:.6f}")
        #print(f"Random probe flash vs gperf lse: {rn_lse:.3f} vs {rn_gperf_lse:.3f}")
        #rn_gperf_weights_adjusted = jnp.exp(rn_gperf_lse - rn_lse) * rn_gperf_weights

        rn_mon_lse, rn_mon_weights = random_backward_probe(ref_monopole_attn, oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        rn_mon_relsqerr = jnp.sum(jnp.square(rn_weights - rn_mon_weights)) / jnp.sum(jnp.square(rn_weights))
        print(f"Random probe flash vs mon weights rel sq err: {rn_mon_relsqerr:.6f}")
        print(f"Random probe flash vs mon lse: {rn_lse:.3f} vs {rn_mon_lse:.3f}")
        rn_mon_weights_adjusted = jnp.exp(rn_mon_lse - rn_lse) * rn_mon_weights

        rn_mre_lse, rn_mre_weights = random_backward_probe(ref_monopole_retrieval_attn, oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        rn_mre_relsqerr = jnp.sum(jnp.square(rn_weights - rn_mre_weights)) / jnp.sum(jnp.square(rn_weights))
        print(f"Random probe flash vs mre weights rel sq err: {rn_mre_relsqerr:.6f}")
        print(f"Random probe flash vs mre lse: {rn_lse:.3f} vs {rn_mre_lse:.3f}")
        rn_mre_weights_adjusted = jnp.exp(rn_mre_lse - rn_lse) * rn_mre_weights

        rn_dip_lse, rn_dip_weights = random_backward_probe(ref_dipole_attn, oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        rn_dip_relsqerr = jnp.sum(jnp.square(rn_weights - rn_dip_weights)) / jnp.sum(jnp.square(rn_weights))
        print(f"Random probe flash vs dip weights rel sq err: {rn_dip_relsqerr:.6f}")
        print(f"Random probe flash vs dip lse: {rn_lse:.3f} vs {rn_dip_lse:.3f}")
        rn_dip_weights_adjusted = jnp.exp(rn_dip_lse - rn_lse) * rn_dip_weights

        rn_dre_lse, rn_dre_weights = random_backward_probe(ref_dipole_retrieval_attn, oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        rn_dre_relsqerr = jnp.sum(jnp.square(rn_weights - rn_dre_weights)) / jnp.sum(jnp.square(rn_weights))
        print(f"Random probe flash vs dre weights rel sq err: {rn_dre_relsqerr:.6f}")
        print(f"Random probe flash vs dre lse: {rn_lse:.3f} vs {rn_dre_lse:.3f}")
        rn_dre_weights_adjusted = jnp.exp(rn_dre_lse - rn_lse) * rn_dre_weights
        
        rn_coq_lse, rn_coq_weights = random_backward_probe(ref_coarseq_attn, oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        rn_coq_relsqerr = jnp.sum(jnp.square(rn_weights - rn_coq_weights)) / jnp.sum(jnp.square(rn_weights))
        print(f"Random probe flash vs coq weights rel sq err: {rn_coq_relsqerr:.6f}")
        print(f"Random probe flash vs coq lse: {rn_lse:.3f} vs {rn_coq_lse:.3f}")
        rn_coq_weights_adjusted = jnp.exp(rn_coq_lse - rn_lse) * rn_coq_weights
        #rn_coq_weights_adjusted = jnp.ones_like(rn_coq_weights_adjusted)

        rn_difu_lse, rn_difu_weights = random_backward_probe(ref_dipole_init_full_attn, oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        rn_difu_relsqerr = jnp.sum(jnp.square(rn_weights - rn_difu_weights)) / jnp.sum(jnp.square(rn_weights))
        print(f"Random probe flash vs difu weights rel sq err: {rn_difu_relsqerr:.6f}")
        print(f"Random probe flash vs difu lse: {rn_lse:.3f} vs {rn_difu_lse:.3f}")
        rn_difu_weights_adjusted = jnp.exp(rn_difu_lse - rn_lse) * rn_difu_weights

        rn_dfu_lse, rn_dfu_weights = random_backward_probe(ref_dipole_full_attn, oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        rn_dfu_relsqerr = jnp.sum(jnp.square(rn_weights - rn_dfu_weights)) / jnp.sum(jnp.square(rn_weights))
        print(f"Random probe flash vs dfu weights rel sq err: {rn_dfu_relsqerr:.6f}")
        print(f"Random probe flash vs dfu lse: {rn_lse:.3f} vs {rn_dfu_lse:.3f}")
        rn_dfu_weights_adjusted = jnp.exp(rn_dfu_lse - rn_lse) * rn_dfu_weights

        rn_fdr_lse, rn_fdr_weights = random_backward_probe(ref_full_dipole_retrieval_attn, oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        rn_fdr_relsqerr = jnp.sum(jnp.square(rn_weights - rn_fdr_weights)) / jnp.sum(jnp.square(rn_weights))
        print(f"Random probe flash vs fdr weights rel sq err: {rn_fdr_relsqerr:.6f}")
        print(f"Random probe flash vs fdr lse: {rn_lse:.3f} vs {rn_fdr_lse:.3f}")
        rn_fdr_weights_adjusted = jnp.exp(rn_fdr_lse - rn_lse) * rn_fdr_weights

        rn_fkc_lse, rn_fkc_weights = random_backward_probe(ref_dipole_fkcov_attn, oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        rn_fkc_relsqerr = jnp.sum(jnp.square(rn_weights - rn_fkc_weights)) / jnp.sum(jnp.square(rn_weights))
        print(f"Random probe flash vs fkc weights rel sq err: {rn_fkc_relsqerr:.6f}")
        print(f"Random probe flash vs fkc lse: {rn_lse:.3f} vs {rn_fkc_lse:.3f}")
        rn_fkc_weights_adjusted = jnp.exp(rn_fkc_lse - rn_lse) * rn_fkc_weights

        coq_lse_per_k, coq_v_per_k = ref_coarseq_per_kcluster(oh_qs, oh_ks, oh_vs, oh_qmeta, oh_kmeta)
        #cluster_score = coq_lse_per_k[qlab] - jnp.log(oh_kmeta.cnt)
        cluster_score = rn_weight_per_k / oh_kmeta.cnt
        kcluster_sort = jnp.argsort(cluster_score)[::-1]
        kcluster_rank = jnp.zeros_like(kcluster_sort).at[kcluster_sort].set(jnp.arange(kcluster_sort.shape[0]))

        #rn_coq_weights_adjusted = jnp.ones_like(rn_coq_weights_adjusted)

        #sort_indices = jnp.argsort(rn_weights, axis=0)[::-1]
        sort_indices = jnp.argsort(rn_weights/rn_coq_weights_adjusted, axis=0)[::-1]
        #sort_indices = jnp.argsort(rn_weights-rn_coq_weights_adjusted, axis=0)[::-1]
        sorted_klabels = kcluster_rank[oh_kmeta.lab][sort_indices]
        sort_indices = sort_indices[jnp.argsort(sorted_klabels, axis=0, stable=True)]
        
        #rn_coq_weights_adjusted = jnp.ones_like(rn_coq_weights_adjusted)

        xpos = np.arange(rn_weights.shape[0])/rn_weights.shape[0]
        rn_coq_cumsum = jnp.cumsum(rn_coq_weights_adjusted[sort_indices])
        xpos = jnp.zeros_like(xpos).at[1:].set(rn_coq_cumsum[:-1]/rn_coq_cumsum[-1])

        # plot with labels and legend
        #plt.plot(np.array(rn_weights[sort_indices].astype(jnp.float32)), label="flash")
        #plt.plot(np.array(rn_fma_weights[sort_indices].astype(jnp.float32)), label="fma_unadjusted")
        #plt.plot(np.array(rn_fma_weights_adjusted[sort_indices].astype(jnp.float32)), label="fma_adjusted")
        #plt.plot(np.array(rn_mon_weights_adjusted[sort_indices].astype(jnp.float32)), label="monopole_adjusted")
        #plt.plot(np.array(jnp.maximum(1e-5, rn_dip_weights_adjusted[sort_indices].astype(jnp.float32))), label="dipole_adjusted")
        #plt.plot(np.array(rn_dip_weights_adjusted[sort_indices].astype(jnp.float32)), label="dipole_adjusted")
        #plt.plot(np.array(rn_coq_weights_adjusted[sort_indices].astype(jnp.float32)), label="coarseq_adjusted")
        #rn_coq_weights_adjusted = jnp.ones_like(rn_coq_weights_adjusted)
        plt.plot(xpos, np.array((rn_weights/rn_coq_weights_adjusted)[sort_indices].astype(jnp.float32)), label="flash / coarseq")
        #plt.plot(np.array((rn_fma_weights_adjusted/rn_coq_weights_adjusted)[sort_indices].astype(jnp.float32)), label="fma / coarseq")
        #plt.plot(xpos, np.array((rn_gdip_weights_adjusted/rn_coq_weights_adjusted)[sort_indices].astype(jnp.float32)), label="gdipole / coarseq")
        #plt.plot(xpos, np.array((rn_gqua_weights_adjusted/rn_coq_weights_adjusted)[sort_indices].astype(jnp.float32)), label="gquad / coarseq")
        #plt.plot(xpos, np.array((rn_gperf_weights_adjusted/rn_coq_weights_adjusted)[sort_indices].astype(jnp.float32)), label="gperf / coarseq", alpha=0.5)
        plt.plot(xpos, np.array((rn_mon_weights_adjusted/rn_coq_weights_adjusted)[sort_indices].astype(jnp.float32)), label="monopole / coarseq", alpha=0.5)
        plt.plot(xpos, np.array((rn_mre_weights_adjusted/rn_coq_weights_adjusted)[sort_indices].astype(jnp.float32)), label="monopole_retrieval / coarseq", alpha=0.5)
        #plt.plot(xpos, np.array((rn_dip_weights_adjusted/rn_coq_weights_adjusted)[sort_indices].astype(jnp.float32)), label="dipole / coarseq", alpha=0.5)
        plt.plot(xpos, np.array((rn_dre_weights_adjusted/rn_coq_weights_adjusted)[sort_indices].astype(jnp.float32)), label="dipole_retrieval / coarseq", alpha=0.5)
        #plt.plot(xpos, np.array((rn_difu_weights_adjusted/rn_coq_weights_adjusted)[sort_indices].astype(jnp.float32)), label="dipole_init_full / coarseq")
        #plt.plot(xpos, np.array((rn_dfu_weights_adjusted/rn_coq_weights_adjusted)[sort_indices].astype(jnp.float32)), label="dipole_full / coarseq", alpha=0.5)
        plt.plot(xpos, np.array((rn_fdr_weights_adjusted/rn_coq_weights_adjusted)[sort_indices].astype(jnp.float32)), label="full_dipole_retrieval / coarseq", alpha=0.5)
        #plt.plot(xpos, np.array((rn_fkc_weights_adjusted/rn_coq_weights_adjusted)[sort_indices].astype(jnp.float32)), label="fkcov / coarseq")
        #plt.xscale("log")
        plt.yscale("log")
        plt.title(f"Head {HEAD} batch {BATCH_ELEM} flash vs fma attention weights for random probe")
        plt.xlim(0, 1.0)
        #plt.ylim(-5.0, 20.0)
        plt.ylim(1e-3, 1e4)
        #plt.ylim(1e-6, 1e-2)
        plt.legend()
        plt.savefig("visualizations/rn_flash_vs_fma_weights.png")

    exit()


    show_scores(oh_qs, oh_ks, oh_qmeta, oh_kmeta)

    #exit()

    exact_lse, exact_out = mha_pallas_flash(qs, ks, vs)
    approx_lse, approx_out = mha_pallas_fma(qs, ks, vs, qmeta, kmeta)
    mha_exact_approx_lse_sq_err = mha_sq_err(exact_lse, approx_lse)
    mha_exact_approx_out_sq_err = mha_sq_err(exact_out, approx_out)
    mha_exact_approx_axis_corr = mha_axis_corr(exact_out, approx_out)
    print(f"MHA exact vs approx lse sq err: mean {mha_exact_approx_lse_sq_err.mean()}, worst {mha_exact_approx_lse_sq_err.max()}, at {jnp.unravel_index(jnp.argmax(mha_exact_approx_lse_sq_err), mha_exact_approx_lse_sq_err.shape)}")
    print(f"MHA exact vs approx out sq err: mean {mha_exact_approx_out_sq_err.mean()}, worst {mha_exact_approx_out_sq_err.max()}, at {jnp.unravel_index(jnp.argmax(mha_exact_approx_out_sq_err), mha_exact_approx_out_sq_err.shape)}")
    print(f"MHA exact vs approx out axis corr: mean {mha_exact_approx_axis_corr.mean()}, worst {mha_exact_approx_axis_corr.min()}, at {jnp.unravel_index(jnp.argmin(mha_exact_approx_axis_corr), mha_exact_approx_axis_corr.shape)}")
    mha_per_head_corr = jnp.mean(mha_exact_approx_axis_corr, axis=0)
    mha_worst_to_best_heads = jnp.argsort(mha_per_head_corr)
    print(f"MHA per-head axis corr (worst to best): {mha_per_head_corr[mha_worst_to_best_heads]}, heads {mha_worst_to_best_heads}")
    mha_per_batch_corr = jnp.mean(mha_exact_approx_axis_corr, axis=1)
    mha_worst_to_best_batches = jnp.argsort(mha_per_batch_corr)
    print(f"MHA per-batch axis corr (worst to best): {mha_per_batch_corr[mha_worst_to_best_batches]}, batches {mha_worst_to_best_batches}")

if __name__ == "__main__":
    main()
