import pandas as pd
from pathlib import Path
from time import time
from memory_profiler import memory_usage
import signal
import torch
import scipy.sparse
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from pykeen.triples import TriplesFactory

from SEPAL.dataloader import DataLoader
from SEPAL.knowledge_graph import KnowledgeGraph
from SEPAL.partitioning import (
    LPA_partitions,
    metis_partitions,
    spectral_partitions,
    diffusion_map_argmax_partitions,
    diffusion_map_kmeans_partitions,
    louvain_partitions_ig,
    louvain_partitions_nx,
    leiden_partitions_ig,
    leading_eigenvector_partitions_ig,
    label_propagation_partitions_ig,
    infomap_partitions_ig,
)
from SEPAL.utils import keep_only_largest_cc, tf2nx, tf2igraph, create_graph
from SEPAL.subgraph_generation import generate_subgraphs




## TODO: write a decorator that reports time and peak memory usage
def timeis(func): 
    '''Decorator that reports the execution time.'''
  
    def wrap(*args, **kwargs): 
        start = time.time() 
        result = func(*args, **kwargs) 
        end = time.time() 
          
        print(func.__name__, end-start) 
        return result 
    return wrap



sepal_dir = Path(__file__).absolute().parents[2]
results_path = Path(__file__).absolute().parent / "application_oriented_results.parquet"


datasets = ["mini_yago3", "yago3", "yago4", "yago4.5"]


names = {
    "mini_yago3": "Mini Yago3",
    "yago3": "Yago3",
    "yago4": "Yago4",
    "yago4.5": "Yago4.5",
    }

num_partitions = {
    "mini_yago3": 5,
    "yago3": 10,
    "yago4": 30,
    "yago4.5": 30,
    }

n_steps = {
    "mini_yago3": 6,
    "yago3": 7,
    "yago4": 8,
    "yago4.5": 9,
    }

time_intervals = {
    "mini_yago3": .1,
    "yago3": 1,
    "yago4": 10,
    "yago4.5": 10,
    }

core_proportions = {
    "mini_yago3": .05,
    "yago3": .05,
    "yago4": .03,
    "yago4.5": .03,
    }

max_sizes = {
    "mini_yago3": 4e4,
    "yago3": 4e5,
    "yago4": 2e6,
    "yago4.5": 2e6,
    }

ngens = {
    "mini_yago3": 75,
    "yago3": 200,
    "yago4": 500,
    "yago4.5": 400,
    }

methods = {
    #"diffusion": generate_subgraphs,
    #"LPA": get_LPA_partitions,
    #"argmax": diffusion_map_argmax_partitions,
    #"KMeans": diffusion_map_kmeans_partitions,
    #"metis": metis_partitions,
    "louvain-ig": louvain_partitions_ig,
    "leiden-ig": leiden_partitions_ig,
    "eigen-ig": leading_eigenvector_partitions_ig,
    "LPA-ig": label_propagation_partitions_ig,
    "infomap-ig": infomap_partitions_ig,
    "louvain-nx": louvain_partitions_nx,
    #"spectral": get_spectral_partitions,
}



def get_subgraph_entities(graph, prop):
    print("Extracting subgraph...")
    # Select highest degree nodes
    print("... getting highest degree nodes")
    edges = graph.mapped_triples[:,[0,2]]
    node_list_tensor = torch.argsort(torch.IntTensor(graph.degrees), descending=True)[:int(prop*graph.num_entities)+1]
    node_list = node_list_tensor.tolist()

    # Keep only the largest connected component of the subgraph
    print("... keeping largest connected component")
    mask = torch.isin(edges, node_list_tensor).all(axis=1)
    subgraph_edges = edges[mask]
    num_nodes = len(node_list)
    reindex = {node_list[i]:i for i in range(len(node_list))}
    subgraph_edges = torch.tensor(np.vectorize(reindex.__getitem__)(subgraph_edges))
    subgraph_adjacency = scipy.sparse.csr_matrix((np.ones(len(subgraph_edges)), (subgraph_edges[:,0], subgraph_edges[:,1])), shape=(num_nodes, num_nodes))
    n_components, labels = scipy.sparse.csgraph.connected_components(subgraph_adjacency, directed=False, return_labels=True)
    selected_nodes = np.where(labels == np.argmax(np.bincount(labels)))[0]
    reindex_inv = {v:k for k,v in reindex.items()}
    node_list = [reindex_inv[i] for i in selected_nodes]

    return node_list


