import os,sys
import matplotlib
matplotlib.use('Agg')   
import pickle
import time  
import shutil 
import tensorflow.compat.v1 as tf
import numpy as np
import matplotlib.pyplot as plt   
#from BasicFunc import mySaveFig, mkdir
import platform
from mpl_toolkits.mplot3d import Axes3D 
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from datetime import datetime
import seaborn as sns
from matplotlib.lines import Line2D
import math

os.environ["CUDA_VISIBLE_DEVICES"]='0'
tf.disable_eager_execution()

#############一些关于训练的参数##############
Q={}
Q['input_size'] = 1
Q['output_size'] = 1    
Q['nerons_each_layer_m'] = 2500
Q['hidden_layer'] = [Q['nerons_each_layer_m'],Q['nerons_each_layer_m']]
Q['in_learning_rate']= 5e2
Epoches = 1000000
Q['thredhood'] = int( Q['nerons_each_layer_m'] / 2 )
Q['cos_dis_thr'] = 0.5
Q['tol'] = 1e-7
Q['RD_w_0'] = []
Q['RD_w_1'] = []
Q['RD_w_0_last'] = 0
Q['RD_w_1_last'] = 0
Q['S_w_0'] = []
Q['S_w_1'] = []
Q['Omega_layer1'] = []
Q['cos_distance_matrix'] = []
Q['step'] = []
Q['loss']=[]
Q['w_plot'] = []
Q['b_plot'] = []
Q['w_dot_plot_layer1'] = []
Q['b_dot_plot_layer1'] = []
Q['w1_w2_w1_of_layer_2to3'] = []
Q['w1_w2_w2_of_layer_2to3'] = []
Q['b_dot_plot_layer1'] = []
Q['W_1_beta_1'] = 1.0
Q['W_2_beta_2'] = Q['W_1_beta_1']*Q['nerons_each_layer_m']**(-0.5)
Q['a_beta_3'] = Q['W_2_beta_2']
Q['alpha'] =  Q['W_1_beta_1']**3 * Q['nerons_each_layer_m']**(1.5)
Q['kappa_1'] = Q['a_beta_3'] / Q['W_2_beta_2']
Q['kappa_2'] = Q['a_beta_3'] / Q['W_1_beta_1']
Q['kappa_3'] = Q['a_beta_3']*Q['W_1_beta_1']*Q['W_2_beta_2'] / Q['alpha'] 
Q['gamma_1'] = - np.log(Q['kappa_1']) / np.log(Q['nerons_each_layer_m'])
Q['gamma_2'] = - np.log(Q['kappa_2']) / np.log(Q['nerons_each_layer_m'])
Q['gamma_3'] = - np.log(Q['kappa_3']) / np.log(Q['nerons_each_layer_m'])

Q['Y_test_network'] = []
Q['Y_net_train'] = []
print('gamma_2:',Q['gamma_2'])
print('gamma_3:',Q['gamma_3'])

@tf.function


#############生成目录##############
def mkdir(fn): #熟悉，做目录
    if not os.path.isdir(fn):
        os.mkdir(fn)
ran = int(np.absolute(np.random.normal([1])*100000))//int(1)
sBaseDir0='fitnd' 
# BaseDir = '../../../nn/fitnd/'
if platform.system() =='Windows':
    # device_n="0"
    BaseDir0 = r'E:\deep_study\20210601_threelayers_limit_width\fitn2_gam2_0.5/%s'%(sBaseDir0) 
else:
    # device_n="0"
    BaseDir0=sBaseDir0 
    matplotlib.use('Agg')
subsubFolderName = 'm_%s_%s'%(Q['nerons_each_layer_m'],ran) 
subFolderName = 'gam3_%.2f'%(Q['gamma_3']) 
FolderName = '%s/%s/'%(BaseDir0,subFolderName)
mkdir(BaseDir0) 
mkdir(FolderName)
FolderName = r'%s/%s/%s/'%(BaseDir0,subFolderName,subsubFolderName)
mkdir(FolderName)

if True: #not platform.system()=='Windows':
    shutil.copy(__file__,'%s%s'%(FolderName,os.path.basename(__file__)))


