import copy
import json
import typing as ty
from pathlib import Path

import pandas as pd
from ablator import Optim, PlotAnalysis
from sklearn.preprocessing import MinMaxScaler

from tablator import data_directory, results_directory

#### ANALYSIS SPECIFIC TO ABLATOR


def make_analysis(
    res,
    save_dir: Path,
    metrics,
    categorical_name_remap,
    numerical_name_remap,
    cache=True,
):
    analysis = PlotAnalysis(
        res,
        save_dir=save_dir,
        cache=cache,
        optim_metrics=metrics,
        categorical_attributes=list(categorical_name_remap.keys()),
        numerical_attributes=list(numerical_name_remap.keys()),
    )

    attribute_name_remap = {**categorical_name_remap, **numerical_name_remap}
    analysis.make_figures(
        metric_name_remap={
            "val_acc": "Accuracy",
            "val_rmse": "RMSE",
        },
        attribute_name_remap=attribute_name_remap,
    )


def make_dataset_analysis(
    res: pd.DataFrame,
    save_dir: Path,
    aux_metrics,
    categorical_name_remap,
    numerical_name_remap,
):
    dataset_type_map = get_dataset_map()
    for i, (ds_name, task_type, ds_classes) in dataset_type_map:
        _res = copy.deepcopy(res)
        _res = _res[_res["train_config.dataset"] == ds_name]
        if len(_res) == 0:
            continue
        _res = parse_metrics(_res, task_type=task_type, ds_name=ds_name)
        ds_save_dir = save_dir.joinpath(ds_name)
        ds_metrics = aux_metrics[task_type]
        make_analysis(
            _res,
            ds_save_dir,
            ds_metrics,
            categorical_name_remap=categorical_name_remap,
            numerical_name_remap=numerical_name_remap,
        )


def make_dataset_cat_analysis(
    res: pd.DataFrame,
    save_dir: Path,
    aux_metrics,
    categorical_name_remap,
    numerical_name_remap,
):
    dataset_type_map = get_dataset_types()
    for task_type, datasets in dataset_type_map:
        _res = copy.deepcopy(res)
        _res = _res[_res["train_config.dataset"].apply(lambda x: x in datasets)]
        _res = parse_metrics_datasets(_res, task_type=task_type, ds_names=datasets)
        ds_save_dir = save_dir.joinpath(task_type)
        ds_metrics = aux_metrics[task_type]
        make_analysis(
            _res,
            ds_save_dir,
            ds_metrics,
            categorical_name_remap=categorical_name_remap,
            numerical_name_remap=numerical_name_remap,
        )


def run_analysis(res, save_dir: Path, cat_analysis: bool = True):
    res = copy.deepcopy(res)
    categorical_name_remap = {
        "model_config.activation": "Activation",
        "model_config.initialization": "Weight Init.",
        "train_config.optimizer_config.name": "Optimizer",
        "model_config.mask_type": "Mask Type",
        "train_config.cat_nan_policy": "Policy for Cat. Missing",
        "train_config.normalization": "Dataset Normalization",
    }
    numerical_name_remap = {
        "model_config.n_heads": "N. Heads",
        "model_config.n_layers": "N. Layers",
        "model_config.d_token": "Token Hidden Dim.",
        "model_config.d_ffn_factor": "Hidden Dim. Factor",
    }
    aux_metrics = {
        "regression": {"val_rmse": Optim.min},
        "classification": {"val_acc": Optim.max},
    }
    cat_save_dir = save_dir.joinpath("cat_comparison")
    ds_save_dir = save_dir.joinpath("dataset_comparison")
    ds_save_dir.mkdir(exist_ok=True, parents=True)
    cat_save_dir.mkdir(exist_ok=True, parents=True)
    res = rename_categorical(res)
    if cat_analysis:
        make_dataset_cat_analysis(
            res,
            cat_save_dir,
            aux_metrics,
            categorical_name_remap=categorical_name_remap,
            numerical_name_remap=numerical_name_remap,
        )
    make_dataset_analysis(
        res,
        ds_save_dir,
        aux_metrics,
        categorical_name_remap=categorical_name_remap,
        numerical_name_remap=numerical_name_remap,
    )


# Motivation https://www.ncbi.nlm.nih.gov/pmc/articles/PMC1124168/
inconclusive_trial_cut_off = {
    "epsilon": 0.55,
    "covtype": 0.55,
    "aloi": 0.08,
    "adult": 0.77,
    "california_housing": 1.1,
    "covtype": 0.55,
    "epsilon": 0.55,
    "helena": 0.08,
    "higgs_small": 0.55,
    "jannis": 0.55,
    "microsoft": 0.82,
    "yahoo": 0.95,
    "year": 10.5,
}


