"""
Datasets for the experiments.
"""
import alglab
import random
import pickle
from typing import List, Tuple
import os
import numpy as np
from abc import abstractmethod, ABC
import stag.random
import stag.graph
import graph_tool as gt
import graph_tool.generation
from sklearn.neighbors import NearestNeighbors
from sklearn.datasets import fetch_openml
import emnist


class DynamicGraphDataset(alglab.dataset.ClusterableDataset, ABC):
    def __init__(self, num_updates: int, num_vertices: int):
        self.num_vertices = num_vertices
        self.num_clusters = None
        self.num_updates = num_updates
        alglab.dataset.ClusterableDataset.__init__(self, None)

    # The updates for a dynamic graph dataset are given by:
    #   node insertions
    #   node deletions
    #   edge insertions
    #   edge deletions
    # at each time step.
    @abstractmethod
    def get_node_insertions_at_time(self, t):
        pass

    @abstractmethod
    def get_node_deletions_at_time(self, t):
        pass

    @abstractmethod
    def get_edge_insertions_at_time(self, t):
        pass

    @abstractmethod
    def get_edge_deletions_at_time(self, t):
        pass

    @abstractmethod
    def get_num_clusters_at_time(self, t):
        """Get the number of clusters for the time t."""
        pass

    @abstractmethod
    def set_iteration(self, t):
        pass

    def get_update(self, t):
        return (self.get_node_insertions_at_time(t),
                self.get_node_deletions_at_time(t),
                self.get_edge_insertions_at_time(t),
                self.get_edge_deletions_at_time(t))


class DynamicSBMAddingSmallClusters(DynamicGraphDataset):

    def __init__(self, k, n, p, q, n_new, s, num_updates: int):
        """
        Initializes a dynamic graph based on the stochastic block model (SBM).

        The sequence of updates is based on the experimental setup in LS24.

        Parameters:
        - k (int): Number of clusters.
        - n (int): Size of each cluster.
        - p (float): Average number of expected edges within each cluster.
        - q (float): Average number of expected edges between each pair of different clusters.
        """
        super().__init__(num_updates + 1, k * n)
        self.k = k
        self.n = n
        self.p = p
        self.q = q
        self.n_new = n_new
        self.s = s
        self.available_vertices = set(range(k * n))

        # Prepare the updates to the dataset
        initial_edges, initial_labels = self._generate_sbm_graph()
        self.update_schedule = [initial_edges]
        self.labels_by_iteration = [initial_labels]
        for i in range(num_updates):
            edges_to_add, true_labels = self._sample_internal_edges(self.n_new, 1, current_labels=self.labels_by_iteration[i][:])
            random_edges_to_add = self._overlay_erdos_renyi_optimized(self.s)
            self.update_schedule.append(np.vstack((edges_to_add, random_edges_to_add)))
            self.labels_by_iteration.append(true_labels)

        # Set the ground truth to be the initial labels
        self.gt_labels = initial_labels

    def _generate_sbm_graph(self):
        """
        Generates a graph from a stochastic block model with given parameters.

        Returns:
        - a list of edges
        - labels (numpy.ndarray): Array of vertex cluster memberships.
        """
        graph: stag.graph.Graph = stag.random.sbm(self.n * self.k, self.k, self.p, self.q)
        labels = stag.random.sbm_gt_labels(self.n * self.k, self.k)

        all_edges = []
        for v in range(graph.number_of_vertices()):
            for u in graph.neighbors_unweighted(v):
                if v < u:
                    all_edges.append((u, v))

        return np.asarray(all_edges), labels

    def _sample_internal_edges(self, n_small, r, current_labels):
        """
        Randomly selects a subset of vertices, samples edges within it based on
        a given probability, and updates their cluster memberships.

        Parameters:
        - n_small (int): The size of the subset to select and internally connect.
        - r (float): The probability of adding an edge between any two vertices in the subset.

        Returns:
        - sampled_edges (list of tuples): The list of edges sampled within the subset.
        - updated_labels (numpy.ndarray): The updated array of vertex cluster memberships.
        """
        if n_small > len(self.available_vertices):
            raise ValueError("n_small is larger than the number of available vertices.")

        selected_vertices = np.random.choice(list(self.available_vertices), size=n_small, replace=False)
        self.available_vertices -= set(selected_vertices)

        # Update labels for selected vertices to a new cluster
        new_cluster_id = max(current_labels) + 1 if len(current_labels) > 0 else 0
        current_labels[selected_vertices] = new_cluster_id

        # Sample edges with probability r
        sampled_edges = []
        for i in range(n_small):
            for j in range(i + 1, n_small):
                if np.random.rand() < r:
                    sampled_edges.append((selected_vertices[i], selected_vertices[j]))

        return sampled_edges, current_labels

    def _overlay_erdos_renyi_optimized(self, s):
        """
        Efficiently overlays the graph with unique edges based on the Erdős–Rényi model parameter s.

        Parameters:
        - s (float): Probability of an edge existing between any two vertices.

        Returns:
        - new_edges (list): A list of new, unique edges added to the graph.
        """
        N = self.num_vertices
        # Calculate the expected number of new edges to add
        total_possible_edges = N * (N - 1) // 2
        num_edges_to_add = int(np.round(s * total_possible_edges))

        new_edges = set()
        while len(new_edges) < num_edges_to_add:
            u, v = np.random.randint(0, N, size=2)
            new_edges.add((min(u, v), max(u, v)))  # Ensure consistency in edge direction

        return list(new_edges)

    def get_node_insertions_at_time(self, t):
        if t == 0:
            return list(range(self.num_vertices))
        else:
            return []

    def get_node_deletions_at_time(self, t):
        return []

    def get_edge_insertions_at_time(self, t):
        return self.update_schedule[t]

    def get_edge_deletions_at_time(self, t):
        return []

    def set_iteration(self, t):
        self.gt_labels = self.labels_by_iteration[t]
        self.num_clusters = self.k + t

    def get_num_clusters_at_time(self, t):
        return self.k + t


