import pickle
import numpy as np
import tensorflow as tf
import argparse
import yaml
from keras.optimizers import Adam
from test_core.load_test_data import load_cifar_corruption, load_svhn_test, load_cifar10, load_tinyimage_test
import time
import scipy.special

####################################
####### Evaluation Settings  #######
####################################
def parse_config(config_str):
    config = {}
    for item in config_str.split(','):
        key, value = item.split('=')
        value = value.strip().strip("'\"")  # To remove extra spaces and quotes
        config[key.strip()] = value
    return config

parser = argparse.ArgumentParser(description='Process some configuration.')
parser.add_argument('--config', type=str, required=True, 
                    help="Comma-separated list of key=value pairs for configuration, e.g., DATASET='CIFAR10', Architecture='RES18'")
args = parser.parse_args()
config = parse_config(args.config)
dataset_name = config.get('DATASET', 'CIFAR10')
backbone = config.get('Architecture', 'RES18')

enable_temp = True

print(f'DATASET: {dataset_name}; BACKBONE: {backbone}')

####################################
###### Load Distillation Model #####
####################################
def load_model(exp_num):

    global backbone
    global enable_temp
    if backbone == 'VGG16':
        from models.vgg16_edd import VGG16_EDD
        eddSNN = VGG16_EDD(input_shape=(32, 32, 3), num_classes=10)
    else:
        from models.res18_edd import resnet18_EDD
        eddSNN = resnet18_EDD(input_shape=(32, 32, 3), num_classes=10)

    opt=Adam(learning_rate=0.001)
    eddSNN.compile(optimizer=opt)

    model_path = 'path_to_where_the_single_EDD_model_saved/' + str(exp_num)

    with open(model_path + '_weights', 'rb') as file:
        weights = pickle.load(file)
    eddSNN.set_weights(weights)

    return eddSNN

####################################
##### Single Model Evaluation ######
####################################
def expected_entropy_pn(logits):
    """ Calculated expected entropy (data uncertainty) for a prior network.
    Assumes dirichlet distribution.

    Args:
        logits - A (N_data_points, N_classes) - vector

    Outputs:
        A (N_data_points, N_classes) - vector
    """
    logits = logits.astype(np.float64)

    alpha = np.exp(logits)
    alpha_0 = np.sum(alpha, axis=1, keepdims=True)
    probs = alpha / alpha_0

    return np.sum(-probs * (scipy.special.digamma(alpha + 1) - scipy.special.digamma(alpha_0 + 1)),
                  axis=1)


def single_model_evaluate(model, x_test, y_test, IFAcc): 
    logits = model.predict(x_test)
    pred_probs = tf.nn.softmax(logits, axis=-1)

    eps = 1e-12

    entropy = -np.sum(pred_probs*np.log2(pred_probs + eps), axis=-1)

    know_unc = expected_entropy_pn(logits)

    au_unc = entropy - know_unc

    uncertainty = {
        'TU': entropy,
        'EU': know_unc,
        'AU': au_unc
    }

    if IFAcc:
        m = tf.keras.metrics.CategoricalAccuracy()
        m.update_state(y_test, pred_probs)
        acc = m.result().numpy()
        m.reset_state()
    else:
        acc = None
 
    return  pred_probs, uncertainty, acc

def ensembel_distillation_evaluation(dataset):
    ######### Load models #########
    model_list = list()
    for exp_num in range(1, 16):
        edsnn = load_model(exp_num)
        model_list.append(edsnn)

    ######## Unified dictionary for saving the results ########
    eval_results =  dict()

    if dataset == 'CIFAR10':
        ######### Test on CIFAR10 dataset #########
        eval_results = {'pred': [], 'entropy': [], 'acc': [], 'label': []}
        
        (_, _), (x_cifar, y_cifar) = load_cifar10()
        ################################
        ######### Single Model #########
        ################################
        for i in range(15):
            model = model_list[i]
            pred, entropy, acc = single_model_evaluate(model, x_cifar, y_cifar, IFAcc=True)
            # Save the result of single model
            eval_results['pred'].append(pred)
            eval_results['entropy'].append(entropy)
            eval_results['acc'].append(acc)    

        eval_results['label'] = y_cifar

    elif 'CIFAR10C' in dataset:
        ######### Test on CIFAR10_C dataset #########  
        ######### Determine Corruption Type and Severity Level #########
        parts = dataset.split('_')
        if len(parts) == 3:
            cor_type = int(parts[1])
            sev_level = int(parts[2])
        else:
            print("Invalid dataset string format.")
                            
        x_test, y_test = load_cifar_corruption(cor_type, sev_level)
        
        eval_results = {'pred': [], 'entropy': [], 'acc': []}
        for i in range(15):
            model = model_list[i]
            pred, entropy, acc = single_model_evaluate(model, x_test, y_test, IFAcc=True)
            # Save the result of single model
            eval_results['pred'].append(pred)
            eval_results['entropy'].append(entropy)
            eval_results['acc'].append(acc)

    elif dataset == 'SVHN':
        ######### Test on SVHN dataset #########
        eval_results = {'pred': [], 'entropy': []}

        x_svhn, y_svhn = load_svhn_test()
        
        for i in range(15):           
            model = model_list[i]
            pred, entropy, _ = single_model_evaluate(model, x_svhn, y_svhn, IFAcc=False)
            # Save the result of single model
            eval_results['pred'].append(pred)
            eval_results['entropy'].append(entropy)
            

    elif dataset == 'TinyImage':
        ######### Test on CIFAR10 dataset #########
        eval_results = {'pred': [], 'entropy': []}

        x_tiny, y_tiny = load_tinyimage_test()
        ################################
        ######### Single Model #########
        ################################
        for i in range(15):
            model = model_list[i]
            pred, entropy, _ = single_model_evaluate(model, x_tiny, y_tiny, IFAcc=False)
            # Save the result of single model
            eval_results['pred'].append(pred)
            eval_results['entropy'].append(entropy)
            

    else:
        print("Invalid Dataset Name. Try again...")
    return eval_results


def main():
    global backbone
    global dataset_name

    start_time = time.time()
    result = ensembel_distillation_evaluation(dataset_name)
    end_time = time.time()
    print(f'Running Time: {end_time - start_time}')

    full_path = 'Path_to_save_the_test_result/' + dataset_name

    # Save test history
    with open(full_path + '_result', 'wb') as file:
        pickle.dump(result, file)

if __name__ == "__main__":
    main()
