import numpy as np
import pickle
# import math
# import gurobipy as gp
from cleanup import student 

import argparse
import torch
import scipy.stats as stats
from sklearn.metrics import r2_score
from sklearn.preprocessing import PolynomialFeatures
import math 
import wandb 
from model_comparison import load_all_models_v2, get_true_value_of_allocation, generate_random_bundles, model_predict_values, measure_generalization_performance_all_students, keep_high_valued_bundles, calculate_allocation, measure_gui_language_metrics, compare_gui_reports
from tabu_search import binary_search, binary_search_hacky, construct_implied_dataset_v2, solve_student, create_actual_student_list, create_gui_student_list, create_iterative_student_list, project_mvnns
import wandb 
import random 
from acquisition_function_comparison import generate_new_queries, check_agreements_on_cq, get_ordinal_dataset_size
from preference_generator import generate_problem_instance_principled
import numpy as np
import random
import torch

# Replace for local model once in cluster/downloaded locally
from edsl import QuestionMultipleChoice
from edsl import QuestionFreeText
from edsl import Model
from edsl import Agent, Scenario, Survey
from edsl.data import Cache
import pandas as pd

from pdb import set_trace


model = Model("gpt-4o", "max_tokens"==1500)


def generate_student_prompt(student):
    """
    Converts a student profile into a detailed natural language prompt for an LLM,
    including specific relationship strengths for complements and substitutes.
    
    Args:
        student: Student object with attributes for preferences and constraints
        
    Returns:
        str: A formatted prompt for the LLM
    """
    
    # Create sorted list of all courses with their preferences
    all_courses = [(i+1, student.additive_prefs[i]) for i in range(len(student.additive_prefs))]
    sorted_courses = sorted(all_courses, key=lambda x: x[1], reverse=True)
    
    # Identify preference tiers
    max_pref = max(pref for _, pref in sorted_courses)
    min_pref = min(pref for _, pref in sorted_courses)
    range_pref = max_pref - min_pref
    
    high_pref = []
    medium_pref = []
    low_pref = []
    
    for course, pref in sorted_courses:
        normalized_pref = (pref - min_pref) / range_pref
        if normalized_pref >= 0.7:
            high_pref.append((course, pref))
        elif normalized_pref >= 0.3:
            medium_pref.append((course, pref))
        else:
            low_pref.append((course, pref))
    
    def format_tier_courses(courses):
        """Format courses within a tier as an ordered list"""
        return "\n".join(f"   {i+1}. Course {course[0]} (value: {course[1]:.2f})" 
                        for i, course in enumerate(courses))
    
    def format_course_group(group, relationship_type):
        """
        Format course groups with detailed relationship strengths
        """
        courses = [str(int(x) + 1) for x in group[0]]
        values = group[1]
        
        if relationship_type == "substitute":
          # For substitutes, values are negative, so we'll convert to positive percentages 
          impacts = [f"{-value * 100:.0f}%" for value in values[1:]]
          message = f"Courses {', '.join(courses)} overlap in content. "
          if len(values) > 1:
              message += f"Taking any two reduces their combined value by {impacts[0]}"
              for i in range(2, len(values)-1):
                  message += f", taking any {i+1} reduces their combined value by {impacts[i-1]}"
              if len(values) > 2:
                  message += f", and taking all {len(courses)} reduces their combined value by {impacts[-1]}"
          return message
        else:  # complement
          impacts = [f"{value * 100:.0f}%" for value in values[1:]]
          message = f"Courses {', '.join(courses)} complement each other. "
          if len(values) > 1:
              message += f"Taking any two increases their combined value by {impacts[0]}"
              for i in range(2, len(values)-1):
                  message += f", taking any {i+1} increases their combined value by {impacts[i-1]}"
              if len(values) > 2:
                  message += f", and taking all {len(courses)} increases their combined value by {impacts[-1]}"
          return message
    
    substitute_groups = [format_course_group(group, "substitute") for group in student.substitutes]
    complement_groups = [format_course_group(group, "complement") for group in student.complements]
    
    prompt = f"""Please act as a student describing their course preferences for the upcoming semester. Write a detailed, first-person paragraph about your preferences based on the following information:

Course Preferences (ordered from highest to lowest value within each tier):

High Priority Courses:
{format_tier_courses(high_pref)}

Medium Priority Courses:
{format_tier_courses(medium_pref)}

Lower Priority Courses:
{format_tier_courses(low_pref)}

Course Relationships:
Overlapping Content (Substitutes) - These are courses that HARMS YOU when taken together compared to taking only one, and the more the worse:
{chr(10).join(f"- {group}" for group in substitute_groups)}

Complementary Courses (Complements) - These are courses that BENEFITS YOU HARMS YOU when taken together compared to taking only one, and the more the better:
{chr(10).join(f"- {group}" for group in complement_groups)}

Additional Constraints:
- Budget Constraint: {student.budget:.2f}
{f"- Time Gap Penalty: {student.timegap_penalty}" if student.timegap_penalty != 0 else ""}
{f"- Overload Penalty: {student.overload_penalty}" if student.overload_penalty != 0 else ""}

Please write a natural, detailed explanation of these preferences as if you were the student. Include:
1. Your strongest course interests, explaining them in order of preference within each priority tier
2. How you're thinking about course combinations, discussing specific synergies and overlaps:
   - When describing overlapping courses, explain how much the overlap affects your interest
   - When describing complementary courses, explain how much additional value you see in taking them together
3. Your overall strategy for course selection, considering both your budget constraints and the strength of course relationships
4. Any specific scheduling or workload considerations

Keep the tone conversational and authentic to how a student would describe their course preferences. Make sure to reference both your relative preferences within each tier and the specific impacts of course combinations on your overall academic plan. Your response should be three paragraph: the first paragraph should list all the top and medium tier courses you want to pick, and any courses that comes with each that might be complementary (good) or subtitutes (bad). The second paragraph should clearly and detailedly state ALL OF THE bundles of courses that are complements and bundles of courses that are substitutes for you; for each, explain how much it hurts you when you take different numbers of courses from that bundle. You should cover all the complement bundles and substitute bundles given to you - leave nothing out. The thrid paragraph concludes and highlights other concerns. Aim for qualitatively detailed description and avoid saying exact numerical values in the output (e.g. it's fine to say 'taking x and x together is suboptimal, taking x, x and x together even more suboptimal, and taking x,x,x, and x together should really be avoided, (so on so forth until the size of the complement/substitute set).', but not fine to say 'My value for course 16 is 102.45', or 'x and x together decreases my utility by 43%'.)
"""
    print(prompt)
    
    return prompt

def format_bundle(bundle):
    """
    Convert a binary course bundle array into a readable course list string.
    
    Args:
        bundle: numpy array of shape (25,) with 1s indicating selected courses
        
    Returns:
        str: Formatted string listing included courses
    """
    courses = [f"Course {i+1}" for i, present in enumerate(bundle) if present == 1]
    return ", ".join(courses)

