import os,sys
import matplotlib
matplotlib.use('Agg')   
import pickle
import time  
import shutil 
import tensorflow.compat.v1 as tf
import numpy as np
import matplotlib.pyplot as plt   
#from BasicFunc import mySaveFig, mkdir
import platform
from mpl_toolkits.mplot3d import Axes3D 
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from datetime import datetime
from mayavi import mlab
from sympy import *
from mayavi.mlab import *
from matplotlib.lines import Line2D

os.environ["CUDA_VISIBLE_DEVICES"]='0'
tf.disable_eager_execution()

#############生成目录##############
def mkdir(fn): #熟悉，做目录
    if not os.path.isdir(fn):
        os.mkdir(fn)
ran = int(np.absolute(np.random.normal([1])*100000))//int(1)
sBaseDir0='fitnd' 
# BaseDir = '../../../nn/fitnd/'
if platform.system() =='Windows':
    # device_n="0"
    BaseDir0 = r'XXX/%s'%(sBaseDir0) 
else:
    # device_n="0"
    BaseDir0=sBaseDir0 
    matplotlib.use('Agg')
subFolderName = '%s'%(ran) 
FolderName = '%s/%s/'%(BaseDir0,subFolderName)
mkdir(BaseDir0) 
mkdir(FolderName)


if True: #not platform.system()=='Windows':
    shutil.copy(__file__,'%s%s'%(FolderName,os.path.basename(__file__)))


#############一些关于训练的参数##############
Q={}
Q['data_num_train'] = 40
Q['data_num_test'] = 400
D_g_theta_weight = []

###########生成目标函数##########
x_train = np.linspace(start=-1.0,stop=1.5,num=40,endpoint = True)
y_train = np.sin(x_train)+ 1/2 * np.sin(3*x_train)



plt.scatter(x_train,y_train,c='b',s=10)
plt.savefig(r'XXX\true.png')
plt.close()

x_test = np.linspace(start=-1.0,stop =1.5, num =Q['data_num_test'],endpoint=True)

print(np.shape(x_train))
print(np.shape(y_train))


############一些参数#############
Q['FolderName'] = FolderName
Q['train_set'],Q['train_label'] = np.reshape(x_train,[Q['data_num_train'],1]) , np.reshape(y_train,[Q['data_num_train'],1]) 
Q['test_set']= np.reshape(x_test,[Q['data_num_test'],1]) 
Q['input_size'] = 1
Q['output_size'] = 1    
Q['hidden_layer'] = [100]
Q['in_learning_rate']= 2e-4
Epoches = 1050
Q['loss']=[]
Q['tol'] = 1e-4
Q['w_plot'] = []
Q['b_plot'] = []
Q['w_dot_plot'] = []
Q['b_dot_plot'] = []
Q['plot_num'] = 75
Q['Omega'] = []
Q['Omega_ini'] = 0
Q['Amplitude'] = []
Q['Amplitude_ini'] = 0
###########定义g_{theta}(w)#########
def relu(inX):
    return np.maximum(0,inX)

def D_relu(iinX):
    if iinX > 0 :
        return 1
    if iinX <= 0 :
        return 0

def D_tanh(iinX):
    return 1-np.tanh(iinX)**2


def D_g_theta_w(w,b,a,y_net_train):
    D_g_theta_w = 0
    for ir in range(Q['data_num_train']):
        D_g_theta_w = D_g_theta_w  - a * ((w * x_train[ir] + b * 1)**1 * D_tanh(w * x_train[ir] + b * 1) + 1 * (w * x_train[ir] + b * 1)**0 * np.tanh(w * x_train[ir] + b * 1) ) * np.array([x_train[ir],1]) * (y_net_train[ir] - y_train[ir])
    #print('shape of D_g_theta_w:',np.shape(D_g_theta_w))
    return(D_g_theta_w)


