"""
Algorithm specifications for the experiments.
"""
import graph_tool as gt
import graph_tool.spectral
import numpy as np
import stag.graph
import stag.cluster
from scipy import sparse
from typing import Optional
from abc import ABC, abstractmethod
import math
from sklearn.cluster import KMeans

from ls24.contractedgraph import Contracted_Graph
from ls24.spectral_clustering import run_spectral_clustering, spectral_clustering, clusters_to_labels
from ls24.sparsifier import DynamicGraphSparsifier

from coreset_sc import CoresetSpectralClustering

from dynamic_csc.dynamic_csc import DynamicCoreset, FasterDynamicCoreset


def fast_spectral_cluster(g: stag.graph.Graph, k: int):
    l = max(2, math.ceil(math.log(k, 2)))*4
    t = 10 * math.ceil(math.log(g.number_of_vertices() / k, 2))
    M = g.normalised_signless_laplacian()
    Y = np.random.normal(size=(g.number_of_vertices(), l))

    for _ in range(t):
        Y = M @ Y

    Y, _, _ = np.linalg.svd(Y, full_matrices=False)

    kmeans = KMeans(n_clusters=k, n_init='auto')
    kmeans.fit(Y)
    return kmeans.labels_

class DynamicSCAlgorithm(ABC):
    def __init__(self):
        pass

    @abstractmethod
    def add_node(self, new_node_id):
        pass

    @abstractmethod
    def add_edge(self, edge_tuple):
        pass

    @abstractmethod
    def remove_edge(self, edge_tuple):
        pass

    @abstractmethod
    def remove_node(self, node_id):
        pass

    def add_nodes(self, new_node_ids):
        for new_node_id in new_node_ids:
            self.add_node(new_node_id)


    def add_edges(self, edge_tuples):
        for new_edge in edge_tuples:
            self.add_edge(new_edge)


    def remove_edges(self, edge_tuples):
        for edge_tuple in edge_tuples:
            self.remove_edge(edge_tuple)


    def remove_nodes(self, node_ids):
        for node_id in node_ids:
            self.remove_node(node_id)

    @abstractmethod
    def predict(self, num_clusters):
        pass


class NaiveDynamicSC(DynamicSCAlgorithm):
    def __init__(self):
        super().__init__()
        self.gt_graph = gt.Graph()
        self.data_idx_to_gt_idx = {}
        self.gt_idx_to_data_idx = {}
        self.nodes_to_remove_at_predict = []

    def add_node(self, new_node_id):
        gt_idx = int(self.gt_graph.add_vertex())
        self.data_idx_to_gt_idx[new_node_id] = gt_idx
        self.gt_idx_to_data_idx[gt_idx] = new_node_id

    def add_edge(self, edge_tuple):
        node_id_1, node_id_2 = edge_tuple
        gt_id_1 = self.data_idx_to_gt_idx[node_id_1]
        gt_id_2 = self.data_idx_to_gt_idx[node_id_2]
        self.gt_graph.add_edge(gt_id_1, gt_id_2)

    def remove_edge(self, edge_tuple):
        node_id_1, node_id_2 = edge_tuple
        gt_id_1 = self.data_idx_to_gt_idx[node_id_1]
        gt_id_2 = self.data_idx_to_gt_idx[node_id_2]

        # There's a weird bug in graph tool - we need to check for the edge
        # both 'ways round'.
        edge = self.gt_graph.edge(gt_id_1, gt_id_2)
        if edge is None:
            edge = self.gt_graph.edge(gt_id_2, gt_id_1)

        self.gt_graph.remove_edge(edge)

    def remove_node(self, node_id):
        self.nodes_to_remove_at_predict.append(self.data_idx_to_gt_idx[node_id])

    def predict(self, num_clusters):
        self.gt_graph.remove_vertex(self.nodes_to_remove_at_predict)
        self.nodes_to_remove_at_predict.clear()

        adj = gt.spectral.adjacency(self.gt_graph)
        stag_graph = stag.graph.Graph(adj + adj.transpose())
        clusters = stag.cluster.spectral_cluster(stag_graph,
                                                 num_clusters)
        sorted_indices = sorted(list(range(len(clusters))), key=lambda i: self.gt_idx_to_data_idx[i])
        sorted_clusters = [clusters[i] for i in sorted_indices]
        return np.asarray(sorted_clusters)

