import numpy as np
import math
import sys

dim = 30
xdim = 20
N = 15
rate = 1
noise = float(sys.argv[1])
join_p = float(sys.argv[2])

total_step = 2001

local_steps = int(sys.argv[3])

r_step =  40
def calc_gradient_w(p,w,A,B,a,b,active):
    gradient = 0
    total = 0
    for i in range(N):
        if (active[i]>=join_p):
            total = total + 1
            AA = A[i,:,:]
            AA = AA.reshape(dim,xdim)
            tmp = np.matmul(np.transpose(AA),np.matmul(AA,w)) - np.matmul(np.transpose(AA), B[i,:].reshape(dim))
            aa = a[i,:].reshape(xdim)
            tmp -= rate * np.sin(np.matmul(np.transpose(aa),w) - b[i])*aa
            gradient = gradient + p[i]*tmp
    return gradient*N/total
    
def calc_gradient_wi(w,A,B,a,b,i):
    AA = A[i,:,:]
    AA = AA.reshape(dim,xdim)
    tmp = np.matmul(np.transpose(AA),np.matmul(AA,w)) - np.matmul(np.transpose(AA), B[i,:].reshape(dim))
    aa = a[i,:].reshape(xdim)
    tmp -= rate * np.sin(np.matmul(np.transpose(aa),w) - b[i])*aa
    return tmp
    
def calc_gradient_w_0(w,A0,B0,a0,b0):
    return np.matmul(np.transpose(A0),np.matmul(A0,w)) - np.matmul(np.transpose(A0),B0) - rate* np.sin(np.matmul(np.transpose(a0),w ) - b0)*a0

def calc_hessian_w(p,w,A,B,a,b,A0,B0,a0,b0,active):
    gradient_x = np.zeros(N)
    SO = 0
    total = 0
    for i in range(N):
        if (active[i]>=join_p):
            total = total + 1
            AA = A[i,:,:].reshape(dim,xdim)
            tmp = np.matmul(np.transpose(AA),AA)
            aa = a[i,:].reshape(xdim,1)
            tmp = tmp - rate*np.cos(np.matmul(np.transpose(aa),w) - b[i])*np.matmul(aa,np.transpose(aa))
            SO = SO + p[i]*tmp
            
    return SO*N/total
        
def calc_hessian_w_V(p,w,A,B,a,b,A0,B0,a0,b0,active,v,Noise):
    gradient_x = np.zeros(N)
    SO = 0
    total = 0
    for i in range(N):
        if (active[i]>=join_p):
            total = total + 1
            AA = A[i,:,:].reshape(dim,xdim)
            tmp = np.matmul(np.transpose(AA),np.matmul(AA,v))
            aa = a[i,:].reshape(xdim,1)
            tmp = tmp - rate*np.cos(np.matmul(np.transpose(aa),w) - b[i])*np.matmul(aa,np.matmul(np.transpose(aa),v))
            SO = SO + p[i]*tmp + p[i]* np.matmul(Noise[i,:,:].reshape(xdim,xdim),v)
            
    return SO*N/total
        
def calc_hessian_w_v(p,w,A,B,a,b,A0,B0,a0,b0,active,v,Noise):
    gradient_x = np.zeros(N)
    SO = 0
    total = 0
    for i in range(N):
        if (active[i]>=join_p):
            total = total + 1
            AA = A[i,:,:].reshape(dim,xdim)
            tmp = np.matmul(np.transpose(AA),np.matmul(AA,v))
            aa = a[i,:].reshape(xdim)
            tmp = tmp - rate*np.cos(np.sum(aa*w) - b[i])*np.sum(aa*v)*aa
            SO = SO + p[i]*tmp + p[i]* np.matmul(Noise[i,:,:].reshape(xdim,xdim),v)
            
    return SO*N/total


def proj_to_one(x):

    value = np.flip(np.sort(x))
    S = 0
    tmp = 0

    for i in range(N):
         S = S+ value[i]
         if (value[i]+(1-S)/(i+1)>0):
            tmp = (1-S)/(i+1)

    x = x + tmp
    x[x<0] = 0
    return x


def calc_function_value(x,A,B,a,b,A0,B0,a0,b0):
    w = np.zeros(xdim)
    for inner_iter in range(200):
          w = w - 0.01* calc_gradient_w(x,w,A,B,a,b,np.ones(N))
    
    tmp = np.matmul(A0,w) - B0
    res = 0.5*np.sum(tmp*tmp)+ rate*np.cos(np.matmul(np.transpose(a0),w)-b0)
    return w,res

