"""Plot experiment results."""

from __future__ import annotations

import pathlib
import re
import sys
from typing import Sequence

import fire
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import tueplots

sys.path.insert(0, "../../../experiments")
sys.path.insert(0, "analysis/experiments")
import utils

DATASETS = (
    "MNIST",
    # "CIFAR10",
    # "CIFAR100",
    # "TinyImageNet",
    # "ImageNet",
)

PARAMETRIZATIONS = (
    "SP",
    "NTP",
    "$\\mu$P",
)


def plot(
    dir: str = pathlib.Path.cwd()
    / "../../../../publication/paper/figures/parameter_sample_size/",
    raw_results_file: str = "experiment_results.csv",
    extensions: Sequence[str] = ("pdf",),
    batch_size: int = 1024,
    lr: float = 0.003,
):
    """Plot experiment results."""

    # Create output directory
    dir = pathlib.Path(dir)
    dir.mkdir(parents=True, exist_ok=True)

    # Read the raw results
    df = pd.read_csv(raw_results_file)

    # Filter the columns
    df = df[
        [
            "Seed",
            "dataset",
            "OOD Dataset",
            "model",
            "architecture",
            "inference_method",
            "parametrization",
            "num_samples_train",
            "optimizer",
            "max_epochs",
            "batch_size",
            "lr",
            "momentum",
            "nesterov",
            "Test Accuracy/dataloader_idx_0",
            "Test NLL/dataloader_idx_0",
            "Test ECE/dataloader_idx_0",
            "Test AUROC/dataloader_idx_1",
        ]
    ]

    df = df[df["nesterov"] == False]
    df = df[df["batch_size"] == batch_size]

    # Rename the columns
    df.rename(
        columns={
            "dataset": "Dataset",
            "architecture": "Architecture",
            "inference_method": "Method",
            "parametrization": "Param.",
            "model": "Model",
            "num_samples_train": "Parameter Samples",
            "momentum": "Momentum",
            "batch_size": "Batch Size",
            "lr": "Learning Rate",
            "Test Accuracy/dataloader_idx_0": "Test Accuracy",
            "Test NLL/dataloader_idx_0": "Test NLL",
            "Test ECE/dataloader_idx_0": "Test ECE",
            "Test AUROC/dataloader_idx_1": "Test AUROC",
        },
        inplace=True,
    )
    df.dropna(subset=["optimizer"], inplace=True)

    # Column preprocessing
    def optimizer_name(df_row):
        optimizer_name = df_row["optimizer"]
        if (df_row["Momentum"] == 0) or (df_row["optimizer"] == "Adam"):
            return optimizer_name
        else:
            if df_row["nesterov"]:
                return optimizer_name + " + Nesterov Momentum"
            else:
                return optimizer_name + " + Heavy Ball Momentum"

    df["Optimizer"] = df[["optimizer", "Momentum", "nesterov"]].apply(
        optimizer_name, axis=1
    )

    df["Dataset"] = df["Dataset"].astype(
        pd.CategoricalDtype(categories=DATASETS, ordered=True)
    )
    df["Method"] = df["Method"].fillna("Standard")
    df["Method"] = df["Method"].astype(
        pd.CategoricalDtype(
            categories=utils.plotting.colors.INFERENCE_METHODS.keys(), ordered=True
        )
    )
    df["Param."] = df["Param."].replace(
        {
            "Standard": "SP",
            "NeuralTangent": "NTP",
            "MaximalUpdate": "$\\mu$P",
        },
    )
    df["Param."] = df["Param."].astype(
        pd.CategoricalDtype(categories=PARAMETRIZATIONS, ordered=True)
    )
    df["Test Error"] = 1.0 - df["Test Accuracy"]

    # Final touches
    df.rename(
        columns={
            "Test Accuracy": "Test Accuracy $\\uparrow$",
            "Test Error": "Test Error $\\downarrow$",
            "Test NLL": "Test NLL $\\downarrow$",
            "Test ECE": "Test ECE $\\downarrow$",
            "Test AUROC": "Test AUROC $\\uparrow$",
        },
        inplace=True,
    )

    metrics = [
        "Test Error $\\downarrow$",
        "Test NLL $\\downarrow$",
        "Test ECE $\\downarrow$",
        # "Test AUROC $\\uparrow$",
    ]

    ## ---- Learning rate ---- ##
    plot_learning_rate(
        df=df,
        metrics=metrics,
        optimizers=df["Optimizer"].unique(),
        dir=dir,
        extensions=extensions,
    )
    plot_learning_rate_compact(
        df=df,
        dir=dir,
        extensions=extensions,
        lr=lr,
    )

    ## ---- Optimizer choice ---- ##
    plot_optimizer_choice(
        df=df,
        metrics=metrics,
        dir=dir,
        extensions=extensions,
        lr=lr,
    )

    ## ---- Learning Rate vs Parameter Samples ---- ##
    plot_learning_rate_parameter_samples(
        df=df,
        dir=dir,
        extensions=extensions,
    )

    # Read the raw results
    p = pathlib.Path(raw_results_file)
    df = pd.read_csv(p.parents[0] / pathlib.Path(p.stem + "_training" + p.suffix))

    # Filter the columns
    df = df[
        [
            "Seed",
            "dataset",
            "OOD Dataset",
            "model",
            "architecture",
            "inference_method",
            "parametrization",
            "num_samples_train",
            "optimizer",
            "max_epochs",
            "batch_size",
            "lr",
            "momentum",
            "nesterov",
            "trainer/global_step",
            "epoch",
            "Validation Accuracy",
            "Validation NLL",
            "Validation ECE",
        ]
    ]

    df = df[df["nesterov"] == False]
    df = df[df["batch_size"] == batch_size]

    # Rename the columns
    df.rename(
        columns={
            "dataset": "Dataset",
            "architecture": "Architecture",
            "inference_method": "Method",
            "parametrization": "Param.",
            "model": "Model",
            "num_samples_train": "Parameter Samples",
            "momentum": "Momentum",
            "lr": "Learning Rate",
            "batch_size": "Batch Size",
            "trainer/global_step": "Optimizer Step",
            "epoch": "Epoch",
            "Validation Accuracy": "Valid. Accuracy",
            "Validation NLL": "Valid. NLL",
            "Validation ECE": "Valid. ECE",
        },
        inplace=True,
    )
    df.dropna(subset=["optimizer"], inplace=True)

    df["Optimizer"] = df[["optimizer", "Momentum", "nesterov"]].apply(
        optimizer_name, axis=1
    )

    df["Dataset"] = df["Dataset"].astype(
        pd.CategoricalDtype(categories=DATASETS, ordered=True)
    )
    df["Method"] = df["Method"].fillna("Standard")
    df["Method"] = df["Method"].astype(
        pd.CategoricalDtype(
            categories=utils.plotting.colors.INFERENCE_METHODS.keys(), ordered=True
        )
    )
    df["Param."] = df["Param."].replace(
        {
            "Standard": "SP",
            "NeuralTangent": "NTP",
            "MaximalUpdate": "$\\mu$P",
        },
    )
    df["Param."] = df["Param."].astype(
        pd.CategoricalDtype(categories=PARAMETRIZATIONS, ordered=True)
    )
    df["Parameter Samples"] = df["Parameter Samples"].astype(int)
    df["Valid. Error"] = 1.0 - df["Valid. Accuracy"]

    # Final touches
    df.rename(
        columns={
            "Valid. Accuracy": "Valid. Accuracy $\\uparrow$",
            "Valid. Error": "Valid. Error $\\downarrow$",
            "Valid. NLL": "Valid. NLL $\\downarrow$",
            "Valid. ECE": "Valid. ECE $\\downarrow$",
        },
        inplace=True,
    )
    metrics = [
        "Valid. Error $\\downarrow$",
        "Valid. NLL $\\downarrow$",
        "Valid. ECE $\\downarrow$",
    ]

    plot_optimizer_steps_for_parameter_samples(
        df=df,
        metrics=metrics,
        dir=dir,
        extensions=extensions,
        lr=lr,
    )

    plot_optimizer_steps_for_learning_rates(
        df=df,
        metrics=metrics,
        dir=dir,
        extensions=extensions,
    )


