from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn as sk
from scipy.interpolate import make_smoothing_spline
from sklearn.ensemble import RandomForestClassifier

from tablator import results_directory
from tablator.analysis.rq3 import get_best_df, inconclusive_trial_cut_off


def sample_strategy_bias(res: pd.DataFrame, res_tpe: pd.DataFrame):
    # sample strategy
    attr = "train_config.optimizer_config.name"
    ds = "train_config.dataset"
    metrics = ["val_acc", "val_rmse"]
    cols = [attr] + metrics
    random_optim = res.groupby("path").apply(get_best_df).set_index(ds)[cols]
    tpe_optim = res_tpe.groupby("path").apply(get_best_df).set_index(ds)[cols]
    # TPE VS random
    # best vs mean

    def _make_df(_df, name, _df_counts):
        __df = pd.concat(
            [
                _df.max(),
                _df.mean(),
                _df_counts / _df_counts.sum(),
            ],
            axis=1,
        )
        __df.columns = ["max", "mean", "prob"]
        __df["sampling"] = name
        return __df

    def _write_table(tpe_optim, random_optim, dataset, metric):
        obj = -1 if metric == "val_rmse" else 1
        cut_off_mask_tpe = (
            tpe_optim.loc[dataset, metric] * obj
            > inconclusive_trial_cut_off[dataset] * obj
        )
        cut_off_mask_random = (
            random_optim.loc[dataset, metric] * obj
            > inconclusive_trial_cut_off[dataset] * obj
        )
        tpe_optim_counts = tpe_optim.loc[dataset].groupby(attr)[metric].count()
        random_optim_counts = random_optim.loc[dataset].groupby(attr)[metric].count()

        _tpe_optim = (
            tpe_optim.loc[dataset].loc[cut_off_mask_tpe.values].groupby(attr)[metric]
        )
        _random_optim = (
            random_optim.loc[dataset]
            .loc[cut_off_mask_random.values]
            .groupby(attr)[metric]
        )
        res_table = pd.concat(
            [
                _make_df(_tpe_optim, "tpe", tpe_optim_counts),
                _make_df(_random_optim, "random", random_optim_counts),
            ]
        )
        best_trial_error = (
            res_table.set_index("sampling", append=True)
            .loc[(slice(None), "tpe"), "max"]
            .values
            - res_table.set_index("sampling", append=True).loc[
                (slice(None), "random"), "max"
            ]
        ).mean()
        mean_trial_error = (
            res_table.set_index("sampling", append=True)
            .loc[(slice(None), "tpe"), "mean"]
            .values
            - res_table.set_index("sampling", append=True).loc[
                (slice(None), "random"), "mean"
            ]
        ).mean()
        print(f"TPE best_trial error {best_trial_error}")
        print(f"TPE mean_trial error {mean_trial_error}")
        res_table.reset_index().set_index([attr, "sampling"]).sort_index().to_latex(
            results_directory.joinpath("rq2", f"tpe_vs_random_{dataset}.tex"),
            escape=False,
        )

    _write_table(tpe_optim, random_optim, "covtype", "val_acc")


def non_random_errors(res: pd.DataFrame):
    errored_df = pd.read_csv(results_directory.joinpath("errored_trials.csv"))
    # get the the configuration for each trial
    all_trials_df = res.groupby("path").apply(lambda x: x.iloc[0])
    # extract the uid
    all_trials_df["uid"] = all_trials_df["path"].apply(lambda x: Path(x).name)
    all_trials_df.set_index("uid", inplace=True)
    # remove errored uids from the results
    uid_of_interest = list(set(all_trials_df.index) - set(errored_df.index))
    all_trials_df.loc[uid_of_interest]
    feature_names = np.array(
        [
            "train_config.optimizer_config.name",
            "train_config.optimizer_config.lr",
            "train_config.epochs",
            "train_config.dataset",
            "train_config.normalization",
            "train_config.cat_nan_policy",
            "train_config.cat_policy",
            "train_config.cat_min_frequency",
            "train_config.dataset_seed",
            "model_config.token_bias",
            "model_config.n_layers",
            "model_config.d_token",
            "model_config.n_heads",
            "model_config.d_ffn_factor",
            "model_config.attention_dropout",
            "model_config.ffn_dropout",
            "model_config.residual_dropout",
            "model_config.prenormalization",
            "model_config.initialization",
            "model_config.mask_type",
            "model_config.activation",
            "model_config.residual",
            "model_config.random_mask_alpha",
        ]
    )
    X_train = pd.concat(
        [all_trials_df[feature_names], errored_df[feature_names]], axis=0
    )
    for fname, dtype in X_train.dtypes.items():
        if str(dtype) == "object":
            X_train[fname] = sk.preprocessing.LabelEncoder().fit_transform(
                X_train[fname]
            )
    y_train = np.concatenate(
        [np.zeros(all_trials_df.shape[0]), np.ones(errored_df.shape[0])]
    )
    forest = RandomForestClassifier(random_state=0)
    forest.fit(X_train, y_train)
    imp_feats = feature_names[np.argsort(forest.feature_importances_)[::-1]][:3]

    print(f"Most important factors leading to OOM errors. {imp_feats}")
    avg_feature_importance = (
        np.sort(forest.feature_importances_)[::-1][:3].mean()
        - np.sort(forest.feature_importances_)[::-1][3:].mean()
    ) / np.sort(forest.feature_importances_)[::-1].mean()
    print(f"Mean feature importance. {avg_feature_importance}")


