import jax
from jax import numpy as jnp
from jax.ops import segment_sum as segment_sum
from functools import partial
from jax.numpy import einsum
from jax import random
from einshape import jax_einshape as einshape
import math

import triton
import triton.language as tl
from jax_triton import triton_call

from jax.experimental import pallas as pl

@partial(jax.jit, static_argnums=(2,))
def segsum_centroids(xs, labels, K, weights=None):
    if weights is None:
        weights = jnp.ones((xs.shape[0],), dtype=xs.dtype)
    #ones = jnp.ones((xs.shape[0], 1), dtype=xs.dtype)
    xs_ext = jnp.concatenate([xs * weights[:,None], weights[:,None]], axis=-1)  # (N, D+1)
    totals = segment_sum(xs_ext, labels, num_segments=K)
    return totals[:, :-1] / totals[:, -1:]  # Exclude the last column which is the count

@jax.jit
def baseline_cross_moment(values, keys, weights):
    return einsum("n,nv,nk->vk", weights, values, keys)

def _fancy_pallas_cross_moment_kernel(values_ref, keys_ref, weights_ref, in_ref, out_ref):
    """ nv,nk,qn->qvk """
    values = values_ref[...]
    keys = keys_ref[...]
    interm = values[:, :, None] * keys[:, None, :]  # (n, v, k)
    N, V, K = interm.shape
    flat_interm = einshape("nvk->n(vk)", interm)  # (n, vk)
    weights = weights_ref[...] # (q, n)
    Q, _ = weights.shape
    out = weights @ flat_interm  # q(vk)
    out = einshape("q(vk)->qvk", out, v=V, k=K)  # (q, v, k)
    #out = einsum("qn,nvk->qvk", weights_ref[...], interm)  # (q, v, k)
    pl.atomic_add(out_ref, (slice(None), slice(None), slice(None)), out)

def _pallas_cross_moment_kernel(values_ref, keys_ref, weights_ref, in_ref, out_ref):
    """ nv,nk,qn->qvk """
    values = values_ref[...] # (n, v)
    keys = keys_ref[...] # (n, k)
    Q, N = weights_ref.shape
    for i in range(Q):
        weights = weights_ref[i, :] # (n,)
        #result = values.T @ (keys * weights[:, None])
        #result = einsum("n,nv,nk->vk", weights, values, keys)
        result = einsum("nv,nk->vk", values, keys * weights[:, None])
        pl.atomic_add(out_ref, (i, slice(None), slice(None)), result)