def generate_comparison_prompt(bundle1, bundle2, student_preferences_text):
    """
    Generate a highly structured prompt that enforces a specific XML-like format for easier parsing.
    """
    prompt = f"""Based on these student preferences:

{student_preferences_text}

Compare:
Bundle A: {format_bundle(bundle1)}
Bundle B: {format_bundle(bundle2)}
to choose the better bundle.

Please ignore budget constraint if it's mentioned in the preference - pretend it doesn't exist.

Your response must use these EXACT tags below, and ONLY include the tags, end your response after that. The text between tags should be concise.

'''
<PREFERENCES>
Bundle A: [First recall the courses in Bundle, then list matching preferences, e.g. Bundle A contains Courses X, X, (list all courses in Bundle A). Course X is high preference, Course X is mid preference, Course X is low preference]
Bundle B: [First recall the courses in Bundle, then list matching preferences, e.g. Bundle B contains Courses X, X, (list all courses in Bundle B). Course X is high preference, Course X is mid preference, Course X is low preference]
</PREFERENCES>

<COMPLEMENTS>
Bundle A: [First recall the courses in Bundle, then list complementary relationships with magnitudes or "None", e.g. Bundle A contains Courses X, X, (list all courses in Bundle A). Course X and Course X are complements which helps moderately when taken together, Course X, Course X, and Course X are complements and helps significantly when taken together]
Bundle B: [First recall the courses in Bundle, then list complementary relationships with magnitudes or "None", e.g. Bundle B contains Courses X, X, (list all courses in Bundle B). Course X and Course X are complements which helps moderately when taken together, Course X, Course X, and Course X are complements and helps significantly when taken together]
</COMPLEMENTS>

<SUBSTITUTES>
Bundle A: [First recall the courses in Bundle, then list substitute relationships with magnitudes or "None", e.g. Bundle A contains Courses X, X, (list all courses in Bundle A). Course X and Course X are substitutes which harms moderately when taken together, Course X, Course X, and Course X are substitutes and harms significantly when taken together]
Bundle B: [First recall the courses in Bundle, then list substitute relationships with magnitudes or "None", e.g. Bundle B contains Courses X, X, (list all courses in Bundle B). Course X and Course X are substitutes which harms moderately when taken together, Course X, Course X, and Course X are substitutes and harms significantly when taken together]
</SUBSTITUTES>

<REASONING>
[Provide your concise reasoning in a few sentences, e.g. From the above, in terms of preferences, Bundle X is better. In terms of the presence and magnitude of complements, Bundle X is better. In terms of magnitude and precense of substitutes, bundle X is better. Considering the tradeoffs, Bundle X is better.]
</REASONING>

<CHOICE>Bundle X</CHOICE>
This is the end of the output
'''
"""
    return prompt


def extract_choice(llm_response):
    """
    Extract the bundle choice using reliable XML-like tag parsing.
    Returns 1 for Bundle A, 0 for Bundle B.
    """
    import re
    
    # Standardize the response format
    if hasattr(llm_response, 'generated_text'):
        text = llm_response.generated_text
    elif isinstance(llm_response, dict):
        text = llm_response.get('generated_text', str(llm_response))
    elif isinstance(llm_response, list) and len(llm_response) > 0:
        if isinstance(llm_response[0], dict):
            text = llm_response[0].get('generated_text', str(llm_response[0]))
        else:
            text = str(llm_response[0])
    else:
        text = str(llm_response)
    # print(text)

    # First try to find the choice within proper XML tags
    choice_match = re.search(r'<CHOICE>Bundle\s*([AB])</CHOICE>', text, re.IGNORECASE)
    if choice_match:
        return 1 if choice_match.group(1).upper() == 'A' else 0
    
    # If that fails, try to find any analysis block
    analysis_match = re.search(r'<ANALYSIS>(.*?)</ANALYSIS>', text, re.DOTALL)
    if analysis_match:
        analysis_text = analysis_match.group(1)
        # Look for choice within the analysis block
        choice_in_analysis = re.search(r'Bundle\s*([AB])\s*(?:</CHOICE>)?', analysis_text, re.IGNORECASE)
        if choice_in_analysis:
            return 1 if choice_in_analysis.group(1).upper() == 'A' else 0
    
    # As a last resort, look for any clear mention of a choice in the last few lines
    lines = text.strip().split('\n')[-5:]  # Check last 5 lines
    for line in lines:
        bundle_match = re.search(r'(?:choice|choose|prefer|select|pick).*bundle\s*([AB])', line, re.IGNORECASE)
        if bundle_match:
            return 1 if bundle_match.group(1).upper() == 'A' else 0
    
    # Log warning and return default
    print(f"Warning: Could not reliably extract choice. Defaulting to Bundle A (1). Last 5 lines were:\n{lines}")
    return 1

def generate_comparison_labels(cq_bundle_tuples, student_preferences_text):
    """
    Generate labels with additional validation.
    """
    import numpy as np
    from time import sleep
    
    labels = []
    prompt_list = []
    seen_prompts = {}  # Dictionary to track unique prompts and their indices

    # Generate prompts and de-duplicate
    for i, (bundle1, bundle2) in enumerate(cq_bundle_tuples):
        pro = generate_comparison_prompt(bundle1, bundle2, student_preferences_text)
        if pro not in seen_prompts:
            seen_prompts[pro] = i  # Store the index of the first occurrence
            prompt_list.append(pro)


    s = [Scenario({"prompt": str(p)}) for p in prompt_list]
    prompt_in = "{{prompt}}"

    q = QuestionFreeText(
        question_name = "CQ",
        question_text = prompt_in,
    )
    survey = Survey(questions = [q])
    results = survey.by(s).by(model).run()
    questions = results.select("prompt").to_list()
    response = results.select("CQ").to_list()

    results.select("CQ").print(format="rich")

    label_list = []
    for i, resp in enumerate(response):
        if resp is not None:
            try:
                a = extract_choice(resp)
                label_list.append(a)
            except Exception as e:
                print(f"Error extracting choice for response {i}: {e}")
                label_list.append(None)  # Append a placeholder for failed extraction
        else:
            print(f"Missing response for prompt {i}")
            label_list.append(None)  # Append a placeholder for missing response

    # Validate the lengths match
    if len(label_list) != len(cq_bundle_tuples):
        print(f"Warning: label_list length ({len(label_list)}) does not match cq_bundle_tuples length ({len(cq_bundle_tuples)}). Adjusting...")
        while len(label_list) < len(cq_bundle_tuples):
            label_list.append(None)  # Pad with None if labels are fewer
        if len(label_list) > len(cq_bundle_tuples):
            label_list = label_list[:len(cq_bundle_tuples)]  # Truncate if labels are more

    return questions,response, label_list

