import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats

'''''
Progressive sharpening
'''''

# lrs = [0.08, 0.04, 0.02]

# def label_entropy(smoothing):
#     vec = [1.] + [0. for i in range(9)]
#     vec = np.array(vec)
#     uniform = np.array([0.1 for i in range(10)])
#     combination = (1.-smoothing)*vec + smoothing*uniform
#     return stats.entropy(combination)

# def plot(x, y, xticks, ylabel, fname):
#     ym = y.mean(axis = 0)
#     ystd = y.std(axis = 0)
#     fig, ax = plt.subplots()
#     if ylabel == 'Loss':
#         ym[0] = ym[0] - label_entropy(0.0)
#         ym[1] = ym[1] - label_entropy(0.5)
#         ym[2] = ym[2] - label_entropy(0.75)
#     ax.plot(x, ym[0], label = '0.0', linestyle = 'solid', linewidth = 2.0, color = 'blue')
#     ax.fill_between(x, ym[0]-ystd[0], ym[0]+ystd[0], alpha = 0.3)
#     ax.plot(x, ym[1], label = '0.5', linestyle = 'dashed', linewidth = 2.0, color = 'orange')
#     ax.fill_between(x, ym[1]-ystd[1], ym[1]+ystd[1], alpha = 0.3)
#     ax.plot(x, ym[2], label = '0.75', linestyle = 'dotted', linewidth = 2.0, color = 'green')
#     ax.fill_between(x, ym[2]-ystd[2], ym[2]+ystd[2], alpha = 0.3)

#     if ylabel == 'Loss':
#         ax.legend(fontsize = 20)
#     # ax.set_yscale('log')
#     # if ylabel == 'Jacobian norm':
#     #     plt.yscale('log')
#     plt.xticks(xticks, fontsize = 20)
#     plt.xlabel('Iteration', fontsize = 20)
#     plt.ylabel(ylabel, fontsize = 20)
#     plt.yticks(fontsize = 20)
#     fig.tight_layout()
#     plt.savefig(fname)
#     plt.show()

# for lr in lrs:
#     path = f'./fullbatch/vgg/{lr:g}/'

#     names = ['hess', 'loss', 'j1t', 'j1e', 'j1ts', 'j1es']
#     for name in names:
#         arrays = [np.load(path+name+f'_{i}.npy') for i in range(5)]
#         array = np.stack(arrays, axis = 0)
#         np.save(path+name+'.npy', array)

#     h = np.load(path+'hess.npy')
#     l = np.load(path+'loss.npy')
#     j1t = np.load(path+'j1t.npy')
#     j1e = np.load(path+'j1e.npy')
#     j1ts = np.load(path+'j1ts.npy')
#     j1es = np.load(path+'j1es.npy')

#     t1 = np.linspace(0, l.shape[2], l.shape[2])
#     t2 = np.linspace(0, l.shape[2], h.shape[2])
#     xticks = np.linspace(0, l.shape[2], 5)

#     plot(t1, l, xticks, 'Loss', path+'loss.pdf')
#     plot(t2, h, xticks, 'Sharpness', path+'hess.pdf')
#     plot(t2, j1t, xticks, 'Jacobian norm', path+'jac.pdf')
#     plot(t2, j1ts, xticks, 'Jacobian norm', path+'jacsoft.pdf')

######################################################################################################################################

'''
Progressive sharpening with SGD
'''
# momentum = 0

# path = f'./vgg_sgd/{momentum:g}/'

# names = ['hess', 'j1t', 'j1e', 'j1ts', 'j1es', 'outnorm']

# trials = [0, 1, 2, 3, 4]

# smoothings = [0.0, 0.5, 0.75]

# for name in names:
#     trialarrays = []
#     for trial in trials:
#         smoothingarrays = [np.load(path+name+f'_{sm:g}_{trial}.npy') for sm in smoothings]
#         smoothingarray = np.stack(smoothingarrays)
#         trialarrays.append(smoothingarray)
#     array = np.stack(trialarrays)
#     print(array.shape)
#     np.save(path+name+'.npy', array)
    
# names = ['hess', 'j1es', 'j1e', 'outnorm']

