import matplotlib.pyplot as plt
from results import load_data, plot_loss_over_l2, make_average_and_std, extract_best_hyperparams


def plot_loss_over_reg(data, path, limits):
    # plotting the different loss metrics on test data
    for loss_type in ["loss", "P@1", "P@3", "P@5", "PSP@1", "PSP@3", "PSP@5", "R@1", "R@3", "R@5"]:
        grid = plot_loss_over_l2(data, loss_type, "noisy-test", height=1.5, aspect=1.75)

        for (row_val, col_val), ax in grid.axes_dict.items():
            lims = limits[(row_val, col_val)]
            ax.set_xlim(xmin=lims[0], xmax=lims[1], emit=True)
            if row_val == "cce" and col_val is False:
                ax.set_title("CCE")
            elif row_val == "cce" and col_val is True:
                ax.set_title("CCE Normalized")
            elif row_val == "bce" and col_val is True:
                ax.set_title("BCE Normalized")
            elif row_val == "bce":
                ax.set_title("BCE")
        grid.savefig(f"{path}-test-{loss_type}-over-reg.png", dpi=300, transparent=False)
        grid.savefig(f"../../plots/{path}/test-{loss_type}.pgf")
        plt.close(grid.fig)

    grid = plot_loss_over_l2(data, "loss", "noisy-train")
    grid.savefig(f"{path}-train-loss-over-reg.png", dpi=300, transparent=False)
    grid.savefig(f"../../plots/{path}/train-loss.pgf")
    plt.close(grid.fig)


eurlex_limits = {
    ("cce", True): (1e-4, 1e-1),
    ("bce", True): (1e-7, 1e-5),
    ("cce", False): (1e-4, 1),
    ("bce", False): (1e-7, 1e-4)
}

acat_limits = {
    ("cce", True): (1e-6, 1e-3),
    ("bce", True): (1e-11, 1e-7),
    ("cce", False): (1e-5, 1e-2),
    ("bce", False): (1e-11, 1e-7)
}

mode = "amazoncat13k"
mode = "eurlex"
mode = "wiki10"

data = load_data([f"{mode}.json"])
plot_loss_over_reg(data, mode, acat_limits if mode == "amazoncat13k" else eurlex_limits)
best_data = extract_best_hyperparams(data, criterion='PSP@3')

table_data = make_average_and_std(best_data)
table_data.sort_values(by=["config/loss", "config/mode", "config/data", "config/normalized"], inplace=True)
table_data.to_csv(f"{mode}.csv", index=False, float_format='%.15f')