def get_dataset_types() -> ty.List[ty.Tuple[str, ty.List[str]]]:
    dataset_info = get_dataset_info()
    task_type_df = dataset_info.groupby("task_type").apply(
        lambda x: x["basename"].values
    )
    return list(task_type_df.items())


def get_dataset_map() -> ty.List[ty.Tuple[int, ty.Tuple[str, str, int]]]:
    dataset_info = get_dataset_info()
    return list(dataset_info[["basename", "task_type", "n_classes"]].iterrows())


def get_dataset_info():
    dataset_info = pd.DataFrame(
        [json.loads(p.read_text()) for p in data_directory.rglob("info.json")]
    )
    dataset_info.loc[
        (dataset_info["task_type"] == "multiclass")
        | (dataset_info["task_type"] == "binclass"),
        "task_type",
    ] = "classification"
    return dataset_info


def rename_categorical(res: pd.DataFrame):
    res["model_config.activation"] = res["model_config.activation"].apply(
        lambda x: " ".join([r"$\bf{" + _x + r"}$" for _x in x.upper().split("_")])
    )
    res["model_config.mask_type"] = res["model_config.mask_type"].apply(
        lambda x: r"$\bf{" + x.capitalize() + r"}$"
    )
    res["model_config.initialization"] = res["model_config.initialization"].apply(
        lambda x: r"$\bf{" + x.capitalize() + r"}$"
    )
    optimizer_remap = {
        "adabelief": "AdaBelief",
        "adamw": "AdamW",
        "radam": "RAdam",
        "sgd": "SGD",
        "adam": "Adam",
    }
    res["train_config.optimizer_config.name"] = res[
        "train_config.optimizer_config.name"
    ].apply(lambda x: r"$\bf{" + optimizer_remap[x] + r"}$")

    res["train_config.normalization"] = res["train_config.normalization"].apply(
        lambda x: r"$\bf{" + x.capitalize() + r"}$"
    )

    res["train_config.cat_nan_policy"] = res["train_config.cat_nan_policy"].apply(
        lambda x: " ".join([r"$\bf{" + _x.capitalize() + r"}$" for _x in x.split("_")])
    )
    return res


def parse_metrics(res: pd.DataFrame, task_type: str, ds_name: str):
    res = (
        res.groupby("path")
        .apply(lambda x: get_best(x, task_type))
        .reset_index(drop=True)
    )
    if task_type == "regression":
        res = res[
            (res["val_rmse"] < inconclusive_trial_cut_off[ds_name])
            & ~res["val_rmse"].isna()
        ]
    else:
        res = res[
            (res["val_acc"] > inconclusive_trial_cut_off[ds_name])
            & ~res["val_acc"].isna()
        ]

    return res


def get_best(x: pd.DataFrame, task_type: str):
    if task_type == "regression":
        return x.sort_values("val_rmse", na_position="last").iloc[0]
    else:
        return x.sort_values("val_acc", na_position="first").iloc[-1]


def get_best_df(x):
    task_type = "classification" if x["val_rmse"].isna().all() else "regression"
    return get_best(x, task_type)


def parse_metrics_datasets(res: pd.DataFrame, task_type: str, ds_names: ty.List[str]):
    res = (
        res.groupby("path")
        .apply(lambda x: get_best(x, task_type))
        .reset_index(drop=True)
    )

    for ds_name in ds_names:
        if task_type == "regression":
            metric = "val_rmse"
            obj = -1
        else:
            metric = "val_acc"
            obj = 1

        res = res[
            ~(
                (res[metric] * obj < inconclusive_trial_cut_off[ds_name] * obj)
                & (res["train_config.dataset"] == ds_name)
            )
            & ~res[metric].isna()
        ]
        res.loc[
            (res["train_config.dataset"] == ds_name), [metric]
        ] = MinMaxScaler().fit_transform(
            res.loc[(res["train_config.dataset"] == ds_name), [metric]]
        )

    return res