def calc_function(x,A,B,a,b,A0,B0,a0,b0):
    x0 = np.copy(x)
    w = np.zeros(xdim)
    for inner_iter in range(100):
          w = w - 0.01* calc_gradient_w(x,w,A,B,a,b,np.ones(N))
    
    tmp = np.matmul(A0,w) - B0
    res = 0.5*np.sum(tmp*tmp)+ rate*np.cos(np.matmul(np.transpose(a0),w)-b0)
    #SO = np.linalg.inv(calc_hessian_w(x,w,A,B,a,b,A0,B0,a0,b0,np.ones(N)))
    #tmp0 = -np.matmul(SO, calc_gradient_w_0(w,A0,B0,a0,b0))
    tmp0 = calc_gradient_w_0(w,A0,B0,a0,b0) 
    tmp1 = 0
    Noise = np.zeros((N,xdim,xdim))
    for i in range(100):
       active1 = np.ones(N)
       tmp1 = tmp1 + tmp0
       tmp0 = tmp0 - 0.01* calc_hessian_w_v(x,w,A,B,a,b,A0,B0,a0,b0,active1,tmp0,Noise)
           
    tmp1 = -tmp1 *0.01
       
    for i in range(N):
        x0[i] = x0[i] - 0.01*np.sum(tmp1*(calc_gradient_wi(w,A,B,a,b,i)))
    return  x - proj_to_one(x0), res
    
def GD_x(A,B,a,b,A0,B0,a0,b0,eta_w,eta_x):
    x = np.ones(N)/N
    w = np.zeros(xdim)
    maxlab = 0
    for outer_iter in range(total_step):
        if (outer_iter%r_step==0):
             grad,val = calc_function(x,A,B,a,b,A0,B0,a0,b0)
             print("GD_x", outer_iter, np.sum(grad*grad),val[0])
        active = np.ones(N) #np.random.random(N)
        while np.max(active)<join_p:
            active = np.random.random(N)
        for inner_iter in range(30):
            w = w - eta_w* (calc_gradient_w(x,w,A,B,a,b,active) + noise * np.random.randn(xdim))
        SO = np.linalg.inv(calc_hessian_w(x,w,A,B,a,b,A0,B0,a0,b0,active) + noise*np.random.randn(xdim,xdim))
        tmp0 = -np.matmul(SO, calc_gradient_w_0(w,A0,B0,a0,b0))
        maxlab = max(maxlab,np.sum(tmp0*tmp0))
        for i in range(N):
            if (active[i]>=join_p):
                x[i] = x[i] - eta_x*np.sum(tmp0*(calc_gradient_wi(w,A,B,a,b,i) + noise*np.random.randn(xdim)))
        x = proj_to_one(x)
    #print(maxlab)
    return x,calc_function_value(x,A,B,a,b,A0,B0,a0,b0)


def BSA_x(A,B,a,b,A0,B0,a0,b0, eta_w, eta_x, eta_SO):
    x = np.ones(N)/N
    w = np.zeros(xdim)
    for outer_iter in range(total_step):
       if (outer_iter%r_step==0):
            grad,val = calc_function(x,A,B,a,b,A0,B0,a0,b0)
            print("BSA_x", outer_iter, np.sum(grad*grad),val[0])
           
       active = np.random.random(N)
       while np.max(active)<join_p:
            active = np.random.random(N)
       for inner_iter in range(local_steps):
          w = w - eta_w* (calc_gradient_w(x,w,A,B,a,b,active) + noise*np.random.randn(xdim))
       
       tmp0 = calc_gradient_w_0(w,A0,B0,a0,b0) + noise*np.random.randn(xdim)
       tmp1 = 0
       Noise = np.random.randn(N,xdim,xdim)*noise
       for i in range(local_steps):
           active1 = np.random.random(N)
           while np.max(active1)<join_p:
            active1 = np.random.random(N)
            
           tmp1 = tmp1 + tmp0
           tmp0 = tmp0 - eta_SO* calc_hessian_w_v(x,w,A,B,a,b,A0,B0,a0,b0,active1,tmp0,Noise)
           
       tmp1 = -tmp1 *eta_SO
       #tmp2 = tmp1+np.matmul(np.linalg.inv(calc_hessian_w(x,w,A,B,a,b,A0,B0,a0,b0,active1)),calc_gradient_w_0(w,A0,B0,a0,b0)) 
       #print(np.sum(tmp2*tmp2))
       for i in range(N):
            if (active[i]>=join_p):
                x[i] = x[i] - eta_x*np.sum(tmp1*(calc_gradient_wi(w,A,B,a,b,i) + noise*np.random.randn(xdim)))
       
       x = proj_to_one(x)
    
    return x,calc_function_value(x,A,B,a,b,A0,B0,a0,b0)


