import pickle
import numpy as np
import tensorflow as tf
import argparse
import yaml
from tensorflow import keras
from keras.optimizers import Adam
from test_core.load_test_data import load_cifar_corruption, load_cinic10_test, load_svhn_test, load_cifar10, load_tinyimage_test
import time
from tensorflow.keras.applications import EfficientNetB2
from tensorflow.keras.layers import Activation, Input, Dense, GlobalAveragePooling2D, Flatten
from vit_keras import vit

# def define_CNN_model(input_shape, num_classes, weights):
    
#     # base = tf.keras.applications.resnet50.ResNet50(include_top=False, weights=weights, input_shape=(224, 224, 3), classes=num_classes)
#     base = EfficientNetB2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    
#     inputs = Input(input_shape)
    
#     x = tf.keras.layers.UpSampling2D(size=(7, 7))(inputs)
    
#     x = base(x)
#     x = GlobalAveragePooling2D()(x)     
#     x = Flatten()(x)
#     x = Dense(units=1024, activation='relu')(x)
#     x = Dense(units=512, activation='relu')(x) 
    
#     outputs = Dense(units=num_classes, activation='softmax')(x)
    
#     model = keras.Model(inputs, outputs, name='SNN')
#     return model

# def define_CNN_model(input_shape, num_classes):
    
#     # base = vit.vit_b16(
#     #     image_size = (224, 224),
#     #     activation = 'softmax',
#     #     pretrained = True,
#     #     include_top = False,
#     #     pretrained_top = False,
#     #     classes = num_classes)
    
#     # base = tf.keras.applications.DenseNet121(
#     #     include_top=False,
#     #     weights='imagenet',
#     #     input_shape=(224, 224, 3),
#     #     classifier_activation='softmax')
    
#     # base = tf.keras.applications.Xception(
#     #     include_top=False,
#     #     weights='imagenet',
#     #     input_shape=(224, 224, 3),
#     #     classifier_activation='softmax')

#     base = tf.keras.applications.ConvNeXtTiny(
#         model_name='convnext_tiny',
#         include_top=False,
#         include_preprocessing=False,
#         weights='imagenet',
#         input_shape=(224, 224, 3),
#         classifier_activation='softmax'
#     )
    
#     inputs = Input(input_shape)
    
#     x = tf.keras.layers.UpSampling2D(size=(7, 7))(inputs)
    
#     x = base(x)
#     x = GlobalAveragePooling2D()(x)     
#     # x = Flatten()(x)
#     x = Dense(units=1024, activation='relu')(x)
#     x = Dense(units=512, activation='relu')(x) 
    
#     outputs = Dense(units=num_classes, activation='softmax')(x)
    
#     model = keras.Model(inputs, outputs)
#     return model


# def define_CNN_model(input_shape, num_classes):
    
#     base = vit.vit_b16(
#         image_size = (224, 224),
#         activation = 'softmax',
#         pretrained = True,
#         include_top = False,
#         pretrained_top = False,
#         classes = num_classes)
    
#     inputs = Input(input_shape)
    
#     x = tf.keras.layers.UpSampling2D(size=(7, 7))(inputs)
    
#     x = base(x)
#     # x = GlobalAveragePooling2D()(x)     
#     x = Flatten()(x)
#     x = Dense(units=1024, activation='relu')(x)
#     x = Dense(units=512, activation='relu')(x) 
    
#     outputs = Dense(units=num_classes, activation='softmax')(x)
    
#     model = keras.Model(inputs, outputs, name='SNNViT')
#     return model


# def resnet50(input_shape, num_classes):
#     inputs = Input(input_shape)
#     x = EfficientNetB2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))(inputs)
#     # x = tf.keras.applications.resnet50.ResNet50(include_top=False, weights=None, input_shape=(32, 32, 3), classes=num_classes)(inputs)
#     x = GlobalAveragePooling2D()(x)
#     outputs = Dense(units=num_classes, activation='softmax')(x)
#     model = keras.Model(inputs, outputs, name='RES50')  
    
#     return model
    
def load_model(full_path):
    # snn = resnet50(input_shape=(32, 32, 3), num_classes=10)
    
    # with open(full_path + '_weights', 'rb') as file:
    #     weights = pickle.load(file)
        
    # snn.set_weights(weights)
    # snn = tf.keras.saving.load_model(full_path+"_modelRes.keras")
    # opt = Adam(learning_rate=0.001)
    # # snn = define_CNN_model(input_shape=(32, 32, 3), num_classes=10, weights='imagenet')
    # snn = define_CNN_model(input_shape=(32, 32, 3), num_classes=10)
    # snn.compile(optimizer=opt)
    # with open(full_path + '_weights', 'rb') as file:
    #     weights = pickle.load(file)
        
    # snn.set_weights(weights)
    snn = tf.keras.saving.load_model(full_path+"_modelRes.keras")
    return snn

