import jax
from jax import numpy as jnp
import numpy as np
from chex import Array, dataclass
from typing import Optional, Tuple, Callable, List
from functools import partial
from jax.scipy.special import logsumexp
from einshape import jax_einshape as einshape
from jax.tree_util import tree_map
from jax.random import PRNGKey
from jax import random
from math import prod
import math

from .kmeans import subsample_kmeans as kmeans
from .kmeans import subsample_2kmeans as hkmeans
#from .kmeans import kmeans as kmeans

from .kernels import level_zero as kernel_level_zero, level_final as kernel_level_final, level_mid as kernel_level_mid
from .kernels import level_mid_reclustering as kernel_level_mid_reclustering
from .kernels import level_zero_reclustering as kernel_level_zero_reclustering
from . import kernels
from . import fast_kmeans

def show_cluster_sizes(labels, K):
    @jax.vmap
    def cluster_size(idx):
        return jnp.sum(labels == idx)
    sizes = cluster_size(jnp.arange(K))
    max_size = jnp.max(sizes)
    avg_size = jnp.mean(sizes)
    min_size = jnp.min(sizes)
    jax.debug.print("Clusters min {min_size}, max {max_size}, avg {avg_size}", min_size=min_size, max_size=max_size, avg_size=avg_size)
    #jax.debug.print("Cluster sizes: {sizes}", sizes=sizes)
    

