import pickle
import numpy as np
import matplotlib.pyplot as plt
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data', default='cifar10', type=str)
parser.add_argument('-m', '--methods', default="all", type=str, nargs='+')
parser.add_argument('-a','--attack', default='auto-att', type=str)
args = parser.parse_args()

if args.methods=="all":
    args.methods = ["GAT", "TRADES", "Nu_AT", "MART", "PGD_AT_p", "RFGSM_AT_p", "Ours", "Ours_AdvGAN"]

folder = 'plot_results'
with open(f'{folder}/plot_results_{args.data}_{args.attack}.pkl', 'rb') as f:
    plot_results = pickle.load(f)

plt.figure(figsize=(7,7))

for defense in args.methods:
    robust_accs = plot_results[defense]["robust_accs"]
    # robust_accs.sort()
    natural_accs = plot_results[defense]["natural_accs"]
    # natural_accs.sort()
    plt.scatter(robust_accs, natural_accs, marker = 'o', label=defense)

plt.legend()
plt.title(f'{args.data} {args.attack}')
plt.xlabel('Robust Accuracy')
plt.ylabel('Natural Accuracy')
plt.show()
plt.savefig(f'{folder}/plot_{args.data}_{args.attack}.png')