import jax
from jax import numpy as jnp
import numpy as np
from chex import Array
import os
import ctypes
from functools import partial

def baseline_similarity(xs: Array, centroids: Array) -> Array: # [512, 64], [64, 64]
    return jnp.einsum('nd,kd->nk', xs, centroids)  # [512, 64]

def baseline_costs(xs: Array, totals: Array, counts: Array): # [512, 64], [64, 64], [64]
    cs = totals / counts[:, None]  # [64, 64]
    sim = jnp.einsum('nd,kd->nk', xs, cs)  # [512, 64]
    csq = jnp.einsum('kd,kd->k', cs, cs)  # [64]
    return csq[None, :] - 2 * sim # [512, 64]

def baseline_kmeans_iter(xs: Array, totals: Array, counts: Array):
    #cs = totals / counts[:, None] # [64, 64]
    costs = baseline_costs(xs, totals, counts)  # [512, 64]
    labels = jnp.argmin(costs, axis=-1)  # [512]
    new_totals = jnp.zeros_like(totals)  # [64, 64]
    new_counts = jnp.zeros_like(counts)  # [64]
    new_totals = new_totals.at[labels].add(xs)  # [64, 64]
    new_counts = new_counts.at[labels].add(1)  # [64]
    return new_totals, new_counts, labels

def baseline_kmeans(xs: Array, totals: Array, counts: Array):
    for i in range(100):
        totals, counts, labels = baseline_kmeans_iter(xs, totals, counts)
    return totals, counts, labels

# ffi code here

# Load the compiled CUDA library
library_path = os.path.join(os.path.dirname(__file__), "kmeans.so")
kmeans_lib = ctypes.cdll.LoadLibrary(library_path)

# Register the FFI target with JAX
jax.ffi.register_ffi_target(
    "kmeans_costs",
    jax.ffi.pycapsule(kmeans_lib.KmeansCosts),
    platform="CUDA"
)

jax.ffi.register_ffi_target(
    "kmeans_similarity",
    jax.ffi.pycapsule(kmeans_lib.KmeansSimilarity),
    platform="CUDA"
)

jax.ffi.register_ffi_target(
    "kmeans_kernel",
    jax.ffi.pycapsule(kmeans_lib.KmeansKernel),
    platform="CUDA"
)

def cuda_costs(xs: Array, totals: Array, counts: Array) -> Array:
    """CUDA implementation of baseline_costs function.
    
    Args:
        xs: Input data points of shape [512, 64] and dtype bfloat16
        totals: Cluster totals of shape [64, 64] and dtype bfloat16
        counts: Cluster counts of shape [64] and dtype int32
        
    Returns:
        Cost matrix of shape [512, 64] and dtype bfloat16
    """
    # Validate input shapes and dtypes
    if xs.shape != (512, 64):
        raise ValueError(f"xs must have shape [512, 64], got {xs.shape}")
    if totals.shape != (64, 64):
        raise ValueError(f"totals must have shape [64, 64], got {totals.shape}")
    if counts.shape != (64,):
        raise ValueError(f"counts must have shape [64], got {counts.shape}")
    
    if xs.dtype != jnp.bfloat16:
        raise ValueError(f"xs must be bfloat16, got {xs.dtype}")
    if totals.dtype != jnp.bfloat16:
        raise ValueError(f"totals must be bfloat16, got {totals.dtype}")
    if counts.dtype != jnp.int32:
        raise ValueError(f"counts must be int32, got {counts.dtype}")
    
    # Call the FFI function
    result = jax.ffi.ffi_call(
        "kmeans_costs",
        jax.ShapeDtypeStruct((512, 64), jnp.bfloat16),  # output shape/dtype
        vmap_method="broadcast_all"
    )(xs, totals, counts)
    
    return result

def cuda_similarity(xs: Array, centroids: Array) -> Array:
    """CUDA implementation of matrix multiplication xs @ centroids.T.
    
    Args:
        xs: Input data points of shape [512, 64] and dtype bfloat16
        centroids: Centroid matrix of shape [64, 64] and dtype bfloat16
        
    Returns:
        Similarity matrix of shape [512, 64] and dtype bfloat16
    """
    # Validate input shapes and dtypes
    if xs.shape != (512, 64):
        raise ValueError(f"xs must have shape [512, 64], got {xs.shape}")
    if centroids.shape != (64, 64):
        raise ValueError(f"centroids must have shape [64, 64], got {centroids.shape}")
    
    if xs.dtype != jnp.float16:
        raise ValueError(f"xs must be bfloat16, got {xs.dtype}")
    if centroids.dtype != jnp.float16:
        raise ValueError(f"centroids must be bfloat16, got {centroids.dtype}")
    
    # Call the FFI function
    result = jax.ffi.ffi_call(
        "kmeans_similarity",
        jax.ShapeDtypeStruct((512, 64), jnp.float16),  # output shape/dtype
        vmap_method="broadcast_all"
    )(xs, centroids)
    
    return result