class StaticCSC(NaiveDynamicSC):
    def __init__(self, coreset_size=4096, coreset_modifier=2.0, fixed_pid_shift = 1000.0):
        super().__init__()

        self.coreset_size = coreset_size
        self.coreset_modifier = coreset_modifier
        self.fixed_pid_shift = fixed_pid_shift

    def predict(self, num_clusters):
        # Remove unused nodes
        self.gt_graph.remove_vertex(self.nodes_to_remove_at_predict)
        self.nodes_to_remove_at_predict.clear()

        # Handle the trivial one-cluster case
        if num_clusters == 1:
            return [0] * self.gt_graph.num_vertices()

        n = self.gt_graph.num_vertices()
        ratio = min(1.0, self.coreset_size / n)

        csc = CoresetSpectralClustering(
            num_clusters= num_clusters,
            coreset_ratio= ratio,
            k_over_sampling_factor = float(self.coreset_modifier),
            shift = float(self.fixed_pid_shift)
        )
        adj = gt.spectral.adjacency(self.gt_graph)
        adj.setdiag(1)
        adj = (adj + adj.transpose())/ 2.0


        coreset_graph = csc.get_coreset_graph(adj)
        
        # bug in coreset sc that returns self if the coreset is the full graph
        if isinstance(coreset_graph, CoresetSpectralClustering):
            stag_cg = stag.graph.Graph(adj)
            clusters = stag.cluster.spectral_cluster(stag_cg, num_clusters)
            sorted_indices = sorted(list(range(len(clusters))), key=lambda i: self.gt_idx_to_data_idx[i])
            sorted_clusters = [clusters[i] for i in sorted_indices]
            return np.asarray(sorted_clusters)
        else:
            stag_cg = stag.graph.Graph(coreset_graph)
            
            coreset_labels = stag.cluster.spectral_cluster(stag_cg, num_clusters)

            csc.set_coreset_graph_labels(coreset_labels.astype(np.uint64))
            csc.label_full_graph()
            clusters = csc.labels_.astype(np.int64)
            sorted_indices = sorted(list(range(len(clusters))), key=lambda i: self.gt_idx_to_data_idx[i])
            sorted_clusters = [clusters[i] for i in sorted_indices]
            return np.asarray(sorted_clusters)



class CPSDynamicSC(DynamicSCAlgorithm):
    def __init__(self, sampling_constant=5):
        super().__init__()
        self.sparsifier: Optional[DynamicGraphSparsifier] = None
        self.gt_graph = gt.Graph(directed=False)
        self.data_idx_to_gt_idx = {}
        self.gt_idx_to_data_idx = {}
        self.number_of_edges = 0
        self.sampling_constant = sampling_constant

    def add_node(self, new_node_id):
        if self.sparsifier is None:
            gt_idx = int(self.gt_graph.add_vertex())
            self.data_idx_to_gt_idx[new_node_id] = gt_idx
            self.gt_idx_to_data_idx[gt_idx] = new_node_id
        else:
            gt_idx = int(self.sparsifier.original_graph.add_vertex())
            self.data_idx_to_gt_idx[new_node_id] = gt_idx
            self.gt_idx_to_data_idx[gt_idx] = new_node_id

    def add_edge(self, edge_tuple):
        node_id_1, node_id_2 = edge_tuple
        if self.sparsifier is None:
            self.gt_graph.add_edge(self.data_idx_to_gt_idx[node_id_1],
                                   self.data_idx_to_gt_idx[node_id_2])
            self.number_of_edges += 1
            if self.number_of_edges >= 1000:
                self.sparsifier = DynamicGraphSparsifier(self.gt_graph, sampling_constant=self.sampling_constant)
                self.sparsifier.create_sparsifier()
        else:
            self.sparsifier.original_graph.add_edge(self.data_idx_to_gt_idx[node_id_1],
                                                    self.data_idx_to_gt_idx[node_id_2])
            self.sparsifier.update_sparsifier(np.asarray([[self.data_idx_to_gt_idx[node_id_1],
                                                           self.data_idx_to_gt_idx[node_id_2]]]), verbose=False)
            self.number_of_edges += 1

    def remove_edge(self, edge_tuple):
        raise NotImplementedError

    def remove_node(self, node_id):
        raise NotImplementedError

    def predict(self, num_clusters):
        if self.sparsifier is None:
            self.sparsifier = DynamicGraphSparsifier(self.gt_graph, sampling_constant=self.sampling_constant)
            self.sparsifier.create_sparsifier()

        sparsified_graph = self.sparsifier.get_sparsified_graph()

        # remove isolated vertices in the sparsified graph
        not_isolated = sparsified_graph.new_vertex_property("bool")
        for v in sparsified_graph.vertices():
            not_isolated[v] = v.out_degree() + v.in_degree() > 0

        # Set the graph's vertex filter to hide isolated vertices
        sparsified_graph.set_vertex_filter(not_isolated)

        # Perform the clustering
        adj = gt.spectral.adjacency(sparsified_graph)
        stag_graph = stag.graph.Graph(adj + adj.transpose())
        clusters = stag.cluster.spectral_cluster(stag_graph,
                                                 num_clusters)
        sorted_indices = sorted(list(range(len(clusters))), key=lambda i: self.gt_idx_to_data_idx[i])
        sorted_clusters = [clusters[i] for i in sorted_indices]
        return np.asarray(sorted_clusters)


