import tensorflow as tf
import pickle
import sys
import tensorflow.keras
import os
import numpy as np
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten
# from tensorflow.keras.utils.np_utils import *
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.models import Model
from tensorflow.keras.models import load_model

def get_cos_distance(model,target_size,pathh,epochs):

    # if True: #not platform.system()=='Windows':
    #     shutil.copy(__file__,'%s%s'%(pathh,os.path.basename(__file__)))

    size = target_size

    # print(model.get_weights())

    dense_layer = model.get_layer(name='dense1')

    # print(dense_layer.weights[0])
    # print(np.shape(dense_layer.weights[0]))

    dense_layer = model.get_layer(name='dense1')

    W_dir = dense_layer.weights[0]
    # print(dense_layer.weights[0])
    # print(np.shape(dense_layer.weights[0]))

    # print("W_shape:",np.shape(W_dir))
    # print(np.shape(np.sum(W_dir**2,axis=0)))

    W_dir = W_dir/np.reshape(np.sqrt(np.sum(W_dir**2,axis=0)),(1,size))

    # print(np.shape(W_dir))

    # print('max: W_dir',np.max(W_dir))


    cos_distance = np.zeros((size,size))
    # print(np.shape(W_dir[:,2]))
    # print(np.shape(np.sum(W_dir**2,axis=0)))
    for i in range(size):
        if i % 100 == 0:
            print(i)
        for j in range(size):
            cos_distance[i][j] = np.sum(W_dir[:,i]*W_dir[:,j]) /( np.sqrt(np.sum(W_dir[:,i]**2))*np.sqrt(np.sum(W_dir[:,j]**2)))
    # print(cos_distance[6][5])

    # print(cos_distance)

    W_now = W_dir
    W_now_lenth = np.reshape(np.sqrt(np.sum(W_now**2,axis=0)),(1,size))
    print(np.shape(np.sum(W_now**2,axis=0)))

    cos_distance_matrix = cos_distance
    cos_distance_matrix_temp = cos_distance

    order = []
    order1 = range(size)
    k = 0
    order_temp = []
    order1 = []
    for j in range(size):
        mark = -1 
        if j != 0:
            for i in order2:
                if cos_distance_matrix_temp[k][i] > 0.6:
                    order.append(i)
                else:
                    order1.append(i)
                if cos_distance_matrix_temp[k][i] < -0.6:
                    mark = i
        else:
            for i in range(size):
                if cos_distance_matrix_temp[0][i] > 0.6:
                    order.append(i)
                else:
                    order1.append(i)
                if cos_distance_matrix_temp[k][i] < -0.6:
                    mark = i

        order_temp = order_temp + order
        if len(order_temp) == size:
            break
        #print(len(order_temp))
        if mark == -1:
            k = order1[0]
        else:
            k = mark
        order2 = order1
        order1 = []
        order = []

    cos_distance_matrix_temp = cos_distance_matrix_temp[order_temp,:]
    cos_distance_matrix_temp = cos_distance_matrix_temp[:,order_temp]
    W_now_lenth = W_now_lenth[:,order_temp]

    record = []
    # for i in range(np.shape(W_now_lenth)[0]):
    #     if W_now_lenth[i,:]<0.00000000000000001:
    #         record.append(i)

    cos_distance_matrix_temp = np.delete(cos_distance_matrix_temp,record,axis=1)
    cos_distance_matrix_temp = np.delete(cos_distance_matrix_temp,record,axis=0)

    plt.rcParams['savefig.dpi'] = 200 #图片像素
    plt.rcParams['figure.dpi'] = 200 #分辨率
    fig,ax = plt.subplots()
    # ax = sns.heatmap(cos_distance_matrix_temp,linewidths = 0,vmin=-1,vmax=1,cmap='YlGnBu_r') # ,xticklabels = np.arange(40),yticklabels = np.arange(40))
    # ax.set_xticks(np.arange(40)) #设置x轴刻度
    # ax.set_yticks(np.arange(40)) #设置y轴刻度
    # ax.xaxis.set_ticks_position('top')
    # ax.set_xticklabels(range(40),fontsize=5)
    # ax.set_yticklabels(range(40),fontsize=5)
    plt.imshow(cos_distance_matrix_temp,cmap='YlGnBu_r')
    cb = plt.colorbar(ticks=[-1.0,-0.5,0.0,0.5,1.0])
    cb.ax.tick_params(labelsize=14)
    # ax.xaxis.set_ticks_position('top')
    plt.clim(-1, 1)
    plt.xlabel(r'Neu index',fontsize=18,labelpad=-17.0)
    plt.ylabel(r'Neu index',fontsize=18,rotation=90,labelpad=-15.0)
    # plt.colorbar(fig, ) 
    plt.xticks([0,np.shape(cos_distance_matrix_temp)[0]-1])  #去掉x轴
    plt.yticks([np.shape(cos_distance_matrix_temp)[0]-1])  #去掉y轴
    ax.set_title("$D(u,v)$",fontsize=18)
    ax.invert_yaxis()
    plt.tick_params(labelsize=18)
    plt.tight_layout()
    plt.savefig(r'%s/heatmap_step_%d.png'%(pathh,epochs))
    plt.close()
    plt.clf()



    fig,ax = plt.subplots()
    plt.plot(np.floor(np.linspace(start=1,stop=size,endpoint=True,num=size)),np.transpose(W_now_lenth))
    plt.xticks([])  #去掉x轴
    my_x_ticks = np.floor(np.linspace(start=1,stop=size,endpoint=True,num=size))
    plt.xticks(my_x_ticks)
    plt.xlabel(r'Index',fontsize=18)
    plt.ylabel(r'Amplitude',fontsize=18)
    plt.tick_params(axis='y',which='major',labelsize=14)
    plt.tick_params(axis='x',which='major',labelsize=7)
    my_x_ticks = [1,size]
    plt.xticks(my_x_ticks)
    ax.set_title("Amplitude: $x^{2}tanh(x)$",fontsize=18)
    plt.tight_layout()
    plt.savefig(r'%s/length_step_%d.png'%(pathh,epochs))
    plt.close()
    plt.clf()