@partial(jax.jit, donate_argnums=(1, 2))
def cuda_kmeans(xs: Array, totals: Array, counts: Array) -> tuple[Array, Array, Array]:
    """CUDA implementation of baseline_kmeans function.
    
    Args:
        xs: Input data points of shape [..., 512, 64] and dtype float16
        totals: Cluster totals of shape [..., 64, 64] and dtype float16
        counts: Cluster counts of shape [..., 64] and dtype int32
        
    Returns:
        Labels with shape [..., 512]
    """
    # Validate input shapes and dtypes
    if len(xs.shape) < 2 or xs.shape[-2:] != (512, 64):
        raise ValueError(f"xs must have final dimensions [512, 64], got {xs.shape}")
    if len(totals.shape) < 2 or totals.shape[-2:] != (64, 64):
        raise ValueError(f"totals must have final dimensions [64, 64], got {totals.shape}")
    if len(counts.shape) < 1 or counts.shape[-1:] != (64,):
        raise ValueError(f"counts must have final dimensions [64], got {counts.shape}")
    
    if xs.dtype != jnp.float16:
        raise ValueError(f"xs must be float16, got {xs.dtype}")
    if totals.dtype != jnp.float16:
        raise ValueError(f"totals must be float16, got {totals.dtype}")
    if counts.dtype != jnp.int32:
        raise ValueError(f"counts must be int32, got {counts.dtype}")
    
    # Determine output shapes (same batch dimensions as inputs)
    batch_shape = xs.shape[:-2]
    #new_totals_shape = batch_shape + (64, 64)
    #new_counts_shape = batch_shape + (64,)
    labels_shape = batch_shape + (512,)
    
    # Call the FFI function
    labels = jax.ffi.ffi_call(
        "kmeans_kernel",
        (#jax.ShapeDtypeStruct(new_totals_shape, jnp.float16),
         #jax.ShapeDtypeStruct(new_counts_shape, jnp.int32),
         jax.ShapeDtypeStruct(labels_shape, jnp.int32)),
        vmap_method="broadcast_all",
        #input_output_aliases={1: 0, 2: 1},  # totals -> new_totals, counts -> new_counts
    )(xs, totals, counts)
    counts = jax.ops.segment_sum(jnp.ones(xs.shape[:-2] + (512,), dtype=jnp.int32), labels, num_segments=totals.shape[-2])
    totals = jax.ops.segment_sum(xs, labels, num_segments=totals.shape[-2])
    
    return totals, counts, labels

# end ffi code

def print_kmeans_score(totals, counts, name=""):
    centroids = totals / (1e-3+counts[..., None])
    assert jnp.all(jnp.isfinite(centroids)), f"Centroids contain non-finite values: {jnp.sum(~jnp.isfinite(centroids))} non-finite values found: {centroids[~jnp.isfinite(centroids)]}"
    score = jnp.mean(jnp.sum(centroids * totals, axis=(-1, -2))/jnp.sum(counts, axis=-1))
    print(f"Kmeans score [{name}]: {score:.3f}")