def double_momentum(A,B,a,b,A0,B0,a0,b0,mw,mx,eta_w,eta_x,eta_SO):
    x = np.ones(N)/N
    w = np.zeros(xdim)
    hw = np.zeros(xdim)
    hx = np.zeros(N)

    for iteration in range(total_step):
        if (iteration%r_step==0):
            grad,val = calc_function(x,A,B,a,b,A0,B0,a0,b0)
            print("double_momentum", iteration, np.sum(grad*grad),val[0])
        active = np.random.random(N)
        while np.max(active)<join_p:
            active = np.random.random(N)
       
        if (iteration==0):
            hw = calc_gradient_w(x,w,A,B,a,b,active) + noise*np.random.randn(xdim)
            hx = np.zeros(N)
            tmp0 = calc_gradient_w_0(w,A0,B0,a0,b0) + noise*np.random.randn(xdim)
            tmp1 = 0
            Noise = np.random.randn(N,xdim,xdim)*noise
            for i in range(local_steps):
                active1 = np.random.random(N)
                while np.max(active1)<join_p:
                    active1 = np.random.random(N)
           
                tmp1 = tmp1 + tmp0
                tmp0 = tmp0 - eta_SO* calc_hessian_w_v(x,w,A,B,a,b,A0,B0,a0,b0,active1,tmp0,Noise)
            tmp1 = -tmp1 *eta_SO   
            for i in range(N):
                if (active[i]>=join_p):
                    hx[i] = np.sum(tmp1*(calc_gradient_wi(w,A,B,a,b,i) + noise*np.random.randn(xdim)))
       
        else:
            hw = calc_gradient_w(x,w,A,B,a,b,active) + (1-mw)*(hw - calc_gradient_w(x0,w0,A,B,a,b,active)) + mw*noise*np.random.randn(xdim)
            tmphx1 = np.zeros(N)
            tmphx2 = np.zeros(N)
            tmp0 = calc_gradient_w_0(w,A0,B0,a0,b0) + noise*np.random.randn(xdim)
            tmp02 = calc_gradient_w_0(w0,A0,B0,a0,b0) + noise*np.random.randn(xdim)
            tmp1 = 0
            tmp12 = 0
            Noise = np.random.randn(N,xdim,xdim)*noise
            for i in range(local_steps):
                active1 = np.random.random(N)
                while np.max(active1)<join_p:
                    active1 = np.random.random(N)
           
                tmp1 = tmp1 + tmp0
                tmp12 = tmp12 + tmp02
                tmp0 = tmp0 - eta_SO* calc_hessian_w_v(x,w,A,B,a,b,A0,B0,a0,b0,active1,tmp0,Noise)
                tmp02 = tmp02 - eta_SO*calc_hessian_w_v(x0,w0,A,B,a,b,A0,B0,a0,b0,active1,tmp02,Noise)
            tmp1 = -tmp1*eta_SO
            tmp12 = -tmp12*eta_SO
            for i in range(N):
                if (active[i]>=join_p):
                    tmpnoise = noise*np.random.randn(xdim)
                    tmphx1[i] = np.sum(tmp1*(calc_gradient_wi(w,A,B,a,b,i) + tmpnoise))
                    tmphx2[i] = np.sum(tmp12*(calc_gradient_wi(w0,A,B,a,b,i) + tmpnoise))
                    
       
            hx = tmphx1 + (1-mx)*(hx - tmphx2)
        x0 = np.copy(x)
        w0 = np.copy(w)
        x = x - eta_x*hx
        w = w - eta_w*hw
        x = proj_to_one(x)
    return x,calc_function_value(x,A,B,a,b,A0,B0,a0,b0)



