import matplotlib.pyplot as plt
import numpy as np
import pdb
import random
import os
from fractions import Fraction
import time

random.seed(1)
np.random.seed(1)
algorithms=["SGD","fixed-shuffle","begin-shuffle","uniform-shuffle"]
gamma={"SGD":2e-6,"fixed-shuffle":7e-3,"begin-shuffle":7e-3,"uniform-shuffle":7e-3}

time_start=time.time()
def Shuffle_PhaseRetrieval(A,y_true,total_iters,gamma,z0=None,alg=algorithms[0],epoch_eval=1,print_progress=False):
    m,d=A.shape
    A_H=np.conjugate(A.T)
    if z0 is None:
        z0=np.random.normal(scale=np.sqrt(0.5),size=d)

    if alg!="SGD":
        gamma/=m

    z_err_set=[]
    obj_set=[]
    grad_norm_set=[]
    zt=z0.copy()
    if alg=="begin-shuffle":
        begin_seq=np.random.choice(m, m, replace=False)
    for k in range(total_iters):
        if print_progress:
            print(str(k)+"-th iteration")
        Az=A.dot(zt)
        y=np.absolute(Az)**2
        
        if k%epoch_eval==0:
            if print_progress:
                print("evaluating "+str(k)+"-th iteration")
            z_err_set+=[np.sqrt(np.sum(np.absolute(zt-z_true)**2))]
            obj_set+=[((y_true-y)**2).mean()/2]
            grad=A_H.dot(Az*(y-y_true))/m
            grad_norm_set+=[np.sqrt(np.sum(np.absolute(grad)**2))]
            
        if alg=="uniform-shuffle":
            if k%m==0:
                remain_batch=np.random.choice(m, m, replace=False)
            batch=remain_batch[k%m]
        elif alg=="fixed-shuffle":
            batch=(k%m)
        elif alg=="begin-shuffle":
            batch=begin_seq[k%m]
        else:   #SGD
            batch=np.random.choice(m, 1, replace=False)
        
        vt=A_H[:,batch].dot(Az[batch]*(y[batch]-y_true[batch]))
        zt=zt-gamma*vt

    return zt,z_err_set,obj_set,grad_norm_set

def num2str_neat(num):
    a=Fraction(num)
    if abs(a.numerator)>100:
        a=Fraction(num).limit_denominator()
        return(str(a.numerator)+'/'+str(a.denominator))
    return str(num)
    

#hyperparameters
num_exprs=100   #number of experiments
m=3000  #number of samples
d=100    #dimensionality
y_std=4.0  #noise std of y
total_iters=1001

epoch_eval=1

print_progress=False
percentile=95
colors=['red','black','blue','green','cyan','purple','gold','lime','darkorange']
markers=['v','s','P','*','h','.']
label_size=16
num_size=14
lgd_size=18
bottom_loc=0.15
left_loc=0.2

#Add hyperparameters for obtaining final results
folder="./Phase_results/"
if not os.path.isdir(folder):
    os.makedirs(folder)
        
np.random.seed(1)
z_true=np.random.normal(scale=np.sqrt(0.5),size=d)
A=np.random.normal(scale=np.sqrt(0.5),size=(m,d))   #m*d matrix whose r-th column is a_r*
Az_true=A.dot(z_true)
y_true=np.absolute(Az_true)**2+np.random.normal(scale=y_std,size=m)

#Initialize
z0=np.random.normal(scale=np.sqrt(0.5),size=d)
z0+=5
A_H=np.conjugate(A.T)

x_plot=np.array(range(0,total_iters,epoch_eval))
len_x=len(x_plot)
results={alg:[] for alg in algorithms}
for alg in algorithms:
    results[alg]={y_type:np.zeros((num_exprs,len_x)) for y_type in ['z_err','obj','grad_norm']}
for kk in range(num_exprs):  
    for alg in algorithms:
        print("Begin "+str(kk)+"-th experiment: Algorithm "+str(alg))            
        
        zt,z_err_set,obj_set,grad_norm_set=Shuffle_PhaseRetrieval\
            (A,y_true,total_iters,gamma[alg],z0,alg,epoch_eval,print_progress)
            
        results[alg]['z_err'][kk,:]=z_err_set
        results[alg]['obj'][kk,:]=obj_set
        results[alg]['grad_norm'][kk,:]=grad_norm_set
time_seconds=time.time()-time_start

hyp_txt=open(folder+'hyperparameters.txt','w')
hyp_txt.write('number of iterations='+str(total_iters)+'\n')
for alg in algorithms:
    hyp_txt.write('Stepsize gamma for '+str(alg)+':'+str(gamma[alg])+'\n')
hyp_txt.write('Total time consumption: '+str(time_seconds/60)+' minutes.\n\n')
hyp_txt.close()

ylabels={'z_err':r'$||z_t-z^*||$','obj':r'$f(z_t)$','grad_norm':r'$||\nabla f(z_t)||$'}
legends={"SGD":"SGD","fixed-shuffle":"Fixed-shuffling","begin-shuffle":"Shuffle-once","uniform-shuffle":"Uniform-shuffling"}
for y_type in ['obj']:
    plt.figure().set_size_inches(5, 5)
    # plt.figure(figsize=(6,6))
    k=0
    for alg in algorithms:
        y_plot=results[alg][y_type].copy()
        upper_loss = np.percentile(y_plot, percentile, axis=0)
        lower_loss = np.percentile(y_plot, 100 - percentile, axis=0)
        avg_loss = np.mean(y_plot, axis=0)
        plt.plot(x_plot,avg_loss,color=colors[k],marker=markers[k],markevery=int(len(avg_loss)/(k+6)),label=legends[alg])
        if num_exprs>1:
            plt.fill_between(x_plot,lower_loss,upper_loss,color=colors[k],alpha=0.3,edgecolor="none")
        k+=1
    plt.legend(prop={'size':lgd_size},loc=1)
    plt.xlabel('Iteration t')
    plt.ylabel(ylabels[y_type])
    plt.rc('axes', labelsize=label_size)   # fontsize of the x and y labels
    plt.rc('xtick', labelsize=num_size)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=num_size)    # fontsize of the tick labels
    plt.gcf().subplots_adjust(bottom=bottom_loc)
    plt.gcf().subplots_adjust(left=left_loc)
    plt.ticklabel_format(axis="x", style="sci", scilimits=(0,0))
    plt.ticklabel_format(axis="y", style="sci", scilimits=(0,0))
    plt.grid(True)
    plt.savefig(folder+'PhaseResults.png',dpi=200)
    plt.close()
        

