import os,sys
import matplotlib
matplotlib.use('Agg')   
import pickle
import time  
import shutil 
import tensorflow.compat.v1 as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt   
#from BasicFunc import mySaveFig, mkdir
import platform
from mpl_toolkits.mplot3d import Axes3D 
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from datetime import datetime
# import seaborn as sns
from matplotlib.lines import Line2D
import math

# 定义一些函数
@tf.function
def mkdir(fn):
    if not os.path.isdir(fn):
        os.mkdir(fn)


def savefile(): #保存模型参数的函数
    with open('%s/threelayers_Phasediagram.pkl'%(FolderName), 'wb') as f:  # Python 3: open(..., 'wb')
        pickle.dump(Q, f, protocol=4)
    #序列化对象，将对象obj保存到文件file中去
    text_file = open("%s/threelayers_Phasediagram.txt"%(FolderName), "w")
    for para in Q:
        if np.size(Q[para])>1000:
            continue
        text_file.write('%s: %s\n'%(para,Q[para]))
    
    for para in sys.argv: 
        text_file.write('%s  '%(para))
    text_file.close()


#   <X,W>+b,以这种形式的话，X的列大小是神经元个数，行大小是样本个数
#   构建参数w和b
def initializer_generate(inp_size=10 , hidden_layer = [20] , out_size=10 ):
    Weights_ini = []
    Biases_ini = []
    W_ini_0 = tf.random_normal(shape=[inp_size, hidden_layer[0]],dtype = 'float32',mean=0.0,stddev=Q['W_1_beta_1'])
    B_ini_0 = tf.random_normal(shape = [1, hidden_layer[0]],dtype = 'float32',mean=0.0,stddev=Q['W_1_beta_1'])
    Weights_ini.append(W_ini_0)
    Biases_ini.append(B_ini_0)
    for k in range(len(hidden_layer)-1):
        W_ini = tf.random_normal(shape=[hidden_layer[k], hidden_layer[k+1]],dtype = 'float32',mean=0.0,stddev=Q['W_2_beta_2'])
        B_ini = tf.random_normal(shape = [1, hidden_layer[k+1]],dtype = 'float32',mean=0.0,stddev=Q['W_2_beta_2'])
        Weights_ini.append(W_ini)
        Biases_ini.append(B_ini)
        # print(k)
    W_ini = tf.random_normal(shape=[hidden_layer[-1], out_size],dtype = 'float32',mean=0.0,stddev=Q['a_beta_3'])
    Weights_ini.append(W_ini)
    return Weights_ini,Biases_ini

def Init_DNN( inp_size=10 , hidden_layer = [20] , out_size=10 ,Weights_ini=0,Biases_ini=0): 
    Weights = []
    Biases = []
    W = tf.Variable(Weights_ini[0])
    B = tf.Variable(Biases_ini[0])
    Weights.append(W)
    Biases.append(B)
    for k in range(len(hidden_layer)-1):
        W = tf.Variable(Weights_ini[k+1])
        B = tf.Variable(Biases_ini[k+1])
        Weights.append(W)
        Biases.append(B)
    W = tf.Variable(Weights_ini[-1])
    Weights.append(W)
    return Weights, Biases    

#   构建网络 在全连接的情况下，甚至网络结构都隐藏在了Weights和Biases里面
def multilayer(X, Weights, Biases, activation = tf.nn.relu): 
    outs = []
    layers = len(Weights)
    #print(layers)
    H = X
    outs.append(H)
    for k in range(layers-1):
        W = Weights[k]
        B = Biases[k]
        H = activation(tf.add(tf.matmul(H, W), B)) #这个是正常的，下面是特殊的
        outs.append(H)
    W = Weights[-1]
    out = (1 / Q['alpha']) * tf.matmul(H, W)
    outs.append(out)
    return out , outs

# 导入数据集，并scale
mnist = keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0 - 0.5, x_test / 255.0 -0.5
print("********************************")
print("mnist is loaded!")
print("********************************")

os.environ["CUDA_VISIBLE_DEVICES"]='0'
M = [10000]
tf.disable_eager_execution()

y_train = y_train.reshape(60000, 1)
n_image = 45
ImgSize = 28
x_train=np.reshape(x_train[0:n_image,:,:],[n_image,ImgSize*ImgSize])
y_train=y_train[0:n_image,:]
# Gamma2 = [0.2, 0.1, 0, -0.1]
# Gamma3 = [0.9, 1.1, 1.3, 1.5, 1.7, 1.9, 2.1]
Gamma2 = [0.1]
Gamma3 = [1.3]
tol = 1e-5
in_learning_rate =3.8e-5


