"""Plot experiment results."""

from __future__ import annotations

import pathlib
import re
import sys
from typing import Sequence

import fire
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import tueplots

sys.path.insert(0, "../../../experiments")
sys.path.insert(0, "analysis/experiments")
import utils

DATASETS = (
    "MNIST",
    "CIFAR10",
    "CIFAR100",
    "TinyImageNet",
    # "ImageNet",
)
# ARCHITECTURES = (
#     "MLP",
#     "LeNet5",
#     "ResNet",
# )
INFERENCE_METHODS = (
    "Standard",
    "Temperature Scaling",
    "Laplace (Last-layer, GS)",
    "Laplace (Last-layer, ML)",
    "Weight-space VI (Mean-field)",
    "Implicit Bias VI (Low-rank)",
    "SWAG",
    "Ensemble",
)

PARAMETRIZATIONS = (
    "SP",
    "NTP",
    "$\\mu$P",
)


def plot(
    dir: str = pathlib.Path.cwd() / "../../../../publication/paper/figures/",
    raw_results_file: str = "experiment_results.csv",
    extensions: Sequence[str] = ("pdf",),
):
    """Plot UQ and OOD detection results."""

    # Read the raw results
    df = pd.read_csv(raw_results_file)

    # Filter the columns
    df = df[
        [
            "dataset",
            "OOD Dataset",
            "model",
            "architecture",
            "inference_method",
            "parametrization",
            "num_parameters_and_buffers",
            "Runtime (s)",
            "Test Accuracy/dataloader_idx_0",
            "Test Top-5 Accuracy/dataloader_idx_0",
            "Test NLL/dataloader_idx_0",
            "Test ECE/dataloader_idx_0",
            "Test Norm. Entropy/dataloader_idx_0",
            "Test Accuracy (OOD)/dataloader_idx_1",
            "Test Top-5 Accuracy (OOD)/dataloader_idx_1",
            "Test NLL (OOD)/dataloader_idx_1",
            "Test ECE (OOD)/dataloader_idx_1",
            "Test Norm. Entropy (OOD)/dataloader_idx_1",
            "Test AUROC/dataloader_idx_1",
        ]
    ]

    # Rename the columns
    df.rename(
        columns={
            "dataset": "Dataset",
            "architecture": "Architecture",
            "inference_method": "Method",
            "parametrization": "Param.",
            "model": "Model",
            "num_parameters_and_buffers": "Num. Parameters",
            "Runtime (s)": "Training Runtime",
            "Test Accuracy/dataloader_idx_0": "Test Accuracy",
            "Test Top-5 Accuracy/dataloader_idx_0": "Test Top-5 Accuracy",
            "Test NLL/dataloader_idx_0": "Test NLL",
            "Test ECE/dataloader_idx_0": "Test ECE",
            "Test Norm. Entropy/dataloader_idx_0": "Test Norm. Entropy",
            "Test Accuracy (OOD)/dataloader_idx_1": "Test Accuracy (OOD)",
            "Test Top-5 Accuracy (OOD)/dataloader_idx_1": "Test Top-5 Accuracy (OOD)",
            "Test NLL (OOD)/dataloader_idx_1": "Test NLL (OOD)",
            "Test ECE (OOD)/dataloader_idx_1": "Test ECE (OOD)",
            "Test Norm. Entropy (OOD)/dataloader_idx_1": "Test Norm. Entropy (OOD)",
            "Test AUROC/dataloader_idx_1": "Test AUROC",
        },
        inplace=True,
    )

    # Column preprocessing
    df["Dataset"] = df["Dataset"].astype(
        pd.CategoricalDtype(categories=DATASETS, ordered=True)
    )
    df["Method"] = df["Method"].fillna("Standard")
    df["Method"] = df["Method"].replace(
        {
            "Implicit VI (Low-rank)": "Implicit Bias VI (Low-rank)",
            "Implicit VI (Kronecker)": "Implicit Bias VI (Kronecker)",
        },
    )
    df["Method"] = df["Method"].astype(
        pd.CategoricalDtype(categories=INFERENCE_METHODS, ordered=True)
    )
    df["Param."] = df["Param."].replace(
        {
            "Standard": "SP",
            "NeuralTangent": "NTP",
            "MaximalUpdate": "$\\mu$P",
        },
    )
    df["Param."] = df["Param."].astype(
        pd.CategoricalDtype(categories=PARAMETRIZATIONS, ordered=True)
    )

    df["Training Runtime"] = df["Training Runtime"].astype(float) * 1000.0

    df["Test Error"] = 1.0 - df["Test Accuracy"]
    df["Test Top-5 Error"] = 1.0 - df["Test Top-5 Accuracy"]
    df["Test Error (OOD)"] = 1.0 - df["Test Accuracy (OOD)"]
    df["Test Top-5 Error (OOD)"] = 1.0 - df["Test Top-5 Accuracy (OOD)"]

    # Final touches
    df.rename(
        columns={
            "Test Accuracy": "Test Accuracy $\\uparrow$",
            "Test Top-5 Accuracy": "Test Top-5 Accuracy $\\uparrow$",
            "Test Error": "Test Error $\\downarrow$",
            "Test Top-5 Error": "Test Top-5 Error $\\downarrow$",
            "Test NLL": "Test NLL $\\downarrow$",
            "Test ECE": "Test ECE $\\downarrow$",
        },
        inplace=True,
    )

    # In-distribution performance

    nrows = 3
    ncols = len(DATASETS)  # df["Dataset"].nunique()

    for param in df["Param."].unique():
        # Filter the dataframe
        df_plot = df[df["Param."] == param]

        with plt.rc_context(
            utils.plotting.style.neurips(
                rel_width=1.0,
                nrows=nrows,
                ncols=ncols,
            )
        ):
            fig, axs = plt.subplots(
                nrows=nrows,
                ncols=ncols,
                sharex="col",
                squeeze=False,
            )

            for idx_dataset, dataset in enumerate(DATASETS):

                metrics = [
                    "Test Error $\\downarrow$",
                    "Test NLL $\\downarrow$",
                    "Test ECE $\\downarrow$",
                    # "Test AUROC $\\uparrow$",
                ]
                if dataset in ["CIFAR100", "TinyImageNet"]:
                    metrics[0] = "Test Top-5 Error $\\downarrow$"

                for idx_metric, metric in enumerate(metrics):

                    df_plot_dataset = df_plot[df_plot["Dataset"] == dataset]

                    # # Stripplot
                    # sns.stripplot(
                    #     df_plot_dataset,
                    #     x="Method",
                    #     y=metric,
                    #     hue="Method",
                    #     hue_order=INFERENCE_METHODS,
                    #     palette=utils.plotting.colors.INFERENCE_METHODS,
                    #     ax=axs[idx_metric, idx_dataset],
                    #     alpha=0.2,
                    #     size=2.5,
                    # )

                    # Means
                    sns.pointplot(
                        df_plot_dataset,
                        x="Method",
                        y=metric,
                        hue="Method",
                        hue_order=INFERENCE_METHODS,
                        palette=utils.plotting.colors.INFERENCE_METHODS,
                        # errorbar=None,
                        err_kws={
                            "linewidth": 1.5,
                            "alpha": 0.3,
                        },
                        marker="_",
                        markersize=8,
                        # markeredgewidth=3,
                        ax=axs[idx_metric, idx_dataset],
                        legend=(
                            True if idx_dataset == 0 and idx_metric == 0 else False
                        ),
                    )

                    # Labels and ticks
                    if dataset == "MNIST":
                        pass
                    elif dataset == "CIFAR100":
                        if idx_metric > 0:
                            axs[idx_metric, idx_dataset].set(ylabel=None)
                    else:
                        axs[idx_metric, idx_dataset].set(ylabel=None)

                    axs[idx_metric, idx_dataset].set(xlabel=None)
                    axs[idx_metric, idx_dataset].tick_params(labelbottom=False)

                    # Title
                    axs[0, idx_dataset].set(title=dataset)

                # OOM for Weight-space VI on TinyImageNet
                if dataset == "TinyImageNet":
                    for ax in axs[:, idx_dataset]:
                        ymin, ymax = ax.get_ylim()
                        ax.text(
                            "Weight-space VI (Mean-field)",
                            (ymax - ymin) * 0.2 + ymin,
                            "OOM",
                            rotation="vertical",
                            verticalalignment="center",
                            horizontalalignment="center",
                            fontsize=6,
                            color=utils.plotting.colors.INFERENCE_METHODS[
                                "Weight-space VI (Mean-field)"
                            ],
                        )

            # # Custom y-axis ranges
            axs[0, 0].set(yscale="linear")
            axs[1, 0].set(yscale="linear")
            axs[2, 0].set(yscale="log")

            # axs[0, 1].set(yscale="linear", ylim=(0.45, 0.485))
            # axs[1, 1].set(yscale="linear", ylim=(1.25, 1.36))
            # axs[2, 1].set(yscale="linear", ylim=(0.0, 0.05))

            # Legend
            handles, labels = axs[0, 0].get_legend_handles_labels()
            fig.legend(
                handles,
                labels,
                title="Method",
                loc="upper center",
                bbox_to_anchor=(0.5, 0.0),
                fancybox=False,
                shadow=False,
                ncol=4,
                frameon=False,
            )
            axs[0, 0].get_legend().remove()

            fig.align_labels()
            # fig.tight_layout()

            for extension in extensions:
                fig.savefig(
                    pathlib.Path(dir)
                    / f"generalization_id_{re.sub(r'[^a-zA-Z0-9]', '', param)}.{extension}",
                    dpi=300,
                    bbox_inches="tight",
                    pad_inches=0.01,
                )

            plt.close(fig)

    # Plot in distribution generalization and computational cost
    plot_in_distribution_computational_cost(
        df,
        dir=dir,
        extensions=extensions,
    )

    # Out-of-distribution performance
    df.drop(
        labels=[
            "Test Accuracy $\\uparrow$",
            "Test Top-5 Accuracy $\\uparrow$",
            "Test Error $\\downarrow$",
            "Test Top-5 Error $\\downarrow$",
            "Test NLL $\\downarrow$",
            "Test ECE $\\downarrow$",
            "Test Norm. Entropy",
        ],
        axis=1,
        inplace=True,
    )
    df.rename(
        columns={
            "Test Accuracy (OOD)": "Test Accuracy $\\uparrow$",
            "Test Top-5 Accuracy (OOD)": "Test Top-5 Accuracy $\\uparrow$",
            "Test Error (OOD)": "Test Error $\\downarrow$",
            "Test Top-5 Error (OOD)": "Test Top-5 Error $\\downarrow$",
            "Test NLL (OOD)": "Test NLL $\\downarrow$",
            "Test ECE (OOD)": "Test ECE $\\downarrow$",
            "Test AUROC": "Test AUROC $\\uparrow$",
        },
        inplace=True,
    )

    nrows = 3
    ncols = len(DATASETS)  # df["Dataset"].nunique()

    for param in df["Param."].unique():
        # Filter the dataframe
        df_plot = df[df["Param."] == param]

        with plt.rc_context(
            utils.plotting.style.neurips(
                rel_width=1.0,
                nrows=nrows,
                ncols=ncols,
            )
        ):
            fig, axs = plt.subplots(
                nrows=nrows,
                ncols=ncols,
                sharex="col",
                squeeze=False,
            )

            for idx_dataset, dataset in enumerate(DATASETS):

                metrics = [
                    "Test Error $\\downarrow$",
                    "Test NLL $\\downarrow$",
                    "Test ECE $\\downarrow$",
                    # "Test AUROC $\\uparrow$",
                ]
                if dataset in ["CIFAR100", "TinyImageNet"]:
                    metrics[0] = "Test Top-5 Error $\\downarrow$"

                for idx_metric, metric in enumerate(metrics):

                    df_plot_dataset = df_plot[df_plot["Dataset"] == dataset]

                    # # Stripplot
                    # sns.stripplot(
                    #     df_plot_dataset,
                    #     x="Method",
                    #     y=metric,
                    #     hue="Method",
                    #     hue_order=INFERENCE_METHODS,
                    #     palette=utils.plotting.colors.INFERENCE_METHODS,
                    #     ax=axs[idx_metric, idx_dataset],
                    #     alpha=0.2,
                    #     size=2.5,
                    # )

                    # Means
                    sns.pointplot(
                        df_plot_dataset,
                        x="Method",
                        y=metric,
                        hue="Method",
                        hue_order=INFERENCE_METHODS,
                        palette=utils.plotting.colors.INFERENCE_METHODS,
                        # errorbar=None,
                        err_kws={
                            "linewidth": 1.5,
                            "alpha": 0.3,
                        },
                        marker="_",
                        markersize=8,
                        # markeredgewidth=3,
                        ax=axs[idx_metric, idx_dataset],
                        legend=(
                            True if idx_dataset == 0 and idx_metric == 0 else False
                        ),
                    )

                    # Labels and ticks
                    if dataset == "MNIST":
                        pass
                    elif dataset == "CIFAR100":
                        if idx_metric > 0:
                            axs[idx_metric, idx_dataset].set(ylabel=None)
                    else:
                        axs[idx_metric, idx_dataset].set(ylabel=None)
                    axs[idx_metric, idx_dataset].set(xlabel=None)
                    axs[idx_metric, idx_dataset].tick_params(labelbottom=False)

                    # Title
                    try:
                        axs[0, idx_dataset].set(
                            title=df_plot_dataset["OOD Dataset"].unique()[0]
                        )
                    except IndexError:
                        pass

                # OOM for Weight-space VI on TinyImageNet
                if dataset == "TinyImageNet":
                    for ax in axs[:, idx_dataset]:
                        ymin, ymax = ax.get_ylim()
                        ax.text(
                            "Weight-space VI (Mean-field)",
                            (ymax - ymin) * 0.2 + ymin,
                            "OOM",
                            rotation="vertical",
                            verticalalignment="center",
                            horizontalalignment="center",
                            fontsize=6,
                            color=utils.plotting.colors.INFERENCE_METHODS[
                                "Weight-space VI (Mean-field)"
                            ],
                        )

            # Custom y-axis ranges
            axs[0, 0].set(yscale="linear")
            axs[1, 0].set(yscale="linear")
            axs[2, 0].set(yscale="log" if param != "SP" else "linear")

            # axs[0, 1].set(yscale="linear", ylim=(0.45, 0.485))
            # axs[1, 1].set(yscale="linear", ylim=(1.25, 1.36))
            # axs[2, 1].set(yscale="linear", ylim=(0.0, 0.05))

            # Legend
            handles, labels = axs[0, 0].get_legend_handles_labels()
            fig.legend(
                handles,
                labels,
                title="Method",
                loc="upper center",
                bbox_to_anchor=(0.5, 0.0),
                fancybox=False,
                shadow=False,
                ncol=4,
                frameon=False,
            )
            axs[0, 0].get_legend().remove()

            fig.align_labels()
            # fig.tight_layout()

            for extension in extensions:
                fig.savefig(
                    pathlib.Path(dir)
                    / f"generalization_ood_{re.sub(r'[^a-zA-Z0-9]', '', param)}.{extension}",
                    dpi=300,
                    bbox_inches="tight",
                    pad_inches=0.01,
                )

            plt.close(fig)


