# This file contains functions to partition large graphs

import torch_sparse
import more_itertools
import numpy as np
import torch
import networkx as nx
from sklearn.cluster import SpectralClustering

from SEPAL import SEPAL_DIR

path_to_hdrf = SEPAL_DIR / "baselines/hdrf/build"
import sys

sys.path.append(str(path_to_hdrf))
import hdrf


def streaming_hdrf_partitions(graph, num_partitions, lambda_value=1):
    print(f"Partitioning the graph into {num_partitions} partitions.")
    # Initialize HDRF
    hdrf_instance = hdrf.HDRF(num_partitions, lambda_value)

    # Initialize an empty set for each partition
    partitions = [set() for _ in range(num_partitions)]

    for triple in np.array(graph.triples_factory.mapped_triples):
        src, _, dest = triple
        partition_id = hdrf_instance.partitionEdge(src, dest)
        # Add src and dest to the partition
        partitions[partition_id].add(src)
        partitions[partition_id].add(dest)

    # Convert the sets to numpy arrays
    partitions = [np.array(list(partition)) for partition in partitions]

    sizes = [a.shape[0] for a in partitions]
    print(f"Partition sizes: min: {min(sizes)}, max: {max(sizes)}")

    return partitions


def hdrf_partitions(graph, num_partitions, lambda_val=1):
    print(f"Partitioning the graph into {num_partitions} partitions.")
    row, col = graph.adjacency.nonzero()
    edges = list(zip(row.tolist(), col.tolist()))

    partitioner = hdrf.HDRF(num_partitions, lambda_val)
    edges_partitions = np.array(partitioner.partitionGraph(edges))
    edges = np.array(edges)

    partitions = [
        np.unique(edges[np.where(edges_partitions == i)[0]])
        for i in range(num_partitions)
    ]
    sizes = [a.shape[0] for a in partitions]
    print(f"Partition sizes: min: {min(sizes)}, max: {max(sizes)}")

    return partitions


def leiden_partitions_ig(ig_graph, resolution_parameter):
    partitions = [
        np.array(i)
        for i in ig_graph.community_leiden(resolution_parameter=resolution_parameter)
    ]
    print(f"Partitioned the graph into {len(partitions)} partitions.")
    sizes = [a.shape[0] for a in partitions]
    print(f"Partition sizes: min: {min(sizes)}, max: {max(sizes)}")
    return partitions


def label_propagation_partitions_ig(ig_graph):
    partitions = [np.array(i) for i in ig_graph.community_label_propagation()]
    print(f"Partitioned the graph into {len(partitions)} partitions.")
    sizes = [a.shape[0] for a in partitions]
    print(f"Partition sizes: min: {min(sizes)}, max: {max(sizes)}")
    return partitions


def infomap_partitions_ig(ig_graph):
    partitions = [np.array(i) for i in ig_graph.community_infomap()]
    print(f"Partitioned the graph into {len(partitions)} partitions.")
    sizes = [a.shape[0] for a in partitions]
    print(f"Partition sizes: min: {min(sizes)}, max: {max(sizes)}")
    return partitions


def louvain_partitions_ig(ig_graph):
    partitions = [
        np.array(i)
        for i in ig_graph.community_multilevel(
            weights=ig_graph.es["weight"], return_levels=False
        )
    ]
    print(f"Partitioned the graph into {len(partitions)} partitions.")
    sizes = [a.shape[0] for a in partitions]
    print(f"Partition sizes: min: {min(sizes)}, max: {max(sizes)}")
    return partitions


def louvain_partitions_nx(nx_graph):
    partitions = [np.array(list(s)) for s in nx.community.louvain_communities(nx_graph)]
    print(f"Partitioned the graph into {len(partitions)} partitions.")
    sizes = [a.shape[0] for a in partitions]
    print(f"Partition sizes: min: {min(sizes)}, max: {max(sizes)}")
    return partitions


def spectral_partitions(graph, num_partitions):
    print(f"Partitioning the graph into {num_partitions} partitions.")

    clustering = SpectralClustering(
        n_clusters=num_partitions,
        eigen_solver="amg",
        affinity="precomputed",
        assign_labels="cluster_qr",
    )
    assignments = clustering.fit_predict(graph.adjacency)

    partitions = [np.where(assignments == j)[0] for j in range(num_partitions)]
    sizes = [a.shape[0] for a in partitions]
    print(f"Partition sizes: min: {min(sizes)}, max: {max(sizes)}")

    return partitions


def leading_eigenvector_partitions_ig(ig_graph, num_partitions):
    partitions = [
        np.array(i)
        for i in ig_graph.community_leading_eigenvector(clusters=num_partitions)
    ]
    print(f"Partitioned the graph into {len(partitions)} partitions.")
    sizes = [a.shape[0] for a in partitions]
    print(f"Partition sizes: min: {min(sizes)}, max: {max(sizes)}")
    return partitions


def metis_partitions(graph, num_partitions, inverse_triples_created=True):
    edge_index, _ = graph.get_edge_index_and_type()
    num_entities = graph.num_entities

    print(f"Partitioning the graph into {num_partitions} partitions.")

    # To prevent possible segfaults in the METIS C code, METIS expects a graph
    # (1) without self-loops; (2) with inverse edges added; (3) with unique edges only
    # https://github.com/KarypisLab/METIS/blob/94c03a6e2d1860128c2d0675cbbb86ad4f261256/libmetis/checkgraph.c#L18

    # Add inverse edges if necessary
    if not inverse_triples_created:
        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=-1)

    # Remove duplicates
    row, col = edge_index.unique(dim=1)

    re_ordered_adjacency, bound, perm = torch_sparse.partition(
        src=torch_sparse.SparseTensor(
            row=row, col=col, sparse_sizes=(num_entities, num_entities)
        ).to(device="cpu"),
        num_parts=num_partitions,
        recursive=True,
    )

    sizes = bound.diff()
    print(f"Partition sizes: min: {sizes.min().item()}, max: {sizes.max().item()}")

    index_map = dict(zip(np.arange(len(perm)), perm.numpy()))
    partitions = [
        np.vectorize(index_map.get)(np.arange(low, high))
        for low, high in more_itertools.pairwise(bound.tolist())
    ]

    return partitions
