# -*- coding: utf-8 -*-
"""Phase_Retrival.ipynb
"""

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
#import cvxpy as cp
import pdb
import random
import os
from fractions import Fraction

def Spider_PhaseRetrieval(A,y_true,total_iters,epoch_vt,epoch_momentum,gamma,S0,S1,S01,beta=0,z0=None,\
                          normalize_power=1.0,grad_max=0.0,clip_constant=0,epoch_eval=1,independent_sampling=False,print_progress=False):
  # A,y_true: input and output data
  # total_iters: algorithm iteration number
  # epoch_vt: = 1 for all GD/SGD/Clipped/Normalized/Momentum update; !=1 for SPIDER case
  # epoch momentum: =1 for all case except momentum update.
  # S0 = gradient computation batch; S01: independent re-sampling (only use if independent sampling is True)
  #S1: moving average update (S1=None) for all cases except SPIDER;
  # z0 = hyper-parameter for model setup
  # nomorlization power: between [0,1], equals to 0 if no normalization, 1 for normalization
  # grad_max: clipp hyperparameter
  # independent sampling, False in general, True, need to sample twice.
  # print_progrgess:-

    m,d=A.shape
    grad_max*=grad_max
    A_H=np.conjugate(A.T)
    if z0 is None:
        z0=np.random.normal(scale=np.sqrt(0.5),size=d)#+1j*np.random.normal(scale=np.sqrt(0.5),size=d)

    z_err_set=[]
    obj_set=[]
    grad_norm_set=[]
    iters_set=[]
    complexity_set=[]
    complexity=0
    zt=z0.copy()
    z_old=mt=Az_old=y_old=v_old=None
    for k in range(total_iters):
        if print_progress:
            print(str(k)+"-th iteration")
        Az=A.dot(zt)
        y=np.absolute(Az)**2
        #evaluation ready for printing.
        if k%epoch_eval==0:
            if print_progress:
                print("evaluating "+str(k)+"-th iteration")
            z_err_set+=[np.sqrt(np.sum(np.absolute(zt-z_true)**2))]
            # compute objective
            obj_val = ((y_true-y)**2).mean()/2
            obj_set+=[obj_val]
            # compute full gradient and its corresponding norm.(used as evaluation metric here.)
            grad=A_H.dot(Az*(y-y_true))/m
            grad_norm_set+=[np.sqrt(np.sum(np.absolute(grad)**2))]
            iters_set+=[k]
            complexity_set+=[complexity]
            if k%100 ==0:
                print('objective value at %d is %f'%(k,obj_val))
        
        # SPIDER implementation, sample batch S_0, S_1.
        if independent_sampling == False:
          if k%epoch_vt==0:
              complexity+=S0
              batch=np.random.choice(m, S0, replace=False)
              # compute stochastic gradient
              vt=A_H[:,batch].dot(Az[batch]*(y[batch]-y_true[batch]))/S0
          else:
              complexity+=S1
              batch=np.random.choice(m, S1, replace=False)
              gt=A_H[:,batch].dot(Az[batch]*(y[batch]-y_true[batch]))/S1
              g_old=A_H[:,batch].dot(Az_old[batch]*(y_old[batch]-y_true[batch]))/S1
              # compute momentum update.
              vt=v_old+gt-g_old
        else:
              avaliable_indices = np.arange(m)
              if k%epoch_vt==0:
                complexity+=S0
                batch=np.random.choice(m, S0, replace=False)
                # compute stochastic gradient
                vt=A_H[:,batch].dot(Az[batch]*(y[batch]-y_true[batch]))/S0
                complexity+=S01
                #avaliable_indices = np.setdiff1d(avaliable_indices, batch)
                independet_batch = np.random.choice(avaliable_indices, S01, replace=False)
                vt_independent = A_H[:,independet_batch].dot(Az[independet_batch]*(y[independet_batch]-y_true[independet_batch]))/S01



        if (beta==0) | (k%epoch_momentum==0):
            # no momentum update or normalization
            m_next=vt.copy()
            if independent_sampling == True:
              m_next_indepedent = vt_independent.copy()
        else:
          #momentum update if they have.
            m_next=beta*mt+(1-beta)*vt

        z_old=zt.copy()
        Az_old=Az.copy()
        v_old=vt.copy()
        y_old=y.copy()
        mt=m_next.copy()
        if independent_sampling == True:
          mt_independent = m_next_indepedent.copy()

        if normalize_power==0:
            #SGD and GD case
            zt=zt-gamma*m_next
        else:
            # clipp case, normalized and \beta-normalized case when grad_max = 0, reduce to clipp case
            norm_sq=max(np.sqrt(np.sum(np.absolute(m_next)**2))+clip_constant,grad_max)
            if independent_sampling == True:
              norm_sq=max(2*np.sqrt(np.sum(np.absolute(m_next_indepedent)**2))+clip_constant,grad_max)
            # norm_sq=max(norm_sq,0.0001)
            #beta normalization (free to choose beta)
            coeff=gamma/(norm_sq**(normalize_power))
            zt=zt-coeff*m_next
        
        


    return zt,z_err_set,obj_set,grad_norm_set,iters_set,complexity_set