def PD_x(A,B,a,b,A0,B0,a0,b0,Gamma, eta_w,eta_x,eta_lamb):
    x = np.ones(N)/N
    w = np.zeros(xdim)
    lamb = np.zeros(xdim)
    wz = np.zeros(xdim)
    for iterate in range(total_step):
        if (iterate%r_step==0):
            grad,val = calc_function(x,A,B,a,b,A0,B0,a0,b0)
            print("PD_x", iterate, np.sum(grad*grad),val[0])
        x0 = np.copy(x)
        for i in range(local_steps):
            active = np.random.random(N)
            while np.max(active)<join_p:
                active = np.random.random(N)
       
            gradient_w = calc_gradient_w_0(w,A0,B0,a0,b0) + noise*np.random.randn(xdim)
            gradient_lamb = calc_gradient_w(x,w,A,B,a,b,active) + noise*np.random.randn(xdim)
            gradient_w = gradient_w + Gamma*gradient_lamb
            gradient_w = gradient_w +  calc_hessian_w_v(x,w,A,B,a,b,A0,B0,a0,b0,active,lamb, np.random.randn(N,xdim,xdim)*noise)
            
        
            w = w - eta_w*gradient_w
            lamb = lamb + eta_lamb*gradient_lamb
        
        for i in range(N):
            if (active[i]>=join_p):
                x[i] = x[i] - eta_x*np.sum(lamb*(calc_gradient_wi(w,A,B,a,b,i)+noise*np.random.randn(xdim)))
        x = proj_to_one(x)
    
    return x,calc_function_value(x,A,B,a,b,A0,B0,a0,b0)

def unrolling_gd(A,B,A0,B0,a,b,a0,b0,eta_w,eta_x):
    w = np.zeros(xdim)
    x = np.ones(N)/N
    I = np.eye(xdim)
   
    for outet_iter in range(total_step):
        if (outet_iter%r_step==0):
            grad,val = calc_function(x,A,B,a,b,A0,B0,a0,b0)
            print("unrolling_gd", outet_iter, np.sum(grad*grad),val[0])
        gradient_x_iter = np.zeros((xdim,N))
        for inner_iter in  range(local_steps):
            active = np.random.random(N)
            while np.max(active)<join_p:
                active = np.random.random(N)
                
            tmp_sc = 0
            gradient_x_iter = gradient_x_iter - eta_w*calc_hessian_w_V(x,w,A,B,a,b,A0,B0,a0,b0,active,gradient_x_iter, np.random.randn(N,xdim,xdim)*noise)
            total = np.sum(active>=join_p)
            for i in range(N):
               if (active[i]>=join_p) :
                tmp = calc_gradient_wi(w,A,B,a,b,i)+ noise*np.random.randn(xdim)
                gradient_x_iter[:,i] = gradient_x_iter[:,i] - tmp*N/total
            
            w = w - eta_w * calc_gradient_w(x,w,A,B,a,b,active)
            
        tmp0 = calc_gradient_w_0(w,A0,B0,a0,b0)
        gradient_x =   eta_w*np.matmul(np.transpose(gradient_x_iter),tmp0) 
        x = proj_to_one(x - eta_x*gradient_x)
    
    return x, calc_function_value(x,A,B,a,b,A0,B0,a0,b0)
    







def check_SC(A,a):
    for i in range(N):
        AA = A[i,:,:].reshape(dim,xdim)
        aa = a[i,:].reshape(xdim,1)
        AA = np.matmul(np.transpose(AA),AA) - rate*np.matmul(aa,np.transpose(aa))
        eigs,_ = np.linalg.eig(AA)
        #print(i)
        num= 0
        while np.min(eigs)<=0.1:
            #num = num+1
            #if (num%10000==0):
            #    print(i,num)
            A[i,:,:] = np.random.randn(dim,xdim)/math.sqrt(xdim)
            a[i,:] = np.random.randn(xdim)/math.sqrt(xdim)
            AA = A[i,:,:].reshape(dim,xdim)
            aa = a[i,:].reshape(xdim,1)
            AA = np.matmul(np.transpose(AA),AA) - rate*np.matmul(aa,np.transpose(aa))
            eigs,_ = np.linalg.eig(AA)
        
 
   
    return A,a
        
