from pathlib import Path
import pykeen.datasets
import networkx as nx
import random as rd
import pandas as pd
import numpy as np
import pickle
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

from SEPAL.dataloader import DataLoader

data_dir = Path(__file__).parent


def tf2nx(triples_factory):
    mapped_triples = np.array(triples_factory.mapped_triples)
    G = nx.Graph()
    G.add_edges_from(list(zip(mapped_triples[:,0], mapped_triples[:,2], [dict(relation_idx = i) for i in mapped_triples[:,1]])))
    return G

def load_metadata(dataset):
    dl = DataLoader(data_dir / dataset)
    return dl

def load_data(dataset):
    dl = DataLoader(data_dir / dataset)
    tf = dl.get_triples_factory(create_inverse_triples=True)
    return dl, tf

def basic_stats(dl):
    print("Number of entities: ", dl.n_entities)
    print("Number of relations: ", dl.n_relations)
    return

def advanced_stats(dataset, tf):
    print(f"------------- Dataset: {dataset} -------------")
    stats = {"dataset": dataset}

    n_entities = np.unique(tf.mapped_triples[:,[0,2]]).size
    stats["n_entities"] = n_entities
    print("Number of entities: ", n_entities)

    n_relations = np.unique(tf.mapped_triples[:,1]).size
    stats["n_relations"] = n_relations
    print("Number of relations: ", n_relations)

    stats["n_triples"] = tf.num_triples
    print("Number of triples: ", tf.num_triples)

    # Create networkx graph
    G = tf2nx(tf)

    # Degree analysis
    degree_values = np.array(list(dict(G.degree).values()))

    stats["max_degree"] = degree_values.max()
    print("Max degree: ", degree_values.max())

    stats["mean_degree"] = degree_values.mean()
    print("Average degree: ", degree_values.mean())

    stats["min_degree"] = degree_values.min()
    print("Min degree: ", degree_values.min())

    # Connected components analysis

    stats["connected"] = nx.is_connected(G)
    print('connected graph: ', stats["connected"])

    stats["n_connected_components"] = nx.number_connected_components(G)
    print("Number of connected components: ", stats["n_connected_components"])

    largest_cc = max(nx.connected_components(G), key=len)

    stats["lcc_size"] = len(largest_cc)
    print('Size of largest connected component: ', len(largest_cc), "(%.1f%%)" % (100*len(largest_cc)/tf.num_entities))

    stats["mspl"] = estimate_MSPL(G.subgraph(largest_cc))
    print('Estimated mean shortest path length in largest connected component: ', stats["mspl"])

    stats["diameter"] = nx.approximation.diameter(G.subgraph(largest_cc))
    print('Approximate diameter of largest connected component:', stats["diameter"])

    # Density measures
    stats["density"] = nx.density(G)
    print("Density: ", stats["density"])

    #print("Transitivity: ", nx.transitivity(G)) # computationally very expensive

    # Save stats
    stats_path = data_dir / "dataset_stats.parquet"
    new_df = pd.DataFrame([stats], index=[0])
    if Path(stats_path).is_file():
        df = pd.read_parquet(stats_path)
        df = pd.concat([df, new_df]).reset_index(drop=True)
    else:
        df = new_df
    df.to_parquet(stats_path, index=False)

    return

