import math
import random

import networkit as nw
import networkx as nx
import numpy as np
from sklearn.neighbors import KDTree

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)


class EmbeddingCoarsening:
    """
    Class for graph coarsening using node embeddings.
    """

    def __init__(self, G, d, shape, ratio):
        """
        Initialize the EmbeddingCoarsening instance.

        Parameters:
        - G: networkit Graph, the graph to coarsen
        - d: int, dimension of the embedding space
        - shape: str, shape parameter (unused in the current implementation)
        - ratio: float, ratio of edges to sparsify
        """
        self.G = G
        self.sG = nw.graphtools.toWeighted(G)
        self.d = d
        self.n = G.numberOfNodes()
        self.space = np.random.rand(self.n, d)
        self.shape = shape  # Currently not used in the code
        self.M = set()      # Set to store matched node pairs
        self.R = -1         # Remaining node if the number of nodes is odd
        self.ratio = ratio  # Ratio of edges to remove during sparsification

    def sparsify(self):
        """
        Sparsify the graph by removing a fraction of edges based on their distances
        in the embedding space. The shortest edges are removed while preserving
        the total edge weight by redistributing weights to adjacent edges.
        """
        if self.ratio == 0:
            return

        # Compute the number of edges to remove
        remove_count = int(self.ratio * self.sG.numberOfEdges())

        # Compute distances of each edge in the embedding space
        edge_distances = []
        edge_map = {}
        for u, v in self.sG.iterEdges():
            w = self.sG.weight(u, v)
            distance = w * np.linalg.norm(self.space[u] - self.space[v])
            edge_distances.append((distance, u, v))
            edge_map[(u, v)] = distance
            edge_map[(v, u)] = distance  # Since the graph is undirected

        # Sort edges by distance
        edge_distances.sort()

        # Remove the shortest edges
        for i in range(remove_count):
            u, v = edge_distances[i][1], edge_distances[i][2]
            w = self.sG.weight(u, v)

            # Find the minimum weight adjacent edge for u
            min_edge_u = None
            for x in self.sG.iterNeighbors(u):
                if v != x:
                    current_edge = (u, x)
                    if min_edge_u is None or edge_map[current_edge] < edge_map[min_edge_u]:
                        min_edge_u = current_edge

            # Find the minimum weight adjacent edge for v
            min_edge_v = None
            for x in self.sG.iterNeighbors(v):
                if u != x:
                    current_edge = (v, x)
                    if min_edge_v is None or edge_map[current_edge] < edge_map[min_edge_v]:
                        min_edge_v = current_edge

            # Reweight edges to preserve total edge weight
            if min_edge_u and (min_edge_v is None or edge_map[min_edge_u] < edge_map[min_edge_v]):
                u1, u2 = min_edge_u
                if self.sG.weight(u1, u2) != 0:
                    self.sG.increaseWeight(u1, u2, w)
            elif min_edge_v:
                v1, v2 = min_edge_v
                if self.sG.weight(v1, v2) != 0:
                    self.sG.increaseWeight(v1, v2, w)

            # Remove the edge
            self.sG.removeEdge(u, v)

    def optimal(self, u):
        """
        Compute the optimal position of a node in the embedding to maximize
        the distances between neighbors.

        Parameters:
        - u: int, node index

        Returns:
        - res: np.ndarray, new position of the node in the embedding space
        - change: float, magnitude of change from the previous position
        """
        k = 2 * self.sG.weightedDegree(u)
        a = 1
        b = -2 * k
        c = k ** 2 - np.sum([
            (2 * self.sG.weight(u, v) * self.space[v]) ** 2
            for v in self.sG.iterNeighbors(u)
        ])

        discriminant = b ** 2 - 4 * a * c
        if discriminant < 0:
            # No real roots, return current position
            return self.space[u], 0

        sqrt_discriminant = np.sqrt(discriminant)
        lambda1 = (-b + sqrt_discriminant) / (2 * a)
        lambda2 = (-b - sqrt_discriminant) / (2 * a)
        p_values = [k - lambda1, k - lambda2]

        positions = []
        for p in p_values:
            if p != 0:
                temp = np.sum([
                    2 * self.sG.weight(u, v) * self.space[v]
                    for v in self.sG.iterNeighbors(u)
                ], axis=0)
                positions.append((temp / p, p))

        # Choose the position that results in minimal total distance to neighbors
        min_total_distance = float('inf')
        best_position = self.space[u]
        for pos, _ in positions:
            total_distance = np.sum([
                np.linalg.norm(pos - self.space[v])
                for v in self.sG.iterNeighbors(u)
            ])
            if total_distance < min_total_distance:
                min_total_distance = total_distance
                best_position = pos

        change = np.linalg.norm(best_position - self.space[u])
        return best_position, change

    def embed(self, nodes):
        """
        Iterate through nodes and optimize their positions in the embedding
        to maximize distances between neighbors.

        Parameters:
        - nodes: list of node indices

        Returns:
        - avg_change: float, average change in node positions
        """
        total_change = 0
        for i in nodes:
            res, change = self.optimal(i)
            self.space[i] = res
            total_change += change
        avg_change = total_change / self.n
        return avg_change

    def match(self):
        """
        Greedily match each vertex with the nearest neighbor in the embedding.
        """
        n = self.sG.numberOfNodes()
        tree = KDTree(self.space)
        indices = tree.query_radius(self.space, r=0)

        used = set()
        clusters = []
        singletons = []

        # Group nodes based on proximity in the embedding space
        for idx_list in indices:
            if idx_list[0] in used:
                continue
            if len(idx_list) == 1:
                singletons.append(idx_list[0])
            else:
                clusters.append(idx_list)
                used.update(idx_list)

        # Pair nodes within clusters
        used.clear()
        for cluster in clusters:
            k = len(cluster)
            if k % 2 == 1:
                singletons.append(cluster[-1])
                k -= 1
            for i in range(0, k, 2):
                self.M.add((cluster[i], cluster[i + 1]))
                used.update([cluster[i], cluster[i + 1]])

        # Handle singletons
        indices = []
        new_space = []
        k = len(singletons)
        if k % 2 == 1:
            self.R = singletons.pop()
            used.add(self.R)
            k -= 1

        if k == 0:
            return

        for idx in singletons:
            indices.append(idx)
            new_space.append(self.space[idx])

        new_space = np.array(new_space)
        tree = KDTree(new_space)
        neighbors = tree.query(new_space, k=min(40, k), return_distance=False)

        # Match remaining nodes based on nearest neighbors
        for i, neighbor_indices in enumerate(neighbors):
            idx = indices[i]
            if idx not in used:
                for j in neighbor_indices:
                    neighbor_idx = indices[j]
                    if neighbor_idx not in used and idx != neighbor_idx:
                        if not self.sG.hasEdge(idx, neighbor_idx):
                            self.M.add((idx, neighbor_idx))
                            used.update([idx, neighbor_idx])
                            break

        # Pair any remaining unmatched nodes
        unmatched = [i for i in range(n) if i not in used]
        for i in range(0, len(unmatched) - 1, 2):
            self.M.add((unmatched[i], unmatched[i + 1]))

    def coarsen(self):
        """
        Construct a coarse graph from the matching of nodes.

        The coarsening process involves embedding optimization, sparsification,
        node matching, and building the coarse graph.
        """
        # Optimize node positions in the embedding space
        nodes = list(range(self.n))
        random.shuffle(nodes)
        change = self.embed(nodes)
        iteration = 1
        while change > 0.01 and iteration < 31:
            change = self.embed(nodes)
            iteration += 1

        # Sparsify the graph
        self.sparsify()

        # Match nodes to create the coarse graph
        self.match()

        # Initialize mappings between coarse and fine nodes
        idx = 0
        self.mapCoarseToFine = {}
        self.mapFineToCoarse = {}

        for u, v in self.M:
            self.mapCoarseToFine[idx] = [u, v]
            self.mapFineToCoarse[u] = idx
            self.mapFineToCoarse[v] = idx
            idx += 1

        if self.R != -1:
            self.mapCoarseToFine[idx] = [self.R]
            self.mapFineToCoarse[self.R] = idx
            idx += 1

        # Build the coarse graph
        self.cG = nw.graph.Graph(n=idx, weighted=True, directed=False)
        for u, v in self.sG.iterEdges():
            cu = self.mapFineToCoarse[u]
            cv = self.mapFineToCoarse[v]
            if cu != cv:
                self.cG.increaseWeight(cu, cv, self.G.weight(u, v))

        self.cG.removeSelfLoops()
        self.cG.indexEdges()
