"""Plot hyperparameter transfer experiment results."""

from __future__ import annotations

import pathlib
import re
import sys
from typing import Sequence

import fire
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import tueplots

sys.path.insert(0, "../../../experiments")
sys.path.insert(0, "analysis/experiments")
import utils


def plot(
    dir: str = pathlib.Path.cwd() / "../../../../publication/paper/figures/",
    raw_results_file_train: str = "experiment_results_train.csv",
    raw_results_file_test: str = "experiment_results_test.csv",
    raw_results_file_val: str = "experiment_results_val.csv",
    extensions: Sequence[str] = ("pdf",),
):
    """Plot Hyperparameter Transfer."""

    # Read data

    ## train
    df_train = pd.read_csv(raw_results_file_train)

    ## test
    df_test = pd.read_csv(raw_results_file_test)

    ## val
    df_val = pd.read_csv(raw_results_file_val)

    test_metrics = [
        col for col in df_test.columns if "Test" in col or "Validation" in col
    ]
    val_metrics = [col for col in df_val.columns if "Validation" in col]

    # merge
    df = pd.merge(
        df_train, df_test[["run name"] + test_metrics], how="left", on=["run name"]
    )
    df = pd.merge(df, df_val[["run name"] + val_metrics], how="left", on=["run name"])

    # hidden size column
    df["hidden_size"] = df.apply(
        lambda row: int(row.hidden_sizes.split(",")[0][1:]), axis=1
    )
    df.drop(columns=["Unnamed: 0"], inplace=True)
    df.shape

    # filter to desired hyperparameters
    df = df[df["model"] == "MLPIVILowRank"]
    df = df[df["max_epochs"] == 20]
    df = df[df["scale_mean_input_init_weight"] == 16.0]  # 1.0, 16
    df = df[df["bias"] == True]
    df = df[df["momentum"] == 0.0]
    df = df[df["Seed"] == 321587]

    # process data for plot

    ## column used for model selection
    select_criteria = "Final Validation NLL/dataloader_idx_2"  # 'Final Validation NLL/dataloader_idx_2', 'Validation NLL', 'Final Validation NLL/dataloader_idx_2'

    assert select_criteria in df.columns
    width_select = df["hidden_size"].min().item()  # width to extrapolate from

    # check if criteria should be minned or maxed
    if "NLL" in select_criteria:
        select_direction = "min"
    elif "Accuracy" in select_criteria:
        select_direction = "max"
    else:
        raise Exception("Not if select_criteria should be min or maxed")

    df_list = []
    for model in ["MLPIVILowRank"]:

        df_model = df.query(f"model == '{model}'")

        # best possible (across parametrization)
        idx = (
            df_model.query(f"model == '{model}'")
            .groupby("hidden_size")["Test Accuracy/dataloader_idx_0"]
            .idxmax()
        )

        df_best_test = df_model.loc[idx].copy()

        for parametrization in ["Standard", "MaximalUpdate"]:

            df_model_param = df_model.query(f"parametrization == '{parametrization}'")

            # extrapolating from small width
            if select_direction == "min":
                opt_lr = (
                    df_model_param.query("hidden_size == %d" % width_select)
                    .nsmallest(1, select_criteria)["lr"]
                    .item()
                )
            elif select_direction == "max":
                opt_lr = (
                    df_model_param.query("hidden_size == %d" % width_select)
                    .nlargest(1, select_criteria)["lr"]
                    .item()
                )

            df_extrapolated = (df_model_param.query(f"lr == {opt_lr}")).copy()
            df_extrapolated["select_method"] = "extrapolated"

            # not extrapolating
            idx = df_model_param.groupby("hidden_size")[select_criteria]
            if select_direction == "min":
                idx = idx.idxmin()
            elif select_direction == "max":
                idx = idx.idxmax()

            df_not_extrapolated = df_model_param.loc[idx].copy()
            df_not_extrapolated["select_method"] = "best validation"

            # combine two model selection strategies and merge with best across parametrizations
            df_concat = pd.concat((df_extrapolated, df_not_extrapolated), axis=0)

            df_merge = pd.merge(
                df_concat,
                df_best_test[["hidden_size", "Test Accuracy/dataloader_idx_0"]],
                on=["hidden_size"],
                how="left",
                suffixes=("_selected", "_best"),
            )

            df_list.append(df_merge)

    # combine both parametrizations
    df_model_selection = pd.concat(df_list)

    # y-axis for plot 3
    df_model_selection["relative_test_accuracy"] = (
        df_model_selection["Test Accuracy/dataloader_idx_0_selected"]
        / df_model_selection["Test Accuracy/dataloader_idx_0_best"]
    )

    # actually do the plotting now
    nrows = 1
    ncols = 3
    with plt.rc_context(
        utils.plotting.style.neurips(
            rel_width=1.0,
            nrows=nrows * 1.4,
            ncols=ncols,
        )
    ):
        fig, axs = plt.subplots(
            nrows=nrows,
            ncols=ncols,
            squeeze=False,
            sharex=False,
        )

        # 1
        ax = axs[0][0]
        ax.set_title("Standard Param. (SP)")

        sns.lineplot(
            x="lr",
            y="Train Loss",
            hue="hidden_size",
            linestyle="-",
            marker="o",
            data=df[df["parametrization"] == "Standard"],
            ax=ax,
        )
        ax.set_xscale("log", base=2)
        # ax.set_yscale('log')
        if ax.get_legend() is not None:
            ax.get_legend().remove()

        ax.set_xlabel("Learning Rate")

        handles, labels = ax.get_legend_handles_labels()
        fig.legend(
            handles,
            labels,
            title="Hidden Size",
            loc="upper center",
            bbox_to_anchor=(0.25, 0.0),
            fancybox=False,
            shadow=False,
            ncol=3,
            frameon=False,
        )

        # 2
        ax = axs[0][1]
        ax.set_title(r"Maximal Update Param. ($\mu$P)")

        sns.lineplot(
            x="lr",
            y="Train Loss",
            hue="hidden_size",
            linestyle="-",
            marker="o",
            data=df[df["parametrization"] == "MaximalUpdate"],
            ax=ax,
        )
        ax.set_xscale("log", base=2)

        # ax.set_yscale('log')
        if ax.get_legend() is not None:
            ax.get_legend().remove()

        ax.set_xlabel("Learning Rate")
        ax.set_ylabel(None)

        # 3
        ax = axs[0][2]
        sns.lineplot(
            x="hidden_size",
            y="relative_test_accuracy",
            hue="parametrization",
            style="select_method",
            ax=ax,
            data=df_model_selection,
            marker="o",
            # style_order=['MaximalUpdate', 'Standard']
            style_order=["best validation", "extrapolated"],
            dashes={"best validation": (1, 0), "extrapolated": (1, 1)},
        )
        ax.set_xscale("log", base=2)

        ax.set_xlabel("Hidden Size")
        ax.set_ylabel("Relative Test Accuracy")

        handles, labels = ax.get_legend_handles_labels()

        # manipulating labels
        labels = list(
            map(lambda x: "Parametrization" if x == "parametrization" else x, labels)
        )
        labels = list(
            map(lambda x: "LR Selection Method" if x == "select_method" else x, labels)
        )
        labels = list(
            map(
                lambda x: "Transferred Grid Search" if x == "extrapolated" else x,
                labels,
            )
        )
        labels = list(
            map(lambda x: "Grid Search" if x == "best validation" else x, labels)
        )

        labels = list(
            map(lambda x: "Maximal Update" if x == "MaximalUpdate" else x, labels)
        )

        fig.legend(
            handles,
            labels,
            title=None,
            loc="upper center",
            bbox_to_anchor=(0.75, 0.0),
            fancybox=False,
            shadow=False,
            ncol=2,
            frameon=False,
        )
        if ax.get_legend() is not None:
            ax.get_legend().remove()

        for extension in extensions:
            fig.savefig(
                pathlib.Path(dir) / f"hyperparameter_transfer.{extension}",
                dpi=400,
                bbox_inches="tight",
                pad_inches=0.085,
            )

        plt.close(fig)


if __name__ == "__main__":
    fire.Fire(plot)
