import jax
from jax import numpy as jnp, lax, vmap
from jax.numpy import einsum
from einshape import jax_einshape as einshape
from functools import partial
from jax.experimental import pallas as pl
from jax.experimental.pallas import triton as plgpu
from jax.scipy.special import logsumexp
import dataclasses
from time import time
import math
import numpy as np
from jax import Array as Array
from .single_level_attention import MultiLevelAttention, MultiLevelClustering
from .flash_attention import attn as pallas_flash
from flax import struct
from .fast_kmeans import pallas_assign_indices as original_assign_indices, segsum_centroids as original_segsum_centroids
from .cuda_kernels.assign_indices_py import assign_indices as custom_cuda_assign_indices
from .cuda_kernels.cluster_py import _cuda_adjust_fp16 as custom_cuda_adjust_fp16

DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.float32).max)
DIPOLE = True

@struct.dataclass
class ClusterData:
    cnt: Array # K
    lab: Array # N
    fwd: Array # Ku
    bwd: Array # N

def baseline_centroids(K, xs, labels):
    return original_segsum_centroids(xs, labels, K)
    def centroid(i):
        return jnp.mean(xs, where=(labels == i)[:, None], axis=0)
    return vmap(centroid)(jnp.arange(K))

def baseline_assign(xs, centroids):
    centroid_sqmags = jnp.sum(jnp.square(centroids), axis=-1)
    similarities = einsum("nd,kd->nk", xs, centroids)
    return jnp.argmin(-2 * similarities + centroid_sqmags[None, :], axis=-1)

def baseline_adjust(num_iters, xs, centroids):
    def body(i, centroids):
        labels = baseline_assign(xs, centroids)
        centroids = baseline_centroids(centroids.shape[0], xs, labels)
        return centroids
    centroids = lax.fori_loop(0, num_iters, body, centroids)
    return centroids

def pallas_adjust(num_iters, xs, centroids):
    def body(i, centroids):
        labels = pallas_assign(xs, centroids)
        centroids = baseline_centroids(centroids.shape[0], xs, labels)
        return centroids
    centroids = lax.fori_loop(0, num_iters, body, centroids)
    return centroids

