"""Format experiment results into a latex table."""

from __future__ import annotations

import fire
import pandas as pd

DATASETS = (
    "MNIST",
    "CIFAR10",
    # "ImageNet",
)
ARCHITECTURES = (
    "MLP",
    "LeNet5",
    # "ResNet",
)
INFERENCE_METHODS = (
    "Standard",
    "Weight-space VI (Mean-field)",
    "Laplace (Last-layer, GS)",
    "Laplace (Last-layer, ML)",
    "Ensemble",
    "Implicit VI (Kronecker)",
    "Implicit VI (Low-rank)",
)

PARAMETRIZATIONS = (
    "SP",
    "NTP",
    "$\\mu$P",
)

METRICS = (
    "Test Accuracy",
    "Test NLL",
    "Test ECE",
    "Test AUROC",
)

COLUMNS_TO_GROUPBY = [
    "Dataset",
    "OOD Dataset",
    "Architecture",
    "Method",
    "Param.",
]


def latex_table(
    file: str = "../../../../publication/paper/tables/ood_detection_results.tex",
    raw_results_file: str = "experiment_results.csv",
    num_stdevs: int = 2,
    decimals: int = 3,
):
    """Format experiment results into a latex table."""

    # Read the raw results
    df = pd.read_csv(raw_results_file)

    # Filter the columns
    df = df[
        [
            "dataset",
            "OOD Dataset",
            "model",
            "architecture",
            "inference_method",
            "parametrization",
            "Test Accuracy/dataloader_idx_0",
            "Test NLL/dataloader_idx_0",
            "Test ECE/dataloader_idx_0",
            "Test AUROC/dataloader_idx_1",
        ]
    ]

    # Rename the columns
    df.rename(
        columns={
            "dataset": "Dataset",
            "architecture": "Architecture",
            "inference_method": "Method",
            "parametrization": "Param.",
            "model": "Model",
            "Test Accuracy/dataloader_idx_0": "Test Accuracy",
            "Test NLL/dataloader_idx_0": "Test NLL",
            "Test ECE/dataloader_idx_0": "Test ECE",
            "Test AUROC/dataloader_idx_1": "Test AUROC",
        },
        inplace=True,
    )

    # Column preprocessing
    df["Dataset"] = df["Dataset"].astype(
        pd.CategoricalDtype(categories=DATASETS, ordered=True)
    )
    df["Method"] = df["Method"].fillna("Standard")
    df["Method"] = df["Method"].astype(
        pd.CategoricalDtype(categories=INFERENCE_METHODS, ordered=True)
    )
    df["Param."] = df["Param."].replace(
        {
            "Standard": "SP",
            "NeuralTangent": "NTP",
            "MaximalUpdate": "$\\mu$P",
        },
    )
    df["Param."] = df["Param."].astype(
        pd.CategoricalDtype(categories=PARAMETRIZATIONS, ordered=True)
    )

    # Aggregate metrics
    df = df.groupby(
        by=COLUMNS_TO_GROUPBY,
        dropna=False,
        observed=True,
    ).aggregate({metric: ["mean", "std"] for metric in METRICS})
    df = df.reset_index()

    # Format numerical values
    for metric in METRICS:
        # Find the best value for each group
        best_values_idcs = df.groupby(COLUMNS_TO_GROUPBY[:-2], observed=True)[
            metric
        ].transform(
            lambda x: (
                x.idxmax() if metric in ["Test Accuracy", "Test AUROC"] else x.idxmin()
            )
        )

        # Combine mean and uncertainty and bold best values
        combined_mean_and_uncertainty = df.apply(
            lambda row: (
                (
                    "{\\bfseries" + f"{row[metric]['mean']:.{decimals}f}" + "}"
                    if round(
                        df.loc[best_values_idcs[metric]["mean"][row.name], metric][
                            "mean"
                        ],
                        decimals,
                    )
                    == round(row[metric]["mean"], decimals)
                    else f"{row[metric]['mean']:.{decimals}f}"
                )
                + "\\scriptsize$\\pm$ "
                + f"{(num_stdevs * row[metric]['std']):.{decimals}f}"
            ),
            axis=1,
        )

        # Drop columns for the current metric
        df.drop(columns=[metric], inplace=True, level=0)

        # Replace column with new combined column
        df[metric] = combined_mean_and_uncertainty

    # Adjust multiindices
    df = df.droplevel(1, axis=1)

    # Final touches
    df.rename(
        columns={
            "Test Accuracy": "Test Accuracy $\\uparrow$",
            "Test NLL": "Test NLL $\\downarrow$",
            "Test ECE": "Test ECE $\\downarrow$",
            "Test AUROC": "Test AUROC $\\uparrow$",
        },
        inplace=True,
    )

    # Colors
    df["Method"] = df["Method"].cat.rename_categories(
        {
            "Standard": "\\textcolor{Standard}{Standard}",
            "Weight-space VI (Mean-field)": "\\textcolor{WSVI}{Weight-space VI (Mean-field)}",
            "Laplace (Last-layer, GS)": "\\textcolor{Laplace}{Laplace (Last-layer, GS)}",
            "Laplace (Last-layer, ML)": "\\textcolor{Laplace}{Laplace (Last-layer, ML)}",
            "Ensemble": "\\textcolor{Ensemble}{Ensemble}",
            "Implicit VI (Kronecker)": "\\textcolor{ImplicitVI}{Implicit VI (Kronecker)}",
            "Implicit VI (Low-rank)": "\\textcolor{ImplicitVI}{Implicit VI (Low-rank)}",
        },
    )

    # Multiindex rows and columns
    df.set_index(COLUMNS_TO_GROUPBY, inplace=True)

    # Save the table to a latex file
    # df.to_latex(
    #     file,
    #     # index=False,
    #     multirow=True,
    #     multicolumn=True,
    #     float_format="%.2f",
    #     # column_format="lrrr",
    #     sparsify=True,
    # )
    df.style.to_latex(
        file,
        hrules=True,
        clines=None,
        column_format="llllccccc",
        multirow_align="t",
        multicol_align="c",
        sparse_index=True,
        siunitx=False,
    )


if __name__ == "__main__":
    fire.Fire(latex_table)