def savefile(): #保存模型参数的函数
    with open('%s/threelayers_Phasediagram.pkl'%(FolderName), 'wb') as f:  # Python 3: open(..., 'wb')
        pickle.dump(Q, f, protocol=4)
    #序列化对象，将对象obj保存到文件file中去
    text_file = open("%s/threelayers_Phasediagram.txt"%(FolderName), "w")
    for para in Q:
        if np.size(Q[para])>1000:
            continue
        text_file.write('%s: %s\n'%(para,Q[para]))
    
    for para in sys.argv: 
        text_file.write('%s  '%(para))
    text_file.close()
#############一些关于训练的参数##############
Q['data_num_train'] = 4
Q['data_num_test'] = 400

###########生成目标函数##########
x_train = [-1.5,-0.5,0.5,1.5]
y_train = [1.0,0,0,1.0]

Q['x_train'] = [-1.5,-0.5,0.5,1.5]
Q['y_train'] = [1.0,0,0,1.0]

plt.scatter(x_train,y_train,c='b',s=10)
plt.savefig(r'%s/true.png'%(FolderName))
plt.close()

x_test = np.linspace(start=-2.0,stop =2.0, num =Q['data_num_test'],endpoint=True)

print(np.shape(x_train))
print(np.shape(y_train))

############一些参数#############
Q['FolderName'] = FolderName
Q['train_set'],Q['train_label'] = np.reshape(x_train,[Q['data_num_train'],1]) , np.reshape(y_train,[Q['data_num_train'],1]) 
Q['test_set']= np.reshape(x_test,[Q['data_num_test'],1]) 

###################
#   <X,W>+b,以这种形式的话，X的列大小是神经元个数，行大小是样本个数
#   构建参数w和b
def initializer_generate(inp_size=10 , hidden_layer = [20] , out_size=10 ):
    Weights_ini = []
    Biases_ini = []
    W_ini_0 = tf.random_normal(shape=[inp_size, hidden_layer[0]],dtype = 'float32',mean=0.0,stddev=Q['W_1_beta_1'])
    B_ini_0 = tf.random_normal(shape = [1, hidden_layer[0]],dtype = 'float32',mean=0.0,stddev=Q['W_1_beta_1'])
    Weights_ini.append(W_ini_0)
    Biases_ini.append(B_ini_0)
    for k in range(len(hidden_layer)-1):
        W_ini = tf.random_normal(shape=[hidden_layer[k], hidden_layer[k+1]],dtype = 'float32',mean=0.0,stddev=Q['W_2_beta_2'])
        B_ini = tf.random_normal(shape = [1, hidden_layer[k+1]],dtype = 'float32',mean=0.0,stddev=Q['W_2_beta_2'])
        Weights_ini.append(W_ini)
        Biases_ini.append(B_ini)
        print(k)
    W_ini = tf.random_normal(shape=[hidden_layer[-1], out_size],dtype = 'float32',mean=0.0,stddev=Q['a_beta_3'])
    Weights_ini.append(W_ini)
    return Weights_ini,Biases_ini

def Init_DNN( inp_size=10 , hidden_layer = [20] , out_size=10 ,Weights_ini=0,Biases_ini=0): 
    Weights = []
    Biases = []
    W = tf.Variable(Weights_ini[0])
    B = tf.Variable(Biases_ini[0])
    Weights.append(W)
    Biases.append(B)
    for k in range(len(hidden_layer)-1):
        W = tf.Variable(Weights_ini[k+1])
        B = tf.Variable(Biases_ini[k+1])
        Weights.append(W)
        Biases.append(B)
    W = tf.Variable(Weights_ini[-1])
    Weights.append(W)
    return Weights, Biases    

#   构建网络 在全连接的情况下，甚至网络结构都隐藏在了Weights和Biases里面
def multilayer(X, Weights, Biases, activation = tf.nn.relu): 
    outs = []
    layers = len(Weights)
    #print(layers)
    H = X
    outs.append(H)
    for k in range(layers-1):
        W = Weights[k]
        B = Biases[k]
        H = activation(tf.add(tf.matmul(H, W), B)) #这个是正常的，下面是特殊的
        outs.append(H)
    W = Weights[-1]
    out = (1 / Q['alpha'])* tf.matmul(H, W)
    outs.append(out)
    return out , outs

