"""
Several variants of spectral clustering that we would like to compare.
"""
import numpy as np
import scipy.sparse.linalg
from sklearn.cluster import KMeans
import stag.graph
import stag.stag_internal
import math


#########################################
# Helper functions
#########################################
def labels_to_clusters(labels, k=None):
    """Take a list of labels, and return a list of clusters, using the indices"""
    if k is None:
        k = max(labels) + 1

    clusters = [[] for i in range(k)]
    for i, c in enumerate(labels):
        clusters[c].append(i)

    return clusters


def clusters_to_labels(clusters):
    """Take a list of clusters, and return a list of labels"""
    n = sum([len(cluster) for cluster in clusters])
    labels = [0] * n
    for c_idx, cluster in enumerate(clusters):
        for j in cluster:
            labels[j] = c_idx
    return labels


def kmeans(data, k):
    """
    Apply the kmeans algorithm to the given data, and return the labels.
    """
    kmeans_obj = KMeans(n_clusters=k, n_init='auto')
    kmeans_obj.fit(data)
    return [int(x) for x in list(kmeans_obj.labels_)], kmeans_obj.cluster_centers_


#############################################
# Normal Spectral Clustering
#############################################
def spectral_cluster(g: stag.graph.Graph, k: int):
    lap_mat = g.normalised_laplacian()
    _, eigenvectors = scipy.sparse.linalg.eigsh(lap_mat, k, which='SM')
    labels, _ = kmeans(eigenvectors, k)
    return labels


def spectral_cluster_logk(g: stag.graph.Graph, k: int):
    """Normal spectral clustering with only log(k) eigenvectors"""
    logk = math.ceil(math.log(k, 2))
    lap_mat = g.normalised_laplacian()
    _, eigenvectors = scipy.sparse.linalg.eigsh(lap_mat, logk, which='SM')
    labels, _ = kmeans(eigenvectors, k)
    return labels


########################################
# Power Method spectral clustering
########################################
def fast_spectral_cluster(g: stag.graph.Graph, k: int, t_const=None):
    if t_const is None:
        t_const = 2
    l = min(k, math.ceil(math.log(k, 2)))
    t = t_const * math.ceil(math.log(g.number_of_vertices() / k, 2))
    M = g.normalised_signless_laplacian()
    Y = np.random.normal(size=(g.number_of_vertices(), l))
    for _ in range(t):
        Y = M @ Y
    labels, _ = kmeans(Y, k)
    return labels


def spectral_cluster_pm_k(g: stag.graph.Graph, k: int, t_const=None):
    if t_const is None:
        t_const = 2
    logn = t_const * math.ceil(math.log(g.number_of_vertices(), 2))
    signlap = g.normalised_signless_laplacian()
    eigenvectors = np.random.normal(size=(g.number_of_vertices(), k))
    for _ in range(logn):
        eigenvectors = signlap @ eigenvectors

    # Orthogonalise
    singular_vectors, _, _ = np.linalg.svd(eigenvectors, full_matrices=False)

    labels, _ = kmeans(singular_vectors, k)
    return labels
