import jax
from jax import numpy as jnp
from jax.ops import segment_sum
from functools import partial
from jax.numpy import einsum
from jax import random

@partial(jax.jit, static_argnums=(3,))
def baseline_aggregate(keys, values, labels, K):
    def aggregate(i):
        mask = (labels == i)
        kmean = jnp.mean(keys, where=mask[:, None], axis=0)
        vmean = jnp.mean(values, where=mask[:, None], axis=0)
        vkcorr = jnp.mean((values - vmean)[:,:,None] * (keys - kmean)[:,None,:], axis=0)
        return kmean, vmean, vkcorr
    kmeans, vmeans, vkcorrs = jax.vmap(aggregate)(jnp.arange(K))
    return kmeans, vmeans, vkcorrs

@partial(jax.jit, static_argnums=(3,))
def segsum_aggregate(keys, values, labels, K):
    ones = jnp.ones((keys.shape[0],), dtype=keys.dtype)
    counts = segment_sum(ones, labels, num_segments=K)
    kmean = segment_sum(keys, labels, num_segments=K) / counts[:, None]
    vmean = segment_sum(values, labels, num_segments=K) / counts[:, None]
    vkmom = segment_sum(values[:, :, None] * keys[:, None, :], labels, num_segments=K) / counts[:, None, None]
    vkcorr = vkmom - vmean[:, None, :] * kmean[:, :, None]
    return kmean, vmean, vkcorr



def main():
    from time import time
    from jax import random
    N = 2**14
    D = 2**6
    K = 2**6
    keys = random.normal(random.PRNGKey(0), (N, D))
    values = random.normal(random.PRNGKey(1), (N, D))
    labels = random.randint(random.PRNGKey(2), (N,), 0, K)

    # print hlo for baseline_aggregate
    lowered_baseline_aggregate = baseline_aggregate.lower(keys, values, labels, K)
    print(f"baseline_aggregate HLO:\n{lowered_baseline_aggregate.as_text()}")

    # print hlo for segsum_aggregate
    lowered_segsum_aggregate = segsum_aggregate.lower(keys, values, labels, K)
    print(f"segsum_aggregate HLO:\n{lowered_segsum_aggregate.as_text()}")

    # warmup
    for _ in range(10):
        kcent, vcent, vcorr = jax.block_until_ready(baseline_aggregate(keys, values, labels, K))
        kcent, vcent, vcorr = jax.block_until_ready(segsum_aggregate(keys, values, labels, K))

    ####################
    reps = 100
    start = time()
    for _ in range(reps):
        kcent, vcent, vcorr = jax.block_until_ready(baseline_aggregate(keys, values, labels, K))
    end = time()
    baseline_aggregate_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"baseline_aggregate time: {baseline_aggregate_mics:.2f} μs")
    start = time()
    for _ in range(reps):
        kcent, vcent, vcorr = jax.block_until_ready(segsum_aggregate(keys, values, labels, K))
    end = time()
    segsum_aggregate_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"segsum_aggregate time: {segsum_aggregate_mics:.2f} μs")


if __name__ == "__main__":
    main()