###############构建网络###################
with tf.variable_scope('Graph',reuse=tf.AUTO_REUSE) as scope:
    X = tf.placeholder(tf.float32,shape=[None,1],name = 'X')
    Y_true = tf.placeholder(tf.float32,shape=[None,1],name = 'Y_ture')
    Weights0 ,Biases0 = initializer_generate(inp_size=Q['input_size'],hidden_layer=Q['hidden_layer'],out_size=Q['output_size'])
    Weights1 ,Biases1 = Init_DNN(inp_size=Q['input_size'],hidden_layer=Q['hidden_layer'],out_size=Q['output_size'],Weights_ini=Weights0,Biases_ini=Biases0)
    Y , layers = multilayer(X,Weights1,Biases1)
    Loss=tf.reduce_mean((Y-Y_true)**2)
    adam = tf.train.GradientDescentOptimizer(learning_rate=Q['in_learning_rate'])
    train_op = adam.minimize(Loss)
    print('the project seems healthy!')

##################

config = tf.ConfigProto(allow_soft_placement=True) #以下是用来指派设备的
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7)
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer()) #初始化参数，没关系下面又初始化了一次 注意，这里仅仅只是初始化了，并没有真正的跑过，如果真的run，是会出问题的；本人改了代码试验过，是不行的，（注意y_true是一维的，因此可以光波到y的维数。）
saver = tf.train.Saver() 

