# Description: This script evaluates the link prediction performance of the DGL-KE models on the validation and test sets.
import pandas as pd

from SEPAL import SEPAL_DIR, DATASETS_NAMES
from SEPAL.link_prediction import transductive_link_prediction
from SEPAL.baselines.dglke.utils import (
    load_dglke_embeddings,
    load_dglke_relation_embeddings,
    load_dglke_checkpoint_info,
    get_dglke2sepal_reordering,
)


if __name__ == "__main__":
    for subset in ["val", "test"]:
        scores_path = SEPAL_DIR / f"{subset}_lp_scores_dglke.parquet"
        if scores_path.is_file():
            scores = pd.read_parquet(scores_path)
            evaluated = scores["id"].tolist()
        else:
            evaluated = []

        # Load checkpoint info
        checkpoints = load_dglke_checkpoint_info()
        checkpoints = checkpoints[checkpoints.subset == "train"]

        # Mask the models already evaluated
        mask = ~checkpoints["id"].isin(evaluated)
        to_evaluate = checkpoints[mask]

        # Begin with smaller models
        to_evaluate = to_evaluate.sort_values(by=["data"]).reset_index(drop=True)
        num_models = len(to_evaluate)

        # Loop over the models to evaluate
        for index, row in to_evaluate.iterrows():
            print(
                f"---------------- {index+1}/{num_models}: Evaluating: {DATASETS_NAMES[row.data]} embeddings on {subset} set ----------------"
            )
            # Load embeddings by id
            embeddings = load_dglke_embeddings(row.id)
            rel_embeddings = load_dglke_relation_embeddings(row.id)

            # Reorder embeddings
            entity_reordering, relation_reordering = get_dglke2sepal_reordering(
                row.data
            )
            embeddings = embeddings[entity_reordering]
            rel_embeddings = rel_embeddings[relation_reordering]

            # Evaluate embeddings
            transductive_link_prediction(
                data=row.data,
                subset=subset,
                model_name=row.embed_method,
                dim=row.embed_dim,
                embeddings=embeddings,
                rel_embeddings=rel_embeddings,
                id=row.id,
                results_file=scores_path,
                num_negatives=10000,
            )
