import tensorflow as tf
from tensorflow import keras
from keras.optimizers import SGD, Adam
from keras import datasets
from keras.utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator
import numpy as np
import time
import argparse
import yaml
import logging
import pickle
from models.res18_edd import resnet18_EDD
from models.vgg16_edd import VGG16_EDD


####################################
######## Training Settings  ########
####################################
def parse_config(config_str):
    config = {}
    for item in config_str.split(','):
        key, value = item.split('=')
        value = value.strip().strip("'\"")  # To remove extra spaces and quotes
        config[key.strip()] = value
    return config

parser = argparse.ArgumentParser(description='Process some configuration.')
parser.add_argument('--config', type=str, required=True, 
                    help="Comma-separated list of key=value pairs for configuration, e.g., ExpNum=1, Architecture='RES18'")
args = parser.parse_args()
config = parse_config(args.config)


ExpNum = int(config.get('ExpNum', 1))
backbone = config.get('Architecture', 'RES50')

# Temp = 2.5
Temp = 10

epochs = 100
# epochs = 3
# epochs = 200
learning_rate = 0.001
cycle_length = 60

print(f'ExpNum: {ExpNum}; BACKBONE: {backbone}; Temperature: {Temp}; Epochs: {epochs}; Cycle length: {cycle_length}')

