import numpy as np
import pickle
# import math
# import gurobipy as gp
from cleanup import student 

import argparse
import torch
import scipy.stats as stats
from sklearn.metrics import r2_score
from sklearn.preprocessing import PolynomialFeatures
import math 
import wandb 
from model_comparison import load_all_models_v2, get_true_value_of_allocation, generate_random_bundles, model_predict_values, measure_generalization_performance_all_students, keep_high_valued_bundles, calculate_allocation, measure_gui_language_metrics, compare_gui_reports
from tabu_search import binary_search, binary_search_hacky, construct_implied_dataset_v2, solve_student, create_actual_student_list, create_gui_student_list, create_iterative_student_list, project_mvnns
import wandb 
import random 
from acquisition_function_comparison import generate_new_queries, check_agreements_on_cq, get_ordinal_dataset_size
from preference_generator import generate_problem_instance_principled
import numpy as np
import random

from pdb import set_trace


def generate_llm_cqs(number_of_cqs, acquisition_function, capacities, number_of_candidate_bundles, 
            model_type, model_param_dictionary, model_list,
            boltzmann_temperature, max_attempts = 20,
                     seed = None): 
    """
    A function to generate the CQs for the LLM to answer in this iteration.

    Parameters:
    - number_of_cqs (int): the number of CQs to generate in this iteration
    - acquisition_function (str): the acquisition function to use for generating the CQs: options: random, boltzmann, infomax, doubleTS
    - capacities (np.array): the capacities of the courses
    - number_of_candidate_bundles (int): the number of candidate bundles to generate. Those will be used to generate the CQs
    - model_type (str): the type of the underlying value model
    - model_param_dictionary (dict): the parameters of the value model
    - model_list (either single model or list): For non uncertainty-based acquisition functions, this is the student's learned value model, wrapped in a list. 
    For uncertainty-based acquisition functions, this is the collection of models to use for the uncertainty estimation.
    - boltzmann_temperature (float): the temperature to use for the Boltzmann exploration
    - max_attempts (int): the maximum number of attempts to find a different bundle in the doubleTS acquisition function

    Returns:
    - bundle_tuples (list): a list of tuples, where each tuple contains two bundles. Those are the CQs that the LLM should answer in this iteration.
    """
    

    bundle_tuples =  [] 
    number_of_courses = capacities.shape[0]

    if seed is not None: 
        raise ValueError('Seed is not None, which is not properly implemented yet')
    
    if acquisition_function == 'random': 
        bundles = generate_random_bundles(number_of_bundles= 500, number_of_courses= number_of_courses, seed = seed)  # generate the bundles that we will use 

        
        for _ in range(number_of_cqs):
            bundle1 = random.choice(bundles)
            bundle2 = random.choice(bundles)
            while np.all(bundle1 == bundle2):
                bundle2 = random.choice(bundles)
            bundle_tuples.append((bundle1, bundle2))

    elif acquisition_function == 'boltzmann':
        model = model_list[0]  # only use the first model for the Boltzmann exploration, we do not use bagging here.
        for cq_number in range(number_of_cqs):
            print(f'Generating CQ number: {cq_number}')
            
            # 1. generate the candidate bundles for the CQ
            bundles = generate_random_bundles(number_of_bundles= number_of_candidate_bundles, number_of_courses= number_of_courses, seed = seed)

            # 2. get the value of the bundles from the value model
            values = model_predict_values(model_type= model_type, model_param_dictionary= model_param_dictionary, model= model, bundles= bundles)

            # 3. calculate the probability of choosing each bundle
            probabilities = np.exp(values / boltzmann_temperature) / np.sum(np.exp(values / boltzmann_temperature))
            # set_trace()

            # 4. choose the first bundle
            bundle1 = random.choices(bundles, probabilities)[0]

            # 5. choose the second bundle
            bundle2 = random.choices(bundles, probabilities)[0]

            
            attempt = 0
            while np.all(bundle1 == bundle2) and attempt < max_attempts:
                bundle2 = random.choices(bundles, probabilities)[0]

                attempt += 1
            if attempt == max_attempts:
                print(f'Warning: could not find a different bundle after {max_attempts} attempts, choosing randomly')
                bundle2 = random.choice(bundles)
            
            bundle_tuples.append((bundle1, bundle2))


    elif acquisition_function == 'infomax':
        for _ in range(number_of_cqs): 

            # 1. generate the candidate bundles for the CQ
            bundles = generate_random_bundles(number_of_bundles= number_of_candidate_bundles, number_of_courses= number_of_courses, seed = None)
            probabilities_all_models = []

            for single_model in model_list:
                # 2. get the value of the bundles from all the models included in the collection 
                values = model_predict_values(model_type= model_type, model_param_dictionary= model_param_dictionary, model= single_model, bundles= bundles)
                probabilities_current_model = [] 

                # 3. predict the probability of choosing each bundle for each tuple for the current model, using the bradley-terry model
                cq_tuples = [] 
                for i in range(len(bundles)):
                    for j in range(i + 1, len(bundles)):
                        cq_tuples.append(i,j)
                        value_diff = values[i] - values[j]
                        p = 1 / (1 + np.exp(-value_diff))
                        probabilities_current_model.append(p)
                        
                probabilities_all_models.append(probabilities_current_model)
                
            # for each bundle tuple, calculate mean and variance of the probabilities
            # mean_probabilities = np.mean(probabilities_all_models, axis= 0)
            variance_probabilities = np.var(probabilities_all_models, axis= 0)

            # 4. choose the bundle tuple with the highest variance
            max_variance_index = np.argmax(variance_probabilities)
            bundle1_index, bundle2_index = cq_tuples[max_variance_index]
            bundle_tuples.append((bundles[bundle1_index], bundles[bundle2_index]))

    elif acquisition_function == 'doubleTS':
        for _ in range(number_of_cqs): 

            # 1. generate the candidate bundles for the CQ
            bundles = generate_random_bundles(number_of_bundles= number_of_candidate_bundles, number_of_courses= number_of_courses, seed = None)
            

            # sample the index of the model to use for the first TS
            model_index = random.randint(0, len(model_list) - 1)
            single_model = model_list[model_index]

            # 2. get the value of the bundles from the value model
            bundle_values = model_predict_values(model_type= model_type, model_param_dictionary= model_param_dictionary, model= single_model, bundles= bundles)

            # 3. select the bundle with the highest value
            bundle1_index = np.argmax(bundle_values)
            bundle1 = bundles[bundle1_index]

            bundle2_index = bundle1_index
            bundle2 = bundles[bundle2_index]
            attempt = 0
            # set_trace()

            while bundle2_index == bundle1_index and attempt < max_attempts: 
                # sample the index of the model to use for the second TS
                model_index = random.randint(0, len(model_list) - 1)
                print(f'DoubleTS using model index: {model_index}')
                single_model = model_list[model_index]

                # 2. get the value of the bundles from the value model
                bundle_values = model_predict_values(model_type= model_type, model_param_dictionary= model_param_dictionary, model= single_model, bundles= bundles)

                # 3. select the bundle with the highest value
                bundle2_index = np.argmax(bundle_values)
                bundle2 = bundles[bundle2_index]

                attempt += 1

            if bundle2_index == bundle1_index:
                print(f'Warning: DoubleTS could not find a different bundle after {max_attempts} attempts')
                bundle2 = random.choice(bundles)
            else: 
                print(f'DoubleTS found a different bundle after {attempt} attempts')
            bundle_tuples.append((bundle1, bundle2))
    
    
    return bundle_tuples


