from __future__ import annotations

import pathlib
import sys
from typing import Sequence

import fire
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import seaborn as sns
import tueplots
from inferno import datasets

sys.path.insert(0, "../../../experiments")
sys.path.insert(0, "analysis/experiments")
import utils

DATASETS = (
    # "MNIST",
    # "CIFAR10",
    # "CIFAR100",
    "TinyImageNet",
)
INFERENCE_METHODS = (
    "Standard",
    # "Temperature Scaling",
    # "Laplace (Last-layer, GS)",
    # "Laplace (Last-layer, ML)",
    # "Weight-space VI (Mean-field)",
    "Implicit Bias VI (Low-rank)",
    # "Ensemble",
)
PARAMETRIZATIONS = (
    "SP",
    "NTP",
    "$\\mu$P",
)


def plot_robustness_illustration(
    dir: str = pathlib.Path.cwd() / "../../../../publication/paper/figures/",
    raw_results_file: str = "experiment_results.csv",
    dataset_root_dir: str = pathlib.Path("../../datasets"),
    extensions: Sequence[str] = ("pdf",),
):

    # In-distribution experiment results
    df = pd.read_csv(raw_results_file)
    df = df[
        [
            "dataset",
            "OOD Dataset",
            "model",
            "architecture",
            "inference_method",
            "parametrization",
            "num_parameters_and_buffers",
            "Runtime (s)",
            "Test Accuracy/dataloader_idx_0",
            "Test Top-5 Accuracy/dataloader_idx_0",
            "Test NLL/dataloader_idx_0",
            "Test ECE/dataloader_idx_0",
            "Test Norm. Entropy/dataloader_idx_0",
            "Test Accuracy (OOD)/dataloader_idx_1",
            "Test Top-5 Accuracy (OOD)/dataloader_idx_1",
            "Test NLL (OOD)/dataloader_idx_1",
            "Test ECE (OOD)/dataloader_idx_1",
            "Test Norm. Entropy (OOD)/dataloader_idx_1",
            "Test AUROC/dataloader_idx_1",
        ]
    ]
    df.rename(
        columns={
            "dataset": "Dataset",
            "architecture": "Architecture",
            "inference_method": "Method",
            "parametrization": "Param.",
            "model": "Model",
            "num_parameters_and_buffers": "Num. Parameters",
            "Runtime (s)": "Training Runtime",
            "Test Accuracy/dataloader_idx_0": "Test Accuracy",
            "Test Top-5 Accuracy/dataloader_idx_0": "Test Top-5 Accuracy",
            "Test NLL/dataloader_idx_0": "Test NLL",
            "Test ECE/dataloader_idx_0": "Test ECE",
            "Test Norm. Entropy/dataloader_idx_0": "Test Norm. Entropy",
            "Test Accuracy (OOD)/dataloader_idx_1": "Test Accuracy (OOD)",
            "Test Top-5 Accuracy (OOD)/dataloader_idx_1": "Test Top-5 Accuracy (OOD)",
            "Test NLL (OOD)/dataloader_idx_1": "Test NLL (OOD)",
            "Test ECE (OOD)/dataloader_idx_1": "Test ECE (OOD)",
            "Test Norm. Entropy (OOD)/dataloader_idx_1": "Test Norm. Entropy (OOD)",
            "Test AUROC/dataloader_idx_1": "Test AUROC",
        },
        inplace=True,
    )
    df["Dataset"] = df["Dataset"].astype(
        pd.CategoricalDtype(categories=DATASETS, ordered=True)
    )
    df["Method"] = df["Method"].fillna("Standard")
    df["Method"] = df["Method"].replace(
        {
            "Implicit VI (Low-rank)": "Implicit Bias VI (Low-rank)",
            "Implicit VI (Kronecker)": "Implicit Bias VI (Kronecker)",
        },
    )
    df["Method"] = df["Method"].astype(
        pd.CategoricalDtype(categories=INFERENCE_METHODS, ordered=True)
    )
    df["Param."] = df["Param."].replace(
        {
            "Standard": "SP",
            "NeuralTangent": "NTP",
            "MaximalUpdate": "$\\mu$P",
        },
    )
    df["Param."] = df["Param."].astype(
        pd.CategoricalDtype(categories=PARAMETRIZATIONS, ordered=True)
    )
    df["Training Runtime"] = df["Training Runtime"].astype(float) * 1000.0
    df["Test Error"] = 1.0 - df["Test Accuracy"]
    df["Test Top-5 Error"] = 1.0 - df["Test Top-5 Accuracy"]
    df["Test Error (OOD)"] = 1.0 - df["Test Accuracy (OOD)"]
    df["Test Top-5 Error (OOD)"] = 1.0 - df["Test Top-5 Accuracy (OOD)"]
    df.rename(
        columns={
            "Test Accuracy": "Test Accuracy $\\uparrow$",
            "Test Top-5 Accuracy": "Test Top-5 Accuracy $\\uparrow$",
            "Test Error": "Test Error $\\downarrow$",
            "Test Top-5 Error": "Test Top-5 Error $\\downarrow$",
            "Test NLL": "Test NLL $\\downarrow$",
            "Test ECE": "Test ECE $\\downarrow$",
        },
        inplace=True,
    )

    # Parametrization
    df = df[df["Param."] == "SP"]

    # Dataset
    dataset = datasets.TinyImageNet(root=dataset_root_dir, train=False)
    dataset_corrupted = datasets.TinyImageNetC(root=dataset_root_dir)
    # dataset = datasets.CIFAR10(root=pathlib.Path("../../datasets"))
    # dataset_corrupted = datasets.CIFAR10C(root=pathlib.Path("../../datasets"))

    # Plot
    nrows = 1
    ncols = 6
    with plt.rc_context(
        utils.plotting.style.neurips(
            rel_width=1.0,
            nrows=nrows * 1.9,
            ncols=ncols,
        )
    ):
        fig, axs = plt.subplots(
            nrows=nrows,
            ncols=ncols,
            # sharex="col",
            squeeze=False,
            width_ratios=(1, 1, 1, 1, 0.5, 0.5),
        )

        # In-distribution image(s)
        # axs[0, 0].imshow(
        #     np.asarray(dataset[220][0]),
        #     interpolation="nearest",
        # )
        # axs[0, 0].set_xticks([])
        # axs[0, 0].set_yticks([])

        # for i in np.arange(1000, 2000):
        #     if dataset[i][1] == 28:
        #         print(i)

        for i in range(4):
            row = i // 2
            col = i % 2
            img = np.asarray(dataset[[220, 1138, 1509, 1721][i]][0])
            ax_inset = axs[0, 0].inset_axes([col * 0.5, row * 0.5, 0.5, 0.5])
            ax_inset.imshow(img, interpolation="nearest")
            ax_inset.set_xticks([])
            ax_inset.set_yticks([])
            ax_inset.set_aspect("equal")
        axs[0, 0].axis("off")
        axs[0, 0].set(title="ID Images")

        # In-distribution performance
        df_plot_dataset = df[
            (df["Dataset"] == "TinyImageNet")
            & (df["Method"].isin(["Standard", "Implicit Bias VI (Low-rank)"]))
        ]

        # # Stripplot
        # sns.stripplot(
        #     df_plot_dataset,
        #     x="Method",
        #     y="Test NLL $\\downarrow$",
        #     hue="Method",
        #     hue_order=INFERENCE_METHODS,
        #     palette=utils.plotting.colors.INFERENCE_METHODS,
        #     ax=axs[0, 1],
        #     alpha=0.2,
        #     size=2.5,
        # )

        # Means
        sns.pointplot(
            df_plot_dataset,
            x="Method",
            y="Test NLL $\\downarrow$",
            hue="Method",
            hue_order=INFERENCE_METHODS,
            palette=utils.plotting.colors.INFERENCE_METHODS,
            # errorbar=None,
            err_kws={
                "linewidth": 1.5,
                "alpha": 0.3,
            },
            marker="_",
            markersize=12,
            # markeredgewidth=3,
            ax=axs[0, 1],
            legend=True,
        )
        axs[0, 1].set(
            xlabel=None,
            ylabel="Test Loss (CE) $\\downarrow$",
            title="ID",
            ylim=(1.5, 7.5),
        )
        axs[0, 1].set_xticklabels([])

        # Performance on corrupted images
        # # Stripplot
        # sns.stripplot(
        #     df_plot_dataset,
        #     x="Method",
        #     y="Test NLL (OOD)",
        #     hue="Method",
        #     hue_order=INFERENCE_METHODS,
        #     palette=utils.plotting.colors.INFERENCE_METHODS,
        #     ax=axs[0, 2],
        #     alpha=0.2,
        #     size=2.5,
        #     legend=False,
        # )

        # Means
        sns.pointplot(
            df_plot_dataset,
            x="Method",
            y="Test NLL (OOD)",
            hue="Method",
            hue_order=INFERENCE_METHODS,
            palette=utils.plotting.colors.INFERENCE_METHODS,
            # errorbar=None,
            err_kws={
                "linewidth": 1.5,
                "alpha": 0.3,
            },
            marker="_",
            markersize=12,
            # markeredgewidth=3,
            ax=axs[0, 2],
            legend=False,
        )
        axs[0, 2].set(
            xlabel=None,
            ylabel=None,
            title="OOD",
        )
        axs[0, 2].sharey(axs[0, 1])
        axs[0, 2].set_xticklabels([])
        axs[0, 2].tick_params(labelleft=False)

        # Corrupted image(s)
        for i in range(4):
            row = i // 2
            col = i % 2
            img = np.asarray(dataset_corrupted[7954 + i * 10000][0])
            ax_inset = axs[0, 3].inset_axes([col * 0.5, row * 0.5, 0.5, 0.5])
            ax_inset.imshow(img, interpolation="nearest")
            ax_inset.set_xticks([])
            ax_inset.set_yticks([])
            ax_inset.set_aspect("equal")
        axs[0, 3].axis("off")
        axs[0, 3].set(title="OOD Images")

        # Number of parameters
        sns.barplot(
            df[df["Dataset"] == "TinyImageNet"],
            x="Method",
            y="Num. Parameters",
            hue="Method",
            hue_order=INFERENCE_METHODS,
            palette=utils.plotting.colors.INFERENCE_METHODS,
            ax=axs[0, 4],
            # log_scale=True,
            legend=False,
            errorbar=None,
            dodge=False,
        )
        axs[0, 4].set(xlabel=None, ylabel="\\# Parameters")  # , title="TinyImageNet")
        axs[0, 4].set_xticklabels([])

        # Training time
        sns.barplot(
            df[df["Dataset"] == "TinyImageNet"],
            x="Dataset",
            y="Training Runtime",
            hue="Method",
            hue_order=INFERENCE_METHODS,
            palette=utils.plotting.colors.INFERENCE_METHODS,
            ax=axs[0, 5],
            legend=False,
            dodge="auto",
            width=1.0,
            gap=0.2,
        )
        axs[0, 5].yaxis.set_major_formatter(
            utils.plotting.tick_formatters.TimeFormatter()
        )
        axs[0, 5].yaxis.set_major_locator(
            mpl.ticker.FixedLocator(locs=np.arange(5) * 3600000, nbins=4)
        )
        # axs[0, 5].set_xticklabels(axs[0, 5].get_xticklabels(), rotation=25, ha="center")
        axs[0, 5].set_xticklabels([])
        # axs[0, 5].tick_params(axis="x", pad=0)
        axs[0, 5].set(xlabel=None, ylabel="Training Time")

        for ax in axs[0, :]:
            ax.tick_params(bottom=False)

        # Legend
        handles, labels = axs[0, 1].get_legend_handles_labels()
        labels = ["ResNet", "ResNet + Implicit Bias VI (Ours)"]
        fig.legend(
            handles,
            labels,
            # title="Method",
            loc="upper center",
            bbox_to_anchor=(0.5, 0.05),
            fancybox=False,
            shadow=False,
            ncol=4,
            frameon=False,
        )
        axs[0, 1].get_legend().remove()

        for extension in extensions:
            fig.savefig(
                pathlib.Path(dir) / f"robustness_illustration.{extension}",
                dpi=300,
                bbox_inches="tight",
                pad_inches=0.01,
            )

        plt.close(fig)


if __name__ == "__main__":
    fire.Fire(plot_robustness_illustration)
