import random
import math
import numpy as np
import matplotlib.pyplot as plt
import sys
np.random.seed(42)

def relu(x): #ReLU
    if x >= 0:
        return x
    else:
        return 0
    
def relu_div(x): #derivative of ReLU
    if x >= 0:
        return 1
    else:
        return 0
    
RELU = np.vectorize(relu)
RELU_div = np.vectorize(relu_div)


def clip(r,R=2): #pamaeter clipping
    return R*np.tanh(r*np.abs(r)/(2*R))

def clip_div(r,R=2): #derivative of #pamaeter clipping
    return np.abs(r)/(np.cosh(r*np.abs(r)/(2*R)))**2

v_clip = np.vectorize(clip)
v_clip_div = np.vectorize(clip_div)


def NN(W,a,x): #output of neural network
    return np.dot(a,RELU(np.dot(W,x)))

def make_sample(dim,atrue,Wtrue,sample_var,n=100): #preparing the training data
    sample = []
    label = []
    for i in range(n):
        v = np.random.normal(size=dim)
        sample.append(v/np.linalg.norm(v))
        label.append(NN(Wtrue,atrue,v/np.linalg.norm(v))+np.random.normal(0,sample_var))
    return np.array(sample),np.array(label)

def inner_product(w,v): #inner product between two nodes with ReLU activation in function space
    w_norm, v_norm = w/np.linalg.norm(w), v/np.linalg.norm(v)
    phi = np.arccos(np.clip(np.dot(w_norm,v_norm),-1,1))
    return (np.sin(phi)+(np.pi-phi))*np.cos(phi)*np.linalg.norm(w)*np.linalg.norm(v)/(2*np.pi*len(w))


def inner_calc(a,W,atrue,Wtrue): #inner product between two neural networks
    innersum = 0
    m = len(a)
    for i in range(m):
        for j in range(m):
            innersum += a[i]*a[j]*inner_product(W[i],W[j])
            innersum -= 2*a[i]*atrue[j]*inner_product(W[i],Wtrue[j])
            innersum += atrue[i]*atrue[j]*inner_product(Wtrue[i],Wtrue[j])
    return innersum/2.


def twophase_gd(m,dim,eta1=0.01,eta2=0.01,beta = 1000,n=100,max_iter1=100,max_iter2=100,sample_var=0.1,reg=0.01): #two-phase gradient descent
    W = np.random.normal(size=dim*m).reshape(m,dim) #initialization of W
    a = np.random.normal(size=m) #initialization of a
    sample,label = make_sample(dim,atrue,Wtrue,sample_var,n)
    loss = []
    objs = []
    objs_noreg = []
        
    #phase I(noisy gradient descent)    
    for t in range(max_iter1):
        xi = np.random.normal(size = m*(dim+1)).reshape(m,dim+1) #Langevin noise
        delta_a = 2*reg*a-np.sqrt(eta1/beta)*xi[:,0]
        delta_W = 2*reg*W-np.sqrt(eta1/beta)*xi[:,1:]
        obj = reg*np.sum(a*a)+sum([np.dot(w,w) for w in W])
        obj_noreg = 0
        for i in range(n): #calculating gradient for each sample
            obj += (NN(v_clip(W),v_clip(a),sample[i])-label[i])**2/(2*n)
            obj_noreg += (NN(v_clip(W),v_clip(a),sample[i])-label[i])**2/(2*n)
            delta_a += eta1*(NN(v_clip(W),v_clip(a),sample[i])-label[i])*RELU(np.dot(v_clip(W),sample[i]))*v_clip_div(a)/n
            delta_W += eta1*(NN(v_clip(W),v_clip(a),sample[i])-label[i])*np.array([v_clip(a[j])*RELU_div(np.dot(v_clip(W[j]),sample[i]))*sample[i] for j in range(m)])*v_clip_div(W)/n
        a,W = a-delta_a,W-delta_W #updating parameters
        objs.append(obj)
        objs_noreg.append(obj_noreg)
        test_error = inner_calc(v_clip(a),v_clip(W),atrue,Wtrue)
        loss.append(test_error) #calculating test error
        if (t+1)%5 == 0:
            print("phase1 iteration:{} done. trainig error:{}, test error:{}".format(t+1,obj_noreg,test_error))
        
    print("phase1 complete")
    obj = reg*np.sum(a*a)+sum([np.dot(w,w) for w in W])
    obj_noreg = 0
    for i in range(n):
        obj += (NN(v_clip(W),v_clip(a),sample[i])-label[i])**2/(2*n)
        obj_noreg += (NN(v_clip(W),v_clip(a),sample[i])-label[i])**2/(2*n)
    objs.append(obj)
    objs_noreg.append(obj_noreg)

    #reparameterize
    W = np.array([np.abs(clip(a[i]))*v_clip(W[i]) for i in range(m)])
    a = np.sign(a)
    
    #phase II(vanilla gradient descent)
    for t in range(max_iter2):
        delta_W = 0
        obj = 0
        
        for i in range(n): #calculating gradient for each sample
            obj += (NN(W,a,sample[i])-label[i])**2/(2*n)
            delta_W += eta2*(NN(W,a,sample[i])-label[i])*np.array([a[j]*RELU_div(np.dot(W[j],sample[i]))*sample[i] for j in range(m)])/n
        W = W-delta_W #updating paraneters
        objs.append(obj)
        objs_noreg.append(obj)
        test_error = inner_calc(a,W,atrue,Wtrue) #caluclating test error
        loss.append(test_error) 
        if (t+1)%5 == 0:
            print("phase2 iteration:{} done. trainig error:{}, test error:{}".format(t+1,obj,test_error))

    obj = 0
    for i in range(n): 
        obj += (NN(W,a,sample[i])-label[i])**2/(2*n)
    objs.append(obj)
    objs_noreg.append(obj)
    print("phase2 complete")
    
    return a,W,loss,np.array(objs[1:]),np.array(objs_noreg[1:])


if __name__ == "__main__":
    m,dim = int(sys.argv[1]),int(sys.argv[2]) #network width, dimensionality
    atrue = np.array([1.]*(m//2)+[-1.]*(m//2)) #setting of the parameters of the teacher network
    Wtrue = np.eye(m)

    a,W,loss,obj,obj_noreg = twophase_gd(m,dim,eta1=1e-2,eta2=1e-2,beta = 100,n=1000,max_iter1=1000,max_iter2=2000,reg=1e-2)

    #visualization
    plt.plot(np.arange(len(loss)),loss,label = "test error")
    # plt.plot(np.arange(1000),obj[:1000],color = "orange",label = "training error")
    plt.plot(np.arange(1000),obj_noreg[:1000],color = "orange",label = "training error")
    plt.plot(np.arange(1000,3001),obj[1000:],color = "orange")
    plt.yscale("log")
    plt.xlabel("iterations")
    plt.ylabel("loss")
    plt.legend()
    plt.savefig("result.pdf")
    np.savetxt('result_train.csv', obj,delimiter=",")
    np.savetxt('result_train_noreg.csv', obj_noreg,delimiter=",")
    np.savetxt('result_test.csv', loss,delimiter=",")