import numpy as np
import pickle
# import math
# import gurobipy as gp
from cleanup import student 

import argparse
import torch
import scipy.stats as stats
from sklearn.metrics import r2_score
from sklearn.preprocessing import PolynomialFeatures
import math 
import wandb 
from model_comparison import load_all_models_v2, get_true_value_of_allocation, generate_random_bundles, model_predict_values, measure_generalization_performance_all_students, keep_high_valued_bundles, calculate_allocation, measure_gui_language_metrics, compare_gui_reports
from tabu_search import binary_search, binary_search_hacky, construct_implied_dataset_v2, solve_student, create_actual_student_list, create_gui_student_list, create_iterative_student_list, project_mvnns
import copy 
import wandb 
import random 

# from pdb import set_trace

def generate_new_queries(actual_student_list, timetable, number_of_samples, current_training_set, model_student_list,
                approximate_prices, credit_units, model_type = 'LinearRegression', seed = 42, model_param_dictionary = {}, gui_student_list = [], get_last_query = False):
    if (approximate_prices is None):     # approximate prices == None -> drawing samples for the first time
        print('NEW QUERIES generating dataset from scratch')
        new_training_set = []
        if (model_param_dictionary.get('use_implied_dataset', False)):
            print('NEW QUERIES GENERATING IMPLIED DATASET from GUI reports, COMPARSION QUERIES next!')
            for (student_counter, (base_values, substitutes, complements, unforgotten_base_values)) in enumerate(gui_student_list):
                (X_train, y_train) = construct_implied_dataset_v2(additive_prefs= base_values, substitutes= substitutes, complements= complements,
                    unforgotten_base_values= unforgotten_base_values, make_monotone= True, points_to_add= model_param_dictionary['points_to_add'],
                    points_to_hallucinate= model_param_dictionary['points_to_hallucinate'], forgotten_course_expected_value= model_param_dictionary['forgotten_course_expected_value'],
                    thompson_sampling= False, chance_actual_zero= 0, uniform_range_low = model_param_dictionary['uniform_range_low'], uniform_range_high = model_param_dictionary['uniform_range_high'],
                    sample_category_weights= None, sample_relative_frequencies= model_param_dictionary['sample_relative_frequencies'], seed= None)

                if model_param_dictionary['cq_method'] in ['basic', 'random']:

                    # desired shape for this method:
                    # (model, solver, scale, budget) = model_student_list[i]  # get the current model for that student.
                    # (X_train, y_train) = current_training_set[i][0]  # get the current (cardinal) dataset
                    # (X_train_ord, y_train_ord) = current_training_set[i][1]  # get the current ordinal dataset
                    # (x_max_current, y_max_current) = current_training_set[i][2]    # get the current best performing point, according to the true student
                    # bundles_to_forbid = current_training_set[i][3]

                    max_idx = np.argmax(y_train)
                    x_max_current = X_train[max_idx]
                    y_max_current = y_train[max_idx]
                    bundles_to_forbid = X_train[X_train.sum(axis = 1) > 2].copy().tolist()
                    # do not ask a question about a point that you have already constructed from the GUI reports, and it is also a bundle of size 5!

                    new_training_set.append([(X_train, y_train), ([], []), (x_max_current, y_max_current), bundles_to_forbid])

                elif model_param_dictionary['cq_method'] in ['basic_plus']:
                    max_idx = np.argmax(y_train)
                    x_max_current = X_train[max_idx]
                    y_max_current = y_train[max_idx]
                    bundles_to_forbid = X_train[X_train.sum(axis = 1) > 2].copy().tolist()
                    bundles_queried = [x_max_current] # we will add the bundles that are involved in previous CQs, so in case the new point is better, we compare against all of them 

                    new_training_set.append([(X_train, y_train), ([], []), (x_max_current, y_max_current), bundles_to_forbid, bundles_queried])

                elif model_param_dictionary['cq_method'] in ['high_pair', 'random_high_pair']:
                    bundles_to_forbid = X_train[X_train.sum(axis = 1) > 2].copy().tolist()
                    new_training_set.append([(X_train, y_train), ([], []), [], bundles_to_forbid])

                elif model_param_dictionary['cq_method'] in ['complete_ordering', 'complete_ordering_pruned']:
                    idx_max = np.argmax(y_train)
                    x_max = X_train[idx_max]

                    (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget) = actual_student_list[student_counter]
                    y_max_true = student(x_max, additive_prefs, substitutes, complements, timetable,
                                overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                                credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)

                    bundles_to_forbid = X_train[X_train.sum(axis = 1) > 2].copy().tolist()

                    new_training_set.append(((X_train, y_train), (np.array([]), np.array([])), ([x_max], [y_max_true]), bundles_to_forbid, 0))

                elif model_param_dictionary['cq_method'] == 'basic_log':
                    sorted_idxs = np.argsort(y_train)
                    sorted_ys = y_train[sorted_idxs].copy().tolist()
                    sorted_Xs = X_train[sorted_idxs].copy().tolist()

                    bundles_to_forbid = X_train[X_train.sum(axis = 1) > 2].copy().tolist()

                    new_training_set.append(((X_train, y_train), (np.array([]), np.array([])), (sorted_Xs, sorted_ys), bundles_to_forbid))


        return(new_training_set)

    else:
        print("NEW QUERIES expanding on a dataset!!!")
        print(f'Using CQs next! Cq method: {model_param_dictionary["cq_method"]}')
        sample_relative_frequencies = model_param_dictionary.get('gui_sample_relative_frequencies', None)
        last_queries = [] # New: At position i, will store the last query for the i-th agent. 

        if (model_param_dictionary['cq_method'] in ['basic', 'basic_plus']):
            for (i, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(actual_student_list):
                print(f'Expanding the GUI dataset with CQs for student: {i}')
                if len(model_student_list[i]) == 5:
                    (model, solver, scale, _ , budget) = model_student_list[i]  # this is the case where position 3 has a pre-trained model that we keep fixed (transfer learning case)
                else:
                    (model, solver, scale, budget) = model_student_list[i]  # get the current model for that student.
                (X_train, y_train) = current_training_set[i][0]  # get the current (cardinal) dataset
                (X_train_ord, y_train_ord) = current_training_set[i][1]  # get the current ordinal dataset
                (x_max_current, _) = current_training_set[i][2]    # get the current best performing point, according to the true student
                bundles_to_forbid = current_training_set[i][3]   # get the bundles that you don't want your solver to find (e.g. already queried, or implied)

                if (model_param_dictionary['cq_method'] == 'basic_plus'):
                        bundles_queried = current_training_set[i][4]   # get the bundles that you have already queried (only applicable to the basic plus method)

                X_list = []
                y_list = []

                y_max_current = student(x_max_current, additive_prefs, substitutes, complements, timetable,
                            overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                            credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)

                print(f'Expanding the dataset for student: {i} whose X_train has a length of {len(X_train)} and X_train_ord has a shape of: {len(X_train_ord)}')
                print(f'get_iterative_dataset_cq called with currrent x_max: {x_max_current} and y_max: {y_max_current}')

                # add the budget constraint to the UNN solver
                solver.add_budget_constraint(course_prices = approximate_prices, budget = budget)

                # do not query the bundles you ahve already queried.
                print(f'for student: {i} forbidding {len(bundles_to_forbid)} bundles')
                for bundle in bundles_to_forbid:
                    solver.add_forbidden_bundle(bundle)

                for j in range(number_of_samples):
                    try:
                        new_x = solver.solve_mip(outputFlag=False, verbose = False)

                        if np.sum(new_x) >= 5:   # do not forbid a bundle that has too few courses! -> this has as a result to also forbid many effective bundles
                            solver.add_forbidden_bundle(new_x)  # add the new bundle we just queried to the list of forbidden bundles so that we don't ask the same question.
                            bundles_to_forbid.append(new_x)  # add the new bundle you just found to those that the maximizer won't return in the feauture

                    except:
                        print('--- ACHTUNG ACHTUNG ---')
                        print(f'GENERATE QUERIES STOPPED EARLY AT SAMPLE NUMBER {j}')
                        break

                    new_y = student(new_x, additive_prefs, substitutes, complements, timetable,
                            overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                            credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)
                    
                    
                    if (model_param_dictionary['cq_method'] == 'basic'):
                        new_x_ordinal = (new_x, x_max_current)  # compare the bundle that currently maximizes the MVNN to the max that you had
                        last_queries.append(new_x_ordinal)   # add the new query to the list of last queries 

                        print(f'new_y: {new_y} vs current y_max: {y_max_current}')
                        print(f'and new_x: {new_x} containing {np.sum(new_x)} courses')
                        if new_y >= y_max_current:
                            # the new point is larger than all in our current dataset -> update the info relating to the current max
                            print('Found a point that is better than the current max!')
                            new_y_ordinal = 1
                            x_max_current = new_x
                            y_max_current = new_y
                        else:
                            new_y_ordinal = 0

                        if (sample_relative_frequencies) is not None:
                            print(f'adding a new query as a sample {sample_relative_frequencies[2]} times!')
                            for _ in range(sample_relative_frequencies[2]):
                                X_list.append(new_x_ordinal)
                                y_list.append(new_y_ordinal)
                        else:
                            X_list.append(new_x_ordinal)
                            y_list.append(new_y_ordinal)

                    elif (model_param_dictionary['cq_method'] == 'basic_plus'):
                        last_queries.append((new_x, x_max_current))  # add the new query to the list of last queries (asked to the student)
                        if new_y <= y_max_current:
                            found_new_max = 0 
                            print('Found a point that is worse than the current max!')
                            # if the new point is not better than the old one: list behaves exactly in the same way as in the past! 
                            new_x_ordinal = (new_x, x_max_current)  # compare the bundle that currently maximizes the MVNN to the max that you had
                            new_y_ordinal = 0

                            if (sample_relative_frequencies) is not None:
                                print(f'adding a new query as a sample {sample_relative_frequencies[2]} times!')
                                for _ in range(sample_relative_frequencies[2]):
                                    X_list.append(new_x_ordinal)
                                    y_list.append(new_y_ordinal)
                            else:
                                X_list.append(new_x_ordinal)
                                y_list.append(new_y_ordinal)

                        else:
                            # The new bundle is better than all bundles queried in the past, need to add all of those comparison queries in the dataset
                            print('Found a point that is better than the current max, adding all implied bundles too!')
                            found_new_max = 1 
                            for bundle in bundles_queried:
                                new_x_ordinal = (new_x, bundle)
                                new_y_ordinal = 1

                                if (sample_relative_frequencies) is not None:
                                    print(f'adding a new query as a sample {sample_relative_frequencies[2]} times!')
                                    for _ in range(sample_relative_frequencies[2]):
                                        X_list.append(new_x_ordinal)
                                        y_list.append(new_y_ordinal)
                                    
                                else:
                                    X_list.append(new_x_ordinal)
                                    y_list.append(new_y_ordinal)

                            x_max_current = new_x
                            y_max_current = new_y


                        bundles_queried.append(new_x)   # add the new bundle you just found to those that you will explicitly enforce transitivity on in the future

                X_train_new = np.array(X_list)
                y_train_new = np.array(y_list)

                if len(X_list) >= 1:     # there are new CQs that we need to add to our dataset of comparision questions

                    if len(X_train_ord) > 0:   # add the new comparision queries to the dataset of comparisoin questions we already had
                        X_train_ord = np.append(X_train_ord, X_train_new, axis = 0)
                        y_train_ord = np.append(y_train_ord, y_train_new, axis = 0)

                    else:   # the dataset of CQs was empty -> replace it with asll the CQs we did this round!
                        X_train_ord = X_train_new
                        y_train_ord = y_train_new

                    if model_param_dictionary['cq_method'] == 'basic':
                        current_training_set[i] = ((X_train, y_train), (X_train_ord, y_train_ord), (x_max_current, y_max_current), bundles_to_forbid)
                    elif model_param_dictionary['cq_method'] == 'basic_plus':
                        current_training_set[i] = ((X_train, y_train), (X_train_ord, y_train_ord), (x_max_current, y_max_current), bundles_to_forbid, bundles_queried, found_new_max)

        elif (model_param_dictionary['cq_method'] in ['high_pair', 'random_high_pair']):
            for (i, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(actual_student_list):
                print(f'Expanding the GUI dataset with CQs for student: {i}')
                if len(model_student_list[i]) == 5:
                    (model, solver, scale, _ , budget) = model_student_list[i]  # this is the case where position 3 has a pre-trained model that we keep fixed (transfer learning case)
                else:
                    (model, solver, scale, budget) = model_student_list[i]  # get the current model for that student.
                (X_train, y_train) = current_training_set[i][0]  # get the current (cardinal) dataset
                (X_train_ord, y_train_ord) = current_training_set[i][1]  # get the current ordinal dataset
                queried_pairs = current_training_set[i][2]    # get the pairs that you have already queried
                bundles_to_forbid = current_training_set[i][3]   # get the bundles that you don't want your solver to find (e.g. already queried, or implied)

                X_list = []
                y_list = []

                print(f'Expanding the dataset for student: {i} whose X_train has a length of {len(X_train)} and X_train_ord has a shape of: {len(X_train_ord)}')
                print(f'get_iterative_dataset_cq called with currrent x_max: {x_max_current} and y_max: {y_max_current}')

                # add the budget constraint to the UNN solver
                solver.add_budget_constraint(course_prices = approximate_prices, budget = budget)

                # do not query the bundles you ahve already queried.
                print(f'for student: {i} forbidding {len(bundles_to_forbid)} bundles')
                for bundle in bundles_to_forbid:
                    solver.add_forbidden_bundle(bundle)

                # get the list of the top-valued bundles according to the MVNN 
                mvnn_bundles = [] 
                while len(mvnn_bundles) < model_param_dictionary['hp_bundles_to_generate']:
                    try:
                        new_x = solver.solve_mip(outputFlag=False, verbose = False)

                        solver.add_forbidden_bundle(new_x)           
                        # NOTE: In this metod you can always forbid the new bundle, since you are not adding it to the 
                        # "bundles_to_forbid_list" (which remains static, because what we care about is the *pair* of bundles in a CQ), 
                        # so forbidding a bundle of size < 5 does not harm you in followup iterations
                        
                        mvnn_bundles.append(new_x) # add the new bundle in the MVNN list of bundles

                    except:
                        print('--- ACHTUNG ACHTUNG ---')
                        print(f'GENERATE QUERIES STOPPED EARLY AFTER GENERATING {len(mvnn_bundles)} BUNDLES')
                        break

                if model_param_dictionary['cq_method'] == 'high_pair':
                    for j in range(len(mvnn_bundles)-1):
                        bundle1 = mvnn_bundles[j]
                        bundle2 = mvnn_bundles[j+1]
                        if not (bundle1, bundle2) in queried_pairs and not (bundle2, bundle1) in queried_pairs:
                            break 
                elif model_param_dictionary['cq_method'] == 'random_high_pair':
                    while True:
                        indexes = np.random.choice(len(mvnn_bundles), 2, replace = False)
                        bundle1 = mvnn_bundles[indexes[0]]
                        bundle2 = mvnn_bundles[indexes[1]]
                        if not (bundle1, bundle2) in queried_pairs and not (bundle2, bundle1) in queried_pairs:
                            break
                else: 
                    raise ValueError(f'cq_method {model_param_dictionary["cq_method"]} not recognized!')
                
                new_x_ordinal = (bundle1, bundle2)
                queried_pairs.append(new_x_ordinal)
                new_y1 = student(bundle1, additive_prefs, substitutes, complements, timetable,
                overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)
                new_y2 = student(bundle2, additive_prefs, substitutes, complements, timetable,
                overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)

                if new_y1 >= new_y2:
                    new_y_ordinal = 1
                else: 
                    new_y_ordinal = 0

                if (sample_relative_frequencies) is not None:
                    print(f'adding a new query as a sample {sample_relative_frequencies[2]} times!')
                    for _ in range(sample_relative_frequencies[2]):
                        X_list.append(new_x_ordinal)
                        y_list.append(new_y_ordinal)
                else:
                    X_list.append(new_x_ordinal)
                    y_list.append(new_y_ordinal)

            X_train_new = np.array(X_list)
            y_train_new = np.array(y_list)

            if len(X_list) >= 1:     # there are new CQ that we need to add to our dataset of comparision questions

                if len(X_train_ord) > 0:   # add the new comparision queries to the dataset of comparisoin questions we already had
                    X_train_ord = np.append(X_train_ord, X_train_new, axis = 0)
                    y_train_ord = np.append(y_train_ord, y_train_new, axis = 0)

                else:   # the dataset of CQs was empty -> replace it with asll the CQs we did this round!
                    X_train_ord = X_train_new
                    y_train_ord = y_train_new

                current_training_set[i] = ((X_train, y_train), (X_train_ord, y_train_ord), queried_pairs, bundles_to_forbid)

        elif (model_param_dictionary['cq_method'] == 'random'):
            for (i, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(actual_student_list):
                print(f'Expanding the GUI dataset with CQs for student: {i}')
                if len(model_student_list[i]) == 5:
                    (model, solver, scale, _ , budget) = model_student_list[i]  # this is the case where position 3 has a pre-trained model that we keep fixed (transfer learning case)
                else:
                    (model, solver, scale, budget) = model_student_list[i]  # get the current model for that student.
                (X_train, y_train) = current_training_set[i][0]  # get the current (cardinal) dataset
                (X_train_ord, y_train_ord) = current_training_set[i][1]  # get the current ordinal dataset
                (x_max_current, y_max_current) = current_training_set[i][2]    # get the current best performing point, according to the true student
                bundles_to_forbid = current_training_set[i][3]   # get the bundles that you don't want your solver to find (e.g. already queried, or implied)

                X_list = []
                y_list = []

                print(f'Expanding the dataset for student: {i} whose X_train has a length of {len(X_train)} and X_train_ord has a shape of: {len(X_train_ord)}')
                print(f'get_iterative_dataset_cq called with currrent x_max: {x_max_current} and y_max: {y_max_current}')

                # add the budget constraint to the UNN solver
                # do not query the bundles you ahve already queried.

                for j in range(number_of_samples):
                    x1_indices = np.random.choice(25, size = 5, replace = False)
                    x2_indices = np.random.choice(25, size = 5, replace = False)
                    x1 = [0 for i in range(25)]
                    x2 = [0 for i in range(25)]
                    for index in x1_indices:
                        x1[index] = 1
                    for index in x2_indices:
                        x2[index] = 1


                    new_y1 = student(x1, additive_prefs, substitutes, complements, timetable,
                            overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                            credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)
                    new_y2 = student(x2, additive_prefs, substitutes, complements, timetable,
                            overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                            credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)

                    new_x_ordinal = (x1, x2)    # NOTE: this line was bugged before. 
                    last_queries.append(new_x_ordinal)  


                    if new_y1 >= new_y2:
                        new_y_ordinal = 1
                    else:
                        new_y_ordinal = 0

                    if (sample_relative_frequencies) is not None:
                        print(f'adding a new query as a sample {sample_relative_frequencies[2]} times!')
                        for _ in range(sample_relative_frequencies[2]):
                            X_list.append(new_x_ordinal)
                            y_list.append(new_y_ordinal)
                    else:
                        X_list.append(new_x_ordinal)
                        y_list.append(new_y_ordinal)

                X_train_new = np.array(X_list)
                y_train_new = np.array(y_list)

                if len(X_list) >= 1:     # there are new CQ that we need to add to our dataset of comparision questions

                    if len(X_train_ord) > 0:   # add the new comparision queries to the dataset of comparison questions we already had
                        X_train_ord = np.append(X_train_ord, X_train_new, axis = 0)
                        y_train_ord = np.append(y_train_ord, y_train_new, axis = 0)

                    else:   # the dataset of CQs was empty -> replace it with asll the CQs we did this round!
                        X_train_ord = X_train_new
                        y_train_ord = y_train_new

                    current_training_set[i] = ((X_train, y_train), (X_train_ord, y_train_ord), (x_max_current, y_max_current), bundles_to_forbid)

        elif (model_param_dictionary['cq_method'] == 'basic_log'):
            print('Entering basic log version in generate new queries!')
            for (i, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(actual_student_list):
                print(f'Expanding the GUI dataset with CQs for student: {i}')
                if len(model_student_list[i]) == 5:
                    (model, solver, scale, _ , budget) = model_student_list[i]  # this is the case where position 3 has a pre-trained model that we keep fixed (transfer learning case)
                else:
                    (model, solver, scale, budget) = model_student_list[i]  # get the current model for that student.
                (X_train, y_train) = current_training_set[i][0]  # get the current (cardinal) dataset
                (X_train_ord, y_train_ord) = current_training_set[i][1]  # get the current ordinal dataset
                (all_bundles, all_values) = current_training_set[i][2]    # get the current best performing point, according to the true student
                bundles_to_forbid = current_training_set[i][3]   # get the bundles that you don't want your solver to find (e.g. already queried, or implied)

                X_list = []
                y_list = []

                print(f'Expanding the dataset for student: {i} whose X_train has a length of {len(X_train)} and X_train_ord has a shape of: {len(X_train_ord)}')

                # add the budget constraint to the UNN solver
                solver.add_budget_constraint(course_prices = approximate_prices, budget = budget)

                # do not query the bundles you ahve already queried.
                print(f'for student: {i} forbidding {len(bundles_to_forbid)} bundles')
                for bundle in bundles_to_forbid:
                    solver.add_forbidden_bundle(bundle)

                for j in range(number_of_samples):
                    try:
                        new_x = solver.solve_mip(outputFlag=False, verbose = False)
                        if np.sum(new_x) >= 5:
                            solver.add_forbidden_bundle(new_x)  # add the new bundle we just queried to the list of forbidden bundles so that we don't ask the same question.
                            bundles_to_forbid.append(new_x)  # add the new bundle you just found to those that the maximizer won't return in the feauture
                    # NOTE: should check that this works!!!
                    except:
                        print('--- ACHTUNG ACHTUNG ---')
                        print(f'GENERATE QUERIES STOPPED EARLY AT SAMPLE NUMBER {j}')
                        break

                    new_y = student(new_x, additive_prefs, substitutes, complements, timetable,
                            overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                            credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)

                    binary_search_indexes = binary_search(arr = all_values, x = new_y)  # those are the indexes of the binary search to find the position of the new bundle
                    print(f'Indexes of the binary search for the new_y: {binary_search_indexes}')

                    for k in binary_search_indexes:
                        old_x = all_bundles[k]
                        old_y = all_values[k]

                        new_x_ordinal = (new_x, old_x)  # compare the bundle that currently maximizes the MVNN to the max that you had

                        if new_y >= old_y:
                            new_y_ordinal = 1
                        else:
                            new_y_ordinal = 0

                        if (sample_relative_frequencies) is not None:
                            for _ in range(sample_relative_frequencies[2]):
                                X_list.append(new_x_ordinal)
                                y_list.append(new_y_ordinal)
                        else:
                            X_list.append(new_x_ordinal)
                            y_list.append(new_y_ordinal)

                    all_bundles.insert(binary_search_indexes[-1], new_x)   # insert the new point in the right place so that future iterations of the binary search still work!
                    all_values.insert(binary_search_indexes[-1], new_y)

                X_train_new = np.array(X_list)
                y_train_new = np.array(y_list)

                if len(X_list) >= 1:     # there are new CQ that we need to add to our dataset of comparision questions

                    if len(X_train_ord) > 0:   # add the new comparision queries to the dataset of comparisoin questions we already had
                        X_train_ord = np.append(X_train_ord, X_train_new, axis = 0)
                        y_train_ord = np.append(y_train_ord, y_train_new, axis = 0)

                    else:   # the dataset of CQs was empty -> replace it with asll the CQs we did this round!
                        X_train_ord = X_train_new
                        y_train_ord = y_train_new

                    current_training_set[i] = ((X_train, y_train), (X_train_ord, y_train_ord), (all_bundles, all_values), bundles_to_forbid)

        elif (model_param_dictionary['cq_method'] == 'complete_ordering'):
            print('Entering complete ordering version in generate new queries!')
            for (i, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(actual_student_list):
                print(f'Expanding the GUI dataset with CQs for student: {i}')
                if len(model_student_list[i]) == 5:
                    (model, solver, scale, _ , budget) = model_student_list[i]  # this is the case where position 3 has a pre-trained model that we keep fixed (transfer learning case)
                else:
                    (model, solver, scale, budget) = model_student_list[i]  # get the current model for that student.
                (X_train, y_train) = current_training_set[i][0]  # get the current (cardinal) dataset
                (X_train_ord, y_train_ord) = current_training_set[i][1]  # get the current ordinal dataset
                (all_bundles_queried, all_values_queried) = current_training_set[i][2]    # get the set of all points already queried
                bundles_to_forbid = current_training_set[i][3]   # get the bundles that you don't want your solver to find (e.g. already queried, or implied)

                X_list = []
                y_list = []

                print(f'Expanding the dataset for student: {i} whose X_train has a length of {len(X_train)} and X_train_ord has a shape of: {len(X_train_ord)}')

                # add the budget constraint to the UNN solver
                solver.add_budget_constraint(course_prices = approximate_prices, budget = budget)

                # do not query the bundles you ahve already queried.
                print(f'for student: {i} forbidding {len(bundles_to_forbid)} bundles')
                for bundle in bundles_to_forbid:
                    solver.add_forbidden_bundle(bundle)

                for j in range(number_of_samples):
                    try:
                        new_x = solver.solve_mip(outputFlag=False, verbose = False)
                        if np.sum(new_x) >= 5:
                            solver.add_forbidden_bundle(new_x)  # add the new bundle we just queried to the list of forbidden bundles so that we don't ask the same question.
                            bundles_to_forbid.append(new_x)  # add the new bundle you just found to those that the maximizer won't return in the feauture
                    # NOTE: should check that this works!!!
                    except:
                        print('--- ACHTUNG ACHTUNG ---')
                        print(f'GENERATE QUERIES STOPPED EARLY AT SAMPLE NUMBER {j}')
                        break

                    new_y = student(new_x, additive_prefs, substitutes, complements, timetable,
                            overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                            credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)

                    print(f'For student {i}, all bundles queried has a size of: {len(all_bundles_queried)}')

                    for k in range(len(all_bundles_queried)):
                        old_x = all_bundles_queried[k]
                        old_y = all_values_queried[k]

                        new_x_ordinal = (new_x, old_x)  # compare the bundle that currently maximizes the MVNN to the max that you had

                        if new_y >= old_y:
                            new_y_ordinal = 1
                        else:
                            new_y_ordinal = 0

                        if (sample_relative_frequencies) is not None:
                            for _ in range(sample_relative_frequencies[2]):
                                X_list.append(new_x_ordinal)
                                y_list.append(new_y_ordinal)
                        else:
                            X_list.append(new_x_ordinal)
                            y_list.append(new_y_ordinal)

                    print(f'For student {i}, the length of the new X_list after we just added 1 "sample" to it is: {len(X_list)}')

                    all_bundles_queried.append(new_x)   # insert the new point in the list of all bundles queried.
                    all_values_queried.append(new_y)

                X_train_new = np.array(X_list)
                y_train_new = np.array(y_list)

                if len(X_list) >= 1:     # there are new CQ that we need to add to our dataset of comparision questions

                    if len(X_train_ord) > 0:   # add the new comparision queries to the dataset of comparisoin questions we already had
                        X_train_ord = np.append(X_train_ord, X_train_new, axis = 0)
                        y_train_ord = np.append(y_train_ord, y_train_new, axis = 0)

                    else:   # the dataset of CQs was empty -> replace it with asll the CQs we did this round!
                        X_train_ord = X_train_new
                        y_train_ord = y_train_new

                    current_training_set[i] = ((X_train, y_train), (X_train_ord, y_train_ord), (all_bundles_queried, all_values_queried), bundles_to_forbid)

        elif (model_param_dictionary['cq_method'] == 'complete_ordering_pruned'):
            print('Entering complete ordering PRUNED version in generate new queries!')
            for (i, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(actual_student_list):
                print(f'Expanding the GUI dataset with CQs for student: {i}')
                if len(model_student_list[i]) == 5:
                    (model, solver, scale, _ , budget) = model_student_list[i]  # this is the case where position 3 has a pre-trained model that we keep fixed (transfer learning case)
                else:
                    (model, solver, scale, budget) = model_student_list[i]  # get the current model for that student.
                (X_train, y_train) = current_training_set[i][0]  # get the current (cardinal) dataset
                (X_train_ord, y_train_ord) = current_training_set[i][1]  # get the current ordinal dataset
                (all_bundles_queried, all_values_queried) = current_training_set[i][2]    # get the set of all points already queried
                bundles_to_forbid = current_training_set[i][3]   # get the bundles that you don't want your solver to find (e.g. already queried, or implied)
                queries_thus_far = current_training_set[i][4]  # so that we can count how many binary search queries were required in every iteration

                X_list = []
                y_list = []

                print(f'Expanding the dataset for student: {i} whose X_train has a length of {len(X_train)} and X_train_ord has a shape of: {len(X_train_ord)}')

                # add the budget constraint to the UNN solver
                solver.add_budget_constraint(course_prices = approximate_prices, budget = budget)

                # do not query the bundles you ahve already queried.
                print(f'for student: {i} forbidding {len(bundles_to_forbid)} bundles')
                for bundle in bundles_to_forbid:
                    solver.add_forbidden_bundle(bundle)

                for j in range(number_of_samples):
                    try:
                        new_x = solver.solve_mip(outputFlag=False, verbose = False)
                    except:
                        print('--- ACHTUNG ACHTUNG ---')
                        print(f'GENERATE QUERIES STOPPED EARLY AT SAMPLE NUMBER {j}')
                        break

                    solver.add_forbidden_bundle(new_x)  # add the new bundle we just queried to the list of forbidden bundles so that we don't ask the same question.
                    bundles_to_forbid.append(new_x)  # add the new bundle you just found to those that the maximizer won't return in the feauture
                    new_y = student(new_x, additive_prefs, substitutes, complements, timetable,
                            overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                            credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)

                    print(f'For student {i}, all bundles queried has a size of: {len(all_bundles_queried)}')
                    print(f'For student number: {i}, number of CQs already answered: {queries_thus_far}')
                    binary_search_indexes = binary_search(arr = all_values_queried, x = new_y)  # those are the indexes of the binary search to find the position of the new bundle
                    cq_limit = model_param_dictionary['cq_limit']

                    if (queries_thus_far + len(binary_search_indexes) <= cq_limit):
                        print(f'Enough queries remaining for student {i}, treating this as insertion sort')
                        queries_thus_far = queries_thus_far + len(binary_search_indexes)
                        for k in range(len(all_bundles_queried)):
                            old_x = all_bundles_queried[k]
                            old_y = all_values_queried[k]

                            new_x_ordinal = (new_x, old_x)  

                            if new_y >= old_y:
                                new_y_ordinal = 1
                            else:
                                new_y_ordinal = 0

                            if (sample_relative_frequencies) is not None:
                                for _ in range(sample_relative_frequencies[2]):
                                    X_list.append(new_x_ordinal)
                                    y_list.append(new_y_ordinal)
                            else:
                                X_list.append(new_x_ordinal)
                                y_list.append(new_y_ordinal)

                        print(f'-------> ALL VALUES BEFORE INSERTION: {all_values_queried}')
                        if (all_values_queried[binary_search_indexes[-1]] >= new_y):
                            all_bundles_queried.insert(binary_search_indexes[-1], new_x)
                            all_values_queried.insert(binary_search_indexes[-1], new_y)
                        else:
                            all_bundles_queried.insert(binary_search_indexes[-1] + 1, new_x)
                            all_values_queried.insert(binary_search_indexes[-1] + 1, new_y)
                        print(f'-----> ALL VALUES AFTER INSERTION: {all_values_queried}')

                    elif (queries_thus_far >= cq_limit):
                        print(f'DONE! FOR student: {i} already at {queries_thus_far} queries, not adding anything for this student')

                    else:
                        print(f'CORNER CASE!!! For student: {i} Already at {queries_thus_far} queries, and the new BS requires {len(binary_search_indexes)} queries, so have to prune smartly!')
                        queries_left = cq_limit - queries_thus_far
                        queries_thus_far = queries_left + queries_thus_far

                        indexes_checked, lower_bound, upper_bound = binary_search_hacky(arr = all_values_queried, x = new_y, queries_available= queries_left)
                        print(f'Value of the new point: {new_y}')

                        counter_to_low = lower_bound
                        while counter_to_low >= 0:
                            print(f'Know that it is higher than the point with value: {all_values_queried[counter_to_low]}')
                            old_x = all_bundles_queried[counter_to_low]
                            new_x_ordinal = (new_x, old_x)
                            new_y_ordinal = 1
                            X_list.append(new_x_ordinal)
                            y_list.append(new_y_ordinal)
                            counter_to_low = counter_to_low - 1

                        counter_to_high = upper_bound
                        while counter_to_high < len(all_bundles_queried):
                            print(f'Know that it is LOWER than the point with value: {all_values_queried[counter_to_high]}')
                            old_x = all_bundles_queried[counter_to_high]
                            new_x_ordinal = (new_x, old_x)
                            new_y_ordinal = 0
                            X_list.append(new_x_ordinal)
                            y_list.append(new_y_ordinal)
                            counter_to_high = counter_to_high + 1

                print(f'For student {i}, the length of the new X_list after we just added 1 "sample" to it is: {len(X_list)}')

                X_train_new = np.array(X_list)
                y_train_new = np.array(y_list)

                if len(X_list) >= 1:     # there are new CQ that we need to add to our dataset of comparision questions

                    if len(X_train_ord) > 0:   # add the new comparision queries to the dataset of comparisoin questions we already had
                        X_train_ord = np.append(X_train_ord, X_train_new, axis = 0)
                        y_train_ord = np.append(y_train_ord, y_train_new, axis = 0)

                    else:   # the dataset of CQs was empty -> replace it with all the CQs we did this round!
                        X_train_ord = X_train_new
                        y_train_ord = y_train_new

                current_training_set[i] = ((X_train, y_train), (X_train_ord, y_train_ord), (all_bundles_queried, all_values_queried), bundles_to_forbid, queries_thus_far)

        elif (model_param_dictionary['cq_method'] == 'complete_ordering_pruned_in_detail'):
            print('Entering complete ordering PRUNED version in generate new queries!')
            for (i, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(actual_student_list):
                print(f'Expanding the GUI dataset with CQs for student: {i}')
                if len(model_student_list[i]) == 5:
                    (model, solver, scale, _ , budget) = model_student_list[i]  # this is the case where position 3 has a pre-trained model that we keep fixed (transfer learning case)
                else:
                    (model, solver, scale, budget) = model_student_list[i]  # get the current model for that student.
                (X_train, y_train) = current_training_set[i][0]  # get the current (cardinal) dataset
                (X_train_ord, y_train_ord) = current_training_set[i][1]  # get the current ordinal dataset
                (all_bundles_queried, all_values_queried) = current_training_set[i][2]    # get the set of all points already queried
                bundles_to_forbid = current_training_set[i][3]   # get the bundles that you don't want your solver to find (e.g. already queried, or implied)
                queries_thus_far = current_training_set[i][4]  # so that we can count how many binary search queries were required in every iteration

                X_list = []
                y_list = []

                print(f'Expanding the dataset for student: {i} whose X_train has a length of {len(X_train)} and X_train_ord has a shape of: {len(X_train_ord)}')

                # add the budget constraint to the UNN solver
                solver.add_budget_constraint(course_prices = approximate_prices, budget = budget)

                # do not query the bundles you ahve already queried.
                print(f'for student: {i} forbidding {len(bundles_to_forbid)} bundles')
                for bundle in bundles_to_forbid:
                    solver.add_forbidden_bundle(bundle)

                try:
                    new_x = solver.solve_mip(outputFlag=False, verbose = False)
                except:
                    print('--- ACHTUNG ACHTUNG ---')
                    print(f'GENERATE QUERIES STOPPED EARLY AT SAMPLE NUMBER {j}')
                    break

                solver.add_forbidden_bundle(new_x)  # add the new bundle we just queried to the list of forbidden bundles so that we don't ask the same question.
                bundles_to_forbid.append(new_x)  # add the new bundle you just found to those that the maximizer won't return in the feauture
                new_y = student(new_x, additive_prefs, substitutes, complements, timetable,
                        overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                        credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)

                print(f'For student {i}, all bundles queried has a size of: {len(all_bundles_queried)}')
                print(f'For student number: {i}, number of CQs already answered: {queries_thus_far}')
                binary_search_indexes = binary_search(arr = all_values_queried, x = new_y)  # those are the indexes of the binary search to find the position of the new bundle
                cq_limit = model_param_dictionary['cq_limit']

                if (queries_thus_far + len(binary_search_indexes) <= cq_limit):
                    finished_iteration = True 
                    print(f'Enough queries remaining for student {i}, treating this as insertion sort')
                    queries_thus_far = queries_thus_far + len(binary_search_indexes)                    
                    for k in range(len(all_bundles_queried)):
                        old_x = all_bundles_queried[k]
                        old_y = all_values_queried[k]

                        new_x_ordinal = (new_x, old_x)  

                        if new_y >= old_y:
                            new_y_ordinal = 1
                        else:
                            new_y_ordinal = 0

                        if (sample_relative_frequencies) is not None:
                            for _ in range(sample_relative_frequencies[2]):
                                X_list.append(new_x_ordinal)
                                y_list.append(new_y_ordinal)
                        else:
                            X_list.append(new_x_ordinal)
                            y_list.append(new_y_ordinal)

                    print(f'-------> ALL VALUES BEFORE INSERTION: {all_values_queried}')
                    if (all_values_queried[binary_search_indexes[-1]] >= new_y):
                        all_bundles_queried.insert(binary_search_indexes[-1], new_x)
                        all_values_queried.insert(binary_search_indexes[-1], new_y)
                    else:
                        all_bundles_queried.insert(binary_search_indexes[-1] + 1, new_x)
                        all_values_queried.insert(binary_search_indexes[-1] + 1, new_y)
                    print(f'-----> ALL VALUES AFTER INSERTION: {all_values_queried}')

                    last_queries.append((new_x, all_bundles_queried[binary_search_indexes[-1]]))


                elif (queries_thus_far >= cq_limit):
                    finished_iteration = False
                    print(f'DONE! FOR student: {i} already at {queries_thus_far} queries, not adding anything for this student')

                else:
                    finished_iteration = False
                    print(f'CORNER CASE!!! For student: {i} Already at {queries_thus_far} queries, and the new BS requires {len(binary_search_indexes)} queries, so have to prune smartly!')
                    queries_left = cq_limit - queries_thus_far
                    queries_thus_far = queries_left + queries_thus_far

                    indexes_checked, lower_bound, upper_bound = binary_search_hacky(arr = all_values_queried, x = new_y, queries_available= queries_left)
                    print(f'Value of the new point: {new_y}')

                    counter_to_low = lower_bound
                    while counter_to_low >= 0:
                        print(f'Know that it is higher than the point with value: {all_values_queried[counter_to_low]}')
                        old_x = all_bundles_queried[counter_to_low]
                        new_x_ordinal = (new_x, old_x)
                        new_y_ordinal = 1
                        X_list.append(new_x_ordinal)
                        y_list.append(new_y_ordinal)
                        counter_to_low = counter_to_low - 1

                    counter_to_high = upper_bound
                    while counter_to_high < len(all_bundles_queried):
                        print(f'Know that it is LOWER than the point with value: {all_values_queried[counter_to_high]}')
                        old_x = all_bundles_queried[counter_to_high]
                        new_x_ordinal = (new_x, old_x)
                        new_y_ordinal = 0
                        X_list.append(new_x_ordinal)
                        y_list.append(new_y_ordinal)
                        counter_to_high = counter_to_high + 1

                    last_queries.append((new_x, all_bundles_queried[indexes_checked[-1]]))

                print(f'For student {i}, the length of the new X_list after we just added 1 "sample" to it is: {len(X_list)}')

                X_train_new = np.array(X_list)
                y_train_new = np.array(y_list)

                if len(X_list) >= 1:     # there are new CQ that we need to add to our dataset of comparision questions

                    if len(X_train_ord) > 0:   # add the new comparision queries to the dataset of comparisoin questions we already had
                        X_train_ord = np.append(X_train_ord, X_train_new, axis = 0)
                        y_train_ord = np.append(y_train_ord, y_train_new, axis = 0)

                    else:   # the dataset of CQs was empty -> replace it with all the CQs we did this round!
                        X_train_ord = X_train_new
                        y_train_ord = y_train_new

                current_training_set[i] = ((X_train, y_train), (X_train_ord, y_train_ord), (all_bundles_queried, all_values_queried), bundles_to_forbid, queries_thus_far, finished_iteration)

        
       

        if get_last_query:
            return current_training_set, last_queries
        else:
            return current_training_set


def check_agreements_on_cq(model_student_list, model_info,benchmark_student_list, benchmark_model_info, cqs, timetable):
    """
    Takes as input the model student list, the benchmark (true) student list, and a cq for each student. 
    Returns if the model and the benchmark agree on the answer to the CQ.
    """
    agreements = [] 
    for i in range(len(model_student_list)): 
        # model = model_student_list[i][0]
        # benchmark = benchmark_student_list[i][0]
        (bundle1, bundle2) = cqs[i]
        bundle_list = [bundle1, bundle2]
        ys_model = model_predict_values(model_type = model_info[0], model_param_dictionary = model_info[1], model = model_student_list[i], bundles = bundle_list, timetable = timetable)
        ys_true = model_predict_values(model_type = benchmark_model_info[0], model_param_dictionary = benchmark_model_info[1], model = benchmark_student_list[i], bundles = bundle_list, timetable = timetable)
        if (ys_model[1] - ys_model[0]) * (ys_true[1] - ys_true[0]) > 0:
            agreements.append(1)
        else:
            agreements.append(0)
    return agreements


def get_ordinal_dataset_size(training_dataset, model_param_dictionary):
    all_lengths = [] 
    if model_param_dictionary['cq_method'] in ['basic', 'basic_plus', 'random']:
        for j in range(len(training_dataset)):
            all_lengths.append(training_dataset[j][1][1].shape[0])
    elif model_param_dictionary['cq_method'] in ['complete_ordering', 'complete_ordering_pruned', 'complete_ordering_pruned_in_detail']:
        for j in range(len(training_dataset)):
            all_lengths.append(training_dataset[j][1][1].shape[0])
    else: 
        raise ValueError('Unknown cq_method in get_ordinal_dataset_size')

    return np.array(all_lengths)


def main_function():
   
    parser = argparse.ArgumentParser()
    parser.add_argument('--supply_ratio', type=float, default= 1.25, help='supply_ratio')
    parser.add_argument('--number_of_popular', type= int, default= 9, help='afffects_correlation')
    parser.add_argument('--model_family', type= str, default= 'BPpop9v.1', help='the models to run') # options: [PO/B]pop9_projected_v1.1_from_scratch_gui_reports, 'BPpop9v.1' and 'POpop9'
    parser.add_argument('--linear_instances', type= str, default= 'false', help='whether we are in linear instances or not')
    parser.add_argument('--seed', type= int, default= 627, help='the seed to use for the random number generator')
    parser.add_argument('--student_to_check', type= int, default= 42, help='the student that this run should check')
    parser.add_argument('--cq_method', type= str, default= 'basic_plus', help='acquisition function to use') # options: 'basic', 'random', 'basic_plus', 'complete_ordering', 'complete_ordering_pruned'
    parser.add_argument('--wandb_tracking', type= str, default= 'true', help='whether to track the results with wandb')
    parser.add_argument('--TL_model', type= str, default= 'true', help='whether to use transfer learning model')


    args = parser.parse_args()

    supply_ratio = args.supply_ratio
    number_of_popular = args.number_of_popular
    index = args.student_to_check // 100
    student_index = args.student_to_check % 100
    model_family = args.model_family
    queries = '30-10'  # only needed to load the dictionaries, does not play any role 
    linear_instances = str(args.linear_instances).lower() == 'true'
    seed = args.seed 
    max_queries = 50   # TODO: change back to something larger, 40/50
    wandb_tracking = str(args.wandb_tracking).lower() == 'true'
    tl_model = str(args.TL_model).lower() == 'true'

    if tl_model:
        model_info = ('UNN_transfer_learning', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 
            'UNN_layers': 1, 'UNN_units': 20, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True, 'use_gradient_clipping': True,
        'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
        'sample_relative_frequencies': (1, 1, 1, 1),
        'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
        'lr_cardinal': 0.01, 'weight_decay_cardinal': 0.001, 'epochs_cardinal': 100, 'batch_size_cardinal': 8,  'UNN_loss_string_cardinal': 'l1', 'clip_cardinal': 0.1,
        'lr_ordinal':  0.01, 'weight_decay_ordinal': 0, 'epochs_ordinal': 10, 'batch_size_ordinal': 8,  'UNN_loss_string_ordinal': 'BCE', 'clip_ordinal': 0.1,
        'use_implied_dataset': True,
        'use_cqs': True, 'cq_method': 'complete_ordering_pruned'
                                    })  # 'UNN_transfer_learning'
        
    else: 
        model_info =  ('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 3, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
        'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
        'use_implied_dataset': True,
        'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
        'sample_relative_frequencies': (1, 1, 1, 1),
        'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
        'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                    })
        

    # set_trace()
    
    if linear_instances:
        true_family = 'true_linear'
    else:
        true_family = 'true'

    result_dict = {} # will save the exact same details as the wandb tracking, but in a dictionary that we can save to a pickle file

    if wandb_tracking:
        wandb_config_dict = {
            'Supply Ratio': supply_ratio,
            'Number of Popular': number_of_popular,
            'Linear Instances': linear_instances,
            'Student Number': args.student_to_check,
            'Acquisition Function': args.cq_method,
            'Model Type': model_info[0]
        }

        run = wandb.init(project=f'MLCM-AF-Comparison-v1.1', # TODO: change to appropriate name
                    config=wandb_config_dict,
                    reinit=True)

        wandb.define_metric("Elicited CQs") 
        
        # Standard Learning Metrics 
        wandb.define_metric("KT", step_metric="Elicited CQs")
        wandb.define_metric("R2", step_metric="Elicited CQs")
        wandb.define_metric("MAE", step_metric="Elicited CQs")
        wandb.define_metric("MSE", step_metric="Elicited CQs")
        wandb.define_metric("KT top 5", step_metric="Elicited CQs")
        wandb.define_metric("R2 top 5", step_metric="Elicited CQs")
        wandb.define_metric("MAE top 5", step_metric="Elicited CQs")
        wandb.define_metric("MSE top 5", step_metric="Elicited CQs")
        wandb.define_metric("KT top 10", step_metric="Elicited CQs")
        wandb.define_metric("R2 top 10", step_metric="Elicited CQs")
        wandb.define_metric("MAE top 10", step_metric="Elicited CQs")
        wandb.define_metric("MSE top 10", step_metric="Elicited CQs")

        # GUI Metrics
        
        wandb.define_metric("Base Values", step_metric="Elicited CQs")
        wandb.define_metric("Adjustments", step_metric="Elicited CQs")
        wandb.define_metric("KT - GUI", step_metric="Elicited CQs")
        wandb.define_metric("R2 - GUI", step_metric="Elicited CQs")
        wandb.define_metric("MAE - GUI", step_metric="Elicited CQs")
        wandb.define_metric("MSE - GUI", step_metric="Elicited CQs")
        wandb.define_metric("KT top 5 - GUI", step_metric="Elicited CQs")
        wandb.define_metric("R2 top 5 - GUI", step_metric="Elicited CQs")
        wandb.define_metric("MAE top 5 - GUI", step_metric="Elicited CQs")
        wandb.define_metric("MSE top 5 - GUI", step_metric="Elicited CQs")
        wandb.define_metric("KT top 10 - GUI", step_metric="Elicited CQs")
        wandb.define_metric("R2 top 10 - GUI", step_metric="Elicited CQs")
        wandb.define_metric("MAE top 10 - GUI", step_metric="Elicited CQs")
        wandb.define_metric("MSE top 10 - GUI", step_metric="Elicited CQs")

        # Other metrics
        wandb.define_metric("Ordinal Dataset Size", step_metric="Elicited CQs")
        wandb.define_metric("Agreement on CQs", step_metric="Elicited CQs")
        wandb.define_metric("Allocated Bundle Value", step_metric="Elicited CQs")
        wandb.define_metric("Found New Max", step_metric="Elicited CQs")  # note: this mteric is only useful for the basic plus method 
        


    # Step 1. Load true and GUI (noisy) student lists, as well as prices
    print(f'supply_ratio: {supply_ratio} number_of_popular: {number_of_popular} index:{index} model_family: {model_family} queries: {queries}')
    (_, true_model_info, model_student_list, true_student_list,  prices_stage_1_GUI, timetable, capacities) = load_all_models_v2(index = index, model_family= model_family, benchmark_family = true_family, queries = '30-10', supply_ratio= supply_ratio, number_of_popular= number_of_popular, linear_instances = linear_instances)
    # (gui_model_info, true_model_info, model_student_list, true_student_list,  prices_stage_1_GUI, timetable, capacities) = load_all_models_v2(index = index, model_family= gui_model_family, benchmark_family = 'true', queries = '30-10', supply_ratio= supply_ratio, number_of_popular= number_of_popular, linear_instances = linear_instances)

    
    
    
    model_type = model_info[0]
    model_param_dictionary = model_info[1]
    number_of_courses = capacities.shape[0]
    model_param_dictionary['cq_method'] = args.cq_method

    # add extra information to the model_param_dictionary for projection. 
    projection_dictionary = {'project_to_gui': True, 'approximate_prices_model': 'gui_reports' ,'proj_alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 500, 'train_high_samples': 500,
        'linear_projection': linear_instances}
    model_param_dictionary.update(projection_dictionary)
    
    true_student_list = [true_student_list[student_index]]

    # set the seeds for reproducability.  
    torch.manual_seed(seed + student_index)
    np.random.seed(seed + student_index)
    random.seed(seed + student_index)


    # Step 2. Generate the bundles that we will use for the generalization test
    bundles = generate_random_bundles(number_of_bundles= 1000, number_of_courses= 25, seed = seed)
    bundles_5_percentile = keep_high_valued_bundles(bundles= bundles, percentile = 95, benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0])
    bundles_10_percentile = keep_high_valued_bundles(bundles= bundles, percentile = 90, benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0])



    # Step 3. Generate the 'initial' training dataset based on the GUI reports

    actual_student_list = create_actual_student_list(true_student_list, model_type = model_info[0], model_param_dictionary = model_param_dictionary, seed = seed)  # just returns the true list: NOTE: should remove
    gui_student_list = create_gui_student_list(true_student_list, model_type = model_type, model_param_dictionary = model_param_dictionary, seed = seed)

    
    training_set = generate_new_queries(actual_student_list = actual_student_list, timetable = timetable, number_of_samples = 0,  # number of samples is not used when generating the original dataset. Instead, it is read from the model_dictionary
                    current_training_set = [], model_student_list = [], approximate_prices = None,
                    credit_units = [1 for i in range(number_of_courses)], model_type = model_type, seed= seed + args.student_to_check, model_param_dictionary = model_param_dictionary, gui_student_list = gui_student_list)

    model_student_list = None 

    if 'complete_ordering' not in model_param_dictionary['cq_method']:
        
        if model_param_dictionary['cq_method'] == 'basic_plus':
            for j in range(len(training_set)):
                training_set[j].append(0)  # first time -> you have not found a better new max. 
    
        for query_number in range(0, max_queries + 1):
            # Step 4. Train the models on the current dataset 
            model_student_list = create_iterative_student_list(training_set, actual_student_list, credit_units = [1 for i in range(number_of_courses)],
                            timetable = timetable, model_type = model_type, model_param_dictionary = model_param_dictionary, 
                            model_student_list= model_student_list)
            
            # Step 4b) get the new dataset for the next iteration
            training_set, last_queries = generate_new_queries(actual_student_list = actual_student_list, timetable = timetable, number_of_samples = 1, # only generate 1 sample at a time so that you can log the reuslts afterwards (and retrain if the method calls for it)
                        current_training_set = training_set, model_student_list = model_student_list, approximate_prices = prices_stage_1_GUI,
                        credit_units = [1 for i in range(number_of_courses)], model_type = model_type, seed= seed + args.student_to_check, model_param_dictionary = model_param_dictionary, gui_student_list = gui_student_list, 
                        get_last_query = True)
            
            
            # Step 5. Measure the performance of the models
        
            # Step 5a) Measure allocation value of the model 
            allocation_model = calculate_allocation(model_list= model_student_list, prices= prices_stage_1_GUI, timetable= timetable, models_to_run= [model_info])
            allocation_value_model = get_true_value_of_allocation(benchmark_student_list= true_student_list, individual_demands= allocation_model, benchmark_model= true_model_info[0], timetable = timetable)

            

            # Step 6.a) Measure generalization performance of the ML models and GUI reports
            kts_model, r2s_model, maes_model, mses_model = measure_generalization_performance_all_students(bundles_all_students= [bundles for i in range(len(true_student_list))], model_type= model_info[0], 
                                            model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_list)
            kts_model_top_5, r2s_model_top_5, maes_model_top_5, mses_model_top_5 = measure_generalization_performance_all_students(bundles_all_students= [bundles_5_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                            model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_list)
            kts_model_top_10, r2s_model_top_10, maes_model_top_10, mses_model_top_10 = measure_generalization_performance_all_students(bundles_all_students= [bundles_10_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                            model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_list)
        
            
            # Step6.b) Measure number of agreements between the GUI reports and the true preferences
            agreements = check_agreements_on_cq(model_student_list= model_student_list, model_info= model_info,benchmark_student_list= true_student_list, benchmark_model_info= true_model_info, cqs= last_queries, timetable= timetable)
            
            ordinal_info_size = get_ordinal_dataset_size(training_dataset= training_set, model_param_dictionary= model_param_dictionary)

            # Step 7. Project the MVNNs to measure results in the GUI language
            projected_mvnns = project_mvnns(mvnn_student_list= model_student_list, model_param_dictionary= model_param_dictionary)

             # Measure the number of base values and adjustments 
            base_values_model, adjustments_model = measure_gui_language_metrics(model_list = projected_mvnns, model_type = 'UNN_projected')
            # set_trace()
        
            # Measure similarity between projected ML models/GUI reports and the true (additive) preferences
            gui_language_kts_model_top10, gui_language_r2s_model_top10, gui_language_maes_model_top10, gui_language_mses_model_top10 = compare_gui_reports(true_student_list= true_student_list, other_student_list= projected_mvnns, model_type= 'UNN_projected', linear_instances = linear_instances, percentile= 60)
            gui_language_kts_model, gui_language_r2s_model, gui_language_maes_model, gui_language_mses_model = compare_gui_reports(true_student_list= true_student_list, other_student_list= projected_mvnns, model_type= 'UNN_projected', linear_instances = linear_instances)
            gui_language_kts_model_top5, gui_language_r2s_model_top5, gui_language_maes_model_top5, gui_language_mses_model_top5 = compare_gui_reports(true_student_list= true_student_list, other_student_list= projected_mvnns, model_type= 'UNN_projected', linear_instances = linear_instances, percentile= 80)
            
            
            
            wandb_dict = {
                'Elicited CQs': query_number,
                'KT': np.mean(kts_model),
                'R2': np.mean(r2s_model),
                'MAE': np.mean(maes_model),
                'MSE': np.mean(mses_model),
                'KT top 5': np.mean(kts_model_top_5),
                'R2 top 5': np.mean(r2s_model_top_5),
                'MAE top 5': np.mean(maes_model_top_5),
                'MSE top 5': np.mean(mses_model_top_5),
                'KT top 10': np.mean(kts_model_top_10),
                'R2 top 10': np.mean(r2s_model_top_10),
                'MAE top 10': np.mean(maes_model_top_10),
                'MSE top 10': np.mean(mses_model_top_10),
                'Base Values': np.mean(base_values_model),
                'Adjustments': np.mean(adjustments_model),
                'KT - GUI': np.mean(gui_language_kts_model),
                'R2 - GUI': np.mean(gui_language_r2s_model),
                'MAE - GUI': np.mean(gui_language_maes_model),
                'MSE - GUI': np.mean(gui_language_mses_model),
                'KT top 5 - GUI': np.mean(gui_language_kts_model_top5),
                'R2 top 5 - GUI': np.mean(gui_language_r2s_model_top5),
                'MAE top 5 - GUI': np.mean(gui_language_maes_model_top5),
                'MSE top 5 - GUI': np.mean(gui_language_mses_model_top5),
                'KT top 10 - GUI': np.mean(gui_language_kts_model_top10),
                'R2 top 10 - GUI': np.mean(gui_language_r2s_model_top10),
                'MAE top 10 - GUI': np.mean(gui_language_maes_model_top10),
                'MSE top 10 - GUI': np.mean(gui_language_mses_model_top10),
                'Ordinal Dataset Size': np.mean(ordinal_info_size),
                'Agreement on CQs': np.mean(agreements),
                'Allocated Bundle Value': np.mean(allocation_value_model),
            }
            if model_param_dictionary['cq_method'] == 'basic_plus':
                wandb_dict['Found New Max'] = training_set[0][5]

            result_dict[query_number] = wandb_dict

            # set_trace()
            if wandb_tracking:
                wandb.log(wandb_dict)

    
    else: 
        # we are in the case of the complete ordering pruned method, need to be careful about how we update the dataset/generate the queries.  
        training_set[0] = training_set[0] + (True,) # this is a flag that tells us if the last iteration finished or not. If it did not finish, then we need to keep the old dataset and just add the new CQs to it.
        model_param_dictionary['cq_method'] = 'complete_ordering_pruned_in_detail'
        
        for query_number in range(0, 1 + max_queries):
            model_param_dictionary['cq_limit'] = query_number + 1  
            # set the CQ limit accordingly for each round. NOTE: We count the number of CQs that were used for the training of the network (for all methods)


            # Step 4. Train the models on the current dataset 
            # set_trace()
            model_student_list = create_iterative_student_list(training_set, actual_student_list, credit_units = [1 for i in range(number_of_courses)],
                            timetable = timetable, model_type = model_type, model_param_dictionary = model_param_dictionary, 
                            model_student_list= model_student_list)
            
            
            if training_set[0][5]:  # if the last iteration finished, then we can create the dataset for the next iteration using the new models, (and new dataset)
                print ('FINISHED LAST ITERATION, REPLACING OLD DATASET WITH NEW ONE and old models with the new ones')
                model_student_list_old = model_student_list
                # if the last binary search iteration finished, then we can just replace the old dataset with the new one
                training_set_old = copy.deepcopy(training_set)
            # if it did not finish, then we need to generate the new dataset using the last iteration of the student list that had completely finished 
            
            # Deepcopy the dataset because we modify it every round (but we may want to keep the old dataset for the next cq, if the previous binary search did not terminate. 
            training_set = copy.deepcopy(training_set_old)
                
            
            # Step 4b. Get the new dataset for the next iteration
            training_set, last_queries = generate_new_queries(actual_student_list = actual_student_list, timetable = timetable, number_of_samples = 1,  # only generate 1 sample at a time so that you can log the reuslts afterwards (and retrain if the method calls for it)
                        current_training_set = training_set, model_student_list = model_student_list_old, approximate_prices = prices_stage_1_GUI,
                        credit_units = [1 for i in range(number_of_courses)], model_type = model_type, seed= seed + args.student_to_check, model_param_dictionary = model_param_dictionary, gui_student_list = gui_student_list, 
                        get_last_query = True)
            
            
            # Step 5. Measure the performance of the models
        
            # Step 5a) Measure allocation value of the model 
            allocation_model = calculate_allocation(model_list= model_student_list, prices= prices_stage_1_GUI, timetable= timetable, models_to_run= [model_info])
            allocation_value_model = get_true_value_of_allocation(benchmark_student_list= true_student_list, individual_demands= allocation_model, benchmark_model= true_model_info[0], timetable = timetable)

            

            # Step 6.a) Measure generalization performance of the ML models and GUI reports
            kts_model, r2s_model, maes_model, mses_model = measure_generalization_performance_all_students(bundles_all_students= [bundles for i in range(len(true_student_list))], model_type= model_info[0], 
                                            model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_list)
            kts_model_top_5, r2s_model_top_5, maes_model_top_5, mses_model_top_5 = measure_generalization_performance_all_students(bundles_all_students= [bundles_5_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                            model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_list)
            kts_model_top_10, r2s_model_top_10, maes_model_top_10, mses_model_top_10 = measure_generalization_performance_all_students(bundles_all_students= [bundles_10_percentile[i] for i in range(len(true_student_list))], model_type= model_info[0], 
                                            model_param_dictionary= model_info[1], benchmark_student_list= true_student_list, benchmark_model_type= true_model_info[0],model_student_list= model_student_list)
        
            
            # Step6.b) Measure number of agreements between the GUI reports and the true preferences
            agreements = check_agreements_on_cq(model_student_list= model_student_list, model_info= model_info,benchmark_student_list= true_student_list, benchmark_model_info= true_model_info, 
                                                cqs= last_queries, timetable= timetable)
            
            ordinal_info_size = get_ordinal_dataset_size(training_dataset= training_set, model_param_dictionary= model_param_dictionary)

            # Step 7. Project the MVNNs to measure results in the GUI language
            projected_mvnns = project_mvnns(mvnn_student_list= model_student_list, model_param_dictionary= model_param_dictionary)

             # Measure the number of base values and adjustments 
            base_values_model, adjustments_model = measure_gui_language_metrics(model_list = projected_mvnns, model_type = 'UNN_projected')
        
            # Measure similarity between projected ML models/GUI reports and the true (additive) preferences
            gui_language_kts_model_top10, gui_language_r2s_model_top10, gui_language_maes_model_top10, gui_language_mses_model_top10 = compare_gui_reports(true_student_list= true_student_list, other_student_list= projected_mvnns, model_type= 'UNN_projected', linear_instances = linear_instances, percentile= 60)
            gui_language_kts_model, gui_language_r2s_model, gui_language_maes_model, gui_language_mses_model = compare_gui_reports(true_student_list= true_student_list, other_student_list= projected_mvnns, model_type= 'UNN_projected', linear_instances = linear_instances)
            gui_language_kts_model_top5, gui_language_r2s_model_top5, gui_language_maes_model_top5, gui_language_mses_model_top5 = compare_gui_reports(true_student_list= true_student_list, other_student_list= projected_mvnns, model_type= 'UNN_projected', linear_instances = linear_instances, percentile= 80)
            
            
            
            wandb_dict = {
                'Elicited CQs': query_number,
                'KT': np.mean(kts_model),
                'R2': np.mean(r2s_model),
                'MAE': np.mean(maes_model),
                'MSE': np.mean(mses_model),
                'KT top 5': np.mean(kts_model_top_5),
                'R2 top 5': np.mean(r2s_model_top_5),
                'MAE top 5': np.mean(maes_model_top_5),
                'MSE top 5': np.mean(mses_model_top_5),
                'KT top 10': np.mean(kts_model_top_10),
                'R2 top 10': np.mean(r2s_model_top_10),
                'MAE top 10': np.mean(maes_model_top_10),
                'MSE top 10': np.mean(mses_model_top_10),
                'Base Values': np.mean(base_values_model),
                'Adjustments': np.mean(adjustments_model),
                'KT - GUI': np.mean(gui_language_kts_model),
                'R2 - GUI': np.mean(gui_language_r2s_model),
                'MAE - GUI': np.mean(gui_language_maes_model),
                'MSE - GUI': np.mean(gui_language_mses_model),
                'KT top 5 - GUI': np.mean(gui_language_kts_model_top5),
                'R2 top 5 - GUI': np.mean(gui_language_r2s_model_top5),
                'MAE top 5 - GUI': np.mean(gui_language_maes_model_top5),
                'MSE top 5 - GUI': np.mean(gui_language_mses_model_top5),
                'KT top 10 - GUI': np.mean(gui_language_kts_model_top10),
                'R2 top 10 - GUI': np.mean(gui_language_r2s_model_top10),
                'MAE top 10 - GUI': np.mean(gui_language_maes_model_top10),
                'MSE top 10 - GUI': np.mean(gui_language_mses_model_top10),
                'Ordinal Dataset Size': np.mean(ordinal_info_size),
                'Agreement on CQs': np.mean(agreements),
                'Allocated Bundle Value': np.mean(allocation_value_model),
            }
            if model_param_dictionary['cq_method'] == 'basic_plus':
                wandb_dict['Found New Max'] = training_set[0][5]

            result_dict[query_number] = wandb_dict

            # set_trace()
            if wandb_tracking:
                wandb.log(wandb_dict)

                

            print(f'--------> Number of CQs: {query_number} -- Agreements: {agreements}, ordinal_info_size: {ordinal_info_size}, allocation values: {allocation_value_model}')

            # set_trace()

    if wandb_tracking: 
        run.finish()

    # save the results to a pickle file
    dict_name = f'./acquisition_function_results/SR_{supply_ratio}_NP_{number_of_popular}_LI_{linear_instances}_AF_{args.cq_method}_student_{args.student_to_check}.pkl'
    # open the file for writing, and create it if it does not exist
    with open(dict_name, 'wb+') as f:
        pickle.dump(result_dict, f)

    return 

if __name__ == '__main__':
    main_function()