import numpy as np
import pickle
# import math
# import gurobipy as gp
from pdb import set_trace
from cleanup import student 

from tabu_search import calculate_total_demand, unfreeze_model_student_list
import argparse
import torch
import scipy.stats as stats
from sklearn.metrics import r2_score
from sklearn.preprocessing import PolynomialFeatures
import math 


def save_obj(obj, name):
    with open(name + '.pkl', 'wb+') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)


def load_obj(name):
    print(name + '.pkl')
    with open(name + '.pkl', 'rb') as f:
        return pickle.load(f)

# supply_ratio_list = [1.1, 1.25, 1.5, 2], number_of_popular_list = [3, 4, 6, 8]


def load_all_models(starting_instance_index = 0, model_family = 'linear_ml', queries = '30-20', supply_ratio = 1.25, number_of_popular = 9,  grid = 'large', linear_instances = False):

    large_grid = grid == 'large'
    distr = 'new'
    

    if not linear_instances:
        problem_instance_folder = 'problem_instances_final'
        results_stage1_folder = 'results_s1_iterative_clean'
        results_stages23_folder = 'results_s23_iterative_clean'

        results_stage1_folder_GUI = 'new_results_s1_clean'
        model_family_GUI = f'true_noise_default_p{number_of_popular}'
        model_family_projected = f'POpop{number_of_popular}_projected_v1.1_from_scratch_gui_reports'

        if model_family == 'true_gui': 
            model_family_GUI = 'true'
            model_family = model_family_projected

    else:
        print('Starting the new EB instances!')
        problem_instance_folder = 'problem_instances_EB_linear'
        results_stage1_folder = 'results_s1_iterative_EB_linear'


        results_stage1_folder_GUI = 'results_s1_ne_EB'
        model_family_GUI = 'noisy_EB_default'
        model_family_projected = 'POpop9EBv2_projected_v1.1_from_scratch_gui_reports'


    # load all of the corresponding problem instances
    print('loading the new distribution instances!')
    true_student_lists_all_runs = load_obj(f'./{problem_instance_folder}/true_student_lists_sr_{supply_ratio}_popular_{number_of_popular}_lg_{large_grid}')
    capacities_all_runs = np.load(f'./{problem_instance_folder}/capacities_all_runs_sr_{supply_ratio}_popular_{number_of_popular}_lg_{large_grid}.npy')
    timetables_all_runs = load_obj(f'./{problem_instance_folder}/timetables_sr_{supply_ratio}_popular_{number_of_popular}_lg_{large_grid}')

    # determine exactly the instances to run
    true_student_lists_to_run = true_student_lists_all_runs[starting_instance_index: (starting_instance_index) + 1]
    capacities_to_run = capacities_all_runs[starting_instance_index: (starting_instance_index) + 1]
    timetables_to_run = timetables_all_runs[starting_instance_index: (starting_instance_index + 1)]

    # load extra stuff specific to stages 2 and 3:
    # prices_stage_1_wrong_format = np.load(f'./{results_stage1_folder}/prices_stage1_iterative_index_{starting_instance_index}_family_{model_family}_sr_{supply_ratio}_popular_{number_of_popular}_d_{distr}_g_{grid}_queries_{queries}.npy')
    loaded_student_lists_wrong_format = load_obj(f'./{results_stage1_folder}/student_lists_stage1_iterative_index_{starting_instance_index}_family_{model_family}_sr_{supply_ratio}_popular_{number_of_popular}_d_{distr}_g_{grid}_queries_{queries}')

    loaded_student_lists = []
    for j in range(len(loaded_student_lists_wrong_format)):
        loaded_student_lists.append([loaded_student_lists_wrong_format[j]])


    
    # load extra stuff specific to stages 2 and 3 (FOR THE GUI VERSION):
    prices_stage_1_GUI = np.load(f'./{results_stage1_folder_GUI}/prices_stage1_index_{starting_instance_index}_family_{model_family_GUI}_sr_{supply_ratio}_popular_{number_of_popular}_d_{distr}_g_{grid}.npy')
    loaded_student_lists_GUI = load_obj(f'./{results_stage1_folder_GUI}/student_lists_stage1_index_{starting_instance_index}_family_{model_family_GUI}_sr_{supply_ratio}_popular_{number_of_popular}_d_{distr}_g_{grid}')

    if model_family_GUI == 'noisy_EB_default':
        increment_multiplier = 6
        forget_base_increment = 0.02 * increment_multiplier
        base_std_increment = 2 * increment_multiplier
        models_to_run_GUI = [('LinearNoisy',  {'noisy_forget_base': 0.5 + forget_base_increment, 'noisy_forget_base_uniform': 0, 'noisy_forget_adjustments': 0,  'noisy_base_std': 10 + base_std_increment, 'noisy_adj_std': 0})]

    elif model_family_GUI == 'true_noise_default_p6':
        models_to_run_GUI = [('PairwiseAdjustmentsNoisy',  {'noisy_forget_base': 0.5, 'noisy_forget_adjustments': 0.4825, 'noisy_base_std': 17, 'noisy_adj_std': 0.2})]

    elif model_family_GUI == 'true_noise_default_p9':
        models_to_run_GUI = [('PairwiseAdjustmentsNoisy',  {'noisy_forget_base': 0.5, 'noisy_forget_adjustments': 0.48, 'noisy_base_std': 23, 'noisy_adj_std': 0.2})]
    
    elif model_family_GUI == 'true': 
        models_to_run_GUI = [('PairwiseAdjustments', None)]

    else: 
        raise ValueError('model_family_GUI not recognized!')


    # determine the models to run
    if model_family == 'linear_ml':
        models_to_run = [('LinearRegression', {'samples': 50, 'samples_in_range': 0, 'range_min_value': 350, 'scale_ys': False, 'alpha': 1}),
        ('LinearRegression', {'samples': 25, 'samples_in_range': 25, 'range_min_value': 350, 'scale_ys': False, 'alpha': 1}),
            ('LinearRegression', {'samples': 0, 'samples_in_range': 50, 'range_min_value': 350, 'scale_ys': False, 'alpha': 1}),
            ('Ridge', {'samples': 50, 'samples_in_range': 0, 'range_min_value': 350, 'scale_ys': False, 'alpha': 1}),
            ('Ridge', {'samples': 25, 'samples_in_range': 25, 'range_min_value': 350, 'scale_ys': False, 'alpha': 1}),
            ('Ridge', {'samples': 0, 'samples_in_range': 50, 'range_min_value': 350, 'scale_ys': False, 'alpha': 1}),
            ('Lasso', {'samples': 50, 'samples_in_range': 0, 'range_min_value': 350, 'scale_ys': False, 'alpha': 1}),
            ('Lasso', {'samples': 25, 'samples_in_range': 25, 'range_min_value': 350, 'scale_ys': False, 'alpha': 1}),
            ('Lasso', {'samples': 0, 'samples_in_range': 50, 'range_min_value': 350, 'scale_ys': False, 'alpha': 1}),
            ('ElasticNet', {'samples': 50, 'samples_in_range': 0, 'range_min_value': 350, 'scale_ys': False, 'alpha': 1}),
            ('ElasticNet', {'samples': 25, 'samples_in_range': 25, 'range_min_value': 350, 'scale_ys': False, 'alpha': 1}),
            ('ElasticNet', {'samples': 0, 'samples_in_range': 50, 'range_min_value': 350, 'scale_ys': False, 'alpha': 1})]

   
    # final model for july 20th
    elif model_family == 'POpop6':
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 3, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.5, 'gui_forget_adjustments':0.4825, 'gui_base_noise_std': 17, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                  })]

    elif model_family == 'POpop9':
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 3, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                  })]
        

    elif model_family == 'POpop9EB':  # our default model for 9 popular courses, but now for the tweaked linear distribution that EB asked for.
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 3, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.5, 'gui_forget_base_uniform': 0.14, 'gui_forget_adjustments':0, 'gui_base_noise_std': 19, 'gui_adj_std':0, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                  })]

    elif model_family == 'POpop9EBL':  # LINAR MVNN for 9 popular courses, but now for the tweaked linear distribution that EB asked for.
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 0, 'UNN_units': 1, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.5, 'gui_forget_base_uniform': 0.14, 'gui_forget_adjustments':0, 'gui_base_noise_std': 19, 'gui_adj_std':0, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                  })]

    elif model_family == 'POpop9EBv2':  # our default model for 9 popular courses, but now for the NEW linear EB distribution with no forgotten base values
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 3, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.62, 'gui_forget_base_uniform': 0, 'gui_forget_adjustments':0, 'gui_base_noise_std': 22, 'gui_adj_std':0, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                  })]

    elif model_family == 'POpop9EBLv2':  # LINAR MVNN for 9 popular courses, but now for the tweaked linear distribution that EB asked for.
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 0, 'UNN_units': 1, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.62, 'gui_forget_base_uniform': 0, 'gui_forget_adjustments':0, 'gui_base_noise_std': 22, 'gui_adj_std':0, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                  })]

    elif model_family == 'POpop9EBv2_m0':
        noise_multiplier = 0.5
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 3, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.62 * noise_multiplier, 'gui_forget_base_uniform': 0, 'gui_forget_adjustments':0, 'gui_base_noise_std': 22 * noise_multiplier, 'gui_adj_std':0, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                  })]


    elif model_family == 'POpop9EBLv2_m0':
        noise_multiplier = 0.5
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 0, 'UNN_units': 1, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.62 * noise_multiplier, 'gui_forget_base_uniform': 0, 'gui_forget_adjustments':0, 'gui_base_noise_std': 22 * noise_multiplier, 'gui_adj_std':0, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                  })]

    
    # models that use a projection of the mvnns back into the gui language. Requirement: the mvnn models have already been trained (ie, they are loaded)
    elif model_family == 'POpop9_projected_v1':
        models_to_run = [('UNN_projected', {'alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000
                                  })]
    elif model_family == 'POpop9_projected_v1.1':
        models_to_run = [('UNN_projected', {'alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000, 'train_high_samples': 5000
                                  })]
        
    elif model_family == 'POpop9_projected_v1.2':
        models_to_run = [('UNN_projected', {'alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 300, 'train_high_samples': 300
                                  })]
        
    elif model_family == 'POpop9_projected_v1.3':
        models_to_run = [('UNN_projected', {'alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 500, 'train_high_samples': 500
                                  })]
        
    elif model_family == 'POpop9_projected_v3.1':
        models_to_run = [('UNN_projected', {'alpha': 0.0 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000, 'train_high_samples': 5000
                                  })]
    
    elif model_family == 'POpop9_projected_v3.2':
        models_to_run = [('UNN_projected', {'alpha': 0.0 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 300, 'train_high_samples': 300
                                  })]
        
    elif model_family == 'POpop9_projected_v2':
        models_to_run = [('UNN_projected', {'alpha': 0.1 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000
                                  })]
        
    elif model_family == 'POpop6_projected_v1':
        models_to_run = [('UNN_projected', {'alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000
                                  })]
        
    elif model_family == 'POpop6_projected_v2':
        models_to_run = [('UNN_projected', {'alpha': 0.1 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000
                                  })]
        
    # models that use projection but started from scratch (ie, no mvnn models are loaded in stage 1)
    elif model_family == 'POpop9_projected_v1.1_from_scratch_gui_reports': # use original GUI reports for approximate price caluclation
        models_to_run = [('UNN_projected', {
        'proj_alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000, 'train_high_samples': 5000   # projection parameters
                                  })]
        
    elif model_family == 'POpop9_projected_v1.1_from_scratch_hallucinated_gui_reports': # use hallucinated GUI reports for approximate price caluclation
        models_to_run = [('UNN_projected', {
            'proj_alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000, 'train_high_samples': 5000   # projection parameters
                                  })]
        
    elif model_family == 'POpop9_gui_reports': # use original GUI reports for approximate price caluclation
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 3, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
    'project_to_gui': False, 'approximate_prices_model': 'gui_reports' 
                                  })]
        
    elif model_family == 'POpop9_hallucinated_gui_reports': # use hallucinated GUI reports for approximate price caluclation
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 3, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
    'project_to_gui': True, 'approximate_prices_model': 'hallucinated_gui_reports' 
                                  })]
        
    # models that use GUI/hallucinated GUI reports for approximate price calculation, and also project, but for the LINEAR distribution with no adjustments. 
    elif model_family == 'POpop9EBv2_projected_v1.1_from_scratch_gui_reports':  # our default model for 9 popular courses, but now for the NEW linear EB distribution with no forgotten base values
        models_to_run = [('UNN_projected', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 3, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.62, 'gui_forget_base_uniform': 0, 'gui_forget_adjustments':0, 'gui_base_noise_std': 22, 'gui_adj_std':0, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
    'project_to_gui': True, 'approximate_prices_model': 'gui_reports' ,'proj_alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000, 'train_high_samples': 5000,
    'linear_projection': True})]
        

    # ----- new model familys that use the basic option for CQs
    elif model_family == 'Bpop9':
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 2, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'basic', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                  })]
        
    elif model_family == 'Bpop9_projected_v1.1_from_scratch_gui_reports':
        models_to_run = [('UNN_projected', {
        'project_to_gui': True, 'approximate_prices_model': 'gui_reports' ,'proj_alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000, 'train_high_samples': 5000,
        'linear_projection': True})]
        
    elif model_family == 'Bpop9EBLv2':  
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 0, 'UNN_units': 1, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.62, 'gui_forget_base_uniform': 0, 'gui_forget_adjustments':0, 'gui_base_noise_std': 22, 'gui_adj_std':0, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'basic', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                  })]
        
    elif model_family == 'Bpop9EBv2_projected_v1.1_from_scratch_gui_reports':  
        models_to_run = [('UNN_projected', {
           'project_to_gui': True, 'approximate_prices_model': 'gui_reports' ,'proj_alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000, 'train_high_samples': 5000,
    'linear_projection': True})]
    # --------------------------------------------
    
        
        
    elif model_family == 'true':
        models_to_run = [('TrueLinear', None), ('PairwiseAdjustments', None), ('True', None)]

    elif model_family == 'true_gui':
        models_to_run = [('PairwiseAdjustments', None)]

    else:
        print('Invalid model family')
        return


    # convert all data to convenient format (i.e., unwrap lists)
    ML_student_list = loaded_student_lists[0][0]
    true_student_list = true_student_lists_to_run[0]
    timetable = timetables_to_run[0]
    capacities = capacities_to_run[0]

    if not model_family_GUI == 'true':
        student_list_GUI = loaded_student_lists_GUI[0][0]
        prices_stage_1_GUI = prices_stage_1_GUI[0][0]
    else: 
        student_list_GUI = loaded_student_lists_GUI[1][0]   # the paiwrwise adjustment student list of the [True linear, PaiwrwiseAdjustment (accurate), True]
        prices_stage_1_GUI = prices_stage_1_GUI[1][0]   # The corresponding GUI prices for it.



    # unfreeze the model student list, if you have to
    ml_model_type, ml_dictionary = models_to_run[0]

    ML_student_list = unfreeze_model_student_list(loaded_student_list = ML_student_list, model_type= ml_model_type, number_of_courses= len(capacities), course_timetable = timetable)

    # set_trace()
    # now run the thing
    return (models_to_run, models_to_run_GUI, true_student_list, ML_student_list, student_list_GUI, prices_stage_1_GUI, timetable, capacities)