def estimate_MSPL(G, n=500):
    """
    Estimates the Mean Shortest Path Length.
    """
    l = 0.0
    n = min(G.number_of_nodes() // 2, n)
    samp = rd.sample(list(G.nodes), 2*n)
    for k in range(n):
        i, j = samp[2*k], samp[2*k+1]
        l += nx.shortest_path_length(G, source=i, target=j)
    return l/n

def display_relation_distribution(dataset, log=True):
    """
    Plot a bar graph representing the number of triples per type of relation.
    """
    dl, tf = load_data(dataset)

    edges = tf.mapped_triples[:,1].tolist()
    data = dict(sorted({rel: edges.count(id) for rel, id in dl.rel_to_idx.items()}.items(), key = lambda item: item[1]))

    names = list(data.keys())

    plt.figure(figsize=(3,12))
    plt.barh(names, data.values(), color = mcolors.XKCD_COLORS)
    if log:
        plt.xscale('log')
    plt.xlabel("Number of triples")
    plt.show()

    if log:
        plt.savefig(data_dir / f"log_relations_{dataset}.pdf", bbox_inches='tight')
    else:
        plt.savefig(data_dir / f"relations_{dataset}.pdf", bbox_inches='tight')
    return

def display_relation_distributions(datasets, log=True):
    """
    Plot a bar graph representing the number of triples per type of relation.
    """
    fig, axs = plt.subplots(nrows=1, ncols=len(datasets), figsize=(4*len(datasets), 12))
    
    for i in range(len(datasets)):
        dl, tf = load_data(datasets[i])
        edges = tf.mapped_triples[:,1].tolist()
        data = dict(sorted({rel: edges.count(id) for rel, id in dl.rel_to_idx.items()}.items(), key = lambda item: item[1]))
        names = list(data.keys())
        axs[i].barh(names, data.values(), color = mcolors.XKCD_COLORS)
        if log:
            axs[i].set_xscale('log')
        axs[i].set_xlabel("Number of triples")
        axs[i].set_title(datasets[i])
    fig.tight_layout()
    if log:
        fig.savefig(data_dir / "datasets_log_relations.pdf", bbox_inches='tight')
    else:
        fig.savefig(data_dir / "datasets_relations.pdf", bbox_inches='tight')
    return

def make_stats(datasets):
    for dataset in datasets:
        dl, tf = load_data(dataset)
        advanced_stats(dataset, tf)
    return

def make_train_test_val_stats(datasets):
    for data in datasets:
        # Training set
        dataset = data + "_train"
        with open(data_dir / f"{data}/training_tf.pkl", "rb") as f:
            tf = pickle.load(f)
        advanced_stats(dataset, tf)

        # Testing set
        dataset = data + "_test"
        with open(data_dir / f"{data}/testing_tf.pkl", "rb") as f:
            tf = pickle.load(f)
        advanced_stats(dataset, tf)

        # Validation set
        dataset = data + "_validation"
        with open(data_dir / f"{data}/validation_tf.pkl", "rb") as f:
            tf = pickle.load(f)
        advanced_stats(dataset, tf)
    return

def literature_datasets_stats(datasets):
    for dataset_name in datasets:
        data = getattr(pykeen.datasets, dataset_name)()
        advanced_stats(dataset_name + "_train", data.training)
        advanced_stats(dataset_name + "_test", data.testing)
        advanced_stats(dataset_name + "_validation", data.validation)
    return

def make_latex_table():
    stats_path = data_dir / "dataset_stats.parquet"
    df = pd.read_parquet(stats_path)
    df.to_latex(index=False)
    return

if __name__ == "__main__":
    datasets = [
        # "mini_yago3_lcc",
        # "mini_yago3",
        # "core_yago3",
        # "core_yago4",
        # "core_yago4.5",
        # "rel_core_yago3",
        # "rel_core_yago4",
        # "rel_core_yago4.5",
        # "yago3_lcc",
        # "yago3",
        # "yago4_lcc",
        # "yago4",
        # "yago4.5_lcc",
        # "yago4.5",
        # "yago4_with_ontology",
        # "yago4_with_full_ontology",
        # "yago4.5_with_ontology",
        # "yago4.5_with_full_ontology",
        "full_freebase_lcc",
        # "full_freebase",
        # "wikikg90mv2",
        "wikikg90mv2_lcc",
    ]
    # make_stats(datasets)
    make_train_test_val_stats(datasets)
    

    # datasets = ["FB15k","FB15k237","WN18","WN18RR", "YAGO310"]
    # literature_datasets_stats(datasets)
    