def num2str_neat(num):
    a=Fraction(num)
    if abs(a.numerator)>100:
        a=Fraction(num).limit_denominator()
        return(str(a.numerator)+'/'+str(a.denominator))
    return str(num)

num_exprs=1   #number of experiments
m=3000  #number of samples
d=100    #dimensionality
y_std=4.0  #noise std of y

total_iters_GDs=10
total_iters_stoc=501

# epoch_momentum=1
# beta=0
epoch_eval=1

print_progress=False
percentile=95
colors=['red','black','blue','green','cyan','purple','gold','lime','darkorange']
markers=['.','v','s','P','*','+','-','x','o']
label_size=18
num_size=18
lgd_size=20
bottom_loc=0.1
left_loc=0.15

#Add hyperparameters for obtaining final results
folder="./Phase_results"
if not os.path.isdir(folder):
    os.makedirs(folder)

S1= 64
S01 = 4
clip_threshold = 15 # for I-NSGD
clip_threshold_1 = 15 # for clip SGD

# Experiment: 1. align all normalization power to 2/3; 2. adjust batch size; 3. run I-NSGD only but vary beta
"""
deterministic setting
"""
hyps=[{'epoch_vt':1,'gamma':8e-4,'S0':m,'S1':None,'S01':None,'normalize_power':0,
       'grad_max':0,'clip_constant':0,'beta':0,'epoch_momentum':1,'sampling':False,'legend':'GD'}]
hyps+=[{'epoch_vt':1,'gamma':0.03,'S0':m,'S1':None,'S01':None,'normalize_power':1/3,
        'grad_max':0,'clip_constant':0,'beta':0,'epoch_momentum':1,'sampling':False,'legend':r'$\frac{1}{3}$'+'GD'}]
hyps+=[{'epoch_vt':1,'gamma':0.1,'S0':m,'S1':None,'S01':None,'normalize_power':2/3,
        'grad_max':0,'clip_constant':0,'beta':0,'epoch_momentum':1,'sampling':False,'legend':r'$\frac{2}{3}$'+'GD'}]
hyps+=[{'epoch_vt':1,'gamma':0.2,'S0':m,'S1':None,'S01':None,'normalize_power':1,
        'grad_max':0,'clip_constant':0,'beta':0,'epoch_momentum':1,'sampling':False,'legend':'1-GD'}]
hyps+=[{'epoch_vt':1,'gamma':0.9,'S0':m,'S1':None,'S01':None,'normalize_power':1,
        'grad_max':100.0,'clip_constant':0,'beta':0,'epoch_momentum':1,'sampling':False,'legend':'Clipped GD'}]
