import numpy as np
import pickle
# import math
# import gurobipy as gp
from cleanup import student 

import argparse
import torch
import scipy.stats as stats
from sklearn.metrics import r2_score
from sklearn.preprocessing import PolynomialFeatures
import math 
import wandb 
from model_comparison import load_all_models_v2, get_true_value_of_allocation, generate_random_bundles, model_predict_values, measure_generalization_performance_all_students, keep_high_valued_bundles, calculate_allocation, measure_gui_language_metrics, compare_gui_reports
from tabu_search import binary_search, binary_search_hacky, construct_implied_dataset_v2, solve_student, create_actual_student_list, create_gui_student_list, create_iterative_student_list, project_mvnns
import wandb 
import random 
from acquisition_function_comparison import generate_new_queries, check_agreements_on_cq, get_ordinal_dataset_size
from preference_generator import generate_problem_instance_principled
import numpy as np
import random

# from pdb import set_trace


def generate_llm_cqs(number_of_cqs, method, capacities, previous_bundles, seed = None ): 
    if method == 'random': 
        bundles = generate_random_bundles(number_of_bundles= 500, number_of_courses= 25, seed = seed)  # generate the bundles that we will use 

        # create random CQ tuples 
        if previous_bundles is None:
            bundle_tuples = []
        else:
            bundle_tuples = previous_bundles.tolist()

        # set_trace()

        for _ in range(number_of_cqs - len(bundle_tuples)):
            bundle1 = random.choice(bundles)
            bundle2 = random.choice(bundles)
            while np.all(bundle1 == bundle2):
                bundle2 = random.choice(bundles)
            bundle_tuples.append((bundle1, bundle2))

        return bundle_tuples


def simulate_llm_answers_cqs(bundle_tuples, model_type, model_param_dictionary, 
            benchmark_model_type, benchmark_student, model_student_list,
            llm_simulation_dictionary = {'reply_encoding': 'hard', 'hard_labels_flip_probability': 0.1}  # options: 'one_hot', 'probability'
            ): 
    """
    Simulate the answers to the CQs for the given bundle tuples.
    Returns: 
    The probability p with which the LLM predicts that the first bundle is preferred to the second bundle.
    """

    labels = []
    for (bundle1, bundle2) in bundle_tuples:
        
        # get the true value of the bundles involved in each CQ 
        true_values = model_predict_values(model_type= benchmark_model_type, model_param_dictionary= None, model= benchmark_student, bundles= [bundle1, bundle2])

        value_diff = true_values[0] - true_values[1]

        if llm_simulation_dictionary['reply_encoding'] == 'hard':
            flip_probability = llm_simulation_dictionary['hard_labels_flip_probability']
            if value_diff > 0:
                true_label = 1
            else: 
                true_label = 0
            
            # flip the label with the given probability
            if random.random() < flip_probability:
                true_label = 1 - true_label
            labels.append(true_label)    

        elif llm_simulation_dictionary['reply_encoding'] == 'soft':
            temperature = llm_simulation_dictionary['temperature-smoothening']
            p = 1 / (1 + np.exp(-value_diff)) # the *true* probability that the first bundle is preferred to the second bundle according to a perfectly tuned MVNN 

            # add noise to the probability 
            p = p + np.random.normal(0, llm_simulation_dictionary['soft_labels_noise_std']) 
            p = np.clip(p, 0, 1)  # clip the probability to be between 0 and 1


            p = p ** (1 / temperature)   # smoothern/sharpen the labes in the same way as the MixMatch paper
            labels.append(p)


        else: 
            raise ValueError('Invalid reply encoding for the LLM CQs')
    
    return np.array(bundle_tuples), np.array(labels)   # those are the X_train_ord, y_train_ord for training the TL MVNN 


