from pathlib import Path
import pandas as pd
import numpy as np
from typing import Union
from time import time
from sklearn.ensemble import HistGradientBoostingRegressor as HGBR
from sklearn.model_selection import (
    cross_val_score,
    RepeatedKFold,
    GridSearchCV,
)

from SEPAL.dataloader import DataLoader
from SEPAL import SEPAL_DIR


TARGETS = {
    "datasets/evaluation/us_elections/target_log.parquet": "US elections",
    "datasets/evaluation/housing_prices/target_log.parquet": "Housing prices",
    "datasets/evaluation/us_accidents/target_log.parquet": "US accidents",
    "datasets/evaluation/movie_revenues/target_log.parquet": "Movie revenues",
    # "datasets/evaluation/company_employees/target_log.parquet": "Company employees",
}

DATASETS_NAMES = {
    "mini_yago3_lcc": "Mini YAGO3",
    "yago3_lcc": "YAGO3",
    "core_yago4": "Core YAGO4",
    "yago4_lcc": "YAGO4",
    "yago4.5_lcc": "YAGO4.5",
    "yago4_with_full_ontology": "YAGO4 + taxonomy",
}


def prediction_scores(
    embeddings: np.ndarray,
    id: str,
    target_file: Union[Path, str],
    triples_dir: Union[Path, str],
    n_repeats: int,
    tune_hyperparameters: bool,
    scoring: str,
    results_file: Union[Path, str],
):

    print(f"{TARGETS[target_file]}:", end=" ")

    # Load target dataframe
    target = pd.read_parquet(SEPAL_DIR / target_file)
    # Init dataloader
    dataloader = DataLoader(SEPAL_DIR / triples_dir)
    # Load previously stored results
    if Path(results_file).is_file():
        df_res = pd.read_parquet(results_file)
    else:
        df_res = None

    # Replace entity names by their embedding
    nan_array = np.empty(shape=embeddings[0].shape, dtype="float32")
    nan_array[:] = np.nan

    column_to_embed = (
        "yago4_col_to_embed" if "yago4" in str(triples_dir) else "yago3_col_to_embed"
    )
    X_emb = np.vstack(
        target[column_to_embed]
        .map(dataloader.entity_to_idx)
        .apply(lambda i: embeddings[int(i)] if i == i else nan_array)
        .to_numpy()
    )
    del dataloader
    y = target["target"]

    # Add party column (only for us_elections)
    if TARGETS[target_file] == "US elections":
        enc_col = pd.get_dummies(target["party"], prefix="party")
        X_emb = np.hstack([X_emb, enc_col.to_numpy()])

    model = HGBR()
    cv = RepeatedKFold(n_splits=5, n_repeats=n_repeats)

    if tune_hyperparameters:
        param_grid = {
            "max_depth": [2, 4, 6, None],
            "min_samples_leaf": [4, 6, 10, 20],
        }
        model = GridSearchCV(
            model,
            param_grid=param_grid,
            scoring=scoring,
            cv=3,
        )
    start_time = time()
    cv_scores = cross_val_score(model, X_emb, y, cv=cv, scoring=scoring, n_jobs=15)
    duration = time() - start_time

    X_shape = X_emb.shape

    # Save results to a dataframe
    results = {
        "triples_dir": str(triples_dir),
        "id": id,
        "target_file": str(target_file),
        "scoring": scoring,
        "duration": duration,
        "n_samples": X_shape[0],
        "n_features": X_shape[1],
        "scores": cv_scores,
    }
    new_df_res = pd.DataFrame([results])
    if Path(results_file).is_file():
        df_res = pd.read_parquet(results_file)
        df_res = pd.concat([df_res, new_df_res]).reset_index(drop=True)
    else:
        df_res = new_df_res
    df_res.to_parquet(results_file, index=False)

    print("OK")

    return


def main():
    ## Evaluate models that have not been evaluated yet
    # Load prediction scores if file exists
    scores_path = SEPAL_DIR / "downstream_scores.parquet"
    if scores_path.is_file():
        scores = pd.read_parquet(scores_path)
        evaluated = scores["id"].tolist()
    else:
        evaluated = []

    # Load checkpoint info
    checkpoint_path = SEPAL_DIR / "checkpoints_sepal.parquet"
    checkpoints = pd.read_parquet(checkpoint_path)
    checkpoints = checkpoints[checkpoints.subset != "train"]

    # Mask the models already evaluated
    mask = ~checkpoints["id"].isin(evaluated)
    to_evaluate = checkpoints[mask]

    # Loop over the models to evaluate
    for index, row in to_evaluate.iterrows():
        if DATASETS_NAMES[row.data] == "Yago4.5":
            print("Cannot evaluate Yago4.5 for now (no entity linkage). ")
            continue

        print(
            f"---------------- Evaluating: {DATASETS_NAMES[row.data]} embeddings ----------------"
        )
        embeddings = np.load(SEPAL_DIR / f"embeddings/{row.id}.npy")
        for target_file in TARGETS.keys():
            prediction_scores(
                embeddings=embeddings,
                id=row["id"],
                target_file=target_file,
                triples_dir=SEPAL_DIR / f"datasets/knowledge_graphs/{row.data}",
                n_repeats=5,
                tune_hyperparameters=True,
                scoring="r2",
                results_file=scores_path,
            )
    return



if __name__ == "__main__":
    main()