import jax
from jax import numpy as jnp, Array, ShapeDtypeStruct
import numpy as np
import os
import ctypes
from functools import partial

library_path = os.path.join(os.path.dirname(__file__), "cluster.so")
cluster_lib = ctypes.cdll.LoadLibrary(library_path)
jax.ffi.register_ffi_target(
    "adjust",
    jax.ffi.pycapsule(cluster_lib.adjust),
    platform="CUDA"
)
jax.ffi.register_ffi_target(
    "adjust_fp16",
    jax.ffi.pycapsule(cluster_lib.adjust_fp16),
    platform="CUDA"
)
library_path_d128 = os.path.join(os.path.dirname(__file__), "cluster_d128.so")
cluster_lib_d128 = ctypes.cdll.LoadLibrary(library_path_d128)
jax.ffi.register_ffi_target(
    "adjust_fp16_d128",
    jax.ffi.pycapsule(cluster_lib_d128.adjust_fp16_d128),
    platform="CUDA"
)

def cuda_adjust(xs: Array, centroids: Array):
    N, D = xs.shape
    result = jax.ffi.ffi_call(
        "adjust",
        (ShapeDtypeStruct((64,64), jnp.float32),
         ShapeDtypeStruct((N,), jnp.int32),),
        vmap_method="sequential",
    )(xs, centroids)
    return result

def _cuda_adjust_fp16(beta: float, iters: int, xs: Array, totals: Array, counts: Array):
    NB, D = xs.shape
    B = 64
    N = NB // B
    K, _ = totals.shape
    #assert counts.shape == totals.shape
    assert counts.shape == totals.shape[:-1]
    #assert N == 16
    if D == 64:
        call_name = "adjust_fp16"
    elif D == 128:
        print("Using D=128 fp16 adjust")
        call_name = "adjust_fp16_d128"
    else:
        raise ValueError(f"Unsupported D {D} for fp16 adjust (only 64 and 128 supported)")
    #S, labels, new_totals, new_counts = jax.ffi.ffi_call(
    new_totals, new_counts = jax.ffi.ffi_call(
        call_name,
        #(ShapeDtypeStruct((N,B,K), jnp.float16),
         #ShapeDtypeStruct((N*B,), jnp.int32),
        (ShapeDtypeStruct((K,D), jnp.float16),
         ShapeDtypeStruct((K,), jnp.float16),),
        vmap_method="broadcast_all",
    )(xs, totals, counts, beta=np.float32(beta), iters=np.int32(iters))
    return new_totals, new_counts
cuda_adjust_fp16 = partial(_cuda_adjust_fp16, 0.9, 1)


def ref_assign(xs: Array, centroids: Array):
    centroid_sqmags = jnp.sum(centroids**2, axis=1)
    S = xs @ centroids.T
    labels = jnp.argmax(S - 0.5*centroid_sqmags[None,:], axis=1)
    return labels

def ref_segsum(K, xs: Array, labels: Array):
    def sum_centroid(centroid):
        return jnp.sum(xs, where=labels[:, None] == centroid, axis=0)
    def sum_labels(centroid):
        return jnp.sum(labels == centroid)
    totals = jax.vmap(sum_centroid)(jnp.arange(K))
    counts = jax.vmap(sum_labels)(jnp.arange(K))
    return totals, counts

def ref_full_iter(xs: Array, centroids: Array):
    K, D = centroids.shape
    labels = ref_assign(xs, centroids)
    totals, counts = ref_segsum(K, xs, labels)
    centroids = totals / counts[:, None]
    return centroids

def ref_blocked_iter(xs: Array, totals: Array, counts: Array):
    N, D = xs.shape
    K, _ = totals.shape
    BLOCK_SIZE = 64
    assert N % BLOCK_SIZE == 0
    num_blocks = N // BLOCK_SIZE
    for block in range(num_blocks):
        centroids = totals / counts[:, None]
        x_block = xs[block*BLOCK_SIZE:(block+1)*BLOCK_SIZE]
        labels_block = ref_assign(x_block, centroids)
        totals_block, counts_block = ref_segsum(K, x_block, labels_block)
        #totals += totals_block
        #totals = update_totals(totals, totals_block, counts_block[:,None])
        #counts += counts_block
        #counts = update_totals(counts, counts_block, counts_block)
        totals, counts = update_totals(totals, totals_block, counts, counts_block)
    return totals, counts

