import os,sys
import matplotlib
matplotlib.use('Agg')   
import pickle
import time  
import shutil 
import tensorflow.compat.v1 as tf
import numpy as np
import matplotlib.pyplot as plt   
#from BasicFunc import mySaveFig, mkdir
import platform
from mpl_toolkits.mplot3d import Axes3D 
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from datetime import datetime
from mayavi import mlab
from sympy import *
from mayavi.mlab import *
from matplotlib.lines import Line2D
from sklearn.cluster import KMeans


os.environ["CUDA_VISIBLE_DEVICES"]='0'
tf.disable_eager_execution()

#############生成目录##############
def mkdir(fn): #熟悉，做目录
    if not os.path.isdir(fn):
        os.mkdir(fn)
ran = int(np.absolute(np.random.normal([1])*100000))//int(1)
sBaseDir0='fitnd_mnist' 
# BaseDir = '../../../nn/fitnd/'
if platform.system() =='Windows':
    # device_n="0"
    BaseDir0 = r'XXX/%s'%(sBaseDir0) 
else:
    # device_n="0"
    BaseDir0=sBaseDir0 
    matplotlib.use('Agg')
subFolderName = '%s'%(ran) 
FolderName = '%s/%s/'%(BaseDir0,subFolderName)
mkdir(BaseDir0) 
mkdir(FolderName)


if True: #not platform.system()=='Windows':
    shutil.copy(__file__,'%s%s'%(FolderName,os.path.basename(__file__)))

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(path='mnist.npz')
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10, dtype='float32')
y_test = tf.keras.utils.to_categorical(y_test, num_classes=10, dtype='float32')
print(np.shape(y_train))

############一些参数#############
Q={}
Q['FolderName'] = FolderName
Q['train_set'],Q['train_label'] = np.reshape(x_train/255-0.5,[60000,28*28]) , np.reshape(y_train,[60000,10]) 
Q['test_set'],Q['test_label'] = np.reshape(x_test/255-0.5,[10000,28*28]) , np.reshape(y_test,[10000,10]) 

Q['input_size'] = 28*28
Q['output_size'] = 10   
Q['hidden_layer'] = [30]
Q['in_learning_rate']= 5e-5
Epoches = 540
Q['loss']=[]
Q['tol'] = 1e-4
Q['cos_distance_matrix'] = []
Q['step'] = []
Q['W'] = []
###################
#   <X,W>+b,以这种形式的话，X的列大小是神经元个数，行大小是样本个数
#   构建参数w和b
def initializer_generate(inp_size=10 , hidden_layer = [20] , out_size=10 ):
    Weights_ini = []
    Biases_ini = []
    W_ini_0 = tf.random_normal(shape=[inp_size, hidden_layer[0]],dtype = 'float32',mean=0.0,stddev=0.001)
    B_ini_0 = tf.random_normal(shape = [1, hidden_layer[0]],dtype = 'float32',mean=0.0,stddev=0.001)
    Weights_ini.append(W_ini_0)
    Biases_ini.append(B_ini_0)
    for k in range(len(hidden_layer)-1):
        W_ini = tf.random_normal(shape=[hidden_layer[k], hidden_layer[k+1]],dtype = 'float32',mean=0.0,stddev=0.001)
        B_ini = tf.random_normal(shape = [1, hidden_layer[k+1]],dtype = 'float32',mean=0.0,stddev=0.001)
        Weights_ini.append(W_ini)
        Biases_ini.append(B_ini)
    W_ini = tf.random_normal(shape=[hidden_layer[-1], out_size],dtype = 'float32',mean=0.0,stddev=0.001)
    Weights_ini.append(W_ini)
    return Weights_ini,Biases_ini

def Init_DNN( inp_size=10 , hidden_layer = [20] , out_size=10 ,Weights_ini=0,Biases_ini=0): 
    Weights = []
    Biases = []
    W = tf.Variable(Weights_ini[0])
    B = tf.Variable(Biases_ini[0])
    Weights.append(W)
    Biases.append(B)
    for k in range(len(hidden_layer)-1):
        W = tf.Variable(Weights_ini[k+1])
        B = tf.Variable(Biases_ini[k+1])
        Weights.append(W)
        Biases.append(B)
    W = tf.Variable(Weights_ini[-1])
    Weights.append(W)
    return Weights, Biases    

