# Import Necessary Packages
import tensorflow as tf
from tensorflow import keras
from keras.optimizers import SGD, Adam
from keras import datasets
from keras.utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import LearningRateScheduler
from tensorflow.keras.layers import Activation, Input, Dense, GlobalAveragePooling2D, BatchNormalization, Flatten
import numpy as np
import time
import pickle

def load_config(yaml_file):
    with open(yaml_file, 'r') as file:
        config = yaml.load(file, Loader=yaml.SafeLoader)
    return config

def define_resnet50_CNN(input_shape, num_classes, weights):
    
    base = tf.keras.applications.resnet50.ResNet50(include_top=False, weights=weights, input_shape=(224, 224, 3), classes=num_classes)
    inputs = Input(input_shape)
    
    x = tf.keras.layers.UpSampling2D(size=(7, 7))(inputs)
    
    x = base(x)
    # x = GlobalAveragePooling2D()(x)     
    x = Flatten()(x)
    x = Dense(units=1024, activation='relu')(x)
    x = Dense(units=512, activation='relu')(x) 
    
    outputs = Dense(units=num_classes, activation='softmax')(x)
    
    model = keras.Model(inputs, outputs, name='SNN')
    return model

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
# Please choose one of the seeds [54833, 7509, 40972, 11840, 46857]

# # Number of experiments to run
# exp_num = 5

# verbose for training 
verbose = True
# verbose = True
batch_size = 128

epochs = 20

# Define the save path
full_path = 'train100ResNet/'+str(seed)
full_path_his = 'train100ResNet/his/'+str(seed)

# Set random seed
# keras.utils.set_random_seed(seed)
# tf.config.experimental.enable_op_determinism()

# Prepare training dataset
(x_train, y_train), (x_test, y_test) = datasets.cifar100.load_data()

x_train = x_train / 255.0
x_test = x_test / 255.0
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
y_train = to_categorical(y_train, 100)
y_test = to_categorical(y_test, 100)

# standard normalizing
x_train = (x_train - np.array([[[0.4914, 0.4822, 0.4465]]])) / np.array([[[0.2023, 0.1994, 0.2010]]])
x_test = (x_test - np.array([[[0.4914, 0.4822, 0.4465]]])) / np.array([[[0.2023, 0.1994, 0.2010]]])
val_samples = -10000

x_val = x_train[val_samples:]
y_val = y_train[val_samples:]


x_train = x_train[:val_samples]
y_train = y_train[:val_samples]

datagen = ImageDataGenerator(zca_epsilon=1e-06, width_shift_range=0.1, height_shift_range=0.1, fill_mode='nearest',horizontal_flip=True)
datagen.fit(x_train)

# augmented_data = datagen.flow(x_train, y_train, batch_size=len(x_train), shuffle=False)

# Get the processed data (x_train) and labels
# x_train, y_train = augmented_data.next()

def lr_scheduler(epoch):
    """Learning Rate Schedule

    Learning rate is scheduled to be reduced after 80, 120, 160, 180 epochs.
    Called automatically every epoch as part of callbacks during training.

    # Arguments
        epoch (int): The number of epochs

    # Returns
        lr (float32): learning rate
    """
    lr = 1e-3
    if epoch > 180:
        lr *= 0.5e-3
    elif epoch > 160:
        lr *= 1e-3
    elif epoch > 120:
        lr *= 1e-2
    # elif epoch > 80:
    # elif epoch > 40:
    elif epoch > 15:
    # elif epoch > 25:
        lr *= 1e-1
    # print('Learning rate: ', lr)
    return lr

lr_scheduler_mod = LearningRateScheduler(lr_scheduler)


# Build and compile ResNet
# opt = SGD(learning_rate=0.1, momentum=0.9)
opt = Adam(learning_rate=0.001)
SNN = define_CNN_model(input_shape=(32, 32, 3), num_classes=100)
SNN.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['acc'])

start = time.time()   

result = SNN.fit(datagen.flow(x_train, y_train, batch_size=batch_size), epochs=epochs, verbose=verbose, validation_data=(x_val, y_val), validation_batch_size=batch_size, callbacks=[lr_scheduler_mod])
  
end = time.time()
print(end-start)

SNN.evaluate(x_test, y_test)

weigts_to_save = SNN.get_weights()
with open(full_path + '_weights', 'wb') as w:
    pickle.dump(weigts_to_save, w)

# Save trainig history
with open(full_path_his + '_result', 'wb') as file:
    pickle.dump(result.history, file)