import jax
from jax._src import config
config.update("jax_platforms", "cpu")
from jax import numpy as jnp
from chex import Array
from typing import Optional, Tuple, Callable
from functools import partial
from jax import random
from jax.random import PRNGKey
from absl import app
from absl import flags

from kmeans import hierarchical_kmeans

# Define command line arguments with default values
FLAGS = flags.FLAGS
flags.DEFINE_integer('num_points', 4096, 'Number of data points', short_name='n')
flags.DEFINE_integer('num_dims', 16, 'Number of dimensions', short_name='d')
flags.DEFINE_integer('num_clusters', 16, 'Number of clusters per level', short_name='k')
flags.DEFINE_integer('num_levels', 2, 'Number of hierarchical levels', short_name='l')
flags.DEFINE_integer('seed', 42, 'Random seed for reproducibility', short_name='s')
flags.DEFINE_integer('iters', 20, 'Number of iterations for k-means', short_name='i')

def generate_xs(n: int, d: int, key: PRNGKey) -> Array:
    """Generates random data points."""
    return random.normal(key, (n, d))

def do_kmeans(xs: Array, k: int, levels: int, iters: int, key: PRNGKey) -> Array:
    """Performs hierarchical k-means clustering."""
    # Perform k-means clustering
    labels = hierarchical_kmeans(xs=xs, key=key, k=k, levels=levels, iters=iters)
    return labels

def generate_and_cluster(n: int, d: int, k: int, levels: int, iters: int, seed: int) -> Tuple[Array, Array]:
    """Generates random data points and performs hierarchical k-means clustering."""
    key = random.PRNGKey(seed)
    xs = generate_xs(n, d, key)
    labels = do_kmeans(xs, k, levels, iters, key)
    return xs, labels

def cluster_size(labels: Array, cluster_label: int) -> int:
    """Calculates the size of a specific cluster."""
    return jnp.sum(labels == cluster_label)

def analyze_cluster_sizes(labels: Array, total_clusters: int) -> Array:
    """Analyzes the sizes of all clusters."""
    cluster_sizes = jax.vmap(partial(cluster_size, labels))(jnp.arange(total_clusters))
    q20 = jnp.quantile(cluster_sizes, 0.2)
    q80 = jnp.quantile(cluster_sizes, 0.8)
    q50 = jnp.quantile(cluster_sizes, 0.5)
    print(f"Cluster sizes 20:50:80: {q20:.0f} - {q50:.0f} - {q80:.0f}")

def cluster_sqmean_var(xs: Array, labels: Array, cluster_label: int) -> float:
    """Calculates the intra-cluster variance for a specific cluster."""
    active_indices = labels == cluster_label
    mean = xs.mean(axis=0, where=active_indices[..., None])
    meansq = jnp.square(xs).mean(where=active_indices[..., None])
    sqmean = jnp.square(mean).mean()
    var = meansq - sqmean
    return sqmean, var

def analyze_cluster_variances(xs: Array, labels: Array, total_clusters: int) -> Array:
    sqmean, var = jax.vmap(partial(cluster_sqmean_var, xs, labels))(jnp.arange(total_clusters))
    sizes = jax.vmap(partial(cluster_size, labels))(jnp.arange(total_clusters))
    mean_size = jnp.mean(sizes)
    sqmean = jnp.mean(sqmean * sizes) / mean_size
    var = jnp.mean(var * sizes) / mean_size
    print(f"Cluster size {mean_size:.0f} sqmean {sqmean:.2f} var {var:.2f}")

def main(argv):
    # Access values via FLAGS
    n = FLAGS.num_points
    d = FLAGS.num_dims
    k = FLAGS.num_clusters
    l = FLAGS.num_levels
    seed = FLAGS.seed
    iters = FLAGS.iters
    
    print(f"Running kmeans_stats with:")
    print(f"  Number of points: {n}")
    print(f"  Number of dimensions: {d}")
    print(f"  Number of clusters per level: {k}")
    print(f"  Number of hierarchical levels: {l}")
    print(f"  Random seed: {seed}")
    print(f"  Number of iterations: {iters}")
    
    xs, labels = generate_and_cluster(n, d, k, l, iters, seed)
    for level in range(l):
        cluster_count = k ** (l - level)
        print("--------------------------------------")
        print(f"Level {level}: {cluster_count} clusters")
        coarse_labels = labels // (k ** level)
        analyze_cluster_variances(xs, coarse_labels, cluster_count)
        analyze_cluster_sizes(coarse_labels, cluster_count)

if __name__ == '__main__':
    app.run(main)
