
import numpy as np

import igraph as ig
import leidenalg as la
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
from networkx.algorithms.community import louvain_communities
from torch_geometric.utils import to_networkx
from collections import defaultdict

from utils.data.load_graph import load_graph, get_labels_from_graph, get_clusters_from_graph


def get_pred_from_clusters(clusters, num_nodes: int, filtered_nodes: set[int]) -> np.array:
    labels = [-1] * num_nodes
    for cluster_id, cluster in enumerate(clusters):
            for node in cluster:
                if node in filtered_nodes:
                    labels[node] = cluster_id
    return np.array(labels)


def nx_to_igraph(G):
    nodes = list(G.nodes())
    idx_of = {u: i for i, u in enumerate(nodes)}
    edges = [(idx_of[u], idx_of[v]) for u, v in G.edges()]
    weights = [G[u][v].get("weight", 1.0) for u, v in G.edges()]
    g = ig.Graph(n=len(nodes), edges=edges, directed=False)
    g.es["weight"] = weights
    g.vs["node_id"] = nodes
    return g, nodes


def evaluate_baseline(baseline: str, n: int, seed: int = 0) -> dict:

    # load the validation graph
    graph = load_graph(folder="../../dataset", tp="valid", undirected=False, use_edge_features=False)
    filtered_nodes, ground_truth_clusters = get_clusters_from_graph(graph=graph, min_adjacency=2)
    ground_truth_labels = get_labels_from_graph(graph=graph)

    nx_graph = to_networkx(data=graph, to_undirected=True)
    if baseline == "leiden":
        ig_graph, ig_nodes = nx_to_igraph(nx_graph)
    else:
        ig_graph, ig_nodes = None, None

    # grid-search of the best resolution parameter
    best_resolution = None
    best_score = None
    for resolution in np.linspace(0.5, 3.0, n):
        if baseline == "louvain":
            clusters = louvain_communities(nx_graph, seed=seed, weight="weight", resolution=resolution)
        elif baseline == "leiden":
            part = la.RBConfigurationVertexPartition(ig_graph, resolution_parameter=resolution)
            opt = la.Optimiser()
            opt.set_rng_seed(seed)
            opt.optimise_partition(part, n_iterations=10)
            membership = part.membership
            comm_to_nodes = defaultdict(set)
            for idx, comm_id in enumerate(membership):
                node_id = ig_nodes[idx]
                comm_to_nodes[comm_id].add(node_id)
            clusters = list(comm_to_nodes.values())
        else:
            raise ValueError("Baseline must be 'louvain' or 'leiden'")
        pred_labels = get_pred_from_clusters(clusters=clusters, num_nodes=len(ground_truth_labels),
                                             filtered_nodes=filtered_nodes)
        mask = pred_labels != -1
        nmi = normalized_mutual_info_score(ground_truth_labels[mask], pred_labels[mask])
        if best_score is None or nmi > best_score:
            best_score = nmi
            best_resolution = resolution

    # load the test graph
    graph = load_graph(folder="../../dataset", tp="test", undirected=False, use_edge_features=False)
    filtered_nodes, ground_truth_clusters = get_clusters_from_graph(graph=graph, min_adjacency=2)
    ground_truth_labels = get_labels_from_graph(graph=graph)

    nx_graph = to_networkx(data=graph, to_undirected=True)
    if baseline == "leiden":
        ig_graph, ig_nodes = nx_to_igraph(nx_graph)
    else:
        ig_graph, ig_nodes = None, None

    if baseline == "louvain":
        clusters = louvain_communities(nx_graph, seed=seed, weight="weight", resolution=best_resolution)
    elif baseline == "leiden":
        part = la.RBConfigurationVertexPartition(ig_graph, resolution_parameter=best_resolution)
        opt = la.Optimiser()
        opt.set_rng_seed(seed)
        opt.optimise_partition(part, n_iterations=20)
        membership = part.membership
        comm_to_nodes = defaultdict(set)
        for idx, comm_id in enumerate(membership):
            node_id = ig_nodes[idx]
            comm_to_nodes[comm_id].add(node_id)
        clusters = list(comm_to_nodes.values())
    else:
        raise ValueError("Baseline must be 'louvain' or 'leiden'")
    pred_labels = get_pred_from_clusters(clusters=clusters, num_nodes=len(ground_truth_labels),
                                         filtered_nodes=filtered_nodes)
    mask = pred_labels != -1
    nmi = normalized_mutual_info_score(ground_truth_labels[mask], pred_labels[mask])
    ari = adjusted_rand_score(ground_truth_labels[mask], pred_labels[mask])

    return {"baseline": baseline, "nmi": nmi, "ari": ari}


if __name__ == "__main__":

    # Leiden
    nmi_scores, ari_scores = [], []
    for seed in range(5):
        res = evaluate_baseline(baseline="leiden", n=10, seed=seed)
        nmi_scores.append(res["nmi"])
        ari_scores.append(res["ari"])
        print(f"Leiden - "
              f"NMI: {np.mean(nmi_scores):.3f} (+- {np.std(nmi_scores):.3f}), "
              f"ARI: {np.mean(ari_scores):.3f} (+- {np.std(ari_scores):.3f})")

    # Louvain
    nmi_scores, ari_scores = [], []
    for seed in range(5):
        res = evaluate_baseline(baseline="louvain", n=10, seed=seed)
        nmi_scores.append(res["nmi"])
        ari_scores.append(res["ari"])
        print(f"Louvain - "
              f"NMI: {np.mean(nmi_scores):.3f} (+- {np.std(nmi_scores):.3f}), "
              f"ARI: {np.mean(ari_scores):.3f} (+- {np.std(ari_scores):.3f})")





















