import h5py
import json
import pickle
import numpy as np
import pandas as pd
from pathlib import Path

from SEPAL import SEPAL_DIR


def load_PBG_embeddings(id: str) -> np.ndarray:
    pbg_checkpoints = pd.read_parquet(
        SEPAL_DIR / "baselines/PBG/checkpoints_pbg.parquet"
    )
    row = pbg_checkpoints.loc[pbg_checkpoints["id"] == id].iloc[0]
    embeddings_path = (
        SEPAL_DIR
        / row.checkpoint_path
        / f"embeddings_all_0.v{len(row.edge_paths) * row.num_edge_chunks * row.num_epochs}.h5"
    )
    with h5py.File(embeddings_path, "r") as file:
        embeddings = file["embeddings"][()]  # returns as a numpy array
    return embeddings


def load_PBG_relation_embeddings(id: str) -> np.ndarray:
    pbg_checkpoints = pd.read_parquet(
        SEPAL_DIR / "baselines/PBG/checkpoints_pbg.parquet"
    )
    row = pbg_checkpoints.loc[pbg_checkpoints["id"] == id].iloc[0]
    model_path = (
        SEPAL_DIR
        / row.checkpoint_path
        / f"model.v{len(row.edge_paths) * row.num_edge_chunks * row.num_epochs}.h5"
    )
    with h5py.File(model_path, "r") as file:
        rel_embeddings = file["model/relations/0/operator/rhs/diagonals"][
            ()
        ]  # returns as a numpy array

    return rel_embeddings


def read_embeddings(file_path: Path) -> np.ndarray:
    # Open the .h5 file in read mode
    with h5py.File(file_path, "r") as file:
        embeddings = file[list(file.keys())[0]][()]  # returns as a numpy array
    return embeddings


def read_edges(file_path: Path) -> np.ndarray:
    # Open the .h5 file in read mode
    with h5py.File(file_path, "r") as file:
        heads = file["lhs"][()]
        rels = file["rel"][()]
        tails = file["rhs"][()]
    edges = np.vstack((heads, rels, tails)).T
    return edges


def get_pbg2sepal_reordering(data):
    """SEPAL and PBG do not index entities and relations in the same way.
    This function returns the reordering needed to match PBG indices to SEPAL indices.
    """
    # Load SEPAL training triples factory
    sepal_data_dir = SEPAL_DIR / f"datasets/knowledge_graphs/{data}"
    with open(sepal_data_dir / "training_tf.pkl", "rb") as f:
        tf = pickle.load(f)
    sepal_entity_to_idx = tf.entity_to_id
    sepal_relation_to_idx = tf.relation_to_id

    # Load PBG entity and relation names
    pbg_data_dir = SEPAL_DIR / f"baselines/PBG/data/{data}"
    with open(pbg_data_dir / "entity_names_all_0.json") as f:
        entity_names = json.load(f)
    with open(pbg_data_dir / "dynamic_rel_names.json") as f:
        relation_names = json.load(f)
    pbg_entity_to_idx = {k: v for v, k in enumerate(entity_names)}
    pbg_relation_to_idx = {k: v for v, k in enumerate(relation_names)}

    # Get the reorderings
    entity_reordering = [pbg_entity_to_idx[k] for k in sepal_entity_to_idx.keys()]
    relation_reordering = [pbg_relation_to_idx[k] for k in sepal_relation_to_idx.keys()]

    return entity_reordering, relation_reordering


def load_pbg_checkpoints():
    pbg_checkpoints = pd.read_parquet(
        SEPAL_DIR / "baselines/PBG/checkpoints_pbg.parquet"
    )
    pbg_checkpoints["total_time"] = pbg_checkpoints["training_time"]
    pbg_checkpoints.loc[
        pbg_checkpoints["relations"].apply(lambda x: x[0]["operator"]) == "diagonal",
        "embed_method",
    ] = "distmult"
    pbg_checkpoints["embed_dim"] = pbg_checkpoints["dimension"]
    pbg_checkpoints["method"] = "PyTorch-BigGraph"

    # Remove ids starting with "PBG - core"
    pbg_checkpoints = pbg_checkpoints[~pbg_checkpoints["id"].str.startswith("PBG - core")]
    return pbg_checkpoints