def simulate_llm_answers_cqs(bundle_tuples, model_type, model_param_dictionary, 
            benchmark_model_type, benchmark_student, model_student_list,
            llm_simulation_dictionary = {'reply_encoding': 'hard', 'hard_labels_flip_probability': 0.1}  # options: 'one_hot', 'probability'
            ): 
    """
    Simulate the answers to the CQs for the given bundle tuples.
    Returns: 
    The probability p with which the LLM predicts that the first bundle is preferred to the second bundle.
    """

    labels = []
    for (bundle1, bundle2) in bundle_tuples:
        
        # get the true value of the bundles involved in each CQ 
        true_values = model_predict_values(model_type= benchmark_model_type, model_param_dictionary= None, model= benchmark_student, bundles= [bundle1, bundle2])

        value_diff = true_values[0] - true_values[1]

        if llm_simulation_dictionary['reply_encoding'] == 'hard':
            flip_probability = llm_simulation_dictionary['hard_labels_flip_probability']
            if value_diff > 0:
                true_label = 1
            else: 
                true_label = 0
            
            # flip the label with the given probability
            if random.random() < flip_probability:
                true_label = 1 - true_label
            labels.append(true_label)    

        elif llm_simulation_dictionary['reply_encoding'] == 'soft':
            temperature = llm_simulation_dictionary['temperature-smoothening']
            p = 1 / (1 + np.exp(-value_diff)) # the *true* probability that the first bundle is preferred to the second bundle according to a perfectly tuned MVNN 

            # add noise to the probability 
            p = p + np.random.normal(0, llm_simulation_dictionary['soft_labels_noise_std']) 
            p = np.clip(p, 0, 1)  # clip the probability to be between 0 and 1


            p = p ** (1 / temperature)   # smoothern/sharpen the labes in the same way as the MixMatch paper
            labels.append(p)


        else: 
            raise ValueError('Invalid reply encoding for the LLM CQs')
    
    return np.array(bundle_tuples), np.array(labels)   # those are the X_train_ord, y_train_ord for training the TL MVNN 


