import jax
from jax import numpy as jnp
from jax.ops import segment_sum as segment_sum
from functools import partial
from jax.numpy import einsum
from jax import random
from einshape import jax_einshape as einshape
import math

import triton
import triton.language as tl
from jax_triton import triton_call

from jax.experimental import pallas as pl

@partial(jax.jit, static_argnums=(2,))
def segsum_centroids(xs, labels, K, weights=None):
    initial_dtype = xs.dtype
    xs = xs.astype(jnp.float32)
    if weights is None:
        weights = jnp.ones((xs.shape[0],), dtype=xs.dtype)
    weights = weights.astype(xs.dtype)
    #ones = jnp.ones((xs.shape[0], 1), dtype=xs.dtype)
    xs_ext = jnp.concatenate([xs * weights[:,None], weights[:,None]], axis=-1)  # (N, D+1)
    totals = segment_sum(xs_ext, labels, num_segments=K)
    eps = 1e-20
    result = totals[:, :-1] / (totals[:, -1:] + eps)  # Exclude the last column which is the count
    return result.astype(initial_dtype)  # Convert back to the original dtype

@partial(jax.jit, static_argnums=(2,))
def baseline_centroids(xs, labels, K):
    def centroid(i):
        return jnp.mean(xs, where=(labels == i)[:, None], axis=0)
    return jax.vmap(centroid)(jnp.arange(K))

@jax.jit
def matmul_assign(xs, centroids, metric=None):
    if metric is None:
        dual_centroids = centroids
    else:
        dual_centroids = einsum("kd, dd->kd", centroids, metric)
    centroid_sqmags = jnp.sum(centroids*dual_centroids, axis=-1)[None,:]
    similarities = einsum("nd, kd->nk", xs, dual_centroids)
    return jnp.argmin(-2.*similarities + centroid_sqmags, axis=-1)

@jax.jit
def balanced_assign(xs, centroids, penalties):
    centroid_sqmags = jnp.sum(centroids**2, axis=-1)[None,:]
    similarities = einsum("nd, kd->nk", xs, centroids)
    penalties = penalties[None, :]
    return jnp.argmin(-2.*similarities + centroid_sqmags + penalties, axis=-1)

