# This file contains a few functions that have been tested to partition large graphs

import torch_sparse
import more_itertools
import numpy as np
import torch
import scipy.sparse
import networkx as nx
import random
from sklearn.preprocessing import quantile_transform
from sklearn.decomposition import PCA  # , RandomizedPCA
from sklearn.cluster import SpectralClustering, BisectingKMeans



def leiden_partitions_ig(ig_graph):
    partitions = [
        np.array(i) for i in ig_graph.community_leiden(resolution_parameter=1e-3)
    ]
    print(f"Partitioned the graph into {len(partitions)} partitions.")
    sizes = [a.shape[0] for a in partitions]
    print(f"Partition sizes: min: {min(sizes)}, max: {max(sizes)}")
    return partitions


def leading_eigenvector_partitions_ig(ig_graph):
    partitions = [np.array(i) for i in ig_graph.community_leading_eigenvector()]
    print(f"Partitioned the graph into {len(partitions)} partitions.")
    sizes = [a.shape[0] for a in partitions]
    print(f"Partition sizes: min: {min(sizes)}, max: {max(sizes)}")
    return partitions


def label_propagation_partitions_ig(ig_graph):
    partitions = [np.array(i) for i in ig_graph.community_label_propagation()]
    print(f"Partitioned the graph into {len(partitions)} partitions.")
    sizes = [a.shape[0] for a in partitions]
    print(f"Partition sizes: min: {min(sizes)}, max: {max(sizes)}")
    return partitions


def infomap_partitions_ig(ig_graph):
    partitions = [np.array(i) for i in ig_graph.community_infomap()]
    print(f"Partitioned the graph into {len(partitions)} partitions.")
    sizes = [a.shape[0] for a in partitions]
    print(f"Partition sizes: min: {min(sizes)}, max: {max(sizes)}")
    return partitions


def louvain_partitions_ig(ig_graph):
    partitions = [
        np.array(i)
        for i in ig_graph.community_multilevel(
            weights=ig_graph.es["weight"], return_levels=False
        )
    ]
    print(f"Partitioned the graph into {len(partitions)} partitions.")
    sizes = [a.shape[0] for a in partitions]
    print(f"Partition sizes: min: {min(sizes)}, max: {max(sizes)}")
    return partitions


def louvain_partitions_nx(nx_graph):
    partitions = [np.array(list(s)) for s in nx.community.louvain_communities(nx_graph)]
    print(f"Partitioned the graph into {len(partitions)} partitions.")
    sizes = [a.shape[0] for a in partitions]
    print(f"Partition sizes: min: {min(sizes)}, max: {max(sizes)}")
    return partitions


def spectral_partitions(graph, num_partitions):
    print(f"Partitioning the graph into {num_partitions} partitions.")

    clustering = SpectralClustering(
        n_clusters=num_partitions,
        eigen_solver="amg",
        affinity="precomputed",
        assign_labels="cluster_qr",
    )
    assignments = clustering.fit_predict(graph.adjacency)

    partitions = [np.where(assignments == j)[0] for j in range(num_partitions)]
    sizes = [a.shape[0] for a in partitions]
    print(f"Partition sizes: min: {min(sizes)}, max: {max(sizes)}")

    return partitions


def LPA_partitions(graph, num_partitions):
    """
    LPA: Label Propagation Algorithm.
    Implements a variation (where all nodes are not initialized with their own label) of the method described in:
        'Near linear time algorithm to detect community structures in large-scale networks'
        https://arxiv.org/pdf/0709.2938.pdf
    """
    print(f"Partitioning the graph into {num_partitions} partitions.")

    # set initial labels (one 1 by row)
    x = scipy.sparse.csc_matrix(
        (
            np.ones(graph.num_entities),
            (
                np.arange(graph.num_entities),
                np.random.randint(num_partitions, size=(graph.num_entities,)),
            ),
        ),
        shape=(graph.num_entities, num_partitions),
    )

    # propagate labels
    A = graph.adjacency.tocsc()
    for _ in range(15):
        x = A * x
        # keep the label shared by most neighbors
        new_x = scipy.sparse.csc_matrix(
            (
                np.ones(graph.num_entities),
                (
                    np.arange(graph.num_entities),
                    np.asarray(np.argmax(x, axis=1)).ravel(),
                ),
            ),
            shape=(graph.num_entities, num_partitions),
        )
        x = new_x

    assignments = np.asarray(np.argmax(x, axis=1)).ravel()
    partitions = [np.where(assignments == j)[0] for j in range(num_partitions)]
    sizes = [a.shape[0] for a in partitions]
    print(f"Partition sizes: min: {min(sizes)}, max: {max(sizes)}")
    return partitions