def load_all_models_v2(index, model_family, benchmark_family, queries, supply_ratio, number_of_popular, linear_instances):

    grid = 'large'
    large_grid = grid == 'large'
    distr = 'new'

    gui_model_flag = model_family in ['true', 'noisy_EB_default', 'true_noise_default_p9', 'true_noise_default_p6'] # will be true if the non-benchmark model is also a GUI model 
    

    if not linear_instances:
        problem_instance_folder = 'problem_instances_final'
        results_stage1_folder_benchmark = 'new_results_s1_clean'
        results_stage1_folder_gui = 'new_results_s1_clean'
        gui_prices_family = 'true_noise_default_p9'

        if gui_model_flag: 
            results_stage1_folder_model = results_stage1_folder_benchmark
        else: 
            results_stage1_folder_model = 'results_s1_iterative_clean'
            

    else:
        print('Starting the new EB instances!')
        problem_instance_folder = 'problem_instances_EB_linear'
        results_stage1_folder_benchmark = 'results_s1_ne_EB'
        results_stage1_folder_gui = 'results_s1_ne_EB'
        gui_prices_family = 'noisy_EB_default'

        if gui_model_flag: 
            results_stage1_folder_model = 'results_s1_ne_EB'
        else: 
            results_stage1_folder_model = 'results_s1_iterative_EB_linear'


    # load all of the corresponding problem instances
    print('loading the new distribution instances!')

    capacities_all_runs = np.load(f'./{problem_instance_folder}/capacities_all_runs_sr_{supply_ratio}_popular_{number_of_popular}_lg_{large_grid}.npy')
    timetables_all_runs = load_obj(f'./{problem_instance_folder}/timetables_sr_{supply_ratio}_popular_{number_of_popular}_lg_{large_grid}')

    capacities = capacities_all_runs[index]
    timetable = timetables_all_runs[index]

    # determine exactly the instances to run
    if benchmark_family == 'true':
        true_student_lists_all_runs = load_obj(f'./{problem_instance_folder}/true_student_lists_sr_{supply_ratio}_popular_{number_of_popular}_lg_{large_grid}')
        benchmark_student_list = true_student_lists_all_runs[index]
    else:
        benchmark_student_list = load_obj(f'./{results_stage1_folder_benchmark}/student_lists_stage1_index_{index}_family_{benchmark_family}_sr_{supply_ratio}_popular_{number_of_popular}_d_{distr}_g_{grid}')[0][0]
        

    if not gui_model_flag:
        model_student_list = load_obj(f'./{results_stage1_folder_model}/student_lists_stage1_iterative_index_{index}_family_{model_family}_sr_{supply_ratio}_popular_{number_of_popular}_d_{distr}_g_{grid}_queries_{queries}')[0]
    else: 
        if model_family != 'true':
            model_student_list = load_obj(f'./{results_stage1_folder_model}/student_lists_stage1_index_{index}_family_{model_family}_sr_{supply_ratio}_popular_{number_of_popular}_d_{distr}_g_{grid}')[0][0]
        else:
            # set_trace()  # NOTE: we are using the pairwise adjustments as the "ML" model, these are the seocnd model in this list! 
            model_student_list = load_obj(f'./{results_stage1_folder_model}/student_lists_stage1_index_{index}_family_{model_family}_sr_{supply_ratio}_popular_{number_of_popular}_d_{distr}_g_{grid}')[1][0] 

    
    prices_stage_1_GUI = np.load(f'./{results_stage1_folder_gui}/prices_stage1_index_{index}_family_{gui_prices_family}_sr_{supply_ratio}_popular_{number_of_popular}_d_{distr}_g_{grid}.npy')[0][0]


    if not linear_instances:
        if benchmark_family == 'true':
            benchmark_model_info = ('True', None)
        else: 
            benchmark_model_info = ('PairwiseAdjustments', None)

    else:
        benchmark_model_info = ('TrueLinear', None)



    # model_student_list = unfreeze_model_student_list(loaded_student_list = model_student_list, model_type= ml_model_type, number_of_courses= len(capacities), course_timetable = timetable)


    # determine the models to run
    if model_family == 'POpop6':
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 3, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.5, 'gui_forget_adjustments':0.4825, 'gui_base_noise_std': 17, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                  })]

    elif model_family == 'POpop9':
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 3, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                  })]

    elif model_family == 'POpop9EB':  # our default model for 9 popular courses, but now for the tweaked linear distribution that EB asked for.
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 3, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.5, 'gui_forget_base_uniform': 0.14, 'gui_forget_adjustments':0, 'gui_base_noise_std': 19, 'gui_adj_std':0, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                  })]

    elif model_family == 'POpop9EBL':  # LINAR MVNN for 9 popular courses, but now for the tweaked linear distribution that EB asked for.
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 0, 'UNN_units': 1, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.5, 'gui_forget_base_uniform': 0.14, 'gui_forget_adjustments':0, 'gui_base_noise_std': 19, 'gui_adj_std':0, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                  })]

    elif model_family == 'POpop9EBv2':  # our default model for 9 popular courses, but now for the NEW linear EB distribution with no forgotten base values
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 3, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.62, 'gui_forget_base_uniform': 0, 'gui_forget_adjustments':0, 'gui_base_noise_std': 22, 'gui_adj_std':0, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                  })]

    elif model_family == 'POpop9EBLv2':  # LINAR MVNN for 9 popular courses, but now for the tweaked linear distribution that EB asked for.
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 0, 'UNN_units': 1, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.62, 'gui_forget_base_uniform': 0, 'gui_forget_adjustments':0, 'gui_base_noise_std': 22, 'gui_adj_std':0, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                  })]

    elif model_family == 'POpop9EBv2_m0':
        noise_multiplier = 0.5
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 3, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.62 * noise_multiplier, 'gui_forget_base_uniform': 0, 'gui_forget_adjustments':0, 'gui_base_noise_std': 22 * noise_multiplier, 'gui_adj_std':0, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                  })]


    elif model_family == 'POpop9EBLv2_m0':
        noise_multiplier = 0.5
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 0, 'UNN_units': 1, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.62 * noise_multiplier, 'gui_forget_base_uniform': 0, 'gui_forget_adjustments':0, 'gui_base_noise_std': 22 * noise_multiplier, 'gui_adj_std':0, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                  })]
        
    elif model_family == 'noisy_EB_default':
        increment_multiplier = 6
        forget_base_increment = 0.02 * increment_multiplier
        base_std_increment = 2 * increment_multiplier
        models_to_run = [('LinearNoisy',  {'noisy_forget_base': 0.5 + forget_base_increment, 'noisy_forget_base_uniform': 0, 'noisy_forget_adjustments': 0,  'noisy_base_std': 10 + base_std_increment, 'noisy_adj_std': 0})]

    elif model_family == 'true_noise_default_p6':
        models_to_run = [('PairwiseAdjustmentsNoisy',  {'noisy_forget_base': 0.5, 'noisy_forget_adjustments': 0.4825, 'noisy_base_std': 17, 'noisy_adj_std': 0.2})]

    elif model_family == 'true_noise_default_p9':
        models_to_run = [('PairwiseAdjustmentsNoisy',  {'noisy_forget_base': 0.5, 'noisy_forget_adjustments': 0.48, 'noisy_base_std': 23, 'noisy_adj_std': 0.2})]

    
    # models that use a projection of the mvnns back into the gui language. Requirement: the mvnn models have already been trained (ie, they are loaded)
    elif model_family == 'POpop9_projected_v1':
        models_to_run = [('UNN_projected', {'alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000
                                  })]
    elif model_family == 'POpop9_projected_v1.1':
        models_to_run = [('UNN_projected', {'alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000, 'train_high_samples': 5000
                                  })]
        
    elif model_family == 'POpop9_projected_v1.2':
        models_to_run = [('UNN_projected', {'alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 300, 'train_high_samples': 300
                                  })]
        
    elif model_family == 'POpop9_projected_v1.3':
        models_to_run = [('UNN_projected', {'alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 500, 'train_high_samples': 500
                                  })]
        
    elif model_family == 'POpop9_projected_v3.1':
        models_to_run = [('UNN_projected', {'alpha': 0.0 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000, 'train_high_samples': 5000
                                  })]
    
    elif model_family == 'POpop9_projected_v3.2':
        models_to_run = [('UNN_projected', {'alpha': 0.0 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 300, 'train_high_samples': 300
                                  })]
        
    elif model_family == 'POpop9_projected_v2':
        models_to_run = [('UNN_projected', {'alpha': 0.1 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000
                                  })]
        
    elif model_family == 'POpop6_projected_v1':
        models_to_run = [('UNN_projected', {'alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000
                                  })]
        
    elif model_family == 'POpop6_projected_v2':
        models_to_run = [('UNN_projected', {'alpha': 0.1 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000
                                  })]
        
    # models that use projection but started from scratch (ie, no mvnn models are loaded in stage 1)
    elif model_family == 'POpop9_projected_v1.1_from_scratch_gui_reports': # use original GUI reports for approximate price caluclation
        models_to_run = [('UNN_projected', {
        'proj_alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000, 'train_high_samples': 5000   # projection parameters
                                  })]
        
    elif model_family == 'POpop9_projected_v1.1_from_scratch_hallucinated_gui_reports': # use hallucinated GUI reports for approximate price caluclation
        models_to_run = [('UNN_projected', {
            'proj_alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000, 'train_high_samples': 5000   # projection parameters
                                  })]
        
    elif model_family == 'POpop9_gui_reports': # use original GUI reports for approximate price caluclation
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 3, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
    'project_to_gui': False, 'approximate_prices_model': 'gui_reports' 
                                  })]
        
    elif model_family == 'POpop9_hallucinated_gui_reports': # use hallucinated GUI reports for approximate price caluclation
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 3, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
    'project_to_gui': True, 'approximate_prices_model': 'hallucinated_gui_reports' 
                                  })]
        
    # models that use GUI/hallucinated GUI reports for approximate price calculation, and also project, but for the LINEAR distribution with no adjustments. 
    elif model_family == 'POpop9EBv2_projected_v1.1_from_scratch_gui_reports':  # our default model for 9 popular courses, but now for the NEW linear EB distribution with no forgotten base values
        models_to_run = [('UNN_projected', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 3, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.62, 'gui_forget_base_uniform': 0, 'gui_forget_adjustments':0, 'gui_base_noise_std': 22, 'gui_adj_std':0, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
    'project_to_gui': True, 'approximate_prices_model': 'gui_reports' ,'proj_alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000, 'train_high_samples': 5000,
    'linear_projection': True})]
        

    # ----- new model families that use the basic option for CQs
    elif model_family == 'Bpop9':
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 2, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'basic', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                  })]
        
    elif model_family == 'Bpop9_projected_v1.1_from_scratch_gui_reports':
        models_to_run = [('UNN_projected', {
        'project_to_gui': True, 'approximate_prices_model': 'gui_reports' ,'proj_alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000, 'train_high_samples': 5000,
        'linear_projection': True})]
        
    elif model_family == 'Bpop9EBLv2':  
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 0, 'UNN_units': 1, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.62, 'gui_forget_base_uniform': 0, 'gui_forget_adjustments':0, 'gui_base_noise_std': 22, 'gui_adj_std':0, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'basic', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                  })]
        
    elif model_family == 'Bpop9EBv2_projected_v1.1_from_scratch_gui_reports':  
        models_to_run = [('UNN_projected', {
           'project_to_gui': True, 'approximate_prices_model': 'gui_reports' ,'proj_alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000, 'train_high_samples': 5000,
    'linear_projection': True})]
        
    elif model_family in ['POpop9_projected_v1.1_from_scratch_gui_reports_high_alpha', 'POpop9_projected_v1.1_from_scratch_gui_reports_higher_alpha', 'POpop9_projected_v1.1_from_scratch_gui_reports_highest_alpha']:
        if model_family == 'POpop9_projected_v1.1_from_scratch_gui_reports_high_alpha':
            alpha = 0.85
        elif model_family == 'POpop9_projected_v1.1_from_scratch_gui_reports_higher_alpha':
            alpha = 0.95
        elif model_family == 'POpop9_projected_v1.1_from_scratch_gui_reports_highest_alpha':
            alpha = 0.98
        models_to_run = [('UNN_projected', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 3, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'complete_ordering_pruned', 'alpha': alpha, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
    'project_to_gui': True, 'approximate_prices_model': 'gui_reports' ,'proj_alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000, 'train_high_samples': 5000   # projection parameters
                                  })]
    # --------------------------------------------
    elif model_family == 'Bpop9v1.1_projected_v1.1_from_scratch_gui_reports':    # same as the 2 above, but now also projecting the networks back to the gui reports
        models_to_run = [('UNN_projected', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 3, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'basic', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
    'project_to_gui': True, 'approximate_prices_model': 'gui_reports' ,'proj_alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000, 'train_high_samples': 5000   # projection parameters
                                  })]
        
    elif model_family == 'BPpop9v.1_projected_v1.1_from_scratch_gui_reports':  # NOTE: careful with the name, missed the v1.1...
        models_to_run = [('UNN_projected', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 3, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'basic_plus', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
    'project_to_gui': True, 'approximate_prices_model': 'gui_reports' ,'proj_alpha': 0.025 , 'ridge': 0.0, 'fit_intercept': False, 'train_samples': 5000, 'train_high_samples': 5000   # projection parameters
                                  })]
        
    # --- testing the basic/basic plus method that implicitly adds all implied points to the dataset, with the new interaction paradigm for the first CQ. 
    elif model_family == 'Bpop9v1.1':    # new logic for the ymax of the first CQ
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 3, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'basic', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                  })]
    
    elif model_family == 'BPpop9v.1':
        models_to_run = [('UNN', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 'UNN_layers': 3, 'UNN_units': 20, 'lr':  0.001, 'weight_decay': 6e-5,
    'epochs': 400, 'batch_size': 8, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_loss_string': 'l1', 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True,
    'use_implied_dataset': True,
    'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'use_cqs': True, 'cq_method': 'basic_plus', 'alpha': 0.65, 'clip_cardinal': 0.68, 'clip_ordinal': 0.41, 'use_gradient_clipping': True, 'UNN_loss_ord_string': 'BCE',
                                  })]
        

    # NEW TRANSFER LEARNING MODELS
    elif model_family == 'TLBP9v0.1':
        models_to_run = [('UNN_TL', {'samples': 0, 'samples_in_range': 0, 'range_min_value': 350, 
        'UNN_layers': 3, 'UNN_units': 20, 'UNN_init_E': 1, 'UNN_init_Var': 0.09, 'UNN_random_ts': [0, 1], 'UNN_trainable_ts': True, 'use_gradient_clipping': True,
    'gui_forget_base': 0.5, 'gui_forget_adjustments':0.48, 'gui_base_noise_std': 23, 'gui_adj_std':0.2, 'gui_sample_categories_weights': [1, 1, 1],
    'sample_relative_frequencies': (1, 1, 1, 1),
    'points_to_add': 40, 'points_to_hallucinate': 20, 'uniform_range_low': None, 'uniform_range_high': None, 'forgotten_course_expected_value': None,
    'lr_cardinal':  0.001, 'weight_decay_cardinal': 6e-5, 'epochs_cardinal': 400, 'batch_size_cardinal': 8,  'UNN_loss_string_cardinal': 'l1', 'clip_cardinal': 0.68,
    'lr_ordinal':  0.001, 'weight_decay_ordinal': 6e-5, 'epochs_ordinal': 400, 'batch_size_ordinal': 8,  'UNN_loss_string_ordinal': 'BCE', 'clip_cardinal': 0.68,
    'use_implied_dataset': True,
    'use_cqs': True, 'cq_method': 'basic_plus',
                                  })]
        
        
    elif model_family == 'true':
        set_trace()
        models_to_run = [('TrueLinear', None), ('PairwiseAdjustments', None), ('True', None)]

    elif model_family == 'true_gui':
        models_to_run = [('PairwiseAdjustments', None)]

    else:
        print('Invalid model family')
        return


    model_info = models_to_run[0]

    # set_trace()
    # now run the thing
    return (model_info, benchmark_model_info, model_student_list, benchmark_student_list,  prices_stage_1_GUI, timetable, capacities)