class DynamicSBMChangingClusters(DynamicGraphDataset):
    """
    A dynamic SBM with k clusters. At each iteration, we split two clusters in half, and merge their halfs
    back together, changing the clusters.

    We repeat this k/2 times.
    """

    def __init__(self, cluster_size, p, q, k):
        assert k % 2 == 0, "Number of clusters must be even."
        assert cluster_size % 2 == 0, "Cluster size must be even."
        self.cluster_size = cluster_size
        self.k = k
        self.n = cluster_size * k
        num_updates = int(k / 2)
        super().__init__(num_updates + 1, self.n)
        self.p = p
        self.q = q
        self.original_graph = None

        # Prepare the updates to the dataset
        initial_edges, initial_labels = self._generate_sbm_graph()
        self.update_schedule = [(initial_edges, [])]
        self.labels_by_iteration = [initial_labels]
        for i in range(num_updates):
            clusters_to_merge = (i * 2, (i * 2) + 1)
            edges_to_delete = self._split_clusters(clusters_to_merge[0], clusters_to_merge[1])
            edges_to_add = self._rejoin_clusters(clusters_to_merge[0], clusters_to_merge[1])
            new_labels = self._labels_after_split_and_merge(clusters_to_merge[0], clusters_to_merge[1],
                                                            self.labels_by_iteration[i])
            self.update_schedule.append((edges_to_add, edges_to_delete))
            self.labels_by_iteration.append(new_labels)

        # Set the ground truth to be the initial labels
        self.gt_labels = initial_labels

    def _get_nodes_list(self, cluster_id, split_id):
        """
        :param cluster_id: The id (0-indexed) of the cluster.
        :param split_id: Which half (1 or 2) of the cluster to return.
        :return: A list of the node ids in this half of the given cluster.
        """
        if split_id == 1:
            return list(range(cluster_id * self.cluster_size,
                              cluster_id * self.cluster_size + int(self.cluster_size / 2)))
        elif split_id == 2:
            return list(range(cluster_id * self.cluster_size + int(self.cluster_size / 2),
                              cluster_id * self.cluster_size + self.cluster_size))
        else:
            raise ValueError("split_id must be 1 or 2.")

    def _labels_after_split_and_merge(self, cluster_id_1, cluster_id_2, current_labels):
        new_labels = current_labels.copy()

        # We only need to update half of each cluster.
        cluster_1_half_2_nodes = self._get_nodes_list(cluster_id_1, 2)
        cluster_2_half_1_nodes = self._get_nodes_list(cluster_id_2, 1)

        for node_id in cluster_1_half_2_nodes:
            new_labels[node_id] = cluster_id_2
        for node_id in cluster_2_half_1_nodes:
            new_labels[node_id] = cluster_id_1

        return new_labels

    def _get_original_edges_between(self, first_cluster, second_cluster):
        edges = []
        second_cluster = set(second_cluster)
        for node_id_1 in first_cluster:
            for node_id_2 in self.original_graph.neighbors_unweighted(node_id_1):
                if node_id_2 in second_cluster:
                    edges.append((node_id_1, node_id_2))
        return edges

    def _split_clusters(self, cluster_id_1, cluster_id_2):
        """
        Split two clusters in half.

        Return a numpy array containing the the edges to remove in order to split the clusters
        in half.
        """
        # Get the node ids in each half of each cluster
        cluster_1_half_1_nodes = self._get_nodes_list(cluster_id_1, 1)
        cluster_1_half_2_nodes = self._get_nodes_list(cluster_id_1, 2)
        cluster_2_half_1_nodes = self._get_nodes_list(cluster_id_2, 1)
        cluster_2_half_2_nodes = self._get_nodes_list(cluster_id_2, 2)

        sub_clusters = [cluster_1_half_1_nodes, cluster_1_half_2_nodes, cluster_2_half_1_nodes, cluster_2_half_2_nodes]

        edges_to_delete = []

        # Remove the edges between each pair of sub-clusters
        for i in range(len(sub_clusters)):
            for j in range(i + 1, len(sub_clusters)):
                edges_to_delete += self._get_original_edges_between(sub_clusters[i], sub_clusters[j])

        return edges_to_delete

    def _sample_edges_between(self, first_cluster, second_cluster, edge_probability):
        """
        Sample edges between the given two clusters with the given probability.
        """
        assert len(first_cluster) == len(second_cluster), "Clusters must have the same size."
        cluster_size: int = len(first_cluster)
        edges = []

        # Sample edges, we just need to re-number them afterwords
        temp_graph = stag.random.sbm(2 * cluster_size, 2, 0, edge_probability)

        for i in range(cluster_size):
            for j in temp_graph.neighbors_unweighted(i):
                edges.append((first_cluster[i], second_cluster[j - cluster_size]))
        return edges

    def _rejoin_clusters(self, cluster_id_1, cluster_id_2):
        """
        Join two halfs of two clusters back together.
        """
        # Get the node ids in each half of each cluster
        cluster_1_half_1_nodes = self._get_nodes_list(cluster_id_1, 1)
        cluster_1_half_2_nodes = self._get_nodes_list(cluster_id_1, 2)
        cluster_2_half_1_nodes = self._get_nodes_list(cluster_id_2, 1)
        cluster_2_half_2_nodes = self._get_nodes_list(cluster_id_2, 2)

        edges_to_add = []
        edges_to_add += self._sample_edges_between(cluster_1_half_1_nodes, cluster_1_half_2_nodes, self.q)
        edges_to_add += self._sample_edges_between(cluster_1_half_1_nodes, cluster_2_half_1_nodes, self.p)
        edges_to_add += self._sample_edges_between(cluster_1_half_1_nodes, cluster_2_half_2_nodes, self.q)
        edges_to_add += self._sample_edges_between(cluster_1_half_2_nodes, cluster_2_half_1_nodes, self.q)
        edges_to_add += self._sample_edges_between(cluster_1_half_2_nodes, cluster_2_half_2_nodes, self.p)
        edges_to_add += self._sample_edges_between(cluster_2_half_1_nodes, cluster_2_half_2_nodes, self.q)
        return edges_to_add

    def _generate_sbm_graph(self):
        """
        Generates a graph from a stochastic block model with given parameters.

        Returns:
        - a list of edges
        - labels (numpy.ndarray): Array of vertex cluster memberships.
        """
        self.original_graph: stag.graph.Graph = stag.random.sbm(self.n, self.k, self.p, self.q)
        labels = stag.random.sbm_gt_labels(self.n, self.k)

        all_edges = []
        for v in range(self.n):
            for u in self.original_graph.neighbors_unweighted(v):
                if v < u:
                    all_edges.append((u, v))

        return np.asarray(all_edges), labels

    def get_num_clusters_at_time(self, t):
        return self.k

    def get_node_insertions_at_time(self, t):
        if t == 0:
            return list(range(self.n))
        else:
            return []

    def get_node_deletions_at_time(self, t):
        return []

    def get_edge_insertions_at_time(self, t):
        return self.update_schedule[t][0]

    def get_edge_deletions_at_time(self, t):
        return self.update_schedule[t][1]

    def set_iteration(self, t):
        self.gt_labels = self.labels_by_iteration[t]
        self.num_clusters = self.get_num_clusters_at_time(t)

    def __repr__(self):
        return f"DynamicSBMChangingClusters({self.n}, {self.k})"


