
import pickle
import matplotlib.pyplot as plt
from numpy import linspace
import numpy as np
# directory for input NC parameters
dir = ['']*4
title = ['']*4
# directory for generating the plot
dir_local = ''


title[0] = 'VGG on CIFAR10'
title[1] = 'VGG on FMNIST'
title[2] = 'ResNet18 on CIFAR10'
title[3] = 'ResNet18 on FMNIST'

c = range(1, 500+1,1)
epochs = []
for i in range(1,len(c)+1,2):
  epochs.append(np.log(np.log(c[i])))


# norms
index = [0,5,50,249]; E = []; F = []
for i in range(4):
    E.append(epochs[index[i]])
F = ['$1$','$10$','$100$', '$500$']

fig_norm, axs = plt.subplots(4, 4, figsize = [3*4,2*4], gridspec_kw = {'wspace':0, 'hspace':0, 'top': 0.88, 'bottom':0.1})
#plt.plot(epochs, h_norm_variation[:L], 'b--', label='Last layer feature')
for i in range(4):

  with open(dir[i] + "w_norm.txt", "rb") as fp:
    w_norm_variation = pickle.load(fp)
  with open(dir[i] + "h_norm.txt", "rb") as fp:
    h_norm_variation = pickle.load(fp)

  axs[0][i].axhline(0.2, color='#ededed', ls='-')
  axs[0][i].axhline(0.05, color='#ededed', ls='-')
  axs[0][i].axhline(0.0125, color='#ededed', ls='-')
  axs[0][i].axvline(0, color='#ededed', ls='-')
  axs[0][i].axvline(0.5, color='#ededed', ls='-')
  axs[0][i].axvline(1, color='#ededed', ls='-')
  axs[0][i].axvline(1.5, color='#ededed', ls='-')
  axs[0][i].axvline(2, color='#ededed', ls='-')


  axs[0][i].plot(epochs, w_norm_variation[:250], '#2E2300', label='Classifier',linewidth=2) #marker = 'o', markersize=2, '#808080','#CD5C5C'
  axs[0][i].plot(epochs, h_norm_variation[:250], '#6E6702', label='Last layer feature',linewidth=2)
  #axs[i].set_yscale('log')
  #axs[i].set_xscale('log')
  axs[0][i].set_title(title[i], fontsize = 15)
  axs[0][i].set_xlabel('Epoch')
  axs[0][i].set_yscale('log')
  if not i:
    axs[0][i].set_ylabel('Avg Relative Std', fontsize = 12)
  else:
    axs[0][i].set_yticklabels([])
    axs[0][i].get_yaxis().set_visible(False)
  #axs[0][i].get_xaxis().set_visible(False)
  axs[0][i].set_xticks(E)
  axs[0][i].set_xticklabels(F)
  axs[0][i].set_ylim(0.008,0.4)
  #axs[0][i].grid() #set_facecolor('#ededed')




# cosines
for i in range(4):
  with open(dir[i] + "w_cos.txt", "rb") as fp:
    w_cos_mean = pickle.load(fp)
  with open(dir[i] + "h_cos.txt", "rb") as fp:
    h_cos_mean = pickle.load(fp)

  axs[1][i].axhline(0.15*5, color='#ededed', ls='-')
  axs[1][i].axhline(0.05*5, color='#ededed', ls='-')
  axs[1][i].axhline(0.05*5/3, color='#ededed', ls='-')
  axs[1][i].axvline(0, color='#ededed', ls='-')
  axs[1][i].axvline(0.5, color='#ededed', ls='-')
  axs[1][i].axvline(1, color='#ededed', ls='-')
  axs[1][i].axvline(1.5, color='#ededed', ls='-')
  axs[1][i].axvline(2, color='#ededed', ls='-')

  axs[1][i].plot(epochs, w_cos_mean[:250], '#2E2300', label='Classifier',linewidth=2)
  axs[1][i].plot(epochs, h_cos_mean[:250], '#6E6702', label='Last layer feature',linewidth=2)
  #axs[i].set_yscale('log')
  #axs[i].set_xscale('log')
  #axs[1][i].set_title(title[i])
  axs[1][i].set_xlabel('Epoch')
  axs[1][i].set_yscale('log')
  if not i:
    axs[1][i].set_ylabel('Avg cosine', fontsize = 12)
  else:
    axs[1][i].set_yticklabels([])
    axs[1][i].get_yaxis().set_visible(False)
  #axs[1][i].get_xaxis().set_visible(False)
  axs[1][i].set_xticks(E)
  axs[1][i].set_xticklabels(F)
  axs[1][i].set_ylim(0.04, 1.1)
  #axs[1][i].grid()