def get_true_value_of_allocation(benchmark_student_list, individual_demands, benchmark_model, timetable): 
    """
    Takes as input the *benchmark* student list and the individual demands induced by some model, and returns the value of that allocation wrt. the benchmark model.
    """

    number_of_courses = individual_demands.shape[1]
    
    value_list_model = []
    if benchmark_model == 'True':
        for (j, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(benchmark_student_list):
            student_value = student(individual_demands[j], additive_prefs, substitutes, complements, timetable,
                    overload_penalty = overload_penalty, free_days_marginal_values= free_days_marginal_values,
                    credit_units= [1 for i in range(number_of_courses)], make_monotone= True)
            value_list_model.append(student_value)

    elif benchmark_model in ['PairwiseAdjustments', 'PairwiseAdjustmentsNoisy']:
        for (j, (additive_prefs, substitutes, complements, budget)) in enumerate(benchmark_student_list):
            student_value = student(individual_demands[j], additive_prefs, substitutes, complements, timetable,
                    credit_units= [1 for i in range(number_of_courses)], make_monotone= True)
            value_list_model.append(student_value)

    elif benchmark_model == 'TrueLinear':
        for j in range(len(benchmark_student_list)):
            additive_prefs = benchmark_student_list[j][0]
            student_value = np.dot(individual_demands[j], additive_prefs)
            value_list_model.append(student_value)

    else: 
        raise ValueError(f'Invalid benchmark model: {benchmark_model}')

    return np.array(value_list_model)

def calculate_allocation(model_list, prices, timetable, models_to_run):

    number_of_courses = prices.shape[0]

    model_type, model_param_dictionary = models_to_run[0]

    total_demand_model, individual_demands_model = calculate_total_demand(prices, model_list, timetable,
                    [1 for i in range(number_of_courses)], return_individual_demands= True, model_type= model_type, model_param_dictionary = model_param_dictionary)
    
    return individual_demands_model


def model_predict_values(model_type, model_param_dictionary, model, bundles, timetable = [[] for i in range(5)]):
    """
    Gets as input a list of bundles and a model, and returns the predicted values of the bundles under the model.
    """
    predicted_values = []
    # set_trace()
    if model_type == 'UNN': 
        mvnn, solver, scale, budget  = model
        
        for bundle in bundles: 
            bundle_tensor = torch.tensor(bundle).float()
            predicted_value = mvnn(bundle_tensor).data.numpy()[0] * scale
            predicted_values.append(predicted_value)

    elif model_type == 'UNN_transfer_learning': 
        mvnn, solver, scale, pretrained_model, budget  = model
        
        for bundle in bundles: 
            bundle_tensor = torch.tensor(bundle).float()
            predicted_value = mvnn(bundle_tensor).data.numpy()[0] * scale
            predicted_values.append(predicted_value)

    elif model_type == 'LinearNoisy' or model_type == 'TrueLinear':
        additive_prefs = model[0]
        for bundle in bundles:
            predicted_value = np.dot(additive_prefs, bundle)
            predicted_values.append(predicted_value)

    elif model_type == 'UNN_projected':
        regressor, budget = model
        additive_prefs = regressor.coef_
        if len(additive_prefs) == len(bundles[0]):
            # we are in the linear case, so just return the dot product
            for bundle in bundles:
                predicted_value = np.dot(additive_prefs, bundle)
                predicted_values.append(predicted_value)
        else:
            # we are in the non-linear case, so use the regressor
            poly = PolynomialFeatures(degree = 2, include_bias = False, interaction_only = True)
            bundles_array = np.array(bundles)
            bundles_poly = poly.fit_transform(bundles_array)
            bundles_poly_reshaped = bundles_poly.reshape(bundles_poly.shape[0], 1, -1)
            for bundle_poly in bundles_poly_reshaped:
                predicted_value = regressor.predict(bundle_poly)
                predicted_values.append(predicted_value)

    elif model_type == 'PairwiseAdjustments' or model_type == 'PairwiseAdjustmentsNoisy':
        (additive_prefs, substitutes, complements, budget) = model
        number_of_courses = len(additive_prefs)

        for bundle in bundles:
            predicted_value = student(bundle, additive_prefs, substitutes, complements, timetable,
                overload_penalty = 0, free_days_marginal_values= [0 for i in range(5)],
                credit_units= [1 for i in range(number_of_courses)], make_monotone= True)
            predicted_values.append(predicted_value)

    elif model_type == 'True':
        (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget) = model 
        number_of_courses = len(additive_prefs)

        for bundle in bundles:
            predicted_value = student(bundle, additive_prefs, substitutes, complements, timetable,
                overload_penalty = overload_penalty, free_days_marginal_values= free_days_marginal_values,
                credit_units= [1 for i in range(number_of_courses)], make_monotone= True)
            predicted_values.append(predicted_value)

    else: 
        print('Invalid model type')
        raise ValueError
        
            
            
    predicted_values = np.array(predicted_values).flatten()
    return  predicted_values


def generate_random_bundles(number_of_bundles = 1000, number_of_courses = 25, seed = 42):
    bundles = [] 
    np.random.seed(seed)

    for i in range(number_of_bundles):
        bundle = np.zeros(number_of_courses)
        bundle_one_indices = np.random.choice(number_of_courses, size= 5, replace= False)
        bundle[bundle_one_indices] = 1
        bundles.append(bundle)
    
    return bundles


def measure_generalization_performance_all_students(bundles_all_students, model_type, model_param_dictionary, benchmark_model_type, benchmark_student_list, model_student_list,
                                                    centered_metrics = False): 
    """
    Measures generalization perfromance across all students for the given bundles.
    """
    kts = [] 
    r2s = []
    maes = [] 
    mses = [] 
    print(f'Measure generalization performance starting with model type: {model_type}')

    for i in range(len(benchmark_student_list)):
        model_values = model_predict_values(model_type= model_type, model_param_dictionary= model_param_dictionary, model= model_student_list[i], bundles= bundles_all_students[i])
        true_values = model_predict_values(model_type= benchmark_model_type, model_param_dictionary= None, model= benchmark_student_list[i], bundles= bundles_all_students[i])

        if centered_metrics:
            model_values = model_values - np.mean(model_values)
            true_values = true_values - np.mean(true_values)

        
        kt, p  = stats.kendalltau(model_values, true_values)
        mae = np.mean(np.abs(model_values - true_values))
        mse = np.mean((model_values - true_values)**2)
        r2 = r2_score(true_values, model_values)

        kts.append(kt)
        r2s.append(r2)
        maes.append(mae)
        mses.append(mse)
    
    return np.array(kts), np.array(r2s), np.array(maes), np.array(mses)

def keep_high_valued_bundles(bundles, benchmark_student_list, benchmark_model_type,percentile = 95, return_values = False): 
    """
    for each student, only keeps bundles in the top percentile according to the benchmark student list 
    """
    bundles_array = np.array(bundles)
    new_bundles = [] 
    new_bundles_values = [] 
    for i in range(len(benchmark_student_list)):
        model_values = model_predict_values(model_type= benchmark_model_type, model_param_dictionary= None, model= benchmark_student_list[i], bundles= bundles)
        threshold_value = np.percentile(model_values, percentile)
        new_bundles.append(bundles_array[model_values > threshold_value])

        new_bundles_values.append(model_values[model_values > threshold_value])
    
    new_bundles = np.array(new_bundles)
    new_bundles_values = np.array(new_bundles_values)

    if return_values:
        return new_bundles, new_bundles_values

    return new_bundles

def compare_gui_reports(true_student_list, other_student_list, model_type, percentile = 0, linear_instances = True):
    """
    Takes as input the true student list, and another (projected ML/gui) student list and compares how well those 2 lists match in the GUI language. 
    """
    kts = []
    r2s = []
    maes = []
    mses = []
    # set_trace()
    print('Compare GUI reports starting with model type: ', model_type)
    if linear_instances:
        if model_type in ['TrueLinear', 'LinearNoisy']:
            for j in range(len(true_student_list)):
                additive_prefs = true_student_list[j][0]
                model_base_values = other_student_list[j][0]
                
                threshold_value = np.percentile(additive_prefs, percentile)
                indexes = additive_prefs >= threshold_value
                additive_prefs = additive_prefs[indexes]
                model_base_values = model_base_values[indexes]
                
                
                
                mae = np.abs(additive_prefs - model_base_values).mean()
                mse = ((additive_prefs - model_base_values)**2).mean()
                r2 = r2_score(additive_prefs, model_base_values)
                kt = stats.kendalltau(additive_prefs, model_base_values)[0]

                kts.append(kt)
                r2s.append(r2)
                maes.append(mae)
                mses.append(mse)

        elif model_type in ['UNN_projected']:
            for j in range(len(true_student_list)):
                additive_prefs = true_student_list[j][0]
                model_base_values = other_student_list[j][0].coef_

                threshold_value = np.percentile(additive_prefs, percentile)
                indexes = additive_prefs >= threshold_value
                additive_prefs = additive_prefs[indexes]
                model_base_values = model_base_values[indexes]

                mae = np.abs(additive_prefs - model_base_values).mean()
                mse = ((additive_prefs - model_base_values)**2).mean()
                r2 = r2_score(additive_prefs, model_base_values)
                kt = stats.kendalltau(additive_prefs, model_base_values)[0]

                kts.append(kt)
                r2s.append(r2)
                maes.append(mae)
                mses.append(mse)

    else: 
        if model_type in ['PairwiseAdjustments', 'PairwiseAdjustmentsNoisy']:
            for j in range(len(true_student_list)):
                additive_prefs = true_student_list[j][0]
                substitutes = true_student_list[j][1]
                complements = true_student_list[j][2]
                model_base_values = other_student_list[j][0]

                threshold_value = np.percentile(additive_prefs, percentile)
                indexes = additive_prefs >= threshold_value
                additive_prefs = additive_prefs[indexes]
                model_base_values = model_base_values[indexes]


                mae = np.abs(additive_prefs - model_base_values).mean()
                mse = ((additive_prefs - model_base_values)**2).mean()
                r2 = r2_score(additive_prefs, model_base_values)
                kt = stats.kendalltau(additive_prefs, model_base_values)[0]

                kts.append(kt)
                r2s.append(r2)
                maes.append(mae)
                mses.append(mse)


        elif model_type in ['UNN_projected']:
            # set_trace()
            for j in range(len(true_student_list)):
                additive_prefs = true_student_list[j][0]
                substitutes = true_student_list[j][1]
                complements = true_student_list[j][2]


                model_base_values = other_student_list[j][0].coef_[:25]


                threshold_value = np.percentile(additive_prefs, percentile)
                indexes = additive_prefs >= threshold_value
                additive_prefs = additive_prefs[indexes]
                model_base_values = model_base_values[indexes]

                mae = np.abs(additive_prefs - model_base_values).mean()
                mse = ((additive_prefs - model_base_values)**2).mean()
                r2 = r2_score(additive_prefs, model_base_values)
                kt = stats.kendalltau(additive_prefs, model_base_values)[0]
                # set_trace()

                kts.append(kt)
                r2s.append(r2)
                maes.append(mae)
                mses.append(mse)

    return np.array(kts), np.array(r2s), np.array(maes), np.array(mses)

def measure_gui_language_metrics(model_list, model_type): 
    """
    Gets as input the model list of all students and measures the number of base values and adjustments implied by that list
    """
    base_values = [] 
    adjustments = [] 
    if model_type in ['TrueLinear', 'LinearNoisy']:
        for model in model_list: 
            base_values.append((model[0] > 0).sum())
            adjustments.append(0)

    elif model_type in ['PairwiseAdjustments', 'PairwiseAdjustmentsNoisy']:
        for model in model_list: 
            base_values.append((model[0] > 0).sum())
            adjustments_nr = 0 
            complements = model[2]
            substitutes = model[1]
            for complement_set in complements: 
                adjustments_nr += math.comb(len(complement_set[0]), 2)
            for substitute_set in substitutes:
                adjustments_nr += math.comb(len(substitute_set[0]), 2)
            adjustments.append(adjustments_nr)

    elif model_type in ['UNN_projected']:
        for model in model_list: 
            coefs = model[0].coef_
            base_values.append((coefs[:25] > 0).sum())
            adjustments.append((coefs[25:] > 5).sum())


    else: 
        print(f'Model type {model_type} not implemented')
        raise NotImplementedError

    return np.array(base_values), np.array(adjustments)




if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--supply_ratio', type=float, default= 1.25, help='supply_ratio')
    parser.add_argument('--number_of_popular', type= int, default= 9, help='afffects_correlation')
    # parser.add_argument('--model_family', type= str, default= 'POpop9_projected_v1.1_from_scratch_gui_reports', help='the models to run')  # POpop9EBv2_projected_v1.1_from_scratch_gui_reports for linear
    parser.add_argument('--model_family', type= str, default= 'Bpop9v1.1_projected_v1.1_from_scratch_gui_reports', help='the models to run') # options: [PO/B]pop9_projected_v1.1_from_scratch_gui_reports
    parser.add_argument('--benchmark_family', type= str, default= 'true', help='the family to compare against: options: true and the noisy gui reports')
    parser.add_argument('--queries', type= str, default= '30-10', help='the number of queries of the model family, if it supports such a thing')
    parser.add_argument('--linear_instances', type= str, default= 'true', help='whether we are in linear instances or not')

    args = parser.parse_args()

    supply_ratio = args.supply_ratio
    number_of_popular = args.number_of_popular
    index = 0
    model_family = args.model_family
    benchmark_family = args.benchmark_family
    queries = args.queries

    linear_instances = str(args.linear_instances).lower() == 'true'

    print(f'supply_ratio: {supply_ratio} number_of_popular: {number_of_popular} index:{index} model_family: {model_family} queries: {queries}')
    
    kts_model_all_runs = []
    r2s_model_all_runs = []
    maes_model_all_runs = []
    mses_model_all_runs = []


    kts_model_all_runs_high_valued = []
    r2s_model_all_runs_high_valued = []
    maes_model_all_runs_high_valued = []
    mses_model_all_runs_high_valued = []

    allocation_values_model_all_runs = []

    gui_language_kts_model_all_runs = []
    gui_language_r2s_model_all_runs = []
    gui_language_maes_model_all_runs = []
    gui_language_mses_model_all_runs = []


    gui_language_kts_model_all_runs_top10 = []
    gui_language_r2s_model_all_runs_top10 = []
    gui_language_maes_model_all_runs_top10 = []
    gui_language_mses_model_all_runs_top10 = []

    gui_language_kts_model_all_runs_top5 = []
    gui_language_r2s_model_all_runs_top5 = []
    gui_language_maes_model_all_runs_top5 = []
    gui_language_mses_model_all_runs_top5 = []

    gui_language_base_values_model_all_runs = []    
    gui_language_adjustments_model_all_runs = []
    

    for i in range(20):
        index_to_run = index + i
        bundles = generate_random_bundles(number_of_bundles= 1000, number_of_courses= 25, seed = 42 + i)

        # (models_to_run, models_to_run_GUI, true_student_list, ML_student_list, student_list_GUI, prices_stage_1_GUI, timetable, capacities) = load_all_models(starting_instance_index= index_to_run, model_family= model_family, queries = queries, supply_ratio= supply_ratio, number_of_popular= number_of_popular, linear_instances = linear_instances)

        (model_info, benchmark_model_info, model_student_list, benchmark_student_list,  prices_stage_1_GUI, timetable, capacities) = load_all_models_v2(index = index_to_run, model_family= model_family, benchmark_family = benchmark_family, queries = queries, supply_ratio= supply_ratio, number_of_popular= number_of_popular, linear_instances = linear_instances)
        
        # Measure the number of base values and adjustments 
        base_values_model, adjustments_model = measure_gui_language_metrics(model_list = model_student_list, model_type = model_info[0])
        

        gui_language_base_values_model_all_runs.append(base_values_model)
        gui_language_adjustments_model_all_runs.append(adjustments_model)
        
        # Measure similarity between projected ML models/GUI reports and the true (additive) preferences
        gui_language_kts_model_top10, gui_language_r2s_model_top10, gui_language_maes_model_top10, gui_language_mses_model_top10 = compare_gui_reports(true_student_list= benchmark_student_list, other_student_list= model_student_list, model_type= model_info[0], linear_instances = linear_instances, percentile= 60)
        gui_language_kts_model, gui_language_r2s_model, gui_language_maes_model, gui_language_mses_model = compare_gui_reports(true_student_list= benchmark_student_list, other_student_list= model_student_list, model_type= model_info[0], linear_instances = linear_instances)
        gui_language_kts_model_top5, gui_language_r2s_model_top5, gui_language_maes_model_top5, gui_language_mses_model_top5 = compare_gui_reports(true_student_list= benchmark_student_list, other_student_list= model_student_list, model_type= model_info[0], linear_instances = linear_instances, percentile= 80)
    
        # set_trace()
        gui_language_kts_model_all_runs.append(gui_language_kts_model)
        gui_language_r2s_model_all_runs.append(gui_language_r2s_model)
        gui_language_maes_model_all_runs.append(gui_language_maes_model)
        gui_language_mses_model_all_runs.append(gui_language_mses_model)

        gui_language_kts_model_all_runs_top10.append(gui_language_kts_model_top10)
        gui_language_r2s_model_all_runs_top10.append(gui_language_r2s_model_top10)
        gui_language_maes_model_all_runs_top10.append(gui_language_maes_model_top10)
        gui_language_mses_model_all_runs_top10.append(gui_language_mses_model_top10)

        gui_language_kts_model_all_runs_top5.append(gui_language_kts_model_top5)
        gui_language_r2s_model_all_runs_top5.append(gui_language_r2s_model_top5)
        gui_language_maes_model_all_runs_top5.append(gui_language_maes_model_top5)
        gui_language_mses_model_all_runs_top5.append(gui_language_mses_model_top5)

        

        # Measure allocation value of the model student list studnet list 
        allocation_model = calculate_allocation(model_list= model_student_list, prices= prices_stage_1_GUI, timetable= timetable, models_to_run= [model_info])
        allocation_value_model = get_true_value_of_allocation(benchmark_student_list= benchmark_student_list, individual_demands= allocation_model, benchmark_model= benchmark_model_info[0])
        # set_trace()

        # measure generalization performance of the ML models and GUI reports
        kts_model, r2s_model, maes_model, mses_model = measure_generalization_performance_all_students(bundles_all_students= [bundles for i in range(len(benchmark_student_list))], model_type= model_info[0], 
                                        model_param_dictionary= model_info[1], benchmark_student_list= benchmark_student_list, benchmark_model_type= benchmark_model_info[0],model_student_list= model_student_list)
        

        kts_model_all_runs.append(kts_model)
        r2s_model_all_runs.append(r2s_model)
        maes_model_all_runs.append(maes_model)
        mses_model_all_runs.append(mses_model)
        

        allocation_values_model_all_runs.append(allocation_value_model)

        bundles_filtered = keep_high_valued_bundles(bundles= bundles, percentile = 95, benchmark_student_list= benchmark_student_list, benchmark_model_type= benchmark_model_info[0])

        # measure generalization performance on high valued bundles
        kts_model, r2s_model, maes_model, mses_model = measure_generalization_performance_all_students(bundles_all_students= bundles_filtered, model_type= model_info[0], model_param_dictionary= model_info[1],
                                                            benchmark_student_list= benchmark_student_list, benchmark_model_type= benchmark_model_info[0] ,model_student_list= model_student_list)
        

        kts_model_all_runs_high_valued.append(kts_model)
        r2s_model_all_runs_high_valued.append(r2s_model)
        maes_model_all_runs_high_valued.append(maes_model)
        mses_model_all_runs_high_valued.append(mses_model)

        

    gui_language_kts_model_all_runs = np.array(gui_language_kts_model_all_runs).flatten()
    gui_language_r2s_model_all_runs = np.array(gui_language_r2s_model_all_runs).flatten()
    gui_language_maes_model_all_runs = np.array(gui_language_maes_model_all_runs).flatten()
    gui_language_mses_model_all_runs = np.array(gui_language_mses_model_all_runs).flatten()

    gui_language_kts_model_all_runs_top10 = np.array(gui_language_kts_model_all_runs_top10).flatten()
    gui_language_r2s_model_all_runs_top10 = np.array(gui_language_r2s_model_all_runs_top10).flatten()
    gui_language_maes_model_all_runs_top10 = np.array(gui_language_maes_model_all_runs_top10).flatten()
    gui_language_mses_model_all_runs_top10 = np.array(gui_language_mses_model_all_runs_top10).flatten()

    gui_language_kts_model_all_runs_top5 = np.array(gui_language_kts_model_all_runs_top5).flatten()
    gui_language_r2s_model_all_runs_top5 = np.array(gui_language_r2s_model_all_runs_top5).flatten()
    gui_language_maes_model_all_runs_top5 = np.array(gui_language_maes_model_all_runs_top5).flatten()
    gui_language_mses_model_all_runs_top5 = np.array(gui_language_mses_model_all_runs_top5).flatten()


    gui_language_base_values_model_all_runs = np.array(gui_language_base_values_model_all_runs).flatten()
    gui_language_adjustments_model_all_runs = np.array(gui_language_adjustments_model_all_runs).flatten()
    
    kts_model = np.array(kts_model_all_runs).flatten()
    r2s_model = np.array(r2s_model_all_runs).flatten()
    maes_model = np.array(maes_model_all_runs).flatten()
    mses_model = np.array(mses_model_all_runs).flatten()


    kts_model_high_valued = np.array(kts_model_all_runs_high_valued).flatten()
    r2s_model_high_valued = np.array(r2s_model_all_runs_high_valued).flatten()
    maes_model_high_valued = np.array(maes_model_all_runs_high_valued).flatten()
    mses_model_high_valued = np.array(mses_model_all_runs_high_valued).flatten()


    allocation_values_model = np.array(allocation_values_model_all_runs).flatten()


    # print('--------------------')
    # print(f'Allocation value ML: {allocation_values_ML.mean()} and GUI: {allocation_values_GUI.mean()}')
    # print('--------------------')
    # print('----- STATS FOR RANDOM BUNDLES -----')
    # print(f'kts_ML: {kts_ML.mean()} kts_GUI: {kts_GUI.mean()}')
    # print(f'r2s_ML: {r2s_ML.mean()} r2s_GUI: {r2s_GUI.mean()}')
    # print(f'maes_ML: {maes_ML.mean()} maes_GUI: {maes_GUI.mean()}')
    # print(f'mses_ML: {mses_ML.mean()} mses_GUI: {mses_GUI.mean()}')
    # print('--------------------')
    # print('----- STATS FOR HIGH VALUED BUNDLES -----')
    # print(f'kts_ML: {kts_ML_high_valued.mean()} kts_GUI: {kts_GUI_high_valued.mean()}')
    # print(f'r2s_ML: {r2s_ML_high_valued.mean()} r2s_GUI: {r2s_GUI_high_valued.mean()}')
    # print(f'maes_ML: {maes_ML_high_valued.mean()} maes_GUI: {maes_GUI_high_valued.mean()}')
    # print(f'mses_ML: {mses_ML_high_valued.mean()} mses_GUI: {mses_GUI_high_valued.mean()}')
    # print('--------------------')
    # print('----- STATS FOR GUI LANGUAGE -----')
    # print(f'kts_ML: {gui_language_kts_ML_all_runs.mean()} kts_GUI: {gui_language_kts_GUI_all_runs.mean()}')
    # print(f'r2s_ML: {gui_language_r2s_ML_all_runs.mean()} r2s_GUI: {gui_language_r2s_GUI_all_runs.mean()}')
    # print(f'maes_ML: {gui_language_maes_ML_all_runs.mean()} maes_GUI: {gui_language_maes_GUI_all_runs.mean()}')
    # print(f'mses_ML: {gui_language_mses_ML_all_runs.mean()} mses_GUI: {gui_language_mses_GUI_all_runs.mean()}')
    # print('--------------------')
    # print('----- STATS FOR GUI LANGUAGE TOP 5 -----')
    # print(f'kts_ML: {gui_language_kts_ML_all_runs_top5.mean()} kts_GUI: {gui_language_kts_GUI_all_runs_top5.mean()}')
    # print(f'r2s_ML: {gui_language_r2s_ML_all_runs_top5.mean()} r2s_GUI: {gui_language_r2s_GUI_all_runs_top5.mean()}')
    # print(f'maes_ML: {gui_language_maes_ML_all_runs_top5.mean()} maes_GUI: {gui_language_maes_GUI_all_runs_top5.mean()}')
    # print(f'mses_ML: {gui_language_mses_ML_all_runs_top5.mean()} mses_GUI: {gui_language_mses_GUI_all_runs_top5.mean()}')
    # print('--------------------')
    # print('----- STATS FOR GUI LANGUAGE TOP 10 -----')
    # print(f'kts_ML: {gui_language_kts_ML_all_runs_top10.mean()} kts_GUI: {gui_language_kts_GUI_all_runs_top10.mean()}')
    # print(f'r2s_ML: {gui_language_r2s_ML_all_runs_top10.mean()} r2s_GUI: {gui_language_r2s_GUI_all_runs_top10.mean()}')
    # print(f'maes_ML: {gui_language_maes_ML_all_runs_top10.mean()} maes_GUI: {gui_language_maes_GUI_all_runs_top10.mean()}')
    # print(f'mses_ML: {gui_language_mses_ML_all_runs_top10.mean()} mses_GUI: {gui_language_mses_GUI_all_runs_top10.mean()}')
    # print('--------------------')
    # print('----- STATS FOR GUI LANGUAGE Projections -----')
    # print(f'base values ML: {gui_language_base_values_ML_all_runs.mean()} GUI: {gui_language_base_values_GUI_all_runs.mean()}')
    # print(f'adjustments ML: {gui_language_adjustments_ML_all_runs.mean()} GUI: {gui_language_adjustments_GUI_all_runs.mean()}')



    result_dict = { 
        'kts_model': kts_model,
        'r2s_model': r2s_model,
        'maes_model': maes_model,
        'mses_model': mses_model,
        'kts_model_high_valued': kts_model_high_valued,
        'r2s_model_high_valued': r2s_model_high_valued,
        'maes_model_high_valued': maes_model_high_valued,
        'mses_model_high_valued': mses_model_high_valued,
        'gui_language_kts_model': gui_language_kts_model_all_runs,
        'gui_language_r2s_model': gui_language_r2s_model_all_runs,
        'gui_language_maes_model': gui_language_maes_model_all_runs,
        'gui_language_mses_model': gui_language_mses_model_all_runs,
        'gui_language_kts_model_top5': gui_language_kts_model_all_runs_top5,
        'gui_language_r2s_model_top5': gui_language_r2s_model_all_runs_top5,
        'gui_language_maes_model_top5': gui_language_maes_model_all_runs_top5,
        'gui_language_mses_model_top5': gui_language_mses_model_all_runs_top5,
        'gui_language_kts_model_top10': gui_language_kts_model_all_runs_top10,
        'gui_language_r2s_model_top10': gui_language_r2s_model_all_runs_top10,
        'gui_language_maes_model_top10': gui_language_maes_model_all_runs_top10,
        'gui_language_mses_model_top10': gui_language_mses_model_all_runs_top10,
        'gui_language_base_values_model_all_runs': gui_language_base_values_model_all_runs, 
        'gui_language_adjustments_model_all_runs': gui_language_adjustments_model_all_runs,
        'allocation_values_model': allocation_values_model,
    }


    save_obj(result_dict, f'./learning_results/model_comparison_v2_linear_instances_{linear_instances}_queries_{queries}_popular_{number_of_popular}_model_family_{model_family}_benchmark_family_{benchmark_family}')
    # set_trace()

    
