import os

import matplotlib.pyplot as plt
import numpy as np
from tool import loaddata, makedir

savedir = '/home/***/data/undergraky/experiment/Cifar10/corssentropy2linear/230327103217tanhadam3layer'
data = loaddata(os.path.join(savedir,'trainpro.pkl'))
for i in range(len(data['weightvector'])):
# for i in range(44,45):
    # if i % 1000 == 999:
    l1 = np.asarray(data['weightvector'][i][2])
    
    l1norm = np.linalg.norm(l1,axis=1)
    l1norm = l1norm.reshape(1024,1)
    # print(l1norm.shape)
    l1normal = l1/l1norm
    
    # print(l1normal.shape)
    # break
    l1norm = l1norm.reshape(1024)
# control length
    l1m = l1normal
    l1n = l1norm
    one = np.ones_like(l1[0])
    one = one/np.linalg.norm(one)



    rankvector = np.dot(l1m,l1m[1])
    index = np.argsort(rankvector)
    data1 = l1n[index]        
    data2 = np.dot(l1m[index],one)
    fig,ax1 = plt.subplots()

    ax1.set_xlabel('kernel index',fontsize=20)
    ax1.set_ylabel('kernel length (log)',fontsize=20)
    # ax1.set_xticklabels(fontsize=20)
    ax1.tick_params(axis='y',which='minor',labelsize='20')
    ax1.tick_params(axis='y',which='major',labelsize='20')

    plt.xticks(fontsize=20)
    
    ax1.set_yscale('log')
    ax1.plot(data1,color='r')

    ax2 = ax1.twinx()    
    ax2.set_yticks([-1.0,-0.5,0,0.5,1.0])
    ax2.tick_params(axis='y',which='both',labelsize=20)
    ax2.set_ylabel('cosine similarity',fontsize=20)
    ax2.set_ylim(-1.05,1.05)
    ax2.scatter(range(1024),data2,color='b')

    fig.legend(labels=('kernel length','cosine similarity'),loc='lower right',bbox_to_anchor=(0.81,0.17))
    plt.tight_layout()
    plt.savefig(os.path.join(savedir,'layer3lengthandinn%s.png'%(i+1)),dpi=300)
    plt.close()
    # print(1)
    
    rel1normal = l1m[index]
    inner_matrix_l = np.dot(rel1normal,rel1normal.T)

# print(inner_matrix.shape)

# plt.pcolormesh(inner_matrix,vmin=-1,vmax=1)
# plt.colorbar()
# plt.savefig(os.path.join(savedir,'inp%s.png'%(i+1)))
# plt.close()



    plt.pcolormesh(inner_matrix_l,vmin=-1,vmax=1,cmap='YlGnBu')
    # plt.title('D(u,v)')
    plt.xlabel('kernel index',fontsize=20)
    plt.ylabel('kernel index',fontsize=20)    
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    cb = plt.colorbar(ticks=[-1,-0.5,0,0.5,1])
    cb.ax.tick_params(labelsize=20)
    plt.tight_layout()
    plt.savefig(os.path.join(savedir,'layer3linp%s.png'%(i+1)),dpi=300)
    plt.close()


    print(1)  


# l1 = np.asarray(data['weightvector'][680][2])
    
# l1norm = np.linalg.norm(l1,axis=1)
# l1norm = l1norm.reshape(1024,1)
# # print(l1norm.shape)
# l1normal = l1/l1norm

# # print(l1normal.shape)
# # break
# l1norm = l1norm.reshape(1024)
# ind = np.argsort(l1norm)[::-1]
# # print(l1norm.shape)
# l1m = l1normal[ind]
# l1n = l1norm[ind]
# rankvector = np.dot(l1m,l1m[1])
# index = np.argsort(rankvector)