###################
#   <X,W>+b,以这种形式的话，X的列大小是神经元个数，行大小是样本个数
#   构建参数w和b
def initializer_generate(inp_size=10 , hidden_layer = [20] , out_size=10 ):
    Weights_ini = []
    Biases_ini = []
    W_ini_0 = tf.random_normal(shape=[inp_size, hidden_layer[0]],dtype = 'float32',mean=0.0,stddev=0.005)
    B_ini_0 = tf.random_normal(shape = [1, hidden_layer[0]],dtype = 'float32',mean=0.0,stddev=0.005)
    Weights_ini.append(W_ini_0)
    Biases_ini.append(B_ini_0)
    for k in range(len(hidden_layer)-1):
        W_ini = tf.random_normal(shape=[hidden_layer[k], hidden_layer[k+1]],dtype = 'float32',mean=0.0,stddev=0.005)
        B_ini = tf.random_normal(shape = [1, hidden_layer[k+1]],dtype = 'float32',mean=0.0,stddev=0.005)
        Weights_ini.append(W_ini)
        Biases_ini.append(B_ini)
    W_ini = tf.random_normal(shape=[hidden_layer[-1], out_size],dtype = 'float32',mean=0.0,stddev=0.005)
    Weights_ini.append(W_ini)
    return Weights_ini,Biases_ini

def Init_DNN( inp_size=10 , hidden_layer = [20] , out_size=10 ,Weights_ini=0,Biases_ini=0): 
    Weights = []
    Biases = []
    W = tf.Variable(Weights_ini[0])
    B = tf.Variable(Biases_ini[0])
    Weights.append(W)
    Biases.append(B)
    for k in range(len(hidden_layer)-1):
        W = tf.Variable(Weights_ini[k+1])
        B = tf.Variable(Biases_ini[k+1])
        Weights.append(W)
        Biases.append(B)
    W = tf.Variable(Weights_ini[-1])
    Weights.append(W)
    return Weights, Biases    

#   构建网络 在全连接的情况下，甚至网络结构都隐藏在了Weights和Biases里面
def multilayer(X, Weights, Biases, activation = tf.nn.tanh): 
    layers = len(Weights)
    #print(layers)
    H = X
    for k in range(layers-1):
        W = Weights[k]
        B = Biases[k]
        H = activation(tf.add(tf.matmul(H, W), B)) * (tf.add(tf.matmul(H, W), B))**1  #这个是正常的，下面是特殊的
    W = Weights[-1]
    out = tf.matmul(H, W)
    return out

###############构建网络###################
with tf.variable_scope('Graph',reuse=tf.AUTO_REUSE) as scope:
    X = tf.placeholder(tf.float32,shape=[None,1],name = 'X')
    Y_true = tf.placeholder(tf.float32,shape=[None,1],name = 'Y_ture')
    Weights0 ,Biases0 = initializer_generate(inp_size=Q['input_size'],hidden_layer=Q['hidden_layer'],out_size=Q['output_size'])
    Weights1 ,Biases1 = Init_DNN(inp_size=Q['input_size'],hidden_layer=Q['hidden_layer'],out_size=Q['output_size'],Weights_ini=Weights0,Biases_ini=Biases0)
    Y = multilayer(X,Weights1,Biases1)
    Loss=tf.reduce_mean((Y-Y_true)**2)
    adam = tf.train.AdamOptimizer(learning_rate=Q['in_learning_rate'])
    train_op = adam.minimize(Loss)
    print('the project seems healthy!')

##################

config = tf.ConfigProto(allow_soft_placement=True) #以下是用来指派设备的
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7)
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer()) #初始化参数，没关系下面又初始化了一次 注意，这里仅仅只是初始化了，并没有真正的跑过，如果真的run，是会出问题的；本人改了代码试验过，是不行的，（注意y_true是一维的，因此可以光波到y的维数。）
saver = tf.train.Saver() 