class LS24(DynamicSCAlgorithm):
    def __init__(self, sampling_constant=1, degree_trigger=0.2, reinit_threshold=100000):
        super().__init__()
        self.gt_graph = gt.Graph(directed=False)
        self.sparsifier: Optional[DynamicGraphSparsifier] = None
        self.cg: Optional[Contracted_Graph] = None
        self.data_idx_to_gt_idx = {}
        self.gt_idx_to_data_idx = {}
        self.updates_since_reinit = 0
        self.sampling_constant = sampling_constant
        self.degree_trigger = degree_trigger
        self.reinit_threshold = reinit_threshold

    def initialise_cg(self, num_clusters):
        self.sparsifier = DynamicGraphSparsifier(self.gt_graph, sampling_constant=self.sampling_constant)
        self.sparsifier.create_sparsifier()
        sparsified_graph = self.sparsifier.get_sparsified_graph()

        # remove isolated vertices in the sparsified graph
        not_isolated = sparsified_graph.new_vertex_property("bool")
        for v in sparsified_graph.vertices():
            not_isolated[v] = v.out_degree() + v.in_degree() > 0

        # Set the graph's vertex filter to hide isolated vertices
        sparsified_graph.set_vertex_filter(not_isolated)

        # I do not care about the labels - ARI etc will be calculated by the experiment code
        subgraph_labels = [0] * sparsified_graph.num_vertices()

        # run spectral clustering on sparse graph
        sparse_clusters, sparse_labels, _, _ = run_spectral_clustering(sparsified_graph,
                                                                       subgraph_labels,
                                                                       n_clusters=max(1, num_clusters))
        self.cg = Contracted_Graph(self.sparsifier.original_graph, sparsified_graph, degree_trigger=self.degree_trigger)
        self.cg.initialize(sparse_clusters, sparse_labels)
        sparsified_graph.set_vertex_filter(None)
        self.updates_since_reinit = 0

    def add_node(self, new_node_id):
        if self.sparsifier is None:
            gt_idx = int(self.gt_graph.add_vertex())
            self.data_idx_to_gt_idx[new_node_id] = gt_idx
            self.gt_idx_to_data_idx[gt_idx] = new_node_id
        else:
            gt_idx = int(self.sparsifier.original_graph.add_vertex())
            self.data_idx_to_gt_idx[new_node_id] = gt_idx
            self.gt_idx_to_data_idx[gt_idx] = new_node_id

    def add_edge(self, edge_tuple):
        node_id_1, node_id_2 = edge_tuple
        self.updates_since_reinit += 1

        # Update the sparsifier
        if self.sparsifier is None:
            self.gt_graph.add_edge(self.data_idx_to_gt_idx[node_id_1],
                                   self.data_idx_to_gt_idx[node_id_2])
        else:
            self.sparsifier.original_graph.add_edge(self.data_idx_to_gt_idx[node_id_1],
                                                    self.data_idx_to_gt_idx[node_id_2])

        # Update the contracted graph
        if self.cg is not None:
            self.cg.update(np.asarray([[self.data_idx_to_gt_idx[node_id_1],
                                        self.data_idx_to_gt_idx[node_id_2]]]))

    def remove_edge(self, edge_tuple):
        raise NotImplementedError

    def remove_node(self, node_id):
        raise NotImplementedError

    def predict(self, num_clusters):
        if self.cg is None or self.updates_since_reinit >= self.reinit_threshold:
            self.initialise_cg(num_clusters)

        _, contract_labels, _, _= self.cg.spectral_clustering_on_contracted(num_clusters, [0] * len(self.data_idx_to_gt_idx))
        sorted_indices = sorted(list(range(len(contract_labels))), key=lambda i: self.gt_idx_to_data_idx[i])
        sorted_clusters = [contract_labels[i] for i in sorted_indices]
        return np.asarray(sorted_clusters)