@dataclass
class SimpleAttention:
    K: int # Number of clusters
    dipole: bool = False
    importance: bool = False
    cluster_importance: bool = False
    logmass_dipole: bool = False
    mean_query_importance: bool = False

    def exact_attention(
        self,
        queries: Array,
        keys: Array,
        values: Array
        )-> Tuple[Array, Array]:
        """
        Perform exact attention, for comparison purposes.
        """
        scores = jnp.einsum('ij, kj->ik', queries, keys)
        logmass = logsumexp(scores, axis=-1)
        weights = jax.nn.softmax(scores, axis=-1)
        output = jnp.einsum('ik, kj->ij', weights, values)
        return logmass, output

    def flash_attention(
        self,
        queries: Array,
        keys: Array,
        values: Array
        )-> Tuple[Array, Array]:
        """
        Dubious implementation of flash attention using scan.
        """
        BLK_K = 2**10
        assert keys.shape[0] % BLK_K == 0, f"Keys shape {keys.shape} must be divisible by {BLK_K}."
        BLK_Q = 2**8
        assert queries.shape[0] % BLK_Q == 0, f"Queries shape {queries.shape} must be divisible by {BLK_Q}."
        BATCH_SIZE = 2**6 # Optional parallelization with vmap implicitly
        D = keys.shape[-1]
        V = values.shape[-1]
        assert keys.shape[-1] == queries.shape[-1], f"Keys and queries must have the same last dimension, got {keys.shape[-1]} and {queries.shape[-1]}."
        num_qblocks = queries.shape[0] // BLK_Q
        num_kblocks = keys.shape[0] // BLK_K
        qblocks = einshape("(nq)d->nqd", queries, n=num_qblocks, q=BLK_Q, d=D)
        kblocks = einshape("(nk)d->nkd", keys, n=num_kblocks, k=BLK_K, d=D)
        vblocks = einshape("(nk)d->nkd", values, n=num_kblocks, k=BLK_K, d=V)

        def compute_for_qblock(qblock: Array): # (BLK_Q, D)
            out_logmass = jnp.full((BLK_Q,), -jnp.inf, dtype=keys.dtype) # (BLK_Q,)
            out_vals = jnp.zeros((BLK_Q, V), dtype=values.dtype) # (BLK_Q, V)
            def scan_body(
                    carry: Tuple[Array, Array], kv_block: Tuple[Array, Array]
                ) -> Tuple[Array, Array]:
                """
                Compute the attention for a single key-value block.
                """
                out_logmass, out_vals = carry # (BLK_Q,), (BLK_Q, V)
                keys_block, values_block = kv_block # (BLK_K, D), (BLK_K, V)
                scores = jnp.einsum('qd, kd->qk', qblock, keys_block) # (BLK_Q, BLK_K)
                new_logmass = logsumexp(scores, axis=-1) # (BLK_Q,)
                weights = jax.nn.softmax(scores, axis=-1) # (BLK_Q, BLK_K)
                new_values = jnp.einsum('qk, kv->qv', weights, values_block) # (BLK_Q, V)
                out_logmass = jnp.logaddexp(out_logmass, new_logmass) # (BLK_Q,)
                old_weight, new_weight = jax.nn.softmax(
                    jnp.stack([out_logmass, new_logmass], axis=0), axis=0
                ) # (BLK_Q,), (BLK_Q,)
                out_vals = old_weight[:,None] * out_vals + new_weight[:,None] * new_values # (BLK_Q, V)
                return (out_logmass, out_vals), None

            (out_logmass, out_vals), _ = jax.lax.scan(
                scan_body,
                init=(out_logmass, out_vals),
                xs=(kblocks, vblocks)
            )
            return out_logmass, out_vals
        out_logmass, out_vals = jax.lax.map(
            compute_for_qblock,
            qblocks,
            batch_size=BATCH_SIZE, # Optional parallelization with vmap implicitly
        )
        out_logmass = einshape("nq->(nq)", out_logmass, n=num_qblocks, q=BLK_Q)
        out_vals = einshape("nqv->(nq)v", out_vals, n=num_qblocks, q=BLK_Q, v=V)
        return out_logmass, out_vals






    def cluster_then_attention(
        self,
        queries: Array,
        keys: Array,
        values: Array
        )-> Tuple[Array, Array]:
        """
        Perform kmeans clustering and then run simplified single-level hierarchical attention.
        """
        qlabels, klabels = self.do_clustering(
            queries=queries,
            keys=keys,
            values=values
        )
        return self.single_level_attention(
            queries=queries,
            qlabels=qlabels,
            keys=keys,
            klabels=klabels,
            values=values
        )

    def do_clustering(
        self,
        queries: Array,
        keys: Array,
        values: Array
        )-> Tuple[Array, Array]:
        """
        Perform kmeans clustering for simplified single-level hierarchical attention.
        """
        ITERS = 10
        TWO_LEVEL = True
        if TWO_LEVEL:
            assert np.sqrt(self.K).is_integer(), "K must be a perfect square for two-level clustering."
            kmeans_k = (int(np.sqrt(self.K)),)*2
        else:
            kmeans_k = (self.K,)
        #queries = queries[...,:16]  # Limit queries to 64 dimensions
        #keys = keys[...,:16] # Limit keys to 64 dimensions
        key = PRNGKey(42)
        klabel_key, qlabel_key = random.split(key)
        kres = keys
        kimplogmass = jnp.zeros((keys.shape[0],), dtype=jnp.float32)
        if self.cluster_importance:
            kimplogmass = kimplogmass + jnp.log(jnp.linalg.norm(values, axis=-1) + 1e-1)
        if self.mean_query_importance:
            mean_query = jnp.mean(queries, axis=0)
            extra_importance = jnp.einsum('d,nd->n', mean_query, keys)
            kimplogmass = kimplogmass + extra_importance
        for _ in range(3):
            qmetric = jnp.cov(kres, bias=True, rowvar=False)
            qcentroids, qlabels = hkmeans(xs=queries, key=qlabel_key, k=kmeans_k, iters=ITERS, metric=qmetric)
            qres = queries - qcentroids[qlabels]
            #kimplogmass = jnp.zeros((keys.shape[0],), dtype=jnp.float32)
            #if self.cluster_importance:
            #    kimplogmass = kimplogmass + jnp.log(jnp.linalg.norm(values, axis=-1) + 1e-1)
            kmetric = jnp.cov(qres, bias=True, rowvar=False)
            kcentroids, klabels = hkmeans(xs=keys, key=klabel_key, k=kmeans_k, iters=ITERS, logmass=kimplogmass, metric=kmetric)
            kres = keys - kcentroids[klabels]
        #show_cluster_sizes(qlabels, self.K)
        #show_cluster_sizes(klabels, self.K)
        return qlabels, klabels

    def single_level_attention(
        self,
        queries: Array,
        qlabels: Array,
        keys: Array,
        klabels: Array,
        values: Array
        )-> Tuple[Array, Array]:
        """
        Compute simplified single-level hierarchical attention for testing purposes.
        Returns the logmass and value for every query.
        """
        qcindices = jnp.arange(self.K)
        def qcentroid(idx):
            return jnp.mean(queries, axis=0, where=(qlabels == idx)[:, None])
        qcentroids = jax.vmap(qcentroid)(qcindices)
        key_weights = jnp.einsum('ij, kj->ik', qcentroids, keys)
        if self.importance:
            logvmags = jnp.log(jnp.linalg.norm(values, axis=-1) + 1e-1)
            importance_mass = logvmags[None, :]
            key_weights = key_weights + importance_mass
            values = values * jnp.exp(-logvmags[:, None])
        kcindices = jnp.arange(self.K)
        def kcentroid(key_weights, idx):
            key_weights = jnp.where(klabels == idx, key_weights, -jnp.inf)
            return self.centroid(keys, key_weights)
        kcentroids_for_ki = jax.vmap(kcentroid, in_axes=(0, None), out_axes=0)
        kcentroids = jax.vmap(kcentroids_for_ki, in_axes=(None, 0), out_axes=1)(key_weights, kcindices)
        def vcentroid(key_weights, idx):
            key_weights = jnp.where(klabels == idx, key_weights, -jnp.inf)
            return self.centroid(values, key_weights)
        vcentroids_for_ki = jax.vmap(vcentroid, in_axes=(0, None), out_axes=0)
        vcentroids = jax.vmap(vcentroids_for_ki, in_axes=(None, 0), out_axes=1)(key_weights, kcindices)
        if self.dipole:
            def vkcorr(key_weights, idx):
                key_weights = jnp.where(klabels == idx, key_weights, -jnp.inf)
                kcent = kcentroid(key_weights, idx)
                vcent = vcentroid(key_weights, idx)
                outer = (values - vcent)[:, :, None] * (keys - kcent)[:, None, :]
                DV = values.shape[-1]
                DK = keys.shape[-1]
                outer = einshape("nvk->n(vk)", outer, v=DV, k=DK)
                centroid_outer = self.centroid(outer, key_weights)
                centroid_outer = einshape("(vk)->vk", centroid_outer, v=DV, k=DK)
                return centroid_outer
            vkcorr_for_ki = jax.vmap(vkcorr, in_axes=(0, None), out_axes=0)
            vkcorrs = jax.vmap(vkcorr_for_ki, in_axes=(None, 0), out_axes=1)(key_weights, kcindices)

        if self.logmass_dipole:
            def kkcorr(key_weights, idx):
                key_weights = jnp.where(klabels == idx, key_weights, -jnp.inf)
                kcent = kcentroid(key_weights, idx)
                #vcent = vcentroid(key_weights, idx)
                outer = (keys - kcent)[:, :, None] * (keys - kcent)[:, None, :]
                DV = values.shape[-1]
                DK = keys.shape[-1]
                outer = einshape("nvk->n(vk)", outer, v=DV, k=DK)
                centroid_outer = self.centroid(outer, key_weights)
                centroid_outer = einshape("(vk)->vk", centroid_outer, v=DV, k=DK)
                return centroid_outer
            kkcorr_for_ki = jax.vmap(kkcorr, in_axes=(0, None), out_axes=0)
            kkcorrs = jax.vmap(kkcorr_for_ki, in_axes=(None, 0), out_axes=1)(key_weights, kcindices)

        def cluster_logmass(key_weights, idx):
            key_weights = jnp.where(klabels == idx, key_weights, -jnp.inf)
            return logsumexp(key_weights, axis=0)
        clogmass_for_ki = jax.vmap(cluster_logmass, in_axes=(0, None), out_axes=0)
        clogmass = jax.vmap(clogmass_for_ki, in_axes=(None, 0), out_axes=1)(key_weights, kcindices)
        if self.importance:
            clogmass_true = jax.vmap(clogmass_for_ki, in_axes=(None, 0), out_axes=1)(key_weights - importance_mass, kcindices)
        print(f"clogmass shape: {clogmass.shape}")
        resqueries = queries - qcentroids[qlabels]
        print(f"resqueries shape: {resqueries.shape}")
        selected_key_clusters = kcentroids[qlabels]
        print(f"selected_key_clusters shape: {selected_key_clusters.shape}")
        selected_value_clusters = vcentroids[qlabels]
        print(f"selected_value_clusters shape: {selected_value_clusters.shape}")
        if self.dipole:
            selected_vkcorrs = vkcorrs[qlabels]
            print(f"selected_vkcorrs shape: {selected_vkcorrs.shape}")
            extra_value = jnp.einsum('ij,ikvj->ikv', resqueries, selected_vkcorrs)
            print(f"extra_value shape: {extra_value.shape}")
            selected_value_clusters = selected_value_clusters + extra_value
        if self.logmass_dipole:
            selected_kkcorrs = kkcorrs[qlabels]
            print(f"selected_kkcorrs shape: {selected_kkcorrs.shape}")
            extra_keys = 0.5 * jnp.einsum('ij,ikvj->ikv', resqueries, selected_kkcorrs)
            print(f"extra_keys shape: {extra_keys.shape}")
            selected_key_clusters = selected_key_clusters + extra_keys

        base_logmass = clogmass[qlabels]
        print(f"base_logmass shape: {base_logmass.shape}")
        extra_logmass = jnp.einsum('ij,ikj->ik', resqueries, selected_key_clusters)
        print(f"extra_logmass shape: {extra_logmass.shape}")
        qclogmass = base_logmass + extra_logmass
        print(f"qclogmass shape: {qclogmass.shape}")

        final_values = jax.vmap(self.centroid)(
            selected_value_clusters, qclogmass
        )
        print(f"final_values shape: {final_values.shape}")
        final_logmass = logsumexp(qclogmass, axis=-1)
        print(f"final_logmass shape: {final_logmass.shape}")
        
        if self.importance:
            base_logmass_true = clogmass_true[qlabels]
            qclogmass_true = base_logmass_true + extra_logmass
            final_logmass_true = logsumexp(qclogmass_true, axis=-1)
            logupweighting = final_logmass - final_logmass_true
            final_logmass = final_logmass - logupweighting
            final_values = final_values * jnp.exp(logupweighting[:, None])

        return final_logmass, final_values


    def centroid(self, xs, logmass):
        ws = jax.nn.softmax(logmass, axis=-1)
        return jnp.einsum('i,ij->j', ws, xs)

