import tensorflow as tf
import tensorflow.keras
import os
import numpy as np
from tensorflow.keras import datasets, layers, models, applications
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten, Input, Add, ZeroPadding2D, BatchNormalization, Conv2D, AveragePooling2D, MaxPooling2D, GlobalMaxPooling2D, GlobalAveragePooling2D
# from tensorflow.keras.utils.np_utils import *
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.models import Model
from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import EarlyStopping
import tensorflow.keras.backend as K
from matplotlib.pyplot import imshow
import scipy.misc
from tensorflow.keras.initializers import glorot_uniform,RandomNormal,glorot_normal
# from tensorflow.keras.utils import plot_model
# from tensorflow.keras.utils.vis_utils import model_to_dot
from IPython.display import SVG
import pydot
# from tensorflow.keras.applications.imagenet_utils import preprocess_input
# from tensorflow.keras.utils.data_utils import get_file
# from tensorflow.keras.utils import layer_utils
from tensorflow.keras.preprocessing import image
import pickle
import sys
import platform
import time  
import shutil 
from cifar10_res18like_load_model import get_cos_distance

gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
os.environ["CUDA_VISIBLE_DEVICES"]='0'

# np.random.seed(1)
allset = np.random.permutation(50000)
allset_train = allset[10000:50000]
allset_validation = allset[0:10000]


R={}

# tf.compat.v1.set_random_seed(40)

def mkdir(fn): 
    if not os.path.isdir(fn):
        os.mkdir(fn)
ran = int(np.absolute(np.random.normal([1])*100000))//int(1)
sBaseDir0='fitnd_cifar_100' 
# BaseDir = '../../../nn/fitnd/'
if platform.system() =='Windows':
    # device_n="0"
    BaseDir0 = r'XXX/%s'%(sBaseDir0) 
else:
    # device_n="0"
    BaseDir0=sBaseDir0 
    matplotlib.use('Agg')
subFolderName = '%s'%(ran) 
FolderName = '%s/%s/'%(BaseDir0,subFolderName)
mkdir(BaseDir0) 
mkdir(FolderName)
R['FolderName'] = FolderName


if True: #not platform.system()=='Windows':
    shutil.copy(__file__,'%s%s'%(FolderName,os.path.basename(__file__)))

def savefile(): #保存模型参数的函数
    with open('%s/objs.pkl'%(FolderName), 'wb') as f:  # Python 3: open(..., 'wb')
        pickle.dump(R, f, protocol=4)
    #序列化对象，将对象obj保存到文件file中去
    text_file = open("%s/Output.txt"%(FolderName), "w")
    for para in R:
        if np.size(R[para])>20:
            continue
        text_file.write('%s: %s\n'%(para,R[para]))
    
    for para in sys.argv: 
        text_file.write('%s  '%(para))
    text_file.close()
    #写到txt方便看

def tanhx(x):
    return  x * K.tanh(x)

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar100.load_data(label_mode='fine') # coarse
print(np.shape(train_images))
print(np.shape(train_labels))


# Normalize pixel values to be between 0 and 1
train_images0, test_images0 = train_images / 255.0 - 0.5 , test_images / 255.0 - 0.5
train_labels0=tf.keras.utils.to_categorical(train_labels, num_classes = 100)
test_labels0=tf.keras.utils.to_categorical(test_labels, num_classes = 100)

train_images0 = np.reshape(train_images0,[-1,32,32,3])
test_images0 = np.reshape(test_images0,[-1,32,32,3])

train_images=train_images0[allset_train,:,:,:]
train_labels=train_labels0[allset_train,:]

validation_images = train_images0[allset_validation,:,:,:]
validation_labels = train_labels0[allset_validation,:]


test_images=test_images0[0:10000,:,:,:]
test_labels=test_labels0[0:10000,:]

print(train_images.shape)
print(train_labels.shape)
print(validation_images.shape)
print(validation_labels.shape)
print(test_images.shape)
print(test_labels.shape)
print(train_labels[-10:-1])