####################################
######### DATA AUGMENTATION ########
####################################
def dataset_generator(images, labels, batch_size):
    ds = tf.data.Dataset.from_tensor_slices((images, labels))
    ds = ds.map(_augment_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = ds.shuffle(len(images)).batch(batch_size)
    ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    return ds

def _augment_fn(images, labels):
    padding = 4
    image_size = 32
    target_size = image_size + padding*2
    images = tf.image.pad_to_bounding_box(images, padding, padding, target_size, target_size)
    images = tf.image.random_crop(images, (image_size, image_size, 3))
    images = tf.image.random_flip_left_right(images)
    return images, labels

####################################
######### Load Teacher Model #######
####################################
def load_model(teacher_model_path):

    global backbone
    if backbone == 'RES18':
        from models.res18 import resnet18
        teacher_model = resnet18(input_shape=(32, 32, 3), num_classes=10)
    else:
        from models.vgg16 import VGG16
        teacher_model = VGG16(input_shape=(32, 32, 3), num_classes=10)
    opt=Adam(learning_rate=0.001)
    teacher_model.compile(optimizer=opt)

    # load weights
    with open(teacher_model_path + '_weights', 'rb') as file:
        result = pickle.load(file)
    teacher_model.set_weights(result)

    return teacher_model

####################################
##### Single Model Evaluation ######
####################################
def single_model_evaluate(model, x_test, y_test):
    global Temp
    
    pred = model.predict(x_test)

    pred_temp = tf.nn.softmax(pred/Temp, axis=-1)
    
    m = tf.keras.metrics.CategoricalAccuracy()
    m.update_state(y_test, pred_temp)
    acc = m.result().numpy()
    m.reset_state()
 
    return pred, acc
    
def main():
    print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

    batch_size = 128
    global epochs
    global cycle_length
    global ExpNum
    global learning_rate
    global backbone
    global Temp

    full_path = 'path_to_save_the_model_weights' + str(Temp) + '/' + str(ExpNum)
    full_path_his = 'path_to_save_the_training_history'  + str(Temp) + '/his/' + str(ExpNum)

    # Set random seed
    # keras.utils.set_random_seed(ExpNum)
    # tf.config.experimental.enable_op_determinism()

    ###################################################################    
    ##################### Prepare training dataset ####################
    ###################################################################
    (x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
    
    x_train = x_train / 255.0
    x_test = x_test / 255.0
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    y_train = to_categorical(y_train, 10)
    y_test = to_categorical(y_test, 10)
    
    # standard normalizing
    x_train = (x_train - np.array([[[0.4914, 0.4822, 0.4465]]])) / np.array([[[0.2023, 0.1994, 0.2010]]])
    x_test = (x_test - np.array([[[0.4914, 0.4822, 0.4465]]])) / np.array([[[0.2023, 0.1994, 0.2010]]])
    
    val_samples = -10000
    
    x_val = x_train[val_samples:]
    y_val = y_train[val_samples:]
    x_train = x_train[:val_samples]
    y_train = y_train[:val_samples]

    train_dataset = dataset_generator(x_train, y_train, batch_size)
    test_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size)
    
    # Define Learning Scheduler 
    def lr_scheduler(epoch):
        lr = 1e-3
        if epoch >= 180:
            lr *= 0.5e-3
        elif epoch >= 160:
            lr *= 1e-3
        elif epoch >= 120:
            lr *= 1e-2
        elif epoch >= 80:
            lr *= 1e-1
        return lr
    lr_scheduler_mod = lr_scheduler

    # def get_one_cycle_lr(epoch):
    #     """
    #     Calculate the learning rate for a given epoch based on the OneCycleLRPolicy.

    #     # Arguments:
    #         init_lr      : the initial learning rate
    #         max_lr       : the maximum learning rate
    #         min_lr       : the minimum learning rate
    #         cycle_length : the length of the cycle, in epochs, must be even
    #         epochs       : the total number of epochs, must be larger than cycle_length
    #         epoch        : the current epoch for which to compute the learning rate

    #     # Returns:
    #         lr: The learning rate for the given epoch
    #     """
        
    #     global learning_rate
    #     global cycle_length 
    #     global epochs

    #     init_lr = learning_rate
    #     max_lr = init_lr * 10
    #     min_lr = init_lr / 1000

    #     assert (cycle_length % 2 == 0), "cycle_length must be an even number"
    #     assert (epochs > cycle_length), "epochs must be greater than cycle_length"
        
    #     # Create the schedule for the learning rate
    #     lr_schedule = np.hstack(
    #         (np.linspace(init_lr, max_lr, cycle_length // 2),
    #         np.linspace(max_lr, init_lr, cycle_length // 2),
    #         np.linspace(init_lr, min_lr, epochs - cycle_length))
    #     )
        
    #     # Return the learning rate for the given epoch
    #     return lr_schedule[epoch]


    # lr_scheduler_mod = get_one_cycle_lr



    ####################################################
    ############## Define Teacher Model ################
    ####################################################
    with open('five_ensembles', 'rb') as file:
        DEs = pickle.load(file)

    model_key = str(ExpNum-1)
    teacher_models = []

    for model_idx in DEs[model_key]:
        seeds = ['a List of indices to recogonize the different single model trained']

        if backbone == 'RES18':
            teacher_model_path = 'Define_teacher_model_path'+str(seeds[model_idx])
        else:
            teacher_model_path = 'Define_teacher_model_path'+str(seeds[model_idx])

        teacher = load_model(teacher_model_path)

        teacher_models.append(teacher)

    ####################################################
    ######### Extract Ensemble Preds Teacher ###########
    ####################################################    
    def get_ensemble_preds(ensemble_pred_list):
        ensemble_preds = np.stack(ensemble_pred_list)
        ensemble_preds = np.transpose(ensemble_preds, (1, 0, 2))
        return ensemble_preds

    ####################################################
    ############## Define Student Model ################
    ####################################################
    opt=Adam(learning_rate=learning_rate)

    if backbone == 'RES18':
        edd_model = resnet18_EDD(input_shape=(32, 32, 3), num_classes=10)
    else:
        edd_model = VGG16_EDD(input_shape=(32, 32, 3), num_classes=10)
    edd_model.compile(optimizer=opt)
    
    # Define the metrics 
    loss_tracker = keras.metrics.Mean(name="Loss")
    val_loss_tracker = keras.metrics.Mean(name="ValLoss")
    acc_tracker = keras.metrics.CategoricalAccuracy(name='Acc')
    val_acc_tracker = keras.metrics.CategoricalAccuracy(name='ValAcc')

    ####################################################
    ################## Define Train Step ###############
    ####################################################
    def DirichletEDDLoss(logits, ensemble_logits, temperature):

        # Hyperparameter
        epsilon = 1e-8
        ensemble_epsilon = 1e-4
        smooth_val = epsilon
        tp_scaling = 1 - ensemble_epsilon

        logits = tf.cast(logits, dtype=tf.float64)
        ensemble_logits = tf.cast(ensemble_logits, dtype=tf.float64)

        alphas = tf.math.exp(logits / temperature)

        precision = tf.reduce_sum(alphas, axis=1)  #sum over classes

        ensemble_probs = tf.nn.softmax(ensemble_logits / temperature, axis=2)

        # Smooth for num. stability:
        probs_mean = 1 / (tf.shape(ensemble_probs)[2])  #divide by nr of classes

        # Subtract mean, scale down, add mean back)
        ensemble_probs = tp_scaling * (ensemble_probs - probs_mean) + probs_mean

        log_ensemble_probs_geo_mean = tf.math.reduce_mean(tf.math.log(ensemble_probs + smooth_val), axis=1)  #mean over ensembles

        target_independent_term = tf.math.reduce_sum(tf.math.lgamma(alphas + smooth_val), axis=1) - tf.math.lgamma(
            precision + smooth_val)  #sum over lgammma of classes - lgamma(precision)
    
        target_dependent_term = -tf.math.reduce_sum((alphas - 1.) * log_ensemble_probs_geo_mean, axis=1)  # -sum over classes

        cost = target_dependent_term + target_independent_term
        
        return tf.math.reduce_mean(cost) * temperature**2
    
    ####################################################
    ########### Define Temperature Annealing ###########
    ####################################################   
    def temperature_annealing(epoch):
        global Temp
        global cycle_length
        global epochs

        init_temp = Temp

        # Ensure the cycle length is even and epochs are greater than cycle length
        assert (cycle_length % 2 == 0)
        assert (epochs > cycle_length)

        # Define lengths for the three phases
        first_length = cycle_length // 2
        second_length = cycle_length - first_length
        third_length = epochs - second_length - first_length

        # First phase: constant temperature at init_temp
        first_schedule = [init_temp] * first_length

        # Second phase: linear decrease from init_temp to 1
        slope = (init_temp - 1) / second_length
        second_schedule = [init_temp - slope * i for i in range(second_length)]

        # Third phase: constant temperature at 1
        third_schedule = [1] * third_length

        # Combine the phases into one schedule
        schedule = first_schedule + second_schedule + third_schedule

        # Ensure the schedule length matches the number of epochs
        assert len(schedule) == epochs

        # Return the temperature for the given epoch
        return schedule[epoch]


    def train_step(data, temperature):
    
        inputs, labels = data

        preds_teachers_logits = []
        for sub_model in  teacher_models:
            pred_sub_model = sub_model(inputs, training=False)
            preds_teachers_logits.append(pred_sub_model)

        preds_teachers = get_ensemble_preds(preds_teachers_logits)

        with tf.GradientTape() as tape:
            logits = edd_model(inputs, training=True)
            loss_total = DirichletEDDLoss(logits, preds_teachers, temperature)

        grads = tape.gradient(loss_total, edd_model.trainable_variables)
        edd_model.optimizer.apply_gradients(zip(grads, edd_model.trainable_variables))

        loss_tracker.update_state(loss_total)

        # Update accuracy
        preds_temp = tf.nn.softmax(logits, axis=-1)

        acc_tracker.update_state(labels, preds_temp)
    
        return loss_total

    ###################################################################
    ######################## Define Test Step #########################
    ###################################################################
    def test_step(data, temperature):
        inputs, labels = data
        logits = edd_model(inputs, training=False)
        preds_temp = tf.nn.softmax(logits, axis=-1)

        preds_teachers_logits = []
        for sub_model in  teacher_models:
            pred_sub_model = sub_model(inputs, training=False)
            preds_teachers_logits.append(pred_sub_model)

        preds_teachers = get_ensemble_preds(preds_teachers_logits)

        loss_total = DirichletEDDLoss(logits, preds_teachers, temperature)
       
        val_loss_tracker.update_state(loss_total)

        # Update validation accuracy
        val_acc_tracker.update_state(labels, preds_temp)


    ###################################################################
    ######################## Train & Validate Loop ####################
    ###################################################################
    start = time.time()
    result_history = {'Acc': [], 'Loss': [], 'val_Acc': [], 'val_Loss': []}
    for epoch in range(epochs):

        temperature = temperature_annealing(epoch)

        # TRAIN LOOP
        for x in train_dataset:
            _ = train_step(x, temperature)
        result_history['Acc'].append(acc_tracker.result().numpy())
        result_history['Loss'].append(loss_tracker.result().numpy())
        
        edd_model.optimizer.learning_rate = lr_scheduler_mod(epoch)
        print(f'Epoch {epoch + 1}/{epochs}, Learning Rate: {edd_model.optimizer.learning_rate.numpy()}')
        
        # TEST LOOP
        for x in test_dataset:
            test_step(x, temperature)
        
        # Update to history per epoch
        result_history['val_Acc'].append(val_acc_tracker.result().numpy())
        result_history['val_Loss'].append(val_loss_tracker.result().numpy())
    
        template = ("Epoch {}, Loss: {:.4f}, Acc: {:.4f}, ValLoss: {:.4f}, ValAcc: {:.4f}")
        
        print(template.format(epoch + 1, loss_tracker.result(), acc_tracker.result(), val_loss_tracker.result(), val_acc_tracker.result()))
    
        acc_tracker.reset_states()
        loss_tracker.reset_states()
        val_acc_tracker.reset_states()
        val_loss_tracker.reset_states()

               
    end = time.time()
    print(end-start)   
    result = result_history

    pred, acc = single_model_evaluate(edd_model, x_test, y_test)
    
    print('Test Acc: ', acc)
    
    # Save trainig history
    with open(full_path_his + '_result', 'wb') as file:
        pickle.dump(result, file)

    weights_to_save = edd_model.get_weights() 
    with open(full_path + '_weights', 'wb') as file2:
        pickle.dump(weights_to_save, file2)

if __name__ == "__main__":
    main()
    