def baseline_kmeans(K, num_iters, xs):
    N, D = xs.shape
    centroids = xs[::N//K, :][:K, :]  # Initialize centroids
    #centroids = sqmag_init(K, xs)  # Use square magnitude initialization
    assert centroids.shape == (K, D), f"Centroids shape mismatch: {centroids.shape} != {(K, D)}"
    centroids = baseline_adjust(num_iters, xs, centroids)
    return centroids

def baseline_kmeans_indices(K, cluster_max, num_iters, xs):
    centroids = baseline_kmeans(K, num_iters, xs)
    cnt, lab, fwd, bwd = original_assign_indices(xs, centroids, metric=None, max_cluster_size=cluster_max)
    return ClusterData(cnt, lab, fwd, bwd)

def kmeans_with_indices(K, cluster_max, num_iters, xs):
    xs = jax.lax.stop_gradient(xs)  # Ensure no gradients are computed for xs
    N, D = xs.shape
    #centroids = xs[::N//K, :][:K, :]  # Initialize centroids
    centroids = sqmag_init(K, xs)  # Use square magnitude initialization
    #centroids = naive_kmeans_plus_plus(K, xs)  # Use k-means++ initialization
    assert centroids.shape == (K, D), f"Centroids shape mismatch: {centroids.shape} != {(K, D)}"
    centroids = pallas_adjust(num_iters, xs, centroids)
    return custom_assign_indices(cluster_max, xs, centroids)

def cuda_adjust_fp16(num_iters, xs, centroids):
    assert centroids.dtype == jnp.float16, "Centroids must be float16"
    assert xs.dtype == jnp.float16, "xs must be float16"
    totals = centroids
    counts = jnp.ones((centroids.shape[0],), dtype=jnp.float16)
    totals, counts = custom_cuda_adjust_fp16(beta=0.90, iters=num_iters, xs=xs, totals=totals, counts=counts)
    centroids = totals / counts[:, None]
    return centroids

def kmeans_centroids_cuda(K, num_iters, xs):
    xs = jax.lax.stop_gradient(xs)  # Ensure autograd does not try to differentiate through clustering
    xs_fp16 = xs.astype(jnp.float16)
    permutation = jax.random.permutation(jax.random.PRNGKey(0), xs_fp16.shape[0])
    xs_fp16 = xs_fp16[permutation]
    N, D = xs.shape
    centroids = xs_fp16[::N//K, :][:K, :].astype(jnp.float16)  # Initialize centroids
    centroids = cuda_adjust_fp16(num_iters, xs_fp16, centroids)
    return centroids.astype(xs.dtype)

def kmeans_with_indices_cuda(K, cluster_max, num_iters, xs):
    centroids = kmeans_centroids_cuda(K, num_iters, xs)
    return custom_assign_indices(cluster_max, xs, centroids)
    from .cuda_kernels.cluster_py import _cuda_adjust_fp16
    xs = jax.lax.stop_gradient(xs)  # Ensure no gradients are computed for xs
    xs_fp16 = xs.astype(jnp.float16)
    # shuffle xs
    permutation = jax.random.permutation(jax.random.PRNGKey(0), xs_fp16.shape[0])
    xs_fp16 = xs_fp16[permutation]
    N, D = xs.shape
    centroids = xs_fp16[::N//K, :][:K, :].astype(jnp.float16)  # Initialize centroids
    #centroids = sqmag_init(K, xs.astype(jnp.float32)).astype(jnp.float16)  # Use square magnitude initialization
    totals = centroids
    counts = jnp.ones((K,), dtype=jnp.float16)
    assert centroids.shape == (K, D), f"Centroids shape mismatch: {centroids.shape} != {(K, D)}"
    totals, counts = _cuda_adjust_fp16(beta=0.90, iters=num_iters, xs=xs_fp16, totals=totals, counts=counts)
    centroids = totals / counts[:, None]
    return custom_assign_indices(cluster_max, xs, centroids.astype(xs.dtype))

def do_clustering(K, cluster_max, num_iters, qs, ks, vs):
    stacked_qk = jnp.stack((qs, ks), axis=0) # (2, N, D)
    vmapped_kmeans = vmap(partial(kmeans_with_indices_cuda, K, cluster_max, num_iters))
    result = vmapped_kmeans(jax.lax.stop_gradient(stacked_qk))
    qcnt, kcnt = result.cnt
    qlab, klab = result.lab
    qfwd, kfwd = result.fwd
    qbwd, kbwd = result.bwd
    qmeta = ClusterData(qcnt, qlab, qfwd, qbwd)
    kmeta = ClusterData(kcnt, klab, kfwd, kbwd)
    return qmeta, kmeta


def sqmag_init(K, xs, centering=True):
    mean_xs = jnp.mean(xs, axis=0)
    N = xs.shape[0]
    key = jax.random.PRNGKey(0)
    xs_centered = xs - mean_xs[None, :] if centering else xs
    weights = jnp.square(xs_centered).sum(axis=-1)
    centroids = jax.random.choice(key, xs, shape=(K,), replace=False, p=weights / jnp.sum(weights))
    if centering:
        centroids = centroids + mean_xs[None, :]
    return centroids

def subsample_baseline_kmeans(K, M, num_iters, xs):
    x_mean = jnp.mean(xs, axis=0)
    weights = jnp.square(xs - x_mean).sum(axis=-1)
    key = jax.random.PRNGKey(0)
    subsampled_xs = jax.random.choice(key, xs, shape=(M,), replace=False, p=weights / jnp.sum(weights))
    print(f"Subsampled shape: {subsampled_xs.shape}, original shape: {xs.shape}, K: {K}, M: {M}")
    return baseline_kmeans(K, num_iters, subsampled_xs)

def naive_kmeans_plus_plus(K, xs):
    N, D = xs.shape
    centroids = jnp.zeros((K, D), dtype=xs.dtype)
    centroids = centroids.at[0].set(xs[jax.random.randint(jax.random.PRNGKey(0), (), 0, N)])

    def body(i, centroids):
        dists = jnp.linalg.norm(xs[:, None, :] - centroids[None, :], axis=-1) # (N, K)
        mask = jnp.arange(K) < i
        min_dists = jnp.min(dists, axis=1, where=mask[None, :], initial=jnp.inf)  # (N,)
        #jax.debug.print("iter {}, min {}, max {}", i, jnp.min(min_dists), jnp.max(min_dists))
        probs = min_dists / jnp.sum(min_dists)
        new_centroid_idx = jax.random.choice(jax.random.PRNGKey(i), N, p=probs)
        centroids = centroids.at[i].set(xs[new_centroid_idx])
        return centroids
    centroids = lax.fori_loop(1, K, body, centroids)
    return centroids
    assert not jnp.any(jnp.isnan(centroids)), "NaN found in centroids"
    return baseline_adjust(num_iters, xs, centroids)

def assign_kernel(xs_ref, centroids_ref, labels_ref, *, block_n: int):
    K, D = centroids_ref.shape
    N = xs_ref.shape[0]
    assert N % block_n == 0, f"N ({N}) must be divisible by block_n ({block_n})"
    centroids = centroids_ref[...] # (K, D)
    centroid_sqmags = jnp.sum(jnp.square(centroids), axis=-1) # (K,)

    def body(start_n, _):
        current_n_slice = pl.dslice(start_n*block_n, block_n)
        xs = xs_ref[current_n_slice, :] # (block_n, D)
        similarities = pl.dot(xs, centroids, trans_b=True, allow_tf32=True)  # (block_n, K)
        scores = -2 * similarities + centroid_sqmags[None, :]  # (block_n, K)
        labels = jnp.argmin(scores, axis=-1)  # (block_n,)
        labels_ref[current_n_slice] = labels
    lower_bound = 0
    upper_bound = pl.cdiv(N, block_n)
    lax.fori_loop(lower_bound, upper_bound, body, None)

def pallas_assign(xs, centroids):
    N, D = xs.shape
    
    labels = pl.pallas_call(
        kernel=partial(assign_kernel, block_n=128),
        out_shape=jax.ShapeDtypeStruct((N,), jnp.int32),
        grid=(),
        compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=1),
        name="assign",
    )(xs, centroids)
    return labels

def assign_indices_kernel(xs_ref, centroids_ref, in_cnt_ref, in_lab_ref, in_fwd_ref, in_bwd_ref, cnt_ref, lab_ref, fwd_ref, bwd_ref, *, block_n: int):
    K, D = centroids_ref.shape
    N = xs_ref.shape[0]
    assert N % block_n == 0, f"N ({N}) must be divisible by block_n ({block_n})"
    _, CAP = fwd_ref.shape
    assert cnt_ref.shape == (K,)
    assert lab_ref.shape == (N,)
    assert fwd_ref.shape == (K, CAP)
    assert bwd_ref.shape == (N,)

    centroids = centroids_ref[...]  # (K, D)
    centroid_sqmags = jnp.sum(jnp.square(centroids), axis=-1)  # (K,)
    counts = jnp.zeros((K,), dtype=jnp.int32)  # (K,)
    #lab_ref[...] = -1  # Reset labels
    #fwd_ref[...] = -1  # Reset forward indices
    #bwd_ref[...] = -1  # Reset backward indices


    def body(start_n, counts):
        current_n_slice = pl.dslice(start_n * block_n, block_n)
        xs = xs_ref[current_n_slice, :] # (block_n, D)
        similarities = pl.dot(xs, centroids, trans_b=True, allow_tf32=True)  # (block_n, K)
        scores = -2 * similarities + centroid_sqmags[None, :]  # (block_n, K)
        labels = jnp.argmin(scores, axis=-1)  # (block_n,)
        new_counts = counts + (jnp.arange(K, dtype=jnp.int32)[None,:] == labels[:, None]).sum(axis=0)  # (block_n, K)
        #new_counts = counts
        #mask = new_counts <= CAP  # Mask for valid counts
        mask = jnp.ones((block_n,), dtype=jnp.bool_)  # No mask, all counts are valid
        pl.store(lab_ref, current_n_slice, labels, mask=mask)  # Store labels where counts are valid
        current_indices = start_n * block_n + jnp.arange(block_n)  # (block_n,)
        #pl.store(fwd_ref, (labels, old_counts), current_indices, mask=mask)
        #pl.store(bwd_ref, current_n_slice, old_counts, mask=mask)
        return new_counts
    lower_bound = 0
    upper_bound = pl.cdiv(N, block_n)
    upper_bound = 1
    counts = lax.fori_loop(lower_bound, upper_bound, body, counts)
    cnt_ref[...] = counts  # Update counts in the output reference


def pallas_assign_indices(max_cluster_size, xs, centroids):

    K, D = centroids.shape
    N = xs.shape[0]
    init_cnt = jnp.ones((K,), dtype=jnp.int32) *5  # (K,)
    init_lab = jnp.full((N,), -1, dtype=jnp.int32)  # (N,)
    init_fwd = jnp.full((K, max_cluster_size), -1, dtype=jnp.int32)  # (K, max_cluster_size)
    init_bwd = jnp.full((N,), -1, dtype=jnp.int32)  # (N,)

    cnt, lab, fwd, bwd = pl.pallas_call(
        kernel=partial(assign_indices_kernel, block_n=16),
        out_shape=(
            jax.ShapeDtypeStruct((K,), jnp.int32),  # cnt
            jax.ShapeDtypeStruct((N,), jnp.int32),  # lab
            jax.ShapeDtypeStruct((K, max_cluster_size), jnp.int32),  # fwd
            jax.ShapeDtypeStruct((N,), jnp.int32)   # bwd
        ),
        grid=(),
        compiler_params=plgpu.CompilerParams(num_warps=1, num_stages=1),
        name="assign_indices",
        input_output_aliases={2: 0, 3: 1, 4: 2, 5: 3}  # cnt, lab, fwd, bwd
    )(xs, centroids, init_cnt, init_lab, init_fwd, init_bwd)
    return ClusterData(cnt, lab, fwd, bwd)

def custom_assign_indices(max_cluster_size, xs, centroids):
    centroid_sqmags = jnp.sum(jnp.square(centroids), axis=-1)  # (K,)
    similarities = einsum("nd,kd->kn", xs, centroids)  # (K, N)
    costs = -2 * similarities + centroid_sqmags[:, None]  # (K, N)
    cnt, lab, fwd, bwd = custom_cuda_assign_indices(max_cluster_size, costs.astype(jnp.float32))
    return ClusterData(cnt, lab, fwd, bwd)


        






##############################################################################################
#Evaluation code:
##############################################################################################

def evaluate_clustering(K, xs, labels):
    centroids = baseline_centroids(K, xs, labels)
    counts = jnp.bincount(labels, minlength=K)
    xres = xs - centroids[labels]
    sum_x_sqmags = jnp.sum(jnp.square(xs))
    sum_xres_sqmags = jnp.sum(jnp.square(xres))
    weighted_sum_centroids_sqmags = jnp.sum(jnp.square(centroids) * counts[:, None])
    N = xs.shape[0]
    print(f"total variance: {float(sum_x_sqmags / N):.2f}, centroids: {float(weighted_sum_centroids_sqmags / N):.2f}, residual: {float(sum_xres_sqmags / N):.2f}")


def consider_joint_clustered_variances(qs, ks, vs, qmeta, kmeta):
    qcentroids = baseline_centroids(qmeta.cnt.shape[0], qs, qmeta.lab)
    kcentroids = baseline_centroids(kmeta.cnt.shape[0], ks, kmeta.lab)
    qres = qs - qcentroids[qmeta.lab]
    kres = ks - kcentroids[kmeta.lab]
    return consider_joint_variances(qres, kres, vs)

def consider_joint_variances(qs, ks, vs):
    N, D = qs.shape
    q_bar = jnp.mean(qs, axis=0)
    k_bar = jnp.mean(ks, axis=0)
    v_bar = jnp.mean(vs, axis=0)
    qs = qs - q_bar[None, :]
    ks = ks - k_bar[None, :]
    vs = vs - v_bar[None, :]

    scale = math.sqrt(math.sqrt(D))
    qs = qs / scale
    ks = ks / scale

    q_cov = (jnp.einsum("nd,ne->de", qs, qs) / N)
    k_cov = (jnp.einsum("nd,ne->de", ks, ks) / N)
    q_cov_k_cov = q_cov @ k_cov
    q_chol = jnp.linalg.cholesky(q_cov + 1e-6 * jnp.eye(D))
    qkq = q_chol.T @ k_cov @ q_chol

    #q_eig = jnp.sort(jnp.linalg.eigvalsh(q_cov), descending=True)
    #k_eig = jnp.sort(jnp.linalg.eigvalsh(k_cov), descending=True)
    qk_eig = jnp.sort(jnp.abs(jnp.linalg.eigvals(q_cov_k_cov)), descending=True)
    qkq_eig = jnp.sort(jnp.linalg.eigvalsh(qkq), descending=True)
    #q_eig_recaled = q_eig / jnp.mean(q_eig) * jnp.mean(k_eig)
    #k_eig_recaled = k_eig / jnp.mean(k_eig) * jnp.mean(q_eig)
    #print(f"Q eig: {q_eig}")
    #print(f"K eig: {k_eig}")
    #print(f"Q eig (rescaled by mean K eig): {q_eig_recaled}")
    #print(f"K eig (rescaled by mean Q eig): {k_eig_recaled}")
    print(f"QK eig: {qk_eig}")
    print(f"QKQ eig: {qkq_eig}")


def sq_err(ref, x):
    ref_msq = jnp.mean(jnp.square(ref))
    err_msq = jnp.mean(jnp.square(ref - x))
    return jnp.where(ref_msq > 0, err_msq / ref_msq, jnp.inf).astype(jnp.float32)

mha_sq_err = vmap(vmap(sq_err, in_axes=1))
tree_sq_err = partial(jax.tree.map, lambda ref, x: float(sq_err(ref, x)))

def benchmark_fn(fn, args, ref_out, *, warmup_iters=100, iters=100):
    start_time = time()
    jit_fn = jax.jit(fn)
    fn_out = jax.block_until_ready(jit_fn(*args))
    end_time = time()
    sq_errs = tree_sq_err(ref_out, fn_out)
    print(f"First call took {end_time - start_time:.4f} seconds.")
    print(f"Output errors: {sq_errs}")
    # Warmup iterations
    for _ in range(warmup_iters):
        jax.block_until_ready(jit_fn(*args))
    # Benchmark iterations
    start_time = time()
    for _ in range(iters):
        jax.block_until_ready(jit_fn(*args))
    end_time = time()
    print(f"Average time over {iters} iterations: {1e3 * (end_time - start_time) / iters:.4f} ms.")

_multi_level_clustering = MultiLevelClustering(
    Qs=(64,),
    Ks=(64,),
    coupled_clustering=False,
    inner_iters=3,
    outer_iters=3,
    max_cluster_scale=1.5,
    compute_indices=True,
)

def _do_multi_level_clustering(qs, ks, vs):
    (qcnt, qlab, qfwd, qbwd), (kcnt, klab, kfwd, kbwd) = _multi_level_clustering.do_clustering(qs, ks, vs)
    qmeta = ClusterData(qcnt, qlab, qfwd, qbwd)
    kmeta = ClusterData(kcnt, klab, kfwd, kbwd)
    return qmeta, kmeta

_do_multi_level_clustering_mha = vmap(vmap(_do_multi_level_clustering, in_axes=1, out_axes=1))

def cluster_attn(qs, ks, vs):
    qmeta, kmeta = _do_multi_level_clustering(qs, ks, vs)
    return attn(qs, ks, vs, qmeta, kmeta)

cluster_attn_mha = vmap(vmap(cluster_attn, in_axes=1, out_axes=1))

def get_data(dtype, n):
    data = jnp.load("examples/minigpt/qkv_data_64_T1/qkv_step_1600.npz")
    #data = jnp.load("examples/minigpt/qkv_data_64_T1/qkv_step_15200.npz")
    ks, vs, qs = data["k"], data["v"], data["q"] # (B, N, H, D)
    B, N, H, D = qs.shape
    B, N, H, V = vs.shape
    assert N % n == 0, f"N ({N}) must be divisible by n ({n})"
    M = N // n
    qs = einshape("b(mn)hd->(bm)nhd", qs, m=M, n=n)
    ks = einshape("b(mn)hd->(bm)nhd", ks, m=M, n=n)
    vs = einshape("b(mn)hv->(bm)nhv", vs, m=M, n=n)

    #ks = ks[:, :n, :, :] # (B, N, H, D)
    #vs = vs[:, :n, :, :] # (B, N, H, V)
    #qs = qs[:, -n:, :, :] # (B, N, H, D)
    clustering = MultiLevelClustering(
        Qs=(64,),
        Ks=(64,),
        coupled_clustering=False,
        inner_iters=3,
        outer_iters=3,
        max_cluster_scale=1.5,
        compute_indices=True,
    )
    multi_attention = MultiLevelAttention(
        clustering=clustering,
    )
    def baseline_approx(qs, ks, vs, qmeta, kmeta):
        scale = math.sqrt(math.sqrt(qs.shape[-1]))
        lse, out = multi_attention.attend(
            queries=qs / scale, qlabels=qmeta.lab, keys=ks / scale, klabels=kmeta.lab, values=vs,
            qfwd_indices=qmeta.fwd, qbwd_indices=qmeta.bwd,
            kfwd_indices=kmeta.fwd, kbwd_indices=kmeta.bwd,
        )
        return lse, out

    baseline_approx_mha = vmap(vmap(baseline_approx, in_axes=1, out_axes=1))

    pallas_flash_mha = vmap(vmap(partial(pallas_flash, sm_scale=1.0 / math.sqrt(qs.shape[-1])), in_axes=1, out_axes=1))

    clustering_mha = vmap(vmap(clustering.do_clustering, in_axes=1, out_axes=1))
    (qcnt, qlab, qfwd, qbwd), (kcnt, klab, kfwd, kbwd) = clustering_mha(qs, ks, vs)
    qmeta = ClusterData(qcnt, qlab, qfwd, qbwd)
    kmeta = ClusterData(kcnt, klab, kfwd, kbwd)
    return pallas_flash_mha, baseline_approx, baseline_approx_mha, qs.astype(dtype), ks.astype(dtype), vs.astype(dtype), qmeta, kmeta

def check_meta_valid(meta):
    (N,) = meta.lab.shape
    K, CAP = meta.fwd.shape
    counts = jnp.bincount(meta.lab, minlength=K, length=K)
    assert jnp.all(counts == meta.cnt), f"Counts mismatch: {counts} != {meta.cnt}"
    probe = jnp.arange(N)
    as_clusters = probe[meta.fwd] # (K, CAP)
    flat_again = as_clusters[meta.lab, meta.bwd]  # (N,)
    assert jnp.all(flat_again == probe), f"Forward and backward indices mismatch at {jnp.mean(flat_again != probe)*100.} % of indices"

def main():
    dtype = jnp.bfloat16
    requested_n = 2**12
    pallas_flash_mha, baseline_approx, baseline_approx_mha, qs, ks, vs, qmeta, kmeta = get_data(dtype=dtype, n=requested_n)
    def one_head(array):
        return array[0,:, 0, ...]
    oh_qs, oh_ks, oh_vs = one_head(qs), one_head(ks), one_head(vs)
    oh_qmeta = jax.tree.map(one_head, qmeta)
    oh_kmeta = jax.tree.map(one_head, kmeta)
    print(f"oh shapes: {oh_qs.shape}, {oh_ks.shape}, {oh_vs.shape}")
    print(f"oh qmeta: {oh_qmeta.fwd.shape}, oh kmeta: {oh_kmeta.fwd.shape}")
    
    B, N, H, D = qs.shape
    _, _, _, V = vs.shape
    sm_scale = 1.0 / math.sqrt(D)
    print(f"B: {B}, N: {N}, H: {H}, D: {D}, V: {V}")
    dlse_shape = (B, N, H)
    dv_out_shape = (B, N, H, V)
    dlse = jax.random.normal(jax.random.PRNGKey(0), shape=dlse_shape, dtype=jnp.float32)
    dv_out = jax.random.normal(jax.random.PRNGKey(1), shape=dv_out_shape, dtype=dtype)
    oh_dlse = one_head(dlse)
    oh_dv_out = one_head(dv_out)

    #print(tree_sq_err(exact_oh_out, ref_oh_out))

    ohqmeta, ohkmeta = do_clustering(64, 96, 3, oh_qs.astype(jnp.float32), oh_ks.astype(jnp.float32), oh_vs.astype(jnp.float32))
    consider_joint_variances(oh_qs.astype(jnp.float32), oh_ks.astype(jnp.float32), oh_vs.astype(jnp.float32))
    consider_joint_clustered_variances(oh_qs.astype(jnp.float32), oh_ks.astype(jnp.float32), oh_vs.astype(jnp.float32), ohqmeta, ohkmeta)
    exit()


    K = 64
    CLUSTER_MAX = math.ceil(1.5 * N / K)
    print(f"K: {K}, CLUSTER_MAX: {CLUSTER_MAX}")
    NUM_ITERS = 1
    baseline_kmeans_oh = partial(baseline_kmeans, K, NUM_ITERS)
    baseline_kmeans_mha = vmap(vmap(baseline_kmeans_oh, in_axes=1, out_axes=1))
    baseline_assign_mha = vmap(vmap(baseline_assign, in_axes=1, out_axes=1))
    pallas_assign_mha = vmap(vmap(pallas_assign, in_axes=1, out_axes=1))
    baseline_kmeans_out_oh = baseline_kmeans_oh(oh_ks)
    pallas_assign_oh_out = pallas_assign(oh_ks, baseline_kmeans_out_oh)
    custom_assign_indices_oh = partial(custom_assign_indices, CLUSTER_MAX)
    custom_assign_indices_out_oh = custom_assign_indices_oh(oh_ks, baseline_kmeans_out_oh)
    check_meta_valid(custom_assign_indices_out_oh)
    print(f"custom assign indices cnt: {custom_assign_indices_out_oh.cnt}")
    print(f"custom assign indices error: {tree_sq_err(custom_assign_indices_out_oh.lab, pallas_assign_oh_out)}")
    assert jnp.all(custom_assign_indices_out_oh.lab >=0) and jnp.all(custom_assign_indices_out_oh.lab < K), "Custom assign indices output has invalid labels"
    custom_assign_indices_mha = vmap(vmap(custom_assign_indices_oh, in_axes=1, out_axes=1))

    baseline_kmeans_out_mha = baseline_kmeans_mha(ks)
    baseline_assign_out_mha = baseline_assign_mha(ks, baseline_kmeans_out_mha)
    pallas_assign_out_mha = pallas_assign_mha(ks, baseline_kmeans_out_mha)
    subsample_baseline_kmeans_oh = partial(subsample_baseline_kmeans, K, 2**12, NUM_ITERS)
    subsample_baseline_kmeans_mha = vmap(vmap(subsample_baseline_kmeans_oh, in_axes=1, out_axes=1))
    custom_assign_indices_out_mha = custom_assign_indices_mha(ks, baseline_kmeans_out_mha)

    baseline_kmeans_indices_oh = partial(baseline_kmeans_indices, K, CLUSTER_MAX, NUM_ITERS)

    kmeans_plus_plus_oh = partial(naive_kmeans_plus_plus, K, NUM_ITERS)
    kmeans_plus_plus_mha = vmap(vmap(kmeans_plus_plus_oh, in_axes=1, out_axes=1))
    kmeans_plus_plus_out_oh = kmeans_plus_plus_oh(oh_ks)

    evaluate_clustering(K, oh_ks, baseline_assign(oh_ks, baseline_kmeans_out_oh))
    evaluate_clustering(K, oh_ks, baseline_assign(oh_ks, kmeans_plus_plus_out_oh))
    evaluate_clustering(K, oh_ks, baseline_assign(oh_ks, subsample_baseline_kmeans(K, 2**10, NUM_ITERS, oh_ks)))
    print("Evaluating balanced clustering from custom assign indices...")
    evaluate_clustering(K, oh_ks, custom_assign_indices_out_oh.lab)

    print("Benchmarking mha baseline assign...")
    benchmark_fn(baseline_assign_mha, (ks, baseline_kmeans_out_mha), baseline_assign_out_mha)
    print("Benchmarking mha pallas assign...")
    benchmark_fn(pallas_assign_mha, (ks, baseline_kmeans_out_mha), baseline_assign_out_mha)
    #print("Benchmarking mha pallas assign indices...")
    #benchmark_fn(pallas_assign_indices_mha, (ks, baseline_kmeans_out_mha), pallas_assign_indices_out_mha)
    print("Benchmarking mha custom assign indices...")
    benchmark_fn(custom_assign_indices_mha, (ks, baseline_kmeans_out_mha), custom_assign_indices_out_mha)

    baseline_adjust_oh = partial(baseline_adjust, NUM_ITERS)
    baseline_adjust_mha = vmap(vmap(baseline_adjust_oh, in_axes=1, out_axes=1))
    print("Benchmarking mha baseline adjust...")
    benchmark_fn(baseline_adjust_mha, (ks, baseline_kmeans_out_mha), baseline_kmeans_out_mha)
    baseline_centroids_oh = partial(baseline_centroids, K)
    baseline_centroids_mha = vmap(vmap(baseline_centroids_oh, in_axes=1, out_axes=1))
    print("Benchmasking mha baseline centroids...")
    benchmark_fn(baseline_centroids_mha, (ks, baseline_assign_out_mha), baseline_kmeans_out_mha)




    kmeans_with_indices_oh = partial(kmeans_with_indices, K, CLUSTER_MAX, NUM_ITERS)
    kmeans_with_indices_mha = vmap(vmap(kmeans_with_indices_oh, in_axes=1, out_axes=1))
    print("Benchmarking mha kmeans with indices...")
    benchmark_fn(kmeans_with_indices_mha, (ks,), custom_assign_indices_out_mha)

    do_clustering_mha = vmap(vmap(partial(do_clustering, K, CLUSTER_MAX, NUM_ITERS), in_axes=1, out_axes=1))
    print("Benchmarking mha qk clustering...")
    benchmark_fn(do_clustering_mha, (qs, ks, vs), (qmeta, kmeta))

    exit()



    print("Benchmarking mha baseline kmeans...")
    benchmark_fn(baseline_kmeans_mha, (ks,), baseline_kmeans_out_mha)
    print("Benchmarking mha subsample baseline kmeans...")
    benchmark_fn(subsample_baseline_kmeans_mha, (ks,), baseline_kmeans_out_mha)
    print("Benchmarking mha baseline indices clustering...")
    benchmark_fn(_do_multi_level_clustering_mha, (qs, ks, vs), (qmeta, kmeta))


if __name__ == "__main__":
    main()

    