#####开始训练
for itepch in range(Epoches):
    #############################计算RD与S_theta#############################
    if itepch == 0:
        Y_net_train ,Weights_Net_1 ,Biases1_Net_1 = sess.run([Y,Weights1,Biases1], feed_dict={X:Q['train_set'], Y_true :Q['train_label']})
        W_initial = Weights_Net_1
        b_inital = Biases1_Net_1
        W_dir_initial_0 = np.transpose(np.concatenate((W_initial[0],b_inital[0]),axis=0))
        W_dir_initial_1 = np.transpose(np.concatenate((W_initial[1],b_inital[1]),axis=0))
    if itepch%100==0:
        Y_net_train ,Weights_Net_1 ,Biases1_Net_1 = sess.run([Y,Weights1,Biases1], feed_dict={X:Q['train_set'], Y_true :Q['train_label']})
        W_now = Weights_Net_1
        b_now = Biases1_Net_1
        W_dir_now_0 = np.transpose(np.concatenate((W_now[0],b_now[0]),axis=0))
        W_dir_now_1 = np.transpose(np.concatenate((W_now[1],b_now[1]),axis=0))
        Q['RD_w_0'].append(np.sqrt(np.sum((W_dir_now_0-W_dir_initial_0)**2 +1e-20 ) ) / np.sqrt(np.sum(W_dir_initial_0**2) +1e-20 )  )
        Q['RD_w_1'].append(np.sqrt(np.sum((W_dir_now_1-W_dir_initial_1)**2 +1e-20 ) ) / np.sqrt(np.sum(W_dir_initial_1**2) +1e-20 )  )
        Q['S_w_0'].append(np.log(Q['RD_w_0'][-1])/np.log(Q['nerons_each_layer_m']))
        Q['S_w_1'].append(np.log(Q['RD_w_1'][-1])/np.log(Q['nerons_each_layer_m']))
        Q['RD_w_0_last'] = Q['RD_w_0'][-1]
        Q['RD_w_1_last'] = Q['RD_w_1'][-1]
        print('RD_w_0', Q['RD_w_0'][-1])
        print('RD_w_1', Q['RD_w_1'][-1])
        # print('S_w_0', Q['S_w_0'][-1])
        # print('S_w_1', Q['S_w_1'][-1])
      
    #############################画第一层到第二层角度的分布以及w,b#############################
    if itepch%100==0:
        Y_net_train ,Weights_Net_1 ,Biases1_Net_1 = sess.run([Y,Weights1,Biases1], feed_dict={X:Q['train_set'], Y_true :Q['train_label']})
        #######统计角度的分布########
        rangee = np.linspace(start=-np.pi/2,stop=np.pi/2,num=20,endpoint=True)
        Omega = Biases1_Net_1[0][0]/Weights_Net_1[0][0]
        Q['Omega_layer1'].append(Omega)
        plt.style.use('seaborn-whitegrid')
        plt.hist(np.arctan(Q['Omega_layer1'][-1]),rangee,histtype='barstacked',edgecolor = 'k',color = 'steelblue',label='num')
        plt.legend()
        plt.xlim([-2.0,2.0])
        plt.xlabel(r'angle',fontsize=18)
        plt.ylabel(r'num',fontsize=18)
        plt.savefig(r'%s/Omega_of_layer_1to2_%s.png'%(Q['FolderName'],itepch))
        plt.close()
        plt.style.use( 'default')

        Q['w_dot_plot_layer1'].append(np.array(Weights_Net_1[0][0]))
        Q['b_dot_plot_layer1'].append(np.array(Biases1_Net_1[0][0]))
        ax = plt.gca()
        plt.rcParams['savefig.dpi'] = 200 #图片像素
        plt.rcParams['figure.dpi'] = 200 #分辨率
        # plt.xlim(-2,2)
        # plt.ylim(-2,2)
        plt.xlabel(r'w',fontsize=18)
        plt.ylabel(r'b',fontsize=18)
        plt.scatter(Q['w_dot_plot_layer1'][0],Q['b_dot_plot_layer1'][0],color = 'r',s = 10, label = 'initial')
        plt.scatter(Q['w_dot_plot_layer1'][-1],Q['b_dot_plot_layer1'][-1],color = 'g',s = 10,label = 'now')
        plt.legend()
        plt.savefig(r'%s/w_b_of_layer_1to2_%s.png'%(Q['FolderName'],itepch))
        plt.close()

    #############################画第二层到第三层的最大的W1_W2#############################
    if itepch%500==0:
        Y_net_train ,layers_hidden, Weights_Net_1 ,Biases1_Net_1 = sess.run([Y,layers,Weights1,Biases1], feed_dict={X:Q['train_set'], Y_true :Q['train_label']})
        W_dir = np.transpose(np.concatenate((Weights_Net_1[1],Biases1_Net_1[1]),axis=0))
        W_dir_lenth = np.sum(W_dir**2,axis=1)
        pos = np.argsort(W_dir_lenth)  
        Q['w1_w2_w1_of_layer_2to3'].append(W_dir[pos[-1],:])
        Q['w1_w2_w2_of_layer_2to3'].append(W_dir[pos[-2],:])

        ax = plt.gca()
        plt.rcParams['savefig.dpi'] = 200 #图片像素
        plt.rcParams['figure.dpi'] = 200 #分辨率
        plt.xlabel(r'w1',fontsize=18)
        plt.ylabel(r'w2',fontsize=18)
        plt.scatter(Q['w1_w2_w1_of_layer_2to3'][0],Q['w1_w2_w2_of_layer_2to3'][0],color = 'r',s = 10,label='initial')
        plt.scatter(Q['w1_w2_w1_of_layer_2to3'][-1],Q['w1_w2_w2_of_layer_2to3'][-1],color = 'g',s = 10,label='now')
        plt.legend()
        plt.savefig(r'%s/largest_w1_w2_layer2to3_%s.png'%(Q['FolderName'],itepch))
        plt.close()

    #############################画第二层到第三层的W的余弦相似度#############################
    if itepch%500==0:
        Y_net_train ,layers_hidden, Weights_Net_1 ,Biases1_Net_1 = sess.run([Y,layers,Weights1,Biases1], feed_dict={X:Q['train_set'], Y_true :Q['train_label']})
        W_dir = np.transpose(np.concatenate((Weights_Net_1[1],Biases1_Net_1[1]),axis=0))
        W_dir_lenth = np.sum(W_dir**2,axis=1)
        W_dir_lenth = np.sort(W_dir_lenth)
        W_dir_lenth_thredhood = W_dir_lenth[-1 * Q['thredhood']]
        W_dir = np.transpose(np.concatenate((Weights_Net_1[1],Biases1_Net_1[1]),axis=0))
        W_dir_lenth = np.sum(W_dir**2,axis=1)        
        counter = []
        for i in range(np.shape(W_dir_lenth)[0]):
            if W_dir_lenth[i]<W_dir_lenth_thredhood:
                counter.append(i)
        W_dir = np.delete(W_dir,counter,axis=0)
        ##按照长度由长到短重新排序
        W_dir_lenth_2 = np.sum(W_dir**2,axis=1)
        W_dir_lenth_2 = np.argsort(W_dir_lenth_2)  
        W_dir = W_dir[np.flipud(W_dir_lenth_2),:]
        #计算余弦距离
        cos_distance = np.zeros((Q['thredhood'],Q['thredhood']))
        for i in range(Q['thredhood']):
            for j in range(Q['thredhood']):
                cos_distance[i][j] = np.sum(W_dir[i]*W_dir[j]) /( np.sqrt(np.sum(W_dir[i]**2))*np.sqrt(np.sum(W_dir[j]**2)))
        #按照余弦距离聚类
        cos_distance_matrix = cos_distance
        order = []
        order1 = range(Q['thredhood'])
        k = 0
        order_temp = []
        order1 = []
        for j in range(Q['thredhood']):
            if j != 0:
                for i in order2:
                    if cos_distance_matrix[k][i] > Q['cos_dis_thr']:
                        order.append(i)
                    else:
                        order1.append(i)
            else:
                for i in range(Q['thredhood']):
                    if cos_distance_matrix[0][i] > Q['cos_dis_thr']:
                        order.append(i)
                    else:
                        order1.append(i)
            order_temp = order_temp + order
            if len(order_temp) == Q['thredhood']:
                break
            #print(len(order_temp))
            k = order1[0]
            order2 = order1
            order1 = []
            order = []
        cos_distance_matrix = cos_distance_matrix[order_temp,:]
        cos_distance_matrix = cos_distance_matrix[:,order_temp]
        Q['cos_distance_matrix'].append(cos_distance_matrix)

        plt.rcParams['savefig.dpi'] = 300 #图片像素
        plt.rcParams['figure.dpi'] = 300 #分辨率
        fig,ax = plt.subplots()
        ax = sns.heatmap(Q['cos_distance_matrix'][-1],vmin=-1,vmax=1,cmap='YlGnBu_r')#,linewidths = 0.05,xticklabels = np.arange(Q['thredhood']),yticklabels = np.arange(Q['thredhood']))
        # ax.set_xticks(np.arange(40)) #设置x轴刻度
        # ax.set_yticks(np.arange(40)) #设置y轴刻度
        ax.xaxis.set_ticks_position('top')
        # ax.set_xticklabels(range(Q['thredhood']),fontsize=3)
        # ax.set_yticklabels(range(Q['thredhood']),fontsize=3)
        ax.set_title("cos distance: relu",fontsize=16)
        plt.savefig(r'%s/heatmap_1_step_%s.png'%(Q['FolderName'],itepch))
        plt.close()
        
        print(np.shape(cos_distance_matrix))
        plt.figure()
        ax = plt.gca()
        plt.figure()
        plt.imshow(Q['cos_distance_matrix'][-1],cmap='YlGnBu_r')
        plt.colorbar()
        ax.xaxis.set_ticks_position('top')
        plt.title("cos distance: relu",fontsize=16)
        plt.clim(-1, 1)
        plt.savefig((r'%s/heatmap_2_step_%s.png'%(Q['FolderName'],itepch)))
        plt.close()
        plt.clf()
    ###############################更新一次
    Q['step'].append(itepch)
    train_loss = sess.run(Loss, feed_dict={X:Q['train_set'] , Y_true :Q['train_label']})
    Q['loss'].append(train_loss)
    if train_loss < Q['tol']:
        break
    if itepch%100 == 0:
        print('training loss:',train_loss)
        Y_test_network = sess.run(Y, feed_dict={X:Q['test_set']})
        Q['Y_test_network'].append(Y_test_network)
        ##############画loss############
        plt.figure()
        ax = plt.gca()
        plt.plot(Q['loss'])
        plt.title('loss',fontsize=15)        
        ax.set_yscale('log')
        plt.xlabel(r'epochs',fontsize=18)
        plt.ylabel(r'loss',fontsize=18)
        ax.set_xscale('log')
        # plt.xlim([-4.5,4.5])
        # plt.ylim([-4.5,4.5])
        plt.savefig(r'%s/loss_%s.png'%(Q['FolderName'],itepch))
        plt.close()

        ##############Y_test############
        Y_net_train = sess.run( Y, feed_dict={X:Q['train_set']})
        Q['Y_net_train'].append(Y_net_train)
        plt.figure()
        ax = plt.gca()
        Y_net_ends = [(Q['train_set'][0],Y_net_train[0]),(Q['train_set'][-1],Y_net_train[-1])]
        # (line1_xs, line1_ys) = zip(*Y_net_ends)
        # ax.add_line(Line2D(line1_xs, line1_ys, linewidth=1.0, color='red',label = 'comparison: a straight line'))
        plt.plot(Q['test_set'],Q['Y_test_network'][-1], linewidth=1.0,label = 'test')
        plt.scatter(Q['x_train'] ,Q['y_train'],label = 'train_true', color='b')
        plt.scatter(Q['x_train'] ,Q['Y_net_train'][-1],label = 'train_net', color='y')
        #plt.title('result',fontsize=15)        
        #ax.set_yscale('log')
        plt.xlabel(r'x',fontsize=18)
        plt.ylabel(r'y',fontsize=18)
        plt.legend()
        #ax.set_xscale('log')
        # plt.xlim([-4.5,4.5])
        # plt.ylim([-4.5,4.5])
        plt.savefig(r'%s/y_test_%s.png'%(Q['FolderName'],itepch))
        plt.close()
        
    _= sess.run(train_op, feed_dict={X:Q['train_set'] , Y_true :Q['train_label']})



    savefile()