# def plot(x, y, xticks, ylabel, fname, xaxis):
#     ym = y.mean(axis = 0)
#     ystd = y.std(axis = 0)
#     fig, ax = plt.subplots()
#     ax.plot(x, ym[0], label = '0.0', linestyle = 'solid', linewidth = 2.0)
#     ax.fill_between(x, ym[0]-ystd[0], ym[0]+ystd[0], alpha = 0.3)
#     ax.plot(x, ym[1], label = '0.5', linestyle = 'dashed', linewidth = 2.0)
#     ax.fill_between(x, ym[1]-ystd[1], ym[1]+ystd[1], alpha = 0.3)
#     ax.plot(x, ym[2], label = '0.75', linestyle = 'dotted', linewidth = 2.0)
#     ax.fill_between(x, ym[2]-ystd[2], ym[2]+ystd[2], alpha = 0.3)

#     if ylabel == 'Hessian norm':
#         ax.legend(fontsize = 20)
#     # ax.set_yscale('log')
#     plt.xticks(xticks, fontsize = 20)
#     plt.xlabel(xaxis, fontsize = 20)
#     plt.ylabel(ylabel, fontsize = 20)
#     plt.yticks(fontsize = 20)
#     fig.tight_layout()
#     # plt.savefig(fname)
#     plt.show()
    
# def name_conversion(name):
#     if name == 'hess':
#         return 'Hessian norm'
#     elif name == 'j1es':
#         return 'Softmax Jacobian norm'
#     elif name == 'j1e':
#         return 'Jacobian norm'
#     elif name == 'outnorm':
#         return 'Output norm'

# for name in names:
#     x = np.linspace(0, 390, 78)
#     plot(x, np.load(path+name+'.npy')[:,:,:78], np.linspace(0, 390, 4), name_conversion(name), path+name+'_first_epoch.pdf', 'Iterations')

#     x = np.linspace(0, 90, 46)
#     array = np.load(path+name+'.npy')[:,:,[0] + [i for i in range(78, 123)]]
#     plot(x, array, np.linspace(0, 90, 4), name_conversion(name), path+name+'_all_epochs.pdf', 'Epochs')






#######################################################################################################################################

'''
Batch size vs learning rate
'''

path = './batch_lr/batch_lr/cifar10/vgg11/0.0001/'

names = ['hess', 'jac1eval', 'jac1train', 'jac2eval', 'jac2train', 'trainacc', 'testacc', 'trainloss', 'testloss']

# for name in names:
#     arrays = [np.load(path+name+f'_{i}.npy') for i in range(5)]
#     array = np.stack(arrays, axis = 0)
#     np.save(path+name+'.npy', array)

gap = np.abs(np.load(path+'trainloss.npy') - np.load(path+'testloss.npy'))
np.save(path+'gap.npy', gap)

names = ['jac1eval', 'jac1train', 'jac2eval', 'jac2train']
for name in names:
    array = np.load(path+name+'.npy')
    array = np.sqrt(array)
    np.save(path+name+'.npy', array)

lrs = np.array([0.1, 0.01, 0.001])
bs = np.array([64, 128, 256, 512, 1024])
# bs_down = np.array([64, 256, 1024])

def name_conversion(name):
    if name == 'hess':
        return 'Sharpness'
    elif name == 'jac1eval':
        return 'Jacobian norm'
    elif name == 'gap':
        return 'Gen. gap (loss)'
    elif name == 'trainloss':
        return 'Train loss (excl. reg.)'

def plot(mode, name):
    y = np.load(path+name+'.npy')
    if mode == 'lr':
        x = np.array([0., 1., 2.])
        for i in range(5):
            plt.scatter(x, y[:,i,:].mean(axis = 0))
            plt.scatter(np.tile(x, (y.shape[0], 1)), y[:,i,:], alpha=0.4)
            plt.errorbar(x, y[:,i,:].mean(axis = 0), yerr = y[:,i,:].std(axis = 0))
            plt.xticks(x, lrs, fontsize = 20)
            plt.xlabel('Learning rate', fontsize = 20)
            plt.title(name_conversion(name), fontsize = 20)
            plt.yticks(fontsize = 20)
            # plt.xscale('log')
            plt.tight_layout()
            plt.savefig(path+name+f'_vs_lr_{bs[i]}.pdf')
            plt.show()
    elif mode == 'batch':
        x = np.array([0., 1., 2., 3., 4.])
        for i in range(3):
            plt.scatter(x, y[:,:,i].mean(axis = 0))
            plt.scatter(np.tile(x, (y.shape[0], 1)), y[:,:,i], alpha = 0.4)
            plt.errorbar(x, y[:,:,i].mean(axis = 0), yerr = y[:,:,i].std(axis= 0))
            plt.xticks(x, bs, fontsize = 20)
            # plt.xticks(bs, fontsize = 20)
            plt.xlabel('Batch size', fontsize = 20)
            plt.yticks(fontsize = 20)
            plt.title(name_conversion(name), fontsize = 20)
            # plt.xscale('log')
            plt.tight_layout()
            plt.savefig(path+name+f'_vs_bs_{lrs[i]}.pdf')
            plt.show()