def update_totals(old, new, old_count, new_count):
    TAU = 8
    old_frac = jnp.exp(-new_count/TAU)
    result = old * old_frac[:,None] + new
    result_count = old_frac * old_count + new_count
    return result, result_count


def ref_blocked_adjust(xs: Array, centroids: Array):
    N, D = xs.shape
    K, _ = centroids.shape
    BLOCK_SIZE = 256
    assert N % BLOCK_SIZE == 0
    num_blocks = N // BLOCK_SIZE
    totals = centroids
    counts = jnp.ones((K,))
    for block in list(range(num_blocks))*1:
        centroids = totals / counts[:, None]
        x_block = xs[block*BLOCK_SIZE:(block+1)*BLOCK_SIZE]
        centroid_sqmags = jnp.sum(centroids**2, axis=1)
        S_block = x_block @ centroids.T
        labels_block = jnp.argmax(S_block - 0.5*centroid_sqmags[None,:], axis=1)
        def sum_centroid(centroid):
            return jnp.sum(x_block, where=labels_block[:, None] == centroid, axis=0)
        def sum_labels(centroid):
            return jnp.sum(labels_block == centroid)
        totals_block = jax.vmap(sum_centroid)(jnp.arange(K))
        totals_labels_block = jax.vmap(sum_labels)(jnp.arange(K))
        totals += totals_block
        counts += totals_labels_block
    centroids = totals / counts[:, None]
    return centroids

def ref_kmeans_plusplus(K: int, xs: Array):
    N, D = xs.shape
    rng = jax.random.PRNGKey(0)
    rng, init_key = jax.random.split(rng)
    init_centroid = jax.random.choice(init_key, N, (1,), replace=False)
    centroids = jnp.array(list(xs[init_centroid]))
    for k in range(1, K):
        dists = jnp.min(jnp.sum((xs[:, None, :] - centroids[None, :, :])**2, axis=-1), axis=1)
        probs = dists / jnp.sum(dists)
        rng, choice_key = jax.random.split(rng)
        next_centroid = jax.random.choice(choice_key, N, (1,), replace=False, p=probs)
        centroids = jnp.vstack([centroids, xs[next_centroid]])
    return centroids

def evaluate_clustering(xs: Array, centroids: Array):
    labels = jnp.argmin(jnp.sum(jnp.square(xs[...,:, None, :] - centroids[...,None, :, :]), axis=-1), axis=-1)
    #residuals = xs - centroids[labels]
    residuals = xs - jnp.take_along_axis(centroids, labels[..., None], axis=-2)
    error = jnp.mean(jnp.sum(residuals**2, axis=-1), axis=-1)
    print("Clustering error:", jnp.mean(error))
    