class DynamicSBMMergingClusters(DynamicGraphDataset):

    def __init__(self, cluster_sizes, p, q, r, s, num_updates: int):
        """
        Initializes a dynamic graph based on the stochastic block model (SBM).

        The sequence of edge updates is based on the experiment on mering clusters in LS24.

        Parameters:
        - cluster_sizes: a list of cluster sizes to generate
        - k (int): Number of clusters.
        - n (int): Size of each cluster.
        - p (float): Average number of expected edges within each cluster.
        - q (float): Average number of expected edges between each pair of different clusters.
        """
        self.sizes = np.asarray(cluster_sizes)
        self.k = len(cluster_sizes)
        self.n = sum(cluster_sizes)
        super().__init__(num_updates + 1, self.n)
        self.p = p
        self.q = q
        self.r = r
        self.s = s

        # Prepare the updates to the dataset
        initial_edges, initial_labels = self._generate_sbm_graph()
        self.update_schedule = [initial_edges]
        self.labels_by_iteration = [initial_labels]
        clusters_to_merge = (self.k - 2, self.k - 1)
        for i in range(num_updates):
            edges_to_add, true_labels = self._merge_clusters(
                clusters_to_merge[0], clusters_to_merge[1], self.r, self.labels_by_iteration[i][:])
            random_edges_to_add = self._overlay_erdos_renyi_optimized(s)
            self.update_schedule.append(np.vstack((edges_to_add, random_edges_to_add)))
            self.labels_by_iteration.append(true_labels)
            clusters_to_merge = (clusters_to_merge[0] - 2, clusters_to_merge[0] - 2)

        # Set the ground truth to be the initial labels
        self.gt_labels = initial_labels

    def _generate_sbm_graph(self):
        """
        Generates a graph from a stochastic block model with given parameters.

        Returns:
        - a list of edges
        - labels (numpy.ndarray): Array of vertex cluster memberships.
        """
        probs = np.zeros((self.k, self.k))

        for i in range(self.k):
            probs[i, i] = self.p
            for j in range(i + 1, self.k):
                probs[i, j] = self.q
                probs[j, i] = self.q

        graph: stag.graph.Graph = stag.random.general_sbm(self.sizes, probs)
        labels = stag.random.general_sbm_gt_labels(self.sizes)

        all_edges = []
        for v in range(graph.number_of_vertices()):
            for u in graph.neighbors_unweighted(v):
                if v < u:
                    all_edges.append((u, v))

        return np.asarray(all_edges), labels

    def _overlay_erdos_renyi_optimized(self, s):
        """
        Efficiently overlays the graph with unique edges based on the Erdős–Rényi model parameter s.

        Parameters:
        - s (float): Probability of an edge existing between any two vertices.

        Returns:
        - new_edges (list): A list of new, unique edges added to the graph.
        """
        N = self.num_vertices
        # Calculate the expected number of new edges to add
        total_possible_edges = N * (N - 1) // 2
        num_edges_to_add = int(np.round(s * total_possible_edges))

        new_edges = set()
        while len(new_edges) < num_edges_to_add:
            u, v = np.random.randint(0, N, size=2)
            new_edges.add((min(u, v), max(u, v)))  # Ensure consistency in edge direction

        return list(new_edges)

    def _merge_clusters(self, label_a, label_b, r, current_labels):
        """
        Samples edges between two clusters with probability r without directly updating the graph.
        Returns the sampled edges as a 2xm numpy array and the updated labels array.

        Parameters:
        - label_a (int): The label of the first cluster.
        - label_b (int): The label of the second cluster.
        - r (float): The probability of adding an edge between any two vertices in the two clusters.
        - current_labels (list): The labels of the clusters before the merge.

        Returns:
        - sampled_edges (numpy.ndarray): A 2xm array of sampled edges between the two clusters.
        - new_labels (numpy.ndarray): The updated labels after merging the two clusters.
        """
        import numpy as np

        # Find vertices belonging to each cluster
        vertices_a = np.where(current_labels == label_a)[0]
        vertices_b = np.where(current_labels == label_b)[0]

        # Initialize an empty list to store sampled edges
        sampled_edges_list = []

        # Sample edges between the two clusters with probability r
        for v_a in vertices_a:
            for v_b in vertices_b:
                if np.random.rand() < r:
                    sampled_edges_list.append([v_a, v_b])

        # Convert the list of sampled edges to a numpy array
        sampled_edges = sampled_edges_list  # Transpose to get a 2xm array

        # Update labels to reflect the merged cluster
        current_labels[vertices_b] = label_a

        return sampled_edges, current_labels

    def get_num_clusters_at_time(self, t):
        return self.k - t

    def get_node_insertions_at_time(self, t):
        if t == 0:
            return list(range(self.num_vertices))
        else:
            return []

    def get_node_deletions_at_time(self, t):
        return []

    def get_edge_insertions_at_time(self, t):
        return self.update_schedule[t]

    def get_edge_deletions_at_time(self, t):
        return []

    def set_iteration(self, t):
        self.gt_labels = self.labels_by_iteration[t]
        self.num_clusters = self.get_num_clusters_at_time(t)

    def __repr__(self):
        return f"DynamicSBMMergingClusters({self.n}, {self.k})"