class DynamicCSC(DynamicSCAlgorithm):
    def __init__(self, coreset_size=4096 * 2, degree_threshold=5.0, pid_target=0.01, update_threads=8,
                 coreset_modifier=300, update_buffer_size=10_000_000):
        super().__init__()
        self.dynamic_csc = DynamicCoreset(
            coreset_size=coreset_size,
            num_clusters=coreset_modifier,
            pid_target=pid_target,
            degree_threshold=degree_threshold,
            update_threads=update_threads,
            update_buffer_size=update_buffer_size,
        )

        self.node_ids_in_graph = set()

    def add_node(self, new_node_id):
        self.node_ids_in_graph.add(new_node_id)

    def add_edge(self, edge_tuple):
        node_id_1, node_id_2 = edge_tuple
        if node_id_1 != node_id_2:
            self.dynamic_csc.insert_edge(str(node_id_1), str(node_id_2), 1)

    def remove_edge(self, edge_tuple):
        node_id_1, node_id_2 = edge_tuple
        if node_id_1 != node_id_2:
            self.dynamic_csc.delete_edge(str(node_id_1), str(node_id_2))

    def remove_node(self, node_id):
        self.node_ids_in_graph.remove(node_id)

    def predict(self, num_clusters):
        # Extract the coreset graph
        (n, indptr, indices, data, nnz) = self.dynamic_csc.rust_get_coreset_graph()
        coreset_graph = sparse.csr_array((data, indices, indptr), shape=(n, n))

        # symmeterise the coreset graph
        coreset_graph = (coreset_graph + coreset_graph.transpose())/ 2.0

        # cluster the coreset graph:
        coreset_graph.indptr = coreset_graph.indptr.astype(np.int32)
        coreset_graph.indices = coreset_graph.indices.astype(np.int32)
        stag_cg = stag.graph.Graph(coreset_graph)
        coreset_labels = stag.cluster.spectral_cluster(stag_cg, num_clusters)

        names, full_labels, distances = self.dynamic_csc.label_entire_graph(coreset_labels.astype(np.uint64),
                                                                            num_clusters)

        # Reorder the full labels in order of index according to the dataset graph
        names_to_labels = {int(names[i]): full_labels[i] for i in range(len(full_labels))}
        for v in self.node_ids_in_graph:
            if v not in names_to_labels:
                names_to_labels[v] = num_clusters + 1
        sorted_names = sorted(names_to_labels.keys())
        sorted_clusters = [names_to_labels[i] for i in sorted_names]

        # Return the sorted labels
        return np.asarray(sorted_clusters).astype(np.int32)