@dataclass
class FastSimpleAttention(SimpleAttention):
    size_multiplier: float = 6.0
    """
    A simple attention mechanism that performs clustering and then attention.
    This is a fast version of the SimpleAttention class.
    """
    def __post_init__(self):
        assert self.K > 0, "Number of clusters K must be greater than 0."
        #assert not self.dipole, "Dipole correction is not supported in FastSimpleAttention."
        assert not self.importance, "Importance weighting is not supported in FastSimpleAttention."
        #assert not self.cluster_importance, "Cluster importance is not supported in FastSimpleAttention."
        assert not self.logmass_dipole, "Logmass dipole correction is not supported in FastSimpleAttention."

    def get_cluster_size(self, N: int) -> int:
        """
        Get the size of each cluster based on the number of queries.
        """
        assert N % self.K == 0, f"Number of queries {N} must be divisible by number of clusters {self.K}."
        n = N // self.K
        return math.ceil(n * self.size_multiplier)

    def _two_level_attention(
        self,
        queries: Array,
        qlabels: Array,
        keys: Array,
        klabels: Array,
        values: Array
        )-> Tuple[Array, Array]:
        root_K = int(np.sqrt(self.K))
        if False: return self.replace(K=root_K)._simple_attention(
            queries=queries,
            qlabels=qlabels // root_K,
            keys=keys,
            klabels=klabels // root_K,
            values=values
        )

        def get_query_cov(idx):
            logmass = jnp.zeros((queries.shape[0],), dtype=queries.dtype)
            logmass = jnp.where(qlabels == idx, logmass, -jnp.inf) # (N,)
            weights = jax.nn.softmax(logmass) # (N,)
            qmean = jnp.einsum('n,nd->d', weights, queries) # (D,)
            return jnp.einsum('n,na,nb->ab', weights, queries - qmean, queries - qmean) # (D, D)

        assert np.sqrt(self.K).is_integer(), "K must be a perfect square for two-level clustering."
        K0 = int(np.sqrt(self.K))
        K1 = int(np.sqrt(self.K))
        Q0 = K0
        Q1 = K1
        D = queries.shape[-1]
        coarse_qlabels = qlabels // K1
        coarse_klabels = klabels // K1
        def qcentroid(idx):
            return jnp.mean(queries, axis=0, where=(qlabels == idx)[:, None])
        def qcentroid_coarse(idx):
            return jnp.mean(queries, axis=0, where=(coarse_qlabels == idx)[:, None])
        qcentroids = jax.vmap(qcentroid)(jnp.arange(self.K)) # (Q, D)
        qcentroids1 = einshape("(ab)d->abd", qcentroids, a=K0, b=K1, d=queries.shape[-1]) # (Q0, Q1, D)
        qcentroids0 = jax.vmap(qcentroid_coarse)(jnp.arange(Q0)) # (Q0, D)

        def get_coarse_query_cov(idx):
            logmass = jnp.zeros((queries.shape[0],), dtype=queries.dtype)
            logmass = jnp.where(coarse_qlabels == idx, logmass, -jnp.inf)
            weights = jax.nn.softmax(logmass) # (N,)
            qmean = jnp.einsum('n,nd->d', weights, queries) # (D,)
            return jnp.einsum('n,na,nb->ab', weights, queries - qmean, queries - qmean) # (D, D)
        coarse_metrics = jax.vmap(get_coarse_query_cov)(jnp.arange(Q0)) # (Q0, D, D)

        kcentroids, vcentroids, vkcentroids, logmass = kernel_level_zero(
            qcentroids=qcentroids0, # (Q0, D)
            keys=keys, # (N, D)
            klabels=klabels, # (N,)
            values=values, # (N, D)
            K=K0 * K1, # Number of clusters
            n=self.get_cluster_size(keys.shape[0]), # Assumed max size of each cluster
            #metrics=coarse_metrics, # (Q0, D, D)
        ) # (Q0, K, D), (Q0, K, D), (Q0, K, D, D), (Q0, K)
        qcentroids1_res = qcentroids1 - qcentroids0[:, None, :] # (Q0, Q1, D)
        def split_k(arr):
            return einshape("q(ab)...->qab...", arr, q=Q0, a=K0, b=K1)
        def merge_q(arr):
            return einshape("abk...->(ab)k...", arr, a=K0, b=K1, k=K0)
        #metrics = jnp.identity(D, dtype=keys.dtype)[None,None,:,:] # (1, 1, D, D)
        #metrics = jnp.broadcast_to(metrics, (Q0, Q1, D, D)) # (Q0, Q1, D, D)
        qindices = jnp.arange(Q0 * Q1).reshape((Q0, Q1)) # (Q0, Q1)
        metrics = jax.vmap(jax.vmap(get_query_cov))(qindices) # (Q0, Q1, D, D)
        #metrics = jnp.broadcast_to(jnp.mean(metrics, axis=(0, 1), keepdims=True), (Q0, Q1, D, D)) # (Q0, Q1, D, D)
        kcentroids, vcentroids, vkcentroids, logmass = map(merge_q, jax.vmap(kernel_level_mid)(
            qcentroids=qcentroids1_res, # (Q0, Q1, D)
            kcentroids=split_k(kcentroids), # (Q0, K0, K1, D)
            vcentroids=split_k(vcentroids), # (Q0, K0, K1, D)
            vkcentroids=split_k(vkcentroids), # (Q0, K0, K1, D, D)
            logmass=split_k(logmass), # (Q0, K0, K1)
            #metrics=metrics, # (Q0, Q1, D, D)
        )) # (Q0, Q1, K0, D), (Q0, Q1, K0, D), (Q0, Q1, K0, D, D), (Q0, Q1, K0)
        qres = queries - qcentroids[qlabels] # (N, D)
        return kernel_level_final(
            qres=qres, # (N, D)
            qlabels=qlabels, # (N,)
            kcentroids=kcentroids, # (Q, K0, D)
            vcentroids=vcentroids, # (Q, K0, D)
            vkcentroids=vkcentroids, # (Q, K0, D, D)
            logmass=logmass, # (Q, K0)
            Q=Q0 * Q1, # Number of clusters
            n=self.get_cluster_size(queries.shape[0]), # Assumed max size of each cluster
        )

    def _simple_attention(
        self,
        queries: Array,
        qlabels: Array,
        keys: Array,
        klabels: Array,
        values: Array
        )-> Tuple[Array, Array]:

        @jax.vmap
        def qcentroid(idx):
            """
            Compute the centroid of queries for a given label.
            """
            return jnp.mean(queries, axis=0, where=(qlabels == idx)[:, None])
        qcentroids = qcentroid(jnp.arange(self.K)) # (Q, D)

        if False: kcentroids, vcentroids, vkcentroids, logmass = self._capped_level_zero(
            qcentroids=qcentroids, # (Q, D)
            keys=keys, # (N, D)
            klabels=klabels, # (N,)
            values=values # (N, D)
        ) # (Q, K, D), (Q, K, D), (Q, K, D, D), (Q, K)
        kcentroids, vcentroids, vkcentroids, logmass = kernel_level_zero(
            qcentroids=qcentroids, # (Q, D)
            keys=keys, # (N, D)
            klabels=klabels, # (N,)
            values=values, # (N, D)
            K=self.K, # Number of clusters
            n=self.get_cluster_size(keys.shape[0]), # Assumed max size of each cluster
        )

        qres = queries - qcentroids[qlabels] # (N, D)

        if False: return self._capped_level_one(
            qres=qres, # (N, D)
            qlabels=qlabels, # (N,)
            kcentroids=kcentroids, # (Q, K, D)
            vcentroids=vcentroids, # (Q, K, D)
            vkcentroids=vkcentroids, # (Q, K, D, D)
            logmass=logmass # (Q, K)
        ) # (N,), (N, D)

        return kernel_level_final(
            qres=qres, # (N, D)
            qlabels=qlabels, # (N,)
            kcentroids=kcentroids, # (Q, K, D)
            vcentroids=vcentroids, # (Q, K, D)
            vkcentroids=vkcentroids, # (Q, K, D, D)
            logmass=logmass, # (Q, K)
            Q=self.K, # Number of clusters
            n=self.get_cluster_size(queries.shape[0]), # Assumed max size of each cluster
        ) # (N,), (N, D)

    def _simple_level_zero(
        self,
        qcentroids: Array, # (Q, D)
        keys: Array,
        klabels: Array,
        values: Array
        )-> Tuple[Array, Array]:

        @partial(jax.vmap, in_axes=(None, 0), out_axes=1)
        def get_k_v_logmass(qcentroids, key_idx):
            """
            Compute the key, value centroids and logmass for a given query centroid and key index.
            """
            logmass = jnp.einsum('qd,nd->qn', qcentroids, keys) # (Q, N)
            logmass = jnp.where((klabels == key_idx)[None,:], logmass, -jnp.inf)
            weights = jax.nn.softmax(logmass, axis=1) # (Q, N)
            kcentroid = jnp.einsum('qn,nd->qd', weights, keys) # (Q, D)
            kmu = jnp.mean(kcentroid, axis=0) # (D,)
            vcentroid = jnp.einsum('qn,nd->qd', weights, values) # (Q, D)
            vmu = jnp.mean(vcentroid, axis=0) # (D,)
            if self.dipole:
                vkcentroid = jnp.einsum('qn,nv,nk->qvk', weights, values-vmu, keys-kmu) # (Q, D, D)
                vkcentroid = vkcentroid - jnp.einsum('qv,qk->qvk', vcentroid-vmu, kcentroid-kmu) # (Q, D, D)
            else:
                vkcentroid = None
            total_logmass = logsumexp(logmass, axis=1) # (Q,)
            return kcentroid, vcentroid, vkcentroid, total_logmass
        return get_k_v_logmass(qcentroids, jnp.arange(self.K)) # (Q, K, D), (Q, K, D), (Q, K, D, D), (Q, K)

    @partial(jax.named_call, name="capped_level_zero")
    def _capped_level_zero(
        self,
        qcentroids: Array, # (Q, D)
        keys: Array, # (N, D)
        klabels: Array, # (N,)
        values: Array # (N, D)
        )-> Tuple[Array, Array]:
        """
        Assumes clusters are no larger than self.get_cluster_size(N).
        """

        def get_indices_and_valid(key_idx):
            """
            Get the indices of the keys for a given key index and whether they are valid.
            """
            n = self.get_cluster_size(keys.shape[0])
            (indices,) = jnp.nonzero(klabels == key_idx, size=n, fill_value=-1)
            valid = indices >= 0
            return indices, valid

        #@partial(jax.vmap, in_axes=(None, 0), out_axes=1)
        def get_k_v_logmass(qcentroids, key_idx):
            """
            Compute the key, value centroids and logmass for a given query centroid and key index.
            """
            indices, valid = get_indices_and_valid(key_idx)
            local_keys = keys[indices] # (n, D)
            local_values = values[indices] # (n, D)
            logmass = jnp.einsum('qd,nd->qn', qcentroids, local_keys) # (Q, N)
            logmass = jnp.where(valid[None,:], logmass, -jnp.inf)
            weights = jax.nn.softmax(logmass, axis=1) # (Q, N)
            kcentroid = jnp.einsum('qn,nd->qd', weights, local_keys) # (Q, D)
            kmu = jnp.mean(kcentroid, axis=0) # (D,)
            vcentroid = jnp.einsum('qn,nd->qd', weights, local_values) # (Q, D)
            vmu = jnp.mean(vcentroid, axis=0) # (D,)
            if self.dipole:
                vkcentroid = jnp.einsum(
                    'qn,nv,nk->qvk', weights, local_values-vmu, local_keys-kmu) # (Q, D, D)
                vkcentroid = vkcentroid - jnp.einsum(
                    'qv,qk->qvk', vcentroid-vmu, kcentroid-kmu) # (Q, D, D)
            else:
                vkcentroid = None
            total_logmass = logsumexp(logmass, axis=1) # (Q,)
            return kcentroid, vcentroid, vkcentroid, total_logmass

        vmap_get_k_v_logmass = jax.vmap(
            get_k_v_logmass, in_axes=(None, 0), out_axes=1
        )
        # (Q, K, D), (Q, K, D), (Q, K, D, D), (Q, K)
        #return vmap_get_k_v_logmass(qcentroids, jnp.arange(self.K))
        def kvl_for_k(k, val):
            kc, vc, vkc, lm = val
            kcentroid, vcentroid, vkcentroid, total_logmass = get_k_v_logmass(qcentroids, k)
            kc = kc.at[:,k].set(kcentroid)
            vc = vc.at[:,k].set(vcentroid)
            if self.dipole:
                vkc = vkc.at[:,k].set(vkcentroid)
            lm = lm.at[:,k].set(total_logmass)
            return kc, vc, vkc, lm

        Q, D = qcentroids.shape
        K = self.K
        V = values.shape[-1]
        kc = jnp.zeros((Q, K, D), dtype=keys.dtype) # (Q, K, D)
        vc = jnp.zeros((Q, K, V), dtype=values.dtype) # (Q, K, V)
        vkc = jnp.zeros((Q, K, V, D), dtype=values.dtype) if self.dipole else None # (Q, K, V, D)
        lm = jnp.zeros((Q, K), dtype=keys.dtype) # (Q, K)
        kc, vc, vkc, lm = jax.lax.fori_loop(
            lower=0,
            upper=self.K,
            body_fun=kvl_for_k,
            init_val=(kc, vc, vkc, lm)
        )
        return kc, vc, vkc, lm # (Q, K, D), (Q, K, D), (Q, K, D, D), (Q, K)


    def _simple_level_one(
        self,
        qres: Array, # (N, D)
        qlabels: Array, # (N,)
        kcentroids: Array, # (Q, K, D)
        vcentroids: Array, # (Q, K, D)
        vkcentroids: Optional[Array], # (Q, K, D, D)
        logmass: Array # (Q, K)
        )-> Tuple[Array, Array]:
        chosen_kcentroids = kcentroids[qlabels] # (N, K, D)
        # (N, K, D)
        chosen_logmass = logmass[qlabels] + jnp.einsum('nd,nkd->nk', qres, chosen_kcentroids)
        values = vcentroids[qlabels] # (N, K, D)
        if self.dipole:
            chosen_vkcentroids = vkcentroids[qlabels] # (N, K, D, D)
            extra_value = jnp.einsum('nd,nkvd->nkv', qres, chosen_vkcentroids) # (N, K, D)
            values = values + extra_value # (N, K, D)
        total_logmass = logsumexp(chosen_logmass, axis=1) # (N,)
        weights = jax.nn.softmax(chosen_logmass, axis=1) # (N, K)
        total_values = jnp.einsum('nk,nkd->nd', weights, values) # (N, D)
        return total_logmass, total_values

    @partial(jax.named_call, name="capped_level_one")
    def _capped_level_one(
        self,
        qres: Array, # (N, D)
        qlabels: Array, # (N,)
        kcentroids: Array, # (Q, K, D)
        vcentroids: Array, # (Q, K, D)
        vkcentroids: Optional[Array], # (Q, K, D, D)
        logmass: Array # (Q, K)
        )-> Tuple[Array, Array]:

        def get_indices_and_valid(q):
            """
            Get the indices of the queries for a given query index and whether they are valid.
            """
            n = self.get_cluster_size(qres.shape[0])
            (indices,) = jnp.nonzero(qlabels == q, size=n, fill_value=-1)
            valid = indices >= 0
            return indices, valid

        indices, valid = jax.vmap(get_indices_and_valid)(jnp.arange(self.K)) # (Q, n), (Q, n)

        def do_compute(qidx):
            indices_q, valid_q = indices[qidx], valid[qidx] # (n,), (n,)
            kcent_q, vcent_q = kcentroids[qidx], vcentroids[qidx] # (K, D), (K, D)
            logmass_q = logmass[None, qidx, :] # (n,K)

            local_qres = qres[indices_q] # (n, D)

            assert local_qres.ndim == 2, f"local_qres should be 2D, got {local_qres.ndim}D."
            logmass_q = logmass_q + jnp.einsum('nd,kd->nk', local_qres, kcent_q) # (n, K)
            out_logmass_q = logsumexp(logmass_q, axis=1) # (n,)
            weights_q = jax.nn.softmax(logmass_q, axis=1) # (n, K)
            out_values_q = jnp.einsum('nk,kd->nd', weights_q, vcent_q) # (n, D)

            if self.dipole:
                vkcent_q = vkcentroids[qidx] if self.dipole else None # (K, D, D)
                extra_value_q = jnp.einsum('nd,nk,kvd->nv', local_qres, weights_q, vkcent_q) # (n, D)
                out_values_q = out_values_q + extra_value_q # (n, D)

            return out_logmass_q, out_values_q

        out_logmass_qs, out_values_qs = jax.lax.map(
            do_compute,
            jnp.arange(self.K), # (K,)
            batch_size=16, # Optional parallelization with vmap implicitly
        ) # (Q, n), (Q, n, D)
        out_logmass = jnp.zeros((qres.shape[0],), dtype=logmass.dtype) # (N,)
        out_values = jnp.zeros((qres.shape[0], vcentroids.shape[-1]), dtype=vcentroids.dtype) # (N, D)
        Q, n, V = out_values_qs.shape
        out_q_indices = jnp.arange(Q*n).reshape((Q, n)) # (Q, n)
        out_indices = jnp.full((qres.shape[0],), -1, dtype=jnp.int32)
        out_indices = out_indices.at[indices].set(out_q_indices) # (N,)
        out_logmass = einshape("qn->(qn)", out_logmass_qs, q=Q, n=n)[out_indices] # (N,)
        out_values = einshape("qnv->(qn)v", out_values_qs, q=Q, n=n, v=V)[out_indices] # (N, D)
        return out_logmass, out_values # (N,), (N, D)

    def single_level_attention(
        self,
        queries: Array,
        qlabels: Array,
        keys: Array,
        klabels: Array,
        values: Array
        )-> Tuple[Array, Array]:
        """
        Assumes clusters are equal size.
        """
        return self._simple_attention(
            queries=queries, # (N, D)
            qlabels=qlabels, # (N,)
            keys=keys, # (N, D)
            klabels=klabels, # (N,)
            values=values # (N, D)
        )