class DynamicKNNDataset(DynamicGraphDataset):

    def __init__(self, X, y, k=5, batch_size=1000, name=None, include_deletions=False):
        self.k = k
        self.include_deletions = include_deletions
        self.gt_labels = y.tolist()
        self.all_labels = y.tolist()

        # Load or compute the nearest neighbors
        if name is None:
            self.nns = self._construct_knn_graph(X)
        else:
            data_filename = f'data/{name}_{k}_nns.pkl'
            if os.path.exists(data_filename):
                with open(data_filename, 'rb') as f:
                    self.nns = pickle.load(f)
            else:
                self.nns = self._construct_knn_graph(X)
                with open(data_filename, 'wb') as f:
                    pickle.dump(self.nns, f)

        self.batch_size = batch_size
        num_classes = max(self.gt_labels) + 1
        self.num_batches = int(len(self.all_labels) / self.batch_size)

        # If we include deletions, then the number of updates is doubles.
        num_updates = (2 * self.num_batches) - 1 if self.include_deletions else self.num_batches

        super().__init__(num_updates, len(self.all_labels))

        # Prepare the order in which classes will be added to the dataset
        self.class_update_order = list(range(num_classes))
        random.shuffle(self.class_update_order)
        self.node_update_order = []
        for t in range(num_classes):
            this_cluster = self.class_update_order[t]
            self.node_update_order.extend([i for i, l in enumerate(self.all_labels) if l == this_cluster])

        # Prepare the insertions
        self.node_insertions_by_time = {}
        self.edge_insertions_by_time = {}
        unique_clusters_so_far = set()
        self.num_clusters_by_time = []
        self.node_deletions_by_time = {}
        self.edge_deletions_by_time = {}
        for t in range(self.num_batches):
            self.node_insertions_by_time[t] = self.node_update_order[t*self.batch_size:(t+1)*self.batch_size]
            new_edges = self._get_edges_to_add(set(self.node_update_order[:t*self.batch_size]), set(self.node_insertions_by_time[t]))
            self.edge_insertions_by_time[t] = np.asarray(new_edges)
            for i in self.node_insertions_by_time[t]:
                unique_clusters_so_far.add(self.all_labels[i])
            self.num_clusters_by_time.append(len(unique_clusters_so_far))
            self.node_deletions_by_time[t] = []
            self.edge_deletions_by_time[t] = []

        # Prepare the deletions
        for t in range(self.num_batches - 1):
            # Remove stuff in reverse order
            self.node_insertions_by_time[self.num_batches + t] = []
            self.edge_insertions_by_time[self.num_batches + t] = []
            self.node_deletions_by_time[self.num_batches + t] = self.node_insertions_by_time[self.num_batches - 1 - t]
            self.edge_deletions_by_time[self.num_batches + t] = self.edge_insertions_by_time[self.num_batches - 1 - t]
            self.num_clusters_by_time.append(self.num_clusters_by_time[self.num_batches - 2 - t])

        # Set the ground truth to be the initial labels
        self.gt_labels = [l for l in self.all_labels if l == self.class_update_order[0]]

    def _construct_knn_graph(self, X):
        nn = NearestNeighbors(n_neighbors=self.k + 1).fit(X)
        distances, indices = nn.kneighbors(X)
        return indices

    def _get_edges_to_add(self, prev_nodes, new_nodes):
        edges_to_add = []

        for i in new_nodes:
            for j in self.nns[i]:
                if j in prev_nodes or j in new_nodes:
                    edges_to_add.append((i, j))

        return edges_to_add

    def get_node_insertions_at_time(self, t):
        return self.node_insertions_by_time[t]

    def get_node_deletions_at_time(self, t):
        return self.node_deletions_by_time[t]

    def get_edge_insertions_at_time(self, t):
        return self.edge_insertions_by_time[t]

    def get_edge_deletions_at_time(self, t):
        return self.edge_deletions_by_time[t]

    def set_iteration(self, t):
        if t >= self.num_batches:
            t = (self.num_batches - 2) - (t - self.num_batches)
        gt_labels_dict = {}
        for v in self.node_update_order[:(t+1)*self.batch_size]:
            gt_labels_dict[v] = self.all_labels[v]
        self.gt_labels = [v for k, v in sorted(gt_labels_dict.items())]
        self.num_clusters = self.get_num_clusters_at_time(t)

    def get_num_clusters_at_time(self, t):
        return self.num_clusters_by_time[t]