#   构建网络 在全连接的情况下，甚至网络结构都隐藏在了Weights和Biases里面
def multilayer(X, Weights, Biases, activation = tf.nn.tanh): 
    layers = len(Weights)
    #print(layers)
    H = X
    for k in range(layers-1):
        W = Weights[k]
        B = Biases[k]
        H = activation(tf.add(tf.matmul(H, W), B)) * (tf.add(tf.matmul(H, W), B))**2 #这个是正常的，下面是特殊的
    W = Weights[-1]
    out = tf.matmul(H, W)
    return out

###############构建网络###################
with tf.variable_scope('Graph',reuse=tf.AUTO_REUSE) as scope:
    X = tf.placeholder(tf.float32,shape=[None,28*28],name = 'X')
    Y_true = tf.placeholder(tf.float32,shape=[None,10],name = 'Y_ture')
    Weights0 ,Biases0 = initializer_generate(inp_size=Q['input_size'],hidden_layer=Q['hidden_layer'],out_size=Q['output_size'])
    Weights1 ,Biases1 = Init_DNN(inp_size=Q['input_size'],hidden_layer=Q['hidden_layer'],out_size=Q['output_size'],Weights_ini=Weights0,Biases_ini=Biases0)
    Y = multilayer(X,Weights1,Biases1)
    Loss=tf.reduce_mean((Y-Y_true)**2)
    adam = tf.train.AdamOptimizer(learning_rate=Q['in_learning_rate'])
    train_op = adam.minimize(Loss)
    print('the project seems healthy!')

##################

config = tf.ConfigProto(allow_soft_placement=True) #以下是用来指派设备的
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7)
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer()) #初始化参数，没关系下面又初始化了一次 注意，这里仅仅只是初始化了，并没有真正的跑过，如果真的run，是会出问题的；本人改了代码试验过，是不行的，（注意y_true是一维的，因此可以光波到y的维数。）
saver = tf.train.Saver() 

#####开始训练
for itepch in range(Epoches):
    if itepch%5==0:
        Y_net_train ,Weights_Net_1 ,Biases1_Net_1 = sess.run([Y,Weights1,Biases1], feed_dict={X:Q['train_set'], Y_true :Q['train_label']})
        W_dir = np.transpose(np.concatenate((Weights_Net_1[0],Biases1_Net_1[0]),axis=0))
        Q['W'].append(W_dir)
        print(np.shape(W_dir))
        W_dir = W_dir/np.reshape(np.sqrt(np.sum(W_dir**2,axis=1)),(Q['hidden_layer'][0],1))
        print(np.max(W_dir))
        cos_distance = np.zeros((Q['hidden_layer'][0],Q['hidden_layer'][0]))
        print(np.shape(W_dir[2]))
        for i in range(Q['hidden_layer'][0]):
            for j in range(Q['hidden_layer'][0]):
                cos_distance[i][j] = np.sum(W_dir[i]*W_dir[j]) /( np.sqrt(np.sum(W_dir[i]**2))*np.sqrt(np.sum(W_dir[j]**2)))
        # print(cos_distance[6][5])
        Q['cos_distance_matrix'].append(cos_distance)
        Q['step'].append(itepch)

    train_loss = sess.run(Loss, feed_dict={X:Q['train_set'] , Y_true :Q['train_label']})
    Q['loss'].append(train_loss)
    if train_loss < Q['tol']:
        break
    if itepch%5==0:
        print('training loss:',train_loss)
        Y_test_network = sess.run(Y, feed_dict={X:Q['test_set']})
        ##############画loss############
        plt.figure()
        ax = plt.gca()
        plt.plot(Q['loss'])
        plt.title('loss',fontsize=15)        
        ax.set_yscale('log')
        plt.xlabel(r'epochs',fontsize=18)
        plt.ylabel(r'loss',fontsize=18)
        #ax.set_xscale('log')
        # plt.xlim([-4.5,4.5])
        # plt.ylim([-4.5,4.5])
        plt.savefig(r'%s/loss_%s.png'%(Q['FolderName'],itepch))
        plt.close()

    plt.clf
    _= sess.run(train_op, feed_dict={X:Q['train_set'] , Y_true :Q['train_label']})


def savefile(): #保存模型参数的函数
    with open('%s/object.pkl'%(FolderName), 'wb') as f:  # Python 3: open(..., 'wb')
        pickle.dump(Q, f, protocol=4)
    #序列化对象，将对象obj保存到文件file中去
    text_file = open("%s/object.txt"%(FolderName), "w")
    for para in Q:
        if np.size(Q[para])>20:
            continue
        text_file.write('%s: %s\n'%(para,Q[para]))
    
    for para in sys.argv: 
        text_file.write('%s  '%(para))
    text_file.close()

Q['train_set'],Q['train_label'] = 0,0
Q['test_set'],Q['test_label'] = 0,0
savefile()