def plot_learning_rate(
    df: pd.DataFrame,
    metrics: list[str],
    optimizers: list[str],
    dir: str,
    extensions: Sequence[str],
) -> None:

    nrows = 1
    ncols = len(metrics)

    for param in df["Param."].unique():
        for optimizer in optimizers:
            # Filter the dataframe
            df_plot = df[(df["Optimizer"] == optimizer) & (df["Param."] == param)]

            if optimizer == "SGD":
                palette = sns.color_palette(
                    "Blues", n_colors=df_plot["Learning Rate"].nunique()
                )
            elif optimizer == "SGD + Heavy Ball Momentum":
                palette = sns.color_palette(
                    "Oranges", n_colors=df_plot["Learning Rate"].nunique()
                )
            elif optimizer == "SGD + Nesterov Momentum":
                palette = sns.color_palette(
                    "Greens", n_colors=df_plot["Learning Rate"].nunique()
                )
            else:
                raise ValueError(f"Unknown optimizer: {optimizer}")

            with plt.rc_context(
                utils.plotting.style.neurips(
                    rel_width=1.0,
                    nrows=nrows * 1.25,
                    ncols=ncols,
                )
            ):
                fig, axs = plt.subplots(
                    nrows=nrows,
                    ncols=ncols,
                    squeeze=False,
                    sharex=True,
                )

                for idx_metric, metric in enumerate(metrics):

                    sns.lineplot(
                        df_plot,
                        x="Parameter Samples",
                        y=metric,
                        hue="Learning Rate",
                        # hue_norm=mpl.colors.LogNorm(),
                        palette=palette,
                        # errorbar=,
                        marker="o",
                        markersize=3,
                        ax=axs[0, idx_metric],
                        legend=(True if idx_metric == 0 and idx_metric == 0 else False),
                    )

                    # Labels and ticks
                    axs[0, idx_metric].set_xscale("log", base=10)

                    axs[0, 0].set(ylim=(0.0, 0.1))
                    axs[0, 1].set(ylim=(0.0, 0.4))
                    axs[0, 2].set(yscale="log")
                    # axs[0, 3].set(ylim=(0.5, 1.0))

                    # for ax in axs[0, 0:3]:
                    #     ax.set(yscale="log")

                # Legend
                handles, labels = axs[0, 0].get_legend_handles_labels()
                fig.legend(
                    handles,
                    labels,
                    title="Learning Rate",
                    loc="upper center",
                    bbox_to_anchor=(0.5, 0.1),
                    fancybox=False,
                    shadow=False,
                    ncol=df["Learning Rate"].nunique(),
                    frameon=False,
                )
                if axs[0, 0].get_legend() is not None:
                    axs[0, 0].get_legend().remove()

                fig.align_labels()
                fig.tight_layout()

                for extension in extensions:
                    fig.savefig(
                        pathlib.Path(dir)
                        / f"parameter_sample_size_learning_rate_{re.sub(r'[^a-zA-Z0-9]','',optimizer)}_{re.sub(r'[^a-zA-Z0-9]', '', param)}.{extension}",
                        dpi=400,
                        bbox_inches="tight",
                        pad_inches=0.085,
                    )

                plt.close(fig)