#####开始训练
for itepch in range(Epoches):
    if itepch%10==0:
        if itepch<=510:
            ranges = 0.3
        else:
            ranges = 0.6
        Y_net_train ,Weights_Net_1 ,Biases1_Net_1 = sess.run([Y,Weights1,Biases1], feed_dict={X:Q['train_set'], Y_true :Q['train_label']})

        #######统计角度的分布########
        Omega1 = []
        Omega2 = []
        Omega3 = []
        Omega4 = []
        Amplitude1 = []    
        Amplitude2 = []    
        Amplitude3 = []    
        Amplitude4 = []           
        Omega1_ini = []
        Omega2_ini = []
        Omega3_ini = []
        Omega4_ini = []
        Amplitude1_ini = []    
        Amplitude2_ini = []    
        Amplitude3_ini = []    
        Amplitude4_ini = []   
        rangee = np.linspace(start=-np.pi/2,stop=np.pi/2,num=20,endpoint=True)
        Omega = Biases1_Net_1[0][0]/Weights_Net_1[0][0]
        Amplitude = np.sqrt(Biases1_Net_1[0][0]**2 + Weights_Net_1[0][0]**2)
        Q['Amplitude'].append(Amplitude)
        Q['Omega'].append(Omega)
        for i in range(np.shape(Weights_Net_1[0][0])[0]):
            if Weights_Net_1[0][0][i] >= 0:
                if Biases1_Net_1[0][0][i] >=0:
                    Omega1.append(np.arctan(Biases1_Net_1[0][0][i]/Weights_Net_1[0][0][i]))
                    Amplitude1.append(np.sqrt(Biases1_Net_1[0][0][i]**2 + Weights_Net_1[0][0][i]**2))
            if Weights_Net_1[0][0][i] >= 0:
                if Biases1_Net_1[0][0][i] <0:
                    Omega2.append(np.arctan(Biases1_Net_1[0][0][i]/Weights_Net_1[0][0][i]))
                    Amplitude2.append(np.sqrt(Biases1_Net_1[0][0][i]**2 + Weights_Net_1[0][0][i]**2))
            if Weights_Net_1[0][0][i] < 0:
                if Biases1_Net_1[0][0][i] >=0:
                    Omega3.append(np.arctan(Biases1_Net_1[0][0][i]/Weights_Net_1[0][0][i])+np.pi)
                    Amplitude3.append(np.sqrt(Biases1_Net_1[0][0][i]**2 + Weights_Net_1[0][0][i]**2))
            if Weights_Net_1[0][0][i] < 0:
                if Biases1_Net_1[0][0][i] <0:
                    Omega4.append(np.arctan(Biases1_Net_1[0][0][i]/Weights_Net_1[0][0][i])-np.pi) 
                    Amplitude4.append(np.sqrt(Biases1_Net_1[0][0][i]**2 + Weights_Net_1[0][0][i]**2))
        Omega_real = Omega1 + Omega2 + Omega3 + Omega4
        Amplitude_real = Amplitude1 + Amplitude2 + Amplitude3 + Amplitude4
        if itepch == 0:
            for i in range(np.shape(Weights_Net_1[0][0])[0]):
                if Weights_Net_1[0][0][i] >= 0:
                    if Biases1_Net_1[0][0][i] >=0:
                        Omega1_ini.append(np.arctan(Biases1_Net_1[0][0][i]/Weights_Net_1[0][0][i]))
                        Amplitude1_ini.append(np.sqrt(Biases1_Net_1[0][0][i]**2 + Weights_Net_1[0][0][i]**2))
                if Weights_Net_1[0][0][i] >= 0:
                    if Biases1_Net_1[0][0][i] <0:
                        Omega2_ini.append(np.arctan(Biases1_Net_1[0][0][i]/Weights_Net_1[0][0][i]))
                        Amplitude2_ini.append(np.sqrt(Biases1_Net_1[0][0][i]**2 + Weights_Net_1[0][0][i]**2))
                if Weights_Net_1[0][0][i] < 0:
                    if Biases1_Net_1[0][0][i] >=0:
                        Omega3_ini.append(np.arctan(Biases1_Net_1[0][0][i]/Weights_Net_1[0][0][i])+np.pi)
                        Amplitude3_ini.append(np.sqrt(Biases1_Net_1[0][0][i]**2 + Weights_Net_1[0][0][i]**2))
                if Weights_Net_1[0][0][i] < 0:
                    if Biases1_Net_1[0][0][i] <0:
                        Omega4_ini.append(np.arctan(Biases1_Net_1[0][0][i]/Weights_Net_1[0][0][i])-np.pi) 
                        Amplitude4_ini.append(np.sqrt(Biases1_Net_1[0][0][i]**2 + Weights_Net_1[0][0][i]**2))
            Omega_real_ini = Omega1_ini + Omega2_ini + Omega3_ini + Omega4_ini
            Amplitude_real_ini = Amplitude1_ini + Amplitude2_ini + Amplitude3_ini + Amplitude4_ini
        #Amplitude = Amplitude/np.max(Amplitude)
        print(np.shape(Amplitude))
        # plt.style.use( 'seaborn-whitegrid')
        # plt.hist(np.arctan(Omega),rangee,histtype='barstacked',edgecolor = 'k',color = 'steelblue',label='num')
        plt.scatter(Omega_real,Amplitude_real/np.max(Amplitude),color = 'r',s=10 , label = 'finial')
        plt.scatter(Omega_real_ini,Amplitude_real_ini/np.max(Amplitude),color = 'c',s=10 , label='initial')
        #plt.legend()
        my_x_ticks = [-2.0,0,2.0]
        plt.xticks(my_x_ticks)
        my_y_ticks = [0.0,0.5,1.0]
        plt.yticks(my_y_ticks)
        plt.xlim([-3.5,3.5])
        plt.xlabel(r'orientation',fontsize=22)
        plt.ylabel(r'A',rotation=0,fontsize=22)
        plt.tick_params(axis='both',which='major',labelsize=20)
        plt.tight_layout()
        plt.savefig(r'%s/Omega_%s.png'%(Q['FolderName'],itepch))
        plt.close()
        # plt.style.use( 'default')


        w_plot = []
        b_plot = []
        w_dot_plot = []
        b_dot_plot = []
        # print(np.shape(Weights_Net_1[0]))
        # print(np.shape(Biases1_Net_1[0]))
        # print(np.shape(Weights_Net_1[1]))
        w_fake = np.linspace(-ranges,ranges, num = Q['plot_num'] )
        b_fake = np.linspace(-ranges,ranges, num = Q['plot_num'] )
        w_fake, b_fake = np.meshgrid(w_fake, b_fake)
        # print(np.shape(w_fake))
        w_fakes =[]
        b_fakes =[]
        ###distance
        distance = np.zeros((Q['plot_num'] ,Q['plot_num'] ))
        W_U = np.zeros((Q['plot_num'] ,Q['plot_num'] ))
        B_V = np.zeros((Q['plot_num'] ,Q['plot_num'] ))
        for irt in range(Q['plot_num'] ):
            for iiirt in range(Q['plot_num'] ):
                D_g_theta_weight_temp = D_g_theta_w(a = 1 , w = w_fake[irt][iiirt],b = b_fake[irt][iiirt],y_net_train = Y_net_train)
                distance[irt][iiirt] = np.sqrt((D_g_theta_weight_temp[0]-0)**2+(D_g_theta_weight_temp[1]-0)**2)
                W_U[irt][iiirt] = D_g_theta_weight_temp[0]
                B_V[irt][iiirt] = D_g_theta_weight_temp[1]
                # print(np.shape(D_g_theta_weight_temp))
                if np.sqrt((D_g_theta_weight_temp[0]-0)**2+(D_g_theta_weight_temp[1]-0)**2) < 0.0001:
                    w_fakes.append(w_fake[irt][iiirt])
                    b_fakes.append(b_fake[irt][iiirt])
        for irt in range(Q['hidden_layer'][0]):
            D_g_theta_weight_temp = D_g_theta_w(a = Weights_Net_1[1][irt][0] ,w = Weights_Net_1[0][0][irt],b = Biases1_Net_1[0][0][irt],y_net_train = Y_net_train)
            D_g_theta_weight.append(D_g_theta_weight_temp)
            w_plot.append(Weights_Net_1[0][0][irt])
            b_plot.append(Biases1_Net_1[0][0][irt])
            w_dot_plot.append(D_g_theta_weight_temp[0])
            b_dot_plot.append(D_g_theta_weight_temp[1])
            # print(D_g_theta_weight_temp[0]/D_g_theta_weight_temp[1])
            # print(Weights_Net_1[0][0][irt]/Biases1_Net_1[0][0][irt])
        
        ########画一下原点的方向########
        D_g_theta_weight_0 = D_g_theta_w(a = 1 , w = 0 ,b = 0 ,y_net_train = Y_net_train)
        len_0 = np.sqrt(D_g_theta_weight_0[0]**2+D_g_theta_weight_0[1]**2)
        print(len_0)

        Q['w_plot'].append(np.array(w_plot))
        Q['b_plot'].append(np.array(b_plot))
        Q['w_dot_plot'].append(np.array(w_dot_plot))
        Q['b_dot_plot'].append(np.array(b_dot_plot))
        #print(np.max(np.array(b_dot_plot)))
        max_temp = np.sqrt(np.max(Q['w_dot_plot'][-1]**2+Q['b_dot_plot'][-1]**2))
        len_temp = np.sqrt(Q['w_dot_plot'][-1]**2+Q['b_dot_plot'][-1]**2)+1e-10
        # print(np.max(np.array(w_dot_plot)/max_temp))
        # print(np.max(np.array(b_dot_plot)/max_temp))

        rangee = np.linspace(start=-np.pi/2,stop=np.pi/2,num=20,endpoint=True)
        Omega = np.array(b_dot_plot)/(np.array(w_dot_plot)+1e-12)
        plt.style.use( 'seaborn-whitegrid')
        plt.hist(np.arctan(Omega),rangee,histtype='barstacked',edgecolor = 'k',color = 'steelblue',label='num')
        plt.legend()
        plt.xlim([-2.0 ,2.0 ])
        plt.xlabel(r'angle',fontsize=18)
        plt.ylabel(r'num',fontsize=18)
        plt.tick_params(axis='both',which='major',labelsize=20)
        plt.tight_layout()
        plt.savefig(r'%s/D_Omega_%s.png'%(Q['FolderName'],itepch))
        plt.close()
        plt.style.use( 'default')


        ax = plt.gca()
        plt.rcParams['savefig.dpi'] = 200 #图片像素
        plt.rcParams['figure.dpi'] = 200 #分辨率
        plt.xlim(-ranges,ranges)
        plt.ylim(-ranges,ranges)
        plt.xlabel(r'w',fontsize=22)
        plt.ylabel(r'b',rotation=0,fontsize=22)
        plt.scatter(np.array(w_plot), np.array(b_plot),color = 'g',s = 20)
        # strm = ax.streamplot(w_fake, b_fake, W_U, B_V, color=distance, linewidth=1,cmap='autumn',density=1,arrowsize = 1,arrowstyle ='-|>' )
        strm = ax.streamplot(w_fake, b_fake, W_U, B_V, color='k', linewidth=1.0, density=2 ,arrowsize = 1,arrowstyle ='-|>' )
        # plt.colorbar(strm.lines)
        CS = plt.contour(w_fake,b_fake,distance,levels=[1e-8], colors=['r'], linewidths = 2)
        # CS = plt.contour(w_fake,b_fake,distance,levels=3, colors=['b', 'c','r'], linewidths = 2)
        # ax.clabel(CS, inline=1, fontsize= 12 ,fmt = '%1.4f',inline_spacing = 3,use_clabeltext = True)
        # ax.quiver(np.array(w_plot), np.array(b_plot), np.array(w_dot_plot)/max_temp, np.array(b_dot_plot)/max_temp, units = 'inches', scale = 1,scale_units = 'inches',color = 'k' ,linestyle = '--') #被最长的归一化了
        # ax.quiver(np.array(w_plot), np.array(b_plot), np.array(w_dot_plot)/len_temp, np.array(b_dot_plot)/len_temp, units = 'inches', scale = 3,scale_units = 'inches',color = 'coral' ,linestyle = '--') #长度均为1
        # ax.quiver([0,0], [0,0], D_g_theta_weight_0[0]/len_0,D_g_theta_weight_0[1]/len_0, units = 'inches', scale = 0.75,scale_units = 'inches',color = 'red' ,linestyle = '--',pivot='mid',minshaft=1,headwidth=2.5,width=0.030,headlength=6) #长度均为1
        ax.arrow(0.18 * -1/2 * D_g_theta_weight_0[0]/len_0,0.18 *  -1/2 * D_g_theta_weight_0[1]/len_0,0.18 *  D_g_theta_weight_0[0]/len_0,0.18 * D_g_theta_weight_0[1]/len_0, width=0.004,fill=False,ec='red',linestyle = '--',length_includes_head=True,head_width = 0.012)
        for ir in range(len(np.array(w_plot))):
            if np.abs(np.array(w_dot_plot)[ir]/len_temp[ir])<1e-20:
                continue
            # ax.arrow(np.array(w_plot)[ir], np.array(b_plot)[ir],0.02 * np.array(w_dot_plot)[ir]/len_temp[ir],0.02 * np.array(b_dot_plot)[ir]/len_temp[ir], color = 'coral',width = 0.0015,head_width = 0.0045) #长度均为1
        my_x_ticks = [-ranges+0.1,0,ranges-0.1]
        plt.xticks(my_x_ticks)
        my_y_ticks = [-ranges+0.1,0,ranges-0.1]
        plt.yticks(my_y_ticks)
        plt.scatter(np.array(w_plot), np.array(b_plot),color = 'coral',s = 20)
        plt.plot(np.zeros((50,)),np.linspace(-ranges,ranges,num=50,endpoint=True),color='b',linewidth=1,linestyle='--')
        plt.plot(np.linspace(-ranges,ranges,num=50,endpoint=True),np.zeros((50,)),color='b',linewidth=1,linestyle='--')
        plt.tick_params(axis='both',which='major',labelsize=20)
        plt.tight_layout()
        plt.savefig(r'%s/field_pattern_%s.png'%(Q['FolderName'],itepch))
        plt.close()



    train_loss = sess.run(Loss, feed_dict={X:Q['train_set'] , Y_true :Q['train_label']})
    Q['loss'].append(train_loss)

    if train_loss < Q['tol']:
        break
    if itepch%10==0:
        print('training loss:',train_loss)
        Y_test_network = sess.run(Y, feed_dict={X:Q['test_set']})
        ##############画loss############
        plt.figure()
        ax = plt.gca()
        plt.plot(Q['loss'])
        plt.title('loss',fontsize=15)        
        ax.set_yscale('log')
        plt.xlabel(r'epochs',fontsize=20)
        plt.ylabel(r'loss',fontsize=20)
        plt.tick_params(axis='both',which='major',labelsize=20)
        plt.tight_layout()
        #ax.set_xscale('log')
        # plt.xlim([-4.5,4.5])
        # plt.ylim([-4.5,4.5])
        plt.savefig(r'%s/loss_%s.png'%(Q['FolderName'],itepch))
        plt.close()


        ##############Y_test############
        Y_net_train = sess.run( Y, feed_dict={X:Q['train_set']})
        x_poly = x_test
        Y_poly_network = sess.run(Y, feed_dict={X:Q['test_set']})[:,0]
        print(np.shape(x_poly))
        print(np.shape(Y_poly_network))
        z1 = np.polyfit(x_poly, Y_poly_network, 1) # 用7次多项式拟合，可改变多项式阶数；
        print(z1)
        p1 = np.poly1d(z1) #得到多项式系数，按照阶数从高到低排列
        Y_poly = x_poly**1 * z1[0] +  x_poly**0 * z1[1]# + x_poly**1 * z1[2] + x_poly**0 * z1[3] 
        print(p1) #显示多项式
        
        plt.figure()
        plt.rcParams['savefig.dpi'] = 200 #图片像素
        plt.rcParams['figure.dpi'] = 200 #分辨率
        #ax.add_line(Line2D(line1_xs, line1_ys, linewidth=1.0, color='red',label = 'auxiliary line'))
        plt.plot(Q['test_set'],Y_test_network, linewidth=2.0, color='blue',label = 'NN output', linestyle='-')
        plt.plot(x_poly,Y_poly, linewidth=2.0, color='red',label = 'auxiliary line',linestyle='--')
        plt.scatter(x_train,y_train,label = 'training data', color='g')
        # plt.scatter(x_train,Y_net_train,label = 'train_net', color='y')
        # plt.title('result',fontsize=15)        
        #ax.set_yscale('log')
        plt.xlabel(r'x',fontsize=24)
        plt.ylabel(r'y',rotation=0,fontsize=24)
        my_x_ticks = [-1.0,-0.5,0,0.5,1.0,1.5]
        plt.xticks(my_x_ticks)
        my_y_ticks = [-1.0,-0.5,0.0,0.5,1.0]
        plt.yticks(my_y_ticks)
        plt.tick_params(axis='both',which='major',labelsize=20)
        # plt.legend()
        # ax.set_xscale('log')
        # plt.xlim([-4.5,4.5])
        # plt.ylim([-4.5,4.5])
        plt.tight_layout()
        plt.savefig(r'%s/y_test_%s.png'%(Q['FolderName'],itepch))
        plt.close()
        
    plt.clf
    _= sess.run(train_op, feed_dict={X:Q['train_set'] , Y_true :Q['train_label']})


def savefile(): #保存模型参数的函数
    with open('%s/cifar10_检查.pkl'%(FolderName), 'wb') as f:  # Python 3: open(..., 'wb')
        pickle.dump(Q, f, protocol=4)
    #序列化对象，将对象obj保存到文件file中去
    text_file = open("%s/Output_检查.txt"%(FolderName), "w")
    for para in Q:
        if np.size(Q[para])>20:
            continue
        text_file.write('%s: %s\n'%(para,Q[para]))
    
    for para in sys.argv: 
        text_file.write('%s  '%(para))
    text_file.close()

savefile()