import pandas as pd
from pathlib import Path
from time import time
from memory_profiler import memory_usage
import signal
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.sparse.csgraph import connected_components

from SEPAL.dataloader import DataLoader
from SEPAL.knowledge_graph import KnowledgeGraph
from SEPAL.partitioning import get_LPA_partitions, get_metis_partitions, get_spectral_partitions, get_diffusion_map_argmax_partitions, get_diffusion_map_kmeans_partitions
from SEPAL.utils import create_graph


sepal_dir = Path(__file__).absolute().parents[2]
results_path = Path(__file__).absolute().parent / "results.parquet"


datasets = ["mini_yago3", "yago3", "yago4"]


names = {
    "mini_yago3": "Mini Yago3",
    "yago3": "Yago3",
    "yago4": "Yago4",
    }

num_partitions = {
    "mini_yago3": 5,
    "yago3": 10,
    "yago4": 30,
    }

n_steps = {
    "mini_yago3": 6,
    "yago3": 7,
    "yago4": 8,
    }

time_intervals = {
    "mini_yago3": .1,
    "yago3": 1,
    "yago4": 10,
    }

methods = {
    "spectral": get_spectral_partitions,
    "LPA": get_LPA_partitions,
    "argmax": get_diffusion_map_argmax_partitions,
    "KMeans": get_diffusion_map_kmeans_partitions,
    "metis": get_metis_partitions,
}

def evaluate_partitions(partitions, graph):
    """
    Check partition balance and connectivity
    """
    sizes = [a.shape[0] for a in partitions]
    balance = min(sizes) / max(sizes)

    nb_cc = []
    for k in range(len(partitions)):
        node_array = partitions[k]
        nb_cc.append(connected_components(graph.adjacency[node_array, :][:, node_array], return_labels=False))

    return sizes, balance, nb_cc


def handler(signum, frame):
   raise Exception("end of time")




def make_plots():
    # Load results
    df = pd.read_parquet(results_path)
    df = df.sort_values('method')

    # Remove methods that failed
    df = df[~df.isna().any(axis=1)]

    # Change unit for memory
    df["memory"] = df["memory"] / 1024

    # Compute relevant quality metrics
    df["nb_cc/partition"] = df["nb_cc"].apply(np.mean)
    df["size_std"] = df["sizes"].apply(np.std)

    for data in datasets:
        df2 = df[df["data"] == data]
        fig, axes = plt.subplots(nrows=2, ncols=2, sharex="col", sharey="row", figsize=(7,7))
        fig.suptitle(names[data], x=0.52, y=1.05)

        sns.scatterplot(x='time', y='nb_cc/partition', data=df2, hue='method', ax=axes[0,0])
        sns.scatterplot(x='memory', y='nb_cc/partition', data=df2, hue='method', ax=axes[0,1], legend=False)
        sns.scatterplot(x='time', y='size_std', data=df2, hue='method', ax=axes[1,0], legend=False)
        sns.scatterplot(x='memory', y='size_std', data=df2, hue='method', ax=axes[1,1], legend=False)

        axes[0,0].set_xscale('log')
        axes[0,0].set_yscale('log')
        axes[1,1].set_xscale('log')
        axes[1,1].set_yscale('log')

        axes[1,0].set_xlabel('Time (s)')
        axes[0,0].set_ylabel('# Connected Components / partition')
        axes[1,0].set_ylabel('Partition sizes STD')
        axes[1,1].set_xlabel('Memory usage (GiB)')

        handles, labels = axes[0,0].get_legend_handles_labels()
        if data == "mini_yago3":
            fig.legend(handles, labels, title="Partitioning algotithm", fancybox=True, ncol=5, bbox_to_anchor=(0.92, 1.02))
        else:
            fig.legend(handles, labels, title="Partitioning algotithm", fancybox=True, ncol=4, bbox_to_anchor=(0.85, 1.02))
        axes[0,0].get_legend().remove()
        fig.tight_layout()

        plt.annotate('', xy=(1, -0.25), xycoords='axes fraction', xytext=(-1.05, -0.25), arrowprops=dict(arrowstyle="->", color='r'))
        plt.annotate('Cost ↗', xy=(-0.1, -0.3), xycoords='axes fraction', color="r")
        plt.annotate('', xy=(1.1, 0), xycoords='axes fraction', xytext=(1.1, 2.05), arrowprops=dict(arrowstyle="->", color='g'))
        plt.annotate('Quality ↗', xy=(1.12, 0.95), xycoords='axes fraction', rotation=90, color="g")

        plt.savefig(Path(__file__).absolute().parent / f"{data}_partitioning.pdf", bbox_inches='tight')
        
        # Try to add confidence ellipses: https://matplotlib.org/stable/gallery/statistics/confidence_ellipse.html
        #df2.groupby(["method"])["time"].agg(['mean', 'std'])
        #df2.groupby(["method"])["nb_cc/partition"].agg(['mean', 'std'])
        #df2.groupby(["method"])["size_std"].agg(['mean', 'std'])

    return





if __name__ == "__main__":

    for data in datasets:
        print(f"---------- {data} ----------")
        graph = create_graph(data)
        for method in methods.keys():

            if method == "spectral" and data in ["yago3", "yago4"]:
                continue

            print(method)
            num_partition = num_partitions[data]

            signal.signal(signal.SIGALRM, handler)
            signal.alarm(12000)

            try:
                start = time()
                if method in ["argmax", "KMeans"]:
                    partitions = methods[method](num_partition, graph, n_steps[data])
                else:
                    partitions = methods[method](num_partition, graph)
                end = time()
                duration = end - start
                signal.alarm(0)

                sizes, balance, nb_cc = evaluate_partitions(partitions, graph)

                signal.alarm(12000)

                try:
                    if method in ["argmax", "KMeans"]:
                        memory = max(memory_usage((methods[method], (), {"num_partitions": num_partition, "graph": graph, "n_steps": n_steps[data]}), interval = time_intervals[data]))
                    else:
                        memory = max(memory_usage((methods[method], (), {"num_partitions": num_partition, "graph": graph}), interval = time_intervals[data]))

                except:
                    memory = None


            except Exception as exc:
                print(exc)
                print(f"{method} failed for {data}.")
                duration = None
                sizes = None
                balance = None
                nb_cc = None
                memory = None


            signal.alarm(0)

            # Save results
            res_dict = {
                "data": data,
                "method": method,
                "num_partitions": num_partition,
                "balance": balance,
                "sizes": [sizes],
                "nb_cc": [nb_cc],
                "time": duration,
                "memory": memory,
            }

            new_results = pd.DataFrame(res_dict, index=[0])

            if Path(results_path).is_file():
                results = pd.read_parquet(results_path)
                results = pd.concat([results, new_results]).reset_index(drop=True)

            else:
                results = new_results
                
            results.to_parquet(results_path, index=False)