def plot_learning_rate_compact(
    df: pd.DataFrame,
    dir: str,
    extensions: Sequence[str],
    lr: float,
    optimizers: list[str] = ["SGD", "SGD + Heavy Ball Momentum"],
):
    nrows = 2
    ncols = 1

    for param in df["Param."].unique():
        with plt.rc_context(
            utils.plotting.style.neurips(
                rel_width=0.35,
                nrows=nrows * 1.1,
                ncols=ncols,
            )
        ):
            fig, axs = plt.subplots(
                nrows=nrows,
                ncols=ncols,
                squeeze=False,
                sharex="col",
                sharey="col",
            )

            for idx_optimizer, optimizer in enumerate(optimizers):
                # Filter the dataframe
                df_plot = df[(df["Optimizer"] == optimizer) & (df["Param."] == param)]

                if optimizer == "SGD":
                    palette = sns.color_palette(
                        "Blues", n_colors=df_plot["Learning Rate"].nunique()
                    )
                elif optimizer == "SGD + Heavy Ball Momentum":
                    palette = sns.color_palette(
                        "Oranges", n_colors=df_plot["Learning Rate"].nunique()
                    )
                elif optimizer == "SGD + Nesterov Momentum":
                    palette = sns.color_palette(
                        "Greens", n_colors=df_plot["Learning Rate"].nunique()
                    )
                else:
                    raise ValueError(f"Unknown optimizer: {optimizer}")

                sns.lineplot(
                    df_plot,
                    x="Parameter Samples",
                    y="Test Error $\\downarrow$",
                    hue="Learning Rate",
                    # hue_norm=mpl.colors.LogNorm(),
                    palette=palette,
                    # errorbar=,
                    marker="o",
                    markersize=3,
                    ax=axs[idx_optimizer, 0],
                    # legend=(
                    #     True if idx_optimizer == 0 and idx_optimizer == 0 else False
                    # ),
                    legend=False,
                )

                axs[idx_optimizer, 0].set(
                    xscale="log",
                    ylim=(0.0, 0.1),
                    title="SGD + Momentum" if optimizer != "SGD" else "SGD",
                )

            # # Legend
            # handles, labels = axs[0, 0].get_legend_handles_labels()
            # fig.legend(
            #     handles,
            #     labels,
            #     title="Learning Rate",
            #     loc="upper center",
            #     bbox_to_anchor=(0.5, 0.1),
            #     fancybox=False,
            #     shadow=False,
            #     ncol=2,
            #     frameon=False,
            # )
            # if axs[0, 0].get_legend() is not None:
            #     axs[0, 0].get_legend().remove()

            fig.align_labels()
            fig.tight_layout()

            for extension in extensions:
                fig.savefig(
                    pathlib.Path(dir)
                    / f"parameter_sample_size_learning_rate_compact_{re.sub(r'[^a-zA-Z0-9]', '', param)}.{extension}",
                    dpi=400,
                    bbox_inches="tight",
                    pad_inches=0.085,
                )

            plt.close(fig)