if __name__ == "__main__":
    K = 64
    D = 64


    xs = jax.random.normal(jax.random.PRNGKey(0), (64, D), dtype=jnp.bfloat16)
    centroids = jax.random.normal(jax.random.PRNGKey(1), (K, D), dtype=jnp.bfloat16)
    S, lab = cuda_adjust(xs, centroids)
    true_S = xs @ centroids.T
    true_lab = jnp.argmax(true_S, axis=1)
    def true_sum(centroid):
        return jnp.sum(xs, where=true_lab[:, None] == centroid, axis=0)
    true_totals = jax.vmap(true_sum)(jnp.arange(K))
    print("Max absolute difference in S:", jnp.max(jnp.abs(S - true_S)))
    #print(S[0])
    #print(true_S[0])
    #print(S[:,0])
    #print(true_S[:,0])
    print(S[:4, :4])
    print(true_S[:4, :4])
    print("Labels:", lab[:8])
    print("True Labels:", true_lab[:8])
    print("Incorrect label count:", jnp.sum(lab != true_lab))

    print("Max absolute difference in totals:", jnp.max(jnp.abs(S - true_totals)))

    K = 64
    D = 64



    many_N = 1024
    low_rank = 8
    many_low_dim_xs = jax.random.normal(jax.random.PRNGKey(2), (many_N, low_rank), dtype=jnp.float32)
    from scipy.linalg import hadamard
    projected_many_xs = many_low_dim_xs @ hadamard(D)[:low_rank, :] / jnp.sqrt(D*low_rank)

    many_xs = jax.random.normal(jax.random.PRNGKey(2), (many_N, D), dtype=jnp.float32)
    many_xs = projected_many_xs
    many_centroids = jax.random.normal(jax.random.PRNGKey(3), (K, D), dtype=jnp.float32)
    many_centroids = many_xs[:K]

    many_xs_fp16 = many_xs.astype(jnp.float16)
    many_centroids_fp16 = many_centroids.astype(jnp.float16)
    many_centroids_sqmag_fp16 = jnp.sum(many_centroids_fp16**2, axis=1)

    cuda_totals, cuda_counts = cuda_adjust_fp16(
        many_xs_fp16, many_centroids_fp16, jnp.ones_like(many_centroids_fp16[...,0]))
    true_S = many_xs_fp16 @ many_centroids_fp16.T
    true_lab = jnp.argmax(true_S - 0.5*many_centroids_sqmag_fp16[None,:], axis=1)
    def true_sum(centroid):
        return jnp.sum(many_xs_fp16, where=true_lab[:, None] == centroid, axis=0)
    true_totals = jax.vmap(true_sum)(jnp.arange(K)) + many_centroids_fp16
    def true_count(centroid):
        return jnp.sum(true_lab == centroid).astype(jnp.float16)
    true_counts = jax.vmap(true_count)(jnp.arange(K)) + 1.0
    #print("Max absolute difference in FP16 S:", jnp.max(jnp.abs(S - true_S.reshape(S.shape))))
    print("Max absolute difference in FP16 totals:", jnp.max(jnp.abs(cuda_totals - true_totals)))
    print("Max absolute difference in FP16 counts:", jnp.max(jnp.abs(cuda_counts - true_counts[:,None])))
    #print("FP16 Incorrect label count:", jnp.sum(lab != true_lab))

    #print(f"FP16 S[0,:4,:4]: {S[0,::8, ::8]} shape {S.shape}")
    print(f"True FP16 S[0,:4,:4]: {true_S[:64:8, ::8]}")
    #print(f"FP16 S[1,:4,:4]: {S[1,:8, :8]} shape {S.shape}")
    #print(f"True FP16 S[1,:4,:4]: {true_S[64:64+8, :8]}")
    #print(f"FP16 lab[0,:8]: {lab[:8]}")
    print(f"True FP16 lab[0,:8]: {true_lab[:8]}")
    print(f"FP16 totals[:4,:4]: {cuda_totals[:4, :4]}")
    print(f"True FP16 totals[:4,:4]: {true_totals[:4, :4]}")
    print(f"FP16 counts[:4,:4]: {cuda_counts}")
    print(f"True FP16 counts[:4,:4]: {true_counts}")
    print(f"fp16 counts shape: {cuda_counts.shape}")
    print(f"true counts shape: {true_counts.shape}")





    baseline_error = jnp.mean(jnp.sum(jnp.square(many_xs), axis=-1))
    print("Baseline error (no clustering):", baseline_error)
    print("Initial clustering error:", evaluate_clustering(many_xs, many_centroids))
    full_centroids = many_centroids
    print("Full K-means:")
    for step in range(10):
        full_centroids = ref_full_iter(many_xs, full_centroids)
        evaluate_clustering(many_xs, full_centroids)
    blocked_totals = many_centroids
    blocked_counts = jnp.ones((K,))
    print("Blocked K-means:")
    for step in range(5):
        blocked_totals, blocked_counts = ref_blocked_iter(many_xs, blocked_totals, blocked_counts)
        blocked_totals /= 2**0
        blocked_counts /= 2**0
        blocked_centroids = blocked_totals / blocked_counts[:, None]
        evaluate_clustering(many_xs, blocked_centroids)
    for step in range(0):
        blocked_centroids = ref_full_iter(many_xs, blocked_centroids)
        evaluate_clustering(many_xs, blocked_centroids)


    print("Blocked K-means cuda:")
    blocked_totals = many_centroids_fp16
    blocked_totals = many_xs_fp16[:K]
    blocked_counts = jnp.ones_like(blocked_totals[...,0])
    for step in range(5):
        blocked_totals, blocked_counts = cuda_adjust_fp16(many_xs_fp16, blocked_totals, blocked_counts)
        evaluate_clustering(many_xs_fp16.astype(jnp.float32), (blocked_totals / blocked_counts[...,None]).astype(jnp.float32))
    blocked_centroids = blocked_totals / blocked_counts[:, None]
    for step in range(0):
        blocked_centroids = ref_full_iter(many_xs_fp16, blocked_centroids)
        evaluate_clustering(many_xs_fp16.astype(jnp.float32), blocked_centroids.astype(jnp.float32))