for g in range(len(Gamma2)):
    gamma2 = Gamma2[g]
    for h in range(len(Gamma3)):
        gamma3 = Gamma3[h]
        for s in range(len(M)):
            Q={}
            Q['in_learning_rate']= in_learning_rate
            Q['tol'] = tol
            Q['nerons_each_layer_m'] = M[s]
            #循环用来指派不同的gamma2，gamma3，测试时候就看一组gamma2与gamma3
            Q['W_1_beta_1'] = Q['nerons_each_layer_m']**(-0.5 * gamma3 + 0.75 * gamma2)
            Q['W_2_beta_2'] = Q['nerons_each_layer_m']**(-0.5 * gamma3 - 0.25 * gamma2)
            Q['alpha'] =  Q['nerons_each_layer_m']**(-0.5 * gamma3 + 0.25 * gamma2)
            Q['a_beta_3'] = Q['W_2_beta_2']

            i = 40
            # 至此已经找到了较为合适的lr，下面把剩下的七次重复跑掉（此时仍有可能出现loss上升的情况，直接重开一把，不再下调lr
            while i < 70:
                success = 1
                Q['x_train'] = x_train
                Q['x_test'] = x_test
                Q['y_train'] = y_train
                Q['y_test'] = y_test
                Q['input_size'] = 784
                Q['output_size'] = 1    
                Q['hidden_layer'] = [Q['nerons_each_layer_m'],Q['nerons_each_layer_m']]
                Epoches = 100000
                Q['thredhood'] = int( Q['nerons_each_layer_m'] / 2 )
                Q['cos_dis_thr'] = 0.5
                Q['RD_w_0'] = []
                Q['RD_w_1'] = []
                Q['RD_w_0_last'] = 0
                Q['RD_w_1_last'] = 0
                Q['S_w_0'] = []
                Q['S_w_1'] = []
                Q['Omega_layer1'] = []
                Q['cos_distance_matrix'] = []
                Q['step'] = []
                Q['loss']=[]
                Q['w_plot'] = []
                Q['b_plot'] = []
                Q['w_dot_plot_layer1'] = []
                Q['b_dot_plot_layer1'] = []
                Q['w1_w2_w1_of_layer_2to3'] = []
                Q['w1_w2_w2_of_layer_2to3'] = []
                Q['b_dot_plot_layer1'] = []
                Q['kappa_1'] = Q['a_beta_3'] / Q['W_2_beta_2']
                Q['kappa_2'] = Q['a_beta_3'] / Q['W_1_beta_1']
                Q['kappa_3'] = Q['a_beta_3']*Q['W_1_beta_1']*Q['W_2_beta_2'] / Q['alpha'] 
                Q['gamma_1'] = - np.log(Q['kappa_1']) / np.log(Q['nerons_each_layer_m'])
                Q['gamma_2'] = - np.log(Q['kappa_2']) / np.log(Q['nerons_each_layer_m'])
                Q['gamma_3'] = - np.log(Q['kappa_3']) / np.log(Q['nerons_each_layer_m'])

                Q['Y_test_network'] = []
                Q['Y_net_train'] = []
                print('gamma_2:',Q['gamma_2'])
                print('gamma_3:',Q['gamma_3'])

                sBaseDir0='fitnd' 
                # BaseDir = '../../../nn/fitnd/'
                if platform.system() =='Windows':
                    # device_n="0"
                    BaseDir0 = r'E:\deep_study\20210601_threelayers_limit_width\fitn2_gam2_0.5/%s'%(sBaseDir0) 
                else:
                    # device_n="0"
                    BaseDir0=sBaseDir0 
                    matplotlib.use('Agg')
                subsubsubFolderName = 'm_%s_%s_%s'%(Q['nerons_each_layer_m'],Q['in_learning_rate'],i) 
                subsubFolderName = 'gam3_%.2f'%(Q['gamma_3']) 
                subFolderName = 'gam2_%.2f'%(Q['gamma_2'])
                FolderName = '%s/%s/'%(BaseDir0,subFolderName)
                print(FolderName)
                mkdir(BaseDir0) 
                mkdir(FolderName)
                FolderName = r'%s/%s/%s/'%(BaseDir0,subFolderName,subsubFolderName)
                print(FolderName)
                mkdir(FolderName)
                FolderName = r'%s/%s/%s/%s/'%(BaseDir0,subFolderName,subsubFolderName,subsubsubFolderName)
                print(FolderName)
                # if os.path.isdir(FolderName):
                #     shutil.rmtree(FolderName)
                #     print('old Folder is removed')
                mkdir(FolderName)
                print(FolderName)
                print('****************************')
                
                # 下面这个不知道干什么的
                # if True: #not platform.system()=='Windows':
                #     shutil.copy(__file__,'%s%s'%(FolderName,os.path.basename(__file__)))


                #############一些关于训练的参数##############
                Q['data_num_train'] = n_image
                #Q['data_num_test'] = 10000


                ############一些参数#############
                Q['FolderName'] = FolderName
                Q['train_set'],Q['train_label'] = np.reshape(x_train,[Q['data_num_train'],784]) , np.reshape(y_train,[Q['data_num_train'],1]) 
                # Q['test_set']= np.reshape(x_test,[Q['data_num_test'],784]) 


                ###############构建网络###################
                with tf.variable_scope('Graph',reuse=tf.AUTO_REUSE) as scope:
                    X = tf.placeholder(tf.float32,shape=[None,784],name = 'X')
                    Y_true = tf.placeholder(tf.float32,shape=[None,1],name = 'Y_ture')
                    Weights0 ,Biases0 = initializer_generate(inp_size=Q['input_size'],hidden_layer=Q['hidden_layer'],out_size=Q['output_size'])
                    Weights1 ,Biases1 = Init_DNN(inp_size=Q['input_size'],hidden_layer=Q['hidden_layer'],out_size=Q['output_size'],Weights_ini=Weights0,Biases_ini=Biases0)
                    Y , layers = multilayer(X,Weights1,Biases1)
                    Loss=tf.reduce_mean((Y-Y_true)**2)
                    adam = tf.train.GradientDescentOptimizer(learning_rate=Q['in_learning_rate'])
                    train_op = adam.minimize(Loss)
                    print('the project seems healthy!')

                ##################

                config = tf.ConfigProto(allow_soft_placement=True) #以下是用来指派设备的
                gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7)
                config.gpu_options.allow_growth=True
                sess = tf.Session(config=config)
                sess.run(tf.global_variables_initializer()) #初始化参数，没关系下面又初始化了一次 注意，这里仅仅只是初始化了，并没有真正的跑过，如果真的run，是会出问题的；本人改了代码试验过，是不行的，（注意y_true是一维的，因此可以光波到y的维数。）
                saver = tf.train.Saver() 

                #####开始训练
                for itepch in range(Epoches):
                    #############################计算RD与S_theta#############################
                    if itepch == 0:
                        Y_net_train ,Weights_Net_1 ,Biases1_Net_1 = sess.run([Y,Weights1,Biases1], feed_dict={X:Q['train_set'], Y_true :Q['train_label']})
                        W_initial = Weights_Net_1
                        b_inital = Biases1_Net_1
                        W_dir_initial_0 = np.transpose(np.concatenate((W_initial[0],b_inital[0]),axis=0))
                        W_dir_initial_1 = np.transpose(np.concatenate((W_initial[1],b_inital[1]),axis=0))
                    if itepch%100==0:
                        Y_net_train ,Weights_Net_1 ,Biases1_Net_1 = sess.run([Y,Weights1,Biases1], feed_dict={X:Q['train_set'], Y_true :Q['train_label']})
                        W_now = Weights_Net_1
                        b_now = Biases1_Net_1
                        W_dir_now_0 = np.transpose(np.concatenate((W_now[0],b_now[0]),axis=0))
                        W_dir_now_1 = np.transpose(np.concatenate((W_now[1],b_now[1]),axis=0))
                        Q['RD_w_0'].append(np.sqrt(np.sum((W_dir_now_0-W_dir_initial_0)**2 +1e-20 ) ) / np.sqrt(np.sum(W_dir_initial_0**2) +1e-20 )  )
                        Q['RD_w_1'].append(np.sqrt(np.sum((W_dir_now_1-W_dir_initial_1)**2 +1e-20 ) ) / np.sqrt(np.sum(W_dir_initial_1**2) +1e-20 )  )
                        Q['S_w_0'].append(np.log(Q['RD_w_0'][-1])/np.log(Q['nerons_each_layer_m']))
                        Q['S_w_1'].append(np.log(Q['RD_w_1'][-1])/np.log(Q['nerons_each_layer_m']))
                        Q['RD_w_0_last'] = Q['RD_w_0'][-1]
                        Q['RD_w_1_last'] = Q['RD_w_1'][-1]
                        print('RD_w_0', Q['RD_w_0'][-1])
                        print('RD_w_1', Q['RD_w_1'][-1])

                    ###############################更新一次
                    Q['step'].append(itepch)
                    train_loss = sess.run(Loss, feed_dict={X:Q['train_set'] , Y_true :Q['train_label']})
                    Q['loss'].append(train_loss)
                    if train_loss < Q['tol']:
                        break
                    if itepch > 1:
                        if train_loss > Q['loss'][-2]:
                            success = 0
                            break
                    if itepch%100 == 0:
                        print('training loss:',train_loss)
                        # Y_test_network = sess.run(Y, feed_dict={X:Q['test_set']})
                        # Q['Y_test_network'].append(Y_test_network)
                        ##############画loss############
                        # plt.figure()
                        # ax = plt.gca()
                        # plt.plot(Q['loss'])
                        # plt.title('loss',fontsize=15)        
                        # ax.set_yscale('log')
                        # plt.xlabel(r'epochs',fontsize=18)
                        # plt.ylabel(r'loss',fontsize=18)
                        # ax.set_xscale('log')
                        # # plt.xlim([-4.5,4.5])
                        # # plt.ylim([-4.5,4.5])
                        # plt.savefig(r'%s/loss_%s.png'%(Q['FolderName'],itepch))
                        # plt.close()

                        ##############Y_test############
                        Y_net_train = sess.run( Y, feed_dict={X:Q['train_set']})
                        Q['Y_net_train'].append(Y_net_train)

                        
                    _= sess.run(train_op, feed_dict={X:Q['train_set'] , Y_true :Q['train_label']})

                ##########画一下最后一步##########
                #############################计算RD与S_theta#############################

                Y_net_train ,Weights_Net_1 ,Biases1_Net_1 = sess.run([Y,Weights1,Biases1], feed_dict={X:Q['train_set'], Y_true :Q['train_label']})
                W_now = Weights_Net_1
                b_now = Biases1_Net_1
                W_dir_now_0 = np.transpose(np.concatenate((W_now[0],b_now[0]),axis=0))
                W_dir_now_1 = np.transpose(np.concatenate((W_now[1],b_now[1]),axis=0))
                Q['RD_w_0'].append(np.sqrt(np.sum((W_dir_now_0-W_dir_initial_0)**2 +1e-20 ) ) / np.sqrt(np.sum(W_dir_initial_0**2) +1e-20 )  )
                Q['RD_w_1'].append(np.sqrt(np.sum((W_dir_now_1-W_dir_initial_1)**2 +1e-20 ) ) / np.sqrt(np.sum(W_dir_initial_1**2) +1e-20 )  )
                Q['S_w_0'].append(np.log(Q['RD_w_0'][-1])/np.log(Q['nerons_each_layer_m']))
                Q['S_w_1'].append(np.log(Q['RD_w_1'][-1])/np.log(Q['nerons_each_layer_m']))
                Q['RD_w_0_last'] = Q['RD_w_0'][-1]
                Q['RD_w_1_last'] = Q['RD_w_1'][-1]
                print('RD_w_0', Q['RD_w_0'][-1])
                print('RD_w_1', Q['RD_w_1'][-1])

                ###############################更新一次
                Q['step'].append(itepch)
                train_loss = sess.run(Loss, feed_dict={X:Q['train_set'] , Y_true :Q['train_label']})
                Q['loss'].append(train_loss)

                print('training loss:',train_loss)
                # Y_test_network = sess.run(Y, feed_dict={X:Q['test_set']})
                # Q['Y_test_network'].append(Y_test_network)
                ##############画loss############
                plt.figure()
                ax = plt.gca()
                plt.plot(Q['loss'])
                plt.title('loss',fontsize=15)        
                ax.set_yscale('log')
                plt.xlabel(r'epochs',fontsize=18)
                plt.ylabel(r'loss',fontsize=18)
                ax.set_xscale('log')
                # plt.xlim([-4.5,4.5])
                # plt.ylim([-4.5,4.5])
                plt.savefig(r'%s/loss_%s.png'%(Q['FolderName'],itepch))
                plt.close()

                ##############Y_test############
                Y_net_train = sess.run( Y, feed_dict={X:Q['train_set']})
                Q['Y_net_train'].append(Y_net_train)

                    
                savefile()
                # if success==0:
                #     if os.path.isdir(FolderName):
                #         shutil.rmtree(FolderName)
                #         print(FolderName)
                #         print("Folder is removed")
                #         mkdir(FolderName)
                i = i + success

                tf.keras.backend.clear_session()
                tf.reset_default_graph()