class FasterDynamicCSC(DynamicSCAlgorithm):
    def __init__(self, coreset_size=4096 * 2, pid_target=0.01, filtering_constant=0.9, coreset_modifier=10, fixed_pid_shift=None):
        super().__init__()
        self.dynamic_csc = FasterDynamicCoreset(
            pid_target=pid_target,
            filtering_constant=filtering_constant,
            const_pid_output=fixed_pid_shift
        )
        self.coreset_size = coreset_size
        self.coreset_modifier = coreset_modifier

        self.node_ids_in_graph = set()

    def add_node(self, new_node_id):
        self.node_ids_in_graph.add(new_node_id)

    def remove_node(self, node_id):
        self.node_ids_in_graph.remove(node_id)

    def add_edge(self, edge_tuple):
        node_id_1, node_id_2 = edge_tuple
        if node_id_1 != node_id_2:
            self.dynamic_csc.insert_edge(str(node_id_1), str(node_id_2), 1)

    def remove_edge(self, edge_tuple):
        node_id_1, node_id_2 = edge_tuple
        if node_id_1 != node_id_2:
            self.dynamic_csc.delete_edge_weighted(str(node_id_1), str(node_id_2), 1.0)

    def predict(self, num_clusters):
        # Extract the coreset graph
        (n, indptr, indices, data, nnz, coreset_node_names, coreset_weights) = self.dynamic_csc.rust_get_coreset_graph(
            coreset_size=self.coreset_size,
            sampling_seeds = num_clusters * self.coreset_modifier
        )
        coreset_graph = sparse.csr_array((data, indices, indptr), shape=(n, n))

        # symmeterise the coreset graph
        coreset_graph = (coreset_graph + coreset_graph.transpose())/ 2.0

        # cluster the coreset graph:
        coreset_graph.indptr = coreset_graph.indptr.astype(np.int32)
        coreset_graph.indices = coreset_graph.indices.astype(np.int32)
        stag_cg = stag.graph.Graph(coreset_graph)
        coreset_labels = stag.cluster.spectral_cluster(stag_cg, num_clusters)

        names, full_labels, distances = self.dynamic_csc.label_entire_graph(coreset_labels.astype(np.uint64),
                                                                           coreset_node_names,
                                                                           coreset_weights,
                                                                           num_clusters)

        # Reorder the full labels in order of index according to the dataset graph
        names_to_labels = {int(names[i]): full_labels[i] for i in range(len(full_labels))}
        for v in self.node_ids_in_graph:
            if v not in names_to_labels:
                names_to_labels[v] = num_clusters + 1
        sorted_names = sorted(names_to_labels.keys())
        sorted_clusters = [names_to_labels[i] for i in sorted_names]

        # Return the sorted labels
        return np.asarray(sorted_clusters).astype(np.int32)


class FastFasterDynamicCSC(FasterDynamicCSC):
    def __init__(self, coreset_size=4096 * 2, pid_target=0.01, filtering_constant=0.9, coreset_modifier=10, fixed_pid_shift=None):
        super().__init__(coreset_size=coreset_size, pid_target=pid_target,
                         filtering_constant=filtering_constant, 
                         coreset_modifier=coreset_modifier, 
                         fixed_pid_shift=fixed_pid_shift)
    
    def predict(self, num_clusters):
        # Extract the coreset graph
        (n, indptr, indices, data, nnz, coreset_node_names, coreset_weights) = self.dynamic_csc.rust_get_coreset_graph(
            coreset_size=self.coreset_size,
            sampling_seeds = num_clusters * self.coreset_modifier
        )
        coreset_graph = sparse.csr_array((data, indices, indptr), shape=(n, n))

        # symmeterise the coreset graph
        coreset_graph = (coreset_graph + coreset_graph.transpose())/ 2.0

        # cluster the coreset graph:
        coreset_graph.indptr = coreset_graph.indptr.astype(np.int32)
        coreset_graph.indices = coreset_graph.indices.astype(np.int32)
        stag_cg = stag.graph.Graph(coreset_graph)

        coreset_labels = fast_spectral_cluster(stag_cg, num_clusters)

        names, full_labels, distances = self.dynamic_csc.label_entire_graph(coreset_labels.astype(np.uint64),
                                                                           coreset_node_names,
                                                                           coreset_weights,
                                                                           num_clusters)

        # Reorder the full labels in order of index according to the dataset graph
        names_to_labels = {int(names[i]): full_labels[i] for i in range(len(full_labels))}
        for v in self.node_ids_in_graph:
            if v not in names_to_labels:
                names_to_labels[v] = num_clusters + 1
        sorted_names = sorted(names_to_labels.keys())
        sorted_clusters = [names_to_labels[i] for i in sorted_names]

        # Return the sorted labels
        return np.asarray(sorted_clusters).astype(np.int32)
    

class FixedShiftFastFasterDynamicCSC(FastFasterDynamicCSC):
    def __init__(self, coreset_size=4096 * 2, fixed_pid_shift=0.5, coreset_modifier=10):
        pid_target=0.01
        filtering_constant=0.9
        super().__init__(coreset_size=coreset_size, pid_target=pid_target,
                        filtering_constant=filtering_constant, 
                        coreset_modifier=coreset_modifier, 
                        fixed_pid_shift=fixed_pid_shift)

