import jax
from jax import numpy as jnp
from chex import Array
from typing import Callable, Optional, Tuple
from functools import partial
from jax import random
from jax.random import PRNGKey
from jax.scipy.special import logsumexp
from einshape import jax_einshape as einshape

@partial(jax.jit, static_argnames=("k", "iters"))
def subsample_2kmeans(xs: Array, key:PRNGKey, k: Tuple[int]=(4, 4), iters: int=10, logmass: Optional[Array]=None, metric: Optional[Array]=None) -> Tuple[Array, Array]:
    # level 0
    centroids, labels = subsample_kmeans(xs, key, k=k[0], iters=iters, logmass=logmass, metric=metric)
    if len(k) == 1:
        return centroids, labels
    assert len(k) == 2, "k must be a tuple of two integers"
    K0, K1 = k
    # level 1
    centroids = jnp.zeros((K0, K1, xs.shape[1],), dtype=centroids.dtype)
    labels = labels * K1
    if logmass is None:
        logmass = jnp.zeros(xs.shape[0], dtype=xs.dtype)
    def do_cluster_k(k, val):
        centroids, labels = val
        base_index = k * K1
        active = labels == base_index
        local_logmass = jnp.where(active, logmass, -jnp.inf)
        sub_centroids, sub_labels = subsample_kmeans(xs, key, k=K1, iters=iters, logmass=local_logmass, metric=metric)
        centroids = centroids.at[k].set(sub_centroids)
        labels = jnp.where(active, sub_labels + base_index, labels)
        return centroids, labels
    centroids, labels = jax.lax.fori_loop(0, K0, do_cluster_k, (centroids, labels))
    centroids = einshape("abd->(ab)d", centroids, a=K0, b=K1)
    return centroids, labels



@partial(jax.jit, static_argnames=("k", "iters"))
def subsample_kmeans(xs: Array, key: PRNGKey, k: int=4, iters: int=10, logmass: Optional[Array]=None, metric: Optional[Array]=None) -> Tuple[Array, Array]:
    N = xs.shape[0]
    D = xs.shape[1]
    SUBSAMPLE_SIZE = k * D
    SN = min(N, SUBSAMPLE_SIZE)
    subsample_key, clustering_key = random.split(key)
    weights = jax.nn.softmax(logmass) if logmass is not None else jnp.ones(N) / N
    chosen_indices = random.choice(subsample_key, jnp.arange(N), shape=(SN,), replace=False, p=weights)
    # TODO adjust importance weights after subsampling
    inactive_indices = jnp.ones(N, dtype=jnp.bool_).at[chosen_indices].set(False)
    sub_xs = xs[chosen_indices]
    if logmass is not None:
        excess_logmass = jnp.where(inactive_indices.any(), logsumexp(logmass, where=inactive_indices), -jnp.inf)
        sub_logmass = logmass[chosen_indices]
        sub_logmass = jnp.logaddexp(sub_logmass, excess_logmass - jnp.log(SN))
    else:
        sub_logmass = None

    centroids, _ = kmeans(sub_xs, clustering_key, k=k, iters=iters, logmass=sub_logmass, metric=metric)

    def assign(xs: Array, centroids: Array, metric: Optional[Array]=None) -> Array:
        if metric is None:
            distances = jnp.linalg.norm(xs[:, None] - centroids[None, :], axis=-1)
        else:
            xs_ = xs @ metric
            centroids_ = centroids @ metric
            distances = jnp.sum(
                (xs[:, None] - centroids[None, :]) * (xs_[:, None] - centroids_[None, :]), axis=-1)
            distances = jnp.abs(distances)
        return jnp.argmin(distances, axis=1)
    labels = assign(xs, centroids, metric=metric)
    return centroids, labels

