import h5py
import numpy as np
from pathlib import Path

from SEPAL import SEPAL_DIR


def load_PBG_embeddings(data: str, subset: str) -> np.ndarray:
    if subset == "all":
        embeddings_path = (
            SEPAL_DIR / f"baselines/PBG/model/{data}/{subset}/embeddings_all_0.v1200.h5"
        )
    elif subset == "train":
        embeddings_path = (
            SEPAL_DIR / f"baselines/PBG/model/{data}/{subset}/embeddings_all_0.v400.h5"
        )
    with h5py.File(embeddings_path, "r") as file:
        embeddings = file["embeddings"][()]  # returns as a numpy array
    return embeddings


def load_PBG_relation_embeddings(data: str, subset: str) -> np.ndarray:
    if subset == "all":
        model_path = SEPAL_DIR / f"baselines/PBG/model/{data}/{subset}/model.v1200.h5"
    elif subset == "train":
        model_path = SEPAL_DIR / f"baselines/PBG/model/{data}/{subset}/model.v400.h5"

    with h5py.File(model_path, "r") as file:
        rel_embeddings = file["model/relations/0/operator/rhs/diagonals"][
            ()
        ]  # returns as a numpy array

    return rel_embeddings


def read_embeddings(file_path: Path) -> np.ndarray:
    # Open the .h5 file in read mode
    with h5py.File(file_path, "r") as file:
        embeddings = file[list(file.keys())[0]][()]  # returns as a numpy array
    return embeddings


def read_edges(file_path: Path) -> np.ndarray:
    # Open the .h5 file in read mode
    with h5py.File(file_path, "r") as file:
        heads = file["lhs"][()]
        rels = file["rel"][()]
        tails = file["rhs"][()]
    edges = np.vstack((heads, rels, tails)).T
    return edges
