"""Plot experiment results."""

from __future__ import annotations

import pathlib
import re
import sys
from typing import Sequence

import fire
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import tueplots

sys.path.insert(0, "../../../experiments")
sys.path.insert(0, "analysis/experiments")
import utils

MODELS = {
    "MLPIVI": "Implicit Bias VI",
    "MLPIVITheoreticalScaling": "Implicit Bias VI + Theoretical Scaling",
    "MLPIVITemperatureScaling": "Implicit Bias VI + Temperature Scaling",
}

INFERENCE_METHODS = (
    "Standard",
    "Temperature Scaling",
    "Laplace (Last-layer, GS)",
    "Laplace (Last-layer, ML)",
    "Weight-space VI (Mean-field)",
    "Implicit Bias VI (Low-rank)",
    "Ensemble",
)

PARAMETRIZATIONS = (
    "SP",
    "NTP",
    "$\\mu$P",
)


def plot(
    dir: str = pathlib.Path.cwd() / "../../../../publication/paper/figures/",
    raw_results_file: str = "experiment_results.csv",
    extensions: Sequence[str] = ("pdf",),
):

    # Read the raw results
    df = pd.read_csv(raw_results_file)

    # Filter the columns
    df = df[
        [
            "dataset",
            "model",
            "architecture",
            "inference_method",
            "parametrization",
            "num_parameters_and_buffers",
            "epoch",
            "Runtime (s)",
            "Validation Accuracy/dataloader_idx_0",
            "Validation NLL/dataloader_idx_0",
            "Validation ECE/dataloader_idx_0",
            "Mean Parameter Norm/dataloader_idx_0",
            # "Variance val x0/dataloader_idx_0",
            # "Variance val x1/dataloader_idx_0",
            # "Variance val x2/dataloader_idx_0",
            # "Variance val x3/dataloader_idx_0",
            # "Variance val x4/dataloader_idx_0",
            # "Variance val x5/dataloader_idx_0",
            # "Variance val x6/dataloader_idx_0",
            # "Variance val x7/dataloader_idx_0",
            # "Variance val x8/dataloader_idx_0",
            # "Variance val x9/dataloader_idx_0",
            # "Variance val x0/dataloader_idx_1",
            # "Variance val x1/dataloader_idx_1",
            # "Variance val x2/dataloader_idx_1",
            # "Variance val x3/dataloader_idx_1",
            # "Variance val x4/dataloader_idx_1",
            # "Variance val x5/dataloader_idx_1",
            # "Variance val x6/dataloader_idx_1",
            # "Variance val x7/dataloader_idx_1",
            # "Variance val x8/dataloader_idx_1",
            # "Variance val x9/dataloader_idx_1",
        ]
    ]

    # Rename the columns
    df.rename(
        columns={
            "dataset": "Dataset",
            "architecture": "Architecture",
            "inference_method": "Method",
            "parametrization": "Param.",
            "model": "Model",
            "num_parameters_and_buffers": "Num. Parameters",
            "epoch": "Epoch",
            "Runtime (s)": "Training Runtime",
            "Validation Accuracy/dataloader_idx_0": "Validation Accuracy",
            "Validation NLL/dataloader_idx_0": "Validation NLL",
            "Validation ECE/dataloader_idx_0": "Validation ECE",
            "Mean Parameter Norm/dataloader_idx_0": "Norm of Mean Params.",
        },
        inplace=True,
    )

    # Column preprocessing
    df["Method"] = df["Method"].fillna("Standard")
    df["Method"] = df["Method"].replace(
        {
            "Implicit VI (Low-rank)": "Implicit Bias VI (Low-rank)",
            "Implicit VI (Kronecker)": "Implicit Bias VI (Kronecker)",
        },
    )
    df["Method"] = df["Method"].astype(
        pd.CategoricalDtype(categories=INFERENCE_METHODS, ordered=True)
    )

    df["Model"] = df["Model"].replace(MODELS)
    df["Model"] = df["Model"].astype(
        pd.CategoricalDtype(categories=MODELS.values(), ordered=True)
    )
    df["Param."] = df["Param."].replace(
        {
            "Standard": "SP",
            "NeuralTangent": "NTP",
            "MaximalUpdate": "$\\mu$P",
        },
    )
    df["Param."] = df["Param."].astype(
        pd.CategoricalDtype(categories=PARAMETRIZATIONS, ordered=True)
    )

    df["Training Runtime"] = df["Training Runtime"].astype(float) * 1000.0

    df["Validation Error"] = 1.0 - df["Validation Accuracy"]

    # Final touches
    df.rename(
        columns={
            "Train Loss": "Training Loss $\\downarrow$",
            "Validation Accuracy": "Validation Accuracy $\\uparrow$",
            "Validation Error": "Validation Error $\\downarrow$",
            "Validation NLL": "Validation NLL $\\downarrow$",
            "Validation ECE": "Validation ECE $\\downarrow$",
        },
        inplace=True,
    )

    metrics = [
        "Validation Error $\\downarrow$",
        "Validation NLL $\\downarrow$",
        "Norm of Mean Params.",
    ]
    dataset = "TwoMoons"

    nrows = 1
    ncols = len(metrics)

    for param in df["Param."].unique():
        # Filter the dataframe
        df_plot = df[df["Param."] == param]
        df_plot_dataset = df_plot[df_plot["Dataset"] == dataset]

        with plt.rc_context(
            utils.plotting.style.neurips(
                rel_width=1.0,
                nrows=nrows * 1.2,
                ncols=ncols,
            )
        ):
            fig, axs = plt.subplots(
                nrows=nrows,
                ncols=ncols,
                sharex="row",
                squeeze=False,
            )

            for idx_metric, metric in enumerate(metrics):

                sns.lineplot(
                    df_plot_dataset,
                    x="Epoch",
                    y=metric,
                    hue="Model",
                    # hue_order=utils.plotting.colors.INFERENCE_METHODS.keys(),
                    # palette=utils.plotting.colors.INFERENCE_METHODS,
                    ax=axs[0, idx_metric],
                    legend=True if idx_metric == 0 else False,
                )

            # Custom y-axis ranges
            axs[0, 0].set(xscale="linear", yscale="log")
            axs[0, 1].set(yscale="log")
            axs[0, 2].set(yscale="linear")

            # Legend
            handles, labels = axs[0, 0].get_legend_handles_labels()
            fig.legend(
                handles,
                labels,
                # title="Model",
                loc="upper center",
                bbox_to_anchor=(0.5, 0.0),
                fancybox=False,
                shadow=False,
                ncol=3,
                frameon=False,
            )
            axs[0, 0].get_legend().remove()

            fig.align_labels()
            # fig.tight_layout()

            for extension in extensions:
                fig.savefig(
                    pathlib.Path(dir)
                    / f"theory_classification_{re.sub(r'[^a-zA-Z0-9]', '', param)}.{extension}",
                    dpi=300,
                    bbox_inches="tight",
                    pad_inches=0.08,
                )

            plt.close(fig)


if __name__ == "__main__":
    fire.Fire(plot)
