import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import math

def generate_data(sample_number, dimension):
    z_true=np.random.normal(scale=np.sqrt(0.5),size=dimension)
    a=np.random.normal(scale=np.sqrt(0.5),size=(sample_number,dimension))
    y=np.absolute(a.dot(z_true))**2+np.random.normal(scale=5,size=sample_number)
    return a,y,z_true

sample_number=3000
dimension=10
a,y,z_true=generate_data(sample_number,dimension)


def gradient_calculate(a,y,z_t1,batch_size=10):
    sample_number, dimension=a.shape
    index=np.random.choice(sample_number,batch_size,replace=False)
    a_H=a.T
    Az=np.dot(a,z_t1)
    gradient=a_H[:,index].dot(Az[index]*((np.absolute(Az)**2)[index]-y[index]))/batch_size
    return gradient

def loss_calculate(a,y,z):
    loss=((y-(a.dot(z))**2)**2).mean()/4
    return loss

def SGD(a,y, eta, total_iteration):
    z=np.random.normal(scale=np.sqrt(5),size=dimension)+1
    loss_temp=np.zeros(total_iteration)
    for i in range(total_iteration):
        gradient=gradient_calculate(a,y,z)
        z=z-eta*gradient
        loss_temp[i]=loss_calculate(a,y,z)
    return loss_temp

##RSAG
def RSAG_convex(a,y, beta, total_iteration):
    z_t=np.random.normal(scale=np.sqrt(5),size=dimension)+1
    z_t_ag=z_t
    B_t=0
    A_t=1/beta
    A_t_1=1/beta
    loss_temp=np.zeros(total_iteration)
    for i in range(total_iteration):
         B_t=B_t+0.5*(1+math.sqrt(4*B_t+1))         
         A_t=B_t+1/beta
         z_t_md=A_t_1/A_t*z_t_ag+(1-A_t_1/A_t)*z_t
         gradient=gradient_calculate(a,y,z_t_md)
         z_t=z_t-(0.25*(A_t-A_t_1))*beta*gradient
         z_t_ag=z_t_md-beta*gradient
         A_t_1=A_t
         loss_temp[i]=loss_calculate(a,y,z_t_md)
    return loss_temp

def RSAG_nonconvex(a,y, eta, total_iteration):
    z_t=np.random.normal(scale=np.sqrt(5),size=dimension)+1
    z_t_ag=z_t
    loss_temp=np.zeros(total_iteration)
    for i in range(total_iteration):
         alpha_t=2/(i+2)
         beta_t=(1+alpha_t)* eta
         z_t_md=(1-alpha_t)*z_t_ag+alpha_t*z_t
         gradient=gradient_calculate(a,y,z_t_md)
         z_t=z_t-eta*gradient
         z_t_ag=z_t_md-beta_t*gradient
         loss_temp[i]=loss_calculate(a,y,z_t_md)
    return loss_temp

def algorithm_mean(algorithm,eta,a,y,times, total_iteration):
    loss=np.zeros((times,total_iteration))
    for i in range(times):
        if algorithm=="SGD":
            loss[i:]=SGD(a,y,eta,total_iteration)
        if algorithm=="RSAG_nonconvex":
            loss[i:]=RSAG_nonconvex(a,y,eta,total_iteration)
        if algorithm=="RSAG_convex":
            loss[i:]=RSAG_convex(a,y,eta,total_iteration)
    return loss

begin=0
total_iteration=250
times=100
eta=[0.0001,0.0005]
color=['blue','green','red']
linestyle=['-','dashed']

'''avg_SGD=[]
std_SGD=[]
high_SGD=[]
low_SGD=[]
avg_RSAG_convex=[]
std_RSAG_convex=[]
high_RSAG_convex=[]
low_RSAG_convex=[]
avg_RSAG_nonconvex=[]
std_RSAG_nonconvex=[]
high_RSAG_nonconvex=[]
low_RSAG_nonconvex=[]'''

for i in range(len(eta)):
    loss_SGD=algorithm_mean("SGD",eta[i],a,y,times,total_iteration)
    loss_RSAG_convex=algorithm_mean("RSAG_convex",eta[i],a,y,times,total_iteration)
    #loss_RSAG_nonconvex=algorithm_mean("RSAG_nonconvex",eta[i],a,y,times,total_iteration)

    #print(loss_mean_SGD)
    #print(loss_mean_RSAG)

    avg_SGD=np.mean(loss_SGD,axis=0)
    std_SGD=np.std(loss_SGD,axis=0)
    high_SGD=list(map(lambda x: x[0]+x[1],zip(avg_SGD,std_SGD)))
    low_SGD=list(map(lambda x: x[0]-x[1],zip(avg_SGD,std_SGD)))

    avg_RSAG_convex=np.mean(loss_RSAG_convex,axis=0)
    std_RSAG_convex=np.std(loss_RSAG_convex,axis=0)
    high_RSAG_convex=list(map(lambda x: x[0]+x[1],zip(avg_RSAG_convex,std_RSAG_convex)))
    low_RSAG_convex=list(map(lambda x: x[0]-x[1],zip(avg_RSAG_convex,std_RSAG_convex)))



    #avg_RSAG_nonconvex=np.mean(loss_RSAG_nonconvex,axis=0)
    #std_RSAG_nonconvex=np.std(loss_RSAG_nonconvex,axis=0)
    #high_RSAG_nonconvex=list(map(lambda x: x[0]+x[1],zip(avg_RSAG_nonconvex,std_RSAG_nonconvex)))
    #low_RSAG_nonconvex=list(map(lambda x: x[0]-x[1],zip(avg_RSAG_nonconvex,std_RSAG_nonconvex)))

    plt.plot(avg_SGD,color=color[0],linestyle=linestyle[i],label='SGD,'r' $\beta={}$'.format(eta[i]))
    plt.plot(avg_RSAG_convex, color=color[2],linestyle=linestyle[i],label='SNAG (Alg 2),' r' $\beta={}$'.format(eta[i]))
    #plt.plot(avg_RSAG_nonconvex, color=color[1],linestyle=linestyle[i],label='stochastic AGD (Alg 3),' r' $\eta={}$'.format(eta[i]))

    #plt.fill_between(np.arange(total_iteration-begin),low_SGD,high_SGD, color=color[0],alpha=0.2)
    #plt.fill_between(np.arange(total_iteration-begin),low_RSAG_convex,high_RSAG_convex, color=color[2],alpha=0.2)
    #plt.fill_between(np.arange(total_iteration-begin),low_RSAG_nonconvex,high_RSAG_nonconvex, color=color[2],alpha=0.2)

plt.xlabel("Iteration t")
plt.ylabel("Loss")

plt.title("Convergence behavior for the phase retrieval problem")
plt.legend()
plt.show()


'''plt.plot(loss_mean_SGD[begin:])
plt.plot(loss_mean_RSAG[begin:],color='red')
plt.title('Loss Function Convergence')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.show()

plt.savefig('RSAG and SGD.png')'''