##########画一下最后一步##########
#############################计算RD与S_theta#############################

Y_net_train ,Weights_Net_1 ,Biases1_Net_1 = sess.run([Y,Weights1,Biases1], feed_dict={X:Q['train_set'], Y_true :Q['train_label']})
W_now = Weights_Net_1
b_now = Biases1_Net_1
W_dir_now_0 = np.transpose(np.concatenate((W_now[0],b_now[0]),axis=0))
W_dir_now_1 = np.transpose(np.concatenate((W_now[1],b_now[1]),axis=0))
Q['RD_w_0'].append(np.sqrt(np.sum((W_dir_now_0-W_dir_initial_0)**2 +1e-20 ) ) / np.sqrt(np.sum(W_dir_initial_0**2) +1e-20 )  )
Q['RD_w_1'].append(np.sqrt(np.sum((W_dir_now_1-W_dir_initial_1)**2 +1e-20 ) ) / np.sqrt(np.sum(W_dir_initial_1**2) +1e-20 )  )
Q['S_w_0'].append(np.log(Q['RD_w_0'][-1])/np.log(Q['nerons_each_layer_m']))
Q['S_w_1'].append(np.log(Q['RD_w_1'][-1])/np.log(Q['nerons_each_layer_m']))
Q['RD_w_0_last'] = Q['RD_w_0'][-1]
Q['RD_w_1_last'] = Q['RD_w_1'][-1]
print('RD_w_0', Q['RD_w_0'][-1])
print('RD_w_1', Q['RD_w_1'][-1])
# print('S_w_0', Q['S_w_0'][-1])
# print('S_w_1', Q['S_w_1'][-1])
    
