from pathlib import Path
from time import time
import pickle

from SEPAL.settings import set_control_params
from SEPAL.utils import (
    store_embeddings,
    create_graph,
    create_train_graph,
    measure_performance,
    reorder_outer_subgraphs,
    subgraph_hash,
)
from SEPAL.core_extraction import extract_subgraph
from SEPAL.subgraph_generation import make_subgraphs
from SEPAL.propagation import propagate_embeddings
import SEPAL.embed as embed_module
from SEPAL import SEPAL_DIR


def sepal(ctrl):
    # Check if the same model has already been trained
    embedding_path = SEPAL_DIR / f"embeddings/{ctrl.id}.npy"
    if embedding_path.is_file():
        print("The embedding file for this configuration already exists.")
        return

    ### Load Knowledge Graph
    if ctrl.subset == "train":
        graph = create_train_graph(ctrl.data, ctrl.create_inverse_triples)
    else:
        graph = create_graph(ctrl.data, ctrl.create_inverse_triples)

    ### Set time intervals for memory usage monitoring
    ti = ctrl.time_interval

    ### Compute Embeddings
    start = time()

    ## Step 1: Generate subgraphs
    subgraph_start = time()

    core_graph, ctrl.core_subgraph_time, ctrl.core_subgraph_mem_usage = (
        measure_performance(extract_subgraph)(ti, ctrl, graph)
    )

    subgraphs_file = subgraph_hash(ctrl)

    if Path(subgraphs_file).is_file():
        ctrl.load_subgraphs = True
        with open(subgraphs_file, "rb") as f:
            subgraphs = pickle.load(f)

    else:
        ctrl.load_subgraphs = False
        subgraphs, ctrl.cover_subgraph_time, ctrl.cover_subgraph_mem_usage = (
            measure_performance(make_subgraphs)(ti, ctrl, graph)
        )
        # Save subgraphs for next time
        with open(subgraphs_file, "wb") as f:
            pickle.dump(subgraphs, f)

    subgraph_end = time()
    ctrl.subgraph_time = subgraph_end - subgraph_start

    if ctrl.reorder_subgraphs:
        subgraphs = reorder_outer_subgraphs(graph, subgraphs)

    ## Step 2: Embed core subgraph
    if ctrl.embed_method == "tucker":
        (
            (core_tensor, core_embed, relations_embed),
            ctrl.embed_time,
            ctrl.embed_mem_usage,
        ) = measure_performance(getattr(embed_module, ctrl.embed_method))(
            ti, ctrl, core_graph
        )
    else:
        (core_embed, relations_embed), ctrl.embed_time, ctrl.embed_mem_usage = (
            measure_performance(getattr(embed_module, ctrl.embed_method))(
                ti, ctrl, core_graph
            )
        )

    ## Step 3: Propagate embeddings
    if ctrl.embed_method == "tucker":
        (
            embeddings,
            ctrl.propagation_time,
            ctrl.propagation_mem_usage,
        ) = measure_performance(propagate_embeddings)(
            ti,
            ctrl,
            graph,
            core_embed,
            relations_embed,
            subgraphs,
            core_tensor=core_tensor,
        )
    else:
        embeddings, ctrl.propagation_time, ctrl.propagation_mem_usage = (
            measure_performance(propagate_embeddings)(
                ti, ctrl, graph, core_embed, relations_embed, subgraphs
            )
        )

    end = time()
    ctrl.total_time = end - start

    ### Store embeddings
    if ctrl.embed_method == "tucker":
        store_embeddings(
            ctrl,
            embeddings,
            relations_embed,
            embedding_path,
            checkpoint_path=SEPAL_DIR / "checkpoints_sepal.parquet",
            core_tensor=core_tensor,
        )
    else:
        store_embeddings(
            ctrl,
            embeddings,
            relations_embed,
            embedding_path,
            checkpoint_path=SEPAL_DIR / "checkpoints_sepal.parquet",
        )

    return


def sepal_wrapper(data="mini_yago3_lcc", gpu=0, **kwargs):
    """
    Runs the SEPAL model.
    --------------
    gpu: The gpu to use (-1 to use cpu).
    """
    # Create crtl
    ctrl = set_control_params(data, gpu, **kwargs)

    # Run model
    sepal(ctrl)
    return


if __name__ == "__main__":
    for subset in ["train"]:
        for config in [
            # ("mini_yago3_lcc", 4e4),
            # ("yago3_lcc", 4e5),
            # ("yago4.5_lcc", 4e6),
            # ("yago4.5_with_full_ontology", 4e6),
            # ("yago4_lcc", 4e6),
            # ("yago4_with_full_ontology", 4e6),
            # ("full_freebase_lcc", 4e6),
            ("wikikg90mv2_lcc", 2e6),
        ]:
            sepal_wrapper(
                config[0],
                1,
                subgraph_max_size=config[1],
                embed_method="distmult",
                partitioning="blocs",
                subset=subset,
                core_selection="hybrid",
                propagation_type="normalized_sum",
                num_negs_per_pos=100,
            )