def make_results_table(res: pd.DataFrame):
    results_directory.joinpath("rq3").mkdir(exist_ok=True, parents=True)
    results_table = pd.Series(
        [0.459, 0.859, 0.391, 0.732, 0.729, 0.960, 0.8982, 8.855, 0.970, 0.756, 0.746],
        index=[
            "california_housing",
            "adult",
            "helena",
            "jannis",
            "higgs_small",
            "aloi",
            "epsilon",
            "year",
            "covtype",
            "yahoo",
            "microsoft",
        ],
    )
    results_table.name = "ft-transformer"
    res_table = res.groupby("train_config.dataset")[["val_acc", "val_rmse"]].apply(
        get_best_df
    )
    res_table["tablator"] = res_table.values[~res_table.isna()]
    final_table = pd.concat([results_table, res_table["tablator"]], axis=1)
    final_table["task_type"] = "regression"
    final_table.loc[res_table["val_rmse"].isna(), "task_type"] = "classification"
    final_table["is_best"] = "$\\uparrow$"
    final_table["is_best"][final_table["task_type"] == "regression"] = "$\\downarrow$"
    final_table.index = [f"{x[0]} {x[1].is_best}" for x in final_table.iterrows()]
    final_table[["ft-transformer", "tablator"]].T.to_latex(
        results_directory.joinpath("rq3", "results.tex"), escape=False
    )


def optim_rank(
    results: pd.DataFrame, optim: ty.Optional[str] = None, is_best: bool = False
):
    def _get_for_metric(metric="val_acc"):
        if metric == "val_acc":
            sort_results = lambda x: x.sort_values()[::-1]
            reduction = "mean" if not is_best else "max"
        else:
            sort_results = lambda x: x.sort_values()
            reduction = "mean" if not is_best else "min"

        return (
            getattr(
                results.groupby(
                    ["train_config.dataset", "train_config.optimizer_config.name"]
                )[metric],
                reduction,
            )()
            .dropna()
            .groupby("train_config.dataset")
            .apply(
                lambda x: sort_results(x)[:1].index.values[0][1]
                if optim is None
                else sort_results(x)
                .reset_index()["train_config.optimizer_config.name"]
                .values.tolist()
                .index(optim)
            )
        )

    best_optim_for_classification = _get_for_metric(metric="val_acc")
    best_optim_for_regression = _get_for_metric(metric="val_rmse")
    best_optim = pd.concat([best_optim_for_classification, best_optim_for_regression])

    return best_optim


def best_optim(res: pd.DataFrame):
    best_res_by_config = res.groupby("path").apply(lambda x: get_best_df(x))
    best_optim = optim_rank(best_res_by_config)
    sgd_count = (best_optim == "sgd").sum()
    print(f"SGD outperforms on average {sgd_count} / {best_optim.shape[0]}")

    sgd_rank = optim_rank(best_res_by_config, "sgd")
    print(f"SGD rank {sgd_rank.mean()+1}.")
    optimizers = (
        best_res_by_config["train_config.optimizer_config.name"].unique().tolist()
    )

    def best_vs_worst(optim):
        optim_ranks = optim_rank(best_res_by_config, optim) + 1
        best_optim_ranks = optim_rank(best_res_by_config, optim, is_best=True) + 1
        res = dict(
            mean_ranks=optim_ranks.mean(),
            best_ranks=best_optim_ranks.mean(),
            regression_mean=optim_ranks.iloc[-4:].mean(),
            classification_mean=optim_ranks.iloc[:-4].mean(),
            regression_best=best_optim_ranks.iloc[-4:].mean(),
            classification_best=best_optim_ranks.iloc[:-4].mean(),
            optim=optim,
        )
        return res

    res_df = pd.DataFrame(list(map(best_vs_worst, optimizers)))
    print("Optimizer Rank")
    print(res_df.to_latex())


def make_dataset_table():
    dataset_info = get_dataset_info()
    dataset_table = dataset_info[
        [
            "basename",
            "task_type",
            "n_num_features",
            "n_cat_features",
            "train_size",
            "val_size",
            "n_classes",
        ]
    ]
    dataset_table["n_classes"].isna()
    dataset_table.loc[
        dataset_table["n_classes"].isna()
        & (dataset_table["task_type"] == "classification"),
        "n_classes",
    ] = 2
    dataset_table.drop("task_type", axis=1, inplace=True)
    dataset_table.to_latex(results_directory.joinpath("dataset_info.tex"), index=False)
    pass


if __name__ == "__main__":
    res_random = pd.read_csv(results_directory.joinpath("raw_results_random.csv"))
    save_dir_random = results_directory.joinpath("random")
    res_tpe = pd.read_csv(results_directory.joinpath("raw_results_tpe.csv"))
    save_dir_tpe = results_directory.joinpath("tpe")
    run_analysis(res_random, save_dir_random)
    run_analysis(res_tpe, save_dir_tpe, cat_analysis=False)
    make_results_table(res_random)
    best_optim(res_random)
    make_dataset_table()
