import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import math

def activation_function(t,alpha=0.2,l=1):
    if t>=0:
        sigma=l*t
    else:
        sigma=alpha*t
    return sigma


def generate_data(sample_number,hidden_neuron_number,dimension):
    a=(np.random.randint(0,2,size=hidden_neuron_number)*2-1)/hidden_neuron_number
    x=np.random.normal(scale=np.sqrt(5),size=(sample_number,dimension))
    y=np.random.randint(0,2,size=sample_number)*2-1
    return a,x,y

sample_number=500
hidden_neuron_number=30
dimension=10
a,x,y=generate_data(sample_number,hidden_neuron_number,dimension)
'''print("a",a.shape[0])
print("x",x)
print("y",y)'''
#w=np.random.normal(scale=1,size=(hidden_neuron_number,dimension))

def f_out(t):
    try:
        return math.exp(-t)
    except Exception as result:
        print(t)

def inner_mul(y,a,w,x,index):
    sample_number,dimension=x.shape
    hidden_neuron_number=a.shape[0]
    w_x=np.dot(w,x.T)
    '''inner_value=np.zeros(sample_number)
    for i in index:
        temp=0
        for j in range(hidden_neuron_number):
            temp=temp+a[j]*activation_function(w_x[j,i])
        temp=temp*y[i]
        inner_value[i]=temp
    print(inner_value[:10])'''
    inner_value=np.zeros(sample_number)
    act_np=np.vectorize(activation_function)
    act_w_x=act_np(w_x)
    for i in index:
        temp=y[i]*a.dot(act_w_x[:,i])
        inner_value[i]=temp
    return inner_value


def gradient_calculate(a,x,y,w,batch_size=10,alpha=0.2,l=1):
    sample_number, dimension=x.shape
    hidden_neuron_number=a.shape[0]
    index=np.random.choice(sample_number,batch_size,replace=False)
    inner_value=inner_mul(y,a,w,x,index)
    gradient=np.empty((hidden_neuron_number,dimension))
    for j in range(hidden_neuron_number):
        gradient_w_j=np.zeros(dimension)
        for i in index:
            wj_x_i=w[j,:].dot(x[i,:])
            if wj_x_i>=0:
                gradient_w_j=gradient_w_j-f_out(inner_value[i])*y[i]*a[j]*l*x[i,:]
            else:
                gradient_w_j=gradient_w_j-f_out(inner_value[i])*y[i]*a[j]*alpha*x[i,:]
        gradient[j,:]=gradient_w_j/batch_size
    return gradient

def loss_calculate(a,x,y,w):
    sample_number,dimension=x.shape
    all=np.arange(0,sample_number,1)
    inner_value=inner_mul(y,a,w,x,all)
    loss=0
    for i in range(sample_number):
        loss=loss+f_out(inner_value[i])
    loss=loss/sample_number
    return loss

def SGD(a,x,y,beta, total_iteration):
    w_t=np.random.normal(scale=np.sqrt(5),size=(hidden_neuron_number,dimension))+1
    loss_temp=np.zeros(total_iteration)
    for t in range(total_iteration):
        try:
            gradient=gradient_calculate(a,x,y,w_t)
            w_t=w_t-beta*gradient
            loss_temp[t]=loss_calculate(a,x,y,w_t)
        except Exception as result:
            t=t-1
    return loss_temp



##RSAG
def RSAG_convex(a,x,y, beta, total_iteration):
    w_t=np.random.normal(scale=np.sqrt(5),size=(hidden_neuron_number,dimension))+1
    w_t_ag=w_t
    B_t=0
    A_t=1/beta
    A_t_1=1/beta
    loss_temp=np.zeros(total_iteration)
    for t in range(total_iteration):
         try:
            B_t=B_t+0.5*(1+math.sqrt(4*B_t+1))         
            A_t=B_t+1/beta
            w_t_md=A_t_1/A_t*w_t_ag+(1-A_t_1/A_t)*w_t
            gradient=gradient_calculate(a,x,y,w_t_md)
            w_t=w_t-0.25*(A_t-A_t_1)*beta*gradient
            w_t_ag=w_t_md-beta*gradient
            A_t_1=A_t
            loss_temp[t]=loss_calculate(a,x,y,w_t_md)
         except Exception as result:
            t=t-1
    return loss_temp

