# Import Necessary Packages
import tensorflow as tf
from tensorflow import keras
from keras.optimizers import SGD, Adam
from keras import datasets
from keras.utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import LearningRateScheduler
from tensorflow.keras.layers import Activation, Input, Dense, GlobalAveragePooling2D, BatchNormalization, Flatten
import numpy as np
import time
import pickle
import argparse
import yaml

def load_config(yaml_file):
    with open(yaml_file, 'r') as file:
        config = yaml.load(file, Loader=yaml.SafeLoader)
    return config

def define_resnet50_CNN(input_shape, num_classes, weights):
    
    base = tf.keras.applications.resnet50.ResNet50(include_top=False, weights=weights, input_shape=(224, 224, 3), classes=num_classes)
    inputs = Input(input_shape)
    
    x = tf.keras.layers.UpSampling2D(size=(7, 7))(inputs)
    
    x = base(x)
    # x = GlobalAveragePooling2D()(x)     
    x = Flatten()(x)
    x = Dense(units=1024, activation='relu')(x)
    x = Dense(units=512, activation='relu')(x) 
    
    outputs = Dense(units=num_classes, activation='softmax')(x)
    
    model = keras.Model(inputs, outputs, name='SNN')
    return model


print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
# Please choose one of the seeds [54833, 7509, 40972, 11840, 46857]

# Accept a YAML file as a command-line argument
parser = argparse.ArgumentParser(description='Process parameters from a YAML file.')
parser.add_argument('config_file', type=str, help='Path to the YAML configuration file')
args = parser.parse_args()

config = load_config(args.config_file)
# exp_num = config['ExpNum']
seed = config['Seed']


print("Applied Seed: ", seed)

# # Number of experiments to run
# exp_num = 5

# verbose for training 
verbose = True
# verbose = True
batch_size = 128

epochs = 20

# Define the save path
full_path = 'train_resultsCIFAR100Res/'+str(seed)
full_path_his = 'train_resultsCIFAR100Res/his/'+str(seed)

# Set random seed
keras.utils.set_random_seed(seed)
tf.config.experimental.enable_op_determinism()

# Prepare training dataset
(x_train, y_train), (x_test, y_test) = datasets.cifar100.load_data()

x_train = x_train / 255.0
x_test = x_test / 255.0
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
y_train = to_categorical(y_train, 100)
y_test = to_categorical(y_test, 100)

# standard normalizing
x_train = (x_train - np.array([[[0.4914, 0.4822, 0.4465]]])) / np.array([[[0.2023, 0.1994, 0.2010]]])
x_test = (x_test - np.array([[[0.4914, 0.4822, 0.4465]]])) / np.array([[[0.2023, 0.1994, 0.2010]]])
val_samples = -10000

x_val = x_train[val_samples:]
y_val = y_train[val_samples:]


x_train = x_train[:val_samples]
y_train = y_train[:val_samples]

datagen = ImageDataGenerator(zca_epsilon=1e-06, width_shift_range=0.1, height_shift_range=0.1, fill_mode='nearest',horizontal_flip=True)
datagen.fit(x_train)

augmented_data = datagen.flow(x_train, y_train, batch_size=len(x_train), shuffle=False)

# Get the processed data (x_train) and labels
x_train, y_train = augmented_data.next()

def lr_scheduler(epoch):
    """Learning Rate Schedule

    Learning rate is scheduled to be reduced after 80, 120, 160, 180 epochs.
    Called automatically every epoch as part of callbacks during training.

    # Arguments
        epoch (int): The number of epochs

    # Returns
        lr (float32): learning rate
    """
    lr = 1e-3
    if epoch > 180:
        lr *= 0.5e-3
    elif epoch > 160:
        lr *= 1e-3
    elif epoch > 120:
        lr *= 1e-2
    # elif epoch > 80:
    # elif epoch > 40:
    elif epoch > 15:
    # elif epoch > 20:
        lr *= 1e-1
    # print('Learning rate: ', lr)
    return lr

lr_scheduler_mod = LearningRateScheduler(lr_scheduler)


# Build and compile ResNet
# opt = SGD(learning_rate=0.1, momentum=0.9)
opt = Adam(learning_rate=0.001)
SNN = define_resnet50_CNN(input_shape=(32, 32, 3), num_classes=100, weights='imagenet')
SNN.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['acc'])

start = time.time()   

result = SNN.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=verbose, validation_data=(x_val, y_val), validation_batch_size=batch_size, callbacks=[lr_scheduler_mod])
  
end = time.time()
print(end-start)

SNN.evaluate(x_test, y_test)

# weigts_to_save = SNN.get_weights()
# with open(full_path + '_weights', 'wb') as w:
#     pickle.dump(weigts_to_save, w)
    
SNN.save(full_path+'_modelRes.keras')

# Save trainig history
with open(full_path_his + '_result', 'wb') as file:
    pickle.dump(result.history, file)