def main_function():
   
    # --- Mechanismp parameterds, no need to change --- #
    parser = argparse.ArgumentParser()
    parser.add_argument('--supply_ratio', type=float, default= 1.25, help='supply_ratio')
    parser.add_argument('--number_of_popular', type= int, default= 9, help='afffects_correlation')
    parser.add_argument('--model_family', type= str, default= 'BPpop9v.1', help='the models to run') # options: [PO/B]pop9_projected_v1.1_from_scratch_gui_reports, 'BPpop9v.1' and 'POpop9'
    parser.add_argument('--linear_instances', type= str, default= 'false', help='whether we are in linear instances or not')
    parser.add_argument('--seed', type= int, default= 627, help='the seed to use for the random number generator')
    parser.add_argument('--cq_method', type= str, default= 'basic_plus', help='acquisition function to use') # options: 'basic', 'random', 'basic_plus', 'complete_ordering', 'complete_ordering_pruned'
    parser.add_argument('--wandb_tracking', type= str, default= 'false', help='whether to track the results with wandb')
    parser.add_argument('--TL_model', type= str, default= 'true', help='whether to use transfer learning model')
    parser.add_argument('--project_to_gui', type= str, default= 'false', help='whether to log details about the MVNNs projected to the GUI language')
    parser.add_argument('--load_students', type= str, default= 'false', help='whether to load the students from the file, if not, they will be created')
    parser.add_argument('--student_to_check', type= int, default= 42, help='the student that this run should check')

    # --- new arguments for the generation of LLM CQs --- #
    parser.add_argument('--llm_cq_method', type= str, default= 'random', help='the method to use for generating the LLM CQs')
    parser.add_argument('--llm_cq_labels', type= str, default= 'hard', help='the method to use for generating the LLM CQ labels: options: hard, soft')  
    parser.add_argument('--llm_hard_labels_flip_probability', type= float, default= 0.2, help='the probability of flipping the label in the hard label case')
    parser.add_argument('--llm_soft_labels_noise_std', type= float, default= 0.00, help='the std of the noise to add to the soft labels')
    parser.add_argument('--llm_temperature_smoothening', type= float, default= 1, help='the temperature for the smoothening of the labels in the soft label case')
    
    # --- new arguments for learning on those LLM CQs --- #
    parser.add_argument('--llm_cq_loss_string', type= str, default= 'BCE', help='the loss function to use for the LLM CQs')  # should add robust to noise versions 
    parser.add_argument('--llm_cq_epochs', type= int, default= 10, help='the number of epochs to use for the LLM CQs')
    parser.add_argument('--llm_cq_lr', type= float, default= 0.01, help='the learning rate to use for the LLM CQs')
    parser.add_argument('--llm_cq_wd', type= float, default= 0.00, help='the regularization to use for the LLM CQs')
    parser.add_argument('--llm_cq_clip', type= float, default= 0.1, help='the gradient clipping to use for the LLM CQs')
    parser.add_argument('--llm_cq_batch_size', type= int, default= 8, help='the batch size to use for the LLM CQs')

    parser.add_argument('--llm_cq_trainsize', type= int, default= 5000, help='the number of training samples to use for the LLM CQs')

    # --- Making things more principled/ arguments from the efficient exploration for llms paper --- # 
    # parser.add_argument('--llm_cq_bsize', type= int, default= 10, help='the number of CQs to ask the LLM between training') 
    # parser.add_argument('--llm_cq_iterations', type= int, default= 100, help='the number of feedback iterations to do with the LLM CQs')
    # parser.add_argument('--boltzmann_temperature', type= float, default= 1, help='the temperature to use for the Boltzmann exploration')



    args = parser.parse_args()

    # --- Environment-related arguments --- # 
    supply_ratio = args.supply_ratio
    number_of_popular = args.number_of_popular
    index = args.student_to_check // 100
    student_index = args.student_to_check % 100
    model_family = args.model_family
    queries = '30-10'  # only needed to load the dictionaries, does not play any role 
    linear_instances = str(args.linear_instances).lower() == 'true'
    seed = args.seed 
    wandb_tracking = str(args.wandb_tracking).lower() == 'true'
    tl_model = str(args.TL_model).lower() == 'true'
    project_to_gui = str(args.project_to_gui).lower() == 'true'
    load_students = str(args.load_students).lower() == 'true'

    # --- LLM CQs-related arguments --- #
    llm_cq_method = args.llm_cq_method
    llm_cq_labels = args.llm_cq_labels
    llm_hard_labels_flip_probability = args.llm_hard_labels_flip_probability
    llm_soft_labels_noise_std = args.llm_soft_labels_noise_std
    llm_temperature_smoothening = args.llm_temperature_smoothening
    llm_cq_loss_string = args.llm_cq_loss_string
    llm_cq_epochs = args.llm_cq_epochs
    llm_cq_lr = args.llm_cq_lr
    llm_cq_wd = args.llm_cq_wd
    llm_cq_clip = args.llm_cq_clip
    llm_cq_batch_size = args.llm_cq_batch_size
    llm_cq_trainsize = args.llm_cq_trainsize

    # --- Model-related arguments --- #
    students_to_check = 1 

    if tl_model:
        # --- Transfer Learning Model (when we had the student answer the CQs without any noise) --- #
        # model_info = ('UNN_transfer_learning', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 
        #     'UNN_layers': 1, 'UNN_units': 20, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True, 'use_gradient_clipping': True,
        # 'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
        # 'sample_relative_frequencies': (1, 1, 1, 1),
        # 'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
        # 'lr_cardinal': 0.01, 'weight_decay_cardinal': 0.001, 'epochs_cardinal': 100, 'batch_size_cardinal': 8,  'UNN_loss_string_cardinal': 'l1', 'clip_cardinal': 0.1,
        # 'lr_ordinal':  0.01, 'weight_decay_ordinal': 0, 'epochs_ordinal': 10, 'batch_size_ordinal': 8,  'UNN_loss_string_ordinal': 'BCE', 'clip_ordinal': 0.1,
        # 'use_implied_dataset': True,
        # 'use_cqs': True, 'cq_method': 'complete_ordering_pruned'
        #                             })  # 'UNN_transfer_learning'

        model_info = ('UNN_transfer_learning', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 
            'UNN_layers': 1, 'UNN_units': 20, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True, 'use_gradient_clipping': True,
        'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
        'sample_relative_frequencies': (1, 1, 1, 1),
        'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
        'use_implied_dataset': True,'use_cqs': True, 'cq_method': 'complete_ordering_pruned',
        'lr_cardinal': 0.01, 'weight_decay_cardinal': 0.001, 'epochs_cardinal': 100, 'batch_size_cardinal': 8,  'UNN_loss_string_cardinal': 'l1', 'clip_cardinal': 0.1,
        # --- HPs that refer to the (used to be student) CQs --- # 
        'lr_ordinal': llm_cq_lr, 'weight_decay_ordinal': llm_cq_wd, 'epochs_ordinal': llm_cq_epochs, 'batch_size_ordinal': llm_cq_batch_size, 
        'UNN_loss_string_ordinal': llm_cq_loss_string, 'clip_ordinal': llm_cq_clip,
                                    })  # 'UNN_transfer_learning'
        
        

    # set_trace()
    
    if linear_instances:
        true_family = 'true_linear'
    else:
        true_family = 'true'

    result_dict = {} # will save the exact same details as the wandb tracking, but in a dictionary that we can save to a pickle file

    if wandb_tracking:
        wandb_config_dict = {
            'Supply Ratio': supply_ratio,
            'Number of Popular': number_of_popular,
            'Linear Instances': linear_instances,
            'Student Number': args.student_to_check,
            'Acquisition Function': args.cq_method,
            'Model Type': model_info[0], 
            'llm_cq_method': llm_cq_method,
            'llm_cq_labels': llm_cq_labels,
            'llm_hard_labels_flip_probability': llm_hard_labels_flip_probability,
            'llm_soft_labels_noise_std': llm_soft_labels_noise_std,
            'llm_temperature_smoothening': llm_temperature_smoothening,
            'llm_cq_loss_string': llm_cq_loss_string,
            'llm_cq_epochs': llm_cq_epochs,
            'llm_cq_lr': llm_cq_lr,
            'llm_cq_wd': llm_cq_wd,
            'llm_cq_clip': llm_cq_clip,
            'llm_cq_batch_size': llm_cq_batch_size,  
            'llm_cq_trainsize': llm_cq_trainsize     
        }

        run = wandb.init(project=f'HPO-MLCM-LLMs-v1.0', # TODO: change to appropriate name
                    config=wandb_config_dict,
                    reinit=True)

        wandb.define_metric("Elicited CQs") 
        
        # Standard Learning Metrics 
        wandb.define_metric("KT", step_metric="Elicited CQs")
        wandb.define_metric("R2", step_metric="Elicited CQs")
        wandb.define_metric("MAE", step_metric="Elicited CQs")
        wandb.define_metric("MSE", step_metric="Elicited CQs")
        wandb.define_metric("KT top 5", step_metric="Elicited CQs")
        wandb.define_metric("R2 top 5", step_metric="Elicited CQs")
        wandb.define_metric("MAE top 5", step_metric="Elicited CQs")
        wandb.define_metric("MSE top 5", step_metric="Elicited CQs")
        wandb.define_metric("KT top 10", step_metric="Elicited CQs")
        wandb.define_metric("R2 top 10", step_metric="Elicited CQs")
        wandb.define_metric("MAE top 10", step_metric="Elicited CQs")
        wandb.define_metric("MSE top 10", step_metric="Elicited CQs")

        # also add centered version of all above metrics 
        wandb.define_metric("R2 - Centered", step_metric="Elicited CQs")
        wandb.define_metric("MAE - Centered", step_metric="Elicited CQs")
        wandb.define_metric("MSE - Centered", step_metric="Elicited CQs")
        wandb.define_metric("R2 top 5 - Centered", step_metric="Elicited CQs")
        wandb.define_metric("MAE top 5 - Centered", step_metric="Elicited CQs")
        wandb.define_metric("MSE top 5 - Centered", step_metric="Elicited CQs")
        wandb.define_metric("R2 top 10 - Centered", step_metric="Elicited CQs")
        wandb.define_metric("MAE top 10 - Centered", step_metric="Elicited CQs")
        wandb.define_metric("MSE top 10 - Centered", step_metric="Elicited CQs")

        # GUI Metrics
        if project_to_gui:
            wandb.define_metric("Base Values", step_metric="Elicited CQs")
            wandb.define_metric("Adjustments", step_metric="Elicited CQs")
            wandb.define_metric("KT - GUI", step_metric="Elicited CQs")
            wandb.define_metric("R2 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("MAE - GUI", step_metric="Elicited CQs")
            wandb.define_metric("MSE - GUI", step_metric="Elicited CQs")
            wandb.define_metric("KT top 5 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("R2 top 5 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("MAE top 5 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("MSE top 5 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("KT top 10 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("R2 top 10 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("MAE top 10 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("MSE top 10 - GUI", step_metric="Elicited CQs")

        # Other metrics
        wandb.define_metric("Ordinal Dataset Size", step_metric="Elicited CQs")
        wandb.define_metric("Agreement on CQs", step_metric="Elicited CQs")
        wandb.define_metric("Allocated Bundle Value", step_metric="Elicited CQs")
        wandb.define_metric("Found New Max", step_metric="Elicited CQs")  # note: this mteric is only useful for the basic plus method 
        


    # --- Step 1. Load true and GUI (noisy) student lists, as well as prices --- #
    if load_students:
        print(f'supply_ratio: {supply_ratio} number_of_popular: {number_of_popular} index:{index} model_family: {model_family} queries: {queries}')
        (_, true_model_info, model_student_list, true_student_list,  prices_stage_1_GUI, timetable, capacities) = load_all_models_v2(index = index, model_family= model_family, benchmark_family = true_family, queries = '30-10', supply_ratio= supply_ratio, number_of_popular= number_of_popular, linear_instances = linear_instances)
    
        # set_trace()

    else:
        # create the model instance from scratch 
        true_student_lists_all_instances, capacities_all_instances, timetables_all_instances = generate_problem_instance_principled(number_of_times= 1, number_of_students= 100, number_of_courses= 25,
                supply_ratio= supply_ratio, number_of_popular= number_of_popular, seed= seed + student_index)
        
        true_student_list = true_student_lists_all_instances[0]
        capacities = capacities_all_instances[0]
        timetable = timetables_all_instances[0]

        prices_stage_1_GUI = np.array([0.2 for _ in range(25)])  # just a price vector so that each student can take 5 courses 
        
        if not linear_instances:
            model_info = ('UNN_transfer_learning', 
            {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 1, 'UNN_units': 20, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True, 'use_gradient_clipping': True, 
          'gui_forget_base': 0.5, 'gui_forget_adjustments': 0.48, 'gui_base_noise_std': 23, 'gui_adj_std': 0.2, 'gui_sample_categories_weights': [1, 1, 1], 
          'sample_relative_frequencies': (1, 1, 1, 1), 'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None, 'use_implied_dataset': True, 
          'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'lr_cardinal': 0.01, 'weight_decay_cardinal': 0.001, 'epochs_cardinal': 100, 'batch_size_cardinal': 8, 'UNN_loss_string_cardinal': 'l1', 'clip_cardinal': 0.1, 
          'lr_ordinal': 0.01, 'weight_decay_ordinal': 0.0, 'epochs_ordinal': 10, 'batch_size_ordinal': 8, 'UNN_loss_string_ordinal': 'BCE', 'clip_ordinal': 0.1})
            true_model_info = ('True', None)
    

    
    
    model_type = model_info[0]
    model_param_dictionary = model_info[1]
    number_of_courses = capacities.shape[0]
    model_param_dictionary['cq_method'] = args.cq_method

    # add extra information to the model_param_dictionary for projection. 
    projection_dictionary = {'project_to_gui': True, 'approximate_prices_model': 'gui_reports' ,'proj_alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 500, 'train_high_samples': 500,
        'linear_projection': linear_instances}
    model_param_dictionary.update(projection_dictionary)
    
    # true_student_list = [true_student_list[student_index]]
    true_student_list = true_student_list[student_index: student_index + students_to_check]

    # set the seeds for reproducability.  
    torch.manual_seed(seed + student_index)
    np.random.seed(seed + student_index)
    random.seed(seed + student_index)


    # --- Step 2. Generate the bundles that we will use for the generalization test --- # 
    
    bundles = generate_random_bundles(number_of_bundles= 2000, number_of_courses= 25, seed = seed)
    bundles_5_percentile = keep_high_valued_bundles(bundles= bundles, percentile = 95, benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0])
    bundles_10_percentile = keep_high_valued_bundles(bundles= bundles, percentile = 90, benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0])

    # bundles, values = keep_high_valued_bundles(bundles= bundles, percentile = 0, benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0], return_values= True)
    # bundles_5_percentile, values_5_percentile = keep_high_valued_bundles(bundles= bundles, percentile = 95, benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0], return_values= True)
    # bundles_10_percentile, values_10_percentile = keep_high_valued_bundles(bundles= bundles, percentile = 90, benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0], return_values= True)
    # set_trace()


    # --- Step 3. Generate the 'initial' training dataset based on the GUI reports --- #

    actual_student_list = create_actual_student_list(true_student_list, model_type = model_info[0], model_param_dictionary = model_param_dictionary, seed = seed)  # just returns the true list: NOTE: should remove
    gui_student_list = create_gui_student_list(true_student_list, model_type = model_type, model_param_dictionary = model_param_dictionary, seed = seed)

    
    training_set = generate_new_queries(actual_student_list = actual_student_list, timetable = timetable, number_of_samples = 0,  # number of samples is not used when generating the original dataset. Instead, it is read from the model_dictionary
                    current_training_set = [], model_student_list = [], approximate_prices = None,
                    credit_units = [1 for i in range(number_of_courses)], model_type = model_type, seed= seed + args.student_to_check, model_param_dictionary = model_param_dictionary, gui_student_list = gui_student_list)

    model_student_list = None 


    # --- Step 3.5  Generate initial student list based on the initial training set   --- # 
    model_student_list = create_iterative_student_list(training_set, actual_student_list, credit_units = [1 for i in range(number_of_courses)],
                            timetable = timetable, model_type = model_type, model_param_dictionary = model_param_dictionary, 
                            model_student_list= model_student_list)
    
    
    # set_trace()

    elicited_cq_list = [5 * i for i in range(40)]
    elicited_cq_list = elicited_cq_list + [200 + 10 * i for i in range(30)]
    elicited_cq_list = elicited_cq_list + [500 + 50 * i for i in range(30)]
    elicited_cq_list = elicited_cq_list + [2000 + 100 * i for i in range(31)]

    # Reduce the list of LLM cqs to only have the numbers up to the llm_cq_trainsize
    elicited_cq_list = [i for i in elicited_cq_list if i <= llm_cq_trainsize]
    # elicited_cq_list = [5, 10, 100]

    print('elicited_cq_list:', elicited_cq_list)
    # set_trace()
    cq_bundle_tuples = None
    
    for elicited_llm_cqs in elicited_cq_list: 
        print(f'--------> Number of LLM CQs: {elicited_llm_cqs}')
        # set_trace()

        
        # --- Step 4: Generate the LLM CQs and get their answers --- # 
        for i in range(students_to_check):
            cq_bundle_tuples = generate_llm_cqs(number_of_cqs= elicited_llm_cqs, method= 'random', capacities= capacities, 
                                            previous_bundles= cq_bundle_tuples, seed = seed + args.student_to_check)

        
            cq_bundle_tuples, labels = simulate_llm_answers_cqs(bundle_tuples= cq_bundle_tuples, model_type= model_info[0], 
                model_param_dictionary= model_info[1], benchmark_model_type= true_model_info[0], benchmark_student= true_student_list[i], model_student_list= model_student_list,
                llm_simulation_dictionary = {'reply_encoding': llm_cq_labels, 'hard_labels_flip_probability': llm_hard_labels_flip_probability,
                                            'soft_labels_noise_std': llm_soft_labels_noise_std, 'temperature-smoothening': llm_temperature_smoothening})  
        
        
       
             # Add those to the training set
            training_set[i][1] = (cq_bundle_tuples, labels)


        # Step 4. Train the models on the current dataset -- NOTE: no need to change anything below this point 
        model_student_list = create_iterative_student_list(training_set, actual_student_list, credit_units = [1 for i in range(number_of_courses)],
                        timetable = timetable, model_type = model_type, model_param_dictionary = model_param_dictionary, 
                        model_student_list= model_student_list)
        
        # set_trace()
        
        
        # Step 5. Measure the performance of the models
    
        # Step 5a) Measure allocation value of the model 
        allocation_model = calculate_allocation(model_list= model_student_list, prices= prices_stage_1_GUI, timetable= timetable, models_to_run= [model_info])
        allocation_value_model = get_true_value_of_allocation(benchmark_student_list= true_student_list, individual_demands= allocation_model, benchmark_model= true_model_info[0], timetable = timetable)
        # set_trace()

        

        # Step 6.a) Measure generalization performance of the ML models and GUI reports
        kts_model, r2s_model, maes_model, mses_model = measure_generalization_performance_all_students(bundles_all_students= [bundles for i in range(len(true_student_list))], model_type= model_info[0], 
                                        model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_list)
        kts_model_top_5, r2s_model_top_5, maes_model_top_5, mses_model_top_5 = measure_generalization_performance_all_students(bundles_all_students= [bundles_5_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                        model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_list)
        kts_model_top_10, r2s_model_top_10, maes_model_top_10, mses_model_top_10 = measure_generalization_performance_all_students(bundles_all_students= [bundles_10_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                        model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_list)
        

        # measure the centered versions of the metrics
        _, r2s_model_centered, maes_model_centered, mses_model_centered = measure_generalization_performance_all_students(bundles_all_students= [bundles for i in range(len(true_student_list))], model_type= model_info[0], 
                                        model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_list, centered_metrics= True)
        _, r2s_model_top_5_centered, maes_model_top_5_centered, mses_model_top_5_centered = measure_generalization_performance_all_students(bundles_all_students= [bundles_5_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                        model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_list, centered_metrics= True)
        _, r2s_model_top_10_centered, maes_model_top_10_centered, mses_model_top_10_centered = measure_generalization_performance_all_students(bundles_all_students= [bundles_10_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                        model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_list, centered_metrics= True)

        # set_trace()
    
        
        # # Step6.b) Measure number of agreements between the GUI reports and the true preferences
        # agreements = check_agreements_on_cq(model_student_list= model_student_list, model_info= model_info,benchmark_student_list= true_student_list, benchmark_model_info= true_model_info, cqs= last_queries, timetable= timetable)
        
        ordinal_info_size = get_ordinal_dataset_size(training_dataset= training_set, model_param_dictionary= model_param_dictionary)

        # Step 7. Project the MVNNs to measure results in the GUI language
        if project_to_gui:
            projected_mvnns = project_mvnns(mvnn_student_list= model_student_list, model_param_dictionary= model_param_dictionary)

            # Measure the number of base values and adjustments 
            base_values_model, adjustments_model = measure_gui_language_metrics(model_list = projected_mvnns, model_type = 'UNN_projected')
        
    
            # Measure similarity between projected ML models/GUI reports and the true (additive) preferences
            gui_language_kts_model_top10, gui_language_r2s_model_top10, gui_language_maes_model_top10, gui_language_mses_model_top10 = compare_gui_reports(true_student_list= true_student_list, other_student_list= projected_mvnns, model_type= 'UNN_projected', linear_instances = linear_instances, percentile= 60)
            gui_language_kts_model, gui_language_r2s_model, gui_language_maes_model, gui_language_mses_model = compare_gui_reports(true_student_list= true_student_list, other_student_list= projected_mvnns, model_type= 'UNN_projected', linear_instances = linear_instances)
            gui_language_kts_model_top5, gui_language_r2s_model_top5, gui_language_maes_model_top5, gui_language_mses_model_top5 = compare_gui_reports(true_student_list= true_student_list, other_student_list= projected_mvnns, model_type= 'UNN_projected', linear_instances = linear_instances, percentile= 80)
        
        
        
        wandb_dict = {
            'Elicited CQs': elicited_llm_cqs,
            'KT': np.mean(kts_model),
            'R2': np.mean(r2s_model),
            'MAE': np.mean(maes_model),
            'MSE': np.mean(mses_model),
            'KT top 5': np.mean(kts_model_top_5),
            'R2 top 5': np.mean(r2s_model_top_5),
            'MAE top 5': np.mean(maes_model_top_5),
            'MSE top 5': np.mean(mses_model_top_5),
            'KT top 10': np.mean(kts_model_top_10),
            'R2 top 10': np.mean(r2s_model_top_10),
            'MAE top 10': np.mean(maes_model_top_10),
            'MSE top 10': np.mean(mses_model_top_10),
            'R2 - Centered': np.mean(r2s_model_centered),
            'MAE - Centered': np.mean(maes_model_centered),
            'MSE - Centered': np.mean(mses_model_centered),
            'R2 top 5 - Centered': np.mean(r2s_model_top_5_centered),
            'MAE top 5 - Centered': np.mean(maes_model_top_5_centered),
            'MSE top 5 - Centered': np.mean(mses_model_top_5_centered),
            'R2 top 10 - Centered': np.mean(r2s_model_top_10_centered),
            'MAE top 10 - Centered': np.mean(maes_model_top_10_centered),
            'MSE top 10 - Centered': np.mean(mses_model_top_10_centered),
            'Allocated Bundle Value': np.mean(allocation_value_model),
        }
        # if model_param_dictionary['cq_method'] == 'basic_plus':
        #     wandb_dict['Found New Max'] = training_set[0][5]

        if project_to_gui:
            gui_dict = {
                'Base Values': base_values_model,
                'Adjustments': adjustments_model,
                'KT - GUI': gui_language_kts_model,
                'R2 - GUI': gui_language_r2s_model,
                'MAE - GUI': gui_language_maes_model,
                'MSE - GUI': gui_language_mses_model,
                'KT top 5 - GUI': gui_language_kts_model_top5,
                'R2 top 5 - GUI': gui_language_r2s_model_top5,
                'MAE top 5 - GUI': gui_language_maes_model_top5,
                'MSE top 5 - GUI': gui_language_mses_model_top5,
                'KT top 10 - GUI': gui_language_kts_model_top10,
                'R2 top 10 - GUI': gui_language_r2s_model_top10,
                'MAE top 10 - GUI': gui_language_maes_model_top10,
                'MSE top 10 - GUI': gui_language_mses_model_top10,
            }

            wandb_dict.update(gui_dict)

        result_dict[elicited_llm_cqs] = wandb_dict

        # set_trace()
        if wandb_tracking:
            wandb.log(wandb_dict)

            

        print(f'--------> Number of LLM CQs: {elicited_llm_cqs}, ordinal_info_size: {ordinal_info_size}, allocation values: {allocation_value_model}')

        # set_trace()

    if wandb_tracking: 
        run.finish()

    # save the results to a pickle file
    # dict_name = f'./acquisition_function_results/SR_{supply_ratio}_NP_{number_of_popular}_LI_{linear_instances}_AF_{args.cq_method}_student_{args.student_to_check}.pkl'
    # open the file for writing, and create it if it does not exist
    # with open(dict_name, 'wb+') as f:
    #     pickle.dump(result_dict, f)

    return 

if __name__ == '__main__':
    main_function()