def RSAG_nonconvex(a,x,y, eta, total_iteration):
    w_t=np.random.normal(scale=np.sqrt(5),size=(hidden_neuron_number,dimension))+1
    w_t_ag=w_t
    loss_temp=np.zeros(total_iteration)
    for t in range(total_iteration):
         try:
            alpha_t=2/(t+2)
            beta_t=(1+alpha_t)* eta
            w_t_md=(1-alpha_t)*w_t_ag+alpha_t*w_t
            gradient=gradient_calculate(a,x,y,w_t_md)
            w_t=w_t-eta*gradient
            w_t_ag=w_t_md-beta_t*gradient
            loss_temp[t]=loss_calculate(a,x,y,w_t_md)
         except Exception as result:
            t=t-1
    return loss_temp

def algorithm_mean(algorithm,eta,a,x,y,times, total_iteration):
    loss=np.zeros((times,total_iteration))
    for i in range(times):
        if algorithm=="SGD":
            loss[i:]=SGD(a,x,y,eta,total_iteration)
        if algorithm=="RSAG_nonconvex":
            loss[i:]=RSAG_nonconvex(a,x,y,eta,total_iteration)
        if algorithm=="RSAG_convex":
            loss[i:]=RSAG_convex(a,x,y, eta, total_iteration)
    return loss

begin=0
total_iteration=250
times=100
eta=[0.0005,0.005]
color=['blue','green','red']
linestyle=['-','dashed']

'''avg_SGD=[]
std_SGD=[]
high_SGD=[]
low_SGD=[]
'''
avg_RSAG_convex=[]
std_RSAG_convex=[]
high_RSAG_convex=[]
low_RSAG_convex=[]
'''
avg_RSAG_nonconvex=[]
std_RSAG_nonconvex=[]
high_RSAG_nonconvex=[]
low_RSAG_nonconvex=[]'''


for i in range(len(eta)):
    loss_SGD=algorithm_mean("SGD",eta[i],a,x,y,times,total_iteration)
    loss_RSAG_convex=algorithm_mean("RSAG_convex",eta[i],a,x,y,times,total_iteration)
    loss_RSAG_nonconvex=algorithm_mean("RSAG_nonconvex",eta[i],a,x,y,times,total_iteration)


    avg_SGD=np.mean(loss_SGD,axis=0)
    std_SGD=np.std(loss_SGD,axis=0)
    high_SGD=list(map(lambda x: x[0]+x[1],zip(avg_SGD,std_SGD)))
    low_SGD=list(map(lambda x: x[0]-x[1],zip(avg_SGD,std_SGD)))

    avg_RSAG_convex=np.mean(loss_RSAG_convex,axis=0)
    std_RSAG_convex=np.std(loss_RSAG_convex,axis=0)
    high_RSAG_convex=list(map(lambda x: x[0]+x[1],zip(avg_RSAG_convex,std_RSAG_convex)))
    low_RSAG_convex=list(map(lambda x: x[0]-x[1],zip(avg_RSAG_convex,std_RSAG_convex)))



    #avg_RSAG_nonconvex=np.mean(loss_RSAG_nonconvex,axis=0)
    #std_RSAG_nonconvex=np.std(loss_RSAG_nonconvex,axis=0)
    #high_RSAG_nonconvex=list(map(lambda x: x[0]+x[1],zip(avg_RSAG_nonconvex,std_RSAG_nonconvex)))
    #low_RSAG_nonconvex=list(map(lambda x: x[0]-x[1],zip(avg_RSAG_nonconvex,std_RSAG_nonconvex)))

    plt.plot(avg_SGD,color=color[0],linestyle=linestyle[i],label='SGD,'r' $\beta={}$'.format(eta[i]))
    plt.plot(avg_RSAG_convex, color=color[2],linestyle=linestyle[i],label='SNAG (Alg 2),' r' $\beta={}$'.format(eta[i]))
    #plt.plot(avg_RSAG_nonconvex, color=color[1],linestyle=linestyle[i],label='stochastic AGD (Alg 3),' r' $\eta={}$'.format(eta[i]))

    #plt.fill_between(np.arange(total_iteration-begin),low_SGD,high_SGD, color=color[0],alpha=0.2)
    #plt.fill_between(np.arange(total_iteration-begin),low_RSAG_convex,high_RSAG_convex, color=color[2],alpha=0.2)
    #plt.fill_between(np.arange(total_iteration-begin),low_RSAG_nonconvex,high_RSAG_nonconvex, color=color[2],alpha=0.2)
 
plt.xlabel("Iteration t")
plt.ylabel("Loss")

plt.title("Convergence behavior for two-layer neural network")
plt.legend()
plt.show()


'''plt.plot(loss_mean_SGD[begin:])
plt.plot(loss_mean_RSAG[begin:],color='red')
plt.title('Loss Function Convergence')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.show()

plt.savefig('RSAG and SGD.png')'''