'''
plot the attack margin obtained by different attack methods.
'''

import matplotlib.pyplot as plt
import torch
import numpy as np

model_name = "cnn_4_layer"
if model_name == "cnn_4_layer":
    data_list = [[1150, 5, 0.03968451], [8178, 9, 0.04562723], [6991, 3, 0.04517244], [9874, 3, 0.03447832], [2152, 3, 0.07273999], [5995, 3, 0.03525582], [1268, 7, 0.03721436], [1285, 3, 0.03270999], [3577, 1, 0.01615015], [5602, 7, 0.02632373], [9723, 7, 0.0380844], [758, 4, 0.03523076], [2460, 9, 0.02390697], [7587, 8, 0.0285167], [1599, 9, 0.02867397], [2929, 7, 0.00579698], [4899, 8, 0.03676112], [5986, 9, 0.05513051], [8006, 7, 0.02189252], [1291, 5, 0.02397713], [9155, 8, 0.0140906], [212, 4, 0.03108337], [124, 4, 0.03095981], [395, 8, 0.03789664], [3029, 9, 0.05042984], [8928, 7, 0.00312109], [3498, 9, 0.05272638], [7070, 7, 0.01824653], [4088, 4, 0.00572934], [451, 2, 0.00566403], [7178, 2, 0.03809859], [3795, 2, 0.05580203], [7029, 3, 0.00585325]]
elif model_name == "madrycnn_no_maxpool_tiny":
    data_list = [[8416, 8, 0.539685297], [2425, 3, 1.34910693], [3153, 5, 0.108600129], [952, 2, 0.06509659500000001], [1454, 5, 0.959299821], [4439, 4, 0.186001812], [2000, 8, 1.407823281], [9921, 2, 0.050690466], [6668, 5, 0.1747431], [66, 4, 1.545566661], [4126, 1, 0.195029055], [2479, 7, 0.27865067400000004], [8759, 3, 0.6096991230000001], [9720, 5, 0.016555356], [3068, 3, 0.298558566], [3302, 4, 0.057713246999999995], [6596, 2, 0.097022556], [5236, 6, 0.32228621100000004], [5736, 2, 0.29543328900000004], [9210, 8, 0.087765993], [1775, 5, 0.245375352], [3846, 0, 0.93976407]]
elif model_name == "cnn_4layer_b":
    d1 = [[7648, 8, 0.00484309], [1759, 3, 0.0089553], [233, 2, 0.00800648], [6358, 2, 0.00522499], [9796, 8, 0.00794274], [9462, 4, 0.00542963], [4173, 0, 0.00279543], [455, 5, 0.00972146], [4267, 4, 0.00068247], [314, 2, 0.00941603], [432, 5, 0.00291255], [4025, 4, 0.00441736]]
    d2 = [[1502, 6, 0.03996308], [1303, 2, 0.04355151], [3666, 3, 0.04413575], [4863, 9, 0.03844466], [5379, 2, 0.04220087], [7059, 6, 0.04229824], [1499, 0, 0.00861354], [8501, 0, 0.03784466], [5649, 6, 0.0411936]]
    d3 = [[746, 2, 0.0547759], [2371, 8, 0.05769562], [645, 9, 0.05898961], [8023, 5, 0.05782123], [8134, 2, 0.05210367], [2972, 4, 0.05391872], [7262, 3, 0.05810022], [8181, 2, 0.05981432], [1897, 0, 0.05125416]]
    d1 = {d[0]:d[2] for d in d1}
    d2 = {d[0]:d[2] for d in d2}
    d3 = {d[0]:d[2] for d in d3}

    data_list = d1
    for d, v in d2.items():
        if d in data_list:
            data_list[d] = min(data_list[d], v)
        else:
            data_list[d] = v
    
    for d, v in d3.items():
        if d in data_list:
            data_list[d] = min(data_list[d], v)
        else:
            data_list[d] = v

    data_list = list(data_list.items())

plt.ylabel("minmum attack margin(Ours)", fontsize=17)
plt.xlabel("minmum attack margin(PGD, Diversed PGD)", fontsize=17)
for method in ["pgd_attack", "diversed_pgd_attack", "new_auto_attack"]:
    attack_margin = torch.load("my_data/{}_{}_margin.torch".format(model_name, method))
    attack_margin.sort(axis=1)
    margin_list = [attack_margin[idx][1] for idx, label, _ in data_list]

    plt.xlim(0,1.0)
    plt.ylim(0,1.0)
    plt.scatter(margin_list, [a[-1] for a in data_list], s=20, label=method)

plt.scatter(np.arange(0, 5.0, 0.0001), np.arange(0, 5.0, 0.0001), label="y=x", s=0.2, color='grey')
plt.legend(fontsize=15)
plt.savefig("{}.png".format(model_name))
