import jax
from jax._src import config
config.update("jax_platforms", "cpu")
from jax import numpy as jnp
import numpy as np
import chex
from chex import Array
from typing import Optional, Tuple, Callable
from functools import partial
from jax import random
from jax.random import PRNGKey
from absl import app
from absl import flags
from jax.tree_util import tree_map
from einshape import jax_einshape as einshape

from kmeans import hierarchical_kmeans
from cluster import Cluster, QueryCluster, BasicCluster
from attention import exact_attention, hierarchical_attention

# Define command line arguments with default values
FLAGS = flags.FLAGS
flags.DEFINE_integer('num_points', 4096, 'Number of data points', short_name='n')
flags.DEFINE_integer('num_dims', 16, 'Number of dimensions', short_name='d')
flags.DEFINE_integer('num_clusters', 16, 'Number of clusters per level', short_name='k')
flags.DEFINE_integer('num_levels', 2, 'Number of hierarchical levels', short_name='l')
flags.DEFINE_integer('seed', 42, 'Random seed for reproducibility', short_name='s')
flags.DEFINE_integer('iters', 20, 'Number of iterations for k-means', short_name='i')

def generate_xs(n: int, d: int, key: PRNGKey) -> Array:
    """Generates random data points."""
    return random.normal(key, (n, d))

def cluster_size(labels: Array, cluster_label: int) -> int:
    """Calculates the size of a specific cluster."""
    return jnp.sum(labels == cluster_label)

def analyze_cluster_sizes(labels: Array, total_clusters: int) -> Array:
    """Analyzes the sizes of all clusters."""
    cluster_sizes = jax.vmap(partial(cluster_size, labels))(jnp.arange(total_clusters))
    q20 = jnp.quantile(cluster_sizes, 0.2)
    q80 = jnp.quantile(cluster_sizes, 0.8)
    q50 = jnp.quantile(cluster_sizes, 0.5)
    print(f"Cluster sizes 20:50:80: {q20:.0f} - {q50:.0f} - {q80:.0f}")
    min_size = jnp.min(cluster_sizes)
    max_size = jnp.max(cluster_sizes)
    mean_size = jnp.mean(cluster_sizes)
    print(f"Cluster sizes min: {min_size:.0f} max: {max_size:.0f}")
    print(f"Oversize ratio: {max_size / mean_size:.2f}")
    #assert max_size <= 4.0 * mean_size, f"Max size {max_size:.0f} is too large compared to mean size {mean_size:.0f}"

def cluster_sqmean_var(xs: Array, labels: Array, cluster_label: int) -> float:
    """Calculates the intra-cluster variance for a specific cluster."""
    active_indices = labels == cluster_label
    mean = xs.mean(axis=0, where=active_indices[..., None])
    meansq = jnp.square(xs).sum(axis=-1).mean(where=active_indices)
    sqmean = jnp.square(mean).sum()
    var = meansq - sqmean
    return sqmean, var

def analyze_cluster_variances(xs: Array, labels: Array, total_clusters: int) -> Array:
    sqmean, var = jax.vmap(partial(cluster_sqmean_var, xs, labels))(jnp.arange(total_clusters))
    sizes = jax.vmap(partial(cluster_size, labels))(jnp.arange(total_clusters))
    mean_size = jnp.mean(sizes)
    sqmean = jnp.mean(sqmean * sizes) / mean_size
    var = jnp.mean(var * sizes) / mean_size
    print(f"Cluster size {mean_size:.0f} sqmean {sqmean:.2f} var {var:.2f}")


@chex.dataclass
class SimulatedData:
    queries: Array
    keys: Array
    values: Array
    query_labels: Array
    key_labels: Array
    query_clusters: QueryCluster
    key_clusters: Cluster

