import math
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from matplotlib.colors import ListedColormap
from matplotlib.lines import Line2D
import torch
from palettable.tableau import GreenOrange_12, TableauMedium_10, Tableau_20, ColorBlind_10 
from palettable.cartocolors.qualitative import Bold_8, Prism_10, Vivid_10, Safe_10
from palettable.colorbrewer.qualitative import Accent_8 
from palettable.cartocolors.sequential import DarkMint_4, Magenta_4, TealGrn_4, PurpOr_4
plt.rcParams["font.family"] = "Times New Roman"

new_colors = np.vstack((ColorBlind_10.mpl_colors, Bold_8.mpl_colors))
my_cmap = ListedColormap(new_colors, name='BoldBlind')
my_cmap2 = ListedColormap(Tableau_20.mpl_colors)
sequent = np.vstack((TealGrn_4.mpl_colors, Magenta_4.mpl_colors, DarkMint_4.mpl_colors, PurpOr_4.mpl_colors))
my_cmap3 = ListedColormap(Prism_10.mpl_colors)
my_cmap4 = ListedColormap(sequent)
# mcmap = my_cmap1 +  my_cmap2
# sns.set_style("darkgrid")
sns.set_style("whitegrid", {"grid.color": ".8", "grid.linestyle": "--"})
sns.despine(left=True)
# from tueplots import bundles
plt.rcParams.update({'text.usetex': False,
 'font.serif': ['Times New Roman'],
 'mathtext.fontset': 'stix',
 'mathtext.rm': 'Times New Roman',
 'mathtext.it': 'Times New Roman:italic',
 'mathtext.bf': 'Times New Roman:bold',
 'font.family': 'serif'})

## some plotting stas helper functions
def round_near_five(x, base=5):
    return base * round(x/base)

# models = ["pgd_5step", "apgd_5step"] #, "vit_s_cvst_25ep_final3", "vit_s_cvst_25ep_convstem_high_lr"] #, "conviso_cvblk_300AT", "conviso_300AT"] #, "base_cvblk"] #, "tb10_model4", "tb10_model5", "model1"]
# # l2s = [model0, model3, model5, model6, model7, model8]
# train_stats = []
# for m in models:
#     print(m)
#     with open(f"/Users/nmndeep/Documents/logs_semseg/{m}_log.txt", 'r') as fp:
#         # lines to read
#         # if "300" in m:
#         line_numbers = np.arange(0,50,1)
#         # else:
#         #     line_numbers = np.arange(0,49,2)
#         cnvnxt = []
#         for i, line in enumerate(fp):
#             # read line 4 and 7
#             print(line)
#             if i in line_numbers:
#                 i_0 = (line.index("Loss"))
#                 i_1 = (line.index("Cost"))
#                 cnvnxt.append(float(line[i_0+6:i_1-6]))
#         train_stats.append(cnvnxt)
# print(train_stats)
# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 7))
fig, axs = plt.subplots(2, 2, figsize=(18, 7))
# axs[0, 0].plot(x, y)
# axs[0, 0].set_ylabel("")
# axs[1, 0].plot(x, y**2)
# axs[1, 0].set_title("shares x with main")

# axs[0, 1].plot(x + 1, y + 1)
# axs[0, 1].set_title("unrelated")
# axs[1, 1].plot(x + 2, y + 2)
# axs[1, 1].set_title("also unrelated")
# fig.tight_layout()
# plt.subplots_adjust(wspace=0.2, hspace=1.2)
fig.text(0.08, 0.5, 'Robust accuracy (%)', va='center', rotation = 'vertical', fontsize=23)
fig.text(.505, 0.5, 'Robust mIoU (%)', va='center', rotation = 'vertical', fontsize=23)
fig.text(.5, 0.03, 'Number of attack iterations', ha='center', fontsize=23)

eps_8_ac = [72.5, 71.9, 71.2, 71.1] 
eps_12_ac = [33.4, 29.7, 27.3, 27.1]
eps_8_mi = [38.4, 37.7, 37.0, 36.9]
eps_12_mi = [10.1, 8.8, 8.0, 7.9] 
# eps_12_mi = [10.2, 38.0, 8.1, 37.0]

