from pathlib import Path
import pandas as pd
import numpy as np
import json
from tqdm import tqdm
from time import time
from sklearn.ensemble import (
    HistGradientBoostingRegressor as HGBR,
    HistGradientBoostingClassifier as HGBC,
)
from sklearn.model_selection import (
    cross_val_score,
    RepeatedKFold,
    GridSearchCV,
)

from SEPAL import SEPAL_DIR, DATASETS_NAMES
from SEPAL.dataloader import DataLoader
from SEPAL.utils import (
    get_full_checkpoint_info,
    get_cls_wikidb_files,
    get_reg_wikidb_files,
    get_cls_val_files,
    get_reg_val_files,
)
from SEPAL.baselines.dglke.utils import load_dglke_embeddings
from SEPAL.baselines.PBG.utils import load_PBG_embeddings
from SEPAL.plotting.utils import (
    make_score_df,
    make_scores_normalized,
    get_best_model,
)


def prediction_scores(
    embeddings,
    id,
    target_file,
    data,
    entity_to_idx,
    n_repeats,
    tune_hyperparameters,
    scoring,
    results_file,
    kind,
):
    # Load target dataframe
    target = pd.read_parquet(target_file)

    # Replace entity names by their embedding
    nan_array = np.empty(shape=embeddings[0].shape, dtype="float32")
    nan_array[:] = np.nan

    if "yago4.5" in data:
        column_to_embed = "yago4.5_col_to_embed"
    elif "yago4" in data:
        column_to_embed = "yago4_col_to_embed"
    elif "freebase" in data:
        column_to_embed = "freebase_col_to_embed"
    elif "wikikg" in data:
        column_to_embed = "wikidata_col_to_embed"
    else:
        column_to_embed = "yago3_col_to_embed"

    X_emb = np.vstack(
        target[column_to_embed]
        .map(entity_to_idx)
        .apply(lambda i: embeddings[int(i)] if i == i else nan_array)
        .to_numpy()
    )
    y = target["target"]

    if kind == "classification":
        model = HGBC()
    elif kind == "regression":
        model = HGBR()
    cv = RepeatedKFold(n_splits=5, n_repeats=n_repeats)

    if tune_hyperparameters:
        param_grid = {
            "max_depth": [2, 4, 6, None],
            "min_samples_leaf": [4, 6, 10, 20],
        }
        model = GridSearchCV(
            model,
            param_grid=param_grid,
            scoring=scoring,
            cv=3,
        )
    start_time = time()
    cv_scores = cross_val_score(model, X_emb, y, cv=cv, scoring=scoring, n_jobs=15)
    duration = time() - start_time

    X_shape = X_emb.shape

    # Save results to a dataframe
    results = {
        "data": data,
        "id": id,
        "target_file": str(target_file),
        "scoring": scoring,
        "duration": duration,
        "n_samples": X_shape[0],
        "n_features": X_shape[1],
        "scores": cv_scores,
        "kind": kind,
    }
    new_df_res = pd.DataFrame([results])
    if Path(results_file).is_file():
        df_res = pd.read_parquet(results_file)
        df_res = pd.concat([df_res, new_df_res]).reset_index(drop=True)
    else:
        df_res = new_df_res
    df_res.to_parquet(results_file, index=False)

    return