def generate_data() -> SimulatedData:
    # Get values from FLAGS
    n = FLAGS.num_points
    d = FLAGS.num_dims
    k = FLAGS.num_clusters
    q = k
    l = FLAGS.num_levels
    seed = FLAGS.seed
    iters = FLAGS.iters

    data = np.load("data/data.npz")
    realks, realvs, realqs = data["keys"][0,:n,:d], data["values"][0,:n,:d], data["queries"][0,-n:,:d]
    realks, realvs, realqs = jnp.array(realks), jnp.array(realvs), jnp.array(realqs)
    assert realks.shape == (n, d), f"Keys shape mismatch: {realks.shape} != ({n}, {d})"

    # Setup random keys
    key = random.PRNGKey(seed)
    qs_key, ks_key = random.split(key)
    qgenkey, qlabkey = random.split(qs_key)

    # Generate random query points and cluster them
    qs = generate_xs(n, d, qgenkey)
    qs = realqs
    qmean = jnp.mean(qs, axis=0, keepdims=False)
    qs = qs - qmean  # Center the queries
    #qs = qs - jnp.mean(qs, axis=0, keepdims=True)  # Center the queries
    qlabels = hierarchical_kmeans(xs=qs, key=qlabkey, k=q, levels=l, iters=iters)

    @jax.vmap
    def make_qclusters(i):
        return QueryCluster.make_leaves(qs, qlabels, label=i, size=4*n//(q**l))
        return QueryCluster.make(qs, jnp.log((qlabels == i).astype(jnp.float32)))
    qclusters = make_qclusters(jnp.arange(q**l))

    # Generate random key points and values, cluster them
    kgenkey, vgenkey, klabkey = random.split(ks_key, 3)
    ks = generate_xs(n, d, kgenkey)
    ks = realks
    ks = ks - jnp.mean(ks, axis=0, keepdims=True)  # Center the keys
    vs = generate_xs(n, d, vgenkey)
    vs = realvs
    klabels = hierarchical_kmeans(xs=ks, key=klabkey, k=k, levels=l, iters=iters)

    @jax.vmap
    def make_kclusters(i):
        return BasicCluster.make_leaves(ks, vs, klabels, label=i, size=4*n//(q**l), qmean=qmean)
        return Cluster.make(ks, vs, jnp.log((klabels == i).astype(jnp.float32)))
    kclusters = make_kclusters(jnp.arange(k**l))

    return SimulatedData(
        queries=qs,
        keys=ks,
        values=vs,
        query_labels=qlabels,
        key_labels=klabels,
        query_clusters=qclusters,
        key_clusters=kclusters
    )




def main(argv):
    # Access values via FLAGS
    n = FLAGS.num_points
    d = FLAGS.num_dims
    k = FLAGS.num_clusters
    l = FLAGS.num_levels
    seed = FLAGS.seed
    iters = FLAGS.iters
    q = k
    
    
    print(f"Running kmeans_stats with:")
    print(f"  Number of points: {n}")
    print(f"  Number of dimensions: {d}")
    print(f"  Number of clusters per level: {k}")
    print(f"  Number of hierarchical levels: {l}")
    print(f"  Random seed: {seed}")
    print(f"  Number of iterations: {iters}")

    data = generate_data()
    print(jnp.mean(jnp.sum(data.keys ** 2, axis=-1)))

    qclusters = data.query_clusters
    flat_qclusters = tree_map(lambda x: einshape(f"({l*'k'})m...->({l*'k'}m)...", x, k=q), qclusters)
    qclusters = tree_map(lambda x: einshape(f"({l*'k'})...->{l*'k'}...", x, k=q), qclusters)
    kclusters = data.key_clusters
    flat_kclusters = tree_map(lambda x: einshape(f"({l*'k'})m...->({l*'k'}m)...", x, k=k), kclusters)
    kclusters = tree_map(lambda x: einshape(f"({l*'k'})...->{l*'k'}...", x, k=k), kclusters)

    for level in range(l):
        cluster_count = k ** (l - level)
        print("--------------------------------------")
        print(f"Level {level}: {cluster_count} clusters")
        coarse_labels = data.key_labels // (k ** level)
        analyze_cluster_variances(data.keys, coarse_labels, cluster_count)
        analyze_cluster_sizes(coarse_labels, cluster_count)

    print("--------------------------------------")
    valid_q_indices = flat_qclusters.logmass > -jnp.inf
    print("--------------------------------------")

    exact_results = exact_attention(
            queries=flat_qclusters.centroid,# + data.qmean,
            keys=flat_kclusters.centroid,
            values=flat_kclusters.value,
            logmass=flat_kclusters.logmass,
        )
    exact_results = exact_results[valid_q_indices]
    print("Exact attention output type:")
    print(tree_map(lambda x: x.shape, exact_results))

    print("--------------------------------------")
    
    hierarchical_results = hierarchical_attention(qclusters, kclusters)
    h_result_value = hierarchical_results.value[valid_q_indices].reshape(exact_results.shape)
    print("Hierarchical attention output type:")
    print(tree_map(lambda x: x.shape, hierarchical_results))
    print("--------------------------------------")
    print(f"Result Sq: {jnp.mean(jnp.square(exact_results)):.4f}")
    print(f"Error Sq: {jnp.mean(jnp.square(exact_results - h_result_value)):.4f}")
    print(f"Corr: {jnp.corrcoef(exact_results.flatten(), h_result_value.flatten())[0, 1]:.4f}")

    print("--------------------------------------")

if __name__ == '__main__':
    app.run(main)
