import pickle
import numpy as np
import argparse
import yaml
import time

def load_config(yaml_file):
    with open(yaml_file, 'r') as file:
        config = yaml.load(file, Loader=yaml.SafeLoader)
    return config

def load_lower_upper_probs(full_path, num):    
    with open(full_path + '_lower_upper_probs' + str(num), 'rb') as file5:
        result5 = pickle.load(file5)  
        
    return result5


def entropy_evaluate(intersection_probs):
    pred = intersection_probs
    
    eps = 1e-12
    entropy = -np.sum(pred*np.log2(pred + eps), axis=-1)

    return entropy

def compute_intersection_probability(upper_probs, lower_probs):
    alpha_num = 1.0 - np.sum(lower_probs, axis=-1, keepdims=True)
    alpha_denom = np.sum(upper_probs-lower_probs, axis=-1, keepdims=True)
    
    alpha = alpha_num/alpha_denom

    intersection_probs = (upper_probs - lower_probs) * alpha + 1.0 * lower_probs

    return intersection_probs
    

def lower_upper_probs_evaluation(dataset, num):
    ######### Get the Same Ensembles as Standards #########
    full_path = 'ResNetTest30_CRE/' + dataset
    
    result5 = load_lower_upper_probs(full_path, num)
    
    # Ensemble of Five
    intsec_pred5 = dict()
    intsec_pred5['prob'] = []
    intsec_pred5['entropy'] = []

    lower_probs5 = np.stack(result5['lower probs'])
    upper_probs5 = np.stack(result5['upper probs'])
    
    for j in range(15):
        
        prob_intsec = compute_intersection_probability(upper_probs5[j], lower_probs5[j])
        intsec_pred5['prob'].append(prob_intsec)
        entropy = entropy_evaluate(prob_intsec)
        
        intsec_pred5['entropy'].append(entropy)
        
    return intsec_pred5

def main():
    
    # Accept a YAML file as a command-line argument
    parser = argparse.ArgumentParser(description='Process parameters from a YAML file.')
    parser.add_argument('config_file', type=str, help='Path to the YAML configuration file')
    args = parser.parse_args()

    config = load_config(args.config_file)

    dataset_name = config['Dataset']

    ens = [3, 5, 10, 15, 20, 25]
    for i in range(6):
        num = ens[i]
        result5 = lower_upper_probs_evaluation(dataset_name, num)
    
        full_path = 'TestResNet30C_CRE/intersection_probs/' + dataset_name
        
        with open(full_path + '_intersec_probab' + str(num), 'wb') as file5:
            pickle.dump(result5, file5)

if __name__ == "__main__":
    main()

    