# -*- coding: utf-8 -*-


import numpy as np
import matplotlib.pyplot as plt
import os 
from scipy.linalg import fractional_matrix_power
#moving to the the data directory

#Value to investigate 
data_for_prior_training = 0.5
kl_penalty=1 

#Retriving error bound and testing error
sigma = [0.05,0.07,0.09,0.1,0.3,0.5,0.7,0.9] 
bound_temp_list = []
testing_error_temp_list = [] 


for kk in sigma:
    
    os.chdir("your directory") 
    if kl_penalty == 1:
        kl_penalty = 1.0
    bound = np.load(str(kl_penalty)+'_'+str(kk)+'_'+str(data_for_prior_training)+'_cnn_.npy',allow_pickle=True)[0]
    test = np.load(str(kl_penalty)+'_'+str(kk)+'_'+str(data_for_prior_training)+'_cnn_.npy',allow_pickle=True)[1]
    bound_temp_list.append(bound)
    testing_error_temp_list.append(test)
    if kl_penalty == 1.0:
        kl_penalty = 1


#Value to investigate 
kl_penalty = 1
data_for_prior_training = 0.5
sample_per_class = 75 
NTK_methods = "ntk_init_withdivnothing" 


#Retriving right 
os.chdir("your directory") 
rr = np.load('0.03_'+str(sample_per_class)+'_'+str(data_for_prior_training)+'_cnn_.npy',allow_pickle=True)[-2].cpu().numpy()
haha = []
for ww in range(10):
    small_list =[] 
    for k in rr[:,ww]:
        temp_small = [ 1 if k == l else -1 for l in rr[:,ww]]
        small_list.append(temp_small)
    haha.append(np.array(small_list))
final_matrix = haha[0]+haha[1]+haha[2]+haha[3]+haha[4]+haha[5]+haha[6]+haha[7]+haha[8]+haha[9]

os.chdir("your directory") 
temp_ntk_matrix = np.load('0.03_'+str(sample_per_class)+'_'+str(data_for_prior_training)+'_cnn_.npy',allow_pickle=True)[-1].cpu().numpy()
sigma = [0.05,0.07,0.09,0.1,0.3,0.5,0.7,0.9] 
temp_list =[]
for i in sigma: 
    temp_ntk_matrix_with_element = fractional_matrix_power(temp_ntk_matrix + (np.identity(len(temp_ntk_matrix)))*(kl_penalty/i),-2)
    temp_align_value = np.trace(final_matrix*temp_ntk_matrix_with_element)
    r_temp_align_value = np.sqrt(temp_align_value/sample_per_class)*(kl_penalty/i)
    l_temp_align_value = (np.trace(final_matrix*fractional_matrix_power(temp_ntk_matrix + (np.identity(len(temp_ntk_matrix)))*(kl_penalty/(i)),-2)))*(1/(i*sample_per_class))
    temp_align_value = l_temp_align_value + r_temp_align_value
    temp_list.append(temp_align_value)
    
    

from matplotlib.pyplot import figure
import scipy.stats as stats

colors = np.array(["red","green","black","orange","purple","lime","cyan","magenta"])#,'pink']) , "dodgerblue","crimson","teal","peru","violet","seagreen","moccasin","darkred"])
labels = np.array([r"$\rho = 0.05 $",r"$\rho = 0.07 $",r"$\rho = 0.09 $",r"$\rho = 0.1 $",r"$\rho = 0.3 $",r"$\rho = 0.5 $",r"$\rho = 0.7 $",r"$\rho = 0.9 $"])#,"50% prior data"]),"55% prior data","60% prior data","65% prior data","70% prior data","75% prior data","80% prior data","85% prior data","90% prior data"])

tau, p_value = stats.kendalltau(temp_list, bound_temp_list)

#figure(figsize=(8, 6), dpi=80)
x = temp_list
y = bound_temp_list
for xx,yy,zz,jj in zip(x,y,colors,labels):
        plt.scatter(xx, yy, c=zz,label=jj)   
plt.ylabel('Error Bound',fontsize=18)
plt.xlabel(r'$\mathcal{PA}$',fontsize=18)
plt.grid(linestyle='-')
plt.legend(loc='upper left', borderaxespad=0.)
legend = plt.legend()
legend.get_frame().set_edgecolor('black')
#plt.title(r"$ {\frac{1}{\sigma_0}}{Y^T (k(X,X)+ \frac{\lambda}{\sigma_0} I)^{-2} Y} + \frac{\lambda}{\sigma_0} \sqrt{ Y^T(k(X,X)+  \frac{\lambda}{\sigma_0} I)^{-2}Y}$")
plt.title(r"Correlation between $\mathcal{PA}$ and bound"'\n' r"under CNN with different $\rho$",fontsize=18)
plt.show()