
import os
import sys
import json
import yaml
import pickle
from collections import defaultdict

import numpy as np
import pandas as pd
from tqdm import tqdm

from scipy.spatial.distance import squareform
from scipy.cluster.hierarchy import linkage, fcluster
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score, silhouette_score


from utils.data.load_graph import load_graph, get_labels_from_graph, get_clusters_from_graph
from utils.data.positional_encoding import landmark_spd_features
from utils.evaluation.compute_embeddings import compute_embeddings
from utils.evaluation.hierarchical_clustering import create_leiden_partition, blockwise_pdist
from utils.evaluation.purity import get_samples_for_purity, compute_leiden_cluster_purity


ROOT_DIR = os.path.dirname(__file__)
sys.path.append(ROOT_DIR)
DATASET_DIR = os.path.join(ROOT_DIR, "dataset")


def evaluate(config: dict):

    # load the test graph
    graph = load_graph(folder=DATASET_DIR, tp="test", undirected=config["training"]["undirected"],
                       use_edge_features=config["training"]["use_edge_features"])
    if config["positional_encoding"]["use"]:
        graph = landmark_spd_features(data=graph, num_landmarks=config["positional_encoding"]["num_landmarks"])
    filtered_nodes, ground_truth_clusters = get_clusters_from_graph(graph=graph, min_adjacency=2)
    ground_truth_labels: np.array = get_labels_from_graph(graph=graph)

    # create the first partition with the Leiden algorithm
    leiden_path = os.path.join(config["save_folder"], "filtered_partition.pkl")
    if not os.path.exists(leiden_path):
        partition = create_leiden_partition(graph=graph, resolution_parameter=1., max_comm_size=65000,
                                            filtered_nodes=filtered_nodes)
        with open(leiden_path, "wb") as f:
            pickle.dump(partition, f)
    with open(leiden_path, "rb") as f:
        partition: list[list[int]] = pickle.load(f)

    # compute the nodes' embeddings
    embeddings = compute_embeddings(graph=graph,
                                    saved_model_file=os.path.join(config["save_folder"], "model.pt"),
                                    config=config)

    # divide each leiden cluster into hierarchical sub-clusters
    thresholds = np.round(np.geomspace(0.00001, 1., 150), 10)
    silhouette_scores = defaultdict(list)
    total_leiden_visited = 0

    for i, leiden_cluster in enumerate(tqdm(partition, leave=False)):

        if len(leiden_cluster) < 3:
            continue
        if len(leiden_cluster) > 10000:
            continue

        total_leiden_visited += 1

        emb_leiden_cluster = embeddings[leiden_cluster]
        cos_distances = blockwise_pdist(emb_leiden_cluster, block_size=20000)
        cos_distances_mat = squareform(cos_distances)
        linkage_z = linkage(cos_distances, method="average")

        # simulate cuts and compute silhouette score
        for t in thresholds:
            sub_labels = fcluster(linkage_z, t=t, criterion="distance")
            k = len(np.unique(sub_labels))
            if 2 <= k <= len(leiden_cluster) - 1:
                sil = silhouette_score(cos_distances_mat, sub_labels, metric="precomputed")
                silhouette_scores[t].append(sil)
            else:
                silhouette_scores[t].append(0.)

    silhouette_scores = pd.Series({k: np.mean(v) for k, v in silhouette_scores.items()
                                   if len(v) / total_leiden_visited > 0.3})
    opt_threshold, _ = max(silhouette_scores.items(), key=lambda kv: kv[1])

    # get the samples for purity computation
    samples_purity = get_samples_for_purity(partition=partition, labels=ground_truth_labels,
                                            clusters=ground_truth_clusters,
                                            num_samples=10000)
    total, purity = 0, 0.

    # create the optimal clustering
    opt_clustering = []

    for i, leiden_cluster in tqdm(enumerate(partition), total=len(partition), leave=False):

        if len(leiden_cluster) < 3:
            opt_clustering.append(leiden_cluster)
            continue

        emb_leiden_cluster = embeddings[leiden_cluster]
        cos_distances = blockwise_pdist(emb_leiden_cluster, block_size=20000)
        linkage_z = linkage(cos_distances, method="average")
        labels_leiden_clusters = fcluster(linkage_z, t=opt_threshold, criterion="distance")
        sub_clusters = defaultdict(list)

        # compute the sub-clusters
        for node, label in zip(leiden_cluster, labels_leiden_clusters):
            sub_clusters[label].append(node)
        for cluster in sub_clusters.values():
            opt_clustering.append(cluster)

        # compute the purity
        leiden_cluster_purity = compute_leiden_cluster_purity(linkage_z=linkage_z, leiden_cluster=leiden_cluster,
                                                              samples=samples_purity[i], labels=ground_truth_labels)
        purity += leiden_cluster_purity
        total += len(samples_purity[i])

    purity /= total

    # create the prediction
    pred_labels = np.full(len(ground_truth_labels), -1)
    for cluster_id, cluster in enumerate(opt_clustering):
        for node_id in cluster:
            pred_labels[node_id] = cluster_id

    # compute the scores
    mask = pred_labels != -1
    gt, pr = ground_truth_labels[mask], pred_labels[mask]
    nmi = normalized_mutual_info_score(gt, pr)
    ari = adjusted_rand_score(gt, pr)
    results = {"purity": purity, "nmi": nmi, "ari": ari, "opt_threshold": opt_threshold}

    bins = [(0,10), (10,100), (100,1000), (1000,np.inf)]
    for lo, hi in bins:
        idx = np.isin(gt, [c for c in np.unique(gt) if lo < np.sum(gt==c) <= hi])
        if np.any(idx):
            nmi = normalized_mutual_info_score(gt[idx], pr[idx])
            ari = adjusted_rand_score(gt[idx], pr[idx])
            results[f"nmi_{lo}_{hi}"] = nmi
            results[f"ari_{lo}_{hi}"] = ari

    return results


if __name__ == "__main__":

    if not os.path.exists("config.yaml"):
        raise FileNotFoundError("config.yaml not found, please create one following 'example_config.yaml'")
    conf = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
    res = evaluate(conf)
    # save the results
    with open(os.path.join(conf["save_folder"], "results.json"), "w") as f:
        json.dump(res, f, indent=4)