names = ['hess', 'jac1eval', 'gap', 'trainloss']
modes = ['batch', 'lr']

# for name in names:
#     for mode in modes:
#         plot(mode, name)

def plot_together(mode, name):
    y = np.load(path+name+'.npy')
    if mode == 'lr':
        x = np.array([0., 1., 2.])
        colours = [f'C{i:g}' for i in range(3)]
        linestyles = ['solid', 'dashed', 'dotted']
        markers = ['P', 'o', 'X']
        markersize = 50

        plt.figure(figsize = (3, 2.5))
        for i in range(3):
            # plt.scatter(x, y[:,i,:].mean(axis = 0), label = f'Batch size: {bs[i]}', linewidth = 2.0)
            # # plt.scatter(np.tile(x, (y.shape[0], 1)), y[:,i,:], alpha=0.4, label = None)
            # plt.errorbar(x, y[:,i,:].mean(axis = 0), yerr = y[:,i,:].std(axis = 0), ls = None, label = None)
            # plt.xticks(x, lrs, fontsize = 20)
            # plt.xlabel('Learning rate', fontsize = 20)
            # plt.ylabel(name_conversion(name), fontsize = 20)
            # plt.yticks(fontsize = 20)
            # plt.xscale('log')
            # plt.tight_layout()
            # plt.savefig(path+name+f'_vs_lr_{bs[i]}.pdf')
            # plt.show()

            plt.plot(x, y[:,2*i,:].mean(axis = 0), label = f'{bs[2*i]:g}', color = colours[i], linewidth = 1.0, linestyle = linestyles[i], markersize = 7, marker = markers[i], markeredgecolor = 'white')
            # plt.scatter(x, y[:,:,i].mean(axis =0), label = f'{lrs[i]:g}', color = colours[i], s = markersize, linestyle = linestyles[i], edgecolors = 'white', marker = markers[i], linewidth= 2.0)
            # plt.errorbar(x, y[:,:,i].mean(axis = 0), yerr = y[:,:,i].std(axis= 0), ls = None, label = None)
            plt.fill_between(x, y[:,2*i,:].mean(axis = 0) - y[:,2*i,:].std(axis = 0), y[:,2*i,:].mean(axis = 0) + y[:,2*i,:].std(axis = 0), alpha = 0.3, color = colours[i])
            plt.scatter(np.tile(x,  (y.shape[0], 1)), y[:,2*i,:], alpha = 0.3, label = '_nolegend_', color = colours[i], marker = markers[i], s = markersize)
            plt.xticks(x,lrs)
            # plt.xticks(bs, fontsize = 20)
            # plt.xlabel('Learning rate')
            plt.yticks()
            plt.grid(True, alpha = 0.3)
            plt.title(name_conversion(name))
    elif mode == 'batch':
        x = np.array([0., 1., 2., 3., 4.])
        colours = [f'C{i:g}' for i in range(3)]
        linestyles = ['solid', 'dashed', 'dotted']
        markers = ['P', 'o', 'X']
        markersize = 50

        plt.figure(figsize = (3, 2.5))
        for i in range(3):
            # plt.scatter(x, y[:,:,i].mean(axis = 0), label = f'Learn. rate: {lrs[i]:g}', linewidth = 2.0)
            # # plt.scatter(np.tile(x, (y.shape[0], 1)), y[:,:,i], alpha = 0.4, label = None)
            # plt.errorbar(x, y[:,:,i].mean(axis = 0), yerr = y[:,:,i].std(axis= 0), ls = None, label = None)
            # plt.xticks(x, bs, fontsize = 20)
            # # plt.xticks(bs, fontsize = 20)
            # plt.xlabel('Batch size', fontsize = 20)
            # plt.yticks(fontsize = 20)
            # plt.ylabel(name_conversion(name), fontsize = 20)
            # plt.xscale('log')
            # plt.tight_layout()
            # plt.savefig(path+name+f'_vs_bs_{lrs[i]}.pdf')
            # plt.show()

            # plt.figure(figsize = (3, 2.5))
            plt.plot(x, y[:,:,i].mean(axis = 0), label = f'{lrs[i]:g}', color = colours[i], linewidth = 1.0, linestyle = linestyles[i], markersize = 7, marker = markers[i], markeredgecolor = 'white')
            # plt.scatter(x, y[:,:,i].mean(axis =0), label = f'{lrs[i]:g}', color = colours[i], s = markersize, linestyle = linestyles[i], edgecolors = 'white', marker = markers[i], linewidth= 2.0)
            # plt.errorbar(x, y[:,:,i].mean(axis = 0), yerr = y[:,:,i].std(axis= 0), ls = None, label = None)
            plt.fill_between(x, y[:,:,i].mean(axis = 0) - y[:,:,i].std(axis = 0), y[:,:,i].mean(axis = 0) + y[:,:,i].std(axis = 0), alpha = 0.3, color = colours[i])
            plt.scatter(np.tile(x,  (y.shape[0], 1)), y[:,:,i], alpha = 0.3, label = '_nolegend_', color = colours[i], marker = markers[i], s = markersize)
            plt.xticks(x, bs)
            # plt.xticks(bs, fontsize = 20)
            # plt.xlabel('Batch size')
            plt.yticks()
            # if name == 'hess':
            #     plt.yscale('log')
            plt.grid(True, alpha = 0.3)
            plt.title(name_conversion(name))
    plt.tight_layout()
    if name == 'trainloss':
        plt.legend()
    plt.savefig(path+name+f'{mode}.pdf')
    # if name == 'hess':
    #     plt.savefig(path+name+f'{mode}_log.pdf')
    plt.show()