#############################画第一层到第二层角度的分布以及w,b#############################

Y_net_train ,Weights_Net_1 ,Biases1_Net_1 = sess.run([Y,Weights1,Biases1], feed_dict={X:Q['train_set'], Y_true :Q['train_label']})
#######统计角度的分布########
rangee = np.linspace(start=-np.pi/2,stop=np.pi/2,num=20,endpoint=True)
Omega = Biases1_Net_1[0][0]/Weights_Net_1[0][0]
Q['Omega_layer1'].append(Omega)
plt.style.use('seaborn-whitegrid')
plt.hist(np.arctan(Q['Omega_layer1'][-1]),rangee,histtype='barstacked',edgecolor = 'k',color = 'steelblue',label='num')
plt.legend()
plt.xlim([-2.0,2.0])
plt.xlabel(r'angle',fontsize=18)
plt.ylabel(r'num',fontsize=18)
plt.savefig(r'%s/Omega_of_layer_1to2_%s.png'%(Q['FolderName'],itepch))
plt.close()
plt.style.use('default')

Q['w_dot_plot_layer1'].append(np.array(Weights_Net_1[0][0]))
Q['b_dot_plot_layer1'].append(np.array(Biases1_Net_1[0][0]))
ax = plt.gca()
plt.rcParams['savefig.dpi'] = 200 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
# plt.xlim(-2,2)
# plt.ylim(-2,2)
plt.xlabel(r'w',fontsize=18)
plt.ylabel(r'b',fontsize=18)
plt.scatter(Q['w_dot_plot_layer1'][0],Q['b_dot_plot_layer1'][0],color = 'r',s = 10, label = 'initial')
plt.scatter(Q['w_dot_plot_layer1'][-1],Q['b_dot_plot_layer1'][-1],color = 'g',s = 10,label = 'now')
plt.legend()
plt.savefig(r'%s/w_b_of_layer_1to2_%s.png'%(Q['FolderName'],itepch))
plt.close()

