import pandas as pd
import json

from SEPAL import SEPAL_DIR, TARGETS, DATASETS_NAMES
from SEPAL.baselines.dglke.downstream_evaluation_dglke import prediction_scores
from SEPAL.baselines.PBG.utils import load_PBG_embeddings, load_pbg_checkpoints


def main():
    scores_path = SEPAL_DIR / "baselines/PBG/downstream_scores_pbg.parquet"
    if scores_path.is_file():
        scores = pd.read_parquet(scores_path)
        evaluated = scores["id"].tolist()
    else:
        evaluated = []

    # Load checkpoint info
    checkpoints = load_pbg_checkpoints()
    checkpoints = checkpoints[checkpoints.subset == "all"]

    # Mask the models already evaluated
    mask = ~checkpoints["id"].isin(evaluated)
    to_evaluate = checkpoints[mask]

    # Loop over the models to evaluate
    for index, row in to_evaluate.iterrows():
        print(
            f"---------------- Evaluating: {DATASETS_NAMES[row.data]} embeddings ----------------"
        )
        embeddings = load_PBG_embeddings(row.id)
        entity_path = (
            SEPAL_DIR / f"baselines/PBG/data/{row.data}/entity_names_all_0.json"
        )
        with open(entity_path) as f:
            entity_names = json.load(f)
        entity_to_idx = {k: v for v, k in enumerate(entity_names)}
        for target_file in TARGETS.keys():
            prediction_scores(
                embeddings=embeddings,
                id=row["id"],
                target_file=target_file,
                data=row.data,
                entity_to_idx=entity_to_idx,
                n_repeats=5,
                tune_hyperparameters=True,
                scoring="r2",
                results_file=scores_path,
            )
    return


if __name__ == "__main__":
    main()