@dataclass
class Clustering:
    Qs: Tuple[int, ...] # Number of query clusters at each level
    Ks: Tuple[int, ...] # Number of key clusters at each level

    def __post_init__(self):
        assert len(self.Qs) == len(self.Ks), "Qs and Ks must have the same length."

    def do_clustering(
        self,
        queries: Array, # (N, D)
        keys: Array, # (N, D)
        values: Array, # (N, V)
        qlogmass: Optional[Array] = None, # (N,) optional
        klogmass: Optional[Array] = None, # (N,) optional
    ) -> Tuple[Array, Array]: # (N,) qlabels, (N,) klabels
        """
        Perform clustering on queries and keys, returning cluster labels.
        """
        raise NotImplementedError("This method should be implemented by subclasses.")

@dataclass
class StubClustering(Clustering):
    def __post_init__(self):
        assert len(self.Qs) == 1, "StubClustering should only have one level of clustering."
        assert len(self.Ks) == 1, "StubClustering should only have one level of clustering."
        assert self.Qs[0] == self.Ks[0], "StubClustering should have equal number of query and key clusters."
    def do_clustering(
        self,
        queries: Array, # (N, D)
        keys: Array, # (N, D)
        values: Array, # (N, V)
        qlogmass: Optional[Array] = None, # (N,) optional
        klogmass: Optional[Array] = None, # (N,) optional
    ) -> Tuple[Array, Array]: # (N,) qlabels, (N,) klabels
        """
        Perform clustering on queries and keys, returning cluster labels.
        """
        assert qlogmass is None, "qlogmass is not supported in StubClustering."
        assert klogmass is None, "klogmass is not supported in StubClustering."
        return SimpleAttention(K=self.Ks[0]).do_clustering(
            queries=queries, # (N, D)
            keys=keys, # (N, D)
            values=values, # (N, V)
        )