def generate_llm_cqs(number_of_cqs, acquisition_function, capacities, number_of_candidate_bundles, 
            model_type, model_param_dictionary, model_list,
            boltzmann_temperature, max_attempts = 20,
                     seed = None): 
    """
    A function to generate the CQs for the LLM to answer in this iteration.

    Parameters:
    - number_of_cqs (int): the number of CQs to generate in this iteration
    - acquisition_function (str): the acquisition function to use for generating the CQs: options: random, boltzmann, infomax, doubleTS
    - capacities (np.array): the capacities of the courses
    - number_of_candidate_bundles (int): the number of candidate bundles to generate. Those will be used to generate the CQs
    - model_type (str): the type of the underlying value model
    - model_param_dictionary (dict): the parameters of the value model
    - model_list (either single model or list): For non uncertainty-based acquisition functions, this is the student's learned value model, wrapped in a list. 
    For uncertainty-based acquisition functions, this is the collection of models to use for the uncertainty estimation.
    - boltzmann_temperature (float): the temperature to use for the Boltzmann exploration
    - max_attempts (int): the maximum number of attempts to find a different bundle in the doubleTS acquisition function

    Returns:
    - bundle_tuples (list): a list of tuples, where each tuple contains two bundles. Those are the CQs that the LLM should answer in this iteration.
    """
    

    bundle_tuples =  [] 
    number_of_courses = capacities.shape[0]

    if seed is not None: 
        raise ValueError('Seed is not None, which is not properly implemented yet')
    
    if acquisition_function == 'random': 
        bundles = generate_random_bundles(number_of_bundles= 500, number_of_courses= number_of_courses, seed = seed)  # generate the bundles that we will use 

        
        for _ in range(number_of_cqs):
            bundle1 = random.choice(bundles)
            bundle2 = random.choice(bundles)
            while np.all(bundle1 == bundle2):
                bundle2 = random.choice(bundles)
            bundle_tuples.append((bundle1, bundle2))

    elif acquisition_function == 'boltzmann':
        model = model_list[0]  # only use the first model for the Boltzmann exploration, we do not use bagging here.
        for cq_number in range(number_of_cqs):
            print(f'Generating CQ number: {cq_number}')
            
            # 1. generate the candidate bundles for the CQ
            bundles = generate_random_bundles(number_of_bundles= number_of_candidate_bundles, number_of_courses= number_of_courses, seed = seed)

            # 2. get the value of the bundles from the value model
            values = model_predict_values(model_type= model_type, model_param_dictionary= model_param_dictionary, model= model, bundles= bundles)

            # 3. calculate the probability of choosing each bundle
            probabilities = np.exp(values / boltzmann_temperature) / np.sum(np.exp(values / boltzmann_temperature))
            # set_trace()

            # 4. choose the first bundle
            bundle1 = random.choices(bundles, probabilities)[0]

            # 5. choose the second bundle
            bundle2 = random.choices(bundles, probabilities)[0]

            
            attempt = 0
            while np.all(bundle1 == bundle2) and attempt < max_attempts:
                bundle2 = random.choices(bundles, probabilities)[0]

                attempt += 1
            if attempt == max_attempts:
                print(f'Warning: could not find a different bundle after {max_attempts} attempts, choosing randomly')
                bundle2 = random.choice(bundles)
            
            bundle_tuples.append((bundle1, bundle2))


    elif acquisition_function == 'infomax':
        for _ in range(number_of_cqs): 

            # 1. generate the candidate bundles for the CQ
            bundles = generate_random_bundles(number_of_bundles= number_of_candidate_bundles, number_of_courses= number_of_courses, seed = None)
            probabilities_all_models = []

            for single_model in model_list:
                # 2. get the value of the bundles from all the models included in the collection 
                values = model_predict_values(model_type= model_type, model_param_dictionary= model_param_dictionary, model= single_model, bundles= bundles)
                probabilities_current_model = [] 

                # 3. predict the probability of choosing each bundle for each tuple for the current model, using the bradley-terry model
                cq_tuples = [] 
                for i in range(len(bundles)):
                    for j in range(i + 1, len(bundles)):
                        cq_tuples.append(i,j)
                        value_diff = values[i] - values[j]
                        p = 1 / (1 + np.exp(-value_diff))
                        probabilities_current_model.append(p)
                        
                probabilities_all_models.append(probabilities_current_model)
                
            # for each bundle tuple, calculate mean and variance of the probabilities
            # mean_probabilities = np.mean(probabilities_all_models, axis= 0)
            variance_probabilities = np.var(probabilities_all_models, axis= 0)

            # 4. choose the bundle tuple with the highest variance
            max_variance_index = np.argmax(variance_probabilities)
            bundle1_index, bundle2_index = cq_tuples[max_variance_index]
            bundle_tuples.append((bundles[bundle1_index], bundles[bundle2_index]))

    elif acquisition_function == 'doubleTS':
        for _ in range(number_of_cqs): 

            # 1. generate the candidate bundles for the CQ
            bundles = generate_random_bundles(number_of_bundles= number_of_candidate_bundles, number_of_courses= number_of_courses, seed = None)
            

            # sample the index of the model to use for the first TS
            model_index = random.randint(0, len(model_list) - 1)
            single_model = model_list[model_index]

            # 2. get the value of the bundles from the value model
            bundle_values = model_predict_values(model_type= model_type, model_param_dictionary= model_param_dictionary, model= single_model, bundles= bundles)

            # 3. select the bundle with the highest value
            bundle1_index = np.argmax(bundle_values)
            bundle1 = bundles[bundle1_index]

            bundle2_index = bundle1_index
            bundle2 = bundles[bundle2_index]
            attempt = 0
            # set_trace()

            while bundle2_index == bundle1_index and attempt < max_attempts: 
                # sample the index of the model to use for the second TS
                model_index = random.randint(0, len(model_list) - 1)
                print(f'DoubleTS using model index: {model_index}')
                single_model = model_list[model_index]

                # 2. get the value of the bundles from the value model
                bundle_values = model_predict_values(model_type= model_type, model_param_dictionary= model_param_dictionary, model= single_model, bundles= bundles)

                # 3. select the bundle with the highest value
                bundle2_index = np.argmax(bundle_values)
                bundle2 = bundles[bundle2_index]

                attempt += 1

            if bundle2_index == bundle1_index:
                print(f'Warning: DoubleTS could not find a different bundle after {max_attempts} attempts')
                bundle2 = random.choice(bundles)
            else: 
                print(f'DoubleTS found a different bundle after {attempt} attempts')
            bundle_tuples.append((bundle1, bundle2))
    
    
    return bundle_tuples