"""
SGD
"""
hyps+=[{'epoch_vt':1,'gamma':5e-5,'S0':S1,'S1':None,'S01':None,'normalize_power':0,
        'grad_max':0,'clip_constant':0,'beta':0,'epoch_momentum':1,'sampling':False,'legend':'SGD'}]
#hyps+=[{'epoch_vt':1,'gamma':0.25,'S0':S1,'S1':None,'S01':None,'normalize_power':1,
        #'grad_max':0,'clip_constant':0,'beta':0,'epoch_momentum':1,'sampling':False,'legend':'NSGD'}]
"""
NSGD with momentum
"""
hyps+=[{'epoch_vt':1,'gamma':0.2,'S0':S1,'S1':None,'S01':None,'normalize_power':1,
        'grad_max':0,'clip_constant':0,'beta':2e-1,'epoch_momentum':total_iters_stoc+9,'sampling':False,'legend':'NSGDm'}]#+ r"$\beta=\frac{2}{3}$"
hyps+=[{'epoch_vt':1,'gamma':0.2,'S0':S1,'S1':None,'S01':None,'normalize_power':1,
        'grad_max':0,'clip_constant':0,'beta':0,'epoch_momentum':1,'sampling':False,'legend':'NSGD'}]#+r"$\beta=\frac{2}{3}$"
"""
Clipped Setting
"""
hyps+=[{'epoch_vt':1,'gamma':0.6,'S0':S1,'S1':None,'S01':None,'normalize_power':1,
        'grad_max':45.0,'clip_constant':clip_threshold_1,'beta':0,'epoch_momentum':1,'sampling':False,'legend':'Clip SGD'}] #+ r"$\beta=\frac{2}{3}$"
#hyps+=[{'epoch_vt':1,'gamma':5e-2,'S0':S1,'S1':None,'S01':None,'normalize_power':2/3,
        #'grad_max':45.0,'clip_constant':clip_threshold_1,'beta':0,'epoch_momentum':1,'sampling':False,'legend':'Clip SGD with '+r"$\gamma = 0.05$"}] # '$\frac{2}{3}$'
"""
SPIDER Setting
"""
#<<<<<<< Updated upstream
hyps+=[{'epoch_vt':20,'gamma':0.3,'S0':m,'S1':S1,'S01':None,'normalize_power':1,
       'grad_max':0,'clip_constant':0,'beta':0.4,'epoch_momentum':1,'sampling':False,'legend':'SPIDER'}]#+ r"$\beta=\frac{2}{3}$"
#=======
#hyps+=[{'epoch_vt':20,'gamma':0.3,'S0':m,'S1':S1,'S01':None,'normalize_power':1,
       #'grad_max':0,'clip_constant':0,'beta':0.4,'epoch_momentum':1,'sampling':False,'legend':'SPIDER'}]
#>>>>>>> Stashed changes
"""
I-NSGD Setting
"""
hyps+=[{'epoch_vt':1,'gamma':0.3,'S0':S1,'S1':None,'S01':4,'normalize_power':2/3,
        'grad_max':45.0,'clip_constant':clip_threshold,'beta':0,'epoch_momentum':1,'sampling':True,'legend':'I-NSGD'}] # +r"$\beta=\frac{2}{3}$"
#hyps+=[{'epoch_vt':1,'gamma':0.5,'S0':S1,'S1':None,'S01':8,'normalize_power':1,
        #'grad_max':45.0,'clip_constant':clip_threshold,'beta':0,'epoch_momentum':1,'sampling':True,'legend':'I-NSGD with '+ r"$\beta = 1$"}] #"${B'}=64$,"
#hyps+=[{'epoch_vt':1,'gamma':0.5,'S0':S1,'S1':None,'S01':8,'normalize_power':4/5,
       # 'grad_max':45.0,'clip_constant':clip_threshold,'beta':0,'epoch_momentum':1,'sampling':True,'legend':'I-NSGD with '+ r"$\beta = \frac{4}{5}$"}]#"${B'}=32$,"
