from .dynamic_csc import DynamicCoreset
from scipy import sparse
import math
from coreset_sc import gen_sbm
import numpy
from sklearn.cluster import KMeans, SpectralClustering
from sklearn.metrics.cluster import adjusted_rand_score

def signless_laplacian_and_D_sparse(A, D=None):
    n = A.shape[0]
    if D is None:
        D = sparse.diags(A.sum(axis=1))
    else:
        D = sparse.diags(D)
    D_inv_half = sparse.diags(1 / numpy.sqrt(D.diagonal()))
    L = D - A
    N = D_inv_half @ L @ D_inv_half
    M = sparse.eye(n) - (0.5) * N
    return M, D



def fast_spectral_cluster(M, D, k: int, kmeans_alg=None):
    # M is the signless laplacian: I - (1/2) * D^(-1/2) * A * D^(-1/2)

    n = M.shape[0]
    _l = min(k, math.ceil(math.log(k, 2)))
    t = 5 * math.ceil(math.log(n / k, 2))
    Y = numpy.random.normal(size=(n, _l))

    # We know the top eigenvector of the normalised laplacian.
    # It doesn't help with clustering, so we will project our power method to
    # be orthogonal to it.
    top_eigvec = numpy.sqrt(D @ numpy.full((n,), 1))
    norm = numpy.linalg.norm(top_eigvec)
    if norm > 0:
        top_eigvec /= norm

    for _ in range(t):
        Y = M @ Y

        # Project Y to be orthogonal to the top eigenvector
        for i in range(_l):
            Y[:, i] -= (top_eigvec.transpose() @ Y[:, i]) * top_eigvec

    kmeans = (
        kmeans_alg if kmeans_alg is not None else KMeans(n_clusters=k, n_init="auto")
    )
    kmeans.fit(Y)
    return kmeans.labels_

def spectral_clustering(A, k):
    sc = SpectralClustering(
        n_clusters=k,
        affinity="precomputed",
        assign_labels="kmeans",
        random_state=0,
    )
    sc.fit(A)
    return sc.labels_

def extract_coreset_graph(graph):
    (n, indptr,indices, data, nnz) = graph.rust_get_coreset_graph()

    # build the csr matrix from the results:
    # n = number of nodes
    # indptr = array of size n+1
    # indices = array of size nnz
    # data = array of size nnz
    # nnz = number of non-zero entries

    csr_matrix = sparse.csr_array((data, indices, indptr), shape=(n, n))

    return csr_matrix


def test_dynamic_csc():
    import random
    import string
    import numpy as np
    from tqdm import tqdm
    print("Dynamic CSC test")

    # sbm params
    n = 1000
    k = 25
    p = 0.5
    q = 2/(n*k)

    slow_sc = True


    # dynamic csc params
    coreset_size = 2048
    affinity_shift = 0.0
    degree_threshold = 1.25
    update_threads = 4

    # evaluation params
    querry_interval = 200_000

    dynamic_csc = DynamicCoreset(
        coreset_size=coreset_size,
        num_clusters=5*k,
        affinity_shift=0.0,
        degree_threshold=degree_threshold,
        update_threads=update_threads,
        update_buffer_size=querry_interval*2,
    )


    print("Generating sbm")
    sbm, y = gen_sbm(n,k,p,q)
    assert sbm.shape[0] == n*k

    # add edges for each node (random order)

    def random_edges():
        random_node_indices = np.array(range(n*k))
        np.random.shuffle(random_node_indices)
        edges = []
        for i in random_node_indices:
            for neighgbour_index in sbm[i].indices:
                if i == neighgbour_index:
                    continue
                edges.append((str(i), str(neighgbour_index), np.float32(1.0)))
    
        return edges


    print("Inserting edges")
    edges = random_edges()

    filtered_errors = []
    affinity_shifts = []
    fast_aris = []
    slow_aris = []
    xs = []

    pbar = tqdm(total=len(edges))
    for i in range(len(edges)):
        pbar.update(1)
        edge = edges[i]
        dynamic_csc.insert_edge(edge[0], edge[1], edge[2])
        # rust_delete_edge(edge[0], edge[1], edge[2]) delete an amount from the edge
        # rust_delete_entire_edge(edge[0], edge[1]) delete an entire edge

        filtered_errors.append(dynamic_csc.extract_filtered_average_distance_error())
        affinity_shifts.append(dynamic_csc.extract_affinity_shift())

        if i% querry_interval == 0: # only run once all the nodes are in the graph
            if len(dynamic_csc.get_names_in_graph()) < n*k:
                pbar.set_description("Waiting for graph to fill")
            else:
                coreset_graph = extract_coreset_graph(dynamic_csc)
                (indices, _) = dynamic_csc.extract_coreset_indices_and_weights()
                
                # cluster the coreset graph:

                M, D = signless_laplacian_and_D_sparse(coreset_graph)
                coreset_labels = fast_spectral_cluster(M, D, k)
                names, full_labels,distances = dynamic_csc.label_entire_graph(coreset_labels.astype(np.uint64),k)
                names = np.array([int(i) for i in names])
                adjusted_rand = adjusted_rand_score(y[names], full_labels)
                fast_aris.append(adjusted_rand)

                if slow_sc:
                    # need to set the indptr and indices dtypes to int32
                    coreset_graph.indptr = coreset_graph.indptr.astype(np.int32)
                    coreset_graph.indices = coreset_graph.indices.astype(np.int32)
                    coreset_labels = spectral_clustering(coreset_graph, k)

                    names, full_labels,distances = dynamic_csc.label_entire_graph(coreset_labels.astype(np.uint64),k)
                    names = np.array([int(i) for i in names])
                    adjusted_rand = adjusted_rand_score(y[names], full_labels)
                    slow_aris.append(adjusted_rand)
                    pbar.set_description(f"fast ARI: {fast_aris[-1]:.3f} | slow ARI: {slow_aris[-1]:.3f} dist: {sum(distances):.4f}")
                else:
                    pbar.set_description(f"fast ARI: {adjusted_rand:.4f} | dist: {sum(distances):.4f}")
                xs.append(i)

            

    # get the coreset indices and weights:
    (indices, weights) = dynamic_csc.extract_coreset_indices_and_weights()
    names_in_graph = dynamic_csc.get_names_in_graph()
    names_in_graph = [int(i) for i in names_in_graph]
    names_in_graph = sorted(names_in_graph)
    # assert that the indices match the sbm indices:

    expected_index_set = list([i for i in range(n*k)])
    actual_index_set = [int(i) for i in indices]
    actual_index_set = sorted(actual_index_set)
    
    print(len(expected_index_set), len(actual_index_set), len(names_in_graph))
    print(expected_index_set[:10])
    print(actual_index_set[:10])
    print(names_in_graph[:10])


    # plot the filtered errors and affinity shifts:
    # import matplotlib.pyplot as plt
    # plt.plot(filtered_errors, label="Filtered errors")
    # plt.plot(affinity_shifts, label="Affinity shifts")
    # plt.xlabel("Iteration")
    # plt.ylabel("Value")
    # plt.legend()
    # plt.title("Filtered errors and affinity shifts")
    # plt.show()
    # get the coreset graph:


    # plot the aris:
    import matplotlib.pyplot as plt
    plt.plot(xs, fast_aris, label="fast ARI")
    if slow_sc:
        plt.plot(xs, slow_aris, label="slow ARI")

    plt.xlabel("Iteration")
    plt.ylabel("ARI")
    plt.legend()
    plt.savefig("ari.png")