#resnet做加和操作，因此用add函数，
# googlenet以及densenet做filter的拼接，因此用concatenate
#add和concatenate的区别参考链接：https://blog.csdn.net/u012193416/article/details/79479935
def Resnet18_block(X, filters, s ,stage, block):
    """
    Implementation of the identity block as defined in Figure 4
    Arguments:
    X -- input tensor of shape (m, n_H_prev, n_W_prev, n_C_prev)
    f -- integer, specifying the shape of the middle CONV's window for the main path
    filters -- python list of integers, defining the number of filters in the CONV layers of the main path
    stage -- integer, used to name the layers, depending on their position in the network
    block -- string/character, used to name the layers, depending on their position in the network
    Returns:
    X -- output of the identity block, tensor of shape (n_H, n_W, n_C)
    """
    # defining name basis
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'
    # Retrieve Filters
    F1, F2 = filters
    # Save the input value. You'll need this later to add back to the main path. 
    X_shortcut = X
    if s != 1:
        X_shortcut = Conv2D(filters = F1, kernel_size = (1, 1), strides = (s,s), padding = 'same', name = conv_name_base + '2c', kernel_initializer = glorot_uniform(seed=3))(X_shortcut)
        X_shortcut = BatchNormalization(axis = 3, name = bn_name_base + '2c')(X_shortcut)
        X_shortcut = Activation('relu')(X_shortcut)       
    # First component of main path
    X = Conv2D(filters = F1, kernel_size = (3, 3), strides = (s,s), padding = 'same', name = conv_name_base + '2a', kernel_initializer = glorot_uniform(seed=4))(X)
    X = BatchNormalization(axis = 3, name = bn_name_base + '2a')(X)
    X = Activation('relu')(X)
    ### START CODE HERE ###
    # Second component of main path (≈3 lines)
    X = Conv2D(filters = F2, kernel_size = (3, 3), strides = (1,1), padding = 'same', name = conv_name_base + '2b', kernel_initializer = glorot_uniform(seed=5))(X)
    X = BatchNormalization(axis=3, name = bn_name_base + '2b')(X)
    X = Activation('relu')(X)
    # Final step: Add shortcut value to main path, and pass it through a RELU activation (≈2 lines)
    X = layers.add([X, X_shortcut])
    X = Activation('relu')(X)
    ### END CODE HERE ###
    return X

def CNN_block(X, filters, s, stage, block, seedd = 10):
    """
    Implementation of the identity block as defined in Figure 4
    Arguments:
    X -- input tensor of shape (m, n_H_prev, n_W_prev, n_C_prev)
    f -- integer, specifying the shape of the middle CONV's window for the main path
    filters -- python list of integers, defining the number of filters in the CONV layers of the main path
    stage -- integer, used to name the layers, depending on their position in the network
    block -- string/character, used to name the layers, depending on their position in the network
    Returns:
    X -- output of the identity block, tensor of shape (n_H, n_W, n_C)
    """
    # defining name basis
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'
    # Retrieve Filters
    F1, F2 = filters
    # First component of main path
    X = Conv2D(filters = F1, kernel_size = (3, 3), strides = (s,s), padding = 'same', name = conv_name_base + '2a', kernel_initializer = glorot_uniform(seed=seedd))(X)
    X = BatchNormalization(axis = 3, name = bn_name_base + '2a')(X)
    X = Activation('relu')(X)
    ### START CODE HERE ###
    ### END CODE HERE ###
    return X

