# Import Necessary Packages
import tensorflow as tf
from tensorflow import keras
from keras.optimizers import SGD, Adam
from keras import datasets
from keras.utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import LearningRateScheduler
from tensorflow.keras.layers import Input, Dense, Flatten
from vit_keras import vit
import numpy as np
import time
import pickle
import argparse
import yaml

# print(tf.config.list_physical_devices('GPU'))
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.set_visible_devices(physical_devices[3], 'GPU')
logical_devices = tf.config.list_logical_devices('GPU')

def load_config(yaml_file):
    with open(yaml_file, 'r') as file:
        config = yaml.load(file, Loader=yaml.SafeLoader)
    return config


def define_CNN_model(input_shape, num_classes):
    
    base = vit.vit_b16(
        image_size = (224, 224),
        activation = 'softmax',
        pretrained = True,
        include_top = False,
        pretrained_top = False,
        classes = num_classes)
    
    inputs = Input(input_shape)
    
    x = tf.keras.layers.UpSampling2D(size=(7, 7))(inputs)
    
    x = base(x)
    # x = GlobalAveragePooling2D()(x)     
    x = Flatten()(x)
    x = Dense(units=1024, activation='relu')(x)
    x = Dense(units=512, activation='relu')(x) 
    
    outputs = Dense(units=num_classes, activation='softmax')(x)
    
    model = keras.Model(inputs, outputs, name='SNNViT')
    return model


parser = argparse.ArgumentParser(description='Process parameters from a YAML file.')
parser.add_argument('config_file', type=str, help='Path to the YAML configuration file')
args = parser.parse_args()

config = load_config(args.config_file)

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

seed = config['Seed']

print("Applied Seed: ", seed)

# verbose for training 
batch_size = 128

# Set random seed
keras.utils.set_random_seed(seed)
tf.config.experimental.enable_op_determinism()

# Prepare training dataset
(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()

x_train = x_train / 255.0
x_test = x_test / 255.0
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

# standard normalizing
x_train = (x_train - np.array([[[0.4914, 0.4822, 0.4465]]])) / np.array([[[0.2023, 0.1994, 0.2010]]])
x_test = (x_test - np.array([[[0.4914, 0.4822, 0.4465]]])) / np.array([[[0.2023, 0.1994, 0.2010]]])

val_samples = -10000

x_val = x_train[val_samples:]
y_val = y_train[val_samples:]


x_train = x_train[:val_samples]
y_train = y_train[:val_samples]

# Define the save path
full_path = 'trainViTSNN/'+str(seed)
full_path_his = 'trainViTSNN/his/'+str(seed)

opt = Adam(learning_rate=0.001)
SNN = define_CNN_model(input_shape=(32, 32, 3), num_classes=10)
SNN.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['acc'])


epochs = 25

def lr_scheduler(epoch):
    """Learning Rate Schedule

    Learning rate is scheduled to be reduced after 80, 120, 160, 180 epochs.
    Called automatically every epoch as part of callbacks during training.

    # Arguments
        epoch (int): The number of epochs

    # Returns
        lr (float32): learning rate
    """
    lr = 1e-3
    if epoch > 180:
        lr *= 0.5e-3
    elif epoch > 160:
        lr *= 1e-3
    elif epoch > 120:
        lr *= 1e-2
    elif epoch > 20:
        lr *= 1e-1
    return lr

lr_scheduler_mod = LearningRateScheduler(lr_scheduler)

start = time.time()   
result = SNN.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=True, validation_data=(x_val, y_val), validation_batch_size=batch_size, callbacks=[lr_scheduler_mod])
end = time.time()

print('TT:', end-start)

SNN.evaluate(x_test, y_test)


weights_to_save = SNN.get_weights()

# Save trainig history
with open(full_path_his + '_result', 'wb') as file:
    pickle.dump(result.history, file)

# Save trainig history
with open(full_path + '_weights', 'wb') as file_weights:
    pickle.dump(weights_to_save, file_weights)