@partial(jax.jit, static_argnames=("k", "iters"))
def kmeans(xs: Array, key: PRNGKey, k: int=4, iters: int=10, logmass: Optional[Array]=None, metric: Optional[Array]=None):
    """
    K-means clustering algorithm.
    
    Parameters:
    - xs: Input data points (shape: [n_samples, n_features]).
    - key: Random key for initialization.
    - k: Number of clusters.
    - iters: Number of iterations for the algorithm.
    - logmass: log weights for the data points.
    
    Returns:
    - centroids: Final cluster centroids.
    - labels: Cluster labels for each data point.
    """

    # Ensure input is a 2D array
    assert xs.ndim == 2, "Input data must be a 2D array."

    # Set default logmass to zeros if not provided
    if logmass is None:
        logmass = jnp.zeros(xs.shape[0])
    # Ensure logmass shape matches number of samples
    assert logmass.ndim == 1, "logmass must be a 1D array."
    assert logmass.shape[0] == xs.shape[0], "logmass must have the same number of samples as xs."

    weights = jax.nn.softmax(logmass)
    
    
    # Initialize centroids randomly
    centroids = random.choice(key, xs, shape=(k,), replace=False, p=weights)
    #centroids = random.choice(key, xs, shape=(k,), replace=False)
    if metric is not None:
        xs_ = xs @ metric

    for _ in range(iters):
        # Compute distances from points to centroids
        if metric is None:
            distances = jnp.linalg.norm(xs[:, None] - centroids[None, :], axis=-1)
        else:
            #xs_ = xs @ metric
            centroids_ = centroids @ metric
            distances = jnp.sum(
                (xs[:, None] - centroids[None, :]) * (xs_[:, None] - centroids_[None, :]), axis=-1)
            distances = jnp.abs(distances)  # Ensure distances are non-negative
        
        # Assign labels based on closest centroid
        labels = jnp.argmin(distances, axis=1)
        
        # Update centroids
        def get_centroid(l):
            total_xs = jnp.sum(xs * weights[:, None], axis=0, where=(labels==l)[...,None])
            total_weight = jnp.sum(weights, where=(labels==l))
            return jnp.where(total_weight == 0., centroids[l], total_xs / total_weight)
        #centroids = jax.vmap(lambda l: jnp.mean(xs, axis=0, where=(labels==l)[...,None]))(jnp.arange(k))
        centroids = jax.vmap(get_centroid)(jnp.arange(k))
    
    return centroids, labels

def hierarchical_kmeans(xs: Array, key: PRNGKey, k: int=4, levels: int=2, iters: int=10):
    """
    Hierarchical K-means clustering algorithm.
    
    Parameters:
    - xs: Input data points (shape: [n_samples, n_features]).
    - key: Random key for initialization.
    - k: Number of clusters at each level.
    - levels: Number of hierarchical levels.
    - iters: Number of iterations for the algorithm.
    
    Returns:
    - labels: Cluster labels for each datapoint. The first K labels are subclusters of the first cluster of the previous level.
    """

    # Ensure input is a 2D array
    assert xs.ndim == 2, "Input data must be a 2D array."
    assert levels > 0, "Number of levels must be positive."

    # Initial level
    _, labels = kmeans(xs, key, k=k, iters=iters)

    def cluster_size(labels, label):
        return jnp.sum(labels == label)

    def sufficient_size(labels, cluster_count):
        all_sizes = jax.vmap(partial(cluster_size, labels))(jnp.arange(cluster_count))
        return jnp.max(2**jnp.ceil(jnp.log2(all_sizes))).astype(jnp.int32)

    for level in range(1, levels):
        starting_clusters = k ** level
        chunk_size = sufficient_size(labels, starting_clusters)
        # multiply base labels by k
        labels = labels * k
        for meta_cluster in range(starting_clusters):
            base_index = meta_cluster * k
            meta_point_count = jnp.sum(labels == base_index)
            if meta_point_count == 0:
                continue
            elif meta_point_count <= k:
                sub_labels = jnp.arange(meta_point_count)
            else:
                # Get the points in the meta-cluster
                #meta_cluster_points = xs[labels == base_index]
                selected = jnp.argsort(labels == base_index, descending=True)[:chunk_size]
                meta_cluster_points = xs[jnp.argsort(labels == base_index, descending=True)[:chunk_size]]
                logmass = jnp.log((labels == base_index).astype(jnp.float32)[selected])
                # Get new key for this level
                key, subkey = random.split(key)
                # Perform kmeans on the points in the meta-cluster
                #_, sub_labels = kmeans(meta_cluster_points, subkey, k=k, iters=iters)
                _, sub_labels = kmeans(meta_cluster_points, subkey, k=k, iters=iters, logmass=logmass)
                sub_labels = sub_labels[logmass > -jnp.inf]
            # Assign new labels to the points in the meta-cluster
            labels = labels.at[labels == base_index].add(sub_labels)
    return labels