# eps_8_mi
xx1 = ['50', '100', '300', '500']
# # train_stats = [t1_rob, t1_clean]
# fig = plt.figure(figsize=(14,8))
# # plt.title("Training curves")
# # xx1 = [1]
# # xx = ["Rand-init", "1k-init", "3-Aug", "RandAug+CM+MU+LS", "CvBlk", "Long-Train", "Base"]
# # xx1 = (list(np.arange(1,51,1)))
# xx1 = [0, 10, 25, 50, 100, 150, 200]
# # plt.plot(xx1[4:], train_stats[0][4:], linewidth=1.5, color = my_cmap2(12), label=models[0])
# # plt.plot(xx1[4:], train_stats[1][4:], linewidth=1.5, color = my_cmap2(0), label=models[1])
axs[0, 0].plot(xx1, eps_8_ac, linewidth=3, color = my_cmap(14), marker='o', markeredgecolor='black', label="Attack radius, $\ell_\infty$=8/255")
axs[1, 0].plot(xx1, eps_12_ac, linewidth=3, color = my_cmap(12), marker='o', markeredgecolor='black',label="Attack radius, $\ell_\infty$=12/255")
axs[0, 1].plot(xx1, eps_8_mi, linewidth=3, color = my_cmap(14),   markeredgecolor='black',marker='o')
axs[1, 1].plot(xx1, eps_12_mi, linewidth=3, color = my_cmap(12),  markeredgecolor='black',marker='o')
axs[0, 0].set_yticks((np.arange(68, 75, 2)))
axs[1, 0].set_yticks((np.arange(24, 37, 3)))

# axs[0, 0].set_yticks((np.arange(68, 75, 2)))
# axs[1, 0].set_yticks((np.arange(24, 37, 3)))
axs[0, 1].set_yticks(np.arange(36, 40, 1))
axs[1, 1].set_yticks(np.arange(6, 14, 2))
# axs[0, 0].sharex(axs[1, 0])
# # plt.plot(xx1, train_stats[1], linewidth=1.5, color = my_cmap2(0), label="Clean")
# ax1.set_ylim((65, 95))
# ax2.set_ylim((35, 80))
# ax1.set_ylabel('Robust accuracy (%)', fontsize=22)
# ax2.set_ylabel('Robust mIoU (%)', fontsize=22)
axs[0, 0].set_yticklabels(np.arange(68, 75, 2), fontsize=19)
axs[1, 0].set_yticklabels(np.arange(24, 37, 3), fontsize=19)
axs[0, 1].set_yticklabels(np.arange(36, 40, 1), fontsize=19)
axs[1, 1].set_yticklabels(np.arange(6, 14, 2), fontsize=19)

axs[1, 0].set_xticklabels(xx1, fontsize=19)
axs[0, 0].set_xticklabels([])
axs[1, 1].set_xticklabels(xx1, fontsize=19)
axs[0, 1].set_xticklabels([])

# ax2.set_xticklabels(xx1, fontsize=22)
# ax1.set_yticklabels(np.arange(65, 95, 5), fontsize=22) 
# ax2.set_yticklabels(np.arange(35, 80, 5), fontsize=22) 
# fig.text(0.5, 0.01, 'Number of attack iterations for APGD', ha='center', fontsize=23)
axs[0, 0].legend(bbox_to_anchor=(1.1,1.38), ncol=2, 
          fancybox=True, shadow=False, fontsize=23)
axs[1, 0].legend(bbox_to_anchor=(1.85, 2.58), ncol=2, 
          fancybox=True, shadow=False, fontsize=23)