def ResNet18(input_shape = (32,32,3),classes=100):

    X_input = Input(input_shape)
    #X1 = ZeroPadding2D((3,3))(X_input)
    #stage 1 
    X = Conv2D(64,(3,3),strides = (1,1),name = 'conv1',padding='same', kernel_initializer = glorot_uniform(seed=2))(X_input)
    X = BatchNormalization(axis = 3, name = 'bn_conv1')(X)
    X = Activation('relu')(X)
    X = Dropout(0.2)(X)
    #stage2
    X = Resnet18_block(X, filters=[64,64],s=1, stage = 2, block = 'a')
    X = Resnet18_block(X, filters=[64,64],s=1, stage = 2, block = 'b')
    X = Dropout(0.1)(X)
    #Stage3
    X = Resnet18_block(X, filters=[128,128],s=2, stage = 3, block = 'a')
    X = Resnet18_block(X, filters=[128,128],s=1, stage = 3, block = 'b')
    X = Dropout(0.2)(X)
    #Stage4
    X = Resnet18_block(X, filters=[256,256],s=2, stage = 4, block = 'a')
    X = Resnet18_block(X, filters=[256,256],s=1, stage = 4, block = 'b')
    X = Dropout(0.2)(X)
    #Stage5
    X = Resnet18_block(X, filters=[512,512],s=2, stage = 5, block = 'a')
    X = Resnet18_block(X, filters=[512,512],s=1, stage = 5, block = 'b')
    X = Dropout(0.2)(X)

    X = GlobalAveragePooling2D()(X)
    X = Dropout(0.2)(X)

    X = Flatten()(X)
    X = Dense(1024,activation=None, kernel_initializer = RandomNormal(mean=0.0, stddev=1/1024**1.5, seed=None) ,name='dense1')(X)
    X = Activation('relu')(X)
    X = Dropout(0.2)(X)
    X = Dense(1024,activation=None, kernel_initializer = RandomNormal(mean=0.0, stddev=1/1024**1.5, seed=None), name='dense2')(X)
    X = Activation('relu')(X)
    X = Dropout(0.2)(X)
    # X = Dense(1024,activation='relu', kernel_initializer = glorot_uniform(seed=0),name='dense1')(X)
    # X = Dropout(0.2)(X)
    # X = Dense(1024,activation='relu', kernel_initializer = glorot_uniform(seed=0),name='dense2')(X)
    # X = Dropout(0.2)(X)
    X = Dense(classes, activation='softmax', name='fc' + str(classes), kernel_initializer = glorot_uniform(seed=0))(X)

    model = Model(inputs = X_input, outputs = X, name='ResNet18')

    '''
    X_input = Input(input_shape)
    X = applications.vgg16.VGG16(include_top=None, weights=None, input_tensor=X_input, input_shape=(32,32,3), pooling=None, classes=10)
    #X1 = ZeroPadding2D((3,3))(X_input)
    #stage 1 
    # X = Conv2D(1,(3,3),strides = (2,2),name = 'conv1',padding='same', kernel_initializer = glorot_uniform(seed=2))(X_input)
    # X = BatchNormalization(axis = 3, name = 'bn_conv1')(X)
    # X = Activation('relu')(X)
    # X = Dropout(0.1)(X)
    # stage2
    # X = CNN_block(X, filters=[16,16], s=2, stage = 2, block = 'a', seedd = 12)
    # X = CNN_block(X, filters=[16,16],s=2, stage = 2, block = 'b', seedd = 13)
    # X = Dropout(0.1)(X)


    # X = AveragePooling2D(pool_size=(4,4))(X)
    # X = Dropout(0.1)(X)
    X = Flatten()(X_input)
    # X = Dense(4096,activation='relu', kernel_initializer = glorot_uniform() ,name='dense1')(X)
    # X = Dense(4096,activation='relu', kernel_initializer = glorot_uniform() ,name='dense2')(X)
    X = Dense(4096,activation='relu', kernel_initializer = RandomNormal(mean=0.0, stddev=1/128**2, seed=None) ,name='dense2')(X)
    X = Dense(4096,activation='relu', kernel_initializer = RandomNormal(mean=0.0, stddev=1/128**2, seed=None) ,name='dense1')(X)
    # X = Dense(10,activation='tanh', kernel_initializer = RandomNormal(mean=0.0, stddev=np.sqrt(2/(10+3072)), seed=None) ,name='dense1')(X)
    # X = Dense(128,activation='tanh', kernel_initializer = glorot_normal(seed=None) ,name='dense1')(X)
    # X = Dense(64,activation='relu', kernel_initializer = RandomNormal(mean=0.0, stddev=0.1, seed=None) ,name='dense5')(X)
    # X = Dense(64,activation='relu', kernel_initializer = glorot_uniform(seed=0),name='dense2')(X)
    X = Dense(classes, activation='softmax', name='fc' + str(classes), kernel_initializer = RandomNormal(mean=0.0, stddev=0.1, seed=None))(X)

    model = Model(inputs = X_input, outputs = X, name='ResNet18')

    '''
    return model