def diffusion_map_argmax_partitions(graph, num_partitions, n_steps=7):
    print(f"Partitioning the graph into {num_partitions} partitions.")

    # Normalized adjacency
    A = graph.get_right_normalized_adjacency(self_loops=True).tocsc()

    # set diffusers (1 one by column)
    x = scipy.sparse.csc_matrix(
        (
            np.ones(num_partitions),
            (
                np.random.randint(graph.num_entities, size=(num_partitions,)),
                np.arange(num_partitions),
            ),
        ),
        shape=(graph.num_entities, num_partitions),
    )

    # propagate
    for _ in range(n_steps):
        x = A * x

    # Quantile normalization
    x = quantile_transform(x)

    # take max col by row
    assignments = np.asarray(np.argmax(x, axis=1)).ravel()

    partitions = [np.where(assignments == j)[0] for j in range(num_partitions)]
    sizes = [a.shape[0] for a in partitions]
    print(f"Partition sizes: min: {min(sizes)}, max: {max(sizes)}")

    return partitions


def diffusion_map_kmeans_partitions(graph, num_partitions, n_steps=7):
    print(f"Partitioning the graph into {num_partitions} partitions.")

    # Normalized adjacency
    A = graph.get_right_normalized_adjacency(self_loops=True).tocsc()

    # set diffusers (1 one by column)
    x = np.zeros((graph.num_entities, 100))
    for j in range(x.shape[1]):
        i = random.randint(0, graph.num_entities - 1)
        x[i, j] = 1

    # propagate
    for _ in range(n_steps):
        x = A * x

    # Quantile normalization
    x = quantile_transform(x)

    # randomized PCA
    pca = PCA(n_components=10, svd_solver="randomized")
    x = pca.fit_transform(x)

    # cluster entities
    kmeans = BisectingKMeans(
        n_clusters=num_partitions,
        init="k-means++",
        bisecting_strategy="largest_cluster",
    )
    assignments = kmeans.fit_predict(x)

    partitions = [np.where(assignments == j)[0] for j in range(num_partitions)]
    sizes = [a.shape[0] for a in partitions]
    print(f"Partition sizes: min: {min(sizes)}, max: {max(sizes)}")

    return partitions


def metis_partitions(graph, num_partitions, inverse_triples_created=True):
    edge_index, num_entities = graph.edge_index, graph.num_entities

    print(f"Partitioning the graph into {num_partitions} partitions.")

    # To prevent possible segfaults in the METIS C code, METIS expects a graph
    # (1) without self-loops; (2) with inverse edges added; (3) with unique edges only
    # https://github.com/KarypisLab/METIS/blob/94c03a6e2d1860128c2d0675cbbb86ad4f261256/libmetis/checkgraph.c#L18

    # Add inverse edges if necessary
    if not inverse_triples_created:
        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=-1)

    # Remove duplicates
    row, col = edge_index.unique(dim=1)

    re_ordered_adjacency, bound, perm = torch_sparse.partition(
        src=torch_sparse.SparseTensor(
            row=row, col=col, sparse_sizes=(num_entities, num_entities)
        ).to(device="cpu"),
        num_parts=num_partitions,
        recursive=True,
    )

    sizes = bound.diff()
    print(f"Partition sizes: min: {sizes.min().item()}, max: {sizes.max().item()}")

    index_map = dict(zip(np.arange(len(perm)), perm.numpy()))
    partitions = [
        np.vectorize(index_map.get)(np.arange(low, high))
        for low, high in more_itertools.pairwise(bound.tolist())
    ]

    return partitions
