# Description: Evaluate PBG embeddings on link prediction task
import pandas as pd

from SEPAL import SEPAL_DIR, DATASETS_NAMES
from SEPAL.link_prediction import transductive_link_prediction
from SEPAL.baselines.PBG.utils import (
    load_PBG_embeddings,
    load_PBG_relation_embeddings,
    load_pbg_checkpoints,
    get_pbg2sepal_reordering,
)


if __name__ == "__main__":
    for subset in ["val", "test"]:
        scores_path = SEPAL_DIR / f"{subset}_lp_scores_pbg.parquet"
        if scores_path.is_file():
            scores = pd.read_parquet(scores_path)
            evaluated = scores["id"].tolist()
        else:
            evaluated = []

        # Load checkpoint info
        checkpoints = load_pbg_checkpoints()
        checkpoints = checkpoints[checkpoints.subset == "train"]

        # Mask the models already evaluated
        mask = ~checkpoints["id"].isin(evaluated)
        to_evaluate = checkpoints[mask]

        # Begin with smaller models
        to_evaluate = to_evaluate.sort_values(by=["data"]).reset_index(drop=True)
        num_models = len(to_evaluate)

        # Loop over the models to evaluate
        for index, row in to_evaluate.iterrows():
            print(
                f"---------------- {index+1}/{num_models}: Evaluating: {DATASETS_NAMES[row.data]} embeddings on {subset} set ----------------"
            )
            # Load embeddings by id
            embeddings = load_PBG_embeddings(row.id)
            rel_embeddings = load_PBG_relation_embeddings(row.id)

            # Reorder embeddings
            entity_reordering, relation_reordering = get_pbg2sepal_reordering(row.data)
            embeddings = embeddings[entity_reordering]
            rel_embeddings = rel_embeddings[relation_reordering]

            # Evaluate embeddings
            transductive_link_prediction(
                data=row.data,
                subset=subset,
                model_name=row.embed_method,
                dim=row.embed_dim,
                embeddings=embeddings,
                rel_embeddings=rel_embeddings,
                id=row.id,
                results_file=scores_path,
                num_negatives=10000,
            )