def print_SC(A,B,A0,B0,a,b,a0,b0):
    for i in range(N):
        AA = A[i,:,:].reshape(dim,xdim)
        aa = a[i,:].reshape(xdim,1)
        AA = np.matmul(np.transpose(AA),AA) - rate*np.matmul(aa,np.transpose(aa))
        eigs,_ = np.linalg.eig(AA)
    
        print(np.min(eigs), np.max(eigs))
    AA = A0.reshape(dim,xdim)
    aa = a0.reshape(xdim,1)
    AA = np.matmul(np.transpose(AA),AA) - rate*np.matmul(aa,np.transpose(aa))
    eigs,_ = np.linalg.eig(AA)
    print(np.min(eigs),np.max(eigs))



def estimate_w(x,w,A,B,A0,B0,a,b,a0,b0):
    gradient = calc_gradient_w(x,w,A,B,a,b)
    tmp = np.matmul(A0,w) - B0
    res = 0.5*np.sum(tmp*tmp)+ rate*np.cos(np.matmul(np.transpose(a0),w)-b0)
    return np.sum(gradient*gradient),res
    
for i in range(1):
    A = np.random.randn(N,dim,xdim)/math.sqrt(xdim)
    B = np.random.randn(N,dim)/math.sqrt(dim)
    A0 = np.random.randn(dim,xdim)/math.sqrt(xdim)
    B0 = np.random.randn(dim)/math.sqrt(dim)
    a = np.random.randn(N,xdim)/math.sqrt(xdim)
    b = np.random.randn(N)
    a0 = np.random.randn(xdim)/math.sqrt(xdim)
    b0 = np.random.randn(1)
    
    A,a = check_SC(A,a)
    
    #print_SC(A,B,A0,B0,a,b,a0,b0)
    
    #res1_x, res1 = GD_x(A,B,a,b,A0,B0,a0,b0)
    #res1_w, res1_res = res1
    
    #res2_x, res2 = PD_x(A,B,a,b,A0,B0,a0,b0)
    #res2_w, res2_res = res2
    
    #res3_x,res3 = unrolling_gd(A,B,A0,B0,a,b,a0,b0)
    #res3_w,res3_res = res3
    
    #res4_x, res4 = BSA_x(A,B,a,b,A0,B0,a0,b0)
    #res4_w, res4_res = res4
    
    #res5_x,res5 = double_momentum(A,B,a,b,A0,B0,a0,b0)
    #res5_w,res5_res = res5
    #print(res1_x,res2_x)
    #print(res1_w,res2_w)
    #print(res1_res,res2_res,res3_res,res4_res,res5_res)
    #print((res2_res - res1_res)/res1_res, (res3_res - res1_res)/res1_res, (res4_res - res1_res)/res1_res)
    #print(estimate_w(res1_x,res1_w,A,B,A0,B0,a,b,a0,b0), estimate_w(res2_x,res2_w,A,B,A0,B0,a,b,a0,b0), estimate_w(res3_x,res3_w,A,B,A0,B0,a,b,a0,b0), estimate_w(res4_x,res4_w,A,B,A0,B0,a,b,a0,b0))
    #print(res1_x,res2_x,res3_x,res4_x)
    #print(calc_function_value(res1_x,A,B,a,b,A0,B0,a0,b0), calc_function_value(res4_x,A,B,a,b,A0,B0,a0,b0))
    
    eta_w = 0.001
    eta_x = 0.005
    GD_x(A,B,a,b,A0,B0,a0,b0,eta_w,eta_x)
    
    eta_w  =0.0005 
    #eta_x =0.05
    Gamma = 1
    eta_lamb = 0.02
    PD_x(A,B,a,b,A0,B0,a0,b0,Gamma, eta_w,eta_x,eta_lamb)
               
    eta_w = 0.001
    eta_x = 0.005
    unrolling_gd(A,B,A0,B0,a,b,a0,b0,eta_w,eta_x)
    
    #eta_w = 0.1
    #eta_x = 0.05
    eta_SO = 0.01
    BSA_x(A,B,a,b,A0,B0,a0,b0, eta_w, eta_x, eta_SO)
            
            
            
            
    #eta_w = 0.1
    #eta_x  = 0.05
    #eta_SO=0.1
    mx,mw = 0.1,0.1
    double_momentum(A,B,a,b,A0,B0,a0,b0,mw,mx,eta_w,eta_x,eta_SO)
    '''
    if ((res2_res - res1_res)/res1_res>0.5):
            np.save(f,A)
            np.save(f,A0)
            np.save(f,B)
            np.save(f,B0)
            np.save(f,a)
            np.save(f,a0)
            np.save(f,b)
            np.save(f,b0)
    '''
