import jax
from jax._src import config
config.update("jax_platforms", "cpu")
from jax import numpy as jnp
import chex
from chex import Array
from typing import Optional, Tuple, Callable
from functools import partial
from jax import random
from jax.random import PRNGKey
from absl import app
from absl import flags
from jax.tree_util import tree_map
from einshape import jax_einshape as einshape

from kmeans import hierarchical_kmeans
from cluster import Cluster, QueryCluster
from attention import exact_attention, hierarchical_attention

# Define command line arguments with default values
FLAGS = flags.FLAGS
flags.DEFINE_integer('num_points', 4096, 'Number of data points', short_name='n')
flags.DEFINE_integer('num_dims', 16, 'Number of dimensions', short_name='d')
flags.DEFINE_integer('num_clusters', 16, 'Number of clusters per level', short_name='k')
flags.DEFINE_integer('num_levels', 2, 'Number of hierarchical levels', short_name='l')
flags.DEFINE_integer('seed', 42, 'Random seed for reproducibility', short_name='s')
flags.DEFINE_integer('iters', 20, 'Number of iterations for k-means', short_name='i')

def generate_xs(n: int, d: int, key: PRNGKey) -> Array:
    """Generates random data points."""
    return random.normal(key, (n, d))

def cluster_size(labels: Array, cluster_label: int) -> int:
    """Calculates the size of a specific cluster."""
    return jnp.sum(labels == cluster_label)

def analyze_cluster_sizes(labels: Array, total_clusters: int) -> Array:
    """Analyzes the sizes of all clusters."""
    cluster_sizes = jax.vmap(partial(cluster_size, labels))(jnp.arange(total_clusters))
    q20 = jnp.quantile(cluster_sizes, 0.2)
    q80 = jnp.quantile(cluster_sizes, 0.8)
    q50 = jnp.quantile(cluster_sizes, 0.5)
    print(f"Cluster sizes 20:50:80: {q20:.0f} - {q50:.0f} - {q80:.0f}")

def cluster_sqmean_var(xs: Array, labels: Array, cluster_label: int) -> float:
    """Calculates the intra-cluster variance for a specific cluster."""
    active_indices = labels == cluster_label
    mean = xs.mean(axis=0, where=active_indices[..., None])
    meansq = jnp.square(xs).mean(where=active_indices[..., None])
    sqmean = jnp.square(mean).mean()
    var = meansq - sqmean
    return sqmean, var

def analyze_cluster_variances(xs: Array, labels: Array, total_clusters: int) -> Array:
    sqmean, var = jax.vmap(partial(cluster_sqmean_var, xs, labels))(jnp.arange(total_clusters))
    sizes = jax.vmap(partial(cluster_size, labels))(jnp.arange(total_clusters))
    mean_size = jnp.mean(sizes)
    sqmean = jnp.mean(sqmean * sizes) / mean_size
    var = jnp.mean(var * sizes) / mean_size
    print(f"Cluster size {mean_size:.0f} sqmean {sqmean:.2f} var {var:.2f}")


@chex.dataclass
class SimulatedData:
    queries: Array
    keys: Array
    values: Array
    query_labels: Array
    key_labels: Array
    query_clusters: QueryCluster
    key_clusters: Cluster

def generate_data() -> SimulatedData:
    # Get values from FLAGS
    n = FLAGS.num_points
    d = FLAGS.num_dims
    k = FLAGS.num_clusters
    q = k
    l = FLAGS.num_levels
    seed = FLAGS.seed
    iters = FLAGS.iters

    # Setup random keys
    key = random.PRNGKey(seed)
    qs_key, ks_key = random.split(key)
    qgenkey, qlabkey = random.split(qs_key)

    # Generate random query points and cluster them
    qs = generate_xs(n, d, qgenkey)
    qlabels = hierarchical_kmeans(xs=qs, key=qlabkey, k=q, levels=l, iters=iters)

    @jax.vmap
    def make_qclusters(i):
        return QueryCluster.make(qs, jnp.log((qlabels == i).astype(jnp.float32)))
    qclusters = make_qclusters(jnp.arange(q**l))

    # Generate random key points and values, cluster them
    kgenkey, vgenkey, klabkey = random.split(ks_key, 3)
    ks = generate_xs(n, d, kgenkey)
    vs = generate_xs(n, d, vgenkey)
    klabels = hierarchical_kmeans(xs=ks, key=klabkey, k=k, levels=l, iters=iters)

    @jax.vmap
    def make_kclusters(i):
        return Cluster.make(ks, vs, jnp.log((klabels == i).astype(jnp.float32)))
    kclusters = make_kclusters(jnp.arange(k**l))

    return SimulatedData(
        queries=qs,
        keys=ks,
        values=vs,
        query_labels=qlabels,
        key_labels=klabels,
        query_clusters=qclusters,
        key_clusters=kclusters
    )




def main(argv):
    # Access values via FLAGS
    n = FLAGS.num_points
    d = FLAGS.num_dims
    k = FLAGS.num_clusters
    l = FLAGS.num_levels
    seed = FLAGS.seed
    iters = FLAGS.iters
    q = k
    
    print(f"Running kmeans_stats with:")
    print(f"  Number of points: {n}")
    print(f"  Number of dimensions: {d}")
    print(f"  Number of clusters per level: {k}")
    print(f"  Number of hierarchical levels: {l}")
    print(f"  Random seed: {seed}")
    print(f"  Number of iterations: {iters}")

    data = generate_data()
    # We are evaluating as if the first proper clusters are leaves,
    # so delete their second order statistics
    data = data.replace(key_clusters=data.key_clusters.replace(
        corr = jnp.zeros_like(data.key_clusters.corr),
        var = jnp.zeros_like(data.key_clusters.var),
    ))

    qclusters = data.query_clusters
    flat_qclusters = qclusters
    qclusters = tree_map(lambda x: einshape(f"({l*'k'})...->{l*'k'}...", x, k=q), qclusters)
    kclusters = data.key_clusters
    flat_kclusters = kclusters
    kclusters = tree_map(lambda x: einshape(f"({l*'k'})...->{l*'k'}...", x, k=k), kclusters)

    for level in range(l):
        cluster_count = k ** (l - level)
        print("--------------------------------------")
        print(f"Level {level}: {cluster_count} clusters")
        coarse_labels = data.key_labels // (k ** level)
        analyze_cluster_variances(data.keys, coarse_labels, cluster_count)
        analyze_cluster_sizes(coarse_labels, cluster_count)

    print("--------------------------------------")
    print("--------------------------------------")

    exact_results = exact_attention(
            queries=flat_qclusters.centroid,
            keys=flat_kclusters.centroid,
            values=flat_kclusters.value,
            logmass=flat_kclusters.logmass,
        )
    print("Exact attention output type:")
    print(tree_map(lambda x: x.shape, exact_results))

    print("--------------------------------------")
    
    hierarchical_results = hierarchical_attention(qclusters, kclusters)
    h_result_value = hierarchical_results.value.reshape(exact_results.shape)
    print("Hierarchical attention output type:")
    print(tree_map(lambda x: x.shape, hierarchical_results))
    print("--------------------------------------")
    print(f"Result Sq: {jnp.mean(jnp.square(exact_results)):.4f}")
    print(f"Error Sq: {jnp.mean(jnp.square(exact_results - h_result_value)):.4f}")
    print(f"Corr: {jnp.corrcoef(exact_results.flatten(), h_result_value.flatten())[0, 1]:.4f}")

    print("--------------------------------------")

if __name__ == '__main__':
    app.run(main)