#############################画第二层到第三层的最大的W1_W2#############################

Y_net_train ,layers_hidden, Weights_Net_1 ,Biases1_Net_1 = sess.run([Y,layers,Weights1,Biases1], feed_dict={X:Q['train_set'], Y_true :Q['train_label']})
W_dir = np.transpose(np.concatenate((Weights_Net_1[1],Biases1_Net_1[1]),axis=0))
W_dir_lenth = np.sum(W_dir**2,axis=1)
pos = np.argsort(W_dir_lenth)  
Q['w1_w2_w1_of_layer_2to3'].append(W_dir[pos[-1],:])
Q['w1_w2_w2_of_layer_2to3'].append(W_dir[pos[-2],:])

ax = plt.gca()
plt.rcParams['savefig.dpi'] = 200 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
plt.xlabel(r'w1',fontsize=18)
plt.ylabel(r'w2',fontsize=18)
plt.scatter(Q['w1_w2_w1_of_layer_2to3'][0],Q['w1_w2_w2_of_layer_2to3'][0],color = 'r',s = 10,label='initial')
plt.scatter(Q['w1_w2_w1_of_layer_2to3'][-1],Q['w1_w2_w2_of_layer_2to3'][-1],color = 'g',s = 10,label='now')
plt.legend()
plt.savefig(r'%s/largest_w1_w2_layer2to3_%s.png'%(Q['FolderName'],itepch))
plt.close()

#############################画第二层到第三层的W的余弦相似度#############################

Y_net_train ,layers_hidden, Weights_Net_1 ,Biases1_Net_1 = sess.run([Y,layers,Weights1,Biases1], feed_dict={X:Q['train_set'], Y_true :Q['train_label']})
W_dir = np.transpose(np.concatenate((Weights_Net_1[1],Biases1_Net_1[1]),axis=0))
W_dir_lenth = np.sum(W_dir**2,axis=1)
W_dir_lenth = np.sort(W_dir_lenth)
W_dir_lenth_thredhood = W_dir_lenth[-1 * Q['thredhood']]
W_dir = np.transpose(np.concatenate((Weights_Net_1[1],Biases1_Net_1[1]),axis=0))
W_dir_lenth = np.sum(W_dir**2,axis=1)        
counter = []
for i in range(np.shape(W_dir_lenth)[0]):
    if W_dir_lenth[i]<W_dir_lenth_thredhood:
        counter.append(i)
W_dir = np.delete(W_dir,counter,axis=0)
##按照长度由长到短重新排序
W_dir_lenth_2 = np.sum(W_dir**2,axis=1)
W_dir_lenth_2 = np.argsort(W_dir_lenth_2)  
W_dir = W_dir[np.flipud(W_dir_lenth_2),:]
#计算余弦距离
cos_distance = np.zeros((Q['thredhood'],Q['thredhood']))
for i in range(Q['thredhood']):
    for j in range(Q['thredhood']):
        cos_distance[i][j] = np.sum(W_dir[i]*W_dir[j]) /( np.sqrt(np.sum(W_dir[i]**2))*np.sqrt(np.sum(W_dir[j]**2)))
