import numpy as np

import tensorflow as tf

from collections import Counter

                                       
import matplotlib
import matplotlib.pylab as pl
import matplotlib.pyplot as plt
import seaborn as sns

import tensorflow.keras.datasets as kdds
from mnist_experiment_run import keras_image_format_to_std

import pickle as pkl



def norm21(res_weights):
    return np.sqrt((res_weights**2).sum(axis = 0)).sum()    

def norm2infty(res_weights):
    return np.sqrt((res_weights**2).sum(axis = 0)).max()

def normop(res_weights):
    return np.linalg.norm(res_weights, ord = 2)    



def compute_bound_normed_by_k(d, lip_bound = True, use_phi_inf = False):
    
    
    if lip_bound:
        prod = d['op_norm'].prod()        
        s = (((d['weight_sum'] / d['op_norm'])**(2/3)).sum())**(3/2)
        
    else:
        prod = d['op_norm'][:-1].prod() * d['weight_max'][-1]
        s = (((d['weight_sum'][:-1] / d['op_norm'][:-1])**(2/3)).sum() +1)**(3/2)
    

    res =  s*prod
    if not use_phi_inf:
        res *= prod
    
    
    k = d['shapes'][-1][1]
    
    if lip_bound:
        res /= k
    
    
    return res 
    



def get_data(K, header, architecture, i):    

    save_file = f'run_data/mnist_run_{K}_{header}_180_{i}_{architecture}.npz'                    
    
    with open(save_file,'rb') as f:
        load_res = np.load(save_file)
    
    res_tup = tuple( load_res[x] for x in load_res.files)
    (loss_lst,val_lst,train_eval_lst,test_eval_lst) = res_tup[-4:]
    
    weights_lst = [w for i,w in enumerate(res_tup[:-4]) if i % 2 == 0]
    
    return (loss_lst,val_lst,train_eval_lst,test_eval_lst,weights_lst)