def plot_optimizer_choice(
    df: pd.DataFrame,
    metrics: list[str],
    dir: str,
    extensions: Sequence[str],
    lr: float,
) -> None:
    nrows = 1
    ncols = len(metrics)

    for param in df["Param."].unique():
        # Filter the dataframe
        df_plot = df[(df["Learning Rate"] == lr) & (df["Param."] == param)]

        with plt.rc_context(
            utils.plotting.style.neurips(
                rel_width=1.0,
                nrows=nrows * 1.5,
                ncols=ncols,
            )
        ):
            fig, axs = plt.subplots(
                nrows=nrows,
                ncols=ncols,
                squeeze=False,
                sharex=True,
            )

            for idx_metric, metric in enumerate(metrics):

                sns.lineplot(
                    df_plot,
                    x="Parameter Samples",
                    y=metric,
                    hue="Optimizer",
                    # palette=,
                    # errorbar=,
                    marker="o",
                    markersize=3,
                    ax=axs[0, idx_metric],
                    legend=(True if idx_metric == 0 and idx_metric == 0 else False),
                )

                # Labels and ticks
                axs[0, idx_metric].set_xscale("log", base=10)

                axs[0, 0].set(ylim=(0.0, 0.075))
                axs[0, 1].set(ylim=(0.0, 0.15))
                axs[0, 2].set(ylim=(0.0, 0.05))
                # axs[0, 3].set(ylim=(0.5, 1.0))

                # for ax in axs[0, 0:3]:
                #     ax.set(yscale="log")

            # Legend
            handles, labels = axs[0, 0].get_legend_handles_labels()
            fig.legend(
                handles,
                labels,
                title="Optimizer",
                loc="upper center",
                bbox_to_anchor=(0.5, 0.0),
                fancybox=False,
                shadow=False,
                ncol=4,
                frameon=False,
            )

            if axs[0, 0].get_legend() is not None:
                axs[0, 0].get_legend().remove()

            fig.align_labels()
            fig.tight_layout()

            for extension in extensions:
                fig.savefig(
                    pathlib.Path(dir)
                    / f"parameter_sample_size_optimizer_choice_{re.sub(r'[^a-zA-Z0-9]', '', param)}.{extension}",
                    dpi=400,
                    bbox_inches="tight",
                    pad_inches=0.085,
                )

            plt.close(fig)


