from pathlib import Path
import pandas as pd
import json
import pickle
import numpy as np

from SEPAL import SEPAL_DIR


def load_dglke_checkpoint_info():
    # Navigate the folders in the DGL-KE checkpoint directory
    checkpoint_dir = Path(SEPAL_DIR, "baselines", "dglke", "ckpts")

    # Build a pandas DataFrame with the information, adding a new line for each checkpoint
    checkpoint_info = pd.DataFrame()
    for model_dir in checkpoint_dir.iterdir():
        # Check if config.json file exists in model_dir
        if not Path(model_dir, "config.json").exists():
            continue
        # Open config.json file in model_dir
        with open(Path(model_dir, "config.json"), "r") as f:
            config = json.load(f)
        # Concatenate with training_info.json file in model_dir
        with open(Path(model_dir, "training_info.json"), "r") as f:
            training_info = json.load(f)
        config.update(training_info)
        # Add id
        config["id"] = "DGLKE_" + model_dir.name
        # if 'dataset' starts with 'full_', remove 'full_', else add 'subset'='train'
        if config["dataset"].startswith("full_") and not config["dataset"].startswith(
            "full_freebase"
        ):
            config["dataset"] = config["dataset"].removeprefix("full_")
        else:
            config["subset"] = "train"
        # Rename dataset to data
        config["data"] = config.pop("dataset")
        # Rename emb_size to embed_dim
        config["embed_dim"] = config.pop("emb_size")
        # Add embed_method
        config["embed_method"] = config.pop("model").lower()
        # Add method
        config["method"] = "DGL-KE"
        # If config["initialize_time"] exists, add it to config["total_time"]
        if "initialize_time" in config:
            config["total_time"] = config["initialize_time"] + config["total_time"]
        # Add config to checkpoint_info
        checkpoint_info = pd.concat([checkpoint_info, pd.DataFrame([config])])
    return checkpoint_info


def load_dglke_embeddings(ckpt_id):
    model_name = ckpt_id.removeprefix("DGLKE_")
    model_dir = Path(SEPAL_DIR, "baselines", "dglke", "ckpts", model_name)

    # Load the file in model_dir that ends with entity.npy
    entity_embedding_file = next(model_dir.glob("*entity.npy"))
    entity_embeddings = np.load(entity_embedding_file)

    return entity_embeddings


def load_dglke_relation_embeddings(ckpt_id):
    model_name = ckpt_id.removeprefix("DGLKE_")
    model_dir = Path(SEPAL_DIR, "baselines", "dglke", "ckpts", model_name)

    # Load the file in model_dir that ends with relation.npy
    relation_embedding_file = next(model_dir.glob("*relation.npy"))
    relation_embeddings = np.load(relation_embedding_file)

    return relation_embeddings


def get_dglke2sepal_reordering(data):
    """SEPAL and DGLKE do not index entities and relations in the same way.
    This function returns the reordering needed to match DGLKE indices with SEPAL indices.
    """
    # Load SEPAL training triples factory
    sepal_data_dir = SEPAL_DIR / f"datasets/knowledge_graphs/{data}"
    with open(sepal_data_dir / "training_tf.pkl", "rb") as f:
        tf = pickle.load(f)
    sepal_entity_to_idx = tf.entity_to_id
    sepal_relation_to_idx = tf.relation_to_id

    # Load DGLKE entity and relation names
    dglke_data_dir = SEPAL_DIR / f"baselines/dglke/data/{data}"
    entity_names = pd.read_csv(
        dglke_data_dir / "entities.tsv", sep="\t", header=None, index_col=0
    )[1].tolist()
    relation_names = pd.read_csv(
        dglke_data_dir / "relations.tsv", sep="\t", header=None, index_col=0
    )[1].tolist()
    dglke_entity_to_idx = {k: v for v, k in enumerate(entity_names)}
    dglke_relation_to_idx = {k: v for v, k in enumerate(relation_names)}

    # Get the reorderings
    entity_reordering = [dglke_entity_to_idx[k] for k in sepal_entity_to_idx.keys()]
    relation_reordering = [
        dglke_relation_to_idx[k] for k in sepal_relation_to_idx.keys()
    ]

    return entity_reordering, relation_reordering
