import jax
from jax import numpy as jnp
from jax.numpy import einsum
from jax.nn import softmax
import numpy as np
from chex import Array, dataclass
from typing import Optional, Tuple, Callable
from functools import partial
from jax.scipy.special import logsumexp
from einshape import jax_einshape as einshape
from jax.tree_util import tree_map
from jax.random import PRNGKey
from jax import random
from math import prod
from jax.experimental import pallas as pl

DIPOLE = True # Set to True if you want to use dipole correction in the attention mechanism

def _kmeans_centroids(
    xs: Array, # (N, D)
    labels: Array, # (N,) in [0, k)
    k: int, # number of clusters
    logmass: Array, # (N,) weights for each point
    ) -> Array: # (k, D) centroids
    def centroid(k):
        weights = softmax(logmass, where=labels == k)
        return einsum("n,nd->d", weights, xs)
    return jax.vmap(centroid)(jnp.arange(k)) # (k, D)

def _kmeans_assign_no_metric(
    xs: Array, # (N, D)
    centroids: Array, # (k, D)
    ) -> Array: # (N,) new labels
    def assign_x(x):
        sqdists = einsum("kd,kd->k", x - centroids, x - centroids)
        return jnp.argmin(sqdists) # ()
    return jax.vmap(assign_x)(xs) # (N,)

def _kmeans_assign(
    xs: Array, # (N, D)
    dual_xs: Array, # (N, D) contracted with metric
    centroids: Array, # (k, D)
    dual_centroids: Array, # (k, D) contracted with metric
    ) -> Array: # (N,) new labels
    def assign_x(x, dual_x):
        sqdists = einsum("kd,kd->k", x - centroids, dual_x - dual_centroids) # (k,)
        return jnp.argmin(jnp.abs(sqdists)) # ()
    return jax.vmap(assign_x)(xs, dual_xs)

def kmeans(
    xs: Array, # (N, D)
    labels: Array, # (N,) in [0, k)
    k: int, # number of clusters
    iters: int, # number of iterations
    logmass: Array, # (N,) logmass for each point
    metric: Optional[Array] = None, # (D,D) metric for the distance, if None then use identity matrix
    ) -> Array: # (N,) new labels
    #weights = softmax(logmass)
    if metric is not None:
        dual_xs = einsum("nd,dd->nd", xs, metric) # (N, D) contracted with metric
    for i in range(iters):
        centroids = _kmeans_centroids(xs, labels, k, logmass) # (k, D) centroids
        if metric is not None:
            dual_centroids = einsum("kd,dd->kd", centroids, metric) # (k, D) contracted with metric
            labels = _kmeans_assign(xs, dual_xs, centroids, dual_centroids) # (N,) new labels
        else:
            labels = _kmeans_assign_no_metric(xs, centroids)
    return labels # (N,) new labels

def kmeans_with_init(
    xs: Array, # (N, D)
    key: PRNGKey, # random key for initialization
    k: int, # number of clusters
    iters: int, # number of iterations
    logmass: Array, # (N,) logmass for each point
    metric: Optional[Array] = None, # (D,D) metric for the distance, if None then use identity matrix
    ) -> Array: # (N,) new labels
    if logmass is None:
        logmass = jnp.zeros(xs.shape[0])
    weights = softmax(logmass)
    centroids = random.choice(key, xs, shape=(k,), replace=False, p=weights) # (k, D) random initialization
    if metric is not None:
        dual_xs = einsum("nd,dd->nd", xs, metric) # (N, D) contracted with metric
        dual_centroids = einsum("kd,dd->kd", centroids, metric) # (k, D) contracted with metric
        labels = _kmeans_assign(xs, dual_xs, centroids, dual_centroids) # (N,) initial labels
    else:
        labels = _kmeans_assign_no_metric(xs, centroids)
    for i in range(iters):
        centroids = _kmeans_centroids(xs, labels, k, logmass) # (k, D) centroids
        if metric is not None:
            dual_centroids = einsum("kd,dd->kd", centroids, metric)
            labels = _kmeans_assign(xs, dual_xs, centroids, dual_centroids)
        else:
            labels = _kmeans_assign_no_metric(xs, centroids)
    return labels # (N,) new labels