# # plt.annotate('lr-peak', xy=(xx2[65], train_stats[0][65]+.1))
# # plt.title("Mask-Margin-APGD Acc and mIoU at different perturbation radii v.s the number of iterations")
# # # plt.plot(xx1[10:49], cnvnxte4[10:], linewidth=1.5, color = my_cmap(11), linestyle ="--", label="convnext-b-1e-3_epoch10onwards")
# # # plt.plot(xx1[10:46], cnvnxtcvb4[10:], linewidth=1.5, color = my_cmap(0), linestyle ="--",  label="convnext-b-cvblk1e-3_epoch10onwards")
# plt.xlabel("Number of iterations", fontsize=23)
# plt.xticks(xx1, fontsize=22)
# plt.yticks(fontsize=22)
# plot_lines = [Line2D([0], [0], color=my_cmap4(11), linewidth=3), Line2D([0], [0], color=my_cmap(14), linewidth=3, linestyle ='--')]
# labels = ['Acc.', 'mIoU']
# plt.ylabel("Metric (%)", fontsize=23)
# legend1 = plt.legend(plot_lines, labels, loc=9, fontsize=22, bbox_to_anchor=(.5,1.0),
#           fancybox=False)
# plt.gca().add_artist(legend1)
# plt.legend(fontsize=22, bbox_to_anchor=(1.01,1.0),
#           fancybox=False, shadow=False)
# plt.show()
# exit()
plt.savefig("/Users/nmndeep/Documents/logs_semseg/large_eps_iterations_acc_miou.pdf", dpi=600)
exit()
BASE_DIR = '/Users/nmndeep/Documents/logs_semseg/'

worse_comp = [ 'WORST_CASE_5iter_rob_mod_0.0157_n_it_100_pascalvoc_ConvNeXt-T_CVST_ROB.pt',
'WORST_CASE_5iter_rob_mod_0.0314_n_it_100_pascalvoc_ConvNeXt-T_CVST_ROB.pt', 
'WORST_CASE_5iter_rob_mod_0.0471_n_it_100_pascalvoc_ConvNeXt-T_CVST_ROB.pt', 'WORST_CASE_5iter_rob_mod_0.0627_n_it_100_pascalvoc_ConvNeXt-T_CVST_ROB.pt']

if True:
    out_str = worse_comp[0][:16] + worse_comp[0][-22:-3]
    eps_ = [ '0', '$\mathcal{L}_{Mask-CE}$', '$\mathcal{L}_{Bal-CE}$', '$\mathcal{L}_{JS}$', '$\mathcal{L}_{Mask-Sph.}$', 'Worst case']
    liss = []
 

    for i in range(len(worse_comp)):
        vall = torch.load(BASE_DIR + f"worst_case_numbers/{worse_comp[i]}")
        final_acc_ = vall['final_matrix'].min(0)[1].unique(return_counts=True)[1]
        # #put cospgd at 1, js-avg at 3 and mask-ce at 4
        # final_acc_[final_acc_ == 1] = 200
        # final_acc_[final_acc_ == 3] = 201
        # final_acc_[final_acc_ == 4] = 202
        # final_acc_[final_acc_ == 200] = 4
        # final_acc_[final_acc_ == 201] = 1
        # final_acc_[final_acc_ == 202] = 3
        liss.append(final_acc_.numpy()) 
print(liss)

liss = np.asarray(liss)
data = np.transpose(liss)

# fig = plt.figure()
# ax = fig.add_axes([0,0,1,1])
# X = np.arange(4)
# ax.bar(X + 0.16, data[0], color = my_cmap(1), width = 0.16)
# ax.bar(X + 0.33, data[1], color = my_cmap(3), width = 0.16)
# ax.bar(X + 0.50, data[2], color = my_cmap(5), width = 0.16)
# ax.bar(X + 0.67, data[3], color = my_cmap(7), width = 0.16)
# ax.bar(X + 0.84, data[4], color = my_cmap(9), width = 0.16)
# ax.bar(X + 1.0,  data[-1], color = my_cmap(11), width = 0.16)
#acc
# yy_8 = [[74.5, 74.3, 73.7], [74.1, 73.7, 72.9], [75.1, 75.1, 73.8], [80.2, 80.2, 80.1], [72.2, 72.0, 71.3]]

# yy_12 = [[37.7, 34.6, 31.6], [42.6, 38.6, 35.9], [44.2, 42.83, 38.6], [37.9, 36.6, 38.1], [31.1, 29.1, 27.3]]
#miou
yy_8 = [[41.8, 41.1, 40.2], [41.3, 40.3, 39.4], [43.5, 43.3, 41.2], [49.6, 49.2, 49.9], [38.5, 38.0, 37.0]]
yy_12 = [[13.2, 11.3, 10.2], [14.9, 12.8, 11.6], [18.6, 16.1, 15.0], [12.1, 11.0, 11.6], [8.4, 8.3, 8.0]]
yy_8 = np.transpose(yy_8)
yy_12 = np.transpose(yy_12)
# yy = data
# print(yy.size())
# exit()
losses_lis = ['APGD const-$\epsilon$, 3 x 100','APGD const-$\epsilon$, 300 x 1', 'APGD red-$\epsilon$, 300 x 1']
print(yy_8)

