# Import Necessary Packages
import tensorflow as tf
from model.resnet import build_resnet
from tensorflow import keras
from keras.optimizers import SGD, Adam
from keras import datasets
from keras.utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import LearningRateScheduler
from tensorflow.keras.layers import Activation, Input, Dense, GlobalAveragePooling2D, BatchNormalization, Flatten
import numpy as np
import time
import pickle
import argparse
import yaml
from vit_keras import vit

def load_config(yaml_file):
    with open(yaml_file, 'r') as file:
        config = yaml.load(file, Loader=yaml.SafeLoader)
    return config
    

def define_CNN_model(input_shape, num_classes, weights):
    
    base = tf.keras.applications.resnet50.ResNet50(include_top=False, weights=weights, input_shape=(224, 224, 3), classes=num_classes)
    
    inputs = Input(input_shape)
    
    x = tf.keras.layers.UpSampling2D(size=(7, 7))(inputs)
    
    x = base(x)
    x = GlobalAveragePooling2D()(x)     
    x = Flatten()(x)
    x = Dense(units=1024, activation='relu')(x)
    x = Dense(units=512, activation='relu')(x) 
    
    outputs = Dense(units=num_classes, activation='softmax')(x)
    
    model = keras.Model(inputs, outputs, name='SNN')
    return model

# def define_CNN_model(input_shape, num_classes):
    
#     base = vit.vit_b16(
#         image_size = (224, 224),
#         activation = 'softmax',
#         pretrained = True,
#         include_top = False,
#         pretrained_top = False,
#         classes = num_classes)
    
#     inputs = Input(input_shape)
    
#     x = tf.keras.layers.UpSampling2D(size=(7, 7))(inputs)
    
#     x = base(x)
#     # x = GlobalAveragePooling2D()(x)     
#     x = Flatten()(x)
#     x = Dense(units=1024, activation='relu')(x)
#     x = Dense(units=512, activation='relu')(x) 
    
#     outputs = Dense(units=num_classes, activation='softmax')(x)
    
#     model = keras.Model(inputs, outputs, name='SNNViT')
#     return model


parser = argparse.ArgumentParser(description='Process parameters from a YAML file.')
parser.add_argument('config_file', type=str, help='Path to the YAML configuration file')
args = parser.parse_args()

config = load_config(args.config_file)

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

seed = config['Seed']

print("Applied Seed: ", seed)

# Number of experiments to run
exp_num = 6666

# verbose for training 
verbose = True
batch_size = 128

epochs = 15

# Define the save path
full_path = 'train_resultsSNN10/'+str(seed)
full_path_his = 'train_resultsSNN10/his/'+str(seed)
# Set random seed
keras.utils.set_random_seed(seed)
tf.config.experimental.enable_op_determinism()

# Prepare training dataset
(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()

x_train = x_train / 255.0
x_test = x_test / 255.0
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

# standard normalizing
x_train = (x_train - np.array([[[0.4914, 0.4822, 0.4465]]])) / np.array([[[0.2023, 0.1994, 0.2010]]])
x_test = (x_test - np.array([[[0.4914, 0.4822, 0.4465]]])) / np.array([[[0.2023, 0.1994, 0.2010]]])

val_samples = -10000

x_val = x_train[val_samples:]
y_val = y_train[val_samples:]


x_train = x_train[:val_samples]
y_train = y_train[:val_samples]

datagen = ImageDataGenerator(zca_epsilon=1e-06, width_shift_range=0.1,height_shift_range=0.1, fill_mode='nearest',horizontal_flip=True)
datagen.fit(x_train)

   
def lr_scheduler(epoch):
    """Learning Rate Schedule

    Learning rate is scheduled to be reduced after 80, 120, 160, 180 epochs.
    Called automatically every epoch as part of callbacks during training.

    # Arguments
        epoch (int): The number of epochs

    # Returns
        lr (float32): learning rate
    """
    lr = 1e-3
    if epoch > 180:
        lr *= 0.5e-3
    elif epoch > 160:
        lr *= 1e-3
    elif epoch > 120:
        lr *= 1e-2
    # elif epoch > 80:
    # elif epoch > 40:
    # elif epoch > 25:
    elif epoch > 10:
        lr *= 1e-1
    # print('Learning rate: ', lr)
    return lr

lr_scheduler_mod = LearningRateScheduler(lr_scheduler)


# Build and compile ResNet
# opt = SGD(learning_rate=0.1, momentum=0.9)
opt = Adam(learning_rate=0.001)
SNN = define_CNN_model(input_shape=(32, 32, 3), num_classes=10, weights='imagenet')
SNN.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['acc'])

start = time.time()   
result = SNN.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=False, validation_data=(x_val, y_val), validation_batch_size=batch_size, callbacks=[lr_scheduler_mod])
end = time.time()
print(end-start)

SNN.evaluate(x_test, y_test)

SNN.save(full_path+'_modelRes.keras')

# Save trainig history
with open(full_path_his + '_result', 'wb') as file:
    pickle.dump(result.history, file)