@jax.jit
def pallas_cross_moment(values, keys, weights):
    BLOCK_N = 16
    Q, N = weights.shape
    _, V = values.shape
    _, K = keys.shape

    out_shape = jax.ShapeDtypeStruct(shape=(Q, V, K), dtype=values.dtype)  # output
    assert N % BLOCK_N == 0, "N must be divisible by BLOCK_N {BLOCK_N}"
    zeros = jnp.zeros((Q, V, K), dtype=values.dtype)  # zero-initialized output
    result = pl.pallas_call(
        _fancy_pallas_cross_moment_kernel,
        out_shape=out_shape,
        grid=(N // BLOCK_N,),
        in_specs=(
            pl.BlockSpec((BLOCK_N, V), lambda i: (i, 0)),  # values
            pl.BlockSpec((BLOCK_N, K), lambda i: (i, 0)),  # keys
            pl.BlockSpec((Q, BLOCK_N), lambda i: (0, i)),  # weights
            pl.BlockSpec((Q, V, K), lambda i: (0, 0, 0)), # zero-initialized output
        ),
        out_specs=pl.BlockSpec((Q, V, K), lambda i: (0, 0, 0)),  # output
        input_output_aliases={3: 0},  # output is aliased to the input
    )(
        values,  # (N, V)
        keys,    # (N, K)
        weights, # (Q, N)
        zeros,  # (Q, V, K) zero-initialized output
    )
    return result




def _pallas_assign_kernel(
    costs_ref, # (n, K)
    in_counts_ref, # (K,) input/output
    labels_ref, # (n,) output
    counts_ref, # (K,) input/output
    N: int, # number of points
    CAP: int, # maximum number of points per centroid
    ):
    n, K = costs_ref.shape
    available = jnp.ones((K,), dtype=jnp.bool)
    for i in range(n):
        costs = costs_ref[i, :]
        costs = jnp.where(available, costs, jnp.inf)  # Apply the cap
        min_idx = jnp.argmin(costs)
        def cond_fn(tup):
            min_idx, available = tup
            old_count = pl.atomic_add(counts_ref, min_idx, 1)
            return old_count >= CAP
        def body_fn(tup):
            min_idx, available = tup
            pl.atomic_add(counts_ref, min_idx, -1)
            available = jnp.where(min_idx == jnp.arange(K), False, available)
            masked_costs = jnp.where(available, costs, jnp.inf)
            min_idx = jnp.argmin(masked_costs)
            return (min_idx, available)
        (min_idx, available) = jax.lax.while_loop(
            cond_fn, body_fn, (min_idx, available))
        labels_ref[i] = min_idx  # Store the assigned cluster index for this point
        
@partial(jax.jit, static_argnames="max_cluster_size")
def pallas_assign(xs, centroids, metric=None, max_cluster_size=None):
    N, D = xs.shape
    K, _ = centroids.shape
    if max_cluster_size is None:
        max_cluster_size = (N // K) * 2
    assert max_cluster_size * K >= N, "max_cluster_size * K must be at least N"
    # Precompute the costs matrix
    if metric is None:
        dual_centroids = centroids
    else:
        dual_centroids = einsum("kd, dd->kd", centroids, metric)
    similarities = einsum("nd, kd->nk", xs, dual_centroids)
    centroid_sqmags = jnp.sum(centroids*dual_centroids, axis=-1)[None,:]
    costs = -2. * similarities + centroid_sqmags
    out_shape = [
        jax.ShapeDtypeStruct(shape=(N,), dtype=jnp.int32), # labels
        jax.ShapeDtypeStruct(shape=(K,), dtype=jnp.int32), # counts
    ]
    counts = jnp.zeros((K,), dtype=jnp.int32)
    labels = jnp.zeros((N,), dtype=jnp.int32)
    BLOCK_N = 16
    assert N % BLOCK_N == 0, "N must be divisible by BLOCK_N"
    labels, counts = pl.pallas_call(
        partial(_pallas_assign_kernel, N=N, CAP=max_cluster_size),
        out_shape=out_shape,
        input_output_aliases={1:1},
        grid=(N // BLOCK_N,),
        in_specs=(
            pl.BlockSpec((BLOCK_N, K), lambda i: (i, 0)),
            pl.BlockSpec((K,), lambda i: (0,)),
        ),
        out_specs=(
            pl.BlockSpec((BLOCK_N,), lambda i: (i,)),
            pl.BlockSpec((K,), lambda i: (0,)),
        ),
    )(costs, counts)
    return labels

def main():
    from time import time
    from jax import random
    N = 2**8
    D = 2**4
    K = 2**4
    Q = K
    num_iters = 10
    xs = random.normal(random.PRNGKey(0), (N, D))
    labels = random.randint(random.PRNGKey(1), (N,), 0, K)
    centroids = random.choice(random.PRNGKey(2), xs, shape=(K,), replace=False)
    weights = random.normal(random.PRNGKey(3), (N,))
    weights = jnp.exp(weights)  # Ensure weights are positive
    weights = weights / jnp.sum(weights)  # Normalize weights
    weights = jnp.ones_like(weights)
    keys = random.normal(random.PRNGKey(4), (N, D))
    values = random.normal(random.PRNGKey(5), (N, D))

    stacked_keys = jnp.stack([keys] * K, axis=0)  # (K, N, D)
    stacked_values = jnp.stack([values] * K, axis=0)  # (K, N, D)
    stacked_weights = jnp.stack([weights] * Q, axis=0)  # (Q, N)

    vmapped_baseline_cross_moment = jax.vmap(
        jax.vmap(
            baseline_cross_moment,
            in_axes=(None, None, 0),
            out_axes=0,
        ),
        in_axes=(0, 0, None),
        out_axes=1,
    )
    multiweight_baseline_cross_moment = jax.vmap(
        baseline_cross_moment,
        in_axes=(None, None, 0),
        out_axes=0,
    )
    vmapped_pallas_cross_moment = jax.vmap(
        pallas_cross_moment,
        in_axes=(0, 0, None),
        out_axes=1,
    )




    lowered_cross_moment = baseline_cross_moment.lower(values, keys, weights)
    print(f"segsum_centroids HLO:\n{lowered_cross_moment.as_text()}")




    # warmup
    for _ in range(10):
        new_centroids = jax.block_until_ready(segsum_centroids(xs, labels, K))
        new_labels = jax.block_until_ready(pallas_assign(xs, new_centroids))
        vkcorr = baseline_cross_moment(values, keys, weights)
        vmap_vkcorr = vmapped_baseline_cross_moment(stacked_values, stacked_keys, stacked_weights)
        assert vmap_vkcorr.shape == (Q, K, D, D)
        multi_vkcorr = multiweight_baseline_cross_moment(values, keys, stacked_weights)
        assert multi_vkcorr.shape == (Q, D, D)
        print("Calling pallas_cross_moment...")
        pallas_multi_vkcorr = pallas_cross_moment(values, keys, stacked_weights)
        assert pallas_multi_vkcorr.shape == (Q, D, D)
        #print(multi_vkcorr[0])
        #print(pallas_multi_vkcorr[0])
        #assert jnp.allclose(multi_vkcorr, pallas_multi_vkcorr)
        pallas_vmap_vkcorr = vmapped_pallas_cross_moment(stacked_values, stacked_keys, stacked_weights)
    print(f"Error: {jnp.max(jnp.abs(multi_vkcorr - pallas_multi_vkcorr)):.4f}")


    ####################
    reps = 100
    start = time()
    for _ in range(reps):
        baseline_vkcorr = jax.block_until_ready(baseline_cross_moment(values, keys, weights))
    end = time()
    baseline_time_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"baseline_cross_moment time: {baseline_time_mics:.2f} μs")

    start = time()
    for _ in range(reps):
        vmap_vkcorr = jax.block_until_ready(vmapped_baseline_cross_moment(stacked_values, stacked_keys, stacked_weights))
    end = time()
    vmap_time_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"vmapped baseline_cross_moment time: {vmap_time_mics:.2f} μs")

    start = time()  
    for _ in range(reps):
        multi_vkcorr = jax.block_until_ready(multiweight_baseline_cross_moment(values, keys, stacked_weights))
    end = time()
    multi_vkcorr_time_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"multiweight baseline_cross_moment time: {multi_vkcorr_time_mics:.2f} μs")

    start = time()
    for _ in range(reps):
        pallas_multi_vkcorr = jax.block_until_ready(pallas_cross_moment(values, keys, stacked_weights))
    end = time()
    pallas_cross_moment_time_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"pallas_cross_moment time: {pallas_cross_moment_time_mics:.2f} μs")

    start = time()
    for _ in range(reps):
        pallas_vmap_vkcorr = jax.block_until_ready(vmapped_pallas_cross_moment(stacked_values, stacked_keys, stacked_weights))
    end = time()
    pallas_vmap_time_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"vmapped pallas_cross_moment time: {pallas_vmap_time_mics:.2f} μs")


    start = time()
    for _ in range(reps):
        new_centroids = jax.block_until_ready(segsum_centroids(xs, labels, K))
    end = time()
    segsum_time_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"segsum_centroids time: {segsum_time_mics:.2f} μs")
    start = time()
    for _ in range(reps):
        triton_labels = jax.block_until_ready(pallas_assign(xs, centroids))
    end = time()
    pallas_assign_mics = (end - start) / reps * 1e6  # convert to microseconds
    print(f"pallas_assign time: {pallas_assign_mics:.2f} μs")

if __name__ == "__main__":
    main()