def simulate_llm_answers_cqs(bundle_tuples, model_type, model_param_dictionary, 
            benchmark_model_type, benchmark_student, model_student_list,
            llm_simulation_dictionary = {'reply_encoding': 'hard', 'hard_labels_flip_probability': 0.1}  # options: 'one_hot', 'probability'
            ): 
    """
    Simulate the answers to the CQs for the given bundle tuples.
    Returns: 
    The probability p with which the LLM predicts that the first bundle is preferred to the second bundle.
    """

    labels = []
    value = []
    for (bundle1, bundle2) in bundle_tuples:
        
        # get the true value of the bundles involved in each CQ 
        true_values = model_predict_values(model_type= benchmark_model_type, model_param_dictionary= None, model= benchmark_student, bundles= [bundle1, bundle2])

        value_diff = true_values[0] - true_values[1]

        if llm_simulation_dictionary['reply_encoding'] == 'hard':
            flip_probability = llm_simulation_dictionary['hard_labels_flip_probability']
            if value_diff > 0:
                true_label = 1
            else: 
                true_label = 0
            
            # flip the label with the given probability
            if random.random() < flip_probability:
                true_label = 1 - true_label
            labels.append(true_label)   
            value.append([true_values[0] , true_values[1]]) 

        elif llm_simulation_dictionary['reply_encoding'] == 'soft':
            temperature = llm_simulation_dictionary['temperature-smoothening']
            p = 1 / (1 + np.exp(-value_diff)) # the *true* probability that the first bundle is preferred to the second bundle according to a perfectly tuned MVNN 

            # add noise to the probability 
            p = p + np.random.normal(0, llm_simulation_dictionary['soft_labels_noise_std']) 
            p = np.clip(p, 0, 1)  # clip the probability to be between 0 and 1


            p = p ** (1 / temperature)   # smoothern/sharpen the labes in the same way as the MixMatch paper
            labels.append(p)


        else: 
            raise ValueError('Invalid reply encoding for the LLM CQs')
    
    return np.array(bundle_tuples), np.array(labels),value   # those are the X_train_ord, y_train_ord for training the TL MVNN 


