# -*- coding: utf-8 -*-
#
import numpy as np
import matplotlib.pyplot as plt
import os 
from scipy.linalg import fractional_matrix_power

#Value to investigate 
data_for_prior_training = 0.5
sigma = 0.03

#Retriving error bound and testing error
kl_penalty = [0.0001,0.0005,0.001,0.005,0.01,0.05,0.1,0.5,1]  
bound_temp_list = []
testing_error_temp_list = [] 
for kk in kl_penalty:
    os.chdir("your directory")
    if kk == 1:
        kk = 1.0
    bound = np.load(str(kk)+'_'+str(sigma)+'_'+str(data_for_prior_training)+'_fcn_.npy',allow_pickle=True)[0]
    test = np.load(str(kk)+'_'+str(sigma)+'_'+str(data_for_prior_training)+'_fcn_.npy',allow_pickle=True)[1]
    bound_temp_list.append(bound)
    testing_error_temp_list.append(test)


#Value to investigate 
sigma = 0.03
data_for_prior_training = 0.5
sample_per_class = 325 
NTK_methods = "ntk_init_withdivnothing" 


#Retriving right 
os.chdir("your directory")
rr = np.load('0.03_'+str(sample_per_class)+'_'+str(data_for_prior_training)+'_fcn_.npy',allow_pickle=True)[-2].numpy()
haha = []
for ww in range(10):
    small_list =[] 
    for k in rr[:,ww]:
        temp_small = [ 1 if k == l else -1 for l in rr[:,ww]]
        small_list.append(temp_small)
    haha.append(np.array(small_list))
final_matrix = haha[0]+haha[1]+haha[2]+haha[3]+haha[4]+haha[5]+haha[6]+haha[7]+haha[8]+haha[9]

os.chdir("your directory")
temp_ntk_matrix = np.load('0.03_'+str(sample_per_class)+'_'+str(data_for_prior_training)+'_fcn_.npy',allow_pickle=True)[-1].numpy()
kl_penalty = [0.0001,0.0005,0.001,0.005,0.01,0.05,0.1,0.5,1] 
temp_list =[]
for i in kl_penalty: 
    temp_ntk_matrix_with_element = fractional_matrix_power(temp_ntk_matrix + (np.identity(len(temp_ntk_matrix)))*(i/sigma),-2)
    temp_align_value = np.trace(final_matrix*temp_ntk_matrix_with_element)
    r_temp_align_value = np.sqrt(temp_align_value/sample_per_class)*(i/sigma)
    l_temp_align_value = (np.trace(final_matrix*fractional_matrix_power(temp_ntk_matrix + (np.identity(len(temp_ntk_matrix)))*(i/sigma),-2)))*(1/sigma/sample_per_class)
    temp_align_value = l_temp_align_value + r_temp_align_value
    temp_list.append(temp_align_value)
    
    
from matplotlib.pyplot import figure
import scipy.stats as stats

colors = np.array(["red","green","black","orange","purple","lime","cyan","magenta",'navy'])#,'pink']) , "dodgerblue","crimson","teal","peru","violet","seagreen","moccasin","darkred"])
labels = np.array([r"$\lambda = 1x10^{-4} $",r"$\lambda = 5x10^{-4} $",r"$\lambda = 1x10^{-3} $",r"$\lambda = 5x10^{-3} $",r"$\lambda = 1x10^{-2} $",r"$\lambda = 5x10^{-2} $",r"$\lambda = 1x10^{-1} $",r"$\lambda = 5x10^{-1} $",r"$\lambda = 1 $"])#,"50% prior data"]),"55% prior data","60% prior data","65% prior data","70% prior data","75% prior data","80% prior data","85% prior data","90% prior data"])

tau, p_value = stats.kendalltau(temp_list, bound_temp_list)

#figure(figsize=(8, 6), dpi=80)
x = temp_list
y = bound_temp_list
for xx,yy,zz,jj in zip(x,y,colors,labels):
        plt.scatter(xx, yy, c=zz,label=jj)   
plt.ylabel('Error Bound',fontsize=18)
plt.xlabel(r'$\mathcal{PA}$',fontsize=18)
plt.grid(linestyle='-')
plt.legend(loc='upper left', borderaxespad=0.)
legend = plt.legend()
legend.get_frame().set_edgecolor('black')
#plt.title(r"$ {\frac{1}{\sigma_0}}{Y^T (k(X,X)+ \frac{\lambda}{\sigma_0} I)^{-2} Y} + \frac{\lambda}{\sigma_0} \sqrt{ Y^T(k(X,X)+  \frac{\lambda}{\sigma_0} I)^{-2}Y}$")
plt.title(r"Correlation between $\mathcal{PA}$ and bound"'\n' r"under FCN with different $\lambda$",fontsize=18)
plt.show()