def smooth_line(y):
    x = y.index
    y_sp = make_smoothing_spline(x.values, y.values)
    x = np.linspace(x.min(), x.max(), 500)
    y = y_sp(x)
    return x, y


def budget_allocation(res: pd.DataFrame):
    # how budget affects statistical significance
    # mean and best over time
    save_dir = results_directory.joinpath("rq2")
    save_dir.mkdir(exist_ok=True, parents=True)
    best_res = res.groupby("path").apply(get_best_df)

    def _sub_sample_mean(x, n_samples, j):
        rand_idxs = np.random.RandomState(seed=n_samples + j * 3).permutation(
            x.shape[0]
        )
        x = x.iloc[rand_idxs[:n_samples]]
        ds_name = x["train_config.dataset"].iloc[0]

        return pd.Series(
            {
                "mean": x[(x["val_acc"] > inconclusive_trial_cut_off[ds_name])][
                    "val_acc"
                ].mean(),
                "best": x["val_acc"].max(),
                "n_samples": n_samples,
            }
        )

    rows = []
    for j in range(5):
        for i in range(2, 10):
            results = (
                best_res.groupby("train_config.dataset")
                .apply(lambda x: _sub_sample_mean(x, int(i / 2 * 10), j))
                .dropna()
            )
            rows.append(results)
    _df = pd.concat(rows)
    ds_names = [ds for ds in _df.index.unique()]
    fig, ax = plt.subplots(1, 2, figsize=(6, 6))
    for ds_name in ds_names:
        mean_perf = _df.loc[ds_name].groupby("n_samples").mean()
        y_err = _df.loc[ds_name].groupby("n_samples").std()
        ds_display_name = " ".join(ds_name.capitalize().split("_"))
        x, mean_y = smooth_line(mean_perf["mean"])
        _, mean_y_err = mean_y + smooth_line(y_err["mean"])
        _, best_y = smooth_line(mean_perf["best"])
        _, best_y_err = best_y + smooth_line(y_err["best"])
        p_0 = ax[0].plot(x, mean_y, label=ds_display_name)
        p_1 = ax[1].plot(x, best_y, label=ds_name + " best")
        alpha = 0.35
        ax[0].fill_between(
            y1=mean_y,
            y2=mean_y_err,
            x=x,
            alpha=alpha,
            color=p_0[0].get_color(),
        )
        ax[1].fill_between(
            y1=best_y,
            y2=best_y_err,
            x=x,
            alpha=alpha,
            color=p_1[0].get_color(),
        )
        ax[1].set_yticks([])
    ax[0].legend()
    lims = min(ax[0].get_ylim() + ax[1].get_ylim()), max(
        ax[0].get_ylim() + ax[1].get_ylim()
    )
    ax[0].set_title("Mean Performance", size=12)
    ax[0].set_xlabel("Budget N. Trials", size=12)
    ax[1].set_title("Best Performance", size=12)
    ax[1].set_xlabel("Budget N. Trials", size=12)
    ax[0].set_ylabel("Accuracy", size=12)
    ax[1].set_ylim(lims)
    ax[0].set_ylim(lims)
    fig.tight_layout()
    fig.savefig(save_dir.joinpath("resource_allocation.png"))
    return


def resource_utilization():
    mem_usage = pd.read_csv(
        results_directory.joinpath("rq2", "memory_reqs.csv")
    ).set_index("uid")
    # max and mean memory.
    mem_usage["memory"] = mem_usage["memory"].apply(lambda x: int(x[:-3]))
    print(
        f"Mean Usage (MiB): {mem_usage['memory'].mean():,.2f} \pm {mem_usage['memory'].std():,.2f} max: {mem_usage['memory'].max():,.2f}"
    )


def stat_test(res: pd.DataFrame):
    res.groupby("train_config.dataset").apply(
        lambda x: x[["val_rmse", "val_auc", "val_acc"]].mean()
    )


if __name__ == "__main__":

    res = pd.read_csv(results_directory.joinpath("raw_results_random.csv"))
    res_tpe = pd.read_csv(results_directory.joinpath("raw_results_tpe.csv"))
    sample_strategy_bias(res, res_tpe)
    non_random_errors(res)
    resource_utilization()
    budget_allocation(res)