def main_function():
   
    # --- Mechanismp parameterds, no need to change --- #
    parser = argparse.ArgumentParser()
    parser.add_argument('--supply_ratio', type=float, default= 1.25, help='supply_ratio')
    parser.add_argument('--number_of_popular', type= int, default= 9, help='afffects_correlation')
    parser.add_argument('--model_family', type= str, default= 'BPpop9v.1', help='the models to run') # options: [PO/B]pop9_projected_v1.1_from_scratch_gui_reports, 'BPpop9v.1' and 'POpop9'
    parser.add_argument('--linear_instances', type= str, default= 'false', help='whether we are in linear instances or not')
    parser.add_argument('--seed', type= int, default= 627, help='the seed to use for the random number generator')
    parser.add_argument('--cq_method', type= str, default= 'basic_plus', help='acquisition function to use') # options: 'basic', 'random', 'basic_plus', 'complete_ordering', 'complete_ordering_pruned'
    parser.add_argument('--wandb_tracking', type= str, default= 'true', help='whether to track the results with wandb')
    parser.add_argument('--TL_model', type= str, default= 'true', help='whether to use transfer learning model')
    parser.add_argument('--project_to_gui', type= str, default= 'false', help='whether to log details about the MVNNs projected to the GUI language')
    parser.add_argument('--load_students', type= str, default= 'false', help='whether to load the students from the file, if not, they will be created')
    parser.add_argument('--student_to_check', type= int, default= 42, help='the student that this run should check')

    # --- new arguments for the simulation of LLM CQ replies --- #
    parser.add_argument('--llm_cq_labels', type= str, default= 'hard', help='the method to use for generating the LLM CQ labels: options: hard, soft')  
    parser.add_argument('--llm_hard_labels_flip_probability', type= float, default= 0.2, help='the probability of flipping the label in the hard label case')
    parser.add_argument('--llm_soft_labels_noise_std', type= float, default= 0.00, help='the std of the noise to add to the soft labels')
    parser.add_argument('--llm_temperature_smoothening', type= float, default= 1, help='the temperature for the smoothening of the labels in the soft label case')
    
    # --- new arguments for learning on those LLM CQs --- #
    parser.add_argument('--llm_cq_loss_string', type= str, default= 'BCE', help='the loss function to use for the LLM CQs')  # should add robust to noise versions 
    parser.add_argument('--llm_cq_epochs', type= int, default= 50, help='the number of epochs to use for the LLM CQs')
    parser.add_argument('--llm_cq_lr', type= float, default= 0.005, help='the learning rate to use for the LLM CQs')
    parser.add_argument('--llm_cq_wd', type= float, default= 0.001, help='the regularization to use for the LLM CQs')
    parser.add_argument('--llm_cq_clip', type= float, default= 10, help='the gradient clipping to use for the LLM CQs')
    parser.add_argument('--llm_cq_batch_size', type= int, default= 1, help='the batch size to use for the LLM CQs')

    # parser.add_argument('--llm_cq_trainsize', type= int, default= 5000, help='the number of training samples to use for the LLM CQs')

    # --- Making things more principled/ arguments for the LLM acquisition function --- # 
    parser.add_argument('--llm_cq_acquisition_function', type= str, default= 'doubleTS', help='the acquisition function to use for generating the LLM CQs') # options: random, boltzmann, infomax, doubleTS
    parser.add_argument('--llm_cq_bsize', type= int, default= 10, help='the number of CQs to ask the LLM between training')
    parser.add_argument('--llm_cq_iterations', type= int, default= 50, help='the number of feedback iterations to do with the LLM CQs')
    parser.add_argument('--llm_cq_candidates', type= int, default= 100, help='the number of candidate bundles from which to build the LLM CQs')
    parser.add_argument('--boltzmann_temperature', type= float, default= 0.5, help='the temperature to use for the Boltzmann exploration')
    parser.add_argument('--value_model_collection_size', type= int, default= 10, help='the number of models to use for the value model uncertainty estimation')  # only applies to uncertainty-based acquisition functions: infomax, doubleTS 



    args = parser.parse_args()

    # --- Environment-related arguments --- # 
    supply_ratio = args.supply_ratio
    number_of_popular = args.number_of_popular
    index = args.student_to_check // 100
    student_index = args.student_to_check % 100
    model_family = args.model_family
    queries = '30-10'  # only needed to load the dictionaries, does not play any role 
    linear_instances = str(args.linear_instances).lower() == 'true'
    seed = args.seed 
    wandb_tracking = str(args.wandb_tracking).lower() == 'true'
    tl_model = str(args.TL_model).lower() == 'true'
    project_to_gui = str(args.project_to_gui).lower() == 'true'
    load_students = str(args.load_students).lower() == 'true'

    # --- LLM CQs-related arguments --- #
    llm_cq_labels = args.llm_cq_labels
    llm_hard_labels_flip_probability = args.llm_hard_labels_flip_probability
    llm_soft_labels_noise_std = args.llm_soft_labels_noise_std
    llm_temperature_smoothening = args.llm_temperature_smoothening
    llm_cq_loss_string = args.llm_cq_loss_string
    llm_cq_epochs = args.llm_cq_epochs
    llm_cq_lr = args.llm_cq_lr
    llm_cq_wd = args.llm_cq_wd
    llm_cq_clip = args.llm_cq_clip
    llm_cq_batch_size = args.llm_cq_batch_size

    # --- Model-related arguments --- #
    students_to_check = 1


    # --- Acquisition function-related arguments --- #
    llm_cq_acquisition_function = args.llm_cq_acquisition_function
    llm_cq_bsize = args.llm_cq_bsize
    llm_cq_iterations = args.llm_cq_iterations
    llm_cq_candidates = args.llm_cq_candidates
    boltzmann_temperature = args.boltzmann_temperature
    value_model_collection_size = args.value_model_collection_size


    if tl_model:
        # --- Transfer Learning Model (when we had the student answer the CQs without any noise) --- #
        # model_info = ('UNN_transfer_learning', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 
        #     'UNN_layers': 1, 'UNN_units': 20, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True, 'use_gradient_clipping': True,
        # 'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
        # 'sample_relative_frequencies': (1, 1, 1, 1),
        # 'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
        # 'lr_cardinal': 0.01, 'weight_decay_cardinal': 0.001, 'epochs_cardinal': 100, 'batch_size_cardinal': 8,  'UNN_loss_string_cardinal': 'l1', 'clip_cardinal': 0.1,
        # 'lr_ordinal':  0.01, 'weight_decay_ordinal': 0, 'epochs_ordinal': 10, 'batch_size_ordinal': 8,  'UNN_loss_string_ordinal': 'BCE', 'clip_ordinal': 0.1,
        # 'use_implied_dataset': True,
        # 'use_cqs': True, 'cq_method': 'complete_ordering_pruned'
        #                             })  # 'UNN_transfer_learning'

        model_info = ('UNN_transfer_learning', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 
            'UNN_layers': 1, 'UNN_units': 20, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True, 'use_gradient_clipping': True,
        'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
        'sample_relative_frequencies': (1, 1, 1, 1),
        'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
        'use_implied_dataset': True,'use_cqs': True, 'cq_method': 'complete_ordering_pruned',
        'lr_cardinal': 0.01, 'weight_decay_cardinal': 0.001, 'epochs_cardinal': 100, 'batch_size_cardinal': 8,  'UNN_loss_string_cardinal': 'l1', 'clip_cardinal': 0.1,
        # --- HPs that refer to the (used to be student) CQs --- # 
        'lr_ordinal': llm_cq_lr, 'weight_decay_ordinal': llm_cq_wd, 'epochs_ordinal': llm_cq_epochs, 'batch_size_ordinal': llm_cq_batch_size, 
        'UNN_loss_string_ordinal': llm_cq_loss_string, 'clip_ordinal': llm_cq_clip,
                                    })  # 'UNN_transfer_learning'
        
        

    # set_trace()
    
    if linear_instances:
        true_family = 'true_linear'
    else:
        true_family = 'true'

    result_dict = {} # will save the exact same details as the wandb tracking, but in a dictionary that we can save to a pickle file

    # llm_cq_acquisition_function = args.llm_cq_acquisition_function
    # llm_cq_bsize = args.llm_cq_bsize
    # llm_cq_iterations = args.llm_cq_iterations
    # llm_cq_candidates = args.llm_cq_candidates
    # boltzmann_temperature = args.boltzmann_temperature
    # value_model_collection_size = args.value_model_collection_size

    if wandb_tracking:
        wandb_config_dict = {
            'Supply Ratio': supply_ratio,
            'Number of Popular': number_of_popular,
            'Linear Instances': linear_instances,
            'Student Number': args.student_to_check,
            # 'Acquisition Function': args.cq_method,
            'Model Type': model_info[0], 
            # 'llm_cq_method': llm_cq_method,
            'llm_cq_labels': llm_cq_labels,
            'llm_hard_labels_flip_probability': llm_hard_labels_flip_probability,
            'llm_soft_labels_noise_std': llm_soft_labels_noise_std,
            'llm_temperature_smoothening': llm_temperature_smoothening,
            'llm_cq_loss_string': llm_cq_loss_string,
            'llm_cq_epochs': llm_cq_epochs,
            'llm_cq_lr': llm_cq_lr,
            'llm_cq_wd': llm_cq_wd,
            'llm_cq_clip': llm_cq_clip,
            'llm_cq_batch_size': llm_cq_batch_size,  
            'Acquisition Function': llm_cq_acquisition_function, 
            'llm_cq_bsize': llm_cq_bsize,
            'llm_cq_iterations': llm_cq_iterations,
            'llm_cq_candidates': llm_cq_candidates,
            'Boltzmann Temperature': boltzmann_temperature,
            'Value Model Collection Size': value_model_collection_size
        }

        # run = wandb.init(project=f'MLCM-LLMs-v1.2', # TODO: change to appropriate name
        #             config=wandb_config_dict,
        #             reinit=True)
        run = wandb.init(project=f'MLCM-GPT-v2.0', 
                    config=wandb_config_dict,
                    reinit=True)


        wandb.define_metric("Elicited CQs") 
        
        # Standard Learning Metrics 
        wandb.define_metric("KT", step_metric="Elicited CQs")
        wandb.define_metric("R2", step_metric="Elicited CQs")
        wandb.define_metric("MAE", step_metric="Elicited CQs")
        wandb.define_metric("MSE", step_metric="Elicited CQs")
        wandb.define_metric("KT top 5", step_metric="Elicited CQs")
        wandb.define_metric("R2 top 5", step_metric="Elicited CQs")
        wandb.define_metric("MAE top 5", step_metric="Elicited CQs")
        wandb.define_metric("MSE top 5", step_metric="Elicited CQs")
        wandb.define_metric("KT top 10", step_metric="Elicited CQs")
        wandb.define_metric("R2 top 10", step_metric="Elicited CQs")
        wandb.define_metric("MAE top 10", step_metric="Elicited CQs")
        wandb.define_metric("MSE top 10", step_metric="Elicited CQs")

        # also add centered version of all above metrics 
        wandb.define_metric("R2 - Centered", step_metric="Elicited CQs")
        wandb.define_metric("MAE - Centered", step_metric="Elicited CQs")
        wandb.define_metric("MSE - Centered", step_metric="Elicited CQs")
        wandb.define_metric("R2 top 5 - Centered", step_metric="Elicited CQs")
        wandb.define_metric("MAE top 5 - Centered", step_metric="Elicited CQs")
        wandb.define_metric("MSE top 5 - Centered", step_metric="Elicited CQs")
        wandb.define_metric("R2 top 10 - Centered", step_metric="Elicited CQs")
        wandb.define_metric("MAE top 10 - Centered", step_metric="Elicited CQs")
        wandb.define_metric("MSE top 10 - Centered", step_metric="Elicited CQs")
        wandb.define_metric("LLM_Correct_prob", step_metric="Elicited CQs")

        # GUI Metrics
        if project_to_gui:
            wandb.define_metric("Base Values", step_metric="Elicited CQs")
            wandb.define_metric("Adjustments", step_metric="Elicited CQs")
            wandb.define_metric("KT - GUI", step_metric="Elicited CQs")
            wandb.define_metric("R2 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("MAE - GUI", step_metric="Elicited CQs")
            wandb.define_metric("MSE - GUI", step_metric="Elicited CQs")
            wandb.define_metric("KT top 5 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("R2 top 5 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("MAE top 5 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("MSE top 5 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("KT top 10 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("R2 top 10 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("MAE top 10 - GUI", step_metric="Elicited CQs")
            wandb.define_metric("MSE top 10 - GUI", step_metric="Elicited CQs")

        # Other metrics
        wandb.define_metric("Ordinal Dataset Size", step_metric="Elicited CQs")
        wandb.define_metric("Agreement on CQs", step_metric="Elicited CQs")
        wandb.define_metric("Allocated Bundle Value", step_metric="Elicited CQs")
        wandb.define_metric("Found New Max", step_metric="Elicited CQs")  # note: this mteric is only useful for the basic plus method 
        


    # --- Step 1. Load true and GUI (noisy) student lists, as well as prices --- #
    if load_students:
        print(f'supply_ratio: {supply_ratio} number_of_popular: {number_of_popular} index:{index} model_family: {model_family} queries: {queries}')
        (_, true_model_info, model_student_list, true_student_list,  prices_stage_1_GUI, timetable, capacities) = load_all_models_v2(index = index, model_family= model_family, benchmark_family = true_family, queries = '30-10', supply_ratio= supply_ratio, number_of_popular= number_of_popular, linear_instances = linear_instances)
    
        # set_trace()

    else:
        # create the model instance from scratch 
        true_student_lists_all_instances, capacities_all_instances, timetables_all_instances = generate_problem_instance_principled(number_of_times= 1, number_of_students= 100, number_of_courses= 25,
                supply_ratio= supply_ratio, number_of_popular= number_of_popular, seed= seed + student_index)
        
        true_student_list = true_student_lists_all_instances[0]
        capacities = capacities_all_instances[0]
        timetable = timetables_all_instances[0]

        prices_stage_1_GUI = np.array([0.2 for _ in range(25)])  # just a price vector so that each student can take 5 courses 
        
        if not linear_instances:
            model_info = ('UNN_transfer_learning', 
            {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 1, 'UNN_units': 20, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True, 'use_gradient_clipping': True, 
          'gui_forget_base': 0.5, 'gui_forget_adjustments': 0.48, 'gui_base_noise_std': 23, 'gui_adj_std': 0.2, 'gui_sample_categories_weights': [1, 1, 1], 
          'sample_relative_frequencies': (1, 1, 1, 1), 'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None, 'use_implied_dataset': True, 
          'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'lr_cardinal': 0.01, 'weight_decay_cardinal': 0.001, 'epochs_cardinal': 100, 'batch_size_cardinal': 8, 'UNN_loss_string_cardinal': 'l1', 'clip_cardinal': 0.1, 
          'lr_ordinal': 0.01, 'weight_decay_ordinal': 0.0, 'epochs_ordinal': 10, 'batch_size_ordinal': 8, 'UNN_loss_string_ordinal': 'BCE', 'clip_ordinal': 0.1})
            true_model_info = ('True', None)
    

    
    
    model_type = model_info[0]
    model_param_dictionary = model_info[1]
    number_of_courses = capacities.shape[0]
    model_param_dictionary['cq_method'] = args.cq_method

    # add extra information to the model_param_dictionary for projection. 
    projection_dictionary = {'project_to_gui': True, 'approximate_prices_model': 'gui_reports' ,'proj_alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 500, 'train_high_samples': 500,
        'linear_projection': linear_instances}
    model_param_dictionary.update(projection_dictionary)
    
    # true_student_list = [true_student_list[student_index]]
    true_student_list = true_student_list[student_index: student_index + students_to_check]

    # set the seeds for reproducability.  
    torch.manual_seed(seed + student_index)
    np.random.seed(seed + student_index)
    random.seed(seed + student_index)


    # --- Step 2. Generate the bundles that we will use for the generalization test --- # 
    
    bundles = generate_random_bundles(number_of_bundles= 2000, number_of_courses= 25, seed = seed)
    bundles_5_percentile = keep_high_valued_bundles(bundles= bundles, percentile = 95, benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0])
    bundles_10_percentile = keep_high_valued_bundles(bundles= bundles, percentile = 90, benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0])


    # --- Step 3. Generate the 'initial' training dataset based on the GUI reports --- #

    actual_student_list = create_actual_student_list(true_student_list, model_type = model_info[0], model_param_dictionary = model_param_dictionary, seed = seed)  # just returns the true list: NOTE: should remove
    gui_student_list = create_gui_student_list(true_student_list, model_type = model_type, model_param_dictionary = model_param_dictionary, seed = seed)


    # ---- ADDED STEP: GET PROMPT FOR LLM TO SIMULATE STUDENT'S ANSWER ------ #
    # note: not using timetable right now, needs to be added to the function call if we implement it someday.
    student_prompt = f"{generate_student_prompt(actual_student_list[0])}"
    print("Prompt for text generation: ", student_prompt)
    # query the LLM
    q_des = QuestionFreeText(
        question_name = "describe",
        question_text = student_prompt,
    )
    s = Survey([q_des])
    results = s.run()
    output = results.select("describe")
    # Inspect output
    print("LLM output: ", output)
    student_text = output


    
    training_set = generate_new_queries(actual_student_list = actual_student_list, timetable = timetable, number_of_samples = 0,  # number of samples is not used when generating the original dataset. Instead, it is read from the model_dictionary
                    current_training_set = [], model_student_list = [], approximate_prices = None,
                    credit_units = [1 for i in range(number_of_courses)], model_type = model_type, seed= seed + args.student_to_check, model_param_dictionary = model_param_dictionary, gui_student_list = gui_student_list)


    # --- Step 3.5  Generate initial student list based on the initial training set   --- # 
    if llm_cq_acquisition_function == 'infomax' or llm_cq_acquisition_function == 'doubleTS':
        model_student_lists = []
        # if the model requires bagging, we will create multiple models and use them to generate the CQs
        for i in range(value_model_collection_size):
            print(f'Creating model {i} for the collection')
            # set_trace()
            model_student_list = create_iterative_student_list(training_set, actual_student_list, credit_units = [1 for i in range(number_of_courses)],
                            timetable = timetable, model_type = model_type, model_param_dictionary = model_param_dictionary, 
                            model_student_list= None)
            
            model_student_lists.append(model_student_list)
    
    else:
        # if the model does not require bagging, we will only create one model and use it to generate the CQs
        model_student_lists = [create_iterative_student_list(training_set, actual_student_list, credit_units = [1 for i in range(number_of_courses)],
                            timetable = timetable, model_type = model_type, model_param_dictionary = model_param_dictionary, 
                            model_student_list= None)]  
    # set_trace()
        
    
    
    # --- Initial metric calculation for 0 CQs --- #
    elicited_llm_cqs = 0
    # Measure initial allocation value of the model
    allocation_model = calculate_allocation(model_list=model_student_lists[0], prices=prices_stage_1_GUI, timetable=timetable, models_to_run=[model_info])
    allocation_value_model = get_true_value_of_allocation(benchmark_student_list=true_student_list, individual_demands=allocation_model, benchmark_model=true_model_info[0], timetable=timetable)
    # Measure generalization performance of the model (0 CQs answered)
    # Measure generalization performance of the ML models and GUI reports
    kts_model, r2s_model, maes_model, mses_model = measure_generalization_performance_all_students(bundles_all_students= [bundles for i in range(len(true_student_list))], model_type= model_info[0], 
                                    model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0])
    kts_model_top_5, r2s_model_top_5, maes_model_top_5, mses_model_top_5 = measure_generalization_performance_all_students(bundles_all_students= [bundles_5_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                    model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0])
    kts_model_top_10, r2s_model_top_10, maes_model_top_10, mses_model_top_10 = measure_generalization_performance_all_students(bundles_all_students= [bundles_10_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                    model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0])
    # measure the centered versions of the metrics
    _, r2s_model_centered, maes_model_centered, mses_model_centered = measure_generalization_performance_all_students(bundles_all_students= [bundles for i in range(len(true_student_list))], model_type= model_info[0], 
                                    model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0], centered_metrics= True)
    _, r2s_model_top_5_centered, maes_model_top_5_centered, mses_model_top_5_centered = measure_generalization_performance_all_students(bundles_all_students= [bundles_5_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                    model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0], centered_metrics= True)
    _, r2s_model_top_10_centered, maes_model_top_10_centered, mses_model_top_10_centered = measure_generalization_performance_all_students(bundles_all_students= [bundles_10_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                    model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0], centered_metrics= True)

    # set_trace()
    
    wandb_dict = {
            'Elicited CQs': elicited_llm_cqs,
            'KT': np.mean(kts_model),
            'R2': np.mean(r2s_model),
            'MAE': np.mean(maes_model),
            'MSE': np.mean(mses_model),
            'KT top 5': np.mean(kts_model_top_5),
            'R2 top 5': np.mean(r2s_model_top_5),
            'MAE top 5': np.mean(maes_model_top_5),
            'MSE top 5': np.mean(mses_model_top_5),
            'KT top 10': np.mean(kts_model_top_10),
            'R2 top 10': np.mean(r2s_model_top_10),
            'MAE top 10': np.mean(maes_model_top_10),
            'MSE top 10': np.mean(mses_model_top_10),
            'R2 - Centered': np.mean(r2s_model_centered),
            'MAE - Centered': np.mean(maes_model_centered),
            'MSE - Centered': np.mean(mses_model_centered),
            'R2 top 5 - Centered': np.mean(r2s_model_top_5_centered),
            'MAE top 5 - Centered': np.mean(maes_model_top_5_centered),
            'MSE top 5 - Centered': np.mean(mses_model_top_5_centered),
            'R2 top 10 - Centered': np.mean(r2s_model_top_10_centered),
            'MAE top 10 - Centered': np.mean(maes_model_top_10_centered),
            'MSE top 10 - Centered': np.mean(mses_model_top_10_centered),
            'Allocated Bundle Value': np.mean(allocation_value_model),
        }
    
    if wandb_tracking:
        wandb.log(wandb_dict)

    
    for cq_iteration in range(1, llm_cq_iterations + 1):
        elicited_llm_cqs = cq_iteration * llm_cq_bsize
        print(f'--------> Number of elicited LLM CQs: {elicited_llm_cqs}')
        # set_trace()

        
        # --- Step 4: Generate the LLM CQs and get their answers --- # 
        for i in range(students_to_check):
            # set_trace()
            cq_bundle_tuples = generate_llm_cqs(number_of_cqs= llm_cq_bsize, 
                                acquisition_function= llm_cq_acquisition_function, capacities= capacities, 
                                number_of_candidate_bundles= llm_cq_candidates, 
                                model_type = model_info[0], model_param_dictionary = model_info[1], 
                                model_list = [model_student_list[i] for model_student_list in model_student_lists], 
                                boltzmann_temperature= boltzmann_temperature, 
                                max_attempts= 50)
            
            # set_trace()
                                                

        
            # cq_bundle_tuples, labels = simulate_llm_answers_cqs(bundle_tuples= cq_bundle_tuples, model_type= model_info[0], 
            #     model_param_dictionary= model_info[1], benchmark_model_type= true_model_info[0], benchmark_student= true_student_list[i], model_student_list= model_student_lists[0],
            #     llm_simulation_dictionary = {'reply_encoding': llm_cq_labels, 'hard_labels_flip_probability': llm_hard_labels_flip_probability,
            #                                 'soft_labels_noise_std': llm_soft_labels_noise_std, 'temperature-smoothening': llm_temperature_smoothening})  
            
            _, true_labels, value_list = simulate_llm_answers_cqs(bundle_tuples= cq_bundle_tuples, model_type= model_info[0], 
                model_param_dictionary= model_info[1], benchmark_model_type= true_model_info[0], benchmark_student= true_student_list[i], model_student_list= model_student_lists[0],
                llm_simulation_dictionary = {'reply_encoding': 'hard', 'hard_labels_flip_probability': 0.,
                                            'soft_labels_noise_std': 0., 'temperature-smoothening': llm_temperature_smoothening})  
                                                       

            questions, response, labels = generate_comparison_labels(cq_bundle_tuples, student_text)
            agreements = (true_labels == labels)
            num_agreements = np.sum(agreements)
            percentage = (num_agreements / len(labels)) * 100
            print("CORRECT PERCENTAGE: ", percentage)

            # Create a DataFrame
            data = {
                "Questions": questions,
                "response": response,
                "Labels1": labels,
                "true": true_labels,
                "value": value_list,
                "Agreements": agreements
            }
            df = pd.DataFrame(data)

            # Save to CSV
            file_name = f"llm_cq_result_{elicited_llm_cqs}.csv"
            file_path = f"llm_accuracy/{file_name}"
            df.to_csv(file_path, index=False)

            # set_trace()
            labels = np.array(labels)
            cq_bundle_tuples = np.array(cq_bundle_tuples)
            

        
             # Add those to the training set
            # training_set[i][1] = (cq_bundle_tuples, labels)
            if len(training_set[i][1][0]) == 0:
                # set_trace()
                training_set[i][1] = (cq_bundle_tuples, labels)
            else:
                # set_trace()
                all_tuples = np.vstack([training_set[i][1][0], cq_bundle_tuples])
                all_labels = np.hstack([training_set[i][1][1], labels])
                training_set[i][1] = (all_tuples, all_labels)

            

        # Step 4. Train the models on the current dataset -- NOTE: no need to change anything below this point 
        if llm_cq_acquisition_function == 'infomax' or llm_cq_acquisition_function == 'doubleTS':
            # model_student_lists = []
        # if the model requires bagging, we will create multiple models and use them to generate the CQs
            for model_index in range(value_model_collection_size):
                print(f'Creating model {model_index} for the collection')
                model_student_lists[model_index] = create_iterative_student_list(training_set, actual_student_list, credit_units = [1 for i in range(number_of_courses)],
                                timetable = timetable, model_type = model_type, model_param_dictionary = model_param_dictionary, 
                                model_student_list= model_student_lists[model_index])
                # model_student_lists.append(model_student_list)
        else:
            model_student_lists = [create_iterative_student_list(training_set, actual_student_list, credit_units = [1 for i in range(number_of_courses)],
                        timetable = timetable, model_type = model_type, model_param_dictionary = model_param_dictionary, 
                        model_student_list= model_student_lists[0])]
            
        
        # set_trace()
        
        
        # Step 5. Measure the performance of the models
    
        # Step 5a) Measure allocation value of the model 
        allocation_model = calculate_allocation(model_list= model_student_lists[0], prices= prices_stage_1_GUI, timetable= timetable, models_to_run= [model_info])
        allocation_value_model = get_true_value_of_allocation(benchmark_student_list= true_student_list, individual_demands= allocation_model, benchmark_model= true_model_info[0], timetable = timetable)
        # set_trace()

        

        # Step 6.a) Measure generalization performance of the ML models and GUI reports
        kts_model, r2s_model, maes_model, mses_model = measure_generalization_performance_all_students(bundles_all_students= [bundles for i in range(len(true_student_list))], model_type= model_info[0], 
                                        model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0])
        kts_model_top_5, r2s_model_top_5, maes_model_top_5, mses_model_top_5 = measure_generalization_performance_all_students(bundles_all_students= [bundles_5_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                        model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0])
        kts_model_top_10, r2s_model_top_10, maes_model_top_10, mses_model_top_10 = measure_generalization_performance_all_students(bundles_all_students= [bundles_10_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                        model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0])
        

        # measure the centered versions of the metrics
        _, r2s_model_centered, maes_model_centered, mses_model_centered = measure_generalization_performance_all_students(bundles_all_students= [bundles for i in range(len(true_student_list))], model_type= model_info[0], 
                                        model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0], centered_metrics= True)
        _, r2s_model_top_5_centered, maes_model_top_5_centered, mses_model_top_5_centered = measure_generalization_performance_all_students(bundles_all_students= [bundles_5_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                        model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0], centered_metrics= True)
        _, r2s_model_top_10_centered, maes_model_top_10_centered, mses_model_top_10_centered = measure_generalization_performance_all_students(bundles_all_students= [bundles_10_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                        model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_lists[0], centered_metrics= True)

        # set_trace()
    
        
        # # Step6.b) Measure number of agreements between the GUI reports and the true preferences
        # agreements = check_agreements_on_cq(model_student_list= model_student_list, model_info= model_info,benchmark_student_list= true_student_list, benchmark_model_info= true_model_info, cqs= last_queries, timetable= timetable)
        
        ordinal_info_size = get_ordinal_dataset_size(training_dataset= training_set, model_param_dictionary= model_param_dictionary)

        # Step 7. Project the MVNNs to measure results in the GUI language
        if project_to_gui:
            projected_mvnns = project_mvnns(mvnn_student_list= model_student_list, model_param_dictionary= model_param_dictionary)

            # Measure the number of base values and adjustments 
            base_values_model, adjustments_model = measure_gui_language_metrics(model_list = projected_mvnns, model_type = 'UNN_projected')
        
    
            # Measure similarity between projected ML models/GUI reports and the true (additive) preferences
            gui_language_kts_model_top10, gui_language_r2s_model_top10, gui_language_maes_model_top10, gui_language_mses_model_top10 = compare_gui_reports(true_student_list= true_student_list, other_student_list= projected_mvnns, model_type= 'UNN_projected', linear_instances = linear_instances, percentile= 60)
            gui_language_kts_model, gui_language_r2s_model, gui_language_maes_model, gui_language_mses_model = compare_gui_reports(true_student_list= true_student_list, other_student_list= projected_mvnns, model_type= 'UNN_projected', linear_instances = linear_instances)
            gui_language_kts_model_top5, gui_language_r2s_model_top5, gui_language_maes_model_top5, gui_language_mses_model_top5 = compare_gui_reports(true_student_list= true_student_list, other_student_list= projected_mvnns, model_type= 'UNN_projected', linear_instances = linear_instances, percentile= 80)
        
        
        
        wandb_dict = {
            'Elicited CQs': elicited_llm_cqs,
            'KT': np.mean(kts_model),
            'R2': np.mean(r2s_model),
            'MAE': np.mean(maes_model),
            'MSE': np.mean(mses_model),
            'KT top 5': np.mean(kts_model_top_5),
            'R2 top 5': np.mean(r2s_model_top_5),
            'MAE top 5': np.mean(maes_model_top_5),
            'MSE top 5': np.mean(mses_model_top_5),
            'KT top 10': np.mean(kts_model_top_10),
            'R2 top 10': np.mean(r2s_model_top_10),
            'MAE top 10': np.mean(maes_model_top_10),
            'MSE top 10': np.mean(mses_model_top_10),
            'R2 - Centered': np.mean(r2s_model_centered),
            'MAE - Centered': np.mean(maes_model_centered),
            'MSE - Centered': np.mean(mses_model_centered),
            'R2 top 5 - Centered': np.mean(r2s_model_top_5_centered),
            'MAE top 5 - Centered': np.mean(maes_model_top_5_centered),
            'MSE top 5 - Centered': np.mean(mses_model_top_5_centered),
            'R2 top 10 - Centered': np.mean(r2s_model_top_10_centered),
            'MAE top 10 - Centered': np.mean(maes_model_top_10_centered),
            'MSE top 10 - Centered': np.mean(mses_model_top_10_centered),
            'Allocated Bundle Value': np.mean(allocation_value_model),
            "LLM_Correct_prob": percentage,
        }
        # if model_param_dictionary['cq_method'] == 'basic_plus':
        #     wandb_dict['Found New Max'] = training_set[0][5]

        if project_to_gui:
            gui_dict = {
                'Base Values': base_values_model,
                'Adjustments': adjustments_model,
                'KT - GUI': gui_language_kts_model,
                'R2 - GUI': gui_language_r2s_model,
                'MAE - GUI': gui_language_maes_model,
                'MSE - GUI': gui_language_mses_model,
                'KT top 5 - GUI': gui_language_kts_model_top5,
                'R2 top 5 - GUI': gui_language_r2s_model_top5,
                'MAE top 5 - GUI': gui_language_maes_model_top5,
                'MSE top 5 - GUI': gui_language_mses_model_top5,
                'KT top 10 - GUI': gui_language_kts_model_top10,
                'R2 top 10 - GUI': gui_language_r2s_model_top10,
                'MAE top 10 - GUI': gui_language_maes_model_top10,
                'MSE top 10 - GUI': gui_language_mses_model_top10,
            }

            wandb_dict.update(gui_dict)

        result_dict[elicited_llm_cqs] = wandb_dict

        # set_trace()
        if wandb_tracking:
            wandb.log(wandb_dict)

            

        print(f'--------> Number of LLM CQs: {elicited_llm_cqs}, ordinal_info_size: {ordinal_info_size}, allocation values: {allocation_value_model}')

        # set_trace()

    if wandb_tracking: 
        run.finish()

    # save the results to a pickle file
    # dict_name = f'./acquisition_function_results/SR_{supply_ratio}_NP_{number_of_popular}_LI_{linear_instances}_AF_{args.cq_method}_student_{args.student_to_check}.pkl'
    # open the file for writing, and create it if it does not exist
    # with open(dict_name, 'wb+') as f:
    #     pickle.dump(result_dict, f)

    return 

if __name__ == '__main__':
    main_function()