class DynamicMNISTDataset(DynamicKNNDataset):
    def __init__(self, k=5, images_per_class=None, batch_size=1000, include_deletions=False):
        data_filename = 'data/mnist.pkl'
        if os.path.exists(data_filename):
            with open(data_filename, 'rb') as f:
                mnist = pickle.load(f)
        else:
            mnist = fetch_openml('mnist_784', version=1)
            with open(data_filename, 'wb') as f:
                pickle.dump(mnist, f)

        X = np.array(mnist.data)
        y = np.array(mnist.target).astype(int)

        if images_per_class is not None:
            X, y = self._select_subset(X, y, images_per_class)

        super().__init__(X, y, k=k, batch_size=batch_size, name='mnist', include_deletions=include_deletions)

    @staticmethod
    def _select_subset(X, y, images_per_class):
        unique_labels = np.unique(y)
        selected_indices = []

        for label in unique_labels:
            indices = np.where(y == label)[0]
            if len(indices) > images_per_class:
                selected_indices.extend(np.random.choice(indices, images_per_class, replace=False))
            else:
                selected_indices.extend(indices)

        return X[selected_indices], y[selected_indices]

    def __repr__(self):
        return 'DynamicMNISTDataset'


class DynamicEMNISTDataset(DynamicKNNDataset):
    def __init__(self, k=5, images_per_class=None, batch_size=1000, include_deletions=False):
        train_images, train_labels = emnist.extract_training_samples('letters')
        test_images, test_labels = emnist.extract_test_samples('letters')

        # Combine training and test sets
        X = np.concatenate((train_images, test_images))
        y = np.concatenate((train_labels, test_labels))
        X = X.reshape((X.shape[0], -1))

        if images_per_class is not None:
            X, y = DynamicMNISTDataset._select_subset(X, y, images_per_class)

        super().__init__(X, y, k=k, batch_size=batch_size, name='emnist', include_deletions=include_deletions)