# duality
for i in range(4):
  with open(dir[i] + "dual.txt", "rb") as fp:
    dual = pickle.load(fp)

  axs[2][i].axhline(0.2 * 4, color='#ededed', ls='-')
  axs[2][i].axhline(0.1 * 4, color='#ededed', ls='-')
  axs[2][i].axhline(0.05 * 4, color='#ededed', ls='-')
  axs[2][i].axvline(0, color='#ededed', ls='-')
  axs[2][i].axvline(0.5, color='#ededed', ls='-')
  axs[2][i].axvline(1, color='#ededed', ls='-')
  axs[2][i].axvline(1.5, color='#ededed', ls='-')
  axs[2][i].axvline(2, color='#ededed', ls='-')

  axs[2][i].plot(epochs, dual[:250], '#C05805', label='Dual',linewidth=2)
  axs[2][i].set_xlabel('Epoch')
  axs[2][i].set_yscale('log')
  if not i:
    axs[2][i].set_ylabel('Avg difference', fontsize = 12)
  else:
    axs[2][i].set_yticklabels([])
    axs[2][i].get_yaxis().set_visible(False)
  #axs[2][i].get_xaxis().set_visible(False)
  axs[2][i].set_xticks(E)
  axs[2][i].set_xticklabels(F)
  axs[2][i].set_ylim(0.1,1.2)
  #axs[2][i].grid()




# with-in class
for i in range(4):
  with open(dir[i] + "wi.txt", "rb") as fp:
    with_in = pickle.load(fp)

  axs[3][i].axhline(0.2 * 5, color='#ededed', ls='-')
  axs[3][i].axhline(0.12 * 5, color='#ededed', ls='-')
  axs[3][i].axhline(0.072 * 5, color='#ededed', ls='-')
  axs[3][i].axvline(0, color='#ededed', ls='-')
  axs[3][i].axvline(0.5, color='#ededed', ls='-')
  axs[3][i].axvline(1, color='#ededed', ls='-')
  axs[3][i].axvline(1.5, color='#ededed', ls='-')
  axs[3][i].axvline(2, color='#ededed', ls='-')

  axs[3][i].plot(epochs, with_in[:250], '#DB9501', label='Last layer feature',linewidth=2)
  axs[3][i].set_xlabel('Epochs', fontsize = 12)
  axs[3][i].set_yscale('log')
  if not i:
    axs[3][i].set_ylabel('Avg difference', fontsize = 12)
  else:
    axs[3][i].set_yticklabels([])
    axs[3][i].get_yaxis().set_visible(False)
  #axs[3][i].get_xaxis().set_visible(False)
  #axs[3][i].set_xticks(E, F)
  axs[3][i].set_xticks(E)
  axs[3][i].set_xticklabels(F)
  axs[3][i].set_ylim(0.2,1.2)

# save the figure
lines_labels = [axs[0][0].get_legend_handles_labels()]+[axs[2][0].get_legend_handles_labels()]+[axs[3][0].get_legend_handles_labels()]
lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
fig_norm.legend(lines, labels,frameon=False, loc = (0.1, 0.95), ncol=len(axs), fontsize = 15)
fig_norm.show()
fig_norm.savefig(dir_local+"nn-plot-new.pdf")





