import pandas as pd
import numpy as np
from pathlib import Path
from time import time
from sklearn.ensemble import HistGradientBoostingRegressor as HGBR
from sklearn.model_selection import (
    cross_val_score,
    RepeatedKFold,
    GridSearchCV,
)

from SEPAL import SEPAL_DIR, TARGETS, DATASETS_NAMES
from SEPAL.downstream_evaluation import prediction_scores
from SEPAL.baselines.dglke.utils import (
    load_dglke_checkpoint_info,
    load_dglke_embeddings,
)


def prediction_scores(
    embeddings,
    id,
    target_file,
    data,
    entity_to_idx,
    n_repeats,
    tune_hyperparameters,
    scoring,
    results_file,
):
    print(f"{TARGETS[target_file]}:", end=" ")

    # Load target dataframe
    target = pd.read_parquet(SEPAL_DIR / target_file)

    # Replace entity names by their embedding
    nan_array = np.empty(shape=embeddings[0].shape, dtype="float32")
    nan_array[:] = np.nan

    if "yago4.5" in data:
        column_to_embed = "yago4.5_col_to_embed"
    elif "yago4" in data:
        column_to_embed = "yago4_col_to_embed"
    elif "freebase" in data:
        column_to_embed = "freebase_col_to_embed"
    elif "wikikg" in data:
        column_to_embed = "wikidata_col_to_embed"
    else:
        column_to_embed = "yago3_col_to_embed"

    X_emb = np.vstack(
        target[column_to_embed]
        .map(entity_to_idx)
        .apply(lambda i: embeddings[int(i)] if i == i else nan_array)
        .to_numpy()
    )
    y = target["target"]

    # Add party column (only for us_elections)
    if TARGETS[target_file] == "US elections":
        enc_col = pd.get_dummies(target["party"], prefix="party")
        X_emb = np.hstack([X_emb, enc_col.to_numpy()])

    model = HGBR()
    cv = RepeatedKFold(n_splits=5, n_repeats=n_repeats)

    if tune_hyperparameters:
        param_grid = {
            "max_depth": [2, 4, 6, None],
            "min_samples_leaf": [4, 6, 10, 20],
        }
        model = GridSearchCV(
            model,
            param_grid=param_grid,
            scoring=scoring,
            cv=3,
        )
    start_time = time()
    cv_scores = cross_val_score(model, X_emb, y, cv=cv, scoring=scoring, n_jobs=15)
    duration = time() - start_time

    X_shape = X_emb.shape

    # Save results to a dataframe
    results = {
        "data": data,
        "id": id,
        "target_file": str(target_file),
        "scoring": scoring,
        "duration": duration,
        "n_samples": X_shape[0],
        "n_features": X_shape[1],
        "scores": cv_scores,
    }
    new_df_res = pd.DataFrame([results])
    if Path(results_file).is_file():
        df_res = pd.read_parquet(results_file)
        df_res = pd.concat([df_res, new_df_res]).reset_index(drop=True)
    else:
        df_res = new_df_res
    df_res.to_parquet(results_file, index=False)

    print("OK")

    return


def main():
    scores_path = SEPAL_DIR / "baselines/dglke/downstream_scores_dglke.parquet"
    if scores_path.is_file():
        scores = pd.read_parquet(scores_path)
        evaluated = scores["id"].tolist()
    else:
        evaluated = []

    # Load checkpoint info
    checkpoints = load_dglke_checkpoint_info()
    checkpoints = checkpoints[checkpoints.subset != "train"]

    # Mask the models already evaluated
    mask = ~checkpoints["id"].isin(evaluated)
    to_evaluate = checkpoints[mask]

    # Loop over the models to evaluate
    for index, row in to_evaluate.iterrows():
        print(
            f"---------------- Evaluating: {DATASETS_NAMES[row.data]} embeddings ----------------"
        )
        embeddings = load_dglke_embeddings(row.id)
        entity_path = SEPAL_DIR / f"baselines/dglke/data/full_{row.data}/entities.tsv"
        entity_names = pd.read_csv(entity_path, sep="\t", header=None, index_col=0)[
            1
        ].tolist()
        entity_to_idx = {k: v for v, k in enumerate(entity_names)}
        for target_file in TARGETS.keys():
            prediction_scores(
                embeddings=embeddings,
                id=row["id"],
                target_file=target_file,
                data=row.data,
                entity_to_idx=entity_to_idx,
                n_repeats=5,
                tune_hyperparameters=True,
                scoring="r2",
                results_file=scores_path,
            )
    return


if __name__ == "__main__":
    main()