@dataclass
class MultiLevelClustering(Clustering):
    cluster_importance: bool = False
    mean_query_importance: bool = False
    coupled_clustering: bool = False
    inner_iters: int = 10
    outer_iters: int = 3
    max_cluster_scale: float = 2.0
    compute_indices: bool = False

    def get_cluster_size(self, N: int, K: int) -> int:
        """
        Get the size of each cluster based on the number of queries.
        """
        n = N / K
        return math.ceil(n * self.max_cluster_scale)

    def do_clustering(
        self,
        queries: Array, # (N, D)
        keys: Array, # (N, D)
        values: Array, # (N, V)
        qlogmass: Optional[Array] = None, # (N,) optional
        klogmass: Optional[Array] = None, # (N,) optional
    ) -> Tuple[Array, Array]: # (N,) qlabels, (N,) klabels
        if klogmass is None:
            klogmass = jnp.zeros((keys.shape[0],), dtype=keys.dtype)
        if self.cluster_importance:
            klogmass = klogmass + jnp.log(jnp.linalg.norm(values, axis=-1) + 1e-1)
        if self.mean_query_importance:
            mean_query = jnp.mean(queries, axis=0)
            klogmass = klogmass + jnp.einsum('d,nd->n', mean_query, keys)
        if not self.coupled_clustering:
            assert not self.cluster_importance
            assert not self.mean_query_importance
            if self.Qs == self.Ks:
                return self.uniform_parallel_clustering(queries, keys, values)
            return self.uniform_static_clustering(queries, keys, values)
        assert not self.compute_indices, "compute_indices is not supported yet in MultiLevelClustering with coupled clustering"
        key = PRNGKey(42)
        klabel_key, qlabel_key = jax.random.split(key, 2)
        kres = keys
        for _ in range(self.outer_iters):
            qmetric = jnp.cov(kres, bias=True, rowvar=False)
            qcentroids, qlabels = hkmeans(xs=queries, key=qlabel_key, k=self.Qs, iters=self.inner_iters, logmass=qlogmass, metric=qmetric)
            qres = queries - qcentroids[qlabels] # (N, D)
            kmetric = jnp.cov(qres, bias=True, rowvar=False)
            kcentroids, klabels = hkmeans(xs=keys, key=klabel_key, k=self.Ks, iters=self.inner_iters, logmass=klogmass, metric=kmetric)
            kres = keys - kcentroids[klabels] # (N, D)
        return qlabels, klabels # (N,) qlabels, (N,) klabels

    def static_clustering(
        self,
        queries: Array, # (N, D)
        keys: Array, # (N, D)
        values: Array, # (N, V)
        qlogmass: Optional[Array] = None, # (N,) optional
        klogmass: Optional[Array] = None, # (N,) optional
    ) -> Tuple[Array, Array]:
        key = PRNGKey(42)
        klabel_key, qlabel_key = jax.random.split(key, 2)
        qmetric = jnp.cov(keys, bias=True, rowvar=False)
        kmetric = jnp.cov(queries, bias=True, rowvar=False)
        (Q0,) = self.Qs
        (K0,) = self.Ks
        if qlogmass is None:
            qlogmass = jnp.zeros((queries.shape[0],), dtype=queries.dtype)
        if klogmass is None:
            klogmass = jnp.zeros((keys.shape[0],), dtype=keys.dtype)
        qlabels = kernels.kmeans_with_init(xs=queries, key=qlabel_key,k=Q0, iters=self.inner_iters, logmass=qlogmass, metric=None)
        klabels = kernels.kmeans_with_init(xs=keys, key=klabel_key, k=K0, iters=self.inner_iters, logmass=klogmass, metric=None)
        return qlabels, klabels # (N,) qlabels, (N,) klabels
        qcentroids, qlabels = hkmeans(xs=queries, key=qlabel_key, k=self.Qs, iters=self.inner_iters, logmass=qlogmass, metric=qmetric)
        kcentroids, klabels = hkmeans(xs=keys, key=klabel_key, k=self.Ks, iters=self.inner_iters, logmass=klogmass, metric=kmetric)
        return qlabels, klabels # (N,) qlabels, (N,) klabels

    def uniform_static_clustering(
        self,
        queries: Array, # (N, D)
        keys: Array, # (N, D)
        values: Array, # (N, V)
    ) -> Tuple[Array, Array]:
        klabel_key, qlabel_key = jax.random.split(PRNGKey(42), 2)
        Qs = self.Qs
        Ks = self.Ks
        Qcap = self.get_cluster_size(queries.shape[0], prod(Qs)) # Assumed max size of each cluster
        Kcap = self.get_cluster_size(keys.shape[0], prod(Ks)) # Assumed max size of each cluster
        qmetric = jnp.cov(keys, bias=True, rowvar=False)
        kmetric = jnp.cov(queries, bias=True, rowvar=False)
        if self.compute_indices:
            (K,) = self.Ks
            kcounts, klabels, kfwd_indices, kbwd_indices = fast_kmeans.balanced_kmeans_with_indices(
                xs=keys, k=K, iters=self.inner_iters, metric=kmetric, max_cluster_size=Kcap)
            ktotals = jax.ops.segment_sum(keys, klabels, K) # (K, D)
            kcent = ktotals / kcounts[:, None]
        else:
            _, klabels = fast_kmeans.balanced_hkmeans_with_init(xs=keys, key=klabel_key, ks=Ks, iters=self.inner_iters, metric=kmetric, max_cluster_size=Kcap)
            kcent = jax.ops.segment_sum(keys, klabels, prod(Ks)) / jax.ops.segment_sum(jnp.ones(keys.shape[0], keys.dtype), klabels, prod(Ks))[:, None] # (K, D)
        kres = keys = kcent[klabels]
        qmetric = jnp.cov(kres, bias=True, rowvar=False)
        if self.compute_indices:
            (Q,) = self.Qs
            qcounts, qlabels, qfwd_indices, qbwd_indices = fast_kmeans.balanced_kmeans_with_indices(
                xs=queries, k=Q, iters=self.inner_iters, metric=qmetric, max_cluster_size=Qcap)
            return (qcounts, qlabels, qfwd_indices, qbwd_indices), (kcounts, klabels, kfwd_indices, kbwd_indices)
        else:
            _, qlabels = fast_kmeans.balanced_hkmeans_with_init(xs=queries, key=qlabel_key, ks=Qs, iters=self.inner_iters, metric=qmetric, max_cluster_size=Qcap)
            return qlabels, klabels
        qlabels = kernels.uniform_kmeans(xs=queries, key=qlabel_key, k=Q0, iters=self.inner_iters)
        klabels = kernels.uniform_kmeans(xs=keys, key=klabel_key, k=K0, iters=self.inner_iters)
        return qlabels, klabels # (N,) qlabels, (N,) klabels
        qlabels = kernels.kmeans_with_init(xs=queries, key=qlabel_key, k=Q0, iters=self.inner_iters, logmass=None, metric=None)
        klabels = kernels.kmeans_with_init(xs=keys, key=klabel_key, k=K0, iters=self.inner_iters, logmass=None, metric=None)
        return qlabels, klabels # (N,) qlabels, (N,) klabels

    def uniform_parallel_clustering(
        self,
        queries: Array, # (N, D)
        keys: Array, # (N, D)
        values: Array, # (N, V)
    ) -> Tuple[Array, Array]:
        (Q,) = self.Qs
        (K,) = self.Ks
        assert Q == K, "Uniform parallel clustering requires equal number of query and key clusters."
        assert self.compute_indices, "Uniform parallel clustering requires compute_indices to be True."
        Cap = self.get_cluster_size(queries.shape[0], Q) # Assumed max size of each cluster

        @jax.vmap
        def make_clusters(xs, metric):
            return fast_kmeans.balanced_kmeans_with_indices(
                xs=xs, k=K, iters=self.inner_iters, metric=metric, max_cluster_size=Cap)

        N, D = queries.shape
        qmetric = jnp.cov(keys, bias=True, rowvar=False)
        kmetric = jnp.cov(queries, bias=True, rowvar=False)
        #qmetric = jnp.identity(D, dtype=queries.dtype)
        #kmetric = jnp.identity(D, dtype=keys.dtype)
        metrics = jnp.stack([qmetric, kmetric], axis=0) # (2, D, D)

        xs = jnp.stack([queries, keys], axis=0) # (2, N, D)

        counts, labels, fwd_indices, bwd_indices = make_clusters(xs, metrics) # (2, K), (2, N), (2, K, n), (2, N)
        return (counts[0], labels[0], fwd_indices[0], bwd_indices[0]), (counts[1], labels[1], fwd_indices[1], bwd_indices[1])