def compute_core_subgraph(graph, prop):
    print("Extracting subgraph...")
    ## Select the subgraph
    # The proportion of nodes to keep
    prop = prop

    # Select highest degree nodes
    print("... getting highest degree nodes")
    edges = graph.mapped_triples[:,[0,2]]
    node_list_tensor = torch.argsort(torch.IntTensor(graph.degrees), descending=True)[:int(prop*graph.num_entities)+1]
    node_list = node_list_tensor.tolist()
    mask = torch.isin(edges, node_list_tensor).all(axis=1)
    subgraph_edges = edges[mask]
    num_nodes = len(node_list)
    reindex = {node_list[i]:i for i in range(len(node_list))}
    subgraph_edges = torch.tensor(np.vectorize(reindex.__getitem__)(subgraph_edges))

    # Keep only the largest connected component of the subgraph
    print("... keeping largest connected component")
    subgraph_adjacency = scipy.sparse.csr_matrix((np.ones(len(subgraph_edges)), (subgraph_edges[:,0], subgraph_edges[:,1])), shape=(num_nodes, num_nodes))
    n_components, labels = scipy.sparse.csgraph.connected_components(subgraph_adjacency, directed=False, return_labels=True)
    selected_nodes = np.where(labels == np.argmax(np.bincount(labels)))[0]
    reindex_inv = {v:k for k,v in reindex.items()}
    node_list = [reindex_inv[i] for i in selected_nodes]


    ## Build triples_factory
    print("... building triples factory")
    # Remove entities that are outside of the subgraph
    old_tf = graph.triples_factory
    mask = torch.isin(old_tf.mapped_triples[:,[0,2]], torch.IntTensor(node_list)).all(axis=1)
    mapped_triples = old_tf.mapped_triples[mask]

    # Reindex subgraph entities between 0 and n-1
    d = {node_list[i]:i for i in range(len(node_list))}
    mapped_triples[:,[0,2]] = torch.tensor(np.vectorize(d.__getitem__)(mapped_triples[:,[0,2]]))
    
    # Build the new entity_to_id dictionnary
    old_id_to_entity = {v:k for k,v in old_tf.entity_to_id.items()}
    d_inv = {v:k for k,v in d.items()}
    id_to_entity = {i:old_id_to_entity[d_inv[i]] for i in range(len(node_list))}
    entity_to_id = {v:k for k,v in id_to_entity.items()}



    # Create triple factory object
    triples_factory = TriplesFactory(
        mapped_triples=mapped_triples,
        entity_to_id=entity_to_id,
        relation_to_id=old_tf.relation_to_id,
        create_inverse_triples=old_tf.create_inverse_triples,
    )
    
    ## Build knowledge_graph instance
    print("... building knowledge graph object")
    subgraph = KnowledgeGraph(triples_factory)
    graph.core_subgraph_idx = node_list
    
    print(f"Core subgraph contains {subgraph.num_entities} entities ({subgraph.num_entities/graph.num_entities:.1%} of total graph)")
    
    return


def evaluate_partitions(partitions, graph, core_entities):
    """
    Check partition balance and connectivity
    """
    sizes = [a.shape[0] for a in partitions]
    balance = min(sizes) / max(sizes)

    nb_cc = []
    for k in range(len(partitions)):
        node_array = np.union1d(partitions[k], core_entities)
        # Count the connected components
        nb_cc.append(scipy.sparse.csgraph.connected_components(graph.adjacency[node_array, :][:, node_array], return_labels=False))

    return sizes, balance, nb_cc


def handler(signum, frame):
   raise Exception("end of time")