def single_model_evaluate(model, x_test, y_test, IFAcc):
    pred = model.predict(x_test)
    eps = 1e-12
    entropy = -np.sum(pred*np.log2(pred + eps), axis=-1)

    if IFAcc:
        m = tf.keras.metrics.CategoricalAccuracy()
        m.update_state(y_test, pred)
        acc = m.result().numpy()
        m.reset_state()
    else:
        acc = None
    return pred, acc, entropy

def ensembl_evaluate(preds, y_test, IFAcc):
    pred_ensemble = np.mean(preds, axis=0)
    
    # print(preds.shape)
    
    eps = 1e-12
    
    # tu = -np.sum(np.mean(preds, axis=0)*np.log2(np.mean(preds, axis=0) + eps))
    tu = -np.sum(pred_ensemble*np.log2(pred_ensemble + eps), axis=-1)
    # print(tu.shape)
    au = np.mean(-np.sum(preds*np.log2(preds + eps), axis=-1), axis=0)  
    eu = tu - au
   
    entropy = dict()
    entropy['TU'] = tu
    entropy['EU'] = eu
    entropy['AU'] = au

    if IFAcc:
        m = tf.keras.metrics.CategoricalAccuracy()
        m.update_state(y_test, pred_ensemble)
        acc = m.result().numpy()
        m.reset_state()  
    else:
        acc = None
    return pred_ensemble, acc, entropy

