#!/opt/conda/bin/python3
import random

import matplotlib.pyplot as plt
import numpy as np

random.seed(0)
np.random.seed(0)

plt.style.use('ggplot')
plt.rcParams["figure.figsize"] = (16,9)
plt.rcParams["font.size"] = 40
plt.rcParams["font.family"] = 'serif'
# # plt.rcParams["font.weight"] = 'bold'
plt.rcParams["xtick.color"] = 'black'
plt.rcParams["ytick.color"] = 'black'
plt.rcParams["axes.edgecolor"] = 'black'
plt.rcParams["axes.linewidth"] = 1

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Helper
#----------------------------------------------------------------------------------------------------------------------------------------------------

def project(u,v):
    return np.dot(u, v) * v / np.linalg.norm(v)**2

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Main
#----------------------------------------------------------------------------------------------------------------------------------------------------

num_nodes = 100
samples = 1001

x = np.linspace(-1.5,1.5,num_nodes)
d = np.random.randint(2, 10, size = num_nodes)
eigenspace = np.sqrt(d)
alpha = np.linspace(np.min(x), np.max(x), samples)

a = .01
smoothness = [[],[],[]]
for i in range(samples):

    z = x + alpha[i]*eigenspace
    relu_z = np.maximum(z, 0)
    leaky_z = np.where(z>0, z, z*a)

    z_m = project(z,eigenspace)
    relu_z_m = project(relu_z,eigenspace)
    leaky_z_m = project(leaky_z,eigenspace)

    s_z = np.linalg.norm(z_m,2)/np.linalg.norm(z,2) if np.linalg.norm(z_m,2)>1e-9 else 1
    s_relu = np.linalg.norm(relu_z_m,2)/np.linalg.norm(relu_z,2) if np.linalg.norm(relu_z_m,2)>1e-9 else 1
    s_leaky = np.linalg.norm(leaky_z_m,2)/np.linalg.norm(leaky_z,2) if np.linalg.norm(leaky_z_m,2)>1e-9 else 1

    smoothness[0].append(s_z)
    smoothness[1].append(s_relu)
    smoothness[2].append(s_leaky)


plt.figure(tight_layout=True)
# plt.grid(False)
plt.axvline(-.75, color='dimgrey', linestyle='dashed', linewidth=7) #Decrease Smoothness
plt.axvline(-.2, color='dimgrey', linestyle='dashed', linewidth=7) #Preserve Smoothness
plt.axvline(.25, color='dimgrey', linestyle='dashed', linewidth=7) #Increase Smoothness
plt.text(.605, 0.92, '(c)', transform=plt.gca().get_yaxis_transform(), color='dimgrey', fontweight='bold')
plt.text(.425, 0.92, '(b)', transform=plt.gca().get_yaxis_transform(), color='dimgrey', fontweight='bold')
plt.text(.205, 0.92, '(a)', transform=plt.gca().get_yaxis_transform(), color='dimgrey', fontweight='bold')

# plt.plot(alpha, [min(smoothness[1])]*len(alpha), color='r', linewidth=5)

plt.plot(alpha,smoothness[0], color='black', linewidth=7, label='$s(z_\\alpha)$')

idx1 = list(np.where(np.array(smoothness[1])<1)[0])
alpha1 = np.array(alpha)
smoothness1 = np.array(smoothness[1])
plt.plot(alpha1[idx1],smoothness1[idx1], color='blue', dashes=[4,1], linewidth=7)


idx1 = list(np.where(np.array(smoothness[1])==1)[0])
alpha1 = np.array(alpha)
smoothness1 = np.array(smoothness[1])
plt.plot(alpha1[idx1],smoothness1[idx1], color='blue', dashes=[4,1], linewidth=7, label='$s(\sigma(z_\\alpha))$', drawstyle='steps-pre')

plt.plot(alpha,smoothness[2], color='red', dashes=[2,2], linewidth=10, label='$s(\sigma_a(z_\\alpha))$')
plt.legend(bbox_to_anchor=(.65,.5),handlelength=1)
plt.xlim(-1.25,1.25)
plt.xlabel('Parameter $(\\alpha)$',color='black')
plt.ylabel('Smoothness $(s)$',color='black')
plt.savefig(f'/root/workspace/out/sct_gnn/smoothness.pdf',format='pdf',bbox_inches='tight',dpi=300)