import numpy as np


'''
additional_name =  "_no_KL_gd"
#Loading pickle 
big = np.load('16'+additional_name+'.npy',allow_pickle=True) #修改
info = big[0]

new_info_mu_16 = [info[0][1],info[1][1],info[2][1],info[3][1]]
new_info_pho_16 = [info[0][0],info[1][0],info[2][0],info[3][0]]



big = np.load('32'+additional_name+'.npy',allow_pickle=True) #修改
info = big[0]

new_info_mu_32 = [info[0][1],info[1][1],info[2][1],info[3][1]]
new_info_pho_32 = [info[0][0],[info[1][0]],[info[2][0]],[info[3][0]]]



big = np.load('64'+additional_name+'.npy',allow_pickle=True) #修改
info = big[0]

new_info_mu_64 = [info[0][1],info[1][1],info[2][1],info[3][1]]
new_info_pho_64 = [info[0][0],info[1][0],info[2][0],info[3][0]]



big = np.load('128'+additional_name+'.npy',allow_pickle=True) #修改
info = big[0]

new_info_mu_128 = [info[0][1],info[1][1],info[2][1],info[3][1]]
new_info_pho_128 = [info[0][0],info[1][0],info[2][0],info[3][0]]


big = np.load('256'+additional_name+'.npy',allow_pickle=True) #修改
info = big[0]

new_info_mu_256 = [info[0][1],info[1][1],info[2][1],info[3][1]]
new_info_pho_256 = [info[0][0],info[1][0],info[2][0],info[3][0]]



big = np.load('512'+additional_name+'.npy',allow_pickle=True) #修改
info = big[0]

new_info_mu_512 = [info[0][1],info[1][1],info[2][1],info[3][1]]
new_info_pho_512 = [info[0][0],info[1][0],info[2][0],info[3][0]]


big = np.load('1024'+additional_name+'.npy',allow_pickle=True) #修改
info = big[0]

new_info_mu_1024 = [info[0][1],info[1][1],info[2][1],info[3][1]]
new_info_pho_1024 = [info[0][0],info[1][0],info[2][0],info[3][0]]



big = np.load('2048'+additional_name+'.npy',allow_pickle=True) #修改
info = big[0]

new_info_mu_2048 = [info[0][1],info[1][1],info[2][1],info[3][1]]
new_info_pho_2048 = [info[0][0],info[1][0],info[2][0],info[3][0]]



big = np.load('4096'+additional_name+'.npy',allow_pickle=True) #修改
info = big[0]

new_info_mu_4096 = [info[0][1],info[1][1],info[2][1],info[3][1]]
new_info_pho_4096 = [info[0][0],info[1][0],info[2][0],info[3][0]]

    
import matplotlib.pyplot as plt 

from matplotlib.pyplot import figure

figure(figsize=(6, 8), dpi=80)
  
layer_one_last = [new_info_mu_16[0][-1],new_info_mu_32[0][-1],new_info_mu_64[0][-1],
                  new_info_mu_128[0][-1], new_info_mu_256[0][-1], 
                  new_info_mu_512[0][-1], new_info_mu_1024[0][-1],
                  new_info_mu_2048[0][-1],new_info_mu_4096[0][-1]]

layer_two_last = [new_info_mu_16[1][-1],new_info_mu_32[1][-1],new_info_mu_64[1][-1],
                  new_info_mu_128[1][-1], new_info_mu_256[1][-1], 
                  new_info_mu_512[1][-1], new_info_mu_1024[1][-1],
                  new_info_mu_2048[1][-1],new_info_mu_4096[1][-1]]

layer_three_last = [new_info_mu_16[2][-1],new_info_mu_32[2][-1],new_info_mu_64[2][-1],
                  new_info_mu_128[2][-1], new_info_mu_256[2][-1], 
                  new_info_mu_512[2][-1], new_info_mu_1024[2][-1],
                  new_info_mu_2048[2][-1],new_info_mu_4096[2][-1]]

layer_four_last = [new_info_mu_16[3][-1],new_info_mu_32[3][-1],new_info_mu_64[3][-1],
                  new_info_mu_128[3][-1], new_info_mu_256[3][-1], 
                  new_info_mu_512[3][-1], new_info_mu_1024[3][-1],
                  new_info_mu_2048[3][-1],new_info_mu_4096[3][-1]]
  
x = np.array([16,32,64,128,256,512,1024,2048,4096])
##################################################################################

big_layer_one_last =[]
big_layer_two_last =[]
big_layer_three_last=[]
big_layer_four_last=[]

big_layer_one_last.append(layer_one_last)
big_layer_two_last.append(layer_two_last)
big_layer_three_last.append(layer_three_last)
big_layer_four_last.append(layer_four_last)

bigest_one = np.array(big_layer_one_last)
bigest_two = np.array(big_layer_two_last)
bigest_three = np.array(big_layer_three_last)
bigest_four= np.array(big_layer_four_last)
'''
'''
figure(figsize=(4.5,4))
(np.mean(bigest_two,axis=0)[0])
plt.loglog(x, np.mean(bigest_one,axis=0), linestyle="-", linewidth=1.0, color="blue", label="Layer 1 (input)",basex=2,basey=2)   
plt.errorbar(x, np.mean(bigest_one,axis=0), yerr= np.std(bigest_one,axis=0),color="blue", elinewidth=2,capsize=4)
plt.loglog(x, np.mean(bigest_two,axis=0), linestyle="-", linewidth=1.0, color="green",label="Layer 2 ",basex=2,basey=2)  
plt.errorbar(x, np.mean(bigest_two,axis=0), yerr= np.std(bigest_two,axis=0),color="green", elinewidth=2,capsize=4) 
plt.loglog(x, np.mean(bigest_three,axis=0), linestyle="-", linewidth=1.0, color="red",label="Layer 3 ",basex=2,basey=2)  
plt.errorbar(x, np.mean(bigest_three,axis=0), yerr= np.std(bigest_three,axis=0),color="red", elinewidth=2,capsize=4)      
plt.loglog(x, np.mean(bigest_four,axis=0), linestyle="-", linewidth=1.0, color="skyblue", label="Layer 4 (output)",basex=2,basey=2)    
plt.errorbar(x, np.mean(bigest_four,axis=0), yerr= np.std(bigest_four,axis=0),color="skyblue",elinewidth=2,capsize=4) 
plt.loglog(x, 1/x/(1/2048/(np.mean(bigest_two,axis=0)[-1])), linestyle="--", linewidth=1.0, color="red", label=r"$1/n$",basex=2,basey=2)  
plt.loglog(x, 1/x/(1/2048/(np.mean(bigest_three,axis=0)[-1])), linestyle="--", linewidth=1.0, color="red",basex=2,basey=2)  
plt.loglog(x, 1/(np.sqrt(x))/(1/(np.sqrt(2048))/(np.mean(bigest_one,axis=0)[-1])), linestyle="--", color="black", linewidth=1.0, label=r"$1/\sqrt{n}$",basex=2,basey=2)    
plt.loglog(x, 1/(np.sqrt(x))/(1/(np.sqrt(2048))/(np.mean(bigest_four,axis=0)[-1])), linestyle="--", color="black", linewidth=1.0,basex=2,basey=2) 
plt.legend(loc='lower left', borderaxespad=0.)
plt.xlim([16, 2048])
plt.xticks(x,fontsize=18)
plt.yticks(fontsize=18)
plt.ylim( (1/(np.power(2,19)), 1/(32)))
plt.title(r'$\left \|| \mu_{T}^{l} - \mu_{0}^{l} \right\||_F / \left \|| \mu_{0}^{l} \right\||_F$',fontsize=18)
plt.grid(linestyle='-')
plt.xlabel(r'n',fontsize=18)
legend = plt.legend()
legend.get_frame().set_edgecolor('black')
'''
from matplotlib.pyplot import figure
import matplotlib.pyplot as plt
figure(figsize=(4.5,4))
x = np.array([16,32,64,128,256,512,1024,2048])
plt.loglog(x, np.mean(bigest_one,axis=0), linestyle="-", linewidth=1.0, color="blue", label="Layer 1 (input)",basex=2,basey=2)   
plt.errorbar(x, np.mean(bigest_one,axis=0), yerr= np.std(bigest_one,axis=0),color="blue", elinewidth=2,capsize=4)
plt.loglog(x, np.mean(bigest_two,axis=0), linestyle="-", linewidth=1.0, color="green",label="Layer 2 ",basex=2,basey=2)  
plt.errorbar(x, np.mean(bigest_two,axis=0), yerr= np.std(bigest_two,axis=0),color="green", elinewidth=2,capsize=4) 
plt.loglog(x, np.mean(bigest_three,axis=0), linestyle="-", linewidth=1.0, color="red",label="Layer 3 ",basex=2,basey=2)  
plt.errorbar(x, np.mean(bigest_three,axis=0), yerr= np.std(bigest_three,axis=0),color="red", elinewidth=2,capsize=4)      
plt.loglog(x, np.mean(bigest_four,axis=0), linestyle="-", linewidth=1.0, color="skyblue", label="Layer 4 (output)",basex=2,basey=2)    
plt.errorbar(x, np.mean(bigest_four,axis=0), yerr= np.std(bigest_four,axis=0),color="skyblue",elinewidth=2,capsize=4) 
plt.loglog(x, 1/x/(1/2048/(np.mean(bigest_two,axis=0)[-1])), linestyle="--", linewidth=1.0, color="red", label=r"$1/n$",basex=2,basey=2)  
plt.loglog(x, 1/x/(1/2048/(np.mean(bigest_three,axis=0)[-1])), linestyle="--", linewidth=1.0, color="red",basex=2,basey=2)  
plt.loglog(x, 1/(np.sqrt(x))/(1/(np.sqrt(2048))/(np.mean(bigest_one,axis=0)[-1])), linestyle="--", color="black", linewidth=1.0, label=r"$1/\sqrt{n}$",basex=2,basey=2)    
plt.loglog(x, 1/(np.sqrt(x))/(1/(np.sqrt(2048))/(np.mean(bigest_four,axis=0)[-1])), linestyle="--", color="black", linewidth=1.0,basex=2,basey=2) 
plt.legend(loc='lower left', borderaxespad=0.)
plt.xlim([16, 2048])
plt.xticks(x,fontsize=18)
plt.yticks(fontsize=18)
plt.ylim( (1/(np.power(2,19))/40, 1/(200)))
plt.title(r'$\left \|| \rho_{T}^{l} - \rho_{0}^{l} \right\||_F / \left \|| \rho_{0}^{l} \right\||_F$',fontsize=18)
plt.grid(linestyle='-')
plt.xlabel(r'n',fontsize=18)
legend = plt.legend()
legend.get_frame().set_edgecolor('black')

'''
fig = plt.figure()
ax = plt.axes()
ax.set_xscale("log",basex=2)
ax.set_yscale("log",basey=2);ax.errorbar(x,np.mean(bigest_one,axis=0), yerr= np.std(bigest_one,axis=0), fmt = 'b')
'''