import os,sys
import matplotlib
matplotlib.use('Agg')   
import pickle
import time  
import shutil 
import numpy as np
import matplotlib.pyplot as plt   
#from BasicFunc import mySaveFig, mkdir
import platform
from mpl_toolkits.mplot3d import Axes3D 
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from datetime import datetime
from mayavi import mlab
from sympy import *
from mayavi.mlab import *
from matplotlib.lines import Line2D
from sklearn.cluster import KMeans
import seaborn as sns

size = 50

target_folder=r'XXX/object.pkl'
with open(target_folder,'rb') as f:
    RR = pickle.load(f)

print(RR.keys())

for iii in range(200):
    W_now = RR['W'][iii]
    W_now_lenth = np.reshape(np.sqrt(np.sum(W_now**2,axis=1)),(RR['hidden_layer'][0],1))
    print(np.shape(W_now_lenth))

    cos_distance_matrix = RR['cos_distance_matrix']
    cos_distance_matrix_temp = cos_distance_matrix[iii]
    
    order = []
    order1 = range(size)
    k = 0
    order_temp = []
    order1 = []
    for j in range(size):
        mark = -1 
        if j != 0:
            for i in order2:
                if cos_distance_matrix_temp[k][i] > 0.85:
                    order.append(i)
                else:
                    order1.append(i)
                if cos_distance_matrix_temp[k][i] < -0.5:
                    mark = i
        else:
            for i in range(size):
                if cos_distance_matrix_temp[0][i] > 0.85:
                    order.append(i)
                else:
                    order1.append(i)
                if cos_distance_matrix_temp[k][i] < -0.5:
                    mark = i

        order_temp = order_temp + order
        if len(order_temp) == size:
            break
        #print(len(order_temp))
        if mark == -1:
            k = order1[0]
        else:
            k = mark
        order2 = order1
        order1 = []
        order = []

    cos_distance_matrix_temp = cos_distance_matrix_temp[order_temp,:]
    cos_distance_matrix_temp = cos_distance_matrix_temp[:,order_temp]
    W_now_lenth = W_now_lenth[order_temp,:]

    record = []
    for i in range(np.shape(W_now_lenth)[0]):
        if W_now_lenth[i,:]<0.03:
            record.append(i)
    
    cos_distance_matrix_temp = np.delete(cos_distance_matrix_temp,record,axis=1)
    cos_distance_matrix_temp = np.delete(cos_distance_matrix_temp,record,axis=0)
    

    plt.rcParams['savefig.dpi'] = 200 #图片像素
    plt.rcParams['figure.dpi'] = 200 #分辨率
    fig,ax = plt.subplots()
    # ax = sns.heatmap(cos_distance_matrix_temp,linewidths = 0,vmin=-1,vmax=1,cmap='YlGnBu_r') # ,xticklabels = np.arange(40),yticklabels = np.arange(40))
    # ax.set_xticks(np.arange(40)) #设置x轴刻度
    # ax.set_yticks(np.arange(40)) #设置y轴刻度
    # ax.xaxis.set_ticks_position('top')
    # ax.set_xticklabels(range(40),fontsize=5)
    # ax.set_yticklabels(range(40),fontsize=5)
    plt.imshow(cos_distance_matrix_temp,cmap='YlGnBu_r')
    cb = plt.colorbar(ticks=[-1.0,-0.5,0.0,0.5,1.0])
    cb.ax.tick_params(labelsize=14)
    # ax.xaxis.set_ticks_position('top')
    plt.clim(-1, 1)
    # plt.colorbar(fig, ) 
    plt.xticks([])  #去掉x轴
    plt.yticks([])  #去掉y轴
    ax.set_title("cos distance: $x^{2}tanh(x)$",fontsize=18)
    plt.tight_layout()
    plt.savefig(r'XXX/heatmap_step_%s.png'%(iii*5))
    plt.close()
    plt.clf()

    fig,ax = plt.subplots()
    plt.plot(np.floor(np.linspace(start=1,stop=size,endpoint=True,num=size)),W_now_lenth)
    plt.xticks([])  #去掉x轴
    my_x_ticks = np.floor(np.linspace(start=1,stop=size,endpoint=True,num=size))
    plt.xticks(my_x_ticks)
    plt.xlabel(r'Index',fontsize=18)
    plt.ylabel(r'Amplitude',fontsize=18)
    plt.tick_params(axis='y',which='major',labelsize=14)
    plt.tick_params(axis='x',which='major',labelsize=7)
    my_x_ticks = [1,10,20,30,40,size]
    plt.xticks(my_x_ticks)
    ax.set_title("Amplitude: $x^{2}tanh(x)$",fontsize=18)
    plt.tight_layout()
    plt.savefig(r'XXX/length_step_%s.png'%(iii*5))
    plt.close()
    plt.clf()