def accum_kmeans(
    xs: Array, # (N, D)
    centroids: Array, # (k, D) initial centroids
    ) -> Array: # (N,) new labels
    N, D = xs.shape
    K = centroids.shape[0]
    counts = jnp.ones((K,), dtype=centroids.dtype) # (k,) counts for each cluster
    BATCH_SIZE = 2**8 # batch size for the accumulation
    assert N % BATCH_SIZE == 0, f"N {N} must be divisible by BATCH_SIZE {BATCH_SIZE}"
    NUM_BATCHES = N // BATCH_SIZE # number of batches
    batched_xs = einshape("(bn)d->bnd", xs, n=N//BATCH_SIZE, b=BATCH_SIZE, d=D) # (N//BATCH_SIZE, BATCH_SIZE, D)
    def do_for_b(b, val):
        totals, counts = val # (k, D) centroids, (k,) counts
        local_xs = batched_xs[:,b] # (BATCH_SIZE, D) local xs
        centroids = totals / counts[:, None] # (k, D) current centroids
        diffs = local_xs[None, :, :] - centroids[:, None, :] # (k, BATCH_SIZE, D) differences
        labels = jnp.argmin(einsum("kbd,kbd->kb", diffs, diffs), axis=0) # (BATCH_SIZE,) new labels
        #labels = jnp.argmin(einsum("nd,kd->nk", local_xs, totals/counts[:,None]), axis=1) # (BATCH_SIZE,) new labels
        totals = totals.at[labels].add(local_xs) # (k, D) update centroids
        counts = counts.at[labels].add(1.) # (k,) update counts
        return totals, counts
    (totals, counts) = jax.lax.fori_loop(0, NUM_BATCHES, do_for_b, (centroids, counts)) # (k, D) final totals, (k,) final counts
    centroids = totals / counts[:, None] # (k, D) final centroids
    labels = jnp.argmin(einsum("nd,kd->nk", xs, centroids), axis=1) # (N,) new labels
    return labels # (N,) new labels

def assign_pallas(
    xs: Array, # (N, D)
    centroids: Array, # (k, D)
    ) -> Array: # (N,) new labels

    def assign_kernel(xs_ref, centroids_ref, out_ref):
        #out_ref[:] = jnp.zeros(xs_ref.shape[0], dtype=jnp.int32) # initialize output labels
        #return
        xs = xs_ref[:]
        centroids = centroids_ref[:]
        diffs = xs[None, :, :] - centroids[:, None, :] # (k, N, D) differences
        sqdists = jnp.sum(diffs * diffs, axis=-1)
        labels = jnp.argmin(sqdists, axis=0)
        out_ref[:] = labels

    BLOCK_SIZE = 16 # size of the block to process at once
    N, D = xs.shape
    K = centroids.shape[0]
    assert N % BLOCK_SIZE == 0, f"N {N} must be divisible by BLOCK_SIZE {BLOCK_SIZE}"
    num_blocks = N // BLOCK_SIZE # number of blocks

    return pl.pallas_call(
        assign_kernel,
        out_shape=jax.ShapeDtypeStruct((N,), jnp.int32), # output shape and dtype
        grid=(num_blocks,), # grid size
        in_specs=(
            pl.BlockSpec((BLOCK_SIZE, D), lambda i: (i, 0)), # input xs block
            pl.BlockSpec((K, D), lambda i: (0, 0)), # input centroids block
        ),
        out_specs=pl.BlockSpec((BLOCK_SIZE,), lambda i: (i,)), # output labels block
    )(xs, centroids)

def matmul_assign(
    xs: Array, # (N, D)
    centroids: Array, # (k, D)
    ) -> Array: # (N,) new labels
    centroid_sqmags = jnp.sum(centroids**2, axis=-1)[None,:] # (1, k,) squared magnitudes of centroids
    similarities = einsum("nd, kd->nk", xs, centroids) # (N, k) dot products with centroids
    return jnp.argmin(-2.*similarities + centroid_sqmags, axis=-1) # (N,) new labels



def uniform_kmeans(
    xs: Array, # (N, D)
    key: PRNGKey, # random key for initialization
    k: int, # number of clusters
    iters: int, # number of iterations
    ) -> Array: # (N,) new labels
    centroids = random.choice(key, xs, shape=(k,), replace=False) # (k, D) random initialization
    def assign(xs, centroids):
        diffs = xs[None, :, :] - centroids[:, None, :]
        return jnp.argmin(einsum("knd,knd->kn", diffs, diffs), axis=0) # (N,) new labels
    #return accum_kmeans(xs, centroids) # (N,) new labels
    @partial(jax.vmap, in_axes=(0, None), out_axes=0)
    def centroid(k, labels):
        return jnp.mean(xs, where=(labels == k)[:, None], axis=0)

    def segsumcentroids(labels):
        totals = jax.ops.segment_sum(xs, labels, num_segments=k) # (k, D) total centroids
        ones = jnp.ones((xs.shape[0], 1), dtype=xs.dtype) # (N, 1) ones for counting
        counts = jax.ops.segment_sum(ones, labels, num_segments=k) # (k,) counts for each cluster
        return totals / counts
        counts = jax.ops.segment_sum(jnp.ones(xs.shape[0], dtype=xs.dtype), labels, num_segments=k) # (k,) counts for each cluster
        return totals / jnp.maximum(counts[:, None], 1e-8)

    labels = matmul_assign(xs, centroids) # (N,) initial labels
    for i in range(iters):
        #centroids = segsumcentroids(labels) # (k, D) centroids
        centroids = centroid(jnp.arange(k), labels) # (k, D) centroids
        labels = matmul_assign(xs, centroids) # (N,) new labels
    return labels # (N,) new labels

    @partial(jax.vmap, in_axes=(0, None), out_axes=0)
    def centroid(k, labels):
        return jnp.mean(xs, where=(labels == k)[:, None], axis=0)
    def body_fun(i, labels):
        centroids = centroid(jnp.arange(k), labels) # (k, D) centroids
        labels = assign(xs, centroids) # (N,) new labels
        return labels
    labels = jax.lax.fori_loop(0, iters, body_fun, labels) # (N,) new labels
    return labels # (N,) new labels


def wasabi_level_zero(qcentroids, keys, klabels, values, K, n=None):
    Q, D = qcentroids.shape
    N, V = values.shape
    # (Q, D), (N, D), (N,), (N, V), int, Optional[int]
    scores = einsum("qd,nd->qn", qcentroids, keys) # (Q, N)
    def segment_weights(qscores, klabels):
        max_score = jax.ops.segment_max(qscores, klabels, num_segments=K) # (K,) max score for each cluster
        safe_scores = qscores - max_score[klabels] # (N,) safe scores for each key
        exp_safe_scores = jnp.exp(safe_scores) # (N,) exponentials of safe scores
        cluster_sums = jax.ops.segment_sum(exp_safe_scores, klabels, num_segments=K) # (K,) sum of exponentials for each cluster
        cluster_logsumexp = max_score + jnp.log(cluster_sums) # (K,) logsumexp for each cluster
        weights = exp_safe_scores / cluster_sums[klabels] # (N,) weights for each key
        return weights, cluster_logsumexp # (N,) weights for each key, (K,) logsumexp for each cluster
    all_weights, logmass = jax.vmap(segment_weights, in_axes=(0, None))(scores, klabels) # (Q, N), (Q, K,)
    def compute_centroids(q):
        weights = all_weights[q]
        kcentroids = jax.ops.segment_sum(weights[:, None] * keys, klabels, num_segments=K) # (K, D) key centroids
        vcentroids = jax.ops.segment_sum(weights[:, None] * values, klabels, num_segments=K) # (K, V) value centroids
        return kcentroids, vcentroids
    kcentroids, vcentroids = jax.vmap(compute_centroids)(jnp.arange(Q)) # (Q, K, D), (Q, K, V)

    #vkcentroids = jnp.zeros((Q, K, V, D), dtype=keys.dtype) # (Q, K, V, D) value-key centroids
    #return kcentroids, vcentroids, vkcentroids, logmass # (Q, K, D), (Q, K, V), (Q, K, V, D), (Q, K)

    def compute_kvd(k):
        mask = klabels == k # (N,) mask for the keys
        weights = jnp.mean(all_weights, axis=0)
        weights = jnp.where(mask, weights, 0.0) # (N,) local weights for the keys
        kcent = einsum("n,nk->k", weights, keys) # (K, D) key centroid
        vcent = einsum("n,nv->v", weights, values)
        vkcent = einsum("n,nv,nk->vk", weights, values, keys) # (K, V, D) value-key centroid
        outer = einsum("v,k->vk", vcent, kcent) # (K, V, D) outer product of value and key centroids
        vkcent = vkcent - outer
        return vkcent
    k_vkcentroids = jax.vmap(compute_kvd)(jnp.arange(K))[None,:,:,:] # (1, K, V, D)

    vkcentroids = jnp.broadcast_to(k_vkcentroids, (Q, K, V, D)) # (Q, K, V, D) broadcast to queries and keys
    return kcentroids, vcentroids, vkcentroids, logmass # (Q, K, D), (Q, K, V), (Q, K, V, D), (Q, K)

    def compute_qvd(q):
        weights = all_weights[q]
        vcent = vcentroids[q] # (K, V)
        kcent = kcentroids[q] # (K, D)
        #vcent = einsum("n,nv->v", weights, values) # (V,) value centroid
        #kcent = einsum("n,nk->k", weights, keys) # (K,) key centroid
        vkcent = einsum("n,nv,nd->vd", weights, values-vcent[klabels], keys-kcent[klabels]) / jnp.sum(weights)
        return vkcent
    q_vkcentroids = jax.vmap(compute_qvd)(jnp.arange(Q))[:,None,:,:] # (Q, 1, V, D)

    def mean_0(arr):
        return jnp.mean(arr, axis=0)
    p_vkcentroids = einsum("n,nv,nd->vd", mean_0(all_weights), values-mean_0(vcentroids)[klabels], keys-mean_0(kcentroids)[klabels]) / jnp.sum(mean_0(all_weights)) # (V, D) value-key centroid
    p_vkcentroids = p_vkcentroids[None,None,:,:] # (1, 1, V, D) broadcast to queries and keys

    def do_for_q_and_k(q, k):
        mask = klabels == k # (N,) mask for the keys
        weights = jnp.where(mask, all_weights[q], 0.0) # (N,) local weights for the keys
        kcent = einsum("n,nk->k", weights, keys) # (K, D) key centroid
        vcent = einsum("n,nv->v", weights, values)
        vkcent = einsum("n,nv,nk->vk", weights, values, keys) # (K, V, D) value-key centroid
        outer = einsum("v,k->vk", vcent, kcent) # (K, V, D) outer product of value and key centroids
        vkcent = vkcent - outer
        #logmass = logsumexp(scores_for_q, where=mask) # (Q,) total logmass for the query
        return vkcent
    do_for_q = jax.vmap(do_for_q_and_k, in_axes=(None, 0), out_axes=0)
    do_for_all = jax.vmap(do_for_q, in_axes=(0, None), out_axes=0)
    vkcentroids = do_for_all(jnp.arange(Q), jnp.arange(K)) # (Q, K, D), (Q, K, V), (Q, K, V, D)
    #q_vkcentroids = jnp.broadcast_to(jnp.mean(vkcentroids, axis=1, keepdims=True), (Q, K, V, D))
    #k_vkcentroids = jnp.broadcast_to(jnp.mean(vkcentroids, axis=0, keepdims=True), (Q, K, V, D))
    #p_vkcentroids = jnp.broadcast_to(jnp.mean(vkcentroids, axis=(0,1), keepdims=True), (Q, K, V, D))
    vkcentroids = k_vkcentroids + q_vkcentroids - p_vkcentroids # (Q, K, V, D) value-key centroids
    vkcentroids = jnp.broadcast_to(p_vkcentroids, (Q, K, V, D)) # (Q, K, V, D) broadcast to queries and keys
    vkcentroids = jnp.zeros_like(vkcentroids) # (Q, K, V, D) zero out the value-key centroids

    return kcentroids, vcentroids, vkcentroids, logmass # (Q, K, D), (Q, K, V), (Q, K, V, D), (Q, K)

def wasabi_keys_leaf(qcentroids, keys, logmass, values) -> Tuple[Array, Array, Array, Array]:
    # (Q, D), (N, D), (N,), (N, V) -> (Q, D), (Q, V), (Q, V, D), (Q,)
    logmass = logmass[None, :] + einsum("qd,nd->qn", qcentroids, keys)
    weights = softmax(logmass, axis=-1) # (Q, N) weights for each key
    kcent = einsum("qn,nd->qd", weights, keys)
    vcent = einsum("qn,nv->qv", weights, values)
    def mean_0(arr):
        return jnp.mean(arr, axis=0)
    #vkcent = einsum("qn,nv,nd->qvd", weights, values, keys) - einsum("qv,qd->qvd", vcent, kcent) # (Q, V, D)
    Q, D = qcentroids.shape
    N, V = values.shape
    vkcent = einsum("n,nv,nd->vd", mean_0(weights), values, keys) - einsum("v,d->vd", mean_0(vcent), mean_0(kcent)) # (V, D)
    assert vkcent.shape == (V, D), f"vkcent shape {vkcent.shape} does not match expected shape {(V, D)}"
    #vkcent = jnp.broadcast_to(vkcent[None,:,:], (Q, V, D)) # (Q, V, D) broadcast to queries
    total_logmass = logsumexp(logmass, axis=-1) # (Q,) total logmass for each query
    return kcent, vcent, vkcent, total_logmass


def level_zero(
    qcentroids: Array, # (Q, D)
    keys: Array, # (N, D)
    klabels: Array, # (N,) in [0, K)
    values: Array, # (N, V)
    K: int, # number of keys
    n: int, # size to downsample to, if None then use 2N//K 
    fwd_indices: Optional[Array] = None, # (K, n) precomputed indices for key collation, if None then construct from klabels
    ) -> Tuple[
        Array, # (Q, K, D)
        Array, # (Q, K, V)
        Array, # (Q, K, V, D)
        Array, # (Q, K)
        ]:
    """
    Level zero of the hierarchical attention, takes full key array and labels.
    Returns query clusters at index 0 and key clusters at index 1.
    """
    leaf_for_all_Qs = jax.vmap(
        keys_leaf,
        in_axes=(0, None, None, None), # query is the first axis, keys, logmass and values are shared
        out_axes=(0, 0, 0, 0), # all outputs are batched over queries
        )
    leaf_for_all_Qs = wasabi_keys_leaf
    def do_cluster_k(k):
        if fwd_indices is not None:
            indices = fwd_indices[k]
        else:
            (indices,) = jnp.where(klabels == k, size=n, fill_value=-1)
        local_keys = keys[indices]
        local_logmass = jnp.where(indices >= 0, 0.0, -jnp.inf)
        local_values = values[indices]
        return leaf_for_all_Qs(qcentroids, local_keys, local_logmass, local_values)

    kcent, vcent, vkcent, total_logmass = jax.lax.map(
        do_cluster_k,
        jnp.arange(K),
        batch_size=256,
    )

    def transpose(arr):
        return einshape("kq...->qk...", arr)
    return transpose(kcent), transpose(vcent), vkcent, transpose(total_logmass)

def level_zero_reclustering(
    qcentroids: Array, # (Q, D)
    keys: Array, # (N, D)
    klabels: Array, # (N,) in [0, K)
    values: Array, # (N, V)
    K: int, # number of keys
    n: int, # size to downsample to, if None then use 2N//K
    metrics: Optional[Array] = None, # (Q, D, D) metrics for the distance, if None then use identity matrix
    ) -> Tuple[
        Array, # (Q, K, D)
        Array, # (Q, K, V)
        Array, # (Q, K, V, D)
        Array, # (Q, K)
    ]:
    """
    Level zero of the hierarchical attention with reclustering, takes full key array and labels.
    Returns query clusters at index 0 and key clusters at index 1.
    """
    @jax.vmap
    def do_query(query, metric=None):
        logmass = einsum("d,nd->n", query, keys) # (N,) logmass for each key
        local_labels = kmeans(
            xs=keys, # (N, D) keys
            labels=klabels, # (N,) in [0, K)
            k=K, # number of clusters
            iters=16,
            logmass=logmass, # (N,) logmass for each key
            metric=metric, # no metric for the distance
        ) # (N,) new labels

        @jax.vmap
        def do_cluster_k(k):
            (indices,) = jnp.where(local_labels == k, size=n, fill_value=-1)
            local_keys = keys[indices]
            local_logmass = jnp.where(indices >= 0, 0.0, -jnp.inf)
            local_values = values[indices]
            return keys_leaf(query, local_keys, local_logmass, local_values)
        return do_cluster_k(jnp.arange(K))

    return do_query(qcentroids) if metrics is None else do_query(qcentroids, metrics) # (Q, K, D) key centroid, (Q, K, V) value centroid, (Q, K, V, D) value-key centroid, (Q, K) total_logmass
        

@partial(jax.named_call, name="keys_leaf")
def keys_leaf(
    query: Array, # (D,)
    keys: Array, # (N, D)
    logmass: Array, # (N,)
    values: Array, # (N, V)
    ) -> Tuple[
        Array, # (D,)
        Array, # (V,)
        Array, # (V, D)
        Array, # ()
        ]:
    """
    Level zero of the hierarchical attention, takes dense inputs.
    Returns key_centroid, value_centroid, value_key_centroid, total_logmass.
    """
    (D,) = query.shape
    (N, V) = values.shape
    assert keys.shape == (N, D), f"keys shape {keys.shape} does not match expected shape {(N, D)}"
    assert logmass.shape == (N,), f"logmass shape {logmass.shape} does not match expected shape {(N,)}"

    logmass = logmass + einsum("d,nd->n", query, keys)
    weights = softmax(logmass)
    kcent = einsum("n,nd->d", weights, keys)
    vcent = einsum("n,nv->v", weights, values)
    # Note that if we include the query dimension we can actually use the overall centroid here
    # and then subtract einsum("v,k->vk", vcent-vmu, kcent-kmu) to get the vkcent.
    vkcent = einsum("n,nv,nk->vk", weights, values-vcent, keys-kcent)
    total_logmass = logsumexp(logmass)
    return kcent, vcent, vkcent, total_logmass

def wasabi_queries_leaf(queries, kcent, vcent, vkcent, logmass, mask) -> Tuple[Array, Array]:
    # (Q,n,D), (Q,K,D), (Q,K,V), (Q,K,V,D), (Q,K) -> (Q,n), (Q,n,V)
    scores = einsum("qnd,qkd->qnk", queries, kcent) + logmass[:,None,:]# (Q, n, K)
    weights = softmax(scores, axis=-1) # (Q, n, K) weights for each key cluster
    total_logmass = logsumexp(scores, axis=-1) # (Q, n) total logmass for each query
    out_values = einsum("qnk,qkv->qnv", weights, vcent) # (Q, n, V) values for each query

    if DIPOLE:
        n_mean_weights = jnp.mean(weights, axis=1, where=mask[:,:,None]) # (Q, K) mean weights for each query
        extra_values = einsum("qnd,qk,kvd->qnv", queries, n_mean_weights, vkcent) # (Q, n, V) extra values for each query
        out_values = out_values + extra_values # (Q, n, V) values for each query with dipole correction
    return total_logmass, out_values # (Q, n), (Q, n, V)


def level_final(
    qres: Array, # (N, D) residual queries
    qlabels: Array, # (N,) in [0, Q)
    kcentroids: Array, # (Q, K, D) key centroids
    vcentroids: Array, # (Q, K, V) value centroids
    vkcentroids: Array, # (Q, K, V, D) value-key centroids
    logmass: Array, # (Q, K)
    Q: int, # number of query clusters
    n: int, # assumed max number of queries per cluster
    fwd_indices: Optional[Array] = None, # (Q, n) precomputed indices for query collation, if None then construct from qlabels
    bwd_indices: Optional[Array] = None, # (N,) precomputed indices for query decollation, if None then construct from qlabels
    ) -> Tuple[
        Array, # (N,) total_logmass
        Array, # (N, V) values
    ]:
    """
    Level one of the hierarchical attention, takes residual queries and key-value clusters.
    Returns total_logmass and values.
    """
    leaf_for_all_qs = jax.vmap(
        queries_leaf,
        in_axes=(0, None, None, None, None), # query is the first axis, kcent, vcent, vkcent and logmass are shared
        out_axes=(0, 0), # total_logmass and values are batched over queries
    )
    def get_indices(q):
        (indices,) = jnp.where(qlabels == q, size=n, fill_value=-1)
        return indices
    if fwd_indices is not None:
        indices = fwd_indices
    else:
        indices = jax.vmap(get_indices)(jnp.arange(Q))
    def do_cluster_q(xs):
        indices, kcent_q, vcent_q, vkcent_q, logmass_q = xs
        local_queries = qres[indices]
        out_logmass, out_values = leaf_for_all_qs(
            local_queries, # (n, D)
            kcent_q,
            vcent_q,
            vkcent_q,
            logmass_q,
        ) # (n,) total_logmass, (n, V) values
        return out_logmass, out_values
    if False: total_logmass, values = jax.lax.map(
        do_cluster_q,
        (indices, kcentroids, vcentroids, vkcentroids, logmass),
        batch_size=256,
    ) # (Q, n) total_logmass, (Q, n, V) values

    total_logmass, values = wasabi_queries_leaf(qres[indices], kcentroids, vcentroids, vkcentroids, logmass, mask=indices>=0) # (Q, n), (Q, n, V)
    if bwd_indices is not None:
        out_logmass = total_logmass[qlabels, bwd_indices]
        out_values = values[qlabels, bwd_indices]
        return out_logmass, out_values # (N,) total_logmass, (N, V) values


    V = values.shape[-1]
    out_q_indices = jnp.arange(Q*n).reshape((Q, n))
    out_indices = jnp.full((qres.shape[0],), -1, dtype=jnp.int32)
    out_indices = out_indices.at[indices].set(out_q_indices)
    out_logmass = einshape("qn->(qn)", total_logmass, q=Q, n=n)[out_indices]
    out_values = einshape("qnv->(qn)v", values, q=Q, n=n, v=V)[out_indices]
    return out_logmass, out_values

    coarse_total_logmass = logsumexp(logmass, axis=1) # (Q,) total_logmass for each query cluster
    coarse_weights = softmax(logmass, axis=1) # (Q, K) weights for each key cluster
    coarse_values = einsum("qk, qkv->qv", coarse_weights, vcentroids) # (Q, V) values for each query cluster
    default_logmass = coarse_total_logmass[qlabels]
    default_values = coarse_values[qlabels]
    out_logmass = jnp.where(out_indices >= 0, out_logmass, default_logmass)
    out_values = jnp.where((out_indices >= 0)[:,None], out_values, default_values)
    return out_logmass, out_values

@partial(jax.named_call, name="queries_leaf")
def queries_leaf(
    query: Array, # (D)
    kcent: Array, # (K, D)
    vcent: Array, # (K, V)
    vkcent: Array, # (K, V, D)
    logmass: Array, # (K,)
    ) -> Tuple[
        Array, # () total_logmass
        Array, # (V) value
    ]:
    """
    Top level of the hierarchical attention, takes dense inputs.
    Returns total_logmass and values.
    """
    logmass = logmass + einsum("d,kd->k", query, kcent)
    total_logmass = logsumexp(logmass)
    weights = softmax(logmass)
    out_values = einsum("k,kv->v", weights, vcent)
    if DIPOLE:
        out_values = out_values + einsum("d,k,kvd->v", query, weights, vkcent)
    return total_logmass, out_values

def level_mid(
    qcentroids: Array, # (Q, D)
    kcentroids: Array, # (K0, K1, D)
    vcentroids: Array, # (K0, K1, V)
    vkcentroids: Array, # (K0, K1, V, D)
    logmass: Array, # (K0, K1)
) -> Tuple[
        Array, # (Q, K0, D)
        Array, # (Q, K0, V)
        Array, # (Q, K0, V, D)
        Array, # (Q, K0)
    ]:
    assert kcentroids.ndim == 3, f"kcentroids shape {kcentroids.shape} does not match expected shape (K0, K1, D)"
    leaf_for_all_qs = jax.vmap(
        central_leaf,
        in_axes=(0, None, None, None, None), # query is the first axis, kcent, vcent, vkcent and logmass are shared
        out_axes=(0, 0, 0, 0), # all outputs are batched over queries
    )

    def do_cluster_k(k0):
        return leaf_for_all_qs(
            qcentroids, # (Q, D)
            kcentroids[k0], # (K1, D)
            vcentroids[k0], # (K1, V)
            vkcentroids[k0], # (K1, V, D)
            logmass[k0], # (K1,)
            ) # (Q, D) key centroid, (Q, V) value centroid, (Q, V, D) value-key centroid, (Q,) total_logmass

    kcent, vcent, vkcent, total_logmass = jax.lax.map(
        do_cluster_k,
        jnp.arange(kcentroids.shape[0]), # K0
        batch_size=16,
    ) # (K0, Q, D) key centroid, (K0, Q, V) value centroid, (K0, Q, V, D) value-key centroid, (K0, Q) total_logmass
    def transpose(arr):
        return einshape("kq...->qk...", arr)
    return transpose(kcent), transpose(vcent), transpose(vkcent), transpose(total_logmass)



@partial(jax.named_call, name="central_leaf")
def central_leaf(
    query: Array, # (D,)
    kcent: Array, # (K, D)
    vcent: Array, # (K, V)
    vkcent: Array, # (K, V, D)
    logmass: Array, # (K,)
    ) -> Tuple[
        Array, # (D,) key centroid
        Array, # (V,) value centroid
        Array, # (V, D) value-key centroid
        Array, # () total_logmass
    ]:
    """
    Middle level of the hierarchical attention, takes dense inputs.
    Returns key centroid, value centroid, value-key centroid, total_logmass.
    """
    logmass = logmass + einsum("d,kd->k", query, kcent)
    total_logmass = logsumexp(logmass)
    weights = softmax(logmass)
    out_kcent = einsum("k,kd->d", weights, kcent)
    out_vcent = einsum("k,kv->v", weights, vcent)
    out_vkcent = einsum("k,kvd->vd", weights, vkcent)
    if DIPOLE:
        out_vcent = out_vcent + einsum("d,vd->v", query, out_vkcent)
    out_vkcent = out_vkcent + einsum("k,kv,kd->vd", weights, vcent-out_vcent, kcent-out_kcent)
    return out_kcent, out_vcent, out_vkcent, total_logmass

@partial(jax.named_call, name="level_mid_reclustering")
def level_mid_reclustering(
    qcentroids: Array, # (Q, D)
    kcentroids: Array, # (K0, K1, D)
    vcentroids: Array, # (K0, K1, V)
    vkcentroids: Array, # (K0, K1, V, D)
    logmass: Array, # (K0, K1)
    metrics: Optional[Array] = None, # (Q, D, D) metrics for the distance, if None then use identity matrix
) -> Tuple[
        Array, # (Q, K0, D)
        Array, # (Q, K0, V)
        Array, # (Q, K0, V, D)
        Array, # (Q, K0)
    ]:
    assert kcentroids.ndim == 3, f"kcentroids shape {kcentroids.shape} does not match expected shape (K0, K1, D)"
    leaf_for_all_qs = jax.vmap(
        reclustering_leaf,
        in_axes=(0, None, None, None, None), # query is the first axis, kcent, vcent, vkcent and logmass are shared
        out_axes=(0, 0, 0, 0), # all outputs are batched over queries
    )
    return leaf_for_all_qs(
        qcentroids, # (Q, D)
        kcentroids, # (K0, K1, D)
        vcentroids, # (K0, K1, V)
        vkcentroids, # (K0, K1, V, D)
        logmass, # (K0, K1)
        metric=metrics, # (Q, D, D) metric for the distance, if None then use identity matrix
    ) # (Q, K0, D) key centroid, (Q, K0, V) value centroid, (Q, K0, V, D) value-key centroid, (Q, K0) total_logmass

@partial(jax.named_call, name="reclustering_leaf")
def reclustering_leaf(
    query: Array, # (D,)
    kcent: Array, # (C, K, D)
    vcent: Array, # (C, K, V)
    vkcent: Array, # (C, K, V, D)
    logmass: Array, # (C, K)
    metric: Optional[Array] = None, # (D, D) metric for the distance, if None then use identity matrix
    ) -> Tuple[
        Array, # (C, D) key centroid
        Array, # (C, V) value centroid
        Array, # (C, V, D) value-key centroid
        Array, # (C,) total_logmass
    ]:
    C, K, D = kcent.shape
    new_logmass = logmass + einsum("d,kld->kl", query, kcent)
    initial_labels = jnp.arange(C)[:,None] + jnp.zeros((K,), dtype=jnp.int32)[None,:] # (C, K) initial labels
    def flatten(arr):
        return einshape("ck...->(ck)...", arr, c=C, k=K)
    def unflatten(arr):
        return einshape("(ck)...->ck...", arr, c=C, k=K)
    new_labels = kmeans(
        xs=flatten(kcent), # (CK, D)
        labels=flatten(initial_labels), # (CK,)
        k=C, # number of clusters
        iters=16,
        logmass=flatten(new_logmass), # (CK,)
        metric=metric, # (D, D)
    )
    flat_kcent = flatten(kcent) # (CK, D)
    flat_vcent = flatten(vcent)
    flat_vkcent = flatten(vkcent)
    flat_logmass = flatten(logmass)
    def subsample_mid_leaf(idx):
        local_logmass = jnp.where(new_labels == idx, flat_logmass, -jnp.inf)
        return central_leaf(
            query=query, # (D,)
            kcent=flat_kcent, # (CK, D)
            vcent=flat_vcent, # (CK, V)
            vkcent=flat_vkcent, # (CK, V, D)
            logmass=local_logmass, # (CK,)
        )
    return jax.vmap(subsample_mid_leaf)(jnp.arange(C))
    # (C, D) key centroid, (C, V) value centroid, (C, V, D) value-key centroid, (C,) total_logmass



