import argparse
import numpy as np
import random
from matplotlib.pyplot import cm
import matplotlib.pyplot as plt
import pickle

import os
# os.chdir('/.../code')
import seaborn as sns



from required_func import * 
from main_func import * 

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Hessian Averaging')
    # parser.add_argument('--dataset', metavar='DATASETS', default='w8a', type=str,
    #                 help='The dataset')
    # parser.add_argument('--kap', type=float, help='kappa')
    parser.add_argument('--size', type=float, help='Sketch Size')
    parser.add_argument('--lambd', type=float, help='Regularization parameter')
    # parser.add_argument('--plot_time', dest='plot_time', action='store_true',
    #                 help='Plot with respect to time')
    # parser.add_argument('--it_max', default=50000, type=int, metavar='IT',
    #                 help='max iteration')
    # parser.add_argument('--time_max', default=60, type=float, metavar='T',
    #                 help='max time')
    
    # parser.add_argument('--SSCN_dim', nargs='+', default=10, type=int, metavar='D',
    #                 help='Subspace dimensions of SSCN')
    
    args = parser.parse_args()
    # kap = args.kap
    SS_size = args.size
    lambd = args.lambd

    ### Simulated Data

    # kap = 1.5
    np.random.seed(2022)
    # lambd = 1e-3
    n, d = 1000, 100
    U = np.random.randn(n,d)
    # x_under = 1./np.sqrt(d)*np.random.randn(d,1)
    # Prob = scipy.special.expit(U@x_under)
    Prob = np.random.rand(n,1)


    # Data1 = DataGenerate_HighCond(1000,100,lambd,kap,50)
    # # Plot Coherence
    # U1, _, _ = np.linalg.svd(Data1.Dat, full_matrices=False)
    # f = plt.figure()
    # # saveFigpath = os.getcwd()+ '/Figure/'+'Coherence.png'
    # plt.plot(np.sort(np.linalg.norm(U1,2,axis=1)**2),label='low coherence generation',color='red',linestyle=(0,(5,1)))

    # plt.xlabel(r'$i$',fontsize=25)
    # plt.ylabel(r'$||U_{(i)}||_2^2$',fontsize=25)
    # plt.show()

    


    # lambd, Rep = Data1.lambd, Data1.Rep
    logsumexp = LogSumExp(U,Prob,lambd,max_smoothing=0.01)
    _, _, _ = logsumexp.solve_exactly()

    Count_EPS, Sket_Size = [1e-6,1e-7], list(map(int,[0.25*d,0.5*d,d,2*d,3*d,5*d]))
    nnz, power = [0.1], [1]
    sketch_func = ['Gaussian','Subsampled','CountSketch','LESS-uniform']

    SS_size = int(SS_size*d)
    SF = sketch_func[1]

    print(SF+',size='+str(SS_size))
    pp = 1
    Err_weight_power, loss_weight_power,_,_ = logsumexp.sto_weight_Sket_Newton(SS_size,'power',pp,SF)

    print(SF+',size='+str(SS_size))
    pp = 1
    Err_weight_log_power, loss_weight_log_power,_,_ = logsumexp.sto_weight_Sket_Newton(SS_size,'log_power',pp,SF)

    print(SF+',size='+str(SS_size))
    pp = 1
    Err_weight_power_HPE,loss_weight_power_HPE,_,_ = logsumexp.sto_weight_Sket_NPE(SS_size,'power',pp,SF)

    print(SF+',size='+str(SS_size))
    pp = 1
    Err_weight_log_power_HPE,loss_weight_log_power_HPE,_,_ = logsumexp.sto_weight_Sket_NPE(SS_size,'log_power',pp,SF)


    
# Function value

    sns.set_style('ticks') # setting style
    # sns.set_context('paper') # setting context
    sns.set_palette('colorblind') # setting palette

    color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']

    SMALL_SIZE = 14
    MEDIUM_SIZE = 14
    BIGGER_SIZE = 16

    plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
    plt.rc('axes', titlesize=MEDIUM_SIZE)    # fontsize of the axes title
    plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
    plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
    plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
    plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

    plt.rc('lines', markersize=6)
    plt.rc('lines', markerfacecolor='none')
    plt.rc('lines', linewidth=2)

    f = plt.figure()
    # color = iter(['magenta','blue','green','red','cyan'])
    linecycler = iter([(0, ()),(0, (5,1)),(0, (3, 1, 1, 1, 1, 1)),(0, (1, 1)),'dotted'])

    

    lll = min(sum(loss_weight_power>=1e-6)+1, len(loss_weight_power))
    plt.errorbar(range(lll),loss_weight_power[:lll],label='SN-UnifAvg', marker = '^', markevery=lll//10)

    lll = min(sum(loss_weight_log_power>=1e-6)+1, len(loss_weight_log_power))
    plt.errorbar(range(lll),loss_weight_log_power[:lll],label='SN-WeightAvg', marker = '^', markevery=max(lll//10,1))

    lll = min(sum(loss_weight_power_HPE>=1e-6)+1, len(loss_weight_power_HPE))
    plt.errorbar(range(lll),loss_weight_power_HPE[:lll],label='SNPE-UnifAvg', marker = 'D', markevery=lll//10)



    lll = min(sum(loss_weight_log_power_HPE>=1e-6)+1, len(loss_weight_log_power_HPE))
    plt.errorbar(range(lll),loss_weight_log_power_HPE[:lll],label='SNPE-WeightAvg', marker = 'D', markevery=lll//10)


    
    plt.legend(loc='upper right', fontsize=10)
    plt.xlabel('$t$',fontsize=22)
    # plt.ylabel(r'$\|x_t - x^\star\|$',fontsize=25)
    plt.ylabel(r'$f(x) - f^*$',fontsize=22)
    plt.yscale('log')
    plt.grid()
   
    # plt.xlim(0, 200) 

    plt.savefig('Figure/logsumexp/k={}_lam={}_subsampling.pdf'.format(SS_size,lambd), bbox_inches='tight')
    # plt.show()