def make_plots():
    # Load results
    df = pd.read_parquet(results_path)
    df = df.sort_values('method')

    # Remove methods that failed
    df = df[~df.loc[:, df.columns != 'edge_coverage'].isna().any(axis=1)]

    # Change unit for memory
    df["memory"] = df["memory"] / 1024

    # Compute relevant quality metrics
    df["nb_cc/partition"] = df["nb_cc"].apply(np.mean)
    df["size_std"] = df["sizes"].apply(np.std)

    for data in datasets:
        df2 = df[df["data"] == data]
        fig, axes = plt.subplots(nrows=2, ncols=2, sharex="col", sharey="row", figsize=(7,7))
        fig.suptitle(names[data] + " partitions + core subgraph", x=0.52, y=1.05)

        sns.scatterplot(x='time', y='nb_cc/partition', data=df2, hue='method', ax=axes[0,0])
        sns.scatterplot(x='memory', y='nb_cc/partition', data=df2, hue='method', ax=axes[0,1], legend=False)
        sns.scatterplot(x='time', y='size_std', data=df2, hue='method', ax=axes[1,0], legend=False)
        sns.scatterplot(x='memory', y='size_std', data=df2, hue='method', ax=axes[1,1], legend=False)

        axes[0,0].set_xscale('log')
        axes[0,0].set_yscale('log')
        axes[1,1].set_xscale('log')
        axes[1,1].set_yscale('log')

        axes[1,0].set_xlabel('Time (s)')
        axes[0,0].set_ylabel('# Connected Components / partition')
        axes[1,0].set_ylabel('Partition sizes STD')
        axes[1,1].set_xlabel('Memory usage (GiB)')

        handles, labels = axes[0,0].get_legend_handles_labels()
        if data == "mini_yago3":
            fig.legend(handles, labels, title="Partitioning algotithm", fancybox=True, ncol=6, bbox_to_anchor=(1.02, 1.02))
        elif data == "yago3":
            fig.legend(handles, labels, title="Partitioning algotithm", fancybox=True, ncol=5, bbox_to_anchor=(0.92, 1.02))
        else:
            fig.legend(handles, labels, title="Partitioning algotithm", fancybox=True, ncol=4, bbox_to_anchor=(0.85, 1.02))
        axes[0,0].get_legend().remove()
        fig.tight_layout()

        plt.annotate('', xy=(1, -0.25), xycoords='axes fraction', xytext=(-1.05, -0.25), arrowprops=dict(arrowstyle="->", color='r'))
        plt.annotate('Cost ↗', xy=(-0.1, -0.3), xycoords='axes fraction', color="r")
        plt.annotate('', xy=(1.1, 0), xycoords='axes fraction', xytext=(1.1, 2.05), arrowprops=dict(arrowstyle="->", color='g'))
        plt.annotate('Quality ↗', xy=(1.12, 0.95), xycoords='axes fraction', rotation=90, color="g")

        plt.savefig(Path(__file__).absolute().parent / f"{data}_applied_partitioning.pdf", bbox_inches='tight')

    return





def plot_quality_metrics():
    # Load results
    df = pd.read_parquet(results_path)
    df = df.sort_values('method')

    # Remove methods that failed
    df = df[~df.loc[:, df.columns != 'edge_coverage'].isna().any(axis=1)]

    for data in datasets:
        df2 = df[df["data"] == data][["method", "nb_cc", "sizes"]]
        nbcc_col = df2.nb_cc.apply(pd.Series).stack().reset_index(level=1, drop=True)
        nbcc_col.name = "nb_cc"
        size_col = df2.sizes.apply(pd.Series).stack().reset_index(level=1, drop=True)
        size_col.name = "sizes"
        df2 = df2.drop(['nb_cc', 'sizes'], axis=1).join(pd.concat([nbcc_col, size_col],axis=1))

        if data == "mini_yago3":
            df2 = df2[df2.method.isin(["louvain-ig", "eigen-ig", "argmax", "metis", "diffusion"])]

        elif data == "yago3":
            df2 = df2[df2.method != "louvain-ig"]

        sns.scatterplot(data=df2, x="sizes", y="nb_cc", hue="method")
        if data == "mini_yago3":
            plt.vlines(x=129493/5, ymin=0, ymax=3e3, colors="k", ls='--')
            plt.vlines(x=4e4, ymin=0, ymax=3e3, colors="r", lw=2)

        elif data == "yago3":
            plt.vlines(x=2570716/10, ymin=0, ymax=5e4, colors="k", ls='--')
            plt.vlines(x=4e5, ymin=0, ymax=5e4, colors="r", lw=2)

        elif data == "yago4":
            plt.vlines(x=37959466/30, ymin=0, ymax=3e5, colors="k", ls='--')
            plt.vlines(x=2e6, ymin=0, ymax=3e5, colors="r", lw=2)

        elif data == "yago4.5":
            plt.vlines(x=32408607/30, ymin=0, ymax=3e5, colors="k", ls='--')
            plt.vlines(x=2e6, ymin=0, ymax=3e5, colors="r", lw=2)

        plt.xlabel("Partition sizes")
        plt.ylabel("Number of connected components")
        plt.xscale("log")
        plt.yscale("log")
        plt.title(f"Quality metrics - {names[data]}")
        plt.savefig(Path(__file__).absolute().parent / f"{data}_sizes_vs_nbcc.pdf", bbox_inches='tight')
        plt.show()
    return


def plot_edge_coverage():
    # Load results
    df = pd.read_parquet(results_path)
    df = df.sort_values('data')

    # Remove methods that failed
    df = df[~df.isna().any(axis=1)]
    
    df = df[["data", "method", "edge_coverage"]]
    df["data"] = df.data.map(names)
    
    ax = sns.barplot(data=df, x="data", y="edge_coverage", hue="method", edgecolor='black', width=.5)
    plt.setp(ax.patches, linewidth=2)
    plt.ylabel("Proportion of preserved edges after partitioning")
    plt.xlabel("")
    plt.title(f"Edge coverage")
    plt.legend(title="Algorithm", loc='lower center', bbox_to_anchor=(0.5, -0.25), fancybox=True, ncol=3)
    plt.savefig(Path(__file__).absolute().parent / "edge_coverage.pdf", bbox_inches='tight')
    plt.show()

    return



