import pandas as pd
import numpy as np
from pathlib import Path
from time import time
from sklearn.ensemble import HistGradientBoostingRegressor as HGBR
from sklearn.model_selection import (
    cross_val_score,
    RepeatedKFold,
    GridSearchCV,
)
import json

from SEPAL import SEPAL_DIR
from SEPAL.downstream_evaluation import prediction_scores, TARGETS, DATASETS_NAMES

from SEPAL.baselines.PBG.utils import load_PBG_embeddings


def prediction_scores(
    embeddings,
    id,
    target_file,
    data,
    entity_to_idx,
    n_repeats,
    tune_hyperparameters,
    scoring,
    results_file,
):

    print(f"{TARGETS[target_file]}:", end=" ")

    # Load target dataframe
    target = pd.read_parquet(SEPAL_DIR / target_file)
    # Load previously stored results
    if Path(results_file).is_file():
        df_res = pd.read_parquet(results_file)
    else:
        df_res = None

    # Replace entity names by their embedding
    nan_array = np.empty(shape=embeddings[0].shape, dtype="float32")
    nan_array[:] = np.nan

    column_to_embed = (
        "yago4_col_to_embed" if "yago4" in str(data) else "yago3_col_to_embed"
    )
    X_emb = np.vstack(
        target[column_to_embed]
        .map(entity_to_idx)
        .apply(lambda i: embeddings[int(i)] if i == i else nan_array)
        .to_numpy()
    )
    y = target["target"]

    # Add party column (only for us_elections)
    if TARGETS[target_file] == "US elections":
        enc_col = pd.get_dummies(target["party"], prefix="party")
        X_emb = np.hstack([X_emb, enc_col.to_numpy()])

    model = HGBR()
    cv = RepeatedKFold(n_splits=5, n_repeats=n_repeats)

    if tune_hyperparameters:
        param_grid = {
            "max_depth": [2, 4, 6, None],
            "min_samples_leaf": [4, 6, 10, 20],
        }
        model = GridSearchCV(
            model,
            param_grid=param_grid,
            scoring=scoring,
            cv=3,
        )
    start_time = time()
    cv_scores = cross_val_score(model, X_emb, y, cv=cv, scoring=scoring, n_jobs=15)
    duration = time() - start_time

    X_shape = X_emb.shape

    # Save results to a dataframe
    results = {
        "data": data,
        "id": id,
        "target_file": str(target_file),
        "scoring": scoring,
        "duration": duration,
        "n_samples": X_shape[0],
        "n_features": X_shape[1],
        "scores": cv_scores,
    }
    new_df_res = pd.DataFrame([results])
    if Path(results_file).is_file():
        df_res = pd.read_parquet(results_file)
        df_res = pd.concat([df_res, new_df_res]).reset_index(drop=True)
    else:
        df_res = new_df_res
    df_res.to_parquet(results_file, index=False)

    print("OK")

    return


def main():
    scores_path = SEPAL_DIR / "baselines/PBG/downstream_scores_pbg.parquet"
    for data in [
        # "mini_yago3_lcc",
        # "yago3_lcc",
        # "rel_core_yago4",
        # "rel_core_yago4.5",
        # "core_yago4",
        # "core_yago4.5",
        # "yago4_lcc",
        # "yago4.5_lcc",
        "yago4_with_full_ontology",
    ]:
        print(f"Evaluating {data} embeddings on downstream tasks")
        embeddings = load_PBG_embeddings(data, subset="all")
        entity_path = SEPAL_DIR / f"baselines/PBG/data/{data}/entity_names_all_0.json"
        with open(entity_path) as f:
            entity_names = json.load(f)
        entity_to_idx = {k: v for v, k in enumerate(entity_names)}
        for target_file in TARGETS.keys():
            prediction_scores(
                embeddings=embeddings,
                id=f"PBG - {DATASETS_NAMES[data]} all",
                target_file=target_file,
                data=data,
                entity_to_idx=entity_to_idx,
                n_repeats=5,
                tune_hyperparameters=True,
                scoring="r2",
                results_file=scores_path,
            )
    return


if __name__ == "__main__":
    main()
