import matplotlib.pyplot as plt

#natural_baseline = [99.41, 99.38, 99.35, 99.33, 99.31, 99.28, 99.29, 99.24, 99.16, 99.16]
#robust_baseline = [91.09, 92.18, 93.21, 93.87, 94.32, 94.75, 95.45, 95.57, 95.65, 95.65]
# natural_1 = [99.25, 99.28, 99.28, 99.27, 99.28, 99.34, 99.21, 99.19, 99.13, 99.17]
# robust_1 = [92.16, 93.73, 94.83, 95.12, 95.32, 95.53, 95.94, 96.40, 96.58, 96.76]
# natural_2 = [99.27, 99.29, 99.31, 99.32, 99.28, 99.34, 99.32, 99.34, 99.33, 99.34]
# robust_2 = [92.46, 94.02, 94.79, 95.25, 95.53, 95.72, 96.15, 96.40, 96.40, 96.66]

#natural_1 = [99.25, 99.28, 99.28, 99.27, 99.28, 99.34, 99.21, 99.19, 99.13, 99.17, 99.25, 99.23, 99.27, 99.28, 99.28, 99.26, 99.26, 99.21, 99.28, 99.26, 99.27, 99.28, 99.25, 99.29, 99.34, 99.23, 99.29, 99.15, 99.21, 99.10, 99.19, 99.0, 99.13, 98.98, 99.17]                                                
#robust_1 = [92.16, 93.73, 94.83, 95.12, 95.32, 95.53, 95.94, 96.40, 96.58, 96.76, 92.16, 92.16, 91.93, 91.72, 93.73, 93.63, 93.49, 94.69, 94.83, 93.47, 95.12, 95.32, 95.42, 95.52, 95.53, 95.70, 95.75, 95.99, 95.94, 96.42, 96.40, 96.44, 96.58, 96.61, 96.76]
#natural_2 = [99.27, 99.29, 99.31, 99.32, 99.28, 99.34, 99.32, 99.34, 99.33, 99.34, 99.24, 99.32, 99.34, 99.31, 99.31, 99.15, 99.27, 99.28, 99.28, 99.28, 99.34, 99.21, 99.29, 99.26, 99.31, 99.32, 99.30, 99.29, 99.28, 99.34, 99.30, 99.32, 99.34, 99.33, 99.34]
#robust_2 = [92.46, 94.02, 94.79, 95.25, 95.53, 95.72, 96.15, 96.40, 96.40, 96.66, 95.43, 95.70, 95.72, 95.71, 95.61, 92.58, 92.46, 91.86, 91.87, 91.85, 92.24, 93.90, 94.02, 94.64, 94.79, 95.25, 95.24, 95.48, 95.53, 95.72, 95.68, 96.15, 96.40, 96.40, 96.66]

natural_baseline = [91.31, 89.56, 87.91, 87.50, 87.11, 87.01, 85.22, 83.82, 82.90, 81.72]
robust_baseline = [26.53, 37.71, 41.50, 43.37, 44.17, 44.68, 48.22, 49.67, 50.25, 50.64]
natural_1 = [86.35, 86.79, 87.18, 87.42, 87.09, 86.77, 87.42, 87.57, 87.09, 86.95, 83.37, 85.20, 87.80, 87.32, 87.72, 85.49, 84.89, 84.31, 83.52]                                           
robust_1 = [47.79, 47.82, 47.95, 47.90, 47.91, 48.10, 47.75, 47.82, 47.66, 48.22, 8.72, 29.53, 44.14, 46.46, 47.12, 50.73, 52.24, 53.25, 53.53]


natural_1, robust_1 = zip(*sorted(zip(natural_1, robust_1)))
#natural_2, robust_2 = zip(*sorted(zip(natural_2, robust_2)))

plt.scatter(natural_baseline, robust_baseline, color='green', label='Trades', s=10, alpha=0.3)
plt.scatter(natural_1, robust_1, color='blue', label='Trades + SNR', s=10, alpha=0.3)
#plt.scatter(natural_2, robust_2, color='orange', label='AD + SNR', s=10, alpha=0.3)

temp_dict = {}
new_natural_1 = []
new_robust_1 = []

for i, item in enumerate(natural_1):
    if item in temp_dict:
        if robust_1[i] > temp_dict[item]:
            temp_dict[item] = robust_1[i]
    else:
        temp_dict[item] = robust_1[i]

sorted_temp_dict = sorted(temp_dict.items(), key=lambda x: x[0])

for item in sorted_temp_dict:
    new_natural_1.append(item[0])
    new_robust_1.append(item[1])



""" temp_dict = {}
new_natural_2 = []
new_robust_2 = []

for i, item in enumerate(natural_2):
    if item in temp_dict:
        if robust_2[i] > temp_dict[item]:
            temp_dict[item] = robust_2[i]
    else:
        temp_dict[item] = robust_2[i]

sorted_temp_dict = sorted(temp_dict.items(), key=lambda x: x[0])

for item in sorted_temp_dict:
    new_natural_2.append(item[0])
    new_robust_2.append(item[1]) """

""" new_natural_1 = [98.98, 99.0, 99.1, 99.13, 99.15, 99.17, 99.19, 99.21, 99.23, 99.29, 99.34]
new_robust_1 = [96.61, 96.44, 96.42, 96.58, 95.99, 96.76, 96.4, 95.94, 95.7, 95.75, 95.53]
new_natural_2 = [99.15, 99.21, 99.24, 99.28, 99.29, 99.3, 99.31, 99.32, 99.33, 99.34]
new_robust_2 = [92.58, 93.9, 95.43, 95.53, 95.48, 95.68, 95.71, 96.15, 96.4, 96.66] """
new_natural_1 = [83.52, 84.31, 84.89, 85.49, 86.35, 86.77, 86.79, 86.95, 87.09, 87.18, 87.32, 87.42, 87.57, 87.72, 87.8]
new_robust_1 = [53.53, 53.25, 52.24, 50.73, 47.79, 48.1, 47.82, 48.22, 47.91, 47.95, 46.46, 47.9, 47.82, 47.12, 44.14]
plt.plot(natural_baseline, robust_baseline, color='green')
plt.plot(new_natural_1, new_robust_1, color='blue')
#plt.plot(new_natural_2, new_robust_2, color='orange')

plt.xlabel('natural_acc')
plt.ylabel('robust_acc')

#plt.xlim(99, 99.5) 
#plt.ylim(91, 100) 
plt.xlim(81, 92) 
plt.ylim(25, 55) 
#plt.xticks([x / 10 for x in range(990, 995, 1)])

plt.gca().set_xticklabels(['{:.1f}%'.format(x) for x in plt.gca().get_xticks()])
plt.gca().set_yticklabels(['{:.0f}%'.format(y) for y in plt.gca().get_yticks()])

plt.legend()

plt.savefig('/home/verification/Unsupervised-Robust-Learning/TRADES/result_cifar.png')
