# Import Necessary Packages
import tensorflow as tf
from model.resnet import build_resnet
from tensorflow import keras
from keras.optimizers import SGD, Adam
from keras import datasets
from keras.utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import LearningRateScheduler
from tensorflow.keras.layers import Activation, Input, Dense, GlobalAveragePooling2D, BatchNormalization, Flatten
from tensorflow.keras.applications import EfficientNetB2
from vit_keras import vit
import numpy as np
import time
import pickle
import argparse
import yaml

def load_config(yaml_file):
    with open(yaml_file, 'r') as file:
        config = yaml.load(file, Loader=yaml.SafeLoader)
    return config
    

def define_CNN_model(input_shape, num_classes, weights):
    
    base = tf.keras.applications.resnet50.ResNet50(include_top=False, weights=weights, input_shape=(224, 224, 3), classes=num_classes)
    # base = EfficientNetB2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    
    inputs = Input(input_shape)
    
    x = tf.keras.layers.UpSampling2D(size=(7, 7))(inputs)
    
    x = base(x)
    x = GlobalAveragePooling2D()(x)     
    x = Flatten()(x)
    x = Dense(units=1024, activation='relu')(x)
    x = Dense(units=512, activation='relu')(x) 
    
    outputs = Dense(units=num_classes, activation='softmax')(x)
    
    model = keras.Model(inputs, outputs, name='SNN')
    return model




parser = argparse.ArgumentParser(description='Process parameters from a YAML file.')
parser.add_argument('config_file', type=str, help='Path to the YAML configuration file')
args = parser.parse_args()

config = load_config(args.config_file)

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

seed = config['Seed']

print("Applied Seed: ", seed)

# Number of experiments to run
exp_num = 6666

# verbose for training 
verbose = True
batch_size = 128

epochs = 15

# Define the save path
full_path = 'TrainResNet/'+str(seed)
full_path_his = 'TrainResNet/his/'+str(seed)
# Set random seed
# keras.utils.set_random_seed(seed)
# tf.config.experimental.enable_op_determinism()

# Prepare training dataset
(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()

x_train = x_train / 255.0
x_test = x_test / 255.0
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

# standard normalizing
x_train = (x_train - np.array([[[0.4914, 0.4822, 0.4465]]])) / np.array([[[0.2023, 0.1994, 0.2010]]])
x_test = (x_test - np.array([[[0.4914, 0.4822, 0.4465]]])) / np.array([[[0.2023, 0.1994, 0.2010]]])

val_samples = -10000

x_val = x_train[val_samples:]
y_val = y_train[val_samples:]


x_train = x_train[:val_samples]
y_train = y_train[:val_samples]

   
def lr_scheduler(epoch):
    """Learning Rate Schedule

    Learning rate is scheduled to be reduced after 80, 120, 160, 180 epochs.
    Called automatically every epoch as part of callbacks during training.

    # Arguments
        epoch (int): The number of epochs

    # Returns
        lr (float32): learning rate
    """
    lr = 1e-3
    if epoch > 180:
        lr *= 0.5e-3
    elif epoch > 160:
        lr *= 1e-3
    elif epoch > 120:
        lr *= 1e-2
    # elif epoch > 80:
    # elif epoch > 40:
    # elif epoch > 25:
    elif epoch > 10:
        lr *= 1e-1
    # print('Learning rate: ', lr)
    return lr

lr_scheduler_mod = LearningRateScheduler(lr_scheduler)


# Build and compile ResNet
# opt = SGD(learning_rate=0.1, momentum=0.9)
opt = Adam(learning_rate=0.001)
SNN = define_CNN_model(input_shape=(32, 32, 3), num_classes=10, weights='imagenet')
SNN.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['acc'])

start = time.time()   
result = SNN.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=True, validation_data=(x_val, y_val), validation_batch_size=batch_size, callbacks=[lr_scheduler_mod])
end = time.time()
print(end-start)

SNN.evaluate(x_test, y_test)

SNN.save(full_path+'_modelRes.keras')

# Save trainig history
with open(full_path_his + '_result', 'wb') as file:
    pickle.dump(result.history, file)