total_width = 0.8 # 0 ≤ total_width ≤ 1
d = 0.01 # gap between bars, as a fraction of the bar width
width = total_width/(3+(3-1)*d)
offset = -total_width/50
print(offset)
### plot    
x = np.arange(5)
print(x)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 7.2))

plt.subplots_adjust(wspace=0.14, hspace=1.2)

# set width of bar
barWidth = 0.5
 

# Set position of bar on X axis
r = []
r.append(np.arange(0,len(yy_8[0])*2,2))
r.append([x + barWidth for x in r[0]])
r.append([x + barWidth for x in r[1]])


print("rrrr", r[0])
# Add xticks on the middle of the group bars
# plt.xlabel('group', fontweight='bold')
 
# # Create legend & Show graphic
# plt.legend()
# plt.show()


# ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
i=0
ticks = []
for idx, mod in enumerate(range(3)):
    if idx == 0:
        ix = 7
    elif idx == 2:
        ix = 14
    else:
        ix = 12
    ax1.bar([x-barWidth for x in r[idx]], yy_8[idx], barWidth, align='center',  edgecolor='black',  color=my_cmap(ix), hatch="/" if idx in [2] else None)
    ax2.bar([x-barWidth for x in r[idx]], yy_12[idx], barWidth, align='center', edgecolor='black', label = losses_lis[idx], color=my_cmap(ix), hatch="/" if idx in [2] else None)


# ax1.set_xticks(tks)
print(eps_)
ll = [r - barWidth*.5 for r in r[0]]
print(ll)
ax1.set_xticks(ll, eps_)
ax2.set_xticks(ll, eps_)

ax1.set_xticklabels(eps_, fontsize=24)
# ax2.set_xticks(tks)
labels = ax2.set_xticklabels(eps_, fontsize=24)
for i, label in enumerate(labels):
    # label.set_y(label.get_position()[1] - (i % 2) * 0.075)
    label.set_x(label.get_position()[0] + (i % 2) * 0.075)
# for tick in ax1.xaxis.get_majorticklabels():
#     tick.set_horizontalalignment("edge")
# for tick in ax2.xaxis.get_majorticklabels():
#     tick.set_horizontalalignment("left")

# plt.ylabel()
# minn = round_near_five(min([item for sublist in yy for item in sublist]))
# maxx = round_near_five(max([item for sublist in yy for item in sublist]))
# ax.set_ylim([minn-2.5, maxx+2.5])
ax1.set_yticks(np.arange(35, 52, 3))
ax2.set_yticks(np.arange(5, 22, 3))

ax1.set_yticklabels(np.arange(35, 52, 3), fontsize=24) 
ax2.set_yticklabels(np.arange(5, 22, 3), fontsize=24) 
ax1.set_ylim(35, 52) 
ax2.set_ylim(5, 22) 

ax1.set_title("$\epsilon_\infty = 8/255$", y=0.95, pad=-0, fontsize=25)
ax2.set_title("$\epsilon_\infty = 12/255$", y=0.95, pad=-0, fontsize=25)
# fig.set_size_inches(10, 6)
# ax.legend(fontsize=30, ncol=6)
ax2.legend(bbox_to_anchor=(0.67, 1.04), ncol=1, 
          fancybox=True, shadow=False, fontsize=26)
# ax.set_xticklabels(eps_)
ax1.set_ylabel('mIou (%)', fontsize=25)
fig.text(0.5, 0.008, 'Loss optimized for attack with APGD', ha='center', fontsize=26)

# ax.legend(title="$\ell_{\infty}$ robustness", title_fontsize=14)
plt.show()
# plt.savefig("/Users/nmndeep/Documents/logs_semseg/miou_over_iteration_schemes.pdf", dpi=800, tight_box=True)
exit()
