from pathlib import Path
import networkx as nx
import numpy as np
import pandas as pd
import pickle
import torch
import scipy.sparse
from functools import wraps
from time import time
from memory_profiler import memory_usage
from pykeen.triples import TriplesFactory

from SEPAL.knowledge_graph import KnowledgeGraph
from SEPAL.dataloader import DataLoader
from SEPAL import SEPAL_DIR



def create_graph(data, create_inverse_triples=True):
    print(f"Loading {data} data...")
    triples_dir = SEPAL_DIR / f"datasets/knowledge_graphs/{data}"
    dl = DataLoader(triples_dir)
    tf = dl.get_triples_factory(create_inverse_triples=create_inverse_triples)
    print("Building knowledge graph...")
    graph = KnowledgeGraph(tf)
    return graph


def create_train_graph(data, create_inverse_triples=True):
    print(f"Loading {data} data...")
    path = SEPAL_DIR / f"datasets/knowledge_graphs/{data}/training_tf.pkl"
    with open(path, "rb") as f:
        tf = pickle.load(f)

    triple_factory = TriplesFactory(
        tf.mapped_triples,
        tf.entity_to_id,
        tf.relation_to_id,
        create_inverse_triples=create_inverse_triples,
    )
    print("Building knowledge graph...")
    graph = KnowledgeGraph(triple_factory)
    return graph


def tf2nx(triples_factory):
    print("Creating NetworkX object...")
    mapped_triples = np.array(
        triples_factory._add_inverse_triples_if_necessary(
            triples_factory.mapped_triples
        )
    )
    G = nx.MultiDiGraph()
    G.add_edges_from(
        list(
            zip(
                mapped_triples[:, 0],
                mapped_triples[:, 2],
                [dict(relation_idx=i) for i in mapped_triples[:, 1]],
            )
        )
    )
    return G


def store_embeddings(
    ctrl, embeddings, relations_embed, embedding_path, checkpoint_path, **kwargs
):
    # Create a dataframe that contains all the parameters and model state info
    config = ctrl.get_config()
    new_checkpoint_info = pd.DataFrame(config, index=[0])

    # Save embeddings
    np.save(embedding_path, embeddings)
    np.save(SEPAL_DIR / f"embeddings/{ctrl.id}_relations.npy", relations_embed.cpu())
    if ctrl.embed_method == "tucker":
        core_tensor = kwargs["core_tensor"]
        np.save(SEPAL_DIR / f"embeddings/{ctrl.id}_tensor.npy", core_tensor.cpu())

    # Load checkpoint_info and add a new line
    if Path(checkpoint_path).is_file():
        checkpoint_info = pd.read_parquet(checkpoint_path)
        checkpoint_info = pd.concat([checkpoint_info, new_checkpoint_info]).reset_index(
            drop=True
        )
    else:
        checkpoint_info = new_checkpoint_info

    # Save checkpoint info
    checkpoint_info.to_parquet(checkpoint_path, index=False)
    return


def get_partition_list_from_df(df):
    return [
        df[f"Partition_{k}"][~df[f"Partition_{k}"].isna()].astype(np.int32).values
        for k in range(df.shape[1])
    ]


def keep_only_largest_cc(graph):
    """Remove all connected components from the graph except the largest one.

    Parameters
    ----------
    graph: a KnowledgeGraph object.

    Returns
    -------
    new_graph: a KnowledgeGraph object, corresponding to the largest component of the graph.
    """
    print("Computing largest connected component...")
    n_components, labels = scipy.sparse.csgraph.connected_components(
        graph.adjacency, directed=False, return_labels=True
    )
    node_list = np.where(labels == np.argmax(np.bincount(labels)))[0]

    old_tf = graph.triples_factory
    mask = torch.isin(old_tf.mapped_triples[:, [0, 2]], torch.IntTensor(node_list)).all(
        axis=1
    )
    mapped_triples = old_tf.mapped_triples[mask]

    # Reindex subgraph entities between 0 and n-1
    d = {node_list[i]: i for i in range(len(node_list))}
    mapped_triples[:, [0, 2]] = torch.tensor(
        np.vectorize(d.__getitem__)(mapped_triples[:, [0, 2]])
    )

    # Build the new entity_to_id dictionnary
    old_id_to_entity = {v: k for k, v in old_tf.entity_to_id.items()}
    d_inv = {v: k for k, v in d.items()}
    id_to_entity = {i: old_id_to_entity[d_inv[i]] for i in range(len(node_list))}
    entity_to_id = {v: k for k, v in id_to_entity.items()}

    # Create triple factory object
    triples_factory = TriplesFactory(
        mapped_triples=mapped_triples,
        entity_to_id=entity_to_id,
        relation_to_id=old_tf.relation_to_id,
        create_inverse_triples=old_tf.create_inverse_triples,
    )

    ## Build knowledge_graph instance
    new_graph = KnowledgeGraph(triples_factory)

    return new_graph


def measure_performance(func):
    @wraps(func)
    def wrapper(time_interval, *args, **kwargs):
        start_time = time()

        peak_memory_usage, result = memory_usage(
            (func, args, kwargs), retval=True, max_usage=True, interval=time_interval
        )

        end_time = time()

        execution_time = end_time - start_time

        return result, execution_time, peak_memory_usage

    return wrapper