if __name__ == '__main__':


    
    (X_train, y_train), (X_test, y_test) = kdds.mnist.load_data()
    ntmp = np.prod(X_train.shape[1:])
    X_train = keras_image_format_to_std(X_train).reshape(-1,ntmp)
    X_test = keras_image_format_to_std(X_test).reshape(-1,ntmp)
    
    Nclasses = 10

    D = X_train.shape[1]

    
    print('Done Loading...')

    k_lst = [50, 100, 500, 1000, 1500, 5000] #[1, 2, 3, 4, 5, 10, 50, 100, 500, 1000, 1500, 5000]
    run_lst = [1,2,3]
    
    header_lst = ['relu', 'sigmoid']
    architecture_lst = ['x500x','x']

    res_dict = {}


    for  K in k_lst:
        print(K)
        for header in header_lst:
            for i in run_lst:
                for architecture in architecture_lst:

                    (loss_lst,
                     val_lst,
                     train_eval_lst,
                     test_eval_lst,
                     weights_lst
                    ) = get_data(K, header, architecture, i)    
                
                    
                    
                    res_dict[(K,header,architecture,i)] = {'weight_sum': np.array([norm21(w) for w in weights_lst]),
                                            'weight_max': np.array([norm2infty(w) for w in weights_lst]),
                                            'op_norm': np.array([normop(w) for w in weights_lst]),
                                            'loss_lst':loss_lst,
                                            'val_lst':val_lst,
                                            'train_eval_lst':train_eval_lst,
                                            'test_eval_lst':test_eval_lst,
                                            'shapes': [w.shape for w in weights_lst]                                            
                                            }
                    
                    cdict = res_dict[(K,header,architecture,i)]
                    s_arr  = np.array([s[1] for s in cdict['shapes']])
                    cdict['op_norm_ratios'] = cdict['op_norm']/ cdict['weight_max'] #(cdict['weight_max']*np.sqrt(s_arr))
                    cdict['sum_norm_ratios'] = cdict['weight_sum']/ cdict['weight_max'] #(cdict['weight_max']*s_arr)
                    



    #there are itertools methods to do that...    
    def list_sum(l):
        if len(l) == 1:
            return l[0]
        return l[0]+list_sum(l[1:])

    mean_res_dict = {}
    for K in k_lst:        
        for header in header_lst:            
            for architecture in architecture_lst:
                
                key_lst = res_dict[(K,header,architecture,run_lst[0])].keys()
                mean_res_dict[(K,header,architecture)] = {}
                
                for key in key_lst:                    
                    if key == 'shapes':
                        continue
                    
                    mean_res_dict[(K,header,architecture)][key] = list_sum(
                            [res_dict[(K,header,architecture,i)][key] for i in run_lst]
                    )
                
                    mean_res_dict[(K,header,architecture)][key] /= len(run_lst)
    
    
    std_res_dict = {}
    for K in k_lst:        
        for header in header_lst:            
            for architecture in architecture_lst:
                
                key_lst = res_dict[(K,header,architecture,run_lst[0])].keys()
                std_res_dict[(K,header,architecture)] = {}
                
                for key in key_lst:                    
                    if key == 'shapes':
                        continue
                    
                    std_res_dict[(K,header,architecture)][key] = list_sum(
                            [(res_dict[(K,header,architecture,i)][key] - mean_res_dict[(K,header,architecture)][key])**2 for i in run_lst]
                    )
                
                    std_res_dict[(K,header,architecture)][key] /= len(run_lst)
                    std_res_dict[(K,header,architecture)][key] = np.sqrt(std_res_dict[(K,header,architecture)][key])

    
    
    n_k = len(k_lst)



    matplotlib.rcParams.update({'font.size': 10 })    #30
    
        
    zzparam = {'markersize':15,'linewidth':5}
    



    


    matplotlib.rcParams.update({'font.size': 30 })    #30
    #matplotlib.rcParams['text.usetex'] = True
    #matplotlib.rc('font', family='serif')
    pl.rc('text', usetex=True)
    #pl.rc('font', family='serif')

    def arch_map(s):
        return 'L=1' if s == 'x' else 'L=2'


    def header_map(header):
        return ''


    fig = pl.figure()

    color = 0
    for header in ['relu']:  #,'sigmoid']: #header_lst:            
        for architecture in ['x500x','x']: #architecture_lst:

            train_res = np.array([mean_res_dict[(K,header,architecture)]['sum_norm_ratios'][-1]/K for K in k_lst])
            train_std = np.array([std_res_dict[(K,header,architecture)]['sum_norm_ratios'][-1]/K for K in k_lst])

            a = train_res[0]
            train_res /= a
            train_std /= a

            pl.errorbar(range(n_k),train_res,yerr = train_std,fmt=f'-oC{color}',label = f'{header_map(header)} {arch_map(architecture)} '+'$\|A\|_{2,1}/(k\cdot \|A\|_{2,\infty})$ ', **zzparam)
            

            train_res = np.array([mean_res_dict[(K,header,architecture)]['op_norm_ratios'][-1]/np.sqrt(K) for K in k_lst])
            train_std = np.array([std_res_dict[(K,header,architecture)]['op_norm_ratios'][-1]/np.sqrt(K) for K in k_lst])

            a = train_res[0]
            train_res /= a
            train_std /= a

            pl.errorbar(range(n_k),train_res,yerr = train_std,fmt=f'--oC{color}',label = f'{header_map(header)} {arch_map(architecture)} '+'$\|A\|_{op}/(\sqrt{k} \cdot \|A\|_{2,\infty})$ ', **zzparam)
            
                
    
            color += 1


    
    pl.xticks(range(n_k),map(str,k_lst)) #, rotation = 'vertical')
    
    
    fig.suptitle('MNIST Norm Ratios, relu')
    pl.ylabel('ratio')
    pl.xlabel('k')
    
    pl.legend()
    plt.grid(True)
    pl.show()



    fig = pl.figure()

    color = 0
    for header in ['sigmoid']: #header_lst:            
        for architecture in ['x500x','x']: #architecture_lst:

            train_res = np.array([mean_res_dict[(K,header,architecture)]['sum_norm_ratios'][-1]/K for K in k_lst])
            train_std = np.array([std_res_dict[(K,header,architecture)]['sum_norm_ratios'][-1]/K for K in k_lst])

            a = train_res[0]
            train_res /= a
            train_std /= a

            pl.errorbar(range(n_k),train_res,yerr = train_std,fmt=f'-oC{color}',label = f'{header_map(header)} {arch_map(architecture)} '+'$\|A\|_{2,1}/(k\cdot \|A\|_{2,\infty})$ ', **zzparam)
            

            train_res = np.array([mean_res_dict[(K,header,architecture)]['op_norm_ratios'][-1]/np.sqrt(K) for K in k_lst])
            train_std = np.array([std_res_dict[(K,header,architecture)]['op_norm_ratios'][-1]/np.sqrt(K) for K in k_lst])

            a = train_res[0]
            train_res /= a
            train_std /= a

            pl.errorbar(range(n_k),train_res,yerr = train_std,fmt=f'--oC{color}',label = f'{header_map(header)} {arch_map(architecture)} '+'$\|A\|_{op}/(\sqrt{k} \cdot \|A\|_{2,\infty})$ ', **zzparam)
            
                
    
            color += 1


    
    pl.xticks(range(n_k),map(str,k_lst)) #, rotation = 'vertical')
    
    
    fig.suptitle('MNIST Norm Ratios, sigmoid')
    pl.ylabel('ratio')
    pl.xlabel('k')
    
    pl.legend()
    plt.grid(True)
    pl.show()




    