#hyps+=[{'epoch_vt':1,'gamma':0.5,'S0':S1,'S1':None,'S01':8,'normalize_power':7/10,
        #'grad_max':45.0,'clip_constant':clip_threshold,'beta':0,'epoch_momentum':1,'sampling':True,'legend':'I-NSGD with '+ r"$\beta = \frac{7}{10}$"}] # """"
#hyps+=[{'epoch_vt':1,'gamma':0.5,'S0':S1,'S1':None,'S01':8,'normalize_power':0.65,
        #'grad_max':45.0,'clip_constant':clip_threshold,'beta':0,'epoch_momentum':1,'sampling':True,'legend':'I-NSGD with '+ r"$\beta = \frac{13}{20}$ "}]# ","
#hyps+=[{'epoch_vt':1,'gamma':0.5,'S0':S1,'S1':None,'S01':4,'normalize_power':2/3,
        #'grad_max':45.0,'clip_constant':clip_threshold,'beta':0,'epoch_momentum':1,'sampling':True,'legend':r"${B'}=4$,"+'I-NSGD'}]
# for normalize_power in normalize_power_Spider:
#     hyps+=[{'epoch_vt':100,'gamma':0.01,'S0':m,'S1':50,'normalize_power':normalize_power,'beta':0,\
#             'epoch_momentum':1,'legend':num2str_neat(normalize_power)+'-Spider'}]
"""
I-NSGD setting Alation study
"""
#hyps+=[{'epoch_vt':1,'gamma':0.5,'S0':S1,'S1':None,'S01':4,'normalize_power':2/3,
        #'grad_max':45.0,'clip_constant':clip_threshold,'beta':0,'epoch_momentum':1,'sampling':True,'legend':'I-NSGD with '+"${B'}=4$"}]
#hyps+=[{'epoch_vt':1,'gamma':0.5,'S0':S1,'S1':None,'S01':8,'normalize_power':2/3,
        #'grad_max':45.0,'clip_constant':clip_threshold,'beta':0,'epoch_momentum':1,'sampling':True,'legend':'I-NSGD with '+r"${B'}=8$" }] #"${B'}=64$,"
#hyps+=[{'epoch_vt':1,'gamma':0.5,'S0':S1,'S1':None,'S01':16,'normalize_power':2/3,
        #'grad_max':45.0,'clip_constant':clip_threshold,'beta':0,'epoch_momentum':1,'sampling':True,'legend':'I-NSGD with '+"${B'}=16$"}]#"${B'}=32$,"
#hyps+=[{'epoch_vt':1,'gamma':0.5,'S0':S1,'S1':None,'S01':32,'normalize_power':2/3,
        #'grad_max':45.0,'clip_constant':clip_threshold,'beta':0,'epoch_momentum':1,'sampling':True,'legend':'I-NSGD with '+"${B'}=32$"}] # """"
#hyps+=[{'epoch_vt':1,'gamma':0.5,'S0':S1,'S1':None,'S01':64,'normalize_power':2/3,
        #'grad_max':45.0,'clip_constant':clip_threshold,'beta':0,'epoch_momentum':1,'sampling':True,'legend':'I-NSGD with '+"${B'}=64$"}]# ","
#hyps+=[{'epoch_vt':1,'gamma':0.5,'S0':S1,'S1':None,'S01':4,'normalize_power':2/3,
        #'grad_max':45.0,'clip_constant':clip_threshold,'beta':0,'epoch_momentum':1,'sampling':True,'legend':r"${B'}=4$,"+'I-NSGD'}]


results={}
for hyp in hyps:
    results[str(hyp)]={}