model = ResNet18(input_shape = (32,32,3),classes=100)

model.summary()

###学习器
sgd = tf.keras.optimizers.SGD(learning_rate=0.005, momentum=0.9, decay=0, nesterov=True, name='SGD')
# sgd2 = tf.keras.optimizers.SGD(learning_rate=0.005, momentum=0.9, decay=0, nesterov=True, name='SGD')
adam1 = tf.keras.optimizers.Adam(
    learning_rate=0.0000001, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=False,
    name='Adam'
)

adam2 = tf.keras.optimizers.Adam(
    learning_rate=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=False,
    name='Adam'
)

model.compile(optimizer=adam1,
              loss= tf.keras.losses.CategoricalCrossentropy(),
          metrics=['accuracy'])

early_stopping = EarlyStopping(monitor='val_loss', patience=40, verbose=2)

history_val_acc = []
history_train_acc = []
history_val_loss = []
history_train_loss = []

for epochs in range(100):
    # get_cos_distance(model = model, target_size = 1024 , pathh = FolderName ,epochs = epochs + 1 )
    get_cos_distance(model = model, target_size = 1024 , pathh = FolderName ,epochs = epochs )


    historyy = model.fit(train_images, train_labels, epochs=1, 
                        batch_size=128,
                        validation_data=(validation_images, validation_labels),
                        shuffle=True, verbose=1)

    history_val_acc = history_val_acc + historyy.history['val_accuracy'] 
    history_train_acc = history_train_acc + historyy.history['accuracy']
    history_val_loss = history_val_loss + historyy.history['val_loss']
    history_train_loss = history_train_loss + historyy.history['loss']


    if epochs % 2 == 0:
        plt.figure()
        plt.rcParams['savefig.dpi'] = 200 #图片像素
        plt.rcParams['figure.dpi'] = 200 #分辨率
        plt.plot(history_val_acc, linewidth=2.0, color='blue',label = 'val acc', linestyle='-')
        plt.plot(history_train_acc, linewidth=2.0, color='yellow',label = 'train acc', linestyle='-')
        plt.xlabel(r'epochs',fontsize=22)
        plt.ylabel(r'acc',rotation=0,fontsize=22)
        plt.tick_params(axis='both',which='major',labelsize=16)
        plt.legend()
        plt.tight_layout()
        plt.savefig(r'%s/acc_step_%d.png'%(R['FolderName'],epochs))
        plt.close()

        plt.figure()
        plt.rcParams['savefig.dpi'] = 200 #图片像素
        plt.rcParams['figure.dpi'] = 200 #分辨率
        plt.plot(history_val_loss, linewidth=2.0, color='orange',label = 'val loss', linestyle='-')
        plt.plot(history_train_loss, linewidth=2.0, color='grey',label = 'train loss', linestyle='-')
        plt.xlabel(r'epochs',fontsize=22)
        plt.ylabel(r'acc',rotation=0,fontsize=22)
        plt.tick_params(axis='both',which='major',labelsize=16)
        plt.legend()
        plt.tight_layout()
        plt.savefig(r'%s/loss_step_%d.png'%(R['FolderName'],epochs))
        plt.close()

    

R['history_val_acc'] = history_val_acc
R['history_train_acc'] = history_train_acc
R['history_val_loss'] = history_val_loss
R['history_train_loss'] = history_train_loss

savefile()


model.save(r'%s/model.h5'%R['FolderName'])  # creates a HDF5 file 'my_model.h5'

del model  # deletes the existing model
# returns a compiled model
# identical to the previous one

# model = load_model(r'%s/model.h5'%R['FolderName'])

#model.summary()
#a=model.predict(train_images[0:1,:,:,:])

#layername = 'conv_2'

#x = train_images[0:1,:,:,:]  #[1,28,28,1] 的形状
 
# 将模型作为一个层，输出第7层的输出
#layer_model = Model(inputs=model.input,outputs=model.layers[5].output)
 
#feature=layer_model.predict(x)
 
#print(feature)
