import jax
import jax.numpy as jnp
from jax import random
import scipy
import seaborn as sns
import matplotlib.pyplot as plt
from typing import Tuple

# handling jax's random generator
_random_key = random.key(0)

def get_key():
    global _random_key
    _random_key, new_key = random.split(_random_key)
    return new_key

# 8192 x 50
num_points = 409600
num_clusters = 8
dimensions = 512


def run_kmeans_and_compare(u, w, maxiter=10_000, normalize=False):
    '''
    please manually provide dimensions because jax doesn't support shape inference
    return centroids, centroid_distances
    '''
    num_points, dim = u.shape
    num_clusters, dim_w = w.shape
    # assert dim == dim_w

    # initialize centroids
    idxs = random.choice(get_key(), num_points, (num_clusters,), replace=False)
    centroids = u[idxs]
    old_clusters = jnp.zeros((num_points,), dtype=int)

    @jax.jit
    def while_step(arg):
        iter, centroids, assignments, diff = arg
        centroids, new_assignments = kmeans_step(u, centroids)
        diff = jnp.count_nonzero(assignments != new_assignments)
        return iter + 1, centroids, new_assignments, diff
    
    @jax.jit
    def cond(arg: Tuple[int, jnp.ndarray, jnp.ndarray, jnp.ndarray]):
        iter, _, _, diff = arg
        return (iter < maxiter) & (diff > 0)

    iter, centroids, _, diff = jax.lax.while_loop(
        cond,
        while_step,
        (0, centroids, old_clusters, 1000)
    )

    print()
    if diff == 0:
        print(f'Converged in {iter} iterations!')
    else:
        print(f'Did not converge after {maxiter} iterations!')
    print()
    
    # don't always normalize centroids, since magnitude matters
    if isinstance(normalize, bool):
        if normalize:
            centroids = centroids / jnp.linalg.vector_norm(centroids, axis=-1, keepdims=True)
            w = w / jnp.linalg.vector_norm(w, axis=-1, keepdims=True)
        
        centroid_distances = jnp.einsum('pd,dc->pc', centroids, w.T)
    elif normalize == 'l2':
        # measure distance with l2 norm
        centroid_distances = jnp.linalg.norm(centroids[:, None, :] - w[None, :, :], axis=-1)
    else:
        raise ValueError('Invalid value for normalize!')

    # min-cost matching
    new_idxs = scipy.optimize.linear_sum_assignment(-centroid_distances.T)[-1]
    centroids = centroids[new_idxs]
    centroid_distances = centroid_distances[new_idxs]

    # sort such that diagonal is decreasing
    sort_idxs = jnp.argsort(jnp.diag(centroid_distances), descending=True)
    centroid_distances = centroid_distances[sort_idxs, :][:, sort_idxs]

    return centroid_distances


@jax.jit
def cluster_update(cluster, assignments, u):
    q = assignments == cluster
    mask = q.astype(jnp.int32)
    c = jnp.sum(mask)
    s = jnp.sum(u * mask[:, None], axis=0)
    m = s / c
    return m


@jax.jit
def kmeans_step(u, centroids):
    cluster_update_vmap = jax.vmap(cluster_update, in_axes=(0, None, None))

    distance = jnp.einsum('pd,dc->pc', u, centroids.T)
    assignments = jnp.argmax(distance, axis=1)

    new_centroids = cluster_update_vmap(jnp.arange(centroids.shape[0]), assignments, u)

    return new_centroids, assignments


def plot_distances(distances, ax=None, title='', x_labels='embeddings', y_labels='centroids', **subplot_kwargs):
    if ax is None:
        _, ax = plt.subplots(**subplot_kwargs)
    ax = sns.heatmap(distances, annot=True, linewidth=0.5, ax=ax, fmt='.2f')
    ax.set_xlabel(x_labels)
    ax.set_ylabel(y_labels)
    ax.set_title(title)

    return ax

    # ax.figure.show()