#按照余弦距离聚类
cos_distance_matrix = cos_distance
order = []
order1 = range(Q['thredhood'])
k = 0
order_temp = []
order1 = []
for j in range(Q['thredhood']):
    if j != 0:
        for i in order2:
            if cos_distance_matrix[k][i] > Q['cos_dis_thr']:
                order.append(i)
            else:
                order1.append(i)
    else:
        for i in range(Q['thredhood']):
            if cos_distance_matrix[0][i] > Q['cos_dis_thr']:
                order.append(i)
            else:
                order1.append(i)
    order_temp = order_temp + order
    if len(order_temp) == Q['thredhood']:
        break
    #print(len(order_temp))
    k = order1[0]
    order2 = order1
    order1 = []
    order = []
cos_distance_matrix = cos_distance_matrix[order_temp,:]
cos_distance_matrix = cos_distance_matrix[:,order_temp]
Q['cos_distance_matrix'].append(cos_distance_matrix)

plt.rcParams['savefig.dpi'] = 300 #图片像素
plt.rcParams['figure.dpi'] = 300 #分辨率
fig,ax = plt.subplots()
ax = sns.heatmap(Q['cos_distance_matrix'][-1],vmin=-1,vmax=1,cmap='YlGnBu_r')#,linewidths = 0.05,xticklabels = np.arange(Q['thredhood']),yticklabels = np.arange(Q['thredhood']))
# ax.set_xticks(np.arange(40)) #设置x轴刻度
# ax.set_yticks(np.arange(40)) #设置y轴刻度
ax.xaxis.set_ticks_position('top')
# ax.set_xticklabels(range(Q['thredhood']),fontsize=3)
# ax.set_yticklabels(range(Q['thredhood']),fontsize=3)
ax.set_title("cos distance: relu",fontsize=16)
plt.savefig(r'%s/heatmap_1_step_%s.png'%(Q['FolderName'],itepch))
plt.close()

print(np.shape(cos_distance_matrix))
plt.figure()
ax = plt.gca()
plt.figure()
plt.imshow(Q['cos_distance_matrix'][-1],cmap='YlGnBu_r')
plt.colorbar()
ax.xaxis.set_ticks_position('top')
plt.title("cos distance: relu",fontsize=16)
plt.clim(-1, 1)
plt.savefig((r'%s/heatmap_2_step_%s.png'%(Q['FolderName'],itepch)))
plt.close()
plt.clf()
###############################更新一次
Q['step'].append(itepch)
train_loss = sess.run(Loss, feed_dict={X:Q['train_set'] , Y_true :Q['train_label']})
Q['loss'].append(train_loss)

print('training loss:',train_loss)
Y_test_network = sess.run(Y, feed_dict={X:Q['test_set']})
Q['Y_test_network'].append(Y_test_network)
##############画loss############
plt.figure()
ax = plt.gca()
plt.plot(Q['loss'])
plt.title('loss',fontsize=15)        
ax.set_yscale('log')
plt.xlabel(r'epochs',fontsize=18)
plt.ylabel(r'loss',fontsize=18)
ax.set_xscale('log')
# plt.xlim([-4.5,4.5])
# plt.ylim([-4.5,4.5])
plt.savefig(r'%s/loss_%s.png'%(Q['FolderName'],itepch))
plt.close()

##############Y_test############
Y_net_train = sess.run( Y, feed_dict={X:Q['train_set']})
Q['Y_net_train'].append(Y_net_train)
plt.figure()
ax = plt.gca()
Y_net_ends = [(Q['train_set'][0],Y_net_train[0]),(Q['train_set'][-1],Y_net_train[-1])]
# (line1_xs, line1_ys) = zip(*Y_net_ends)
# ax.add_line(Line2D(line1_xs, line1_ys, linewidth=1.0, color='red',label = 'comparison: a straight line'))
plt.plot(Q['test_set'],Q['Y_test_network'][-1], linewidth=1.0,label = 'test')
plt.scatter(Q['x_train'] ,Q['y_train'],label = 'train_true', color='b')
plt.scatter(Q['x_train'] ,Q['Y_net_train'][-1],label = 'train_net', color='y')
#plt.title('result',fontsize=15)        
#ax.set_yscale('log')
plt.xlabel(r'x',fontsize=18)
plt.ylabel(r'y',fontsize=18)
plt.legend()
#ax.set_xscale('log')
# plt.xlim([-4.5,4.5])
# plt.ylim([-4.5,4.5])
plt.savefig(r'%s/y_test_%s.png'%(Q['FolderName'],itepch))
plt.close()
    
savefile()