# random.seed(2)
np.random.seed(2024)
z_true=np.random.normal(scale=np.sqrt(0.5),size=d)#+1j*np.random.normal(scale=np.sqrt(0.5),size=d)
# z_true=z_true/np.sqrt(np.sum(np.absolute(z_true)**2))
'''
changed here
'''
A=np.random.normal(scale=np.sqrt(0.5),size=(m,d))#+1j*np.random.normal(scale=np.sqrt(0.5),size=(m,d))   #m*d matrix whose r-th column is a_r*
# A+=(0.04/m)*np.array(range(m)).reshape((m,1))
# A=A/np.sqrt(np.sum(np.absolute(A)**2,axis=1).reshape((-1,1)))
Az_true=A.dot(z_true)
y_true=np.absolute(Az_true)**2+np.random.normal(scale=y_std,size=m)

#Initialize
"""
changed
"""
z0=np.random.normal(scale=np.sqrt(6),size=d)#+1j*np.random.normal(scale=np.sqrt(0.5),size=d)
z0+=1
A_H=np.conjugate(A.T)

hyp=hyps[2].copy()
# Need to change initialization parameter
z1,z_err_set,obj_set,grad_norm_set,iters_set,complexity_set=Spider_PhaseRetrieval\
    (A,y_true,total_iters=1,epoch_vt=hyp['epoch_vt'],epoch_momentum=hyp['epoch_momentum'],\
     gamma=hyp['gamma'],S0=m,S1=None,S01=None,beta=hyp['beta'],z0=z0,normalize_power=hyp['normalize_power'],\
        grad_max=hyp['grad_max'],clip_constant=0,epoch_eval=1,independent_sampling= False,print_progress=False)
        # change here for evaluation

for kk in range(num_exprs):
    for hyp_str in results.keys():
        hyp=eval(hyp_str)
        epoch_vt=hyp['epoch_vt']
        gamma=hyp['gamma']
        S0=hyp['S0']
        S1=hyp['S1']
        S01 = hyp['S01']
        normalize_power=hyp['normalize_power']
        beta=hyp['beta']
        epoch_momentum=hyp['epoch_momentum']
        grad_max=hyp['grad_max']
        independent_sampling=hyp['sampling']
        method=hyp['legend']
        clip_constant=hyp['clip_constant']
        epoch_eval = 1
        print("Begin "+str(kk)+"-th experiment, final result: method="+str(method)+", normalize_power="+str(normalize_power))
        z0alg=z0.copy()
        if ("SGD" in hyp['legend']) or ("SPIDER" in hyp['legend']):
            z0alg=z1.copy()

        is_stoc=("SGD" in hyp['legend']) or ("SPIDER" in hyp['legend'])
        if is_stoc:
            T=total_iters_stoc
        else:
            T=total_iters_GDs

        zt,z_err_set,obj_set,grad_norm_set,iters_set,complexity_set=Spider_PhaseRetrieval\
            (A,y_true,T,epoch_vt,epoch_momentum,gamma,S0,S1,S01,beta,z0alg,\
             normalize_power,grad_max,clip_constant,epoch_eval,independent_sampling,print_progress)
        #print(hyp['legend'])
        #print(obj_set)

        if kk==0:
            len1=len(z_err_set)
            results[hyp_str]['z_err']=np.zeros((num_exprs,len1))
            results[hyp_str]['obj']=np.zeros((num_exprs,len1))
            results[hyp_str]['grad_norm']=np.zeros((num_exprs,len1))
            results[hyp_str]['iters']=iters_set
            results[hyp_str]['complexity']=complexity_set

        results[hyp_str]['z_err'][kk,:]=z_err_set
        results[hyp_str]['obj'][kk,:]=obj_set
        results[hyp_str]['grad_norm'][kk,:]=grad_norm_set

xlabels={'iters':'Iteration t','complexity':'Sample Complexity'}
ylabels={'z_err':r'$||z_t-z^*||$','obj':r'$f(z_t)$','grad_norm':r'$||\nabla f(z_t)||$'}