for name in names:
    plot_together('batch', name)
    

###########################################################################################################################################################

'''
Small network fullbatch
'''

# path = './small_network/tanh/'

# names = ['feat', 'hess', 'jac', 'loss']
# scalings = [0.5, 1.0, 1.5]

# for name in names:
#     arr = []
#     for scaling in scalings:
#         arrays = [np.load(path+f'{name}_{scaling:g}_{i}.npy') for i in range(5)]
#         array = np.stack(arrays)
#         arr.append(array)
#     a = np.stack(arr)
#     np.save(path+name+'.npy', a)

# # j = np.load(path+'jac.npy')
# # j = np.sqrt(j)
# # np.save(path+'jac.npy', j)

# def name_conversion(name):
#     if name == 'hess':
#         return 'Sharpness'
#     elif name == 'jac':
#         return 'Jacobian norm'
#     elif name == 'loss':
#         return 'Loss'

# for name in names:
#     x = np.linspace(0, 300, 300)
#     y = np.load(path+name+'.npy')[:,:,:300]
#     if name == 'feat':
#         for i in range(2):
#             plt.figure(figsize = (3, 2.5))
#             for j in range(len(scalings)):
#                 plt.plot(x, y[j,:,:,i].mean(axis = 0), label = f'{scalings[j]}')
#                 plt.fill_between(x, y[j,:,:,i].mean(axis = 0) - y[j,:,:,i].std(axis = 0), y[j,:,:,i].mean(axis = 0) + y[j,:,:,i].std(axis = 0), alpha = 0.3)
#             plt.ylabel(f'Layer {i+1} feature norm')
#             plt.xlabel('Iterations')
#             plt.tight_layout()
#             plt.savefig(path+name+f'_{i}.pdf')
#             plt.show()
#     else:
#         plt.figure(figsize = (3,2.5))
#         for j in range(len(scalings)):
#             plt.plot(x, y[j,:,:].mean(axis = 0), label = f'{scalings[j]}')
#             plt.fill_between(x, y[j,:,:].mean(axis = 0) - y[j,:,:].std(axis = 0), y[j,:,:].mean(axis = 0) + y[j,:,:].std(axis = 0), alpha = 0.3)
#         if name == 'hess':
#             plt.legend()
#         plt.ylabel(name_conversion(name))
#         plt.xlabel('Iterations')
#         plt.tight_layout()
#         plt.savefig(path+name+'.pdf')
#         plt.show()