if __name__ == "__main__":
    import nvtx
    from scipy.linalg import hadamard
    K = 64
    D = 64
    #B = 2*3*108
    #B = 108
    B = 64
    N = 1024
    low_rank = 8
    batch_low_dim_xs = jax.random.normal(jax.random.PRNGKey(3), (B, N, low_rank), dtype=jnp.float32)
    batch_xs = batch_low_dim_xs @ hadamard(D)[:low_rank, :] / jnp.sqrt(D*low_rank)
    batch_xs = jnp.zeros_like(batch_xs).at[..., :low_rank].set(batch_low_dim_xs) / jnp.sqrt(low_rank)
    batch_centroids = batch_xs[:, :K, :]
    print("Batched Blocked K-means:")
    totals = batch_centroids
    counts = jnp.ones((B, K))
    for step in range(5):
        totals, counts = jax.vmap(ref_blocked_iter)(batch_xs, totals, counts)
        centroids = totals / counts[:, :, None]
        evaluate_clustering(batch_xs, centroids)
    for step in range(0):
        centroids = jax.vmap(ref_full_iter)(batch_xs, centroids)
        evaluate_clustering(batch_xs, centroids)

    print("Batched Blocked K-means cuda:")
    batch_xs_fp16 = batch_xs.astype(jnp.float16)
    batch_centroids_fp16 = batch_centroids.astype(jnp.float16)
    totals = batch_centroids_fp16
    counts = jnp.ones_like(totals[...,0])
    for step in range(5):
        totals, counts = jax.vmap(cuda_adjust_fp16)(batch_xs_fp16, totals, counts)
        centroids = totals / counts[..., None]
        evaluate_clustering(batch_xs_fp16.astype(jnp.float32), centroids.astype(jnp.float32))
    for step in range(0):
        centroids = jax.vmap(ref_full_iter)(batch_xs_fp16, centroids)
        evaluate_clustering(batch_xs_fp16.astype(jnp.float32), centroids.astype(jnp.float32))

    from time import time
    jit_vmap_cuda_adjust_fp16 = jax.jit(jax.vmap(cuda_adjust_fp16))

    for warmup in range(1000):
        jax.block_until_ready(jit_vmap_cuda_adjust_fp16(batch_xs_fp16, totals, counts))
        #jax.block_until_ready(jax.vmap(cuda_adjust_fp16)(batch_xs_fp16, totals, counts))
    start = time()
    total_reps = 1000
    with nvtx.annotate("cuda_adjust_fp16"):
        for rep in range(total_reps):
            jax.block_until_ready(jit_vmap_cuda_adjust_fp16(batch_xs_fp16, totals, counts))
            #jax.block_until_ready(jax.vmap(cuda_adjust_fp16)(batch_xs_fp16, totals, counts))
    end = time()
    print(f"[cuda_adjust_fp16 K{K}] us per rep over {total_reps} reps: {(end-start)*1e6/total_reps:.1f}")







    


