from collections import defaultdict
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from pandas import DataFrame, Index

from experiments import result_io

artifact_dir = Path(__file__).parent / "paper_artifacts"
artifact_dir.mkdir(exist_ok=True, parents=True)

result_io.results_dir = result_io.results_dir.parent / "arxiv-results"


def write_table(
    model, measure, datasets=[], orders=[], summary_fn=lambda x: (np.mean(x), np.std(x))
):
    model_dir = result_io.results_dir / model
    if not datasets:
        datasets = [loc.stem for loc in model_dir.iterdir() if loc.is_file()]

    measures = defaultdict(list)
    for dataset in datasets:
        res = result_io.read_result(model, dataset)
        if not orders:
            orders = list(res.keys())
        for order in orders:
            measures[order].append(summary_fn(res[order][measure]))

    df = DataFrame.from_dict(dict(measures))
    df.index = Index((name.split("_")[0].capitalize() for name in datasets))
    df.columns = Index((rf"$\alpha={order}$" for order in orders))
    style = df.style.format(lambda v: f"${v[0]:.3f}\pm{v[1]:.3f}$")

    style.to_latex(artifact_dir / f"{model}_{measure}.tex")


def make_errorbar_figure(model, measure, datasets=[], orders=[], prefix=""):
    model_dir = result_io.results_dir / model
    if not datasets:
        datasets = [loc.stem for loc in model_dir.iterdir() if loc.is_file()]

    measures = defaultdict(list)
    for dataset in datasets:
        res = result_io.read_result(model, dataset)
        if not orders:
            orders = list(res.keys())
        for order in orders:
            try:
                measures[dataset].append(
                    (res[order][measure].mean(), res[order][measure].std())
                )
            except KeyError:
                print(measure, dataset, order)
                exit(1)

    fig, axs = plt.subplots(1, len(datasets), figsize=(2 * len(datasets), 2.5))
    if measure == "rmse":
        axs[0].set_ylabel(measure)
    else:
        axs[0].set_ylabel(r"log likelihood")

    for dataset, ax in zip(datasets, axs):
        y, err = zip(*measures[dataset])
        for i, order in enumerate(orders):
            ax.errorbar(
                i,
                y[i],
                fmt="o",
                yerr=err[i],
                label=rf"$\alpha={order}$",
                linewidth=2,
                capsize=6,
            )
            ax.set_xlabel(dataset.split("_")[0].capitalize())
            ax.set_xticks([])

    axs[3].legend(loc=2)

    fig.tight_layout(pad=0.98)
    fig.savefig(artifact_dir / f"{model}_{measure}_{prefix}_errorbar.png")


if __name__ == "__main__":
    make_errorbar_figure(
        "bnn",
        "rmse",
        datasets=[
            "boston_housing",
            "concrete",
            "kin8nm",
            "yacht",
            "power",
            "wine",
            "energy_heating_load",
        ],
    )

    # write_table("vae", "log_like")
    ## SVGD results
    # energy_heating_load
    # rmse 1.382340949302857 pm 0.08777002419971856
    # ll -2.446017806777368 pm 0.29742536153900223