def snn_evaluation(dataset):
    seeds = [0, 66, 99, 314, 524, 803, 888, 908, 1103, 1208, 7509, 11840, 40972, 46857, 54833]
    
    ######### Get 15 Ensembles #########
    with open('three_ensembles', 'rb') as file:
        DEs3 = pickle.load(file)
        
    with open('five_ensembles', 'rb') as file:
        DEs5 = pickle.load(file)
        
    ######### Load models #########
    model_list = list()
    for i in range(15):
        model_path = 'train_resultsSNNVGG/' + str(seeds[i])
        snn = load_model(model_path)
        model_list.append(snn)

    ######## Unified dictionary for saving the results ########
    cifar = dict()
    cifar3 = dict()
    cifar5 = dict()

    if dataset == 'CIFAR10':
        ######### Test on CIFAR10 dataset #########
        cifar = {'pred': [], 'acc': [], 'entro': [], 'label': []}
        cifar3 = {'pred': [], 'acc': [], 'entro': []}
        cifar5 = {'pred': [], 'acc': [], 'entro': []}
        
        (_, _), (x_cifar, y_cifar) = load_cifar10()
        ################################
        ######### Single Model #########
        ################################
        for i in range(15):
            model = model_list[i]
            pred, acc, entropy = single_model_evaluate(model, x_cifar, y_cifar, IFAcc=True)
            # Save the result of single model
            cifar['pred'].append(pred)
            cifar['acc'].append(acc)
            cifar['entro'].append(entropy)

        ################################
        ######### Ensembles-3 ##########
        ################################
        preds_15 = np.stack(cifar['pred'])

        for j in range(15):
            DEs3Index = DEs3[str(j)]
            preds = preds_15[DEs3Index,]
            pred_ensemble, acc, entropy = ensembl_evaluate(preds, y_cifar, IFAcc=True)
            # Save the result of ensemble
            cifar3['pred'].append(pred_ensemble)
            cifar3['acc'].append(acc)
            cifar3['entro'].append(entropy)

        ################################
        ######### Ensembles-5 ##########
        ################################
        for k in range(15):
            DEs5Index = DEs5[str(k)]
            preds = preds_15[DEs5Index,]
            pred_ensemble, acc, entropy = ensembl_evaluate(preds, y_cifar, IFAcc=True)
            # Save the result of ensemble
            cifar5['pred'].append(pred_ensemble)
            cifar5['acc'].append(acc)
            cifar5['entro'].append(entropy)

        cifar['label'] = y_cifar

    elif 'CIFAR10C' in dataset:
        ######### Test on CIFAR10_C dataset #########  
        ######### Determine Corruption Type and Severity Level #########
        parts = dataset.split('_')
        if len(parts) == 3:
            cor_type = int(parts[1])
            sev_level = int(parts[2])
        else:
            print("Invalid dataset string format.")
                            
        x_test, y_test = load_cifar_corruption(cor_type, sev_level)
        
        cifar = {'pred': [], 'acc': [], 'entro': []}
        cifar3 = {'pred': [], 'acc': [], 'entro': []}
        cifar5 = {'pred': [], 'acc': [], 'entro': []}
        for i in range(15):
            model = model_list[i]
            pred, acc, entropy = single_model_evaluate(model, x_test, y_test, IFAcc=True)

            # Save the result of single model
            cifar['pred'].append(pred)
            cifar['acc'].append(acc)
            cifar['entro'].append(entropy)

        ################################
        ######### Ensembles-3 ##########
        ################################
        preds_15 = np.stack(cifar['pred'])

        for j in range(15):
            DEs3Index = DEs3[str(j)]
            preds = preds_15[DEs3Index,]
            pred_ensemble, acc, entropy = ensembl_evaluate(preds, y_test, IFAcc=True)
            # Save the result of ensemble
            cifar3['pred'].append(pred_ensemble)
            cifar3['acc'].append(acc)
            cifar3['entro'].append(entropy)

        ################################
        ######### Ensembles-5 ##########
        ################################
        for k in range(15):
            DEs5Index = DEs5[str(k)]
            preds = preds_15[DEs5Index,]
            pred_ensemble, acc, entropy = ensembl_evaluate(preds, y_test, IFAcc=True)
            # Save the result of ensemble
            cifar5['pred'].append(pred_ensemble)
            cifar5['acc'].append(acc)
            cifar5['entro'].append(entropy)

    elif dataset == 'SVHN':
        ######### Test on SVHN dataset #########
        cifar = {'pred': [], 'entro': [],}
        cifar3 = {'pred': [], 'entro': [],}
        cifar5 = {'pred': [], 'entro': [],}

        x_svhn, y_svhn = load_svhn_test()
        
        for i in range(15):           
            model = model_list[i]
            pred, _, entropy = single_model_evaluate(model, x_svhn, y_svhn, IFAcc=False)

            # Save the result of single model
            cifar['pred'].append(pred)
            cifar['entro'].append(entropy)

        ################################
        ######### Ensembles-3 ##########
        ################################
        preds_15 = np.stack(cifar['pred'])

        for j in range(15):
            DEs3Index = DEs3[str(j)]
            preds = preds_15[DEs3Index,]
            pred_ensemble, _, entropy = ensembl_evaluate(preds, y_svhn, IFAcc=False)
            # Save the result of ensemble
            cifar3['pred'].append(pred_ensemble)
            cifar3['entro'].append(entropy)

        ################################
        ######### Ensembles-5 ##########
        ################################
        for k in range(15):
            DEs5Index = DEs5[str(k)]
            preds = preds_15[DEs5Index,]
            pred_ensemble, _, entropy = ensembl_evaluate(preds, y_svhn, IFAcc=False)
            # Save the result of ensemble
            cifar5['pred'].append(pred_ensemble)
            cifar5['entro'].append(entropy)
 
    elif dataset == 'TinyImage':
        ######### Test on CIFAR10 dataset #########
        cifar = {'pred': [], 'entro': []}
        cifar3 = {'pred': [], 'entro': []}
        cifar5 = {'pred': [], 'entro': []}

        x_tiny, y_tiny = load_tinyimage_test()
        ################################
        ######### Single Model #########
        ################################
        for i in range(15):
            model = model_list[i]
            pred, acc, entropy = single_model_evaluate(model, x_tiny, y_tiny, IFAcc=False)
            # Save the result of single model
            cifar['pred'].append(pred)
            cifar['entro'].append(entropy)
        
        ################################
        ######### Ensembles-3 ##########
        ################################
        preds_15 = np.stack(cifar['pred'])

        for j in range(15):
            DEs3Index = DEs3[str(j)]
            preds = preds_15[DEs3Index,]
            pred_ensemble, _, entropy = ensembl_evaluate(preds, y_tiny, IFAcc=False)
            # Save the result of ensemble
            cifar3['pred'].append(pred_ensemble)
            cifar3['entro'].append(entropy)

        ################################
        ######### Ensembles-5 ##########
        ################################
        for k in range(15):
            DEs5Index = DEs5[str(k)]
            preds = preds_15[DEs5Index,]
            pred_ensemble, _, entropy = ensembl_evaluate(preds, y_tiny, IFAcc=False)
            # Save the result of ensemble
            cifar5['pred'].append(pred_ensemble)
            cifar5['entro'].append(entropy)      

    else:
        print("Invalid Dataset Name. Try again...")
    return cifar, cifar3, cifar5


def load_config(yaml_file):
    with open(yaml_file, 'r') as file:
        config = yaml.load(file, Loader=yaml.SafeLoader)
    return config


def main():
    
    # Accept a YAML file as a command-line argument
    parser = argparse.ArgumentParser(description='Process parameters from a YAML file.')
    parser.add_argument('config_file', type=str, help='Path to the YAML configuration file')
    args = parser.parse_args()

    config = load_config(args.config_file)

    # Access hyperparameters from the loaded configuration
    # exp_num = config['ExpNum']
    exp_num = 1
    dataset_name = config['Dataset']
    delta = config['Delta']
    
    start_time = time.time()
    result, result3, result5 = snn_evaluation(dataset_name)
    end_time = time.time()
    print(end_time - start_time)
    
    full_path = 'test_resultsSNNVGG/' + dataset_name
    
    # Save test history
    with open(full_path + '_result', 'wb') as file:
        pickle.dump(result, file)
    with open(full_path + '_result3', 'wb') as file3:
        pickle.dump(result3, file3)
    with open(full_path + '_result5', 'wb') as file5:
        pickle.dump(result5, file5)

if __name__ == "__main__":
    main()
