import pickle
import numpy as np
import argparse
import yaml
from test_core.min_max_entropy_mod import min_max_entropy_calculation
import time
import multiprocessing


def load_single_model_pred_result(full_path, num):
    with open(full_path + '_result' + str(num), 'rb') as file5:
        result5 = pickle.load(file5) 
    return result5['pred']

def credal_uncertainty(pred, pred_avg):
    entropy = min_max_entropy_calculation(pred, pred_avg)
    return entropy
    
def load_eval_result(full_path, num):
    
    with open(full_path + '_lower_upper_probs' + str(num), 'rb') as f3:
        result5 = pickle.load(f3)   
        
    return result5
    
    
def credal_evaluation(dataset, num):

    full_path = 'ResNetTest30_CRE/' + dataset
    full_avg_path = 'ResNetTest30/' + dataset
    
    result5 = load_eval_result(full_path, num)

    preds5_single = load_single_model_pred_result(full_avg_path, num)
    preds5 = np.stack(preds5_single)
    
    un5 = []

    ########## Ensembles-5 ##########
    lower_probs5 = np.stack(result5['lower probs'])
    upper_probs5 = np.stack(result5['upper probs'])                                 
    for j in range(15):
        entropy5 = credal_uncertainty(np.append(lower_probs5[j], upper_probs5[j], axis=-1), np.mean(preds5[j], axis=0))
        un5.append(entropy5)
    
    return un5


def load_config(yaml_file):
    with open(yaml_file, 'r') as file:
        config = yaml.load(file, Loader=yaml.SafeLoader)
    return config

def calculate_credal_evaluation(params):
    dataset_name, num = params
    return credal_evaluation(dataset_name, num)
    
def main():
    
    # Accept a YAML file as a command-line argument
    parser = argparse.ArgumentParser(description='Process parameters from a YAML file.')
    parser.add_argument('config_file', type=str, help='Path to the YAML configuration file')
    args = parser.parse_args()

    config = load_config(args.config_file)

    # Access hyperparameters from the loaded configuration
    dataset_name = config['Dataset']

    # Create a multiprocessing pool with the desired number of processes
    # You can adjust the number of processes based on your CPU core count

    
    ens = [25]
    for i in range(1):
        num_processes = multiprocessing.cpu_count()
        print('num_processes: ', num_processes)

        pool = multiprocessing.Pool(processes=num_processes)
        
        num = ens[i]

        # Create a list of parameter tuples to be passed to the worker function
        params_list = [(dataset_name, num)] * num_processes

        # Use the map function to distribute the work to the pool
        results = pool.map(calculate_credal_evaluation, params_list)

        # Close the pool and wait for all processes to finish
        pool.close()
        pool.join()
    
        result = results[0]
    
        full_path = 'ResNetTest30_CRE/' + 'entropy/' + dataset_name
   
        # Save test history
        with open(full_path + '_uncertainty_mod'+str(num), 'wb') as file:
            pickle.dump(result, file)

if __name__ == "__main__":
    main()