x_max=np.inf
for hyp in hyps:
    x_max=min(x_max,results[str(hyp)]['complexity'][-1])

for y_type in ['obj']:
    for x_type in ['complexity','iters']:
        if x_type=="complexity":
            opt_type="stoc"
        else:
            opt_type="GD"
        plt.figure(figsize=(8,6))
        k=0
        for hyp_str in results.keys():
            hyp=eval(hyp_str)
            is_stoc=("SGD" in hyp['legend']) or ("SPIDER" in hyp['legend'])
            if opt_type=="GD" and is_stoc:
                continue
            if opt_type=="stoc" and (not is_stoc):
                continue
            if not np.any(results[hyp_str][y_type]>2*results[hyp_str][y_type].reshape(-1)[0]):
                beta=hyp['beta']
                x_plot=np.array(results[hyp_str][x_type])
                if x_type=='complexity':
                    indexes=(x_plot<=x_max)
                    upper_loss = np.percentile(results[hyp_str][y_type], percentile, axis=0)
                    lower_loss = np.percentile(results[hyp_str][y_type], 100 - percentile, axis=0)
                    avg_loss = np.mean(results[hyp_str][y_type], axis=0)
                    #print(indexes.shape)
                    #print(x_plot.shape)
                    #print(avg_loss.shape)
                    #if max(indexes) >= len(avg_loss) or max(indexes) >= len(x_plot):
                      #print('error occurs')
                    plt.plot(x_plot[indexes],avg_loss[indexes],linewidth=2.5,color=colors[k],label=hyp['legend'])
                    #marker=markers[k],markevery=int(len(avg_loss[indexes])/(k+6))
                    if num_exprs>1:
                        plt.fill_between(x_plot[indexes],lower_loss[indexes],upper_loss[indexes],color=colors[k],alpha=0.3,edgecolor="none")
                else:
                    upper_loss = np.percentile(results[hyp_str][y_type], percentile, axis=0)
                    lower_loss = np.percentile(results[hyp_str][y_type], 100 - percentile, axis=0)
                    avg_loss = np.mean(results[hyp_str][y_type], axis=0)
                    plt.plot(x_plot,avg_loss,color=colors[k],linewidth=2,label=hyp['legend'])
                    # marker=markers[k],markevery=int(len(avg_loss)/(k+6)),
                    if num_exprs>1:
                        plt.fill_between(x_plot,lower_loss,upper_loss,color=colors[k],alpha=0.3,edgecolor="none")
                k+=1
        plt.rc('axes', labelsize=label_size)   # fontsize of the x and y labels
        plt.rc('xtick', labelsize=num_size)    # fontsize of the tick labels
        plt.rc('ytick', labelsize=num_size)    # fontsize of the tick labels
        plt.legend(prop={'size':lgd_size},
                   loc='upper right', bbox_to_anchor=(1.03, 1.03), ncol=1)
        plt.xlabel(xlabels[x_type])
        plt.ylabel(ylabels[y_type])
        plt.grid(True)
        plt.gcf().subplots_adjust(bottom=bottom_loc)
        plt.gcf().subplots_adjust(left=left_loc)
        plt.ticklabel_format(axis="x", style="sci", scilimits=(0,0))
        plt.ticklabel_format(axis="y", style="sci", scilimits=(0,0))
        #if opt_type=="GD":
        plt.yscale("log")
        plt.savefig(folder+'/'+y_type+'VS'+x_type+'_'+opt_type+'_FinalResults.png',dpi=200)
        plt.close()

hyp_txt=open(folder+'/hyperparameters.txt','w')
k=0
for hyp in hyps:
    hyp_txt.write('Hyperparameter '+str(k)+':\n')
    k+=1
    for hyp_name in list(hyp.keys()):
        hyp_txt.write(hyp_name+':'+str(hyp[hyp_name])+'\n')
    hyp_txt.write('\n\n')
hyp_txt.close()