def main(subset="val", filters=None):
    ## Evaluate models that have not been evaluated yet
    if subset == "val":
        cls_files = get_cls_val_files()
        reg_files = get_reg_val_files()
    elif subset == "test":
        cls_files = get_cls_wikidb_files()
        reg_files = get_reg_wikidb_files()

    # Load prediction scores if file exists
    scores_path = SEPAL_DIR / "wikidb_scores.parquet"
    if scores_path.is_file():
        scores = pd.read_parquet(scores_path)
        eval_files = [str(f) for f in cls_files + reg_files]
        scores = scores[scores.target_file.isin(eval_files)]
        scores = scores.drop_duplicates(subset=["id", "target_file"])
        evaluated = [
            e
            for e in scores.id.unique()
            if len(scores[scores.id == e].target_file.unique()) == len(eval_files)
        ]
    else:
        evaluated = []

    # Load checkpoint info
    checkpoints = get_full_checkpoint_info()
    checkpoints = checkpoints[checkpoints.subset != "train"]

    # Mask the models already evaluated
    mask = ~checkpoints["id"].isin(evaluated)
    to_evaluate = checkpoints[mask]

    # Filter out experiments
    for k, v in filters.items():
        to_evaluate = to_evaluate[to_evaluate[k].isin(v)]

    # If evaluating on test tables, evaluate only the models performing the best on validation tables
    if subset == "test":
        scores = make_score_df(filters=filters, tasks="wikidb")
        val_scores = scores[
            scores.target_file.isin(
                [str(f) for f in get_cls_val_files() + get_reg_val_files()]
            )
        ]
        val_scores = make_scores_normalized(val_scores)
        val_scores = get_best_model(val_scores, ["data", "method", "embed_method"])

        to_evaluate = to_evaluate[
            to_evaluate["id"].isin(val_scores["id"].unique())
        ].reset_index(drop=True)

    # Loop over the models to evaluate
    for data in to_evaluate.data.unique():
        data_to_evaluate = to_evaluate[to_evaluate.data == data].reset_index(drop=True)

        # Pre-load entity indexes (heavy in memory)
        dglke_entity_path = SEPAL_DIR / f"baselines/dglke/data/full_{data}/entities.tsv"
        dglke_entity_names = pd.read_csv(
            dglke_entity_path, sep="\t", header=None, index_col=0
        )[1].tolist()
        dglke_entity_to_idx = {k: v for v, k in enumerate(dglke_entity_names)}
        del dglke_entity_names
        pbg_entity_path = (
            SEPAL_DIR / f"baselines/PBG/data/{data}/entity_names_all_0.json"
        )
        with open(pbg_entity_path) as f:
            pbg_entity_names = json.load(f)
        pbg_entity_to_idx = {k: v for v, k in enumerate(pbg_entity_names)}
        del pbg_entity_names
        triples_dir = SEPAL_DIR / f"datasets/knowledge_graphs/{data}"
        sepal_entity_to_idx = DataLoader(triples_dir).entity_to_idx
        num_models = len(data_to_evaluate)

        for index, row in data_to_evaluate.iterrows():
            print(
                f"{index+1}/{num_models}: Evaluating {DATASETS_NAMES[row.data]} embeddings by {row.method} on wikidb tables"
            )

            # Load models' embeddings
            if row.method == "DGL-KE":
                embeddings = load_dglke_embeddings(row.id)
                entity_to_idx = dglke_entity_to_idx
            elif row.method == "PyTorch-BigGraph":
                embeddings = load_PBG_embeddings(row.id)
                entity_to_idx = pbg_entity_to_idx
            else:  # TODO: add support for GraSH
                embeddings = np.load(SEPAL_DIR / f"embeddings/{row.id}.npy")
                if row.embed_method == "rotate":
                    # concatenate the real and imaginary parts for complex embeddings
                    embeddings = np.hstack([embeddings.real, embeddings.imag])
                entity_to_idx = sepal_entity_to_idx

            cls_files_to_evaluate = [
                f
                for f in cls_files
                if str(f) not in scores[scores["id"] == row.id]["target_file"].unique()
            ]

            for target_file in tqdm(
                cls_files_to_evaluate, desc="Classification", unit="file"
            ):
                prediction_scores(
                    embeddings=embeddings,
                    id=row["id"],
                    target_file=target_file,
                    data=row.data,
                    n_repeats=5,
                    tune_hyperparameters=True,
                    scoring="f1_weighted",
                    results_file=scores_path,
                    kind="classification",
                    entity_to_idx=entity_to_idx,
                )

            reg_files_to_evaluate = [
                f
                for f in reg_files
                if str(f) not in scores[scores["id"] == row.id]["target_file"].unique()
            ]

            for target_file in tqdm(
                reg_files_to_evaluate, desc="Regression", unit="file"
            ):
                prediction_scores(
                    embeddings=embeddings,
                    id=row["id"],
                    target_file=target_file,
                    data=row.data,
                    n_repeats=5,
                    tune_hyperparameters=True,
                    scoring="r2",
                    results_file=scores_path,
                    kind="regression",
                    entity_to_idx=entity_to_idx,
                )
    return


if __name__ == "__main__":
    main(
        subset="val",
        filters={
            "embed_dim": [100, 50],
            "embed_method": (["distmult", "fastrp", "transe", "rotate"]),
            "data": [
                "mini_yago3_lcc",
                "yago3_lcc",
                "yago4.5_lcc",
                "yago4.5_with_full_ontology",
                "yago4_lcc",
                "yago4_with_full_ontology",
                "full_freebase_lcc",
                "wikikg90mv2_lcc",
            ],
            "num_negs_per_pos": [100],
        },
    )
    main(
        subset="test",
        filters={
            "embed_dim": [100, 50],
            "embed_method": (["distmult", "fastrp", "transe", "rotate"]),
            "data": [
                "mini_yago3_lcc",
                "yago3_lcc",
                "yago4.5_lcc",
                "yago4.5_with_full_ontology",
                "yago4_lcc",
                "yago4_with_full_ontology",
                "full_freebase_lcc",
                "wikikg90mv2_lcc",
            ],
            "num_negs_per_pos": [100],
        },
    )