def plot_cost_metrics():
    # Load results
    df = pd.read_parquet(results_path)
    df = df.sort_values('method')

    # Remove methods that failed
    df = df[~df.loc[:, df.columns != 'edge_coverage'].isna().any(axis=1)]

    # Change unit for memory
    df["memory"] = df["memory"] / 1024

    for data in datasets:
        df2 = df[df["data"] == data][["method", "time", "memory"]]

        sns.scatterplot(data=df2, x="time", y="memory", hue="method")
        if data == "mini_yago3":
            plt.hlines(y=0.012643806636333466, xmin=0, xmax=1e2, colors="k", ls='--')

        elif data == "yago3":
            plt.hlines(y=0.08754613995552063, xmin=0, xmax=2e2, colors="k", ls='--')

        elif data == "yago4":
            plt.hlines(y=3.8061579391360283, xmin=0, xmax=7300, colors="k", ls='--')

        elif data == "yago4.5":
            plt.hlines(y=1.4390715137124062, xmin=0, xmax=7300, colors="k", ls='--')

        plt.xlabel("Time (s)")
        plt.ylabel("Memory usage (GiB)")
        plt.xscale("log")
        plt.yscale("log")
        plt.title(f"Cost metrics - {names[data]}")
        plt.savefig(Path(__file__).absolute().parent / f"{data}_memory_vs_time.pdf", bbox_inches='tight')
        plt.show()
    return


def evaluate_adjacency_size(A):
    size = 0 # size in bit
    size += A.nnz * 32 # memory to store the data
    # memory to store data location
    size += A.indptr.shape[0] * 32
    size += A.indices.shape[0] * 32

    # Convert in GiB
    size /= 8 * (1024)**3
    return size



if __name__ == "__main__":
    for data in datasets:
        print(f"---------- {data} ----------")
        graph = create_graph(data)
        graph = keep_only_largest_cc(graph)
        compute_core_subgraph(graph, prop=core_proportions[data])
        core_entities = graph.core_subgraph_idx
        #nx_graph = tf2nx(graph.triples_factory)
        #ig_graph = tf2igraph(graph.triples_factory)

        for method in methods.keys():

            if method == "spectral" and data in ["yago3", "yago4"]:
                continue

            print(method)
            num_partition = num_partitions[data]

            # Set timer
            signal.signal(signal.SIGALRM, handler)
            signal.alarm(10000)

            try:
                start = time()
                if method in ["argmax", "KMeans"]:
                    memory, partitions = memory_usage((methods[method], (), {"num_partitions": num_partition, "graph": graph, "n_steps": n_steps[data]}), retval=True, max_usage=True, interval = time_intervals[data])
                elif method == "diffusion":
                    memory, partitions = memory_usage((methods[method], (), {"graph": graph, "max_size": max_sizes[data], "ngens": ngens[data]}), retval=True, max_usage=True, interval = time_intervals[data])
                elif method in ["louvain-ig", "leiden-ig", "eigen-ig", "LPA-ig", "infomap-ig"]:
                    memory, partitions = memory_usage((methods[method], (), {"ig_graph": ig_graph}), retval=True, max_usage=True, interval = time_intervals[data])
                elif method in ["louvain-nx"]:
                    memory, partitions = memory_usage((methods[method], (), {"nx_graph": nx_graph}), retval=True, max_usage=True, interval = time_intervals[data])
                else:
                    memory, partitions = memory_usage((methods[method], (), {"num_partitions": num_partition, "graph": graph}), retval=True, max_usage=True, interval = time_intervals[data])
                end = time()
                duration = end - start
                signal.alarm(0)

                sizes, balance, nb_cc = evaluate_partitions(partitions, graph, core_entities)


            except Exception as exc:
                print(exc)
                print(f"{method} failed for {data}.")
                duration = None
                sizes = None
                balance = None
                nb_cc = None
                memory = None
                

            # Save results
            res_dict = {
                "data": data,
                "method": method,
                "num_partitions": num_partition,
                "balance": balance,
                "sizes": [sizes],
                "nb_cc": [nb_cc],
                "time": duration,
                "memory": memory,
            }

            new_results = pd.DataFrame(res_dict, index=[0])

            if Path(results_path).is_file():
                results = pd.read_parquet(results_path)
                results = pd.concat([results, new_results]).reset_index(drop=True)

            else:
                results = new_results
                
            results.to_parquet(results_path, index=False)