def main_function():
   
    # --- Mechanismp parameterds, no need to change --- #
    parser = argparse.ArgumentParser()
    parser.add_argument('--supply_ratio', type=float, default= 1.25, help='supply_ratio')
    parser.add_argument('--number_of_popular', type= int, default= 9, help='afffects_correlation')
    parser.add_argument('--model_family', type= str, default= 'BPpop9v.1', help='the models to run') # options: [PO/B]pop9_projected_v1.1_from_scratch_gui_reports, 'BPpop9v.1' and 'POpop9'
    parser.add_argument('--linear_instances', type= str, default= 'false', help='whether we are in linear instances or not')
    parser.add_argument('--seed', type= int, default= 627, help='the seed to use for the random number generator')
    parser.add_argument('--cq_method', type= str, default= 'basic_plus', help='acquisition function to use') # options: 'basic', 'random', 'basic_plus', 'complete_ordering', 'complete_ordering_pruned'
    parser.add_argument('--wandb_tracking', type= str, default= 'true', help='whether to track the results with wandb')
    parser.add_argument('--TL_model', type= str, default= 'true', help='whether to use transfer learning model')
    parser.add_argument('--project_to_gui', type= str, default= 'false', help='whether to log details about the MVNNs projected to the GUI language')
    parser.add_argument('--load_students', type= str, default= 'false', help='whether to load the students from the file, if not, they will be created')
    parser.add_argument('--student_to_check', type= int, default= 42, help='the student that this run should check')

    # --- new arguments for the simulation of LLM CQ replies --- #
    parser.add_argument('--llm_cq_labels', type= str, default= 'hard', help='the method to use for generating the LLM CQ labels: options: hard, soft')  
    parser.add_argument('--llm_hard_labels_flip_probability', type= float, default= 0.3, help='the probability of flipping the label in the hard label case')
    parser.add_argument('--llm_soft_labels_noise_std', type= float, default= 0.00, help='the std of the noise to add to the soft labels')
    parser.add_argument('--llm_temperature_smoothening', type= float, default= 1, help='the temperature for the smoothening of the labels in the soft label case')
    
    # --- new arguments for learning on those LLM CQs --- #
    parser.add_argument('--llm_cq_loss_string', type= str, default= 'GCE', help='the loss function to use for the LLM CQs')  # options: BCE, GCE 
    parser.add_argument('--llm_cq_epochs', type= int, default= 50, help='the number of epochs to use for the LLM CQs')
    parser.add_argument('--llm_cq_lr', type= float, default= 0.005, help='the learning rate to use for the LLM CQs')
    parser.add_argument('--llm_cq_wd', type= float, default= 0.001, help='the regularization to use for the LLM CQs')
    parser.add_argument('--llm_cq_clip', type= float, default= 10, help='the gradient clipping to use for the LLM CQs')
    parser.add_argument('--llm_cq_batch_size', type= int, default= 1, help='the batch size to use for the LLM CQs')
    parser.add_argument('--llm_gce_q', type= float, default= 0.001, help='the q parameter to use for the GCE loss')

    # parser.add_argument('--llm_cq_trainsize', type= int, default= 5000, help='the number of training samples to use for the LLM CQs')

    # --- Making things more principled/ arguments for the LLM acquisition function --- # 
    parser.add_argument('--llm_cq_acquisition_function', type= str, default= 'doubleTS', help='the acquisition function to use for generating the LLM CQs') # options: random, boltzmann, infomax, doubleTS
    parser.add_argument('--llm_cq_bsize', type= int, default= 10, help='the number of CQs to ask the LLM between training')
    parser.add_argument('--llm_cq_iterations', type= int, default= 2, help='the number of feedback iterations to do with the LLM CQs')
    parser.add_argument('--llm_cq_candidates', type= int, default= 100, help='the number of candidate bundles from which to build the LLM CQs')
    parser.add_argument('--boltzmann_temperature', type= float, default= 0.5, help='the temperature to use for the Boltzmann exploration')
    parser.add_argument('--value_model_collection_size', type= int, default= 10, help='the number of models to use for the value model uncertainty estimation')  # only applies to uncertainty-based acquisition functions: infomax, doubleTS 

    



    args = parser.parse_args()

    # --- Environment-related arguments --- # 
    supply_ratio = args.supply_ratio
    number_of_popular = args.number_of_popular
    index = args.student_to_check // 100
    student_index = args.student_to_check % 100
    model_family = args.model_family
    queries = '30-10'  # only needed to load the dictionaries, does not play any role 
    linear_instances = str(args.linear_instances).lower() == 'true'
    seed = args.seed 
    wandb_tracking = str(args.wandb_tracking).lower() == 'true'
    tl_model = str(args.TL_model).lower() == 'true'
    project_to_gui = str(args.project_to_gui).lower() == 'true'
    load_students = str(args.load_students).lower() == 'true'

    # --- LLM CQs-related arguments --- #
    llm_cq_labels = args.llm_cq_labels
    llm_hard_labels_flip_probability = args.llm_hard_labels_flip_probability
    llm_soft_labels_noise_std = args.llm_soft_labels_noise_std
    llm_temperature_smoothening = args.llm_temperature_smoothening
    llm_cq_loss_string = args.llm_cq_loss_string
    llm_cq_epochs = args.llm_cq_epochs
    llm_cq_lr = args.llm_cq_lr
    llm_cq_wd = args.llm_cq_wd
    llm_cq_clip = args.llm_cq_clip
    llm_cq_batch_size = args.llm_cq_batch_size
    llm_gce_q = args.llm_gce_q

    # --- Model-related arguments --- #
    students_to_check = 10   # NOTE: This affects how many students each individual run will check. SHould be 1 in all experiments other than HPO  


    # --- Acquisition function-related arguments --- #
    llm_cq_acquisition_function = args.llm_cq_acquisition_function
    llm_cq_bsize = args.llm_cq_bsize
    llm_cq_iterations = args.llm_cq_iterations
    llm_cq_candidates = args.llm_cq_candidates
    boltzmann_temperature = args.boltzmann_temperature
    value_model_collection_size = args.value_model_collection_size


    if tl_model:
        # --- Transfer Learning Model (when we had the student answer the CQs without any noise) --- #
        # model_info = ('UNN_transfer_learning', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 
        #     'UNN_layers': 1, 'UNN_units': 20, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True, 'use_gradient_clipping': True,
        # 'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
        # 'sample_relative_frequencies': (1, 1, 1, 1),
        # 'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
        # 'lr_cardinal': 0.01, 'weight_decay_cardinal': 0.001, 'epochs_cardinal': 100, 'batch_size_cardinal': 8,  'UNN_loss_string_cardinal': 'l1', 'clip_cardinal': 0.1,
        # 'lr_ordinal':  0.01, 'weight_decay_ordinal': 0, 'epochs_ordinal': 10, 'batch_size_ordinal': 8,  'UNN_loss_string_ordinal': 'BCE', 'clip_ordinal': 0.1,
        # 'use_implied_dataset': True,
        # 'use_cqs': True, 'cq_method': 'complete_ordering_pruned'
        #                             })  # 'UNN_transfer_learning'

        # --- Transfer Learning Model (with the LLM-answered cqs, but without a robust to noise loss) --- #
        # model_info = ('UNN_transfer_learning', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 
        #     'UNN_layers': 1, 'UNN_units': 20, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True, 'use_gradient_clipping': True,
        # 'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
        # 'sample_relative_frequencies': (1, 1, 1, 1),
        # 'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
        # 'use_implied_dataset': True,'use_cqs': True, 'cq_method': 'complete_ordering_pruned',
        # 'lr_cardinal': 0.01, 'weight_decay_cardinal': 0.001, 'epochs_cardinal': 100, 'batch_size_cardinal': 8,  'UNN_loss_string_cardinal': 'l1', 'clip_cardinal': 0.1,
        # # --- HPs that refer to the (used to be student) CQs --- # 
        # 'lr_ordinal': llm_cq_lr, 'weight_decay_ordinal': llm_cq_wd, 'epochs_ordinal': llm_cq_epochs, 'batch_size_ordinal': llm_cq_batch_size, 
        # 'UNN_loss_string_ordinal': llm_cq_loss_string, 'clip_ordinal': llm_cq_clip,
        #                             })  # 'UNN_transfer_learning'
        
        # --- Transfer Learning Model (with a noise-robust loss) --- #
        model_info = ('UNN_transfer_learning', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 
            'UNN_layers': 1, 'UNN_units': 20, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True, 'use_gradient_clipping': True,
        'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
        'sample_relative_frequencies': (1, 1, 1, 1),
        'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
        'use_implied_dataset': True,'use_cqs': True, 'cq_method': 'complete_ordering_pruned',
        'lr_cardinal': 0.01, 'weight_decay_cardinal': 0.001, 'epochs_cardinal': 100, 'batch_size_cardinal': 8,  'UNN_loss_string_cardinal': 'l1', 'clip_cardinal': 0.1,
        # --- HPs that refer to the (used to be student) CQs --- # 
        'lr_ordinal': llm_cq_lr, 'weight_decay_ordinal': llm_cq_wd, 'epochs_ordinal': llm_cq_epochs, 'batch_size_ordinal': llm_cq_batch_size, 
        'UNN_loss_string_ordinal': llm_cq_loss_string, 'GCE_q': llm_gce_q,   # parameters related to the BCE/GCE loss
        'clip_ordinal': llm_cq_clip,
                                    })  # 'UNN_transfer_learning'
        
    # set_trace()
    
    if linear_instances:
        true_family = 'true_linear'
    else:
        true_family = 'true'

    result_dict = {} # will save the exact same details as the wandb tracking, but in a dictionary that we can save to a pickle file

    # llm_cq_acquisition_function = args.llm_cq_acquisition_function
    # llm_cq_bsize = args.llm_cq_bsize
    # llm_cq_iterations = args.llm_cq_iterations
    # llm_cq_candidates = args.llm_cq_candidates
    # boltzmann_temperature = args.boltzmann_temperature
    # value_model_collection_size = args.value_model_collection_size

    if wandb_tracking:
        wandb_config_dict = {
            'Supply Ratio': supply_ratio,
            'Number of Popular': number_of_popular,
            'Linear Instances': linear_instances,
            'Student Number': args.student_to_check,
            # 'Acquisition Function': args.cq_method,
            'Model Type': model_info[0], 
            # 'llm_cq_method': llm_cq_method,
            'llm_cq_labels': llm_cq_labels,
            'llm_hard_labels_flip_probability': llm_hard_labels_flip_probability,
            'llm_soft_labels_noise_std': llm_soft_labels_noise_std,
            'llm_temperature_smoothening': llm_temperature_smoothening,
            'llm_cq_loss_string': llm_cq_loss_string,
            'llm_gce_q': llm_gce_q,
            'llm_cq_epochs': llm_cq_epochs,
            'llm_cq_lr': llm_cq_lr,
            'llm_cq_wd': llm_cq_wd,
            'llm_cq_clip': llm_cq_clip,
            'llm_cq_batch_size': llm_cq_batch_size,  
            'Acquisition Function': llm_cq_acquisition_function, 
            'llm_cq_bsize': llm_cq_bsize,
            'llm_cq_iterations': llm_cq_iterations,
            'llm_cq_candidates': llm_cq_candidates,
            'Boltzmann Temperature': boltzmann_temperature,
            'Value Model Collection Size': value_model_collection_size
        }

        run = wandb.init(project=f'MLCM-LLMs-v1.3', # TODO: change to appropriate name
                    config=wandb_config_dict,
                    reinit=True)

        wandb.define_metric("Elicited CQs") 
        
        # Standard Learning Metrics 
        wandb.define_metric("KT", step_metric="Elicited CQs")
        wandb.define_metric("R2", step_metric="Elicited CQs")
        wandb.define_metric("MAE", step_metric="Elicited CQs")
        wandb.define_metric("MSE", step_metric="Elicited CQs")
        wandb.define_metric("KT top 5", step_metric="Elicited CQs")
        wandb.define_metric("R2 top 5", step_metric="Elicited CQs")
        wandb.define_metric("MAE top 5", step_metric="Elicited CQs")
        wandb.define_metric("MSE top 5", step_metric="Elicited CQs")
        wandb.define_metric("KT top 10", step_metric="Elicited CQs")
        wandb.define_metric("R2 top 10", step_metric="Elicited CQs")
        wandb.define_metric("MAE top 10", step_metric="Elicited CQs")
        wandb.define_metric("MSE top 10", step_metric="Elicited CQs")

        # also add centered version of all above metrics 
        wandb.define_metric("R2 - Centered", step_metric="Elicited CQs")
        wandb.define_metric("MAE - Centered", step_metric="Elicited CQs")
        wandb.define_metric("MSE - Centered", step_metric="Elicited CQs")
        wandb.define_metric("R2 top 5 - Centered", step_metric="Elicited CQs")
        wandb.define_metric("MAE top 5 - Centered", step_metric="Elicited CQs")
        wandb.define_metric("MSE top 5 - Centered", step_metric="Elicited CQs")
        wandb.define_metric("R2 top 10 - Centered", step_metric="Elicited CQs")
        wandb.define_metric("MAE top 10 - Centered", step_metric="Elicited CQs")
        wandb.define_metric("MSE top 10 - Centered", step_metric="Elicited CQs")

        # GUI Metrics
        if project_to_gui:
            wandb.define_metric("Base Values", step_metric="Elicited CQs")
            wandb.define_metric("Adjustments", step_metric="Elicited CQs")
            wandb.define_metric("KT - GUI", step_metric="Elicited CQs")
            wandb.define_metric("R2 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("MAE - GUI", step_metric="Elicited CQs")
            wandb.define_metric("MSE - GUI", step_metric="Elicited CQs")
            wandb.define_metric("KT top 5 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("R2 top 5 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("MAE top 5 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("MSE top 5 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("KT top 10 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("R2 top 10 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("MAE top 10 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("MSE top 10 - GUI", step_metric="Elicited CQs")

        # Other metrics
        wandb.define_metric("Ordinal Dataset Size", step_metric="Elicited CQs")
        wandb.define_metric("Agreement on CQs", step_metric="Elicited CQs")
        wandb.define_metric("Allocated Bundle Value", step_metric="Elicited CQs")
        wandb.define_metric("Found New Max", step_metric="Elicited CQs")  # note: this mteric is only useful for the basic plus method 
        


    # --- Step 1. Load true and GUI (noisy) student lists, as well as prices --- #
    if load_students:
        print(f'supply_ratio: {supply_ratio} number_of_popular: {number_of_popular} index:{index} model_family: {model_family} queries: {queries}')
        (_, true_model_info, model_student_list, true_student_list,  prices_stage_1_GUI, timetable, capacities) = load_all_models_v2(index = index, model_family= model_family, benchmark_family = true_family, queries = '30-10', supply_ratio= supply_ratio, number_of_popular= number_of_popular, linear_instances = linear_instances)
    
        # set_trace()

    else:
        # create the model instance from scratch 
        true_student_lists_all_instances, capacities_all_instances, timetables_all_instances = generate_problem_instance_principled(number_of_times= 1, number_of_students= 100, number_of_courses= 25,
                supply_ratio= supply_ratio, number_of_popular= number_of_popular, seed= seed + student_index)
        
        true_student_list = true_student_lists_all_instances[0]
        capacities = capacities_all_instances[0]
        timetable = timetables_all_instances[0]

        prices_stage_1_GUI = np.array([0.2 for _ in range(25)])  # just a price vector so that each student can take 5 courses 
        
        
        if not linear_instances:
        #     model_info = ('UNN_transfer_learning', 
        #     {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 1, 'UNN_units': 20, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True, 'use_gradient_clipping': True, 
        #   'gui_forget_base': 0.5, 'gui_forget_adjustments': 0.48, 'gui_base_noise_std': 23, 'gui_adj_std': 0.2, 'gui_sample_categories_weights': [1, 1, 1], 
        #   'sample_relative_frequencies': (1, 1, 1, 1), 'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None, 'use_implied_dataset': True, 
        #   'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'lr_cardinal': 0.01, 'weight_decay_cardinal': 0.001, 'epochs_cardinal': 100, 'batch_size_cardinal': 8, 'UNN_loss_string_cardinal': 'l1', 'clip_cardinal': 0.1, 
        #   'lr_ordinal': 0.01, 'weight_decay_ordinal': 0.0, 'epochs_ordinal': 10, 'batch_size_ordinal': 8, 'UNN_loss_string_ordinal': 'BCE', 'clip_ordinal': 0.1})
            true_model_info = ('True', None)
    

    
    
    model_type = model_info[0]
    model_param_dictionary = model_info[1]
    number_of_courses = capacities.shape[0]
    model_param_dictionary['cq_method'] = args.cq_method

    # add extra information to the model_param_dictionary for projection. 
    projection_dictionary = {'project_to_gui': True, 'approximate_prices_model': 'gui_reports' ,'proj_alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 500, 'train_high_samples': 500,
        'linear_projection': linear_instances}
    model_param_dictionary.update(projection_dictionary)
    
    # true_student_list = [true_student_list[student_index]]
    true_student_list = true_student_list[student_index: student_index + students_to_check]

    # set the seeds for reproducability.  
    torch.manual_seed(seed + student_index)
    np.random.seed(seed + student_index)
    random.seed(seed + student_index)


    # --- Step 2. Generate the bundles that we will use for the generalization test --- # 
    
    bundles = generate_random_bundles(number_of_bundles= 2000, number_of_courses= 25, seed = seed)
    bundles_5_percentile = keep_high_valued_bundles(bundles= bundles, percentile = 95, benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0])
    bundles_10_percentile = keep_high_valued_bundles(bundles= bundles, percentile = 90, benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0])


    # --- Step 3. Generate the 'initial' training dataset based on the GUI reports --- #

    actual_student_list = create_actual_student_list(true_student_list, model_type = model_info[0], model_param_dictionary = model_param_dictionary, seed = seed)  # just returns the true list: NOTE: should remove
    gui_student_list = create_gui_student_list(true_student_list, model_type = model_type, model_param_dictionary = model_param_dictionary, seed = seed)

    
    training_set = generate_new_queries(actual_student_list = actual_student_list, timetable = timetable, number_of_samples = 0,  # number of samples is not used when generating the original dataset. Instead, it is read from the model_dictionary
                    current_training_set = [], model_student_list = [], approximate_prices = None,
                    credit_units = [1 for i in range(number_of_courses)], model_type = model_type, seed= seed + args.student_to_check, model_param_dictionary = model_param_dictionary, gui_student_list = gui_student_list)


    # --- Step 3.5  Generate initial student list based on the initial training set   --- # 
    if llm_cq_acquisition_function == 'infomax' or llm_cq_acquisition_function == 'doubleTS':
        model_student_lists = []
        # if the model requires bagging, we will create multiple models and use them to generate the CQs
        for i in range(value_model_collection_size):
            print(f'Creating model {i} for the collection')
            # set_trace()
            model_student_list = create_iterative_student_list(training_set, actual_student_list, credit_units = [1 for i in range(number_of_courses)],
                            timetable = timetable, model_type = model_type, model_param_dictionary = model_param_dictionary, 
                            model_student_list= None)
            
            model_student_lists.append(model_student_list)
    
    else:
        # if the model does not require bagging, we will only create one model and use it to generate the CQs
        model_student_lists = [create_iterative_student_list(training_set, actual_student_list, credit_units = [1 for i in range(number_of_courses)],
                            timetable = timetable, model_type = model_type, model_param_dictionary = model_param_dictionary, 
                            model_student_list= None)]  
    # set_trace()
        
    
    
    # --- Initial metric calculation for 0 CQs --- #
    elicited_llm_cqs = 0
    # Measure initial allocation value of the model
    allocation_model = calculate_allocation(model_list=model_student_lists[0], prices=prices_stage_1_GUI, timetable=timetable, models_to_run=[model_info])
    allocation_value_model = get_true_value_of_allocation(benchmark_student_list=true_student_list, individual_demands=allocation_model, benchmark_model=true_model_info[0], timetable=timetable)
    # Measure generalization performance of the model (0 CQs answered)
    # Measure generalization performance of the ML models and GUI reports
    kts_model, r2s_model, maes_model, mses_model = measure_generalization_performance_all_students(bundles_all_students= [bundles for i in range(len(true_student_list))], model_type= model_info[0], 
                                    model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0])
    kts_model_top_5, r2s_model_top_5, maes_model_top_5, mses_model_top_5 = measure_generalization_performance_all_students(bundles_all_students= [bundles_5_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                    model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0])
    kts_model_top_10, r2s_model_top_10, maes_model_top_10, mses_model_top_10 = measure_generalization_performance_all_students(bundles_all_students= [bundles_10_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                    model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0])
    # measure the centered versions of the metrics
    _, r2s_model_centered, maes_model_centered, mses_model_centered = measure_generalization_performance_all_students(bundles_all_students= [bundles for i in range(len(true_student_list))], model_type= model_info[0], 
                                    model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0], centered_metrics= True)
    _, r2s_model_top_5_centered, maes_model_top_5_centered, mses_model_top_5_centered = measure_generalization_performance_all_students(bundles_all_students= [bundles_5_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                    model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0], centered_metrics= True)
    _, r2s_model_top_10_centered, maes_model_top_10_centered, mses_model_top_10_centered = measure_generalization_performance_all_students(bundles_all_students= [bundles_10_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                    model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0], centered_metrics= True)

    # set_trace()
    
    wandb_dict = {
            'Elicited CQs': elicited_llm_cqs,
            'KT': np.mean(kts_model),
            'R2': np.mean(r2s_model),
            'MAE': np.mean(maes_model),
            'MSE': np.mean(mses_model),
            'KT top 5': np.mean(kts_model_top_5),
            'R2 top 5': np.mean(r2s_model_top_5),
            'MAE top 5': np.mean(maes_model_top_5),
            'MSE top 5': np.mean(mses_model_top_5),
            'KT top 10': np.mean(kts_model_top_10),
            'R2 top 10': np.mean(r2s_model_top_10),
            'MAE top 10': np.mean(maes_model_top_10),
            'MSE top 10': np.mean(mses_model_top_10),
            'R2 - Centered': np.mean(r2s_model_centered),
            'MAE - Centered': np.mean(maes_model_centered),
            'MSE - Centered': np.mean(mses_model_centered),
            'R2 top 5 - Centered': np.mean(r2s_model_top_5_centered),
            'MAE top 5 - Centered': np.mean(maes_model_top_5_centered),
            'MSE top 5 - Centered': np.mean(mses_model_top_5_centered),
            'R2 top 10 - Centered': np.mean(r2s_model_top_10_centered),
            'MAE top 10 - Centered': np.mean(maes_model_top_10_centered),
            'MSE top 10 - Centered': np.mean(mses_model_top_10_centered),
            'Allocated Bundle Value': np.mean(allocation_value_model),
        }
    
    if wandb_tracking:
        wandb.log(wandb_dict)

    
    for cq_iteration in range(1, llm_cq_iterations + 1):
        elicited_llm_cqs = cq_iteration * llm_cq_bsize
        print(f'--------> Number of elicited LLM CQs: {elicited_llm_cqs}')
        # set_trace()

        
        # --- Step 4: Generate the LLM CQs and get their answers --- # 
        for i in range(students_to_check):
            # set_trace()
            cq_bundle_tuples = generate_llm_cqs(number_of_cqs= llm_cq_bsize, 
                                acquisition_function= llm_cq_acquisition_function, capacities= capacities, 
                                number_of_candidate_bundles= llm_cq_candidates, 
                                model_type = model_info[0], model_param_dictionary = model_info[1], 
                                model_list = [model_student_list[i] for model_student_list in model_student_lists], 
                                boltzmann_temperature= boltzmann_temperature, 
                                max_attempts= 50)
            
            # set_trace()
                                                

        
            cq_bundle_tuples, labels = simulate_llm_answers_cqs(bundle_tuples= cq_bundle_tuples, model_type= model_info[0], 
                model_param_dictionary= model_info[1], benchmark_model_type= true_model_info[0], benchmark_student= true_student_list[i], model_student_list= model_student_lists[0],
                llm_simulation_dictionary = {'reply_encoding': llm_cq_labels, 'hard_labels_flip_probability': llm_hard_labels_flip_probability,
                                            'soft_labels_noise_std': llm_soft_labels_noise_std, 'temperature-smoothening': llm_temperature_smoothening})  
            
        
        
             # Add those to the training set
            # training_set[i][1] = (cq_bundle_tuples, labels)
            if len(training_set[i][1][0]) == 0:
                # set_trace()
                training_set[i][1] = (cq_bundle_tuples, labels)
            else:
                # set_trace()
                all_tuples = np.vstack([training_set[i][1][0], cq_bundle_tuples])
                all_labels = np.hstack([training_set[i][1][1], labels])
                training_set[i][1] = (all_tuples, all_labels)

            

        # Step 4. Train the models on the current dataset -- NOTE: no need to change anything below this point 
        if llm_cq_acquisition_function == 'infomax' or llm_cq_acquisition_function == 'doubleTS':
            # model_student_lists = []
        # if the model requires bagging, we will create multiple models and use them to generate the CQs
            for model_index in range(value_model_collection_size):
                print(f'Creating model {model_index} for the collection')
                model_student_lists[model_index] = create_iterative_student_list(training_set, actual_student_list, credit_units = [1 for i in range(number_of_courses)],
                                timetable = timetable, model_type = model_type, model_param_dictionary = model_param_dictionary, 
                                model_student_list= model_student_lists[model_index])
                # model_student_lists.append(model_student_list)
        else:
            model_student_lists = [create_iterative_student_list(training_set, actual_student_list, credit_units = [1 for i in range(number_of_courses)],
                        timetable = timetable, model_type = model_type, model_param_dictionary = model_param_dictionary, 
                        model_student_list= model_student_lists[0])]
            
        
        # set_trace()
        
        
        # Step 5. Measure the performance of the models
    
        # Step 5a) Measure allocation value of the model 
        allocation_model = calculate_allocation(model_list= model_student_lists[0], prices= prices_stage_1_GUI, timetable= timetable, models_to_run= [model_info])
        allocation_value_model = get_true_value_of_allocation(benchmark_student_list= true_student_list, individual_demands= allocation_model, benchmark_model= true_model_info[0], timetable = timetable)
        # set_trace()

        

        # Step 6.a) Measure generalization performance of the ML models and GUI reports
        kts_model, r2s_model, maes_model, mses_model = measure_generalization_performance_all_students(bundles_all_students= [bundles for i in range(len(true_student_list))], model_type= model_info[0], 
                                        model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0])
        kts_model_top_5, r2s_model_top_5, maes_model_top_5, mses_model_top_5 = measure_generalization_performance_all_students(bundles_all_students= [bundles_5_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                        model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0])
        kts_model_top_10, r2s_model_top_10, maes_model_top_10, mses_model_top_10 = measure_generalization_performance_all_students(bundles_all_students= [bundles_10_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                        model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0])
        

        # measure the centered versions of the metrics
        _, r2s_model_centered, maes_model_centered, mses_model_centered = measure_generalization_performance_all_students(bundles_all_students= [bundles for i in range(len(true_student_list))], model_type= model_info[0], 
                                        model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0], centered_metrics= True)
        _, r2s_model_top_5_centered, maes_model_top_5_centered, mses_model_top_5_centered = measure_generalization_performance_all_students(bundles_all_students= [bundles_5_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                        model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0], centered_metrics= True)
        _, r2s_model_top_10_centered, maes_model_top_10_centered, mses_model_top_10_centered = measure_generalization_performance_all_students(bundles_all_students= [bundles_10_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                        model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0], centered_metrics= True)

        # set_trace()
    
        
        # # Step6.b) Measure number of agreements between the GUI reports and the true preferences
        # agreements = check_agreements_on_cq(model_student_list= model_student_list, model_info= model_info,benchmark_student_list= true_student_list, benchmark_model_info= true_model_info, cqs= last_queries, timetable= timetable)
        
        ordinal_info_size = get_ordinal_dataset_size(training_dataset= training_set, model_param_dictionary= model_param_dictionary)

        # Step 7. Project the MVNNs to measure results in the GUI language
        if project_to_gui:
            projected_mvnns = project_mvnns(mvnn_student_list= model_student_list, model_param_dictionary= model_param_dictionary)

            # Measure the number of base values and adjustments 
            base_values_model, adjustments_model = measure_gui_language_metrics(model_list = projected_mvnns, model_type = 'UNN_projected')
        
    
            # Measure similarity between projected ML models/GUI reports and the true (additive) preferences
            gui_language_kts_model_top10, gui_language_r2s_model_top10, gui_language_maes_model_top10, gui_language_mses_model_top10 = compare_gui_reports(true_student_list= true_student_list, other_student_list= projected_mvnns, model_type= 'UNN_projected', linear_instances = linear_instances, percentile= 60)
            gui_language_kts_model, gui_language_r2s_model, gui_language_maes_model, gui_language_mses_model = compare_gui_reports(true_student_list= true_student_list, other_student_list= projected_mvnns, model_type= 'UNN_projected', linear_instances = linear_instances)
            gui_language_kts_model_top5, gui_language_r2s_model_top5, gui_language_maes_model_top5, gui_language_mses_model_top5 = compare_gui_reports(true_student_list= true_student_list, other_student_list= projected_mvnns, model_type= 'UNN_projected', linear_instances = linear_instances, percentile= 80)
        
        
        
        wandb_dict = {
            'Elicited CQs': elicited_llm_cqs,
            'KT': np.mean(kts_model),
            'R2': np.mean(r2s_model),
            'MAE': np.mean(maes_model),
            'MSE': np.mean(mses_model),
            'KT top 5': np.mean(kts_model_top_5),
            'R2 top 5': np.mean(r2s_model_top_5),
            'MAE top 5': np.mean(maes_model_top_5),
            'MSE top 5': np.mean(mses_model_top_5),
            'KT top 10': np.mean(kts_model_top_10),
            'R2 top 10': np.mean(r2s_model_top_10),
            'MAE top 10': np.mean(maes_model_top_10),
            'MSE top 10': np.mean(mses_model_top_10),
            'R2 - Centered': np.mean(r2s_model_centered),
            'MAE - Centered': np.mean(maes_model_centered),
            'MSE - Centered': np.mean(mses_model_centered),
            'R2 top 5 - Centered': np.mean(r2s_model_top_5_centered),
            'MAE top 5 - Centered': np.mean(maes_model_top_5_centered),
            'MSE top 5 - Centered': np.mean(mses_model_top_5_centered),
            'R2 top 10 - Centered': np.mean(r2s_model_top_10_centered),
            'MAE top 10 - Centered': np.mean(maes_model_top_10_centered),
            'MSE top 10 - Centered': np.mean(mses_model_top_10_centered),
            'Allocated Bundle Value': np.mean(allocation_value_model),
        }
        # if model_param_dictionary['cq_method'] == 'basic_plus':
        #     wandb_dict['Found New Max'] = training_set[0][5]

        if project_to_gui:
            gui_dict = {
                'Base Values': base_values_model,
                'Adjustments': adjustments_model,
                'KT - GUI': gui_language_kts_model,
                'R2 - GUI': gui_language_r2s_model,
                'MAE - GUI': gui_language_maes_model,
                'MSE - GUI': gui_language_mses_model,
                'KT top 5 - GUI': gui_language_kts_model_top5,
                'R2 top 5 - GUI': gui_language_r2s_model_top5,
                'MAE top 5 - GUI': gui_language_maes_model_top5,
                'MSE top 5 - GUI': gui_language_mses_model_top5,
                'KT top 10 - GUI': gui_language_kts_model_top10,
                'R2 top 10 - GUI': gui_language_r2s_model_top10,
                'MAE top 10 - GUI': gui_language_maes_model_top10,
                'MSE top 10 - GUI': gui_language_mses_model_top10,
            }

            wandb_dict.update(gui_dict)

        result_dict[elicited_llm_cqs] = wandb_dict

        # set_trace()
        if wandb_tracking:
            wandb.log(wandb_dict)

            

        print(f'--------> Number of LLM CQs: {elicited_llm_cqs}, ordinal_info_size: {ordinal_info_size}, allocation values: {allocation_value_model}')

        # set_trace()

    if wandb_tracking: 
        run.finish()

    # save the results to a pickle file
    # dict_name = f'./acquisition_function_results/SR_{supply_ratio}_NP_{number_of_popular}_LI_{linear_instances}_AF_{args.cq_method}_student_{args.student_to_check}.pkl'
    # open the file for writing, and create it if it does not exist
    # with open(dict_name, 'wb+') as f:
    #     pickle.dump(result_dict, f)

    return 

if __name__ == '__main__':
    main_function()