from pathlib import Path
import pandas as pd
import numpy as np
from typing import Union
from time import time
from sklearn.ensemble import HistGradientBoostingRegressor as HGBR
from sklearn.model_selection import (
    cross_val_score,
    RepeatedKFold,
    GridSearchCV,
)

from SEPAL import SEPAL_DIR, TARGETS, DATASETS_NAMES
from SEPAL.dataloader import DataLoader
from SEPAL.utils import get_checkpoint_info


def prediction_scores(
    embeddings: np.ndarray,
    id: str,
    target_file: Union[Path, str],
    data: str,
    n_repeats: int,
    tune_hyperparameters: bool,
    scoring: str,
    results_file: Union[Path, str],
):

    print(f"{TARGETS[target_file]}:", end=" ")

    # Load target dataframe
    target = pd.read_parquet(SEPAL_DIR / target_file)
    # Init dataloader
    triples_dir = SEPAL_DIR / f"datasets/knowledge_graphs/{data}"
    dataloader = DataLoader(triples_dir)

    # Replace entity names by their embedding
    nan_array = np.empty(shape=embeddings[0].shape, dtype="float32")
    nan_array[:] = np.nan

    if "yago4.5" in data:
        column_to_embed = "yago4.5_col_to_embed"
    elif "yago4" in data:
        column_to_embed = "yago4_col_to_embed"
    elif "freebase" in data:
        column_to_embed = "freebase_col_to_embed"
    elif "wikikg" in data:
        column_to_embed = "wikidata_col_to_embed"
    else:
        column_to_embed = "yago3_col_to_embed"
    X_emb = np.vstack(
        target[column_to_embed]
        .map(dataloader.entity_to_idx)
        .apply(lambda i: embeddings[int(i)] if i == i else nan_array)
        .to_numpy()
    )
    del dataloader
    y = target["target"]

    # Add party column (only for us_elections)
    if TARGETS[target_file] == "US elections":
        enc_col = pd.get_dummies(target["party"], prefix="party")
        X_emb = np.hstack([X_emb, enc_col.to_numpy()])

    model = HGBR()
    cv = RepeatedKFold(n_splits=5, n_repeats=n_repeats)

    if tune_hyperparameters:
        param_grid = {
            "max_depth": [2, 4, 6, None],
            "min_samples_leaf": [4, 6, 10, 20],
        }
        model = GridSearchCV(
            model,
            param_grid=param_grid,
            scoring=scoring,
            cv=3,
        )
    start_time = time()
    cv_scores = cross_val_score(model, X_emb, y, cv=cv, scoring=scoring, n_jobs=15)
    duration = time() - start_time

    X_shape = X_emb.shape

    # Save results to a dataframe
    results = {
        "data": data,
        "id": id,
        "target_file": str(target_file),
        "scoring": scoring,
        "duration": duration,
        "n_samples": X_shape[0],
        "n_features": X_shape[1],
        "scores": cv_scores,
    }
    new_df_res = pd.DataFrame([results])
    if Path(results_file).is_file():
        df_res = pd.read_parquet(results_file)
        df_res = pd.concat([df_res, new_df_res]).reset_index(drop=True)
    else:
        df_res = new_df_res
    df_res.to_parquet(results_file, index=False)

    print("OK")

    return


def main():
    ## Evaluate models that have not been evaluated yet
    # Load prediction scores if file exists
    scores_path = SEPAL_DIR / "downstream_scores.parquet"
    if scores_path.is_file():
        scores = pd.read_parquet(scores_path)
        evaluated = scores["id"].tolist()
    else:
        evaluated = []

    # Load checkpoint info
    checkpoints = get_checkpoint_info()
    checkpoints = checkpoints[checkpoints.subset != "train"]

    # Mask the models already evaluated
    mask = ~checkpoints["id"].isin(evaluated)
    to_evaluate = checkpoints[mask]

    # Loop over the models to evaluate
    for index, row in to_evaluate.iterrows():
        print(
            f"---------------- Evaluating: {DATASETS_NAMES[row.data]} embeddings ----------------"
        )
        embeddings = np.load(SEPAL_DIR / f"embeddings/{row.id}.npy")
        if row.embed_method == "rotate":
            # concatenate the real and imaginary parts for complex embeddings
            embeddings = np.hstack([embeddings.real, embeddings.imag])
        for target_file in TARGETS.keys():
            prediction_scores(
                embeddings=embeddings,
                id=row["id"],
                target_file=target_file,
                data=row.data,
                n_repeats=5,
                tune_hyperparameters=True,
                scoring="r2",
                results_file=scores_path,
            )
    return


if __name__ == "__main__":
    main()