@jax.jit
def capped_scan_assign(xs, centroids):
    N, D = xs.shape
    B = 8
    n = N // B
    K, _ = centroids.shape
    centroid_sqmags = jnp.sum(centroids**2, axis=-1)[None,:]
    similarities = einsum("nd, kd->nk", xs, centroids)
    scores = -2. * similarities + centroid_sqmags
    def scan_body(counts, scores):
        inactive = counts > (N//K)
        scores = jnp.where(inactive[None, :], jnp.inf, scores)
        labels = jnp.argmin(scores, axis=-1)
        new_counts = counts.at[labels].add(1)
        return new_counts, labels
    batched_scores = einshape("(bn)k->bnk", scores, b=B, n=n)
    init_counts = jnp.zeros((K,), dtype=xs.dtype)
    final_counts, batched_labels = jax.lax.scan(scan_body, init_counts, batched_scores)
    labels = einshape("bn->(bn)", batched_labels, b=B, n=n)
    return labels

@jax.jit
def baseline_assign(xs, centroids):
    diffs = xs[None, :, :] - centroids[:, None, :]
    return jnp.argmin(jnp.sum(diffs**2, axis=-1), axis=0)

def fast_kmeans_with_init(xs, key, k, iters, metric=None, weights=None):
    p = None if weights is None else weights / jnp.sum(weights)
    centroids = random.choice(key, xs, shape=(k,), replace=False, p=p)
    return matmul_segsum_kmeans(xs, centroids, k, iters, metric=metric, weights=weights)

@partial(jax.jit, static_argnums=(2, 3))
def matmul_segsum_kmeans(xs, centroids, K, num_iters, metric=None, weights=None):
    def body_fn(i, centroids):
        labels = matmul_assign(xs, centroids, metric=metric)
        #sizes = segment_sum(jnp.ones_like(labels), labels, num_segments=K)
        #normed_sizes = segment_sum(weights, labels, num_segments=K) if weights is not None else sizes
        #jax.debug.print("fast_kmeans iter {}: sizes: {} normed_sizes: {}", i, sizes, normed_sizes)
        return segsum_centroids(xs, labels, K, weights=weights)
    centroids = jax.lax.fori_loop(0, num_iters, body_fn, centroids)
    return centroids, matmul_assign(xs, centroids)

def balanced_kmeans_with_init(xs, key, k, iters, metric=None, max_cluster_size=None, weights=None):
    #p = None if weights is None else weights / jnp.sum(weights)
    #centroids = random.choice(key, xs, shape=(k,), replace=False, p=p)
    N, D = xs.shape
    centroids = xs[::N//k, :][:k, :]  # Ensure centroids are initialized from the data
    assert centroids.shape == (k, D), f"Centroids shape mismatch: {centroids.shape} != {(k, D)}"
    return balanced_kmeans(xs, centroids, k, iters, metric=metric, max_cluster_size=max_cluster_size, weights=weights)

def balanced_kmeans_with_indices(xs, k, iters, metric=None, max_cluster_size=None, weights=None):
    N, D = xs.shape
    centroids = xs[::N//k, :][:k, :]  # Ensure centroids are initialized from the data
    assert centroids.shape == (k, D), f"Centroids shape mismatch: {centroids.shape} != {(k, D)}"
    def body_fn(i, tup):
        (centroids,) = tup
        labels = matmul_assign(xs, centroids, metric=metric)
        return (segsum_centroids(xs, labels, k, weights=weights),)
    (centroids,) = jax.lax.fori_loop(0, iters, body_fn, (centroids,))
    counts, labels, fwd_indices, bwd_indices = pallas_assign_indices(xs, centroids, metric=metric, max_cluster_size=max_cluster_size)
    return counts, labels, fwd_indices, bwd_indices

jit_balanced_kmeans_with_indices = jax.jit(balanced_kmeans_with_indices, static_argnums=(1, 2), static_argnames=("max_cluster_size"))

@partial(jax.jit, static_argnums=(2, 3), static_argnames="max_cluster_size")
def balanced_kmeans(xs, centroids, K, num_iters, metric=None, max_cluster_size=None, weights=None):
    init_centroids = centroids
    N, D = xs.shape
    K, _ = centroids.shape
    def body_fn(i, tup):
        centroids, labels = tup
        # This resets centroids which have exploded because they were assigned no points. Should find something faster.
        centroids = jnp.where(jnp.all(jnp.isfinite(centroids), axis=-1, keepdims=True), centroids, init_centroids)
        #labels = pallas_assign(xs, centroids, metric=metric, max_cluster_size=max_cluster_size)
        labels = matmul_assign(xs, centroids, metric=metric)
        #ones = jnp.ones_like(labels)
        #counts = segment_sum(ones, labels, num_segments=K)
        return segsum_centroids(xs, labels, K, weights=weights), labels
    init_labels = jnp.zeros((N,), dtype=jnp.int32)
    centroids, labels = jax.lax.fori_loop(0, num_iters, body_fn, (centroids, init_labels))
    centroids = jnp.where(jnp.all(jnp.isfinite(centroids), axis=-1, keepdims=True), centroids, init_centroids)
    labels = pallas_assign(xs, centroids, metric=metric, max_cluster_size=max_cluster_size)
    return centroids, labels

@partial(jax.jit, static_argnums=(2, 3))
def baseline_kmeans(xs, centroids, K, num_iters):
    def body_fn(i, centroids):
        labels = baseline_assign(xs, centroids)
        return baseline_centroids(xs, labels, K)
    centroids = jax.lax.fori_loop(0, num_iters, body_fn, centroids)
    return centroids, baseline_assign(xs, centroids)

@triton.jit
def _triton_assign_kernel(
    costs_ptr, # (N, K)
    in_counts_ptr, # (K,) input/output
    labels_ptr, # (N,) output
    counts_ptr, # (K,) input/output
    N: tl.constexpr, # number of points
    K: tl.constexpr, # number of centroids
    CAP: tl.constexpr, # maximum number of points per centroid
    BLOCK_N: tl.constexpr, # number of points per block
    ):
    """ Simple sequential capped assignment kernel. """
    assert N % BLOCK_N == 0, "N must be divisible by BLOCK_N"

    block_id = tl.program_id(0) # Which block of inputs we are processing
    cluster_id = tl.arange(0,K)
    n_base = block_id * BLOCK_N
    counts = tl.zeros((K,), dtype=tl.int32)
    for n_off in tl.range(0, BLOCK_N, loop_unroll_factor=16):
        nidx = n_base + n_off
        costs = tl.load(costs_ptr + nidx * K + cluster_id)
        costs = tl.where(counts < CAP, costs, float("inf"))  # Apply the cap
        min_idx = tl.argmin(costs, axis=0)
        old_count = tl.atomic_add(counts_ptr + min_idx, 1)
        while old_count >= CAP:
            tl.atomic_add(counts_ptr + min_idx, -1)
            costs = tl.where(min_idx == cluster_id, float("inf"), costs)  # Recompute costs
            counts = tl.where(min_idx == cluster_id, CAP + 1, counts)  # Update counts to avoid reassigning
            min_idx = tl.argmin(costs, axis=0)
            old_count = tl.atomic_add(counts_ptr + min_idx, 1)
        tl.store(labels_ptr + nidx, min_idx)

@jax.jit
def triton_assign(xs, centroids):
    N, D = xs.shape
    K, _ = centroids.shape
    centroid_sqmags = jnp.sum(centroids**2, axis=-1)[None,:]
    similarities = einsum("nd, kd->nk", xs, centroids)
    costs = -2. * similarities + centroid_sqmags
    out_shape = [
        jax.ShapeDtypeStruct(shape=(N,), dtype=jnp.int32), # labels
        jax.ShapeDtypeStruct(shape=(K,), dtype=jnp.int32), # counts
        ]
    block_n = 16
    assert N % block_n == 0, "N must be divisible by BLOCK_N"
    counts = jnp.zeros((K,), dtype=jnp.int32)
    labels = jnp.zeros((N,), dtype=jnp.int32)
    labels, counts = triton_call(
        costs,
        counts,
        kernel=_triton_assign_kernel,
        out_shape=out_shape,
        grid=(N // block_n),
        input_output_aliases={1:1},
        num_warps=1,
        N=N,
        K=K,
        CAP=(N // K) + 102,
        BLOCK_N=block_n,
    )
    return labels
    return jnp.argmin(costs, axis=-1)


def _pallas_assign_kernel(
    costs_ref, # (n, K)
    in_counts_ref, # (K,) input/output
    labels_ref, # (n,) output
    counts_ref, # (K,) input/output
    N: int, # number of points
    CAP: int, # maximum number of points per centroid
    ):
    n, K = costs_ref.shape
    available = jnp.ones((K,), dtype=jnp.bool)
    for i in range(n):
        costs = costs_ref[i, :]
        costs = jnp.where(available, costs, jnp.inf)  # Apply the cap
        min_idx = jnp.argmin(costs)
        def cond_fn(tup):
            min_idx, available = tup
            old_count = pl.atomic_add(counts_ref, min_idx, 1)
            return old_count >= CAP
        def body_fn(tup):
            min_idx, available = tup
            pl.atomic_add(counts_ref, min_idx, -1)
            available = jnp.where(min_idx == jnp.arange(K), False, available)
            masked_costs = jnp.where(available, costs, jnp.inf)
            min_idx = jnp.argmin(masked_costs)
            return (min_idx, available)
        (min_idx, available) = jax.lax.while_loop(
            cond_fn, body_fn, (min_idx, available))
        labels_ref[i] = min_idx  # Store the assigned cluster index for this point
        


        #labels_ref[i] = jnp.argmin(costs)

@partial(jax.jit, static_argnames="max_cluster_size")
def pallas_assign(xs, centroids, metric=None, max_cluster_size=None):
    N, D = xs.shape
    K, _ = centroids.shape
    if max_cluster_size is None:
        max_cluster_size = (N // K) * 2
    assert max_cluster_size * K >= N, "max_cluster_size * K must be at least N"
    # Precompute the costs matrix
    if metric is None:
        dual_centroids = centroids
    else:
        dual_centroids = einsum("kd, dd->kd", centroids, metric)
    similarities = einsum("nd, kd->nk", xs, dual_centroids)
    centroid_sqmags = jnp.sum(centroids*dual_centroids, axis=-1)[None,:]
    costs = -2. * similarities + centroid_sqmags
    out_shape = [
        jax.ShapeDtypeStruct(shape=(N,), dtype=jnp.int32), # labels
        jax.ShapeDtypeStruct(shape=(K,), dtype=jnp.int32), # counts
    ]
    counts = jnp.zeros((K,), dtype=jnp.int32)
    labels = jnp.zeros((N,), dtype=jnp.int32)
    BLOCK_N = 16
    assert N % BLOCK_N == 0, "N must be divisible by BLOCK_N"
    labels, counts = pl.pallas_call(
        partial(_pallas_assign_kernel, N=N, CAP=max_cluster_size),
        out_shape=out_shape,
        input_output_aliases={1:1},
        grid=(N // BLOCK_N,),
        in_specs=(
            pl.BlockSpec((BLOCK_N, K), lambda i: (i, 0)),
            pl.BlockSpec((K,), lambda i: (0,)),
        ),
        out_specs=(
            pl.BlockSpec((BLOCK_N,), lambda i: (i,)),
            pl.BlockSpec((K,), lambda i: (0,)),
        ),
    )(costs, counts)
    return labels



def _pallas_assign_indices_kernel(
    costs_ref, # (n, K) input
    in_counts_ref, # (K,) input/output
    in_labels_ref, # (n,) input/output
    in_fwd_indices_ref, # (K, CAP) input/output
    in_bwd_indices_ref, # (n,) input/output
    counts_ref, # (K,) input/output
    labels_ref, # (n,) output
    fwd_indices_ref, # (K, CAP) output
    bwd_indices_ref, # (n,) output
    N: int, # number of points
    ):
    n, K = costs_ref.shape
    _, CAP = in_fwd_indices_ref.shape
    available = jnp.ones((K,), dtype=jnp.bool)
    block_id = pl.program_id(0)  # Get the block ID
    for i in range(n):
        costs = costs_ref[i, :]
        costs = jnp.where(available, costs, jnp.inf)  # Apply the cap
        min_idx = jnp.argmin(costs)
        old_count = pl.atomic_add(counts_ref, min_idx, 1)
        def cond_fn(tup):
            old_count, min_idx, available = tup
            return (old_count >= CAP) & (jnp.sum(available) > 0)
        def body_fn(tup):
            old_count, min_idx, available = tup
            pl.atomic_add(counts_ref, min_idx, -1)
            available = jnp.where(min_idx == jnp.arange(K), False, available)
            masked_costs = jnp.where(available, costs, jnp.inf)
            min_idx = jnp.argmin(masked_costs)
            old_count = pl.atomic_add(counts_ref, min_idx, 1)
            return (old_count, min_idx, available)
        (old_count, min_idx, available) = jax.lax.while_loop(
            cond_fn, body_fn, (old_count, min_idx, available))
        @pl.when(jnp.sum(available) > 0)
        def do_assign():
            labels_ref[i] = min_idx  # Store the assigned cluster index for this point
            fwd_indices_ref[min_idx, old_count] = i + block_id * n  # Store the forward index
            bwd_indices_ref[i] = old_count  # Store the backward index within the cluster
        


@partial(jax.jit, static_argnames="max_cluster_size")
def pallas_assign_indices(xs, centroids, metric=None, max_cluster_size=None):
    N, D = xs.shape
    K, _ = centroids.shape
    if max_cluster_size is None:
        max_cluster_size = (N // K) * 2
    assert max_cluster_size * K >= N, "max_cluster_size * K must be at least N"
    # Precompute the costs matrix
    if metric is None:
        dual_centroids = centroids
    else:
        dual_centroids = einsum("kd, dd->kd", centroids, metric)
    similarities = einsum("nd, kd->nk", xs, dual_centroids)
    centroid_sqmags = jnp.sum(centroids*dual_centroids, axis=-1)[None,:]
    costs = -2. * similarities + centroid_sqmags
    out_shape = [
        jax.ShapeDtypeStruct(shape=(K,), dtype=jnp.int32), # counts
        jax.ShapeDtypeStruct(shape=(N,), dtype=jnp.int32), # labels
        jax.ShapeDtypeStruct(shape=(K, max_cluster_size), dtype=jnp.int32), # fwd_indices
        jax.ShapeDtypeStruct(shape=(N,), dtype=jnp.int32), # bwd_indices
    ]
    counts = jnp.zeros((K,), dtype=jnp.int32)
    labels = jnp.full((N,), -1, dtype=jnp.int32)
    fwd_indices = jnp.full((K, max_cluster_size), -1, dtype=jnp.int32)
    bwd_indices = jnp.full((N,), -1, dtype=jnp.int32)
    BLOCK_N = 16
    assert N % BLOCK_N == 0, "N must be divisible by BLOCK_N"
    counts, labels, fwd_indices, bwd_indices = pl.pallas_call(
        partial(_pallas_assign_indices_kernel, N=N),
        out_shape=out_shape,
        input_output_aliases={1:0, 2:1, 3:2, 4:3},
        grid=(N // BLOCK_N,),
        in_specs=(
            pl.BlockSpec((BLOCK_N, K), lambda i: (i, 0)),
            pl.BlockSpec((K,), lambda i: (0,)),
            pl.BlockSpec((BLOCK_N,), lambda i: (i,)),
            pl.BlockSpec((K, max_cluster_size), lambda i: (0, 0)),
            pl.BlockSpec((BLOCK_N,), lambda i: (i,)),
        ),
        out_specs=(
            pl.BlockSpec((K,), lambda i: (0,)),
            pl.BlockSpec((BLOCK_N,), lambda i: (i,)),
            pl.BlockSpec((K, max_cluster_size), lambda i: (0, 0)),
            pl.BlockSpec((BLOCK_N,), lambda i: (i,)),
        ),
    )(jax.lax.stop_gradient(costs), counts, labels, fwd_indices, bwd_indices)
    return counts, labels, fwd_indices, bwd_indices


def verify_counts_and_indices(counts, labels, fwd_indices, bwd_indices):
    K, n = fwd_indices.shape
    N = bwd_indices.shape[0]
    assert counts.shape == (K,), "Counts shape mismatch"
    assert labels.shape == (N,), "Labels shape mismatch"

    segsum_counts = segment_sum(jnp.ones_like(labels), labels, num_segments=K)
    assert jnp.all(segsum_counts == counts), "Counts do not match segment sum of labels"
    assert jnp.all(bwd_indices >= 0), "Backward indices contain negative values - some points were not assigned"
    assert jnp.all(bwd_indices < n), "Backward indices out of bounds - some points were assigned to non-existent clusters"
    assert jnp.all(labels < K), "Labels contain indices >= K - some points were assigned to non-existent clusters"
    total_fwd_nonegative = jnp.sum(fwd_indices >= 0, axis=(-1,-2))
    assert int(total_fwd_nonegative) == N, "Forward indices do not cover all points - some points were not assigned"
    test_data = jnp.arange(N)
    interm = test_data[fwd_indices]
    assert interm.shape == (K, n), "Forward indices do not have the correct shape"
    recon = interm[labels, bwd_indices]
    assert recon.shape == (N,), "Reconstructed shape does not match original shape"
    assert jnp.all(recon == test_data), "Reconstructed data does not match original data"

@partial(jax.jit, static_argnums=(2, 3), static_argnames=("ks", "iters", "max_cluster_size"))
def balanced_hkmeans_with_init(xs, key, ks, iters, metric=None, max_cluster_size=None):
    if len(ks) == 1:
        centroids, labels = balanced_kmeans_with_init(xs, key, ks[0], iters, metric=metric, max_cluster_size=max_cluster_size)
        return (centroids,), labels

    assert len(ks) == 2, "Only two levels of clustering are supported for balanced_hkmeans"
    K0, K1 = ks
    N, D = xs.shape
    total_upscale = (max_cluster_size * K0 * K1) / N
    coarse_key, fine_key = random.split(key)
    coarse_max_size = math.ceil(math.sqrt(total_upscale) * N / K0)
    coarse_centroids, coarse_labels = balanced_kmeans_with_init(
        xs, coarse_key, K0, iters, metric=metric, max_cluster_size=coarse_max_size)
    #coarse_centroids, coarse_labels = fast_kmeans_with_init(
    #    xs, coarse_key, K0, iters, metric=metric)

    coarse_sizes = segment_sum(jnp.ones_like(coarse_labels), coarse_labels, num_segments=K0)
    #jax.debug.print("Coarse cluster sizes: {}", coarse_sizes)

    @jax.vmap
    def get_bucket_indices(i):
        n = math.ceil(coarse_max_size / 16) * 16
        (indices,) = jnp.where(coarse_labels == i, size=n, fill_value=-1)
        return indices
    bucket_indices = get_bucket_indices(jnp.arange(K0)) # (K0, n)
    bucket_xs = xs[bucket_indices] # (K0, n, D)
    weights = jnp.where(bucket_indices >= 0, 1.0, 0.0)  # (K0, n)
    #jax.debug.print("weights totals: {}", jnp.sum(weights, axis=-1))  # Check weights

    fine_keys = random.split(fine_key, K0)
    fine_kmeans = jax.vmap(
        partial(balanced_kmeans_with_init, k=K1, iters=iters, metric=metric, max_cluster_size=max_cluster_size),
        #partial(fast_kmeans_with_init, k=K1, iters=iters, metric=metric),
    )
    fine_centroids, fine_labels = fine_kmeans(bucket_xs, fine_keys, weights=weights)
    #for i in range(K0):
    #    jax.debug.print("fine_labels for coarse cluster {}: max: {}, min: {}", i, fine_labels[i].max(), fine_labels[i].min())
    #jax.debug.print("fine_labels max: {}, min: {}", fine_labels.max(), fine_labels.min())
    fine_labels = jnp.where(fine_labels >= 0, fine_labels + jnp.arange(K0)[:, None] * K1, -1)  # Offset labels by coarse cluster index
    #jax.debug.print("fine_labels max: {}, min: {}", fine_labels.max(), fine_labels.min())
    flat_fine_labels = fine_labels.reshape(-1)
    fine_cluster_sizes = segment_sum(jnp.ones_like(flat_fine_labels), flat_fine_labels, num_segments=K0 * K1)
    #jax.debug.print("Fine cluster sizes: {}", fine_cluster_sizes)
    flat_fine_labels = fine_labels.reshape(-1)
    fine_labels_indices = jnp.arange(flat_fine_labels.shape[0]).reshape(fine_labels.shape)
    out_indices = jnp.full(xs.shape[0], -1, dtype=jnp.int32)
    out_indices = out_indices.at[bucket_indices].set(fine_labels_indices)
    out_fine_labels = flat_fine_labels[out_indices]
    return (coarse_centroids, fine_centroids), out_fine_labels


@partial(jax.jit, static_argnums=(1, 2))
def baseline_make_indices(labels, K, max_size):
    def do_cluster_k(i):
        (indices,) = jnp.where(labels == i, size=max_size, fill_value=-1)
        return indices
    indices = jax.vmap(do_cluster_k)(jnp.arange(K))  # (K, n)
    return indices

@jax.jit
def as_bucketed(xs, indices): # (N, D), (K, n) -> (K, n, D)
    return xs[indices]  # (K, n, D)

@partial(jax.jit, static_argnums=(1, 2))
def baseline_make_bucketed(labels, K, max_size, xs):
    indices = baseline_make_indices(labels, K, max_size)
    return as_bucketed(xs, indices)


    




def main():
    from time import time
    from jax import random
    N = 2**16
    D = 2**6
    K = 2**6
    num_iters = 10
    xs = random.normal(random.PRNGKey(0), (N, D))
    labels = random.randint(random.PRNGKey(1), (N,), 0, K)
    centroids = random.choice(random.PRNGKey(2), xs, shape=(K,), replace=False)
    metric = random.normal(random.PRNGKey(3), (D,D))
    metric = metric @ metric.T  # Make it positive definite
    #labels = jnp.sort(labels)  # Ensure labels are sorted for segment_sum

    max_upscale = 1.5
    MAX_SIZE = math.ceil(N / K * max_upscale)

    dtype = jnp.float16
    xs = xs.astype(dtype)
    centroids = centroids.astype(dtype)
    metric = metric.astype(dtype)

    meansq_xs = jnp.sum(xs**2, axis=-1).mean()
    def evaluate_clustering(centroids, labels):
        meansq_centroids = jnp.mean(jnp.sum(centroids**2, axis=-1))
        print(f"Score: {float(meansq_centroids):.2f} out of {float(meansq_xs):.2f}")
        #labels = matmul_assign(xs, centroids)
        counts = segment_sum(jnp.ones_like(labels), labels, num_segments=K)
        print(f"Clusters ({N//K}) Min: {counts.min()}, Max: {counts.max()}")


    lowered_segsum = segsum_centroids.lower(xs, labels, K)
    print(f"segsum_centroids HLO:\n{lowered_segsum.as_text()}")
    lowered_matmul = matmul_assign.lower(xs, centroids)
    print(f"matmul_assign HLO:\n{lowered_matmul.as_text()}")
    lowered_matmul_segsum = matmul_segsum_kmeans.lower(xs, centroids, K, num_iters)
    print(f"matmul_segsum_kmeans HLO:\n{lowered_matmul_segsum.as_text()}")
    lowered_balanced = balanced_kmeans.lower(xs, centroids, K, num_iters)
    print(f"balanced_kmeans HLO:\n{lowered_balanced.as_text()}")





    # warmup
    for _ in range(10):
        new_centroids = jax.block_until_ready(baseline_centroids(xs, labels, K))
        new_centroids = jax.block_until_ready(segsum_centroids(xs, labels, K))
        new_labels = jax.block_until_ready(matmul_assign(xs, centroids))
        new_labels = jax.block_until_ready(baseline_assign(xs, centroids))
        new_labels = jax.block_until_ready(capped_scan_assign(xs, centroids))
        new_labels = jax.block_until_ready(triton_assign(xs, centroids))
        new_labels = jax.block_until_ready(pallas_assign(xs, centroids))
        new_centroids = jax.block_until_ready(matmul_segsum_kmeans(xs, centroids, K, num_iters))
        new_centroids = jax.block_until_ready(baseline_kmeans(xs, centroids, K, num_iters))
        new_centroids = jax.block_until_ready(balanced_kmeans(xs, centroids, K, num_iters))

        new_labels = jax.block_until_ready(matmul_assign(xs, centroids, metric))
        new_labels = jax.block_until_ready(pallas_assign(xs, centroids, metric))
        new_centroids = jax.block_until_ready(matmul_segsum_kmeans(xs, centroids, K, num_iters, metric=metric))
        new_centroids = jax.block_until_ready(balanced_kmeans(xs, centroids, K, num_iters, metric=metric, max_cluster_size=MAX_SIZE))
        #(new_coarse_centroids, new_fine_centroids), new_labels = jax.block_until_ready(
        #    balanced_hkmeans_with_init(xs, random.PRNGKey(4), (K, K), num_iters, metric=metric, max_cluster_size=N//K**2 * 2))

        new_indices = jax.block_until_ready(baseline_make_indices(labels, K, MAX_SIZE))
        new_bucketed = jax.block_until_ready(baseline_make_bucketed(labels, K, MAX_SIZE, xs))
        new_counts, new_labels, new_fwd_indices, new_bwd_indices = jax.block_until_ready(jit_balanced_kmeans_with_indices(xs, K, num_iters, metric=metric, max_cluster_size=MAX_SIZE))
    for _ in range(10):
        new_counts, new_labels, new_fwd_indices, new_bwd_indices = jax.block_until_ready(pallas_assign_indices(xs, centroids, metric=metric, max_cluster_size=MAX_SIZE))
    verify_counts_and_indices(new_counts, new_labels, new_fwd_indices, new_bwd_indices)

    ####################
    reps = 10000
    start = time()
    for _ in range(reps):
        new_centroids = jax.block_until_ready(baseline_centroids(xs, labels, K))
    end = time()
    baseline_centroids_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"baseline_centroids time: {baseline_centroids_mics:.2f} μs")
    start = time()
    for _ in range(reps):
        new_centroids = jax.block_until_ready(segsum_centroids(xs, labels, K))
    end = time()
    segsum_time_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"segsum_centroids time: {segsum_time_mics:.2f} μs")
    start = time()
    for _ in range(reps):
        new_labels = jax.block_until_ready(matmul_assign(xs, centroids))
    end = time()
    matmul_assign_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"matmul_assign time: {matmul_assign_mics:.2f} μs")
    start = time()
    for _ in range(reps):
        new_labels = jax.block_until_ready(baseline_assign(xs, centroids))
    end = time()
    baseline_assign_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"baseline_assign time: {baseline_assign_mics:.2f} μs")

    start = time()
    for _ in range(reps):
        new_labels = jax.block_until_ready(matmul_assign(xs, centroids, metric))
    end = time()
    matmul_assign_metric_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"matmul_assign with metric time: {matmul_assign_metric_mics:.2f} μs")

    reps = 1000
    start = time()
    for _ in range(reps):
        triton_labels = jax.block_until_ready(triton_assign(xs, centroids))
    end = time()
    triton_assign_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"triton_assign time: {triton_assign_mics:.2f} μs")
    #assert jnp.all(triton_labels == new_labels), "Triton assign does not match JAX assign!"
    start = time()
    for _ in range(reps):
        triton_labels = jax.block_until_ready(pallas_assign(xs, centroids))
    end = time()
    pallas_assign_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"pallas_assign time: {pallas_assign_mics:.2f} μs")

    #####################
    reps = 100
    print("-------------------------------------")
    start = time()
    for _ in range(reps):
        new_centroids, new_labels = jax.block_until_ready(matmul_segsum_kmeans(xs, centroids, K, num_iters))
    end = time()
    matmul_segsum_kmeans_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"matmul_segsum_kmeans time: {matmul_segsum_kmeans_mics:.2f} μs")
    evaluate_clustering(new_centroids, new_labels)

    print("-------------------------------------")
    start = time()
    for _ in range(reps):
        new_centroids, new_labels = jax.block_until_ready(baseline_kmeans(xs, centroids, K, num_iters))
    end = time()
    baseline_kmeans_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"baseline_kmeans time: {baseline_kmeans_mics:.2f} μs")
    evaluate_clustering(new_centroids, new_labels)

    print("-------------------------------------")
    start = time()
    for _ in range(reps):
        new_centroids, new_labels = jax.block_until_ready(balanced_kmeans(xs, centroids, K, num_iters))
    end = time()
    balanced_kmeans_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"balanced_kmeans time: {balanced_kmeans_mics:.2f} μs")
    evaluate_clustering(new_centroids, new_labels)

    print("-------------------------------------")
    start = time()
    for _ in range(reps):
        new_centroids, new_labels = jax.block_until_ready(matmul_segsum_kmeans(xs, centroids, K, num_iters, metric=metric))
    end = time()
    matmul_segsum_kmeans_metric_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"matmul_segsum_kmeans with metric time: {matmul_segsum_kmeans_metric_mics:.2f} μs")
    evaluate_clustering(new_centroids, new_labels)

    print("-------------------------------------")
    print(f"Running balanced kmeans with metric and max cluster size {MAX_SIZE}")
    start = time()
    for _ in range(reps):
        bal_centroids, bal_labels = jax.block_until_ready(balanced_kmeans(xs, centroids, K, num_iters, metric=metric, max_cluster_size=MAX_SIZE))
    end = time()
    balanced_kmeans_metric_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"balanced_kmeans with metric time: {balanced_kmeans_metric_mics:.2f} μs")
    evaluate_clustering(bal_centroids, bal_labels)
    print("-------------------------------------")

    start = time()
    for _ in range(reps):
        baseline_indices = jax.block_until_ready(baseline_make_indices(labels, K, MAX_SIZE))
    end = time()
    baseline_make_indices_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"baseline_make_indices time: {baseline_make_indices_mics:.2f} μs")
    print("-------------------------------------")

    start = time()
    for _ in range(reps):
        bucketed = jax.block_until_ready(baseline_make_bucketed(labels, K, MAX_SIZE, xs))
    end = time()
    baseline_make_bucketed_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"baseline_make_bucketed time: {baseline_make_bucketed_mics:.2f} μs")
    print("-------------------------------------")

    start = time()
    for _ in range(reps):
        new_counts, new_labels, new_fwd_indices, new_bwd_indices = jax.block_until_ready(pallas_assign_indices(xs, bal_centroids, metric=metric, max_cluster_size=MAX_SIZE))
    end = time()
    pallas_assign_indices_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"pallas_assign_indices time: {pallas_assign_indices_mics:.2f} μs")
    verify_counts_and_indices(new_counts, new_labels, new_fwd_indices, new_bwd_indices)
    print("-------------------------------------")

    print(f"Running balanced kmeans with indices and num_iters {num_iters}, metric and max cluster size {MAX_SIZE}")
    start = time()
    for _ in range(reps):
        new_counts, new_labels, new_fwd_indices, new_bwd_indices = jax.block_until_ready(jit_balanced_kmeans_with_indices(xs, K, num_iters, metric=metric, max_cluster_size=MAX_SIZE))
    end = time()
    jit_balanced_kmeans_with_indices_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"jit_balanced_kmeans_with_indices time: {jit_balanced_kmeans_with_indices_mics:.2f} μs")
    verify_counts_and_indices(new_counts, new_labels, new_fwd_indices, new_bwd_indices)
    new_centroids = segsum_centroids(xs, new_labels, K)
    evaluate_clustering(new_centroids, new_labels)
    print("-------------------------------------")










    exit()
    reps = 10
    start = time()
    for _ in range(reps):
        (coarse_centroids, fine_centroids), fine_labels = jax.block_until_ready(
            balanced_hkmeans_with_init(xs, random.PRNGKey(4), (K, K), num_iters, metric=metric, max_cluster_size=N//K**2 * 2))
    end = time()
    balanced_hkmeans_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"balanced_hkmeans time: {balanced_hkmeans_mics:.2f} μs")
    evaluate_clustering(coarse_centroids, fine_labels // K)
    evaluate_clustering(fine_centroids.reshape(-1, D), fine_labels)


if __name__ == "__main__":
    main()
