import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import math
import numpy as np
import networkx as nx
import scipy as sp
import matplotlib.pyplot as plt


n = 100
r = 10
d = 10 
step_size = 0.0002111

L  = 3

def relu(x):
  return jnp.maximum(0, x)

def dual_relu(x):
  return (jnp.sqrt(1-x**2) + (math.pi - jnp.arccos(x))*x)/math.pi

def squared_loss(M1, M2):
  # return jnp.linalg.norm(M1**L-M2**L)**2
  return jnp.linalg.norm(jnp.linalg.matrix_power(M1,L)-jnp.linalg.matrix_power(M2,L))**2

def loss(w, w_star, D):

  loss_val = 0.0

  loss_val_1 = ((jnp.linalg.norm(w)**2) + 1)*jnp.sum(jnp.sum(dual_relu(D/(d+1))))
  loss_val_2 = 2*(jnp.linalg.norm(w))*jnp.sum(jnp.sum(dual_relu(D*jnp.dot(w,w_star)/((d+1)*jnp.linalg.norm(w)))))
  return loss_val_1 - loss_val_2



@jit
def update(w, w_star, D):
  grads = grad(loss)(w, w_star, D)
  return w - step_size*grads
  # return [w1 - step_size * dw1
  #         for w1, dw1 in zip(w, grads)]

@jit
def squared_update(w, w_star):
  grads = grad(squared_loss)(w, w_star)
  return w - step_size*grads
  # return [w1 - step_size * dw1
  #         for w1, dw1 in zip(w, grads)]    


G = nx.generators.random_graphs.random_regular_graph(d,n)

A = nx.linalg.graphmatrix.adjacency_matrix(G)

A = A.todense()

for i in range(n):
  A[i,i] = 1



D = np.zeros((n,n))

for i in range(n):
  for j in range(n):
    D[i,j] = int(np.sum(np.multiply(A[i,:], A[j,:])))


Dnp = jnp.asarray(D)

# train loop
T = 1000

w_star = np.random.normal(size=(r))
w_star = w_star/np.linalg.norm(w_star)

w_0 = np.random.normal(size=(r))

jw_star = jnp.asarray(w_star)
jw = jnp.asarray(w_0)

err_1 = []
err_2 = []

for i in range(T):
  jw = update(jw, jw_star, Dnp)
  # print(jnp.linalg.norm(jw))
  err_2.append(np.linalg.norm(w_star - jw)) 
  err_1.append(jnp.sqrt(loss(jw, w_star, D)))


plt.plot(range(1,T), err_1[1:T])

plt.xlabel('Num. Iterations')
plt.ylabel('Loss Value')

plt.legend()
plt.show()