def plot_in_distribution_computational_cost(
    df: pd.DataFrame,
    dataset: str = "CIFAR100",
    dir: str = pathlib.Path.cwd() / "../../../../publication/paper/figures/",
    extensions: Sequence[str] = ("pdf",),
):
    """Plot the number of parameters for each model."""

    nrows = 1
    ncols = 4

    for param in df["Param."].unique():
        # Filter the dataframe
        df_plot = df[df["Param."] == param]

        with plt.rc_context(
            utils.plotting.style.neurips(
                rel_width=1.0,
                nrows=nrows * 1.75,
                ncols=ncols,
            )
        ):
            fig, axs = plt.subplots(
                nrows=nrows,
                ncols=ncols,
                squeeze=False,
            )

            # Test metrics
            metrics = [
                "Test Error $\\downarrow$",
                "Test NLL $\\downarrow$",
            ]
            if dataset in ["CIFAR100", "TinyImageNet"]:
                metrics[0] = "Test Top-5 Error $\\downarrow$"

            for idx_metric, metric in enumerate(metrics):

                df_plot_dataset = df_plot[df_plot["Dataset"] == dataset]

                # # Stripplot
                # sns.stripplot(
                #     df_plot_dataset,
                #     x="Method",
                #     y=metric,
                #     hue="Method",
                #     hue_order=INFERENCE_METHODS,
                #     palette=utils.plotting.colors.INFERENCE_METHODS,
                #     ax=axs[0, idx_metric],
                #     alpha=0.2,
                #     size=2.5,
                # )

                # Means
                sns.pointplot(
                    df_plot_dataset,
                    x="Method",
                    y=metric,
                    hue="Method",
                    hue_order=INFERENCE_METHODS,
                    palette=utils.plotting.colors.INFERENCE_METHODS,
                    # errorbar=None,
                    err_kws={
                        "linewidth": 1.5,
                        "alpha": 0.3,
                    },
                    marker="_",
                    markersize=8,
                    # markeredgewidth=3,
                    ax=axs[0, idx_metric],
                    legend=(True if idx_metric == 0 and idx_metric == 0 else False),
                )

                # Title
                axs[0, idx_metric].set(title=dataset)

            # Number of parameters
            sns.barplot(
                df_plot[df_plot["Dataset"] == dataset],
                x="Method",
                y="Num. Parameters",
                hue="Method",
                hue_order=INFERENCE_METHODS,
                palette=utils.plotting.colors.INFERENCE_METHODS,
                ax=axs[0, 2],
                # log_scale=True,
                legend=False,
                errorbar=None,
                dodge=False,
            )

            axs[0, 2].set(title=dataset)

            # axs[0, 2].yaxis.set_major_formatter(
            #     mpl.ticker.LogFormatter(base=10, labelOnlyBase=True)
            # )
            axs[0, 2].set_yticks([1e7, 1e8])

            def format_sci_notation(x, pos):
                a, b = "{:.0e}".format(x).split("e")
                b = int(b)
                if a == "1":
                    s = "$" + r"10^" + f"{b}$"
                elif a == "0":
                    s = "0"
                else:
                    s = f"${a}" + r"\times 10^" + f"{b}$"
                return s

            axs[0, 2].yaxis.set_major_formatter(
                mpl.ticker.FuncFormatter(format_sci_notation)
            )

            for ax in [axs[0, 0], axs[0, 1], axs[0, 2]]:
                ax.set(xlabel=None)
                ax.tick_params(
                    bottom=False,
                    labelbottom=False,
                )

            # Training time
            sns.barplot(
                df_plot[
                    df_plot["Method"].isin(["Standard", "Implicit Bias VI (Low-rank)"])
                ],
                x="Dataset",
                y="Training Runtime",
                hue="Method",
                hue_order=utils.plotting.colors.INFERENCE_METHODS.keys(),
                palette=utils.plotting.colors.INFERENCE_METHODS,
                ax=axs[0, 3],
                # log_scale=True,
                dodge="auto",
                legend=False,
                width=0.5,
                gap=-3.5,
            )
            axs[0, 3].yaxis.set_major_formatter(
                utils.plotting.tick_formatters.TimeFormatter()
            )
            axs[0, 3].yaxis.set_major_locator(
                mpl.ticker.FixedLocator(locs=np.arange(5) * 3600000, nbins=4)
            )
            axs[0, 3].set_xticklabels(
                axs[0, 3].get_xticklabels(), rotation=25, ha="center"
            )
            axs[0, 3].tick_params(axis="x", pad=0)
            axs[0, 3].set(xlabel=None)

            # Legend
            handles, labels = axs[0, 0].get_legend_handles_labels()
            fig.legend(
                handles,
                labels,
                title="Method",
                loc="upper center",
                bbox_to_anchor=(0.5, 0.2),
                fancybox=False,
                shadow=False,
                ncol=4,
                frameon=False,
            )
            axs[0, 0].get_legend().remove()

            for extension in extensions:
                fig.savefig(
                    pathlib.Path(dir)
                    / f"generalization_id_computational_cost_{re.sub(r'[^a-zA-Z0-9]', '', param)}.{extension}",
                    dpi=300,
                    bbox_inches="tight",
                    pad_inches=0.01,
                )

            plt.close(fig)


if __name__ == "__main__":
    fire.Fire(plot)
