from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
import pandas as pd
import numpy as np
import argparse
import json
import glob
import os

def magg(values):
    return "${:.3f}$".format(np.mean(values))

def agg(values):
    return "${:.3f} \\pm {:.3f}$".format(np.mean(values), np.std(values))


def pair_plot(the_df, methods, x, y, fname, legend=False):
    plt.rcParams.update({
        "text.usetex": True,
        "font.family": "serif",
        "font.size": 12})

    plt.figure(figsize=(3, 3))

    for m, method in enumerate(methods):
        df = the_df.loc[the_df["method"] == method]
        plt.scatter(
            df[x],
            df[y],
            s=30,
            marker=list(Line2D.markers.keys())[m],
            label=method)

    plt.xlabel(x)
    plt.ylabel(y)
    if legend:
        plt.legend(prop={'size': 10})
    plt.tight_layout(pad=1)
    plt.savefig(fname)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--debug', action="store_true")
    parser.add_argument('--output_dir', type=str, default="results/jobs_100/")
    parser.add_argument('--plot', action="store_true")
    parser.add_argument('--x', type=str, default="data loss (tr)")
    parser.add_argument('--y', type=str, default="data loss (te)")
    parser.add_argument('--k', type=int, default=20)
    parser.add_argument('--lam', type=float, default=0.8)
    parser.add_argument(
        '--methods', default=["ERM", "ERM (ensemble)", "mixup",
                              "Extra-mixup-v1", "Extra-mixup-v2",
                              "TestTime-mixup", "TestTime-ERM",
                              "TestTime-mixup-aug"],
        nargs="+", type=str)
    args = vars(parser.parse_args())

    records = []
    for fname in glob.glob(os.path.join(args["output_dir"], "*.eval.json")):
        with open(fname, "r") as f:
            records.append(json.load(f)["stats"])

    all_df = pd.DataFrame(records)

    all_df["diversity (tr)"] /= all_df["k"]
    all_df["diversity (te)"] /= all_df["k"]
    all_df["diversity (ne)"] /= all_df["k"]

    all_df["OOD diversity (tr)"] /= all_df["k"]
    all_df["OOD diversity (te)"] /= all_df["k"]
    all_df["OOD diversity (ne)"] /= all_df["k"]

    # main table
    one_df = all_df.loc[all_df["k"] == 1]

    erm_ens_df = all_df.loc[(all_df["method"] == "ERM (ensemble)") &\
                            (all_df["k"] == args["k"])]

    ttm_ens_df = all_df.loc[(all_df["method"] == "TestTime-mixup") &\
                            (all_df["k"] == args["k"]) &\
                            (all_df["lam"] == args["lam"])]

    aaa_ens_df = all_df.loc[(all_df["method"] == "TestTime-mixup-aug") &\
                            (all_df["k"] == args["k"]) &\
                            (all_df["lam"] == args["lam"])]

    erm_ttt_df = all_df.loc[(all_df["method"] == "TestTime-ERM") &\
                            (all_df["k"] == args["k"])]

    the_df = pd.concat((one_df, erm_ens_df, ttm_ens_df, aaa_ens_df, erm_ttt_df))
    the_df = the_df.drop(["k", "lam"], axis=1)

    main_df = the_df.groupby(["method"]).agg(agg).transpose()
    main_df = main_df[args["methods"]]

    print(main_df.to_latex(
          caption="Main results",
          label="table:main",
          column_format="l" + "c" * len(main_df.columns),
          escape=False))

    # ERM ens ablation wrt k
    ens_df = all_df.loc[all_df["method"] == "ERM (ensemble)"]
    ens_df = ens_df.drop(["method", "lam"], axis=1)
    ens_df = ens_df.groupby(["k"]).agg(magg).transpose()

    # TTM ens ablation wrt k
    ttm_df = all_df.loc[(all_df["method"] == "TestTime-mixup") &\
                        (all_df["lam"] == args["lam"])]
    ttm_df = ttm_df.drop(["method", "lam"], axis=1)
    ttm_df = ttm_df.groupby(["k"]).agg(magg).transpose()

    print(pd.concat((ens_df, ttm_df), axis=1).to_latex(
          caption="Ablation wrt k",
          label="table:k",
          column_format="l" + "c" * (2 * len(ens_df.columns)),
          escape=False))

    # TTM ens ablation wrt lam
    ttm_df = all_df.loc[(all_df["method"] == "TestTime-mixup") &\
                        (all_df["k"] == args["k"])]
    ttm_df = ttm_df.drop(["method", "k"], axis=1)
    ttm_df = ttm_df.groupby(["lam"]).agg(magg).transpose()

    print(ttm_df.to_latex(
          caption="Ablation wrt k",
          label="table:k",
          column_format="l" + "c" * len(ttm_df.columns),
          escape=False))

    if args["plot"]:
        for split in ("te", "ne"):
            pair_plot(the_df, args["methods"], "mixup loss (te)", "avg acc ({})".format(split),
                    "figures/figure_scatter_avg_acc_{}.pdf".format(split))
            pair_plot(the_df, args["methods"], "mixup loss (te)", "adv acc ({})".format(split),
                    "figures/figure_scatter_adv_acc_{}.pdf".format(split))
            pair_plot(the_df, args["methods"], "mixup loss (te)", "ood acc ({})".format(split),
                    "figures/figure_scatter_ood_acc_{}.pdf".format(split), legend=True)
            pair_plot(the_df, args["methods"], "mixup loss (te)", "ECE ({})".format(split),
                    "figures/figure_scatter_ece_acc_{}.pdf".format(split))