@dataclass
class MultiLevelAttention:
    clustering: Clustering

    def cluster_then_attention(
        self,
        queries: Array, # (N, D)
        keys: Array, # (N, D)
        values: Array, # (N, V)
        qlogmass: Optional[Array] = None, # for importance weighted queries
        klogmass: Optional[Array] = None, # for importance weighted key-values
        )-> Tuple[Array, Array]: # (N,) logmass, (N, V) values

        qresults, kresults = self.clustering.do_clustering(
            queries=queries, # (N, D)
            keys=keys, # (N, D)
            values=values, # (N, V)
            qlogmass=qlogmass, # (N,) optional
            klogmass=klogmass, # (N,) optional
        )
        if self.clustering.compute_indices:
            qcounts, qlabels, qfwd_indices, qbwd_indices = qresults
            kcounts, klabels, kfwd_indices, kbwd_indices = kresults
        else:
            qlabels, klabels = qresults, kresults
            qcounts, qfwd_indices, qbwd_indices = None, None, None
            kcounts, kfwd_indices, kbwd_indices = None, None, None
        return self.attend(queries, qlabels, keys, klabels, values,
            qfwd_indices=qfwd_indices, qbwd_indices=qbwd_indices, kfwd_indices=kfwd_indices, kbwd_indices=kbwd_indices)

    def attend(self, *args, **kwargs):
        if len(self.clustering.Qs) == 1 and len(self.clustering.Ks) == 1:
            # Single level clustering
            return self.flattened_attention(*args, **kwargs)
        else:
            return self.two_level_attention(*args, **kwargs)

    def flattened_attention(self, queries, qlabels, keys, klabels, values, **kwargs):
        Q = prod(self.clustering.Qs)
        K = prod(self.clustering.Ks)
        return self._single_level_attention(queries, qlabels, keys, klabels, values, Q, K, **kwargs)

    def coarse_level_attention(self, queries, qlabels, keys, klabels, values):
        Q = self.clustering.Qs[0]
        K = self.clustering.Ks[0]
        qlabels = qlabels // prod(self.clustering.Qs[1:])
        klabels = klabels // prod(self.clustering.Ks[1:])
        return self._single_level_attention(queries, qlabels, keys, klabels, values, Q, K)

    def _single_level_attention(
        self,
        queries: Array, # (N, D)
        qlabels: Array, # (N,) cluster labels for queries
        keys: Array, # (N, D)
        klabels: Array, # (N,) cluster labels for keys
        values: Array, # (N, V)
        Q: int, # Number of query clusters
        K: int, # Number of key clusters
        qfwd_indices: Optional[Array] = None, # (N,) optional, precomputed collation indices for queries
        qbwd_indices: Optional[Array] = None, # (N,) optional, precomputed decollation indices for queries
        kfwd_indices: Optional[Array] = None, # (N,) optional, precomputed collation indices for keys
        kbwd_indices: Optional[Array] = None, # (N,) optional, precomputed decollation indices for keys
        )-> Tuple[Array, Array]: # (N,) logmass, (N, V) values
        """
        Treats a hierarchical clustering as a big single-level clustering.
        """
        #@jax.vmap
        #def qcentroid(idx):
        #    return jnp.mean(queries, axis=0, where=(qlabels == idx)[:, None])
        #qcentroids = qcentroid(jnp.arange(Q)) # (Q, D)
        qcentroids = fast_kmeans.segsum_centroids(queries, qlabels, Q) # (Q, D)
        kcentroids, vcentroids, vkcentroids, logmass = kernel_level_zero(
        #kcentroids, vcentroids, vkcentroids, logmass = kernels.wasabi_level_zero(
            qcentroids=qcentroids, # (Q, D)
            keys=keys, # (N, D)
            klabels=klabels, # (N,)
            values=values, # (N, V)
            K=K, # Number of clusters
            n=self.clustering.get_cluster_size(keys.shape[0], K), # Assumed max size of each cluster
            fwd_indices=kfwd_indices, # (N,) optional, precomputed collation indices for keys
        )
        assert vkcentroids.shape == (K, values.shape[-1], queries.shape[-1]), f"vkcentroids shape mismatch: {vkcentroids.shape} != {(K, values.shape[-1], queries.shape[-1])}"
        qres = queries - qcentroids[qlabels] # (N, D)
        return kernel_level_final(
            qres=qres, # (N, D)
            qlabels=qlabels, # (N,)
            kcentroids=kcentroids, # (Q, K, D)
            vcentroids=vcentroids, # (Q, K, D)
            vkcentroids=vkcentroids, # (Q, K, D, D)
            logmass=logmass, # (Q, K)
            Q=Q, # Number of clusters
            n=self.clustering.get_cluster_size(queries.shape[0], Q), # Assumed max size of each cluster
            fwd_indices=qfwd_indices, # (N,) optional, precomputed collation indices for queries
            bwd_indices=qbwd_indices, # (N,) optional, precomputed decollation indices for queries
        )

    def two_level_attention(
        self,
        queries: Array, # (N, D)
        qlabels: Array, # (N,) cluster labels for queries
        keys: Array, # (N, D)
        klabels: Array, # (N,) cluster labels for keys
        values: Array, # (N, V)
        )-> Tuple[Array, Array]: # (N,) logmass, (N, V) values

        K0, K1 = self.clustering.Ks
        Q0, Q1 = self.clustering.Qs
        def qcentroid(idx):
            return jnp.mean(queries, axis=0, where=(qlabels == idx)[:, None])
        def qcentroid_coarse(idx):
            return jnp.mean(queries, axis=0, where=((qlabels // Q1) == idx)[:, None])
        qcentroids = jax.vmap(qcentroid)(jnp.arange(Q0 * Q1)) # (Q0, Q1, D)
        qcentroids1 = einshape("(ab)d->abd", qcentroids, a=Q0, b=Q1) # (Q0, Q1, D)
        qcentroids0 = jax.vmap(qcentroid_coarse)(jnp.arange(Q0)) # (Q0, D)

        def get_query_cov(idx):
            logmass = jnp.zeros((queries.shape[0],), dtype=queries.dtype)
            logmass = jnp.where(qlabels == idx, logmass, -jnp.inf) # (N,)
            weights = jax.nn.softmax(logmass) # (N,)
            qmean = jnp.einsum('n,nd->d', weights, queries) # (D,)
            return jnp.einsum('n,na,nb->ab', weights, queries - qmean, queries - qmean) # (D, D)
        fine_query_covs = jax.vmap(jax.vmap(get_query_cov))(jnp.arange(Q0 * Q1).reshape((Q0,Q1))) # (Q0, Q1, D, D)
        def get_coarse_query_cov(idx):
            coarse_qlabels = qlabels // Q1
            logmass = jnp.zeros((queries.shape[0],), dtype=queries.dtype)
            logmass = jnp.where(coarse_qlabels == idx, logmass, -jnp.inf) # (N,)
            weights = jax.nn.softmax(logmass) # (N,)
            qmean = jnp.einsum('n,nd->d', weights, queries) # (D,)
            return jnp.einsum('n,na,nb->ab', weights, queries - qmean, queries - qmean) # (D, D)
        coarse_query_covs = jax.vmap(get_coarse_query_cov)(jnp.arange(Q0)) # (Q0, D, D)
        coarse_broadcast_query_covs = jnp.broadcast_to(coarse_query_covs[:, None, :, :], (Q0, Q1, queries.shape[-1], queries.shape[-1])) # (Q0, Q1, D, D)

        kcentroids, vcentroids, vkcentroids, logmass = kernel_level_zero(
            qcentroids=qcentroids0, # (Q0, D)
            keys=keys, # (N, D)
            klabels=klabels, # (N,)
            values=values, # (N, V)
            K=K0 * K1, # Number of clusters
            n=self.clustering.get_cluster_size(keys.shape[0], K0 * K1), # Assumed max size of each cluster
            #metrics=coarse_query_covs,
        )
        qcentroids1_res = qcentroids1 - qcentroids0[:, None, :] # (Q0, Q1, D)
        def split_k(arr):
            return einshape("q(ab)...->qab...", arr, q=Q0, a=K0, b=K1)
        def merge_q(arr):
            return einshape("abk...->(ab)k...", arr, a=Q0, b=Q1, k=K0)
        kcentroids, vcentroids, vkcentroids, logmass = map(merge_q, jax.vmap(kernel_level_mid)(
            qcentroids=qcentroids1_res, # (Q0, Q1, D)
            kcentroids=split_k(kcentroids), # (Q0, K0, K1, D)
            vcentroids=split_k(vcentroids), # (Q0, K0, K1, V)
            vkcentroids=split_k(vkcentroids), # (Q0, K0, K1, V, D)
            logmass=split_k(logmass), # (Q0, K0, K1)
            #metrics=fine_query_covs, # (Q0, Q1, D, D)
        ))
        qres = queries - qcentroids[qlabels] # (N, D)
        return kernel_level_final(
            qres=qres, # (N, D)
            qlabels=qlabels, # (N,)
            kcentroids=kcentroids, # (Q0, Q1, K0, D)
            vcentroids=vcentroids, # (Q0, Q1, K0, V)
            vkcentroids=vkcentroids, # (Q0, Q1, K0, V, D)
            logmass=logmass, # (Q0, Q1, K0)
            Q=Q0 * Q1, # Number of clusters
            n=self.clustering.get_cluster_size(queries.shape[0], Q0 * Q1), # Assumed max size of each cluster
        )




@dataclass
class CausalAttention:
    simple_attention: SimpleAttention
    block_size: int = 256

    def exact_attention(
            self,
            queries: Array,
            keys: Array,
            values: Array
        ) -> Array:
        """
        Perform exact attention, for comparison purposes.
        """
        N = queries.shape[-2]
        causal_bias = jnp.log(jnp.tril(jnp.ones((N, N)), k=-1))
        scores = jnp.einsum('ij, kj->ik', queries, keys)
        scores = scores + causal_bias
        logmass = logsumexp(scores, axis=-1)
        weights = jax.nn.softmax(scores, axis=-1)
        weights = weights.at[...,0,:].set(0.)
        output = jnp.einsum('ik, kj->ij', weights, values)
        #jax.debug.print("logmass tail: {logmass}", logmass=logmass[-5:])
        #jax.debug.print("logmass end: {logmass}", logmass=logmass[-1])
        #jax.debug.print("value mag tail: {value_mag}", value_mag=jnp.linalg.norm(output, axis=-1)[-5:])
        return logmass, output

    def null_attention(
        self,
        queries: Array,
        keys: Array,
        values: Array
    ) -> Tuple[Array, Array]:
        """
        Perform null attention, which returns zero logmass and zero values.
        This is useful for testing purposes.
        """
        N = queries.shape[-2]
        logmass = jnp.ones((N,), dtype=queries.dtype) * -1e9
        output_shape = queries.shape[:-1] + (values.shape[-1],)
        output = jnp.zeros(output_shape, dtype=queries.dtype)
        return logmass, output

    def block_attention(
            self,
            queries: Array,
            keys: Array,
            values: Array
        ) -> Array:
        """
        Perform exact attention but blockwise, such that we have a baseline for when we
        swap out the off-diagonal algorithm.
        """
        return self._block_attention(
            off_diag_impl=self.simple_attention.exact_attention,
            queries=queries,
            keys=keys,
            values=values
        )

    def _block_attention(
            self,
            off_diag_impl: Callable,
            queries: Array,
            keys: Array,
            values: Array
        ) -> Array:
        assert queries.shape[-2] == keys.shape[-2] == values.shape[-2], f"Queries, keys, and values must have the same sequence length, found qs:{queries.shape[-2]}, ks:{keys.shape[-2]}, vs:{values.shape[-2]}."
        N = queries.shape[-2]
        D = queries.shape[-1]
        assert N % self.block_size == 0, f"Sequence length {N} must be divisible by the number of blocks {self.block_size}."
        n = N // self.block_size
        b = self.block_size
        query_blocks = einshape("...(nb)d->...nbd", queries, n=n, b=b, d=D)
        key_blocks = einshape("...(nb)d->...nbd", keys, n=n, b=b, d=D)
        value_blocks = einshape("...(nb)d->...nbd", values, n=n, b=b, d=D)

        diag_scores, diag_values = jax.vmap(self.exact_attention)(
            query_blocks, key_blocks, value_blocks
        )

        off_diag_row, off_diag_col = jnp.tril_indices(n, k=-1)
        off_diag_scores, off_diag_values = jax.vmap(off_diag_impl)(
                query_blocks[..., off_diag_row, :, :],
                key_blocks[..., off_diag_col, :, :],
                value_blocks[..., off_diag_col, :, :]
            )

        all_scores = jnp.log(jnp.zeros((n, b, n), dtype=diag_scores.dtype))
        print(all_scores[jnp.arange(n), :, jnp.arange(n)].shape)
        print(diag_scores.shape)
        all_scores = all_scores.at[jnp.arange(n), :, jnp.arange(n)].set(diag_scores)
        all_scores = all_scores.at[off_diag_row, :, off_diag_col].set(off_diag_scores)
        all_scores = einshape("nbn->(nb)n", all_scores, n=n, b=b)
        all_values = jnp.zeros((n, b, n, D), dtype=diag_values.dtype)
        all_values = all_values.at[jnp.arange(n), :, jnp.arange(n), :].set(diag_values)
        all_values = all_values.at[off_diag_row, :, off_diag_col, :].set(off_diag_values)
        all_values = einshape("nbnv->(nb)nv", all_values, n=n, b=b, v=D)
        total_scores = logsumexp(all_scores, axis=-1)
        total_weights = jax.nn.softmax(all_scores, axis=-1)
        total_weights = total_weights.at[..., 0, :].set(0.) # first token cannot attend to anything
        total_values = jnp.einsum('ij, ijv->iv', total_weights, all_values)
        jax.debug.print("logmass tail: {logmass}", logmass=total_scores[-5:])
        jax.debug.print("logmass end: {logmass}", logmass=total_scores[-1])
        jax.debug.print("logmass prefix: {logmass}", logmass=total_scores[:5])
        jax.debug.print("value mag tail: {value_mag}", value_mag=jnp.linalg.norm(total_values, axis=-1)[-5:])
        jax.debug.print("value mag prefix: {value_mag}", value_mag=jnp.linalg.norm(total_values, axis=-1)[:5])
        jax.debug.print("finite values: {finite_values}", finite_values=jnp.isfinite(total_values).all())
        #total_scores = total_scores.at[..., 0].set(0.)
        return total_scores, total_values

 


    def fast_attention(
            self,
            queries: Array,
            keys: Array,
            values: Array
        ) -> Array:
        """
        Perform exact attention on diagonal blocks and simple attention on off diagonals.
        """
        return self._block_attention(
            off_diag_impl=self.simple_attention.cluster_then_attention,
            queries=queries,
            keys=keys,
            values=values
        )
        
    def block_diag_attention(
            self,
            queries: Array,
            keys: Array,
            values: Array
        ) -> Array:
        """
        Perform exact attention on diagonal blocks and null attention on off diagonals.
        """
        return self._block_attention(
            off_diag_impl=self.null_attention,
            queries=queries,
            keys=keys,
            values=values
        )