def main():
    B = 1024
    N = 512
    D = 64
    K = 64
    test_xs = jax.random.normal(jax.random.PRNGKey(0), (B, N, D), dtype=jnp.float16) / 8.
    test_xs = test_xs.at[0,8,:].set(1.0)
    test_xs_bf16 = test_xs.astype(jnp.bfloat16)  # [1, 512, 64]
    initial_totals = test_xs[:,::N//K,:][:,:K,:]  # [64, 64]
    initial_totals_bf16 = initial_totals.astype(jnp.bfloat16)  # [64, 64]
    initial_counts = jnp.ones((B, K,), dtype=jnp.int32)  # [64]

    #print(f"x0: {test_xs[0, 0, :]} \ninitial_totals0: {initial_totals[0, 0, :]}")
    #print(f"x0_sqmag: {jnp.sum(test_xs[0, 0, :] ** 2)} \ninitial_totals0_sqmag: {jnp.sum(initial_totals[0, 0, :] ** 2)}")

    initial_centroids = initial_totals / initial_counts[:, None]  # [64, 64]
    jit_baseline_similarity = jax.jit(jax.vmap(baseline_similarity))
    baseline_similarity_result = jit_baseline_similarity(test_xs, initial_centroids)
    jit_cuda_similarity = jax.jit(jax.vmap(cuda_similarity))
    cuda_similarity_result = jit_cuda_similarity(test_xs, initial_centroids)
    assert cuda_similarity_result.shape == baseline_similarity_result.shape, "CUDA similarity shape mismatch"
    error = jnp.linalg.norm(cuda_similarity_result - baseline_similarity_result)
    print(f"CUDA similarity relative error: {float(error / jnp.linalg.norm(baseline_similarity_result)):.6f}")


    jit_baseline_kmeans = jax.jit(jax.vmap(baseline_kmeans))
    baseline_totals, baseline_counts, baseline_labels = jit_baseline_kmeans(test_xs, initial_totals, initial_counts)
    print_kmeans_score(baseline_totals, baseline_counts, "baseline")

    jit_baseline_costs = jax.jit(jax.vmap(baseline_costs))
    baseline_costs_result = jit_baseline_costs(test_xs, initial_totals, initial_counts)

    jit_cuda_costs = jax.jit(jax.vmap(cuda_costs))
    cuda_costs_result = jax.block_until_ready(jit_cuda_costs(test_xs_bf16, initial_totals_bf16, initial_counts))

    assert cuda_costs_result.shape == baseline_costs_result.shape, "CUDA costs shape mismatch"
    error = jnp.linalg.norm(cuda_costs_result - baseline_costs_result)
    print(f"CUDA costs relative error: {float(error/ jnp.linalg.norm(baseline_costs_result)):.6f}")

    jit_cuda_kmeans = jax.jit(jax.vmap(cuda_kmeans))
    cuda_totals, cuda_counts, cuda_labels = jit_cuda_kmeans(test_xs, initial_totals, initial_counts)
    print_kmeans_score(cuda_totals, cuda_counts, "cuda")
    cluster_sizes = jax.ops.segment_sum(jnp.ones(N, dtype=jnp.int32), cuda_labels, num_segments=K)
    print(f"Base counts: {baseline_counts}")
    print(f"Indices with cluster 1: baseline {np.where(baseline_labels[0] == 1)[0]} cuda {np.where(cuda_labels[0] == 1)[0]}")
    print(f"Cuda counts: {cuda_counts}")
    print(jnp.sum(cuda_counts))
    print(f"totals min: {jnp.nanmin(cuda_totals)}, max: {jnp.nanmax(cuda_totals)}")
    

    assert cuda_totals.shape == baseline_totals.shape, "CUDA totals shape mismatch"
    print(f"CUDA totals relative error: {float(jnp.linalg.norm(cuda_totals - baseline_totals) / jnp.linalg.norm(baseline_totals)):.6f}")
    assert cuda_counts.shape == baseline_counts.shape, "CUDA counts shape mismatch"
    print(f"CUDA counts correctness: {jnp.mean(cuda_counts == baseline_counts) * 100:.2f}%")
    assert cuda_labels.shape == baseline_labels.shape, "CUDA labels shape mismatch"
    print(f"CUDA labels correctness: {jnp.mean(cuda_labels == baseline_labels) * 100:.2f}%")
    assert jnp.all(jnp.isfinite(cuda_totals)), f"CUDA totals contain non-finite values: {jnp.sum(~jnp.isfinite(cuda_totals))} non-finite values found: {cuda_totals[~jnp.isfinite(cuda_totals)]}"
    

    # Compare new results to baseline results here

    from time import time
    warmup = 25
    reps = 100


    for _ in range(warmup):
        jax.block_until_ready(jit_baseline_similarity(test_xs, initial_centroids))
    start = time()
    for _ in range(reps):
        jax.block_until_ready(jit_baseline_similarity(test_xs, initial_centroids))
    end = time()
    print(f"Baseline similarity took {(end - start) / reps * 1e6:.6f} microseconds per iteration")

    for _ in range(warmup):
        jax.block_until_ready(jit_cuda_similarity(test_xs, initial_centroids))
    start = time()
    for _ in range(reps):
        jax.block_until_ready(jit_cuda_similarity(test_xs, initial_centroids))
    end = time()
    print(f"CUDA similarity took {(end - start) / reps * 1e6:.6f} microseconds per iteration")


    for _ in range(warmup):
        jax.block_until_ready(jit_baseline_costs(test_xs, initial_totals, initial_counts))
    start = time()
    for _ in range(reps):
        jax.block_until_ready(jit_baseline_costs(test_xs, initial_totals, initial_counts))
    end = time()
    print(f"Baseline costs took {(end - start) / reps * 1e6:.6f} microseconds per iteration")

    for _ in range(warmup):
        jax.block_until_ready(jit_cuda_costs(test_xs_bf16, initial_totals_bf16, initial_counts))
    start = time()
    for _ in range(reps):
        jax.block_until_ready(jit_cuda_costs(test_xs_bf16, initial_totals_bf16, initial_counts))
    end = time()
    print(f"CUDA costs took {(end - start) / reps * 1e6:.6f} microseconds per iteration")

    warmup = 5
    reps = 10


    for _ in range(warmup):
        jax.block_until_ready(jit_baseline_kmeans(test_xs, initial_totals, initial_counts))
    start = time()
    for _ in range(reps):
        jax.block_until_ready(jit_baseline_kmeans(test_xs, initial_totals, initial_counts))
    end = time()
    print(f"Baseline kmeans took {(end - start) / reps * 1e6:.6f} microseconds per iteration")

    for _ in range(warmup):
        jax.block_until_ready(jit_cuda_kmeans(test_xs, initial_totals, initial_counts))
    start = time()
    for _ in range(reps):
        jax.block_until_ready(jit_cuda_kmeans(test_xs, initial_totals, initial_counts))
    end = time()
    print(f"CUDA kmeans took {(end - start) / reps * 1e6:.6f} microseconds per iteration")

if __name__ == "__main__":
    main()
    

    



