from nesim.utils.json_stuff import load_json_as_dict
import matplotlib.pyplot as plt
import os

results = load_json_as_dict("./results.json")
plots_folder = "./figures"

attack_names = [
    "FGSM",
    "PGD",
    "OnePixel",
    "Square",
    "Jitter",
    # "Pixle"
]
for method_name in attack_names:
    plot_this = {
        "ours\noriginal": results["ours"][method_name]["original"]["accuracy_top_1"],
        "ours\nadv.": results["ours"][method_name]["adversarial"]["accuracy_top_1"],
        "pretrained\noriginal": results["pretrained"][method_name]["original"]["accuracy_top_1"],
        "pretrained\nadv.": results["pretrained"][method_name]["adversarial"]["accuracy_top_1"],
        "random weights\noriginal": results["random_weights"][method_name]["original"]["accuracy_top_1"],
        "random weights\nadv.": results["random_weights"][method_name]["adversarial"]["accuracy_top_1"]
    }

    fig = plt.figure()
    plt.title(f"Adversarial Robustness\nMethod: {results['ours']['FGSM']['attack']['info']}")
    plt.bar(
        plot_this.keys(),
        plot_this.values()
    )
    plt.grid()
    plt.ylabel('Top 1 Val acc.')
    filename = os.path.join(
        plots_folder,
        f"{method_name}.png"
    )
    fig.savefig(
        filename
    )
    print(filename)