import numpy as np
import random
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

random.seed(1)
np.random.seed(1)

num_expr=100
d=50
n=10
settings=["exp","poly"]
setting=settings[0]
#f(x)=sum_{j=1}^d sum_{k=-n}^{n} f_{j,k}(x)/[d(2n+1)]
#"exp": f_{j,k}(x)=exp(x_j-k)+exp(k-x_j)+||x||^2/2
#"poly": f_{j,k}(x)=x_j^4 + k*x_j

algorithms=["SGD","fixed-shuffle","begin-shuffle","uniform-shuffle"]
#"fixed-shuffle": all permutations pi^{(t)} equal to the data's original sequence
#"begin-shuffle": all permutations pi^{(t)} equal to a random sequence obtained before running algorithm
#"uniform-shuffle": Each permutation pi^{(t)}==uniformly obtained, i.i.d..
total_iters=10000

eta={}   #Stepsizes for each algorithm
if setting=="exp":
    eta1=0.1**5
else:
    eta1=0.1**2
eta["SGD"]=eta1
eta["fixed-shuffle"]=eta1
eta["begin-shuffle"]=eta1
eta["uniform-shuffle"]=eta1

jk_pairs_original=[(j,k) for j in range(d) for k in range(-n,n+1)]
total_len=len(jk_pairs_original)
x0=np.ones(d)
x={}
x_norm={}
f={}
k_vec=np.array(range(-n,n+1)).reshape((-1,1))
e=np.exp(1)
const=(np.exp(n+1)-np.exp(-n))/(2*n+1)/(e-1)

for alg in algorithms:
    x_norm[alg]=-np.ones((num_expr,total_iters))
    f[alg]=-np.ones((num_expr,total_iters))
    for expr_k in range(num_expr):
#        print("\n")
        x[alg]=x0.copy()
        x_norm[alg][expr_k,0]=np.sqrt((x[alg]*x[alg]).sum())
    
        if setting=="exp":  #f(x)=sum_{j=1}^d sum_{k=-n}^{n} [exp(x_j-k)+exp(k-x_j)]/[d(2n+1)]
            f[alg][expr_k,0]=(x[alg]*x[alg]).sum()/2+(np.exp(x[alg])+np.exp(-x[alg])).mean()*const
        else:         #"poly": f_{j,k}(x)=x_j^4 + k*x_j
            f[alg][expr_k,0]=(x[alg]*x[alg]*x[alg]*x[alg]).mean()
        
        jk_pairs_remain=[]
        if alg=="begin-shuffle":
            jk_pairs_begin=random.sample(jk_pairs_original,total_len)
    
        for t in range(total_iters-1):
            if alg=="SGD":
                jk_pair_now=random.choice(jk_pairs_original)
            elif alg=="fixed-shuffle":
                if not jk_pairs_remain:
                    jk_pairs_remain=jk_pairs_original.copy()
                jk_pair_now=jk_pairs_remain.pop(0)
            elif alg=="begin-shuffle":
                if not jk_pairs_remain:
                    jk_pairs_remain=jk_pairs_begin.copy()
                jk_pair_now=jk_pairs_remain.pop(0)
            else: #uniform-shuffle
                if not jk_pairs_remain:
                    jk_pairs_remain=random.sample(jk_pairs_original,total_len)
                jk_pair_now=jk_pairs_remain.pop(0)
            j=jk_pair_now[0]
            k=jk_pair_now[1]
            if setting=="exp":  
                #f(x)=sum_{j=1}^d sum_{k=-n}^{n} f_{j,k}(x)/[d(2n+1)]
                #"exp": f_{j,k}(x)=exp(x_j-k)+exp(k-x_j)+||x||^2/2
                x_newj=x[alg][j]-eta[alg]*(np.exp(x[alg][j]-k)-np.exp(k-x[alg][j])+x[alg][j])
                x[alg]*=(1-eta[alg])
                x[alg][j]=x_newj
                x_norm_sq=(x[alg]*x[alg]).sum()
                x_norm[alg][expr_k,t+1]=np.sqrt(x_norm_sq)
                f[alg][expr_k,t+1]=x_norm_sq/2+(np.exp(x[alg])+np.exp(-x[alg])).mean()*const
            else:         #"poly": f(x)=sum_{j=1}^d sum_{k=-n}^{n} (x_j^ 4 + k*x_j)/[d(2n+1)]
                x[alg][j]-=eta[alg]*(k+4*x[alg][j]*x[alg][j]*x[alg][j])
                f[alg][expr_k,t+1]=(x[alg]*x[alg]*x[alg]*x[alg]).mean()
                x_norm[alg][expr_k,t+1]=np.sqrt((x[alg]*x[alg]).sum())
            if num_expr==1:
                if t % 100 ==0:
                    print(alg+": Iteration "+str(t)+", ||x_t||="+str(x_norm[alg][expr_k,t])+", f(x_t)="+str(f[alg][0,t]))

        if num_expr>1: 
            print(alg+": The "+str(expr_k)+"-th experiment finished.")


colors=['red','black','blue','green','cyan','purple','gold','lime','darkorange']
markers=['P','v','*','s','.']
legends={"SGD":"SGD","fixed-shuffle":"Fixed-shuffling","begin-shuffle":"Shuffle-once","uniform-shuffle":"Uniform-shuffling"}

if setting=="exp":
    png_dir="exponential_result/"
else:
    png_dir="poly_result/"
if not os.path.isdir(png_dir):
    os.makedirs(png_dir)

hyp_txt=open(png_dir+'hyperparameters.txt','w')
hyp_txt.write('number of iterations='+str(total_iters)+'\n')
for alg in algorithms:
    hyp_txt.write('Stepsize eta for '+str(alg)+':'+str(eta[alg])+'\n')
hyp_txt.close()
    
label_size=16
num_size=14
lgd_size=18
percentile=95
plt.figure()
alg_k=0
for alg in algorithms:
    if num_expr==1:
        expr_k=0
        plt.plot(range(total_iters),x_norm[alg][expr_k],color=colors[alg_k],marker=markers[alg_k],markevery=int(total_iters/(alg_k+6)),label=legends[alg])
    else:
        x_plot=np.array(range(total_iters))
        y_plot=x_norm[alg].copy()
        upper_loss = np.percentile(y_plot, percentile, axis=0)
        lower_loss = np.percentile(y_plot, 100 - percentile, axis=0)
        avg_loss = np.mean(y_plot, axis=0)
        plt.plot(x_plot,avg_loss,color=colors[alg_k],marker=markers[alg_k],markevery=int(total_iters/(alg_k+6)),label=legends[alg])
        plt.fill_between(x_plot,lower_loss,upper_loss,color=colors[alg_k],alpha=0.3,edgecolor="none")        
    alg_k+=1
plt.legend(prop={'size':lgd_size},loc=1)
plt.xlabel("Iteration t")
plt.ylabel(r'$\|x_t\|$')
plt.rc('axes', labelsize=label_size)   # fontsize of the x and y labels
plt.rc('xtick', labelsize=num_size)    # fontsize of the tick labels
plt.rc('ytick', labelsize=num_size)    # fontsize of the tick labels
plt.gcf().subplots_adjust(bottom=0.15)
plt.gcf().subplots_adjust(left=0.15)
plt.grid(True)
plt.savefig(png_dir+"xnorm.png",dpi=200)
plt.close()

plt.figure()
alg_k=0
for alg in algorithms:
    if num_expr==1:
        expr_k=0
        plt.plot(range(total_iters),f[alg][expr_k],color=colors[alg_k],marker=markers[alg_k],markevery=int(total_iters/(alg_k+6)),label=legends[alg])
    else:
        x_plot=np.array(range(total_iters))
        y_plot=f[alg].copy()
        upper_loss = np.percentile(y_plot, percentile, axis=0)
        lower_loss = np.percentile(y_plot, 100 - percentile, axis=0)
        avg_loss = np.mean(y_plot, axis=0)
        plt.plot(x_plot,avg_loss,color=colors[alg_k],marker=markers[alg_k],markevery=int(total_iters/(alg_k+6)),label=legends[alg])
        plt.fill_between(x_plot,lower_loss,upper_loss,color=colors[alg_k],alpha=0.3,edgecolor="none")        
    alg_k+=1
plt.legend(prop={'size':lgd_size},loc=1)
plt.xlabel("Iteration t")
plt.ylabel(r'$f(x_t)$')
plt.rc('axes', labelsize=label_size)   # fontsize of the x and y labels
plt.rc('xtick', labelsize=num_size)    # fontsize of the tick labels
plt.rc('ytick', labelsize=num_size)    # fontsize of the tick labels
plt.gcf().subplots_adjust(bottom=0.15)
plt.gcf().subplots_adjust(left=0.15)
plt.grid(True)
plt.savefig(png_dir+"fx.png",dpi=200)
plt.close()