def plot_learning_rate_parameter_samples(
    df: pd.DataFrame,
    dir: str,
    extensions: Sequence[str],
) -> None:

    nrows = 1
    ncols = 1

    for param in df["Param."].unique():
        # Filter the dataframe
        df = df[df["Param."] == param]

        with plt.rc_context(
            utils.plotting.style.neurips(
                rel_width=0.5,
                nrows=nrows,
                ncols=ncols,
            )
        ):
            fig, axs = plt.subplots(
                nrows=nrows,
                ncols=ncols,
                squeeze=False,
                sharex=True,
            )

            # Find optimal parameter sample size for each learning rate
            df_plot = df[df["Optimizer"] == "SGD"]
            df_plot = df_plot.loc[
                df_plot.groupby(["Batch Size", "Learning Rate", "Seed"])[
                    "Test NLL $\\downarrow$"
                ].idxmax()
            ].reset_index()

            sns.lineplot(
                df_plot,
                x="Learning Rate",
                y="Parameter Samples",
                hue="Batch Size",
                palette=sns.color_palette(),
                marker="o",
                markersize=3,
                ax=axs[0, 0],
                legend=True,
            )

            # axs[0, 0].set_xscale("log", base=10)
            # axs[0, 0].set_yscale("log", base=10)

            # Legend
            handles, labels = axs[0, 0].get_legend_handles_labels()
            fig.legend(
                handles,
                labels,
                title="Batch Size",
                loc="upper center",
                bbox_to_anchor=(0.5, 0.0),
                fancybox=False,
                shadow=False,
                ncol=4,
                frameon=False,
            )
            if axs[0, 0].get_legend() is not None:
                axs[0, 0].get_legend().remove()

            fig.align_labels()
            fig.tight_layout()

            for extension in extensions:
                fig.savefig(
                    pathlib.Path(dir)
                    / f"learning_rate_optimal_sample_size_{re.sub(r'[^a-zA-Z0-9]', '', param)}.{extension}",
                    dpi=400,
                    bbox_inches="tight",
                    pad_inches=0.085,
                )

            plt.close(fig)


def plot_optimizer_steps_for_parameter_samples(
    df: pd.DataFrame,
    metrics: list[str],
    dir: str,
    extensions: Sequence[str],
    lr: float,
) -> None:
    nrows = 1
    ncols = len(metrics)

    sgd_palette = sns.color_palette("Blues", n_colors=df["Parameter Samples"].nunique())
    sgd_hbm_palette = sns.color_palette(
        "Oranges", n_colors=df["Parameter Samples"].nunique()
    )
    sgd_nm_palette = sns.color_palette(
        "Greens", n_colors=df["Parameter Samples"].nunique()
    )

    for param in df["Param."].unique():
        # Filter the dataframe
        df_plot = df[(df["Learning Rate"] == lr) & (df["Param."] == param)]

        with plt.rc_context(
            utils.plotting.style.neurips(
                rel_width=1.0,
                nrows=nrows * 1.3,
                ncols=ncols,
            )
        ):
            fig, axs = plt.subplots(
                nrows=nrows,
                ncols=ncols,
                squeeze=False,
                sharex=True,
            )

            for idx_metric, metric in enumerate(metrics):

                for optimizer, palette in [
                    ("SGD", sgd_palette),
                    ("SGD + Heavy Ball Momentum", sgd_hbm_palette),
                    ("SGD + Nesterov Momentum", sgd_nm_palette),
                ]:
                    sns.lineplot(
                        df_plot[df_plot["Optimizer"] == optimizer],
                        x="Optimizer Step",
                        y=metric,
                        hue="Parameter Samples",
                        palette=palette,
                        # marker="o",
                        # markersize=3,
                        ax=axs[0, idx_metric],
                        zorder={
                            "SGD": 1,
                            "SGD + Heavy Ball Momentum": -1,
                            "SGD + Nesterov Momentum": -2,
                        }[optimizer],
                        legend=(True if idx_metric == 0 and idx_metric == 0 else False),
                    )

                # Labels and ticks
                axs[0, idx_metric].set_xscale("log", base=10)

            # Ax limits
            axs[0, 0].set(ylim=(0.0, 0.5))
            axs[0, 1].set(ylim=(0.0, 3.5))
            axs[0, 2].set(ylim=(0.0, 0.6))

            # Legend
            handles, labels = axs[0, 0].get_legend_handles_labels()
            fig.legend(
                handles,
                labels,
                title="Parameter Samples",
                loc="upper center",
                bbox_to_anchor=(0.5, 0.05),
                fancybox=False,
                shadow=False,
                ncol=4 * df["Optimizer"].nunique(),
                frameon=False,
            )
            if axs[0, 0].get_legend() is not None:
                axs[0, 0].get_legend().remove()

            fig.align_labels()
            fig.tight_layout()

            for extension in extensions:
                fig.savefig(
                    pathlib.Path(dir)
                    / f"parameter_sample_size_optimizer_step_{re.sub(r'[^a-zA-Z0-9]', '', param)}.{extension}",
                    dpi=400,
                    bbox_inches="tight",
                    pad_inches=0.01,
                )

            plt.close(fig)


def plot_optimizer_steps_for_learning_rates(
    df: pd.DataFrame,
    metrics: list[str],
    dir: str,
    extensions: Sequence[str],
    parameter_sample_size: int = 8,
) -> None:
    nrows = 1
    ncols = len(metrics)

    sgd_palette = sns.color_palette("Blues", n_colors=df["Learning Rate"].nunique())
    sgd_hbm_palette = sns.color_palette(
        "Oranges", n_colors=df["Learning Rate"].nunique()
    )
    sgd_nm_palette = sns.color_palette("Greens", n_colors=df["Learning Rate"].nunique())

    for param in df["Param."].unique():
        # Filter the dataframe
        df_plot = df[
            (df["Parameter Samples"] == parameter_sample_size) & (df["Param."] == param)
        ]

        with plt.rc_context(
            utils.plotting.style.neurips(
                rel_width=1.0,
                nrows=nrows * 1.3,
                ncols=ncols,
            )
        ):
            fig, axs = plt.subplots(
                nrows=nrows,
                ncols=ncols,
                squeeze=False,
                sharex=True,
            )

            for idx_metric, metric in enumerate(metrics):

                for optimizer, palette in [
                    ("SGD", sgd_palette),
                    ("SGD + Heavy Ball Momentum", sgd_hbm_palette),
                    # ("SGD + Nesterov Momentum", sgd_nm_palette),
                ]:
                    sns.lineplot(
                        df_plot[df_plot["Optimizer"] == optimizer],
                        x="Optimizer Step",
                        y=metric,
                        hue="Learning Rate",
                        palette=palette,
                        # marker="o",
                        # markersize=3,
                        ax=axs[0, idx_metric],
                        zorder={
                            "SGD": 1,
                            "SGD + Heavy Ball Momentum": -1,
                            "SGD + Nesterov Momentum": -2,
                        }[optimizer],
                        legend=(True if idx_metric == 0 and idx_metric == 0 else False),
                    )

                # Labels and ticks
                axs[0, idx_metric].set_xscale("log", base=10)

            # Ax limits
            axs[0, 0].set(ylim=(0.0, 0.3))
            axs[0, 1].set(ylim=(0.0, 3.0))
            axs[0, 2].set(ylim=(0.0, 0.6))

            # Legend
            handles, labels = axs[0, 0].get_legend_handles_labels()
            fig.legend(
                handles,
                labels,
                title="Learning Rate",
                loc="upper center",
                bbox_to_anchor=(0.5, 0.05),
                fancybox=False,
                shadow=False,
                ncol=df["Optimizer"].nunique(),
                frameon=False,
            )
            if axs[0, 0].get_legend() is not None:
                axs[0, 0].get_legend().remove()

            fig.align_labels()
            fig.tight_layout()

            for extension in extensions:
                fig.savefig(
                    pathlib.Path(dir)
                    / f"learning_rates_optimizer_step_{re.sub(r'[^a-zA-Z0-9]', '', param)}.{extension}",
                    dpi=400,
                    bbox_inches="tight",
                    pad_inches=0.01,
                )

            plt.close(fig)


if __name__ == "__main__":
    fire.Fire(plot)
