import numpy as np
import math
from timeit import default_timer as timer
from noisy_binary_search import noisy_binary_search_and_insert, noisy_binary_search_hacky
from pdb import set_trace
from sklearn.cluster import KMeans
from gce_losses import GeneralizedCrossEntropyLoss



from cleanup import solve_student, student, calculate_bundle_value_by_hand, solve_student_simple_linear
from util_dataset import create_initial_dataset


from sklearn.preprocessing import MinMaxScaler, PolynomialFeatures

import torch
from torchinfo import summary
import logging

# --- MVNN imports ---
from ca_networks_inner.mvnn import MVNN, compute_metrics
from ca_networks_inner.mvnn import train


# -- MIP imports ---
from gurobi_mip_mvnn_2 import GUROBI_MIP2_MVNN
from gurobi_mip_mvnn_2_linear import GUROBI_MIP2_MVNN_LINEAR


import torch.nn.functional as F
import copy 

from util_dataset import create_multiple_datasets_enhanced
from xgboost_solver import gurobi_MIP_xgboost
from gurobi_SVR_solver import gurobi_MIP_SVR
import xgboost

# --- imports after principled tabu ---
import scipy
import pickle
import pandas as pd

# --- stuff that should be imported for the sake of completeness ---
from cleanup import timetable_generator  # not used right now -> instances already created
from util_dataset import create_multiple_students  # not used right now -> instances arlready created
from sklearn import linear_model   # not used right now -> we are using different networks
from sklearn.svm import NuSVR   # not used right now -> we are using different networks

# --- imports for projecting MVNNs back to the GUI language ---
from projecting_mvnns_utils import sample_mvnn, poly_regression_mvnn
from gurobi_mip_poly_regression import GUROBI_MIP_POLY_REGRESSION


###############
def load_obj(name):
    print(name + '.pkl')
    with open(name + '.pkl', 'rb') as f:
        return pickle.load(f)

def calculate_theoretical_bound_squared(number_of_courses_M = 30, largest_bundle_size_k = 5):
    """
    Returns the clearing error (squared) under which the allocation satisfies the fairness guarantees proved by Eric Budish.

    Parameters:
    -------------------
    number_of_courses_M: int
        The number of courses available (e.g. 351 for the spring semester at Wharton)
    largest_budlde_size_k: int
        The largest bundle ``requested'' by some student (or possible to be requested)

    Returns:
    -------------------
    alpha_squared: int
        The theoretical clearing error squared under which those theoretical guarantees hold.
    """
    return (number_of_courses_M * largest_bundle_size_k / 2)


def calculate_alpha(capacities, demands, prices, return_zeta = False):
    """
    A function taking as input the capacity, max_capacity and demand of each course and returning the clearing error squared (a^2) as defined in
    Budish et Al.

    Parameters:
    --------------------
    capacities: numpy integer array of shape (number of courses, )
        capacities[i]: The seat capacity of the i-th course
    demands: numpy integer array of shape (number of courses, )
        demands[i]: The total student demand for the i-th course
    prices[i]: numpy array of shape (number of courses, )
        prices[i]: The price of the i-th course
    return_zeta: Bool
        If true: Returns instead the zeta vector.

    Returns:
    --------------------
    alpha_squared: int
        The clearing error squared for the A-CCEI (a^2), as defined in Budish et Al.

    """
    diffs = capacities - demands
    for i in range(demands.shape[0]):
        if (prices[i] == 0 and diffs[i] > 0):
            diffs[i] = 0  # if a course has 0 price -> under-demand does not count as clearing error.

    if(return_zeta):
        return diffs

    alpha_squared = np.square(diffs).sum()
    return alpha_squared


def calculate_single_student_demand(prices, student_profile, course_timetable, credit_units = [1 for i in range(25)], model_type = 'True', courses_per_student = 5, budget = None):
    """
    Takes as input a price vector of courses a students' preferences and optionally her budget and returns the optimal legal schedule for that student. 
    If no budget is given: Her initial budget will be used. 

    Parameters:
    --------------------
    prices: np.array of shape(number_of_units, )
        prices[i]: The price of the i-th course
    student_profile: The prefences of that student, in the form of either her true preferences or her GUI reports
    course_timetable: list of lists of ints
        course_timetable[i][j]: The ids of all courses being taught in the j-th timeslot of the i-th day
    credit_units: list of floats
        credit_units[i]: The credit units of the i-th course
    model_type: string
        The type of problem instance to solve. e.g. for `True' it solves the MIP of the true preferences of each student.
    courses_per_student: int
        The maximum nubmer of courses each student is willing to take.

    Returns:
    student_demand: np.array of shape (number_of_courses, )
        sutdnet_demand[i]: 1 if the i-th course is included in the student's optimal schedule, 0 otherwise
    """
    total_demand = np.zeros(prices.shape[0])

    individual_demands = []

    if (model_type == 'True' or model_type == 'TrueNoisy'):
        (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, initial_budget) = student_profile
        if budget is None:
            budget = initial_budget

        student_demand = solve_student(course_timetable, prices, credit_units, budget, courses_per_student, additive_prefs, complements, substitutes, overload_penalty = overload_penalty,
                        timegap_penalty= timegap_penalty, free_days_marginal_values= free_days_marginal_values, ignore_timegaps= True, verbose = False)

    elif(model_type == 'TrueLinear' or model_type == 'LinearNoisy'):
        (linear_coefficients, initial_budget) = student_profile
        if budget is None:
            budget = initial_budget

        # set_trace()
        # student_demand = solve_student(course_timetable, prices, credit_units, budget, courses_per_student, linear_coefficients, [], [], overload_penalty = 0,
        #                 timegap_penalty= 0, free_days_marginal_values= [0, 0, 0, 0, 0], ignore_timegaps= True, verbose = False)
        
        student_demand = solve_student_simple_linear(course_timetable= course_timetable, course_prices= prices, credit_units= credit_units,
                        budget = budget, cu_limit = courses_per_student,additive_preferences = linear_coefficients,already_taken = [], 
                        verbose = False, print_time= False, print_solution = False,seats_available = None)
        # if np.sum(student_demand) == 0:
        #     print('Student demand is zero!')
        #     set_trace()


    elif(model_type == 'PairwiseAdjustments' or model_type == 'PairwiseAdjustmentsNoisy'):
        (additive_prefs, substitutes_clipped, complements_clipped, initial_budget) =  student_profile
        if budget is None:
            budget = initial_budget

        student_demand = solve_student(course_timetable, prices, credit_units, budget, courses_per_student, additive_prefs, complements_clipped, substitutes_clipped,
                                        overload_penalty = 0, timegap_penalty= 0, free_days_marginal_values= [0, 0, 0, 0, 0], ignore_timegaps= True, verbose = False)


    return np.array(student_demand)



def calculate_total_demand(prices, student_profiles, course_timetable, credit_units = [1 for i in range(30)], return_individual_demands = False,
                           model_type = 'True', model_param_dictionary = None, courses_per_student = 5):
    """
    A function that takes as input the price vector and the student_preferences and returns the aggregated demand for all courses.

    Parameters:
    --------------------
    prices: np.array of shape(number_of_units, )
        prices[i]: The price of the i-th course
    student_profiles: list of tuples of the form (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)
        student_profile[i]: Those numbers for the i-th student
    course_timetable: list of lists of ints
        course_timetable[i][j]: The ids of all courses being taught in the j-th timeslot of the i-th day
    credit_units: list of floats
        credit_units[i]: The credit units of the i-th course
    model_type: string
        The type of problem instance to solve. e.g. for `True' it solves the MIP of the true preferences of each student.
    courses_per_student: int
        The maximum nubmer of courses each student is willing to take.

    Returns:
    total_demand: np.array of shape (number_of_courses, )
        total_demand[i]: The total demand of all students for the i-th course
    """
    total_demand = np.zeros(prices.shape[0])

    individual_demands = []

    if (model_type == 'True' or model_type == 'TrueNoisy'):
        for (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget) in student_profiles:
            student_demand = solve_student(course_timetable, prices, credit_units, budget, courses_per_student, additive_prefs, complements, substitutes, overload_penalty = overload_penalty,
                            timegap_penalty= timegap_penalty, free_days_marginal_values= free_days_marginal_values, ignore_timegaps= True, verbose = False)

            total_demand = total_demand + np.array(student_demand)
            individual_demands.append(np.array(student_demand))

    elif(model_type in ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'LinearRegressionNoisy', 'RidgeNoisy', 'LassoNoisy', 'ElasticNetNoisy']):
        for (linear_coefficients, budget) in student_profiles:
            student_demand = solve_student(course_timetable, prices, credit_units, budget, courses_per_student, linear_coefficients, [], [], overload_penalty = 0,
                            timegap_penalty= 0, free_days_marginal_values= [0, 0, 0, 0, 0], ignore_timegaps= True, verbose = False)

            total_demand = total_demand + np.array(student_demand)
            individual_demands.append(np.array(student_demand))

    elif(model_type == 'TrueLinear' or model_type == 'LinearNoisy'):
        for (linear_coefficients, budget) in student_profiles:
            student_demand = solve_student(course_timetable, prices, credit_units, budget, courses_per_student, linear_coefficients, [], [], overload_penalty = 0,
                            timegap_penalty= 0, free_days_marginal_values= [0, 0, 0, 0, 0], ignore_timegaps= True, verbose = False)

            total_demand = total_demand + np.array(student_demand)
            individual_demands.append(np.array(student_demand))

    elif(model_type == 'PairwiseAdjustments' or model_type == 'PairwiseAdjustmentsNoisy'):
        for (additive_prefs, substitutes_clipped, complements_clipped, budget) in student_profiles:
            student_demand = solve_student(course_timetable, prices, credit_units, budget, courses_per_student, additive_prefs, complements_clipped, substitutes_clipped,
                                           overload_penalty = 0, timegap_penalty= 0, free_days_marginal_values= [0, 0, 0, 0, 0], ignore_timegaps= True, verbose = False)

            total_demand = total_demand + np.array(student_demand)
            individual_demands.append(np.array(student_demand))

    elif model_type == 'UNN_projected': 
        if not model_param_dictionary.get('linear_projection', False):
            for (model, budget) in student_profiles: 
                solver = GUROBI_MIP_POLY_REGRESSION(model)
                solver.generate_mip(course_timetable = course_timetable,
                                credit_units = credit_units,
                                course_prices = prices,
                                budget = budget,
                                cu_max = 5,
                                timeLimit = None,
                                MIPGap = None,
                                verbose = False,
                                )

            # get the opt schedule according to the solver
                student_demand = np.array(solver.solve_mip(verbose = False))
                total_demand = total_demand + student_demand
                individual_demands.append(student_demand)

                value_mip = solver.mip.getObjective().getValue()

                # NOTE: should comment these lines out for extra efficiency once any potential kinks have been ironed out!  
                poly = PolynomialFeatures(degree = 2, include_bias = False, interaction_only = True)
                demand_poly = poly.fit_transform(student_demand.reshape(1, -1))
                value_model = model.predict(demand_poly)
                
                if (np.abs(value_model-value_mip) > 1e-5)[0]:
                    print('Achtung achtung, models disagree on value: ', value_model, value_mip)

        else:  # if we are in the linear preferences setting -> we can use the MIP for the GUI language directly. 
            for (model, budget) in student_profiles:
                linear_coefficients = model.coef_.reshape(-1)
                student_demand = solve_student(course_timetable, prices, credit_units, budget, courses_per_student, linear_coefficients, [], [], overload_penalty = 0,
                                timegap_penalty= 0, free_days_marginal_values= [0, 0, 0, 0, 0], ignore_timegaps= True, verbose = False)

                total_demand = total_demand + np.array(student_demand)
                individual_demands.append(np.array(student_demand))



    elif(model_type == 'UNN' or model_type == 'UNN_Noisy'):
        time_add_budget_total = 0
        time_solve_mip_total = 0
        if not model_param_dictionary.get('use_cqs', False):
            for (i, (model, solver, budget)) in enumerate(student_profiles):
                start = timer()
                solver.add_budget_constraint(course_prices = prices, budget = budget)
                mid = timer()
                student_demand = np.array(solver.solve_mip(outputFlag=False, verbose = False))
                end = timer()
                total_demand = total_demand + student_demand
                individual_demands.append(student_demand)
                time_add_budget_total += (mid - start)
                time_solve_mip_total += (end - mid)

        else:
            for (i, (model, solver, scale, budget)) in enumerate(student_profiles):
                start = timer()
                solver.add_budget_constraint(course_prices = prices, budget = budget)
                mid = timer()
                student_demand = np.array(solver.solve_mip(outputFlag=False, verbose = False))
                end = timer()
                total_demand = total_demand + student_demand
                individual_demands.append(student_demand)
                time_add_budget_total += (mid - start)
                time_solve_mip_total += (end - mid)

        print(f'AVG time to add the budget constraint after {i + 1} UNN MIPS: {time_add_budget_total / (i + 1)}')
        print(f'AVG time to actually solve a UNN MIP after {i + 1} UNN MIPS: {time_solve_mip_total / (i + 1)}')

    elif model_type == 'UNN_transfer_learning':
        time_add_budget_total = 0
        time_solve_mip_total = 0
        for (i, (model, solver, scale, pre_trained_model, budget)) in enumerate(student_profiles):
                start = timer()
                solver.add_budget_constraint(course_prices = prices, budget = budget)
                mid = timer()
                student_demand = np.array(solver.solve_mip(outputFlag=False, verbose = False))
                end = timer()
                total_demand = total_demand + student_demand
                individual_demands.append(student_demand)
                time_add_budget_total += (mid - start)
                time_solve_mip_total += (end - mid)

        print(f'AVG time to add the budget constraint after {i + 1} UNN MIPS: {time_add_budget_total / (i + 1)}')
        print(f'AVG time to actually solve a UNN MIP after {i + 1} UNN MIPS: {time_solve_mip_total / (i + 1)}')


    elif(model_type == 'NuSVR' or model_type == 'NuSVRNoisy'):
        time_add_budget_total = 0
        time_solve_mip_total = 0
        for (i, (model, solver, gamma, budget)) in enumerate(student_profiles):
            start = timer()
            solver.add_budget_constraint(course_prices = prices, budget = budget)
            mid = timer()
            optimal_schedule, optimal_value = solver.solve_mip(verbose=False)
            end = timer()
            student_demand = np.array(optimal_schedule)
            total_demand = total_demand + student_demand
            individual_demands.append(student_demand)
            time_add_budget_total += (mid - start)
            time_solve_mip_total += (end - mid)

        print(f'AVG time to add the budget constraint after {i + 1} SVR MIPS: {time_add_budget_total / (i + 1)}')
        print(f'AVG time to actually solve a SVR MIP after {i + 1} SVR MIPS: {time_solve_mip_total / (i + 1)}')

    elif(model_type == 'xgboost' or model_type == 'xgboostNoisy'):
        time_add_budget_total = 0
        time_solve_mip_total = 0
        for (i, (model, solver, budget)) in enumerate(student_profiles):
            start = timer()
            solver.add_budget_constraint(course_prices=prices, budget= budget)
            mid = timer()
            student_demand = np.array(solver.solve_mip()[0])
            end = timer()
            # print(f'Seconds taken to add the budget constraint to xgboost: {mid - start}')
            # print(f'Seconds taken to ACTUALLY solve the xgboost: {end - mid}')
            time_add_budget_total += (mid - start)
            time_solve_mip_total += (end - mid)

            total_demand = total_demand + student_demand
            individual_demands.append(student_demand)
        # set_trace()

        print(f'AVG time to add the budget constraint after {i + 1} XGBOOST MIPS: {time_add_budget_total / (i + 1)}')
        print(f'AVG time to actually solve an xgboost MIP after {i + 1} XGBOOST MIPS: {time_solve_mip_total / (i + 1)}')

    else:
        print(f'Invalid model type provided at function calculate total demand. model type: {model_type}')

    if(return_individual_demands):
        return (total_demand, np.array(individual_demands))

    return total_demand


def percentage_neighbor(prices, capacities, demands, max_budget = 1.01, percentage_multiplier = 0.1):
    """
    Adjusts the prices of all courses proportionally to their percentage of over/under subscription.

    Parameters:
    --------------------
    prices: np.array of shape(number_of_courses, )
        prices[i]: The price of the i-th course
    capacities: np.array of shape(number_of_courses, )
        capacities[i]: The capacity of the i-th course
    demands: np.array of shape(number_of_courses, )
        demands[i]: The demand of the i-th course
    percentage_multiplier: Float
        The value by which to multiply the percentage price change calculated based on the percentage of over/under-subscription

    Returns:
    --------------------
    new_prices: np.array of shape(number_of_courses, )
        new_prices[i]: The price of the i-th course for this neighbor.
    """

    oversubscriptions = demands - capacities
    oversubscription_percetages = oversubscriptions / capacities
    price_changes = oversubscription_percetages * percentage_multiplier
    new_prices = prices * (np.ones(prices.shape[0]) + price_changes)
    new_prices = np.maximum(0, new_prices)
    return np.minimum(new_prices, max_budget)


def gradient_neighbor(prices, capacities, demands, max_budget = 1.01,  gradient_multiplier = 0.002):
    """
    Adjusts the prices of all courses proportionally to their number of over/under subscription as described in Budish et Al.

    Parameters:
    --------------------
    prices: np.array of shape(number_of_courses, )
        prices[i]: The price of the i-th course
    capacities: np.array of shape(number_of_courses, )
        capacities[i]: The capacity of the i-th course
    demands: np.array of shape(number_of_courses, )
        demands[i]: The demand of the i-th course
    gradient_multiplier: Float
        The value by which to change the price of a course for every point of over/under subscription

    Returns:
    --------------------
    new_prices: np.array of shape(number_of_courses, )
        new_prices[i]: The price of the i-th course for this neighbor.
    """

    oversubscriptions = demands - capacities
    price_changes = oversubscriptions * gradient_multiplier
    new_prices = prices + price_changes

    new_prices = np.maximum(0, new_prices)  # do not let the prices go below 0!
    new_prices = np.minimum(max_budget, new_prices)   # do not let the prices go above the maximum budget of a student
    return new_prices



def individual_neighbor_proper(prices, capacities, demands, individual_demands, courses_to_adjust, student_budgets):
    """
    Performns an individual neighbor adjustment, as desbribed in Budish et Al.

    Parameters:
    --------------------
    prices: np.array of shape(number_of_courses, )
        prices[i]: The price of the i-th course
    capacities: np.array of shape(number_of_courses, )
        capacities[i]: The capacity of the i-th course
    demands: np.array of shape(number_of_courses, )
        demands[i]: The demand of the i-th course
    individual_demands: numpy array of shape (number_of_students, number_of_courses)
        individual_demands[i][j]: 1 if the i-th student takes the j-th course, else 0
    courses_to_adjust: list of ints
        Contains the ids of the courses to adjust in this individual adjustment neighbor (usually 1).
    student_budgets: numpy array of shape (number of students, )
        student_budgets[i]: The budget of the i-th student
    max_budget: Float
        The maximum budget beta tilde of a student
    multiplier: Float
        The value by which to change the price of a course for every point of over/under subscription

    Returns:
    --------------------
    new_prices: np.array of shape(number_of_courses, )
        new_prices[i]: The price of the i-th course for this neighbor.

    """
    new_prices = np.array(prices)
    for i in courses_to_adjust:
        oversubscription = demands[i] - capacities[i]
        if(oversubscription < 0):   # course has free seats -> set price to zero
            new_prices[i] = 0
        else:
            leftover_budgets_current_course = []   # change the price in such a way that exactly students_to_drop students will drop this course
            for j in range(len(individual_demands)):
                if (individual_demands[j][i] == 1):
                    student_expense = np.dot(individual_demands[j], prices)  # student expense = dot product of student demand and prices
                    leftover_budgets_current_course.append(student_budgets[j] - student_expense)
            leftover_budgets_current_course.sort()    # this sorts the list in place!
            # set_trace()
            new_prices[i] = prices[i] + (leftover_budgets_current_course[0] + leftover_budgets_current_course[1]) / 2
            # change the price by something "in between" the leftover budgets of the 2 students that have the hardest time affording the course.

            # increase the price of the course by

    return new_prices


def create_neighbors(prices, capacities, demands, individual_demands, student_list, timetable, credit_units, student_budgets, number_percentage_neighbors = 3, max_percentage_multiplier = 1.5,
                     number_gradient_neighbors = 20, max_gradient_multiplier = (0.1 / (2**6)),
                     number_individual_neighbors = 30, max_budget = 1.01, model_type = 'True',
                    model_param_dictionary = None, courses_per_student = 5):
    """
    A function that creates all neighbors of a given price vector and returns them sorted with respect to their clearing error a^2 ascending, as
    described in Budish et Al.

    Parameters:
    --------------------
    prices: np.array of shape(number_of_courses, )
        prices[i]: The price of the i-th course
    capacities: np.array of shape(number_of_courses, )
        capacities[i]: The capacity of the i-th course
    demands: np.array of shape(number_of_courses, )
        demands[i]: The demand of the i-th course
    individual_demands: np.array of shape (number_of_students, number_of_courses)
        individual_demands[i][j]: 1 if the i-th student takes the j-th course, else 0
    """

    neighbors = []
    neighbor_demands = []
    neighbor_individual_demands = []
    neighbor_clearing_errors = []
    neighbor_types = []

    credit_units = [1 for i in range(prices.shape[0])]   # WARNING: THIS SHOULD BE CHANGED

    # generate all the ``percentage'' neighbors
    if(number_percentage_neighbors > 0):
        percentage_step = max_percentage_multiplier / number_percentage_neighbors
    for i in range(1, number_percentage_neighbors + 1):
        percentage_multiplier = percentage_step * i
        new_neighbor = percentage_neighbor(prices, capacities, demands, max_budget = max_budget, percentage_multiplier= percentage_multiplier)
        neighbors.append(new_neighbor)
        neighbor_types.append(('percentage', percentage_step))

    # generate all the ``gradient'' neighbors
    if(number_gradient_neighbors > 0):
        gradient_step = max_gradient_multiplier / number_gradient_neighbors
    for i in range(1, number_gradient_neighbors + 1):
        gradient_multiplier = gradient_step * i
        new_neighbor = gradient_neighbor(prices, capacities, demands, max_budget = max_budget, gradient_multiplier= gradient_multiplier)
        neighbors.append(new_neighbor)
        neighbor_types.append(('gradient', gradient_multiplier))

    # generate all the `individual adjustment'' neighbors
    diffs = capacities - demands
    not_perfectly_subscribed_number = np.count_nonzero(diffs)
    print("Not Perefctly subscribed:", not_perfectly_subscribed_number)

    if (not_perfectly_subscribed_number <= number_individual_neighbors):
        print("entering first case: all individual neighbors single course!")
        for i in range(diffs.shape[0]):
            if (diffs[i] != 0):
                new_neighbor = individual_neighbor_proper(prices, capacities, demands, individual_demands, [i], student_budgets)
                neighbors.append(new_neighbor)
                neighbor_types.append(('individual', 'nada'))
    # ``wrongly subscribed'' courses > max number of individual neighbors -> each individual adjustment neighbor affects a single course!

    elif(number_individual_neighbors >= 1):
        print("entering second case: individual neighbors with more than one courses!")
        non_zero_indices = np.nonzero(diffs)[0]   # get the indices of all non-perfectly subscribed courses
        np.random.shuffle(non_zero_indices)       # shuffle them!
        adjustment_splits = np.array_split(non_zero_indices, number_individual_neighbors)  # create multiple ``close in size'' sublists for the neighbors
        for courses_to_adjust in adjustment_splits:
            new_neighbor = individual_neighbor_proper(prices, capacities, demands, individual_demands, courses_to_adjust, student_budgets)
            neighbors.append(new_neighbor)
            neighbor_types.append(('individual', 'nada'))

    #  calculate the demands and clearing errors for all the neighbors
    for (i, price_vector) in enumerate(neighbors):
        print(i)
        new_demand, new_individual_demand = calculate_total_demand(price_vector, student_list, timetable, credit_units = credit_units, return_individual_demands = True, model_type = model_type,
                                            model_param_dictionary = model_param_dictionary, courses_per_student= courses_per_student)
        neighbor_demands.append(new_demand)
        neighbor_individual_demands.append(new_individual_demand)
        new_clearing_error = calculate_alpha(capacities, new_demand, price_vector)
        neighbor_clearing_errors.append(new_clearing_error)

    # return the above arrays sorted in ascending clearing error
    neighbors = np.array(neighbors)     # convert to np.arrays because else the sort won't work
    neighbor_demands = np.array(neighbor_demands)
    neighbor_individual_demands = np.array(neighbor_individual_demands)
    neighbor_clearing_errors = np.array(neighbor_clearing_errors)
    neighbor_types = np.array(neighbor_types)

    sorted_indexes = np.argsort(neighbor_clearing_errors)
    sorted_neighbors = neighbors[sorted_indexes[::]]
    sorted_neighbor_demands = neighbor_demands[sorted_indexes[::]]
    sorted_neighbor_individual_demands = neighbor_individual_demands[sorted_indexes[::]]
    sorted_neighbor_clearing_errors = neighbor_clearing_errors[sorted_indexes[::]]
    sorted_neighbor_types = neighbor_types[sorted_indexes[::]]
    print(f'Top 10 neighbor types: {sorted_neighbor_types[:10]}')

    return sorted_neighbors, sorted_neighbor_demands, sorted_neighbor_individual_demands, sorted_neighbor_clearing_errors, sorted_neighbor_types


def heuristic_search(student_profiles, course_timetable, credit_units, capacities, max_budget = 1.1, max_steps_without_improvement = 5,  clearing_error_limit = 10, time_limit_search = 60, time_limit_restart = 20,
                    number_percentage_neighbors = 3, number_gradient_neighbors = 10, number_individual_neighbors = 30, max_restarts = 1, model_type = 'True',
                    model_param_dictionary = None, max_courses_per_student = 5, max_gradient_multiplier = (0.1 / (2**6))):
    """
    Algorithm 1 of Budish et At., searching over price space and returning a price vector p* corresponding to the lowest clearing error.
    """
    best_error = math.inf
    start_time = timer()
    current_time = timer()
    search_error_history = []   # a list (of lists for each random restart) of the best error found after each step for bookkeeping
    time_history = []
    neighbor_type_picked_history = []  # a list  (of lists for each random restart) of the type of neighbors picked at each step, for bookkeeping
    price_history = []
    restarts = 1
    termination_history = []

    student_budgets = [student_profile[-1] for student_profile in student_profiles]

    # print(f'heuristic_search was called with max_gradient_multiplier: {max_gradient_multiplier} and individual_multiplier: {individual_multiplier}')
    while((current_time - start_time < time_limit_search * 60) and (best_error > clearing_error_limit) and (restarts <= max_restarts)):  # the main repeat loop (line 2) of the algorithm
        print("Entering main while loop")
        restarts += 1
        prices = np.random.uniform(low = 0, high = max_budget,  size = capacities.shape[0])    # Start from a random, reasonable price vector
        demands, individual_demands = calculate_total_demand(prices, student_profiles, course_timetable, credit_units, return_individual_demands = True, model_type = model_type,
                                         model_param_dictionary= model_param_dictionary, courses_per_student= max_courses_per_student)
        # set_trace()
        search_error = calculate_alpha(capacities, demands, prices)     # search error tracks the best error found in this search start
        tabu_dict = {}
        steps_without_improvement = 0
        restart_start_time = timer()
        search_error_history_restart = []
        time_history_restart = []
        neighbor_type_picked_restart = []
        price_history_restart = []

        while((steps_without_improvement < max_steps_without_improvement) and (current_time - restart_start_time < time_limit_restart * 60)):
            termination_condition = 'steps_no_improvement'
            print("Entering while steps without improvement loop")
            neighbors, neighbor_demands, neighbor_individual_demands, neighbor_clearing_errors, neighbor_types = create_neighbors(prices, capacities, demands, individual_demands, student_list = student_profiles,
                                                timetable = course_timetable, credit_units=credit_units, student_budgets = student_budgets,
                                                number_percentage_neighbors= number_percentage_neighbors, number_gradient_neighbors= number_gradient_neighbors,
                                                number_individual_neighbors= number_individual_neighbors, model_type = model_type, max_gradient_multiplier = max_gradient_multiplier,
                                                model_param_dictionary = model_param_dictionary)
            # set_trace()
            found_next_step = False
            for i in range(neighbors.shape[0]):
                prices_tilde = neighbors[i]
                demands = neighbor_demands[i]
                individual_demands = neighbor_individual_demands[i]
                if(demands.tobytes() not in tabu_dict):  # this price vector does not induce demands found in our tabu list -> we continue form here
                    found_next_step = True
                    break               # found next step -> break neighbor for loop
                print("Actually found neighbor that is in the tabu list!!!")

            if (not found_next_step):
                steps_without_improvement = max_steps_without_improvement  # all neighbors are in the tabu list, force a restart
                print("This random restart was stopped because all neighbors were tabu!")
                termination_condition = 'tabu_neighbors'

            else:                                     # prices_tilde has the next step of the search
                prices = prices_tilde
                neighbor_type_picked_restart.append(neighbor_types[i])  # keep a list of the type of neighbor picked at line 20 of Algorithm 1
                tabu_dict[demands.tobytes()] = True
                current_error = neighbor_clearing_errors[i]
                if (current_error < search_error):
                    search_error = current_error
                    print("Found new SEARCH best error with value:", search_error)
                    steps_without_improvement = 0                      # we impmproved our search solution, so reset the step counter
                else:
                    print("Did not improve our search error. Current error:", current_error, " Search error:", search_error)
                    steps_without_improvement = steps_without_improvement + 1

                search_error_history_restart.append(search_error)  # we append the best error of this restart, no matter what it is

                if (current_error < best_error):
                    print(f'We also improved our best error from {best_error} to {current_error}')
                    best_error = current_error
                    prices_best = prices

            current_time = timer()
            time_history_restart.append(current_time - restart_start_time)
            price_history_restart.append(prices)

        if(current_time - restart_start_time >= time_limit_restart * 60):
            termination_condition = 'restart_timeout'
        current_time = timer()
        neighbor_type_picked_history.append(neighbor_type_picked_restart)
        time_history.append(time_history_restart)
        search_error_history.append(search_error_history_restart)
        price_history.append(price_history_restart)

        termination_history.append(termination_condition)

    print('TERMINATION HISTORY RIGHT BEFORE TABU RETURNS: ', termination_history)
    return prices_best, best_error, {'neighbor_history':  neighbor_type_picked_history, 'search_error_history': search_error_history, 'time_history': time_history, 'price_history': price_history, 'termination_history': termination_history}

# Principled Tabu Parts


def keep_pairwise_adjustments(complements, substitutes):
    """
    A function taking as input the true complements and substitutes of a student (i.e., bundles of all possible sizes), and returning only the pairwise adjustments.
    """
    complements_clipped = []
    substitutes_clipped = []
    for (course_indexes, adjustments) in complements:
        adjustments_clipped = np.full(len(adjustments), adjustments[1]) # create a list with the same length, but all set equal to the value of the pairwise adjustment
        adjustments_clipped[0] = 0
        complements_clipped.append((course_indexes, adjustments_clipped))

    for (course_indexes, adjustments) in substitutes:
        adjustments_clipped = np.full(len(adjustments), adjustments[1]) # create a list with the same adjustment values, but all set equal to the value of the pairwise adjustment
        adjustments_clipped[0] = 0
        substitutes_clipped.append((course_indexes, adjustments_clipped))

    return(complements_clipped, substitutes_clipped)


def train_unn_lean(model, optimizer, loader, ymax, epochs = 200, loss_function = F.l1_loss, print_frequency = 50):
    """
    A simple function to train a UNN network.
    """
    metrics = {}

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(epochs))

    for epoch in range(epochs+1):
        metrics[epoch] = train(model,
                               device=torch.device('cpu'),
                               train_loader=loader,
                               optimizer=optimizer,
                               loss_func= loss_function)

        scheduler.step()

        if(epoch % print_frequency == 0):
            print(f'Epoch {epoch :<4}: mae={round(metrics[epoch]["mae"],6):<10}   kendall_tau={round(metrics[epoch]["kendall_tau"],6):<10}   r2={round(metrics[epoch]["r2"],6):<10}')

    return model


def noisify_student(individual_student, forget_base = 0.5, forget_base_uniform= 0, forget_adjustments = 0.2, base_noise_std = 10, adjustment_noise_std = 0.2, seed = 42, return_forgotten_bases = False, multiplicative_base_noise = False):
    """
    A function taking as input the true preferences of a student and reutrning a "noisy" version of that, much more efficiently than noisifying bundles one by one.

    Paramters:
    -----------------
    inidividual_student: tuple
        The true preferences of a single student, as returned by create_multiple_students
    forget_base: float
        The percentage of her lowest base values that a student will forget.
    forget_base_uniform:
        An additional probability that she forgets a base value, regardless of the value of the course
    forget_adjustments: float
        The percentage of her adjustments that a student will forget.
    base_noise_std: float
        The std of the noise that will be added to the students' base values.
    adjustment_noise_std: float
        The std of the noise that will be added to the students' adjustments.
    seed: int
        The random seed to be used

    Returns:
    -----------------
    noisy_student: tuple
        A student of the same type as the one returned by create_multiple_student, but with the appropriate noise added to it.
        if return_forgotten_bases: Returns base values, complements, substitutes, plus for which courses the base values were forgotten
    """

    print(f'Noisify student got called with forget_base_uniform: {forget_base_uniform}')
    base_values_copy = individual_student[0].copy()


    # add noise to the base values of the courses
    rng = np.random.default_rng(seed)
    if not multiplicative_base_noise:
        base_noise = rng.normal(scale = base_noise_std, size = base_values_copy.shape)
        # set_trace()
        # only add noise to the courses that the student has a non-zero value for 
        base_noise[base_values_copy == 0] = 0
        base_values_copy = base_values_copy + base_noise

    else:
        print('noisify student called with multiplicative_base_noise!!!')
        base_noise = rng.uniform(low = 1 - base_noise_std, high = 1 + base_noise_std, size = base_values_copy.shape)
        base_values_copy = base_values_copy * base_noise
        # print(f'base values noise: {base_noise}')

    base_values_copy = np.maximum(base_values_copy, 0)  # make sure the base values do not go below 0!
#     set_trace()

    # forget some of the courses
    sorted_arguments = np.argsort(base_values_copy)

    # have a sorted list of the indexes, but only for the non-zero base values -> there is no point in forgetting a 0-valued course. 
    sorted_arguments_positive = sorted_arguments[base_values_copy[sorted_arguments] != 0]

    expected_number_of_courses_to_forget = sorted_arguments_positive.shape[0] * forget_base
    random_number = rng.uniform(low = 0, high = 1)
    if random_number <= expected_number_of_courses_to_forget - math.floor(expected_number_of_courses_to_forget):
        courses_to_forget = math.ceil(expected_number_of_courses_to_forget)
    else:
        courses_to_forget = math.floor(expected_number_of_courses_to_forget)

    unforgotten_bases = np.array([1 for i in range(len(base_values_copy))])

    base_values_copy[sorted_arguments_positive[:courses_to_forget]] = 0
    unforgotten_bases[sorted_arguments_positive[:courses_to_forget]] = 0  # also mark the courses you forgot to report a base value for.

    for i in range(len(base_values_copy)):
        if (rng.random() < forget_base_uniform):
            base_values_copy[i] = 0
            unforgotten_bases[i] = 0
            print(f'Forgetting course {i} because of the new condition')

#     set_trace()

    # forget some of the adjustmnets
    noisy_substitutes = []
    noisy_complements = []

    for (substitute_list, substitute_values) in individual_student[1]:
        substitute_list_clipped = [x for x in substitute_list if base_values_copy[x] != 0]
        unforgettable_substitutes = [x for x in substitute_list_clipped if rng.random() >= forget_adjustments]
        if(len(unforgettable_substitutes) > 1):
            unforgettable_values = np.array(substitute_values[: len(unforgettable_substitutes)])
            # adjustment_noise = rng.normal(loc = 1.0,  scale = adjustment_noise_std, size = unforgettable_values.shape)
            adjustment_noise = rng.uniform(low = 1 - adjustment_noise_std, high = 1 + adjustment_noise_std, size = unforgettable_values.shape)
            unforgettable_values = unforgettable_values * adjustment_noise
            unforgettable_values = np.minimum(unforgettable_values, 0)  # make sure substitutes don't become complements!
            unforgettable_values[::-1].sort()  # make sure they are still sorted properly!
            noisy_substitutes.append((unforgettable_substitutes, unforgettable_values))

    for (complement_list, complement_values) in individual_student[2]:
        complement_list_clipped = [x for x in complement_list if base_values_copy[x] != 0]
        unforgettable_complements = [x for x in complement_list_clipped if rng.random() >= forget_adjustments]
        if(len(unforgettable_complements) > 1):
            unforgettable_values = np.array(complement_values[: len(unforgettable_complements)])
            # adjustment_noise = rng.normal(loc = 1.0,  scale = adjustment_noise_std, size = unforgettable_values.shape)
            adjustment_noise = rng.uniform(low = 1 - adjustment_noise_std, high = 1 + adjustment_noise_std, size = unforgettable_values.shape)
            unforgettable_values = unforgettable_values * adjustment_noise
            unforgettable_values = np.maximum(unforgettable_values, 0)   # make sure complements don't become substitutes!
            unforgettable_values.sort()                                 # make sure they are still sorted properly!
            noisy_complements.append((unforgettable_complements, unforgettable_values))

    if not return_forgotten_bases:
        return (base_values_copy, noisy_substitutes, noisy_complements, individual_student[3], individual_student[4], individual_student[5], individual_student[6])

    else:

        return (base_values_copy, noisy_substitutes, noisy_complements, unforgotten_bases)  # if return_forgotten_bases: only returns base values/complements/substitutes and what the student reported no value for


def noisify_all_students(student_list, forget_base = 0.15, forget_base_uniform= 0, forget_adjustments = 0.0, base_noise_std = 3, adjustment_noise_std = 0.05, seed = 42, multiplicative_base_noise = False):
    """
    A function taking as input the true preferences of all students and reutrning a "noisy" version of that, much more efficiently than noisifying bundles one by one.

    Paramters:
    -----------------
    student_list: list
        A list containing the true preferences of all students, as returned by create_multiple_students
    forget_base: float
        The percentage of her lowest base values that a student will forget.
    forget_adjustments: float
        The percentage of her adjustments that a student will forget.
    base_noise_std: float
        The std of the noise that will be added to the students' base values.
    adjustment_noise_std: float
        The std of the noise that will be added to the students' adjustments.
    seed: int
        The random seed to be used

    Returns:
    -----------------
    noisy_list: list
        A list of the same type as the one returned by create_multiple_students, but with the appropriate noise added to every student.
    """

    noisy_list = []
    print(f'Noisy all students got called with forget base uniform: {forget_base_uniform}')
    for (i, individual_student) in enumerate(student_list):
        noisy_student = noisify_student(individual_student, forget_base= forget_base, forget_base_uniform= forget_base_uniform, forget_adjustments = forget_adjustments,
                                        base_noise_std= base_noise_std, adjustment_noise_std = adjustment_noise_std, seed = seed + i, multiplicative_base_noise= multiplicative_base_noise)
        noisy_list.append(noisy_student)

    return noisy_list


def guisify_all_students(student_list, forget_base = 0.15, forget_base_uniform= 0, forget_adjustments = 0.0, base_noise_std = 3, adjustment_noise_std = 0.05, seed = 42, multiplicative_base_noise = False,
                         cognomos_interface = False):
    """
    A function taking as input the true preferences of all students and reutrning a "noisy" version of that, similar to what they would be able to report to the current Course Match GUI. 
    Highlightes which courses were forgotten by the student, and for which courses the student actually entered a 0 value. 

    Paramters:
    -----------------
    student_list: list
        A list containing the true preferences of all students, as returned by create_multiple_students
    forget_base: float
        The percentage of her lowest base values that a student will forget.
    forget_adjustments: float
        The percentage of her adjustments that a student will forget.
    base_noise_std: float
        The std of the noise that will be added to the students' base values.
    adjustment_noise_std: float
        The std of the noise that will be added to the students' adjustments.
    seed: int
        The random seed to be used

    Returns:
    -----------------
    gui_list: list
        A list of type
    """

    gui_list = []
    for (i, individual_student) in enumerate(student_list):
        gui_student = noisify_student(individual_student, forget_base= forget_base, forget_base_uniform= forget_base_uniform, forget_adjustments = forget_adjustments,
                                        base_noise_std= base_noise_std, adjustment_noise_std = adjustment_noise_std, seed = seed + i, return_forgotten_bases= True, multiplicative_base_noise= multiplicative_base_noise)
        
        # if we are using the cognomos interface -> Project the utilities to the cognomos language
        if cognomos_interface:
            cognomos_projected_values = project_utilities_cognomos_language(gui_student[0]) # force the student to report her base values in the projected language

            # but for our purposes (i.e., generating a dataset based on the GUI reports), these values are projected back to [0, 100], using our projection.
            values_unprojected, min_value, max_value = transform_utilities(cognomos_projected_values, log_offset=1e-5, stretch_factor = 3, exploration_factor=0.0)
            # set the exploration factor to 0, because we want the "true" projection of those reports, same as if we did not project where 
            # the student's base values would be 0 for the courses she forgot.
            
            gui_student = (values_unprojected, gui_student[1], gui_student[2], gui_student[3])
            
        
        gui_list.append(gui_student)

    return gui_list


def create_model_student_list(student_list, timetable, model_type, seed, model_param_dictionary = None):
    """
    The heart of the new Course Match. A function taking as input the true student_list and creating the corresponding one for any model/algorithm.

    Paramters:
    -----------------
    student_list: list
        The list of the true students' preferences, as returned by create_multiple_students
    timetable: list of lists of ints
        timetable[i][j]: The indexes of the courses taught in the j-th hour of the i-th day.
    model_type: string
        The model for which to create a student list.
    seed: int
        The random seeed to be used.
    model_param_dictionary: dictionary or None
        Depending on the model, contains all the remaining model parameters required.
    """
    if model_param_dictionary is None:
        model_param_dictionary = {}
    if (model_type == 'True'):
        return student_list

    elif (model_type == 'TrueNoisy'):
        return noisify_all_students(student_list, forget_base = model_param_dictionary['noisy_forget_base'], forget_base_uniform= model_param_dictionary.get('noisy_forget_base_uniform', 0),
                forget_adjustments = model_param_dictionary['noisy_forget_adjustments'],
                base_noise_std = model_param_dictionary['noisy_base_std'],  adjustment_noise_std= model_param_dictionary['noisy_adj_std'], seed = seed,
                multiplicative_base_noise= model_param_dictionary.get('multiplicative_base_noise', False))

    elif (model_type == 'TrueLinear'):
        linear_student_list = []
        for (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget) in student_list:
            if model_param_dictionary.get('cognomos_projection', False):
                cognomos_projected_values = project_utilities_cognomos_language(additive_prefs)
                linear_student_list.append((cognomos_projected_values, budget))
            else:
                linear_student_list.append((additive_prefs, budget))
        return linear_student_list
    

    elif (model_type == 'LinearNoisy'):
        noisy_student_list = noisify_all_students(student_list, forget_base = model_param_dictionary['noisy_forget_base'], forget_base_uniform= model_param_dictionary.get('noisy_forget_base_uniform', 0),
                forget_adjustments = model_param_dictionary['noisy_forget_adjustments'],
                base_noise_std = model_param_dictionary['noisy_base_std'],  adjustment_noise_std= model_param_dictionary['noisy_adj_std'], seed= seed,
                multiplicative_base_noise= model_param_dictionary.get('multiplicative_base_noise', False))
        linear_student_list = []
        for (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget) in noisy_student_list:
            if model_param_dictionary.get('cognomos_projection', False):
                cognomos_projected_values = project_utilities_cognomos_language(additive_prefs)
                linear_student_list.append((cognomos_projected_values, budget))
            else:
                linear_student_list.append((additive_prefs, budget))
        return linear_student_list
    

    elif(model_type == 'PairwiseAdjustments'):
        PA_student_list = []
        for (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget) in student_list:
            complements_clipped, substitutes_clipped = keep_pairwise_adjustments(complements, substitutes)
            PA_student_list.append((additive_prefs, substitutes_clipped, complements_clipped, budget))
        return PA_student_list

    elif(model_type == 'PairwiseAdjustmentsNoisy'):
        noisy_student_list = noisify_all_students(student_list, forget_base = model_param_dictionary['noisy_forget_base'], forget_base_uniform= model_param_dictionary.get('noisy_forget_base_uniform', 0),
                forget_adjustments = model_param_dictionary['noisy_forget_adjustments'],
                base_noise_std = model_param_dictionary['noisy_base_std'],  adjustment_noise_std= model_param_dictionary['noisy_adj_std'], seed= seed,
                multiplicative_base_noise= model_param_dictionary.get('multiplicative_base_noise', False))
        PA_student_list = []
        for (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget) in noisy_student_list:
            complements_clipped, substitutes_clipped = keep_pairwise_adjustments(complements, substitutes)
            PA_student_list.append((additive_prefs, substitutes_clipped, complements_clipped, budget))
        return PA_student_list

    elif(model_type in ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'LinearRegressionNoisy', 'RidgeNoisy', 'LassoNoisy', 'ElasticNetNoisy']):
        if (model_type in ['LinearRegressionNoisy', 'RidgeNoisy', 'LassoNoisy', 'ElasticNetNoisy']):
            actual_list = noisify_all_students(student_list, forget_base = model_param_dictionary['noisy_forget_base'], forget_base_uniform= model_param_dictionary.get('noisy_forget_base_uniform', 0),
                forget_adjustments = model_param_dictionary['noisy_forget_adjustments'],
                base_noise_std = model_param_dictionary['noisy_base_std'],  adjustment_noise_std= model_param_dictionary['noisy_adj_std'], seed= seed,
                multiplicative_base_noise= model_param_dictionary.get('multiplicative_base_noise', False))
        else:
            actual_list = student_list

        linear_student_list = []
        for (number, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(actual_list):
            # Step 1: Generate random questions for all the students
            X_train, y_train, _, _ = create_initial_dataset(additive_prefs, substitutes, complements, timetable, credit_units = [1 for i in range(len(additive_prefs))],
                overload_penalty= overload_penalty, free_days_marginal_values= free_days_marginal_values, n_samples_train= model_param_dictionary['samples'],
                n_samples_test= 0, n_courses= len(additive_prefs), make_monotone= True, n_samples_train_range= model_param_dictionary['samples_in_range'], value_range= (model_param_dictionary['range_min_value'], math.inf), seed = seed + number)

            if (model_param_dictionary['scale_ys']):
                y_train = y_train / max(y_train)
        # step 2: train the appropriate linear model on them
            if (model_type == 'LinearRegression' or model_type == 'LinearRegressionNoisy'):
                reg = linear_model.LinearRegression().fit(X_train, y_train)
            elif(model_type == 'Ridge' or model_type == 'RidgeNoisy'):
                reg = linear_model.Ridge(alpha = model_param_dictionary['alpha']).fit(X_train, y_train)
            elif(model_type == 'Lasso' or model_type == 'LassoNoisy'):
                reg = linear_model.Lasso(alpha = model_param_dictionary['alpha']).fit(X_train, y_train)
            elif(model_type == 'ElasticNet' or model_type == 'ElasticNetNoisy'):
                reg = linear_model.ElasticNet(alpha = model_param_dictionary['alpha']).fit(X_train, y_train)
            linear_student_list.append((reg.coef_, budget))
        return linear_student_list

    elif(model_type == 'xgboost' or model_type == 'xgboostNoisy'):
        xgboost_student_list = []
        seconds_create_solver_total = 0
        seconds_generate_MIP_total = 0
        for (number, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(student_list):
            X_train, y_train, _, _ = create_initial_dataset(additive_prefs, substitutes, complements, timetable, credit_units = [1 for i in range(len(additive_prefs))],
                overload_penalty= overload_penalty, free_days_marginal_values= free_days_marginal_values, n_samples_train= model_param_dictionary['samples'],
                n_samples_test= 0, n_courses= len(additive_prefs), make_monotone= True, n_samples_train_range= model_param_dictionary['samples_in_range'], value_range= (model_param_dictionary['range_min_value'], math.inf), seed = seed + number)

            model = xgboost.XGBRegressor(colsample_bytree = model_param_dictionary['colsample_bytree'], eta = model_param_dictionary['eta'], max_depth= model_param_dictionary['max_depth'],
                        n_estimators= model_param_dictionary['n_estimators'], subsample= model_param_dictionary['subsample'], seed = seed)

            if (model_param_dictionary['scale_ys']):   # scale y's same as behnoosh, if we have to.
                y_train = y_train / max(y_train)

            model.fit(X_train, y_train)
            start = timer()
            solver = gurobi_MIP_xgboost(model, additive_prefs.shape[0])
            mid = timer()
            solver.generate_mip(credit_units=np.repeat(1, additive_prefs.shape[0]), cu_max=5, course_timetable= timetable)
            end = timer()
            # print(f'Seconds to create the xgboost solver: {mid - start}')
            # print(f'Seconds to generate the MIP with the solver: {end - mid}')
            seconds_create_solver_total += (mid - start)
            seconds_generate_MIP_total += (end - mid)

            xgboost_student_list.append((model, solver, budget))

        print(f'AVG time to create a solver for xgboost after 100 solvers: {seconds_create_solver_total / 100} ')
        print(f'AVG time to generate a MIP for xgboost after 100 solvers: {seconds_generate_MIP_total / 100} ')

        # solve part:
        # add course_timetable, and cu constraints here

        return xgboost_student_list

    elif(model_type == 'NuSVR' or model_type == 'NuSVRNoisy'):
        svr_student_list = []
        for (number, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(student_list):
            X_train, y_train, _, _ = create_initial_dataset(additive_prefs, substitutes, complements, timetable, credit_units = [1 for i in range(len(additive_prefs))],
                overload_penalty= overload_penalty, free_days_marginal_values= free_days_marginal_values, n_samples_train= model_param_dictionary['samples'],
                n_samples_test= 0, n_courses= len(additive_prefs), make_monotone= True, n_samples_train_range= model_param_dictionary['samples_in_range'], value_range= (model_param_dictionary['range_min_value'], math.inf), seed = seed + number)

            model = NuSVR(kernel= model_param_dictionary['kernel'], degree = model_param_dictionary['degree'], nu = model_param_dictionary['nu'],
                gamma = model_param_dictionary['gamma'], C = model_param_dictionary['C'])

            if (model_param_dictionary['scale_ys']):
                y_train = y_train / max(y_train)
            # set_trace()
            model.fit(X_train, y_train)

            if model.gamma == 'scale':
                gamma = 1 / (additive_prefs.shape[0] * X_train.var())
            elif model.gamma == 'auto':
                gamma = 1 / additive_prefs.shape[0]  # number of courses == additive_prefs.shape[0]
            else:
                gamma = model_param_dictionary['gamma']
            solver = gurobi_MIP_SVR(model, gamma)

            solver.generate_mip(course_prices = np.repeat(0, additive_prefs.shape[0]), credit_units = np.repeat(1, additive_prefs.shape[0]), budget = budget, course_timetable = timetable, cu_max = 5, verbose = False)

            svr_student_list.append((model, solver, gamma, budget))
        return svr_student_list

    elif(model_type == 'UNN' or model_type == 'UNN_Noisy'):
        if (model_type == 'UNN_Noisy'):
            actual_list = noisify_all_students(student_list, forget_base = model_param_dictionary['noisy_forget_base'], forget_base_uniform= model_param_dictionary.get('noisy_forget_base_uniform', 0),
                forget_adjustments = model_param_dictionary['noisy_forget_adjustments'],
                base_noise_std = model_param_dictionary['noisy_base_std'],  adjustment_noise_std= model_param_dictionary['noisy_adj_std'], seed= seed,
                multiplicative_base_noise= model_param_dictionary.get('multiplicative_base_noise', False))
        else:
            actual_list = student_list
        unn_student_list = []
        seconds_create_solver_total = 0
        seconds_generate_MIP_total = 0
        for (number, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(actual_list):
            # Step 1: Generate random questions for all the students
            #             print(f'timegap penalty: {timegap_penalty}, overload_penalty: {overload_penalty}, free days marginal values: {free_days_marginal_values}')
            X_train, y_train, _, _ = create_initial_dataset(additive_prefs, substitutes, complements, timetable, credit_units = [1 for i in range(len(additive_prefs))],
                overload_penalty= overload_penalty, free_days_marginal_values= free_days_marginal_values, n_samples_train= model_param_dictionary['samples'],
                n_samples_test= 0, n_courses= len(additive_prefs), make_monotone= True, n_samples_train_range= model_param_dictionary['samples_in_range'], value_range= (model_param_dictionary['range_min_value'], math.inf), seed = seed + number)

            # Step 2: Scale the dataset (required for UNN)
            scaler = MinMaxScaler()
            scaler.fit(y_train.reshape(-1, 1))
            y_train = (scaler.transform(y_train.reshape(-1, 1))).reshape(-1)

            ymax = 1  # NOTE: DEFINETELY NOT SURE ABOUT THAT ONE!!!

            # Step 3: Put the data in a dataloader
            train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(X_train).float(),
                                               torch.from_numpy(y_train.reshape(-1, 1)).float())
            train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)

            # Step 4: Initiate and train the model
            model = MVNN(input_dim= len(additive_prefs), num_hidden_layers= model_param_dictionary['UNN_layers'],
                num_units= model_param_dictionary['UNN_units'], layer_type= model_param_dictionary['UNN_layer_type'], target_max= ymax)
            optimizer = torch.optim.Adam(model.parameters(), lr= model_param_dictionary['lr'], weight_decay= model_param_dictionary['weight_decay'])

            trained_model = train_unn_lean(model, optimizer, train_loader, ymax, epochs = model_param_dictionary['epochs'])

            start = timer()
            if trained_model._num_hidden_layers >= 1:
                solver = GUROBI_MIP2_MVNN(trained_model)
            else:
                print('special case, LINEAR MVNN!!!')
                solver = GUROBI_MIP2_MVNN_LINEAR(trained_model)
            mid = timer()
            solver.generate_mip(course_timetable=timetable,
               credit_units=np.repeat(1, len(additive_prefs)),
               cu_max=5,
               timeLimit=100,
               MIPGap=0.0001,
               verbose=False)

            end = timer()
            unn_student_list.append((trained_model, solver, budget))
            seconds_create_solver_total += (mid - start)
            seconds_generate_MIP_total += (end - mid)

        print(f'AVG time to create a solver for UNN after 100 solvers: {seconds_create_solver_total / 100}')
        print(f'AVG time to generate a MIP for UNN after 100 solvers: {seconds_generate_MIP_total / 100}')
        return unn_student_list

    return []


def capacities_generator(number_of_courses = 30, total_number_of_seats = 505, capacity_deviation = 0, seed = 42):
    """
    A random function generating the capacity of each course, given the total capacity:

    Parameters:
    -----------------
    number_of_courses: int
        The number of courses available.
    total_number_of_seats: int
        The total number of seats of all courses.
    capacity_deviation: float
        The standard deviation of the capacities' uniform distribution. Do not go larger than!
    seed: int
        The random seed used

    Returns:
    -----------------
    capacities: numpy array of shape (number_of_courses, )
        capacities[i]: The capacity of the i-th course
    """
    rng = np.random.default_rng(seed = seed)
    values = rng.uniform(low = 1 - capacity_deviation, high = 1, size = number_of_courses)
    capacities_float = (values / values.sum()) * total_number_of_seats
    capacities = np.rint(capacities_float)

    return capacities


def calculate_allocation_KT(true_student_profiles_all_runs, noisy_student_profiles_all_runs, noisy_allocations, timetables):
    """
    A function calculating the "KTs" for all the bundles that the noisy students got in CM (simulating the experiment)

    Paramters:
    -----------------
    true_student_profiles_all_runs: list of lists
        true_student_profile_all_runs[i]: The list of the true students' preferences on the i-th run of course match
    noisy_student_profiles_all_runs: list of lists
        noisy_student_profile_all_runs[i]: The list of the noisy students' preferences on the i-th run of course match
    noisy_allocations: numpy array of shape (number of runs, number of students, number of courses)
        noisy_allocations[k][i][j] = 1 if the i-th student got the j-th course in the k-th run of CM, 0 else.
    timetables: list list of lists of ints
        timetable[k][i][j]: The indexes of the courses taught in the j-th hour of the i-th day, for the k-th run of CM.

    Returns:
    KT_list: numpy array of shape (number of runs, number of students)
        KT_list[k][i]: The KT between the true and the noisy version of the i-th student for the k-th run of course match.
    """
    KT_list = [[] for k in range(len(true_student_profiles_all_runs))]

    for (k, true_student_profiles) in enumerate(true_student_profiles_all_runs):
        noisy_student_profiles = noisy_student_profiles_all_runs[k]
        timetable = timetables[k]
        noisy_allocation = noisy_allocations[k]
        for (i, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(true_student_profiles):
            additive_prefs_noisy, substitutes_noisy, complements_noisy, overload_penalty_noisy, timegap_penalty_noisy, free_days_marginal_values_noisy, budget_noisy = noisy_student_profiles[i]

            true_values_noisy_allocation = np.array([student(noisy_allocation[j], additive_prefs, substitutes, complements, timetable,   # the true value of the student for the allocation of every student
                            overload_penalty = overload_penalty, free_days_marginal_values= free_days_marginal_values,
                            credit_units= [1 for i in range(additive_prefs.shape[0])], make_monotone= True) for j in range(len(noisy_allocation))])

            noisy_values_noisy_allocation = np.array([student(noisy_allocation[j], additive_prefs_noisy, substitutes_noisy, complements_noisy, timetable,   # the reported value of the student for the allocation of everyone
                            overload_penalty = overload_penalty_noisy, free_days_marginal_values= free_days_marginal_values_noisy,
                            credit_units= [1 for i in range(additive_prefs.shape[0])], make_monotone= True) for j in range(len(noisy_allocation))])

            own_value = noisy_values_noisy_allocation[i]
            filtered_noisy = noisy_values_noisy_allocation[noisy_values_noisy_allocation > 0.5 * own_value]
            filtered_true = true_values_noisy_allocation[noisy_values_noisy_allocation > 0.5 * own_value]

            print(filtered_noisy.shape)
            tau, p_value = scipy.stats.kendalltau(filtered_noisy, filtered_true)
            KT_list[k].append(tau)

    return np.array(KT_list)

# stage 2 functions of Course Match


def calculate_oversubscription(prices, student_list, timetable, credit_units, max_capacities, model_param_dictionary, model_type = 'True'):
    """
    For a given price vector p, calculates the oversubscription of all courses, corresponds to the function widehat d in the paper by Budish et Al.

    Parameters:
    --------------------
    prices: np.array of shape(number_of_courses, )
        prices[i]: The price of the i-th course
    student_list: list of tuples of the form (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)
        student_profile[i]: Those numbers for the i-th student
    course_timetable: list of lists of ints
        course_timetable[i][j]: The ids of all courses being taught in the j-th timeslot of the i-th day
    credit_units: list of floats
        credit_units[i]: The credit units of the i-th course
    max_capacities: np.array of shape(number_of_courses, )
        max_capacitie[i]: The maximum permissible capacity of the i-th course

    Returns:
    --------------------
    oversubscription: np.array of shape (number_of_courses, )
        over_subscription[i]: The total amount of oversubscription of the i-th course, or 0 if it is not oversubscribed
    """
    student_demand = calculate_total_demand(prices, student_list, timetable, credit_units= credit_units, model_type= model_type, model_param_dictionary= model_param_dictionary)
    subscription_diffs = student_demand - max_capacities
    oversubscription = np.maximum(subscription_diffs, 0)
    return oversubscription


def get_epsilon(student_list):
    """
    A helper function that takes as input a student list and returns the smallest budget difference among students, epsilon.
    """

    budget_array = np.array([item[-1] for item in student_list])    # get all budgets
    budget_array.sort()                                             # sort them
    budget_differences = np.ediff1d(budget_array)                   # this arary contains the differences among consecutive budgets
    epsilon = np.min(budget_differences)                            # get the minimum such difference

    return epsilon


def oversubscription_elimination(prices, student_list, timetable, credit_units, max_capacities, maximum_price, model_param_dictionary, model_type = 'True', repetition_threshold = 3):
    """
    An iterative function that reduces by half the excess demand of the most oversubscriped course in every iteration, as described in Budish et Al.

    Parameters:
    --------------------
    prices: np.array of shape(number_of_courses, )
        prices[i]: The price of the i-th course
    student_list: list of tuples of the form (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)
        student_profile[i]: Those numbers for the i-th student
    course_timetable: list of lists of ints
        course_timetable[i][j]: The ids of all courses being taught in the j-th timeslot of the i-th day
    credit_units: list of floats
        credit_units[i]: The credit units of the i-th course
    max_capacities: np.array of shape(number_of_courses, )
        max_capacitie[i]: The maximum permissible capacity of the i-th course

    Returns:
    total_demand: np.array of shape (number_of_courses, )
        total_demand[i]: The total demaoversubscription_elimination all students for the i-th course
    """

    print(f'over_subscription elimation called with maximum price: {maximum_price}')

    epsilon = 0.000001
    oversubscriptions = calculate_oversubscription(prices, student_list, timetable, credit_units, max_capacities, model_param_dictionary= model_param_dictionary, model_type= model_type)
    most_oversubscribed = np.argmax(oversubscriptions)
    occurances_dict = {}

    while(oversubscriptions[most_oversubscribed] > 0):
        target_oversubscription = math.floor(oversubscriptions[most_oversubscribed] / 2)
        print(f'Entering main while loop for course: {most_oversubscribed} "with oversubscription: {oversubscriptions[most_oversubscribed]} and target_oversubscription: {target_oversubscription}. epsilon is: {epsilon}')
        print("Current Oversubscription vector: ", oversubscriptions)
        price_low = prices[most_oversubscribed]
        price_high = maximum_price

        print(f'Entering inner while loop with initial price difference {price_high - price_low}')
        print(f'Initial price of that course: {price_low}')

        # if the same vector of oversubscriptions is encountered 3 times ->
        # most likely case is that the binary search is too fine grained and is actually jumping between equivalent solutions
        # -> double epsilon to compensate/speed things up!
        oversubscriptions_tuple = tuple(oversubscriptions)
        times_visited = occurances_dict.get(oversubscriptions_tuple, 0) + 1
        if (times_visited >= repetition_threshold):
            occurances_dict[oversubscriptions_tuple] = 0
            epsilon = epsilon * 2
        else:
            occurances_dict[oversubscriptions_tuple] = times_visited

        while(price_high - price_low > epsilon):
            prices[most_oversubscribed] = (price_low + price_high) / 2
            oversubscriptions = calculate_oversubscription(prices, student_list, timetable, credit_units, max_capacities, model_param_dictionary= model_param_dictionary, model_type = model_type)
            print(f'Current oversubcription of that course: {oversubscriptions[most_oversubscribed]} for price: {prices[most_oversubscribed]} and price diff: {price_high - price_low}')

            if (oversubscriptions[most_oversubscribed] > target_oversubscription):   # NOTE: CHANGED THIS FROM >=
                price_low = prices[most_oversubscribed]
            else:
                price_high = prices[most_oversubscribed]

        print(f'Price found was {prices[most_oversubscribed]} but instead setting its price to: {price_high}')
        prices[most_oversubscribed] = price_high

        oversubscriptions = calculate_oversubscription(prices, student_list, timetable, credit_units, max_capacities, model_type= model_type, model_param_dictionary = model_param_dictionary)
        print(f'For the final price decided, the oversubscription of that course is: {oversubscriptions[most_oversubscribed]}')
        most_oversubscribed = np.argmax(oversubscriptions)

    return prices


def aftermarket_allocations(prices, student_list_sorted, timetable,  individual_demands, capacities, credit_units, budget_increase_percetange = 1.1, check_sanity = False, max_courses = 5, model_type = 'True'):
    """
    Implementation of Algorithm 3 of Budish et Al.

    Parameters:
    --------------------
    prices: np.array of shape(number_of_courses, )
        prices[i]: The price of the i-th course
    student_list_sorted: model list for all of the students sorted according to stage 3 of the algorithm
        student_profile[i]: THe model represenation of the i-th student, sorted in accordance to stage 3 of the Course Match Mechanism
    timetable: list of lists of ints
        course_timetable[i][j]: The ids of all courses being taught in the j-th timeslot of the i-th day
    budget_increase_percentage: float
        The ratio of new budget/old_budget of every student, e.g. 1.1 was used at Wharton.
    credit_units: list of floats
        credit_units[i]: The credit units of the i-th course
    capacities: np.array of shape(number_of_courses, )
        capacities[i]: The capacity of the i-th course

    Returns:
    --------------------
    individual_demands: np.array of shape (number_of_students, number_of_courses, )
        individual_demands[i][j]: 1 if the i-thj student is allocated the j-th course at the end of course match, 0 otherwise.
    """

    done = False
    individual_demands_copy = individual_demands.copy()
    repetition_counter = 0

    while(not done and repetition_counter < individual_demands.shape[0] * 2):
        repetition_counter += 1
        done = True
        total_demands = individual_demands_copy.sum(axis = 0)
        free_seats = capacities - total_demands

        if (model_type in ['True', 'TrueNoisy']):
            for (i, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(student_list_sorted):
                seats_available_to_student = free_seats + individual_demands_copy[i]

                (student_demand, _, value) = solve_student(timetable, prices, credit_units, budget * budget_increase_percetange, max_courses, additive_prefs, complements, substitutes, overload_penalty = overload_penalty,
                            timegap_penalty= timegap_penalty, free_days_marginal_values= free_days_marginal_values, ignore_timegaps= True, verbose = False, seats_available = seats_available_to_student, time_output= True)
                student_demand = np.array(student_demand)

                if ((np.array(student_demand) == individual_demands_copy[i]).sum() != student_demand.shape[0]):
                    (student_demand_debug, _, value_debug) = solve_student(timetable, prices, credit_units, budget, max_courses, additive_prefs, complements, substitutes, overload_penalty = overload_penalty,
                        timegap_penalty= timegap_penalty, free_days_marginal_values= free_days_marginal_values, ignore_timegaps= True, verbose = False, time_output= True)
                    print(f'Student {i} changed his demand! NEW value: {value} OLD value: {value_debug}')
                    if (value > value_debug + 0.01):
                        done = False
                        individual_demands_copy[i] = student_demand
                        break
                    else:
                        print("Student tried to change his demand, but the change in value was less than 0.1")

        elif (model_type in ['LinearRegression', 'TrueLinear', 'Ridge', 'Lasso', 'ElasticNet', 'LinearNoisy', 'LinearRegressionNoisy', 'RidgeNoisy', 'LassoNoisy', 'ElasticNetNoisy']):
            for (i, (linear_coefs, budget)) in enumerate(student_list_sorted):
                seats_available_to_student = free_seats + individual_demands_copy[i]

                (student_demand, _, value) = solve_student(timetable, prices, credit_units, budget * budget_increase_percetange, max_courses, linear_coefs, [], [], overload_penalty = 0,
                            timegap_penalty= 0, free_days_marginal_values= [0, 0, 0, 0, 0], ignore_timegaps= True, verbose = False, seats_available = seats_available_to_student, time_output= True)
                student_demand = np.array(student_demand)

                if ((np.array(student_demand) == individual_demands_copy[i]).sum() != student_demand.shape[0]):
                    (student_demand_debug, _, value_debug) = solve_student(timetable, prices, credit_units, budget, max_courses, linear_coefs, [], [], overload_penalty = 0,
                        timegap_penalty= 0, free_days_marginal_values= [0, 0, 0, 0, 0], ignore_timegaps= True, verbose = False, time_output= True)
                    print(f'Student {i} changed his demand! NEW value (according to {model_type}): {value} OLD value: {value_debug}')
                    if (value > value_debug + 0.01):
                        done = False
                        individual_demands_copy[i] = student_demand
                        break
                    else:
                        print('not actually going through with the change because the value difference is minimal!')

        elif (model_type in ['PairwiseAdjustments', 'PairwiseAdjustmentsNoisy']):
            for (i, (additive_prefs, substitutes_clipped, complements_clipped, budget)) in enumerate(student_list_sorted):
                seats_available_to_student = free_seats + individual_demands_copy[i]

                (student_demand, _, value) = solve_student(timetable, prices, credit_units, budget * budget_increase_percetange, max_courses, additive_prefs, complements_clipped,
                    substitutes_clipped, overload_penalty = 0, timegap_penalty= 0, free_days_marginal_values= [0, 0, 0, 0, 0], ignore_timegaps= True, verbose = False,
                        seats_available = seats_available_to_student, time_output= True)
                student_demand = np.array(student_demand)

                if ((np.array(student_demand) == individual_demands_copy[i]).sum() != student_demand.shape[0]):
                    (student_demand_debug, _, value_debug) = solve_student(timetable, prices, credit_units, budget, max_courses, additive_prefs, complements_clipped,
                    substitutes_clipped, overload_penalty = 0,  timegap_penalty= 0, free_days_marginal_values= [0, 0, 0, 0, 0], ignore_timegaps= True,
                                                                           verbose = False, time_output= True)
                    print(f'Student {i} changed his demand! NEW value (according to PA): {value} OLD value: {value_debug}')
                    if (value > value_debug + 0.01):
                        done = False
                        individual_demands_copy[i] = student_demand
                        break
                    else:
                        print('not actually going through with the change because the value difference is minimal!')

        elif (model_type == 'NuSVR' or model_type == 'NuSVRNoisy'):
            for (i, (model, solver, gamma, budget)) in enumerate(student_list_sorted):
                seats_available_to_student = free_seats + individual_demands_copy[i]
                prices_copy = prices.copy()
                for k in range(len(seats_available_to_student)):
                    if(seats_available_to_student[k] < 1):
                        prices_copy[k] = 10  # this way the student cannot get courses with no capacitiy left in this stage, because they have too high of a price.

                solver.generate_mip(course_prices = prices_copy, credit_units = credit_units, budget = budget * budget_increase_percetange, cu_max = 5, course_timetable = timetable)
                optimal_schedule, optimal_value = solver.solve_mip(verbose=False)
                student_demand = np.array(optimal_schedule)

                if ((np.array(student_demand) == individual_demands_copy[i]).sum() != student_demand.shape[0]):
                    print(f'Student {i} changed his demand! NEW value (according to PA): {optimal_value}')
                    done = False
                    individual_demands_copy[i] = student_demand
                    break

        elif model_type in ['UNN', 'UNN_Noisy', 'UNN_transfer_learning']:
            if len(student_list_sorted[0]) == 3:
                for (i, (model, solver, budget)) in enumerate(student_list_sorted):
                    seats_available_to_student = free_seats + individual_demands_copy[i]
                    prices_copy = prices.copy()
                    for k in range(len(seats_available_to_student)):
                        if(seats_available_to_student[k] < 1):
                            prices_copy[k] = 10  # this way the student cannot get courses with no capacitiy left in this stage, because they have too high of a price.

                    solver.add_budget_constraint(course_prices = prices_copy, budget = budget * budget_increase_percetange)
                    student_demand, optimal_value = solver.solve_mip_rv(outputFlag=False, verbose = False)
                    student_demand = np.array(student_demand)

                    if ((np.array(student_demand) == individual_demands_copy[i]).sum() != student_demand.shape[0]):
                        print(f'Student {i} changed his demand! NEW value (according to UNN): {optimal_value}')
                        print(f'New demand: {student_demand}, old demand: {individual_demands_copy[i]}')
                        done = False
                        individual_demands_copy[i] = student_demand
                        break

            elif len(student_list_sorted[0]) == 5:
                for (i, (model, solver, scale, _, budget)) in enumerate(student_list_sorted):
                    seats_available_to_student = free_seats + individual_demands_copy[i]
                    prices_copy = prices.copy()
                    for k in range(len(seats_available_to_student)):
                        if(seats_available_to_student[k] < 1):
                            prices_copy[k] = 10  # this way the student cannot get courses with no capacitiy left in this stage, because they have too high of a price.

                    solver.add_budget_constraint(course_prices = prices_copy, budget = budget * budget_increase_percetange)
                    student_demand, optimal_value = solver.solve_mip_rv(outputFlag=False, verbose = False)
                    student_demand = np.array(student_demand)

                    if ((np.array(student_demand) == individual_demands_copy[i]).sum() != student_demand.shape[0]):
                        print(f'Student {i} changed his demand! NEW value (according to UNN): {optimal_value}')
                        print(f'New demand: {student_demand}, old demand: {individual_demands_copy[i]}')
                        done = False
                        individual_demands_copy[i] = student_demand
                        break

            elif len(student_list_sorted[0]) == 4:
                for (i, (model, solver, scale, budget)) in enumerate(student_list_sorted):
                    seats_available_to_student = free_seats + individual_demands_copy[i]
                    prices_copy = prices.copy()
                    for k in range(len(seats_available_to_student)):
                        if(seats_available_to_student[k] < 1):
                            prices_copy[k] = 10  # this way the student cannot get courses with no capacitiy left in this stage, because they have too high of a price.

                    solver.add_budget_constraint(course_prices = prices_copy, budget = budget * budget_increase_percetange)
                    student_demand, optimal_value = solver.solve_mip_rv(outputFlag=False, verbose = False)
                    student_demand = np.array(student_demand)

                    if ((np.array(student_demand) == individual_demands_copy[i]).sum() != student_demand.shape[0]):
                        print(f'Student {i} changed his demand! NEW value (according to UNN): {optimal_value}')
                        print(f'New demand: {student_demand}, old demand: {individual_demands_copy[i]}')
                        done = False
                        individual_demands_copy[i] = student_demand
                        break
        
        elif (model_type == 'UNN_projected'):
            linear_coefs = student_list_sorted[0][0].coef_
            
            if len(linear_coefs) > len(prices):
                print('Linear coefs are longer than prices, this means we are using quadratic terms!')
                for (i, (linear_model, budget)) in enumerate(student_list_sorted):
                    # print(f'Current student: {i} with budget: {budget}')
                    seats_available_to_student = free_seats + individual_demands_copy[i]
                    prices_copy = prices.copy()
                    for k in range(len(seats_available_to_student)):
                        if(seats_available_to_student[k] < 1):
                            prices_copy[k] = 10

                    solver = GUROBI_MIP_POLY_REGRESSION(linear_model)
                    solver.generate_mip(course_timetable = timetable,
                                credit_units = credit_units,
                                course_prices = prices_copy,
                                budget = budget * budget_increase_percetange,
                                cu_max = 5,
                                timeLimit = None,
                                MIPGap = None,
                                verbose = False,
                                )

                    # get the opt schedule according to the solver
                    student_demand = np.array(solver.solve_mip(verbose = False))
                    # set_trace()
                    if (((student_demand) == individual_demands_copy[i]).sum() != student_demand.shape[0]):
                        new_value = solver.mip.getObjective().getValue()
                        poly = PolynomialFeatures(degree = 2, include_bias = False, interaction_only = True)
                        X_poly_new = poly.fit_transform(student_demand.reshape(-1, 25))
                        X_poly_old = poly.fit_transform(individual_demands_copy[i].reshape(-1, 25))
                        pred_new = linear_model.predict(X_poly_new)
                        pred_old = linear_model.predict(X_poly_old)
                        print(f'Student {i} changed his demand! NEW value (according to projection): {pred_new} and solver:{new_value} OLD value: {pred_old}')
                        print(f'Student {i} changed his demand!')
                        done = False
                        individual_demands_copy[i] = student_demand
                        break

            else:
                print('ACHTUNG ACHTUNG! Linear coefs are shame length as prices, this means we are only using additive terms!')
                for (i, (linear_model, budget)) in enumerate(student_list_sorted):
                    seats_available_to_student = free_seats + individual_demands_copy[i]
                    linear_coefs = linear_model.coef_.reshape(-1)

                    (student_demand, _, value) = solve_student(timetable, prices, credit_units, budget * budget_increase_percetange, max_courses, linear_coefs, [], [], overload_penalty = 0,
                                timegap_penalty= 0, free_days_marginal_values= [0, 0, 0, 0, 0], ignore_timegaps= True, verbose = False, seats_available = seats_available_to_student, time_output= True)
                    student_demand = np.array(student_demand)

                    if ((np.array(student_demand) == individual_demands_copy[i]).sum() != student_demand.shape[0]):
                        (student_demand_debug, _, value_debug) = solve_student(timetable, prices, credit_units, budget, max_courses, linear_coefs, [], [], overload_penalty = 0,
                            timegap_penalty= 0, free_days_marginal_values= [0, 0, 0, 0, 0], ignore_timegaps= True, verbose = False, time_output= True)
                        print(f'Student {i} changed his demand! NEW value (according to {model_type}): {value} OLD value: {value_debug}')
                        if (value > value_debug + 0.01):
                            done = False
                            individual_demands_copy[i] = student_demand
                            break
                        else:
                            print('not actually going through with the change because the value difference is minimal!')

        elif (model_type == 'xgboost' or model_type == 'xgboostNoisy'):
            for (i, (model, solver, budget)) in enumerate(student_list_sorted):
                seats_available_to_student = free_seats + individual_demands_copy[i]
                prices_copy = prices.copy()
                for k in range(len(seats_available_to_student)):
                    if(seats_available_to_student[k] < 1):
                        prices_copy[k] = 10  # this way the student cannot get courses with no capacitiy left in this stage, because they have too high of a price.

                solver.add_budget_constraint(course_prices=prices_copy, budget= budget * budget_increase_percetange)
                student_demand, optimal_value = solver.solve_mip()

                if ((np.array(student_demand) == individual_demands_copy[i]).sum() != student_demand.shape[0]):
                    print(f'Student {i} changed his demand! NEW value (according to PA): {optimal_value}')
                    done = False
                    individual_demands_copy[i] = student_demand
                    break
        
        
                    

        else: 
            print('Model type not recognized!')
            
            return

    return individual_demands_copy


# ---- Splitting Course Match (Run Tabu Principled into smaller functions so taht we can also run all of them individually) ------

def generate_problem_instance_principled(number_of_times = 10, number_of_courses = 30, number_of_students = 100, supply_ratio = 1.1, capacity_deviation = 0, grid_heigth = 5, grid_width = 6, complement_range = 1, substitute_range = 1,
        complements_expected_value = 0.4, substitutes_expected_value = 0.4, complements_decay = 0.8, substitutes_decay = 0.8, favourites_value_range = (80, 120), non_favourites_value_range = (40, 60),
        number_of_popular = 8, mean_number_of_favourites = 2, number_of_centers = 2, disable_time_prefs = True, fixed_comp_values = True, seed = 42):
    """
    Generates true student profiles for the number of runs specified so that they can then be fed in the different individual stages of course match.
    """

    true_student_list_all_runs = []
    capacities_all_runs = []
    timetables_all_runs = []

    for i in range(number_of_times):
        print(f'Generating students for run {i}')
        np.random.seed(seed + (i * number_of_students))
        student_list = create_multiple_students(number_of_students, number_of_courses, gr_height= grid_heigth, gr_width= grid_width,
            complement_range= complement_range, substitute_range= substitute_range, complements_expected_value = complements_expected_value, substitutes_expected_value = substitutes_expected_value,
            seed = seed + (i * number_of_students), favourites_value_range= favourites_value_range, non_favourites_value_range= non_favourites_value_range,
            number_of_popular= number_of_popular, mean_number_of_favourites = mean_number_of_favourites, number_of_centers = number_of_centers,
            disable_time_prefs = disable_time_prefs, fixed_comp_values= fixed_comp_values)
        capacities = capacities_generator(number_of_courses = number_of_courses, total_number_of_seats= number_of_students * 5 * supply_ratio, capacity_deviation= capacity_deviation, seed = seed + i)
        timetable = timetable_generator(number_of_courses, credit_units= [0.5 for i in range(int(math.floor(number_of_courses / 2)))] + [1 for i in range(int(math.ceil(number_of_courses / 2)))], seed = seed + i)

        true_student_list_all_runs.append(student_list)
        capacities_all_runs.append(capacities)
        timetables_all_runs.append(timetable)

    return true_student_list_all_runs, np.array(capacities_all_runs), timetables_all_runs


def run_stage1_principled(true_student_profiles, timetables_all_runs, capacities_all_runs, percentage_neighbors = 3,  gradient_neighbors = 20, individual_neighbors = 20, maximum_number_of_restarts = 5,
     seed = 42,    clearing_error_limit = 1,  time_limit_restart = 80, time_limit_search = 500, max_gradient_multiplier = (0.1 / (2**6)), max_steps_without_improvement = 5,
     models_to_run = [('LinearRegressionNoisy', {'samples': 50, 'noisy_forget_base': 0.15, 'noisy_forget_adjustments': 0, 'noisy_base_std': 3, 'noisy_adj_std': 0.05}), ('LinearRegression', {'samples': 50}),
        ('RidgeNoisy', {'samples': 50, 'noisy_forget_base': 0.15, 'noisy_forget_adjustments': 0, 'noisy_base_std': 3, 'noisy_adj_std': 0.05}), ('Ridge', {'samples': 50}),
        ('LassoNoisy', {'samples': 50, 'noisy_forget_base': 0.15, 'noisy_forget_adjustments': 0, 'noisy_base_std': 3, 'noisy_adj_std': 0.05}), ('Lasso', {'samples': 50}),
        ('ElasticNetNoisy', {'samples': 50, 'noisy_forget_base': 0.15, 'noisy_forget_adjustments': 0, 'noisy_base_std': 3, 'noisy_adj_std': 0.05}), ('ElasticNet', {'samples': 50}),
        ('UNN_Noisy', {'UNN_layers': 1, 'UNN_units': 16, 'UNN_layer_type': 'CALayerReLUProjected', 'UNN_L': 100, 'UNN_tol': 1e-8, 'UNN_epochs': 200, 'UNN_lr': 1e-3, 'samples': 50,
            'noisy_forget_base': 0.15, 'noisy_forget_adjustments': 0, 'noisy_base_std': 3, 'noisy_adj_std': 0.05}),
        ('UNN', {'UNN_layers': 1, 'UNN_units': 16, 'UNN_layer_type': 'CALayerReLUProjected', 'UNN_L': 100, 'UNN_tol': 1e-8, 'UNN_epochs': 200, 'UNN_lr': 1e-3, 'samples': 50}),
        ('LinearNoisy', {'noisy_forget_base': 0.15, 'noisy_forget_adjustments': 0, 'noisy_base_std': 3, 'noisy_adj_std': 0.05}), ('TrueLinear', None),
        ('PairwiseAdjustmentsNoisy', {'noisy_forget_base': 0.15, 'noisy_forget_adjustments': 0, 'noisy_base_std': 3, 'noisy_adj_std': 0.05}), ('PairwiseAdjustments', None),
        ('TrueNoisy', {'noisy_forget_base': 0.15, 'noisy_forget_adjustments': 0, 'noisy_base_std': 3, 'noisy_adj_std': 0.05}), ('True', None)]):

    value_list_total_stage1_all_models = []  # final shape: <models> x <runs> x <students>
    time_taken_stage1_all_models = []

    number_of_courses = len(true_student_profiles[0][0][0])  # true_student_profiles shape: run x student
    number_of_students = len(true_student_profiles[0])
    print(f'number of students: {number_of_students} and number of courses: {number_of_courses}')

    clearing_error_total_stage1_all_models = []
    oversubcription_error_total_stage1_all_models = []

    allocations_all_models = []  # final shape: <models> x <runs> x <students>
    prices_all_models_s1 = []

    student_list_total_all_runs = []   # final shape: <models> x <runs> x <students>

    for i in range(len(models_to_run)):
        value_list_total_stage1_all_models.append([])
        clearing_error_total_stage1_all_models.append([])
        oversubcription_error_total_stage1_all_models.append([])
        allocations_all_models.append([])
        prices_all_models_s1.append([])

        student_list_total_all_runs.append([])
        time_taken_stage1_all_models.append([])

    if (clearing_error_limit is None):
        clearing_error_limit = calculate_theoretical_bound_squared(number_of_courses_M= number_of_courses, largest_bundle_size_k= 5)  # the theoretical bound

    # get all the timetables and the lists for the experiments
    for i in range(len(true_student_profiles)):
        print(f'Problem instance number: {i}')
        np.random.seed(seed + (i * number_of_students))
        capacities = capacities_all_runs[i]
        # set_trace()
        timetable = timetables_all_runs[i]
        student_list = true_student_profiles[i]
        maximum_budget = np.max([student[-1] for student in student_list])
        print(f'Total number of seats: {np.sum(capacities)}')

        # Step 1: Create compatitable lists for all models
        student_list_per_model = []
        for (j, (model_type, model_param_dictionary)) in enumerate(models_to_run):
            model_student_list = create_model_student_list(student_list, timetable, model_type = model_type, seed = seed + (i * number_of_students), model_param_dictionary= model_param_dictionary)
            student_list_total_all_runs[j].append(model_student_list)
            student_list_per_model.append(model_student_list)

        timetables_all_runs.append(timetable)

        # step 2: Run Tabu search for all models
        tabu_prices_all_models = []
        statistics_all_models = []
        for (i, (model_type, model_param_dictionary)) in enumerate(models_to_run):
            start = timer()
            model_student_list = student_list_per_model[i]
            # set_trace()
            tabu_prices_model, final_error_model, statistics_model = heuristic_search(model_student_list, timetable, credit_units= [1 for i in range(number_of_courses)],
                capacities= capacities, max_budget = maximum_budget, max_steps_without_improvement = max_steps_without_improvement, clearing_error_limit= clearing_error_limit,
                time_limit_restart= time_limit_restart, time_limit_search= time_limit_search, number_percentage_neighbors= percentage_neighbors,
                number_gradient_neighbors = gradient_neighbors, number_individual_neighbors= individual_neighbors, max_restarts = maximum_number_of_restarts,
                model_type= model_type, model_param_dictionary = model_param_dictionary, max_gradient_multiplier= max_gradient_multiplier)
            end = timer()
            tabu_prices_all_models.append(tabu_prices_model)
            clearing_error_total_stage1_all_models[i].append(final_error_model)
            statistics_all_models.append(statistics_model)
            time_taken = end - start
            print(f'{model_type} tabu completed in {time_taken} seconds')
            time_taken_stage1_all_models[i].append(time_taken)
            prices_all_models_s1[i].append(tabu_prices_model)

        # step 3: get individual demands induced by all price vectors
        individual_demands_all_models = []
        for (i, (model_type, model_param_dictionary)) in enumerate(models_to_run):
            total_demand_model, individual_demands_model = calculate_total_demand(tabu_prices_all_models[i], student_list_per_model[i], timetable,
                    [1 for i in range(number_of_courses)], return_individual_demands= True, model_type= model_type, model_param_dictionary = model_param_dictionary)
            # print(f'For model {model_type} the actual total demand is: {total_demand_model.sum()} and the individual_demands are: {individual_demands_model.sum()}')
            individual_demands_all_models.append(individual_demands_model)
            allocations_all_models[i].append(individual_demands_model)
            oversubscription = np.maximum(total_demand_model - capacities, 0).sum()
            oversubcription_error_total_stage1_all_models[i].append(oversubscription)

        # step 4: get the student value for all individual demands induced by all price vectors
        for i in range(len(models_to_run)):
            # print(f'Calculating total value for model type: {models_to_run[i][0]}')
            value_list_model = []
            for (j, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(student_list):
                student_value = student(individual_demands_all_models[i][j], additive_prefs, substitutes, complements, timetable,
                        overload_penalty = overload_penalty, free_days_marginal_values= free_days_marginal_values,
                        credit_units= [1 for i in range(number_of_courses)], make_monotone= True)
                value_list_model.append(student_value)

            value_list_total_stage1_all_models[i].append(value_list_model)

    for (i, model_type) in enumerate(models_to_run):  # cleanup the gurobi xgboost solver because it can't be saved like this (and we don't need it!)
        if model_type[0] in ['xgboost', 'xgboostNoisy', 'UNN', 'UNN_Noisy']:
            for j in range(len(student_list_total_all_runs[i])):   # i is the run
                for k in range(len(student_list_total_all_runs[i][j])):
                    (model, solver, budget) = student_list_total_all_runs[i][j][k]
                    student_list_total_all_runs[i][j][k] = (model, 'solver', budget)

        elif (model_type[0] == 'NuSVR' or model_type == 'NuSVRNoisy'):
            for j in range(len(student_list_total_all_runs[i])):   # i is the run
                for k in range(len(student_list_total_all_runs[i][j])):
                    (model, solver, gamma, budget) = student_list_total_all_runs[i][j][k]
                    student_list_total_all_runs[i][j][k] = (model, 'solver', gamma, budget)

    return (np.array(value_list_total_stage1_all_models), student_list_total_all_runs, np.array(clearing_error_total_stage1_all_models), np.array(oversubcription_error_total_stage1_all_models),
        np.array(allocations_all_models), np.array(time_taken_stage1_all_models), np.array(prices_all_models_s1))


def unfreeze_model_student_list(loaded_student_list, model_type, number_of_courses, course_timetable):
    """
    For some models, the corresponding student_list cannot be saved as-is for some reason, (usually because the Gurobi solver cannot be saved).
    In these cases, this function will amend the student lists.
    """

    if (model_type == 'xgboost' or model_type == 'xgboostNoisy'):
        xgboost_student_list = []
        for (model, _, budget) in loaded_student_list:
            solver = gurobi_MIP_xgboost(model, number_of_courses)
            solver.generate_mip(credit_units=np.repeat(1, number_of_courses), cu_max=5, course_timetable= course_timetable)
            xgboost_student_list.append((model, solver, budget))

        return xgboost_student_list

    elif (model_type == 'NuSVR' or model_type == 'NuSVRNoisy'):
        # set_trace()
        svr_student_list = []
        for (model, _, gamma, budget) in loaded_student_list:
            solver = gurobi_MIP_SVR(model, gamma)
            solver.generate_mip(course_prices = np.repeat(0, number_of_courses), credit_units = np.repeat(1, number_of_courses), budget = budget, course_timetable = course_timetable, cu_max = 5, verbose = False)

            svr_student_list.append((model, solver, gamma, budget))
            # TODO: SVR MIP RIGHT NOW GENERATES MIP EVERY TIME TO CALCULATE THE DEMAND!
        return svr_student_list

    elif model_type in ['UNN', 'UNN_Noisy', 'UNN_transfer_learning']:
        unn_student_list = []

        if len(loaded_student_list[0]) == 3:
            for (model, _, budget) in loaded_student_list:
                if model._num_hidden_layers >= 1:
                    solver = GUROBI_MIP2_MVNN(model)
                else:
                    print('special case, LINEAR MVNN!!!')
                    solver = GUROBI_MIP2_MVNN_LINEAR(model)
                solver.generate_mip(course_timetable=course_timetable,
                    credit_units=np.repeat(1, number_of_courses),
                    cu_max=5,
                    timeLimit=100,
                    MIPGap=0.0001,
                    verbose=False)
                unn_student_list.append((model, solver, budget))

        elif len(loaded_student_list[0]) == 4:
            print('Unfreezing a model with 4 items per student -> scale probably included')
            for (model, _, scale, budget) in loaded_student_list:
                if model._num_hidden_layers >= 1:
                    solver = GUROBI_MIP2_MVNN(model)
                else:
                    print('special case, LINEAR MVNN!!!')
                    solver = GUROBI_MIP2_MVNN_LINEAR(model)
                solver.generate_mip(course_timetable=course_timetable,
                    credit_units=np.repeat(1, number_of_courses),
                    cu_max=5,
                    timeLimit=100,
                    MIPGap=0.0001,
                    verbose=False)
                unn_student_list.append((model, solver, scale, budget))   # NOTE: We do not actually need the scale of the MVNNs for stages 2,3 -> can just skip!
                # unn_student_list.append((model, solver, budget))

        elif len(loaded_student_list[0]) == 5:
            print('Unfreezing a model with 5 items per student -> it is a TL model')
            for (model, _, scale, pretrained_model, budget) in loaded_student_list:
                if model._num_hidden_layers >= 1:
                    solver = GUROBI_MIP2_MVNN(model)
                else:
                    print('special case, LINEAR MVNN!!!')
                    solver = GUROBI_MIP2_MVNN_LINEAR(model)
                solver.generate_mip(course_timetable=course_timetable,
                    credit_units=np.repeat(1, number_of_courses),
                    cu_max=5,
                    timeLimit=100,
                    MIPGap=0.0001,
                    verbose=False)
                unn_student_list.append((model, solver, scale, pretrained_model, budget))   
                # unn_student_list.append((model, solver, budget))

        return unn_student_list

    return loaded_student_list


def run_stages2_3_principled(models_to_run, true_student_profiles, loaded_student_lists, prices_stage_1, timetables_all_runs, capacities_all_runs, seed = 42):
    # ---  stage 2  ---
    # step 5: Run stage 2 for all models

    value_list_total_stage2_all_models = []  # final shape: <models> x <runs> x <students>
    value_list_total_stage3_all_models = []  # final shape: <models> x <runs> x <students>
    time_taken_stage2_all_models = []
    time_taken_stage3_all_models = []
    prices_all_models_s2 = []

    number_of_courses = prices_stage_1.shape[2]   # true_student_profiles shape: run x student
    number_of_students = len(true_student_profiles[0])
    print(f'number of students: {number_of_students} and number of courses: {number_of_courses}')

    allocations_all_models_stage2 = []  # final shape: <models> x <runs> x <students>
    allocations_all_models_stage3 = []  # final shape: <models> x <runs> x <students>
    student_list_total_all_runs = []   # final shape: <models> x <runs> x <students>


    for i in range(len(models_to_run)):
        value_list_total_stage2_all_models.append([])
        value_list_total_stage3_all_models.append([])
        prices_all_models_s2.append([])
        allocations_all_models_stage2.append([])
        allocations_all_models_stage3.append([])

        student_list_total_all_runs.append([])
        time_taken_stage2_all_models.append([])
        time_taken_stage3_all_models.append([])

    for i in range(len(true_student_profiles)):
        print(f'Problem instance number: {i}')
        np.random.seed(seed + (i * number_of_students))
        capacities = capacities_all_runs[i]
        # set_trace()
        timetable = timetables_all_runs[i]
        student_list_true = true_student_profiles[i]
        maximum_budget = np.max([student[-1] for student in student_list_true])
        print(f'Total number of seats: {np.sum(capacities)}')

        # Step 6: (Re)create compatitable lists for all models
        student_list_per_model = []   # student_list_per_model shape: number of models x number of students (it only contains a single run!!!)
        # This for goes over all models to run in a single problem instance
        for (j, (model_type, model_param_dictionary)) in enumerate(models_to_run):
            # set_trace()
            # loaded student_lists shape: Model x Run x student
            model_student_list = unfreeze_model_student_list(loaded_student_lists[j][i], model_type = model_type, number_of_courses= number_of_courses, course_timetable= timetable)
            student_list_total_all_runs[j].append(model_student_list)
            student_list_per_model.append(model_student_list)


        # Step 7: Calculate stage 2 prices for a single run, for all models
        for (j, (model_type, model_param_dictionary)) in enumerate(models_to_run):
            start = timer()
            prices_adjusted = oversubscription_elimination(prices = prices_stage_1[j][i], student_list = student_list_per_model[j], timetable = timetable,
                    credit_units= [1 for i in range(number_of_courses)], max_capacities= capacities, maximum_price= maximum_budget, model_param_dictionary= model_param_dictionary, model_type = model_type)
            end = timer()
            prices_all_models_s2[j].append(prices_adjusted)
            time_taken = end - start
            print(f'Finished with stage 2 for the {model_type} model in {time_taken} seconds')
            time_taken_stage2_all_models[j].append(time_taken)

        # step 8: get individual demands induced by all stage 2 price vectors
        individual_demands_all_models_s2 = []
        for (j, (model_type, model_param_dictionary)) in enumerate(models_to_run):
            total_demand_model, individual_demands_model = calculate_total_demand(prices_all_models_s2[j][-1], student_list_per_model[j], timetable,
                    [1 for i in range(number_of_courses)], return_individual_demands= True, model_type= model_type, model_param_dictionary= model_param_dictionary)
            individual_demands_all_models_s2.append(individual_demands_model)
            allocations_all_models_stage2[j].append(individual_demands_model)

        # step 9: Get the students' values for those allocations
        for j in range(len(models_to_run)):
            value_list_model = []
            for (k, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(student_list_true):
                student_value = student(individual_demands_all_models_s2[j][k], additive_prefs, substitutes, complements, timetable,
                        overload_penalty = overload_penalty, free_days_marginal_values= free_days_marginal_values,
                        credit_units= [1 for i in range(number_of_courses)], make_monotone= True)
                value_list_model.append(student_value)

            value_list_total_stage2_all_models[j].append(value_list_model)

        #     ---  stage 3 ---
        # step 10: Run stage 3 for all models
        final_allocation_all_models = []
        individual_demands_all_models_copy = individual_demands_all_models_s2.copy()
        for (j, (model_type, model_param_dictionary)) in enumerate(models_to_run):
            start = timer()
            final_allocation_model = aftermarket_allocations(prices_all_models_s2[j][-1], student_list_per_model[j], timetable,
                    individual_demands_all_models_copy[j], capacities, [1 for i in range(number_of_courses)], model_type = model_type)
            end = timer()
            print(f'Number of matches between stage2 and stage3 for this model: {(final_allocation_model == individual_demands_all_models_s2[j]).sum()}')
            final_allocation_all_models.append(final_allocation_model)
            allocations_all_models_stage3[j].append(final_allocation_model)
            time_taken = end - start
            print(f'Stage 3 for the {model_type} model finished in {end - start} seconds.')
            time_taken_stage3_all_models[j].append(time_taken)

        # step 9: Calculate the value of those allocations
        # set_trace()
        for j in range(len(models_to_run)):
            value_list_model = []
            for (k, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(student_list_true):
                student_value = student(allocations_all_models_stage3[j][-1][k], additive_prefs, substitutes, complements, timetable,
                        overload_penalty = overload_penalty, free_days_marginal_values= free_days_marginal_values,
                        credit_units= [1 for i in range(number_of_courses)], make_monotone= True)
                value_list_model.append(student_value)

            value_list_total_stage3_all_models[j].append(value_list_model)

    # set_trace()
    value_list_total_stage2_all_models = np.array(value_list_total_stage2_all_models)
    allocations_all_models_stage2 = np.array(allocations_all_models_stage2)
    prices_all_models_s2 = np.array(prices_all_models_s2)
    time_taken_stage2_all_models = np.array(time_taken_stage2_all_models)
    value_list_total_stage3_all_models = np.array(value_list_total_stage3_all_models)
    allocations_all_models_stage3 = np.array(allocations_all_models_stage3)
    time_taken_stage3_all_models = np.array(time_taken_stage3_all_models)

    return value_list_total_stage2_all_models, allocations_all_models_stage2, prices_all_models_s2, time_taken_stage2_all_models, value_list_total_stage3_all_models, allocations_all_models_stage3, time_taken_stage3_all_models


def run_stage3_bandaid(models_to_run, true_student_profiles, loaded_student_lists, prices_stage_2, timetables_all_runs, capacities_all_runs, seed = 42):
    # ---  stage 2  ---
    # step 5: Run stage 2 for all models

    value_list_total_stage3_all_models = []  # final shape: <models> x <runs> x <students>
    time_taken_stage3_all_models = []

    number_of_courses = prices_stage_2.shape[2]   # true_student_profiles shape: run x student
    number_of_students = len(true_student_profiles[0])
    print(f'number of students: {number_of_students} and number of courses: {number_of_courses}')

    allocations_all_models_stage3 = []  # final shape: <models> x <runs> x <students>
    student_list_total_all_runs = []   # final shape: <models> x <runs> x <students>

    for i in range(len(models_to_run)):
        value_list_total_stage3_all_models.append([])
        allocations_all_models_stage3.append([])
        student_list_total_all_runs.append([])
        time_taken_stage3_all_models.append([])

    for i in range(len(true_student_profiles)):
        print(f'Problem instance number: {i}')
        np.random.seed(seed + (i * number_of_students))
        capacities = capacities_all_runs[i]
        # set_trace()
        timetable = timetables_all_runs[i]
        student_list_true = true_student_profiles[i]
        print(f'Total number of seats: {np.sum(capacities)}')

        # Step 6: (Re)create compatitable lists for all models
        student_list_per_model = []   # student_list_per_model shape: number of models x number of students (it only contains a single run!!!)
        # This for goes over all models to run in a single problem instance
        for (j, (model_type, model_param_dictionary)) in enumerate(models_to_run):
            # set_trace()
            # loaded student_lists shape: Model x Run x student
            model_student_list = unfreeze_model_student_list(loaded_student_lists[j][i], model_type = model_type, number_of_courses= number_of_courses, course_timetable= timetable)
            student_list_total_all_runs[j].append(model_student_list)
            student_list_per_model.append(model_student_list)

        # Step 7: Calculate stage 2 prices for a single run, for all models (no need)

        # step 8: get individual demands induced by all stage 2 price vectors
        individual_demands_all_models_s2 = []
        for (j, (model_type, model_param_dictionary)) in enumerate(models_to_run):
            total_demand_model, individual_demands_model = calculate_total_demand(prices_stage_2[j][i], student_list_per_model[j], timetable,
                    [1 for i in range(number_of_courses)], return_individual_demands= True, model_type= model_type, model_param_dictionary= model_param_dictionary)
            individual_demands_all_models_s2.append(individual_demands_model)

        #     ---  stage 3 ---
        # step 10: Run stage 3 for all models
        final_allocation_all_models = []
        individual_demands_all_models_copy = individual_demands_all_models_s2.copy()
        for (j, (model_type, model_param_dictionary)) in enumerate(models_to_run):
            start = timer()
            final_allocation_model = aftermarket_allocations(prices_stage_2[j][i], student_list_per_model[j], timetable,
                    individual_demands_all_models_copy[j], capacities, [1 for i in range(number_of_courses)], model_type = model_type)
            end = timer()
            print(f'Number of matches between stage2 and stage3 for this model: {(final_allocation_model == individual_demands_all_models_s2[j]).sum()}')
            final_allocation_all_models.append(final_allocation_model)
            allocations_all_models_stage3[j].append(final_allocation_model)
            time_taken = end - start
            print(f'Stage 3 for the {model_type} model finished in {end - start} seconds.')
            time_taken_stage3_all_models[j].append(time_taken)

        # step 9: Calculate the value of those allocations
        # set_trace()
        for j in range(len(models_to_run)):
            value_list_model = []
            for (k, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(student_list_true):
                student_value = student(allocations_all_models_stage3[j][-1][k], additive_prefs, substitutes, complements, timetable,
                        overload_penalty = overload_penalty, free_days_marginal_values= free_days_marginal_values,
                        credit_units= [1 for i in range(number_of_courses)], make_monotone= True)
                value_list_model.append(student_value)

            value_list_total_stage3_all_models[j].append(value_list_model)

    # set_trace()
    value_list_total_stage3_all_models = np.array(value_list_total_stage3_all_models)
    allocations_all_models_stage3 = np.array(allocations_all_models_stage3)
    time_taken_stage3_all_models = np.array(time_taken_stage3_all_models)

    return value_list_total_stage3_all_models, allocations_all_models_stage3, time_taken_stage3_all_models


def mean_confidence_interval(data, confidence=0.95, normal_distr = False):
    """
    Takes as input a 1-D array and a confidence percentage and returns
    the corresponding confidence interval.
    """
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    if normal_distr:
        h = se * scipy.stats.norm.ppf((1 + confidence) / 2., n-1)
    else:
        h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return h


def read_principled_results_enhanced(result, clearing_errors, oversubscription_errors, model_names, model_lists = None, timetables = None, models_run = None, test_samples = 2000, test_samples_range = 2000, value_range = (250, math.inf), CI = 0.95, scale_utilities = True):

    if (scale_utilities):
        scaling = 100 / result[-1].mean(axis = 1)
    else:
        scaling = 1

    # AVG UTILITY:
    avg_utility = [f'{(result[i].mean(axis = 1) * scaling).mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(result[i].mean(axis = 1) * scaling , confidence = CI ):.2f}' for i in range(len(result))]

    # Median utility:
    median_utility = [f'{(np.median(result[i],axis = 1) * scaling).mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(np.median(result[i],axis = 1) * scaling, confidence = CI):.2f}' for i in range(len(result))]

    # MIN Utility
    min_utility = [f'{(result[i].min(axis = 1) * scaling).mean() :.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(result[i].min(axis = 1) * scaling, confidence= CI):.2f}' for i in range(len(result))]

    result_sorted = result.copy()
    for i in range(result_sorted.shape[0]):
        result_sorted[i].sort(axis = 1)  # sort the per-run results

    # 10th Percentile
    percentile_utility = [f'{(result_sorted[i][:,int(result.shape[-1] / 10) - 1 ] * scaling).mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(result_sorted[i][:,int(result.shape[-1] / 10) - 1 ] * scaling, confidence= CI):.2f}' for i in range(len(result))]

    # AVG Utility bottom 10 percent
    bottom_utility = [f'{(result_sorted[i][:, : int(result.shape[-1] / 10)].mean(axis = 1) * scaling).mean() :.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(result_sorted[i][:, : int(result.shape[-1] / 10)].mean(axis = 1) * scaling, confidence= CI):.2f}' for i in range(len(result))]

    # Favourite allocaiton
    fav_allocation = [f'{(result[i] == result.max(axis = 0)).sum(axis = 1).mean() * (100 /  result.shape[-1]):.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval((result[i] == result.max(axis = 0)).sum(axis = 1), confidence= CI) * (100 /  result.shape[-1]):.2f}' for i in range(len(result))]

    # AVG Dist to fav
    avg_dist_fav = [f'{((result.max(axis = 0) - result[i]).mean(axis = 1) * scaling).mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval((result.max(axis = 0) - result[i]).mean(axis = 1) * scaling, confidence= CI):.2f}' for i in range(len(result))]

    # MAX Dist to fav
    max_dist_fav = [f'{((result.max(axis = 0) - result[i]).max(axis = 1) * scaling).mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval((result.max(axis = 0) - result[i]).max(axis = 1) * scaling, confidence= CI):.2f}' for i in range(len(result))]

    # clearing error
    clearing_error = [f'{clearing_errors[i].mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(clearing_errors[i], confidence= CI):.2f}' for i in range(len(clearing_errors))]

    # oversubcription seats
    oversubcriptions = [f'{oversubscription_errors[i].mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(oversubscription_errors[i],confidence = CI):.2f}' for i in range(len(oversubscription_errors))]

    if (model_lists is None):
        dataframe = pd.DataFrame(dict(Model = model_names, AvgUtil = avg_utility, MedianUtil = median_utility, MinUtil= min_utility,
                                     PrctUtil = percentile_utility,  FavAlloc = fav_allocation,
                                     AvgDistFav = avg_dist_fav,  MaxDistFav = max_dist_fav, ClearErrors = clearing_error, Oversubscr = oversubcriptions))
        print(dataframe.to_latex(index = False, escape= False))
        return

    else:  # if model lists arguement is provided -> means that we want to also run tests on the Performance of the models themselves!
        model_types = [models_run[i][0] for i in range(len(models_run))]
        if(model_types[-1] != 'True'):
            print("WARNING - IN ORDER TO USE THIS THE LAST MODEL SHOULD BE TRUE PREFS")
            return

        y_total = test_all_models_enhanced(model_lists, model_types, timetables, test_samples, test_samples_range= test_samples_range, value_range= value_range)
        KT_results = get_all_KTs_enhanced(y_total)

        avg_kt = [f'{KT_results[i].mean(axis = 1).mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(KT_results[i].mean(axis = 1), confidence= CI):.3f}' for i in range(len(KT_results))]

        # Mean absolute error
        y_total_copy = y_total.copy()
        y_total_abs = np.absolute(y_total_copy - y_total_copy[-1])
        mae = [f'{y_total_abs[i].mean(axis = (1,2)).mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(y_total_abs[i].mean(axis = (1,2)), confidence= CI):.3f}' for i in range(len(y_total_abs))]

        dataframe = pd.DataFrame(dict(Model = model_names, AvgUtil = avg_utility,  MedianUtil = median_utility, MinUtil = min_utility, PrctUtil = percentile_utility, FavAlloc = fav_allocation,
                                     AvgDistFav = avg_dist_fav, MaxDistFav = max_dist_fav, ClearErrors = clearing_error, Oversubscr = oversubcriptions, AvgKT = avg_kt, MAE = mae))
        print(dataframe.to_latex(index = False, escape= False))

        return(y_total, KT_results)


def get_all_KTs_enhanced(y_total):

    KTs_all_runs = [[[] for run in range(len(y_total[0]))] for model in range(len(y_total))]

    for model in range(len(y_total)):
        for run in range(len(y_total[model])):
            for student_idx in range(len(y_total[model][run])):
                KTs_all_runs[model][run].append(scipy.stats.kendalltau(y_total[model][run][student_idx], y_total[-1][run][student_idx])[0])

    return(np.array(KTs_all_runs))  # shape: model x run x student


def test_all_models_enhanced(model_student_lists_all_runs, model_types, timetables, test_samples = 2000, test_samples_range = 2000, value_range = (250, math.inf)):
    if (model_types[-1] != 'True'):
        print("WARNING - LAST MODEL SHOULD BE TRUE instead of ", model_types[-1])
        return
    # result shape: model x run  x student x sample_prediction
    y_total = [[[] for j in range(len(timetables))] for i in range(len(model_types))]

    for run in range(len(model_student_lists_all_runs[0])):
        print(f'Run number: {run}')
        timetable = timetables[run]
        # model_student_lists_single_run = model_student_lists_all_runs[run]   # this was for the old order!
        model_student_lists_single_run = [model_student_lists_all_runs[i][run] for i in range(len(model_student_lists_all_runs))]

        for j in range(len(model_student_lists_single_run[0])):  # i: which of the models, j: which of the students
            print(f'Current Student: {j}')

            # Get the true student model
            True_model = model_student_lists_single_run[-1][j]
    #         set_trace()
            additive_prefs = True_model[0]
            substitutes = True_model[1]
            complements = True_model[2]
            timegap_penalty = True_model[3]
            overload_penalty = True_model[4]
            free_days_marginal_values = True_model[5]

            # for every student generate a dataset
            _, _, X_true, y_true = create_initial_dataset(additive_prefs, substitutes, complements, timetable, credit_units = [1 for i in range(len(additive_prefs))],
                    overload_penalty= overload_penalty, free_days_marginal_values= free_days_marginal_values, n_samples_train= 0,
                    n_samples_test= test_samples, n_courses= len(additive_prefs), make_monotone= True, n_samples_test_range= test_samples_range, value_range= value_range)

            # test every model on that dataset
            for i in range(len(model_student_lists_single_run)):
                student_model = model_student_lists_single_run[i][j]
                results_single_student = []
                for x in X_true:
                    #                 set_trace()
                    results_single_student.append(model_predict(student_model, x, timetable = timetable, model_type = model_types[i]))

                y_total[i][run].append(results_single_student)

    return np.array(y_total)  # shape: model x run x students x predictions


def model_predict(model, bundle, timetable, model_type):
    if (model_type in ['LinearRegression', 'TrueLinear', 'Ridge', 'Lasso', 'ElasticNet', 'LinearRegressionNoisy', 'LinearNoisy', 'RidgeNoisy', 'LassoNoisy', 'ElasticNetNoisy']):
        return np.dot(model[0], bundle)

    elif (model_type == 'True' or model_type == 'TrueNoisy'):
        additive_prefs = model[0]
        substitutes = model[1]
        complements = model[2]
        timegap_penalty = model[3]
        overload_penalty = model[4]
        free_days_marginal_values = model[5]
        return student(bundle = bundle, additive_prefs= additive_prefs, substitutes= substitutes, complements= complements, timetable= timetable, overload_penalty= overload_penalty,
                       free_days_marginal_values= free_days_marginal_values, credit_units= [1 for i in range(len(additive_prefs))], make_monotone= True)

    elif(model_type == 'PairwiseAdjustments' or model_type == 'PairwiseAdjustmentsNoisy'):
        additive_prefs = model[0]
        substitutes = model[1]
        complements = model[2]
        timegap_penalty = 0
        overload_penalty = 0
        free_days_marginal_values = [0, 0, 0, 0, 0]
        return student(bundle = bundle, additive_prefs= additive_prefs, substitutes= substitutes, complements= complements, timetable= timetable, overload_penalty= overload_penalty,
                       free_days_marginal_values= free_days_marginal_values, credit_units= [1 for i in range(len(additive_prefs))], make_monotone= True)

    elif(model_type == 'UNN' or model_type == 'UNN_Noisy'):
        ml_model = model[0]  # model_list also carries the student's budget, not needed here
        return(ml_model(torch.from_numpy(bundle).float()))


def read_principled_results_smol(result, clearing_errors, oversubscription_errors, model_names, times = None, model_lists = None, timetables = None, models_run = None, test_samples = 2000, test_samples_range = 2000, value_range = (250, math.inf),
CI = 0.95, scale_utilities = True):
    if (scale_utilities):
        scaling = 100 / result[-1].mean(axis = 1)
    else:
        scaling = 1

    # AVG UTILITY:
    avg_utility = [f'{(result[i].mean(axis = 1) * scaling).mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(result[i].mean(axis = 1) * scaling , confidence = CI ):.2f}' for i in range(len(result))]

    # MIN Utility
    min_utility = [f'{(result[i].min(axis = 1) * scaling).mean() :.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(result[i].min(axis = 1) * scaling, confidence= CI):.2f}' for i in range(len(result))]

    # Times for the different models
    if times is not None:
        times = times / 60    # now the time is in minutes instead of seconds
        times_list = [f'{times[i].mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(times[i], confidence = CI ):.2f}' for i in range(len(times))]

    result_sorted = result.copy()
    for i in range(result_sorted.shape[0]):
        result_sorted[i].sort(axis = 1)  # sort the per-run results

    # Favourite allocaiton
    fav_allocation = [f'{(result[i] == result.max(axis = 0)).sum(axis = 1).mean() * (100 /  result.shape[-1]):.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval((result[i] == result.max(axis = 0)).sum(axis = 1), confidence= CI) * (100 /  result.shape[-1]):.2f}' for i in range(len(result))]

    # AVG Dist to fav
    avg_dist_fav = [f'{((result.max(axis = 0) - result[i]).mean(axis = 1) * scaling).mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval((result.max(axis = 0) - result[i]).mean(axis = 1) * scaling, confidence= CI):.2f}' for i in range(len(result))]

    # clearing error
    clearing_error = [f'{clearing_errors[i].mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(clearing_errors[i], confidence= CI):.2f}' for i in range(len(clearing_errors))]

    # oversubcription seats
    oversubcriptions = [f'{oversubscription_errors[i].mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(oversubscription_errors[i],confidence = CI):.2f}' for i in range(len(oversubscription_errors))]

    if (model_lists is None):
        if (times is not None):
            dataframe = pd.DataFrame(dict(Model = model_names, AvgUtil = avg_utility, MinUtil= min_utility,
                                          FavAlloc = fav_allocation, AvgDistFav = avg_dist_fav, ClearErrors = clearing_error, Oversubscr = oversubcriptions, Times = times_list))
            print(dataframe.to_latex(index = False, escape= False))
            return

        dataframe = pd.DataFrame(dict(Model = model_names, AvgUtil = avg_utility, MinUtil= min_utility,
                                      FavAlloc = fav_allocation, AvgDistFav = avg_dist_fav, ClearErrors = clearing_error, Oversubscr = oversubcriptions))
        # TODO: Figure out what is going on with the clearing error!
        print(dataframe.to_latex(index = False, escape= False))
        return

    else:  # if model lists arguement is provided -> means that we want to also run tests on the Performance of the models themselves!
        model_types = [models_run[i][0] for i in range(len(models_run))]
        if(model_types[-1] != 'True'):
            print("WARNING - IN ORDER TO USE THIS THE LAST MODEL SHOULD BE TRUE PREFS")
            return

        y_total = test_all_models_enhanced(model_lists, model_types, timetables, test_samples, test_samples_range= test_samples_range, value_range = value_range)
        KT_results = get_all_KTs_enhanced(y_total)

        avg_kt = [f'{KT_results[i].mean(axis = 1).mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(KT_results[i].mean(axis = 1), confidence= CI):.3f}' for i in range(len(KT_results))]

        # Mean absolute error
        y_total_copy = y_total.copy()
        y_total_abs = np.absolute(y_total_copy - y_total_copy[-1])
        mae = [f'{y_total_abs[i].mean(axis = (1,2)).mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(y_total_abs[i].mean(axis = (1,2)), confidence= CI):.3f}' for i in range(len(y_total_abs))]

        dataframe = pd.DataFrame(dict(Model = model_names, AvgUtil = avg_utility, MinUtil = min_utility, FavAlloc = fav_allocation,
                                     AvgDistFav = avg_dist_fav,  Oversubscr = oversubcriptions, AvgKT = avg_kt, MAE = mae))
        print(dataframe.to_latex(index = False, escape= False))

        return(y_total, KT_results)
        return(y_total, KT_results)


def read_stage_2_3_results(result, model_names, times, CI = 0.95, scale_utilities = True):
    if (scale_utilities):
        scaling = 100 / result[-1].mean(axis = 1)
    else:
        scaling = 1

    # AVG UTILITY:
    avg_utility = [f'{(result[i].mean(axis = 1) * scaling).mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(result[i].mean(axis = 1) * scaling , confidence = CI ):.2f}' for i in range(len(result))]

    # MIN Utility
    min_utility = [f'{(result[i].min(axis = 1) * scaling).mean() :.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(result[i].min(axis = 1) * scaling, confidence= CI):.2f}' for i in range(len(result))]

    # Times for the different models

    times = times / 60    # now the time is in minutes instead of seconds
    times_list = [f'{times[i].mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(times[i], confidence = CI ):.2f}' for i in range(len(times))]

    result_sorted = result.copy()
    for i in range(result_sorted.shape[0]):
        result_sorted[i].sort(axis = 1)  # sort the per-run results

    # Favourite allocaiton
    fav_allocation = [f'{(result[i] == result.max(axis = 0)).sum(axis = 1).mean() * (100 /  result.shape[-1]):.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval((result[i] == result.max(axis = 0)).sum(axis = 1), confidence= CI) * (100 /  result.shape[-1]):.2f}' for i in range(len(result))]

    # AVG Dist to fav
    avg_dist_fav = [f'{((result.max(axis = 0) - result[i]).mean(axis = 1) * scaling).mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval((result.max(axis = 0) - result[i]).mean(axis = 1) * scaling, confidence= CI):.2f}' for i in range(len(result))]

    # clearing error

    # oversubcription seats
    dataframe = pd.DataFrame(dict(Model = model_names, AvgUtil = avg_utility, MinUtil= min_utility,
                                          FavAlloc = fav_allocation, AvgDistFav = avg_dist_fav, Times = times_list))
    print(dataframe.to_latex(index = False, escape= False))
    return


def read_stage_2_3_results_combined(results_s2, results_s3, model_names, times_s2, times_s3, CI = 0.95, scale_utilities = True):
    if (scale_utilities):
        scaling = 100 / results_s3[-1].mean(axis = 1)
    else:
        scaling = 1

    # AVG UTILITY:
    avg_utility_s2 = [f'{(results_s2[i].mean(axis = 1) * scaling).mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(results_s2[i].mean(axis = 1) * scaling , confidence = CI ):.2f}' for i in range(len(results_s2))]
    avg_utility_s3 = [f'{(results_s3[i].mean(axis = 1) * scaling).mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(results_s3[i].mean(axis = 1) * scaling , confidence = CI ):.2f}' for i in range(len(results_s3))]

    # MIN Utility
    min_utility_s2 = [f'{(results_s2[i].min(axis = 1) * scaling).mean() :.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(results_s2[i].min(axis = 1) * scaling, confidence= CI):.2f}' for i in range(len(results_s2))]
    min_utility_s3 = [f'{(results_s3[i].min(axis = 1) * scaling).mean() :.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(results_s3[i].min(axis = 1) * scaling, confidence= CI):.2f}' for i in range(len(results_s2))]

    # Times for the different models

    times_s2 = times_s2 / 60    # now the time is in minutes instead of seconds
    times_s3 = times_s3 / 60
    times_list_s2 = [f'{times_s2[i].mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(times_s2[i], confidence = CI ):.2f}' for i in range(len(times_s2))]
    times_list_s3 = [f'{times_s3[i].mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(times_s3[i], confidence = CI ):.2f}' for i in range(len(times_s3))]

    # result_sorted = result.copy()
    # for i in range(result_sorted.shape[0]):
    #     result_sorted[i].sort(axis = 1)  # sort the per-run results
    #
    # # Favourite allocaiton
    # fav_allocation = [f'{(result[i] == result.max(axis = 0)).sum(axis = 1).mean() * (100 /  result.shape[-1]):.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval((result[i] == result.max(axis = 0)).sum(axis = 1), confidence= CI) * (100 /  result.shape[-1]):.2f}' for i in range(len(result))]
    #
    # # AVG Dist to fav
    # avg_dist_fav = [f'{((result.max(axis = 0) - result[i]).mean(axis = 1) * scaling).mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval((result.max(axis = 0) - result[i]).mean(axis = 1) * scaling, confidence= CI):.2f}' for i in range(len(result))]

    # clearing error

    # oversubcription seats
    # dataframe = pd.DataFrame(dict(Model = model_names, AvgUtil = avg_utility, MinUtil= min_utility,
    #                                       FavAlloc = fav_allocation, AvgDistFav = avg_dist_fav, Times = times_list))

    dataframe = pd.DataFrame(dict(Model = model_names, AvgUtilS2 = avg_utility_s2, AvgUtilS3 = avg_utility_s3,  MinUtilS2= min_utility_s2, MinUtilS3 = min_utility_s3,
                                          TimesS2 = times_list_s2, TimesS3 = times_list_s3))
    print(dataframe.to_latex(index = False, escape= False))
    return


def read_all_stages_combined(results_s1, results_s2, results_s3, model_names, times_s1, times_s2, times_s3, oversubscriptions_s1, ce_s1, query_numbers, CI = 0.95, scale_utilities = True, time_CI = False, ce_tb = 75, time_in_hours = True):
    if (scale_utilities):
        scaling = 100 / results_s1[-1].mean(axis = 1)   # we should probably scale with stage 1 results!
    else:
        scaling = 1

    # AVG UTILITY:
    avg_utility_s1 = [f'{(results_s1[i].mean(axis = 1) * scaling).mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(results_s1[i].mean(axis = 1) * scaling , confidence = CI ):.2f}' for i in range(len(results_s1))]
    avg_utility_s2 = [f'{(results_s2[i].mean(axis = 1) * scaling).mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(results_s2[i].mean(axis = 1) * scaling , confidence = CI ):.2f}' for i in range(len(results_s2))]
    avg_utility_s3 = [f'{(results_s3[i].mean(axis = 1) * scaling).mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(results_s3[i].mean(axis = 1) * scaling , confidence = CI ):.2f}' for i in range(len(results_s3))]

    # MIN Utility
    min_utility_s1 = [f'{(results_s1[i].min(axis = 1) * scaling).mean() :.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(results_s1[i].min(axis = 1) * scaling, confidence= CI):.2f}' for i in range(len(results_s1))]
    min_utility_s2 = [f'{(results_s2[i].min(axis = 1) * scaling).mean() :.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(results_s2[i].min(axis = 1) * scaling, confidence= CI):.2f}' for i in range(len(results_s2))]
    min_utility_s3 = [f'{(results_s3[i].min(axis = 1) * scaling).mean() :.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(results_s3[i].min(axis = 1) * scaling, confidence= CI):.2f}' for i in range(len(results_s3))]

    # Times for the different models

    times_s1 = times_s1 / 60
    times_s2 = times_s2 / 60    # now the time is in minutes instead of seconds
    times_s3 = times_s3 / 60

    if (time_in_hours):
        times_s1 = times_s1 / 60
        times_s2 = times_s2 / 60
        times_s3 = times_s3 / 60

    if (time_CI):
        times_list_s1 = [f'{times_s1[i].mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(times_s1[i], confidence = CI ):.2f}' for i in range(len(times_s1))]
        times_list_s2 = [f'{times_s2[i].mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(times_s2[i], confidence = CI ):.2f}' for i in range(len(times_s2))]
        times_list_s3 = [f'{times_s3[i].mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(times_s3[i], confidence = CI ):.2f}' for i in range(len(times_s3))]

    else:
        times_list_s1 = [f'{times_s1[i].mean():.2f}' for i in range(len(times_s1))]
        times_list_s2 = [f'{times_s2[i].mean():.2f}' for i in range(len(times_s2))]
        times_list_s3 = [f'{times_s3[i].mean():.2f}' for i in range(len(times_s3))]

    avg_utility_s1_LCI = [(results_s1[i].mean(axis = 1) * scaling).mean() - mean_confidence_interval(results_s1[i].mean(axis = 1) * scaling, confidence = CI) for i in range(len(results_s1))]
    avg_utility_s1_UCI = [(results_s1[i].mean(axis = 1) * scaling).mean() + mean_confidence_interval(results_s1[i].mean(axis = 1) * scaling, confidence = CI) for i in range(len(results_s1))]

    avg_utility_s2_LCI = [(results_s2[i].mean(axis = 1) * scaling).mean() - mean_confidence_interval(results_s2[i].mean(axis = 1) * scaling, confidence = CI) for i in range(len(results_s2))]
    avg_utility_s2_UCI = [(results_s2[i].mean(axis = 1) * scaling).mean() + mean_confidence_interval(results_s2[i].mean(axis = 1) * scaling, confidence = CI) for i in range(len(results_s2))]

    avg_utility_s3_LCI = [(results_s3[i].mean(axis = 1) * scaling).mean() - mean_confidence_interval(results_s3[i].mean(axis = 1) * scaling, confidence = CI) for i in range(len(results_s3))]
    avg_utility_s3_UCI = [(results_s3[i].mean(axis = 1) * scaling).mean() + mean_confidence_interval(results_s3[i].mean(axis = 1) * scaling, confidence = CI) for i in range(len(results_s3))]

    min_utility_s1_LCI = [(results_s1[i].min(axis = 1) * scaling).mean() - mean_confidence_interval(results_s1[i].min(axis = 1) * scaling, confidence = CI) for i in range(len(results_s1))]
    min_utility_s1_UCI = [(results_s1[i].min(axis = 1) * scaling).mean() + mean_confidence_interval(results_s1[i].min(axis = 1) * scaling, confidence = CI) for i in range(len(results_s1))]

    min_utility_s2_LCI = [(results_s2[i].min(axis = 1) * scaling).mean() - mean_confidence_interval(results_s2[i].min(axis = 1) * scaling, confidence = CI) for i in range(len(results_s2))]
    min_utility_s2_UCI = [(results_s2[i].min(axis = 1) * scaling).mean() + mean_confidence_interval(results_s2[i].min(axis = 1) * scaling, confidence = CI) for i in range(len(results_s2))]

    min_utility_s3_LCI = [(results_s3[i].min(axis = 1) * scaling).mean() - mean_confidence_interval(results_s3[i].min(axis = 1) * scaling, confidence = CI) for i in range(len(results_s3))]
    min_utility_s3_UCI = [(results_s3[i].min(axis = 1) * scaling).mean() + mean_confidence_interval(results_s3[i].min(axis = 1) * scaling, confidence = CI) for i in range(len(results_s3))]

    # oversubcription seats
    oversubcriptions = [f'{oversubscriptions_s1[i].mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(oversubscriptions_s1[i],confidence = CI):.2f}' for i in range(len(oversubscriptions_s1))]

    # clearing_error
    ce_column = [f'{ce_s1[i].mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(ce_s1[i],confidence = CI):.2f}' for i in range(len(ce_s1))]

    # clearing error below theoretical bound
    btb_column = [f'{(ce_s1[i] <= 75).sum() / ce_s1[i].shape[0]:.2f}' for i in range(len(ce_s1))]

    # query_numbers
    qnb_column = [query_numbers[i][0] for i in range(len(query_numbers))]
    qnj_column = [query_numbers[i][1] for i in range(len(query_numbers))]

    for i in range(len(model_names)):
        if avg_utility_s1_UCI[i] >= np.max(avg_utility_s1_LCI[:-1]):  # we need the -1 because otherwise here we would only ever colour the scaled true...
            avg_utility_s1[i] = '\ccell ' + avg_utility_s1[i]
        if avg_utility_s2_UCI[i] >= np.max(avg_utility_s2_LCI):
            avg_utility_s2[i] = '\ccell ' + avg_utility_s2[i]
        if avg_utility_s3_UCI[i] >= np.max(avg_utility_s3_LCI):
            avg_utility_s3[i] = '\ccell ' + avg_utility_s3[i]
        if min_utility_s1_UCI[i] >= np.max(min_utility_s1_LCI):
            min_utility_s1[i] = '\ccell ' + min_utility_s1[i]
        if min_utility_s2_UCI[i] >= np.max(min_utility_s2_LCI):
            min_utility_s2[i] = '\ccell ' + min_utility_s2[i]
        if min_utility_s3_UCI[i] >= np.max(min_utility_s3_LCI):
            min_utility_s3[i] = '\ccell ' + min_utility_s3[i]
        if ce_s1[i].mean() <= ce_tb:
            ce_column[i] = '\ccell ' + ce_column[i]
        if ((ce_s1[i] <= 75).sum() / ce_s1[i].shape[0]) == 1:
            btb_column[i] = '\ccell ' + btb_column[i]

    dataframe = pd.DataFrame(dict(Model = model_names, QB = qnb_column, QJ = qnj_column, AvgUtilS1 = avg_utility_s1, AvgUtilS2 = avg_utility_s2, AvgUtilS3 = avg_utility_s3,  MinUtilS1= min_utility_s1, MinUtilS2= min_utility_s2, MinUtilS3 = min_utility_s3,
                                          Oversubscriptions = oversubcriptions, CE = ce_column, BTB = btb_column,  AvgTimeS1 = times_list_s1, AvgTimeS2 = times_list_s2, AvgTimeS3 = times_list_s3))
    print(dataframe.to_latex(index = False, escape= False))
    return


def read_neighbor_configs_stage1(results_s1, model_names, times_s1, oversubscriptions_s1, clearing_errors,  CI = 0.95, scale_utilities = True):
    if (scale_utilities):
        scaling = 100 / results_s1[-1].mean(axis = 1)   # we should probably scale with stage 1 results!
    else:
        scaling = 1

    # AVG UTILITY:
    avg_utility_s1 = [f'{(results_s1[i].mean(axis = 1) * scaling).mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(results_s1[i].mean(axis = 1) * scaling , confidence = CI ):.2f}' for i in range(len(results_s1))]

    # MIN Utility
    # min_utility_s1 = [f'{(results_s1[i].min(axis = 1) * scaling).mean() :.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(results_s1[i].min(axis = 1) * scaling, confidence= CI):.2f}' for i in range(len(results_s1))]

    # Times for the different models

    times_s1 = times_s1 / 60
    times_list_s1 = [f'{times_s1[i].mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(times_s1[i], confidence = CI ):.2f}' for i in range(len(times_s1))]

    # oversubcription seats
    oversubscriptions = [f'{oversubscriptions_s1[i].mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(oversubscriptions_s1[i],confidence = CI):.2f}' for i in range(len(oversubscriptions_s1))]
    clearing_errors_list = [f'{clearing_errors[i].mean():.2f} ' + "$\pm$ \scriptsize " + f'{mean_confidence_interval(clearing_errors[i],confidence = CI):.2f}' for i in range(len(clearing_errors))]

    oversubscriptions_argmin = np.argmin(oversubscriptions_s1.mean(axis = 1))
    ce_argmin = np.argmin(clearing_errors.mean(axis = 1))

    oversubscriptions[oversubscriptions_argmin] = '\ccell ' + oversubscriptions[oversubscriptions_argmin]
    clearing_errors_list[ce_argmin] = '\ccell ' + clearing_errors_list[ce_argmin]
    # dataframe = pd.DataFrame(dict(Model = model_names, AvgUtil = avg_utility, MinUtil= min_utility,
    #                                       FavAlloc = fav_allocation, AvgDistFav = avg_dist_fav, Times = times_list))

    dataframe = pd.DataFrame(dict(Model = model_names, AvgUtilS1 = avg_utility_s1, Oversubscriptions = oversubscriptions, ClearErros = clearing_errors_list, TimesS1 = times_list_s1))
    print(dataframe.to_latex(index = False, escape= False))
    return


#     --- Start of Iterative Version! ---   #
def train_unn(X_train, y_train, num_hidden_layers, num_units, random_ts, trainable_ts, init_E, init_Var,
              learning_rate, weight_decay, epochs, batch_size, loss, n_courses = 30, max_courses_in_bundle = 5, max_value_full_bundle = 2):
    """
    Takes as input the training set (as numpy arrays) and trains one MVNN of the latest version on that input, with all the MVNN hyperparameters.
    """

    y_max_unscaled = y_train.max()
    initialization_constant = (max_value_full_bundle*max_courses_in_bundle) / n_courses

    y_train = y_train / y_max_unscaled
    y_train = y_train * initialization_constant

    scale = (y_max_unscaled / initialization_constant)

    ymax = initialization_constant

    # Step 3: Put the data in a dataloader
    train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(X_train).float(),
                                       torch.from_numpy(y_train.reshape(-1, 1)).float())
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size= batch_size, shuffle=True)

    # Step 4: Initiate and train the model

    model = MVNN(input_dim=X_train.shape[1],
                         num_hidden_layers=num_hidden_layers,
                         num_units=num_units,
                         layer_type='MVNNLayerReLUProjected',
                         target_max=ymax,
                         dropout_prob=0,
                         init_method='custom',
                         random_ts= random_ts,
                         trainable_ts= trainable_ts,
                         init_E= init_E,
                         init_Var= init_Var,
                         init_b = 0.05,
                         init_bias = 0.05,
                         init_little_const = 0.1
                 )

    #make sure ts have no regularisation
    # the bigger t the more regular
    print(f'Setting init var to: {init_Var} hidden_layers to: {num_hidden_layers} and units to: {num_units} and weight decay: {weight_decay}')
    l2_reg_parameters = {'params': [], 'weight_decay': weight_decay}
    no_l2_reg_parameters = {'params': [], 'weight_decay': 0.0}
    for p in [*model.named_parameters()]:
        if 'ts' in p[0]:
            logging.debug(f'Setting L2-Reg. to 0.0 for {p[0]}.')
            no_l2_reg_parameters['params'].append(p[1])
        else:
            l2_reg_parameters['params'].append(p[1])

    optimizer = torch.optim.Adam([l2_reg_parameters, no_l2_reg_parameters], lr= learning_rate)

    trained_model = train_unn_lean(model, optimizer, train_loader, ymax, epochs = epochs, loss_function= loss)

    return trained_model, scale


def construct_implied_dataset(additive_prefs, substitutes, complements, unforgotten_base_values, make_monotone = True, points_to_add = 0, points_to_hallucinate = 0, forgotten_course_expected_value = None,
        sample_category_weights = [0.5, 0.25, 1], sample_relative_frequencies = None, seed=None):
    """
    Takes as input a noisy student (the output of guisify student) of a student  and constructs a dataset for an ML algorithm based only on those.
    """

    X_train = []
    y_train = []

    train_weights = []

    if (seed):
        np.random.seed(seed)

    # step 1: create the bundles that correspond to all singleton sets:
    for i in range(len(additive_prefs)):
        x = [0 for i in range(len(additive_prefs))]
        x[i] = 1

        if (unforgotten_base_values[i] == 1):
            y = additive_prefs[i]

            if sample_relative_frequencies is not None:
                for j in range(sample_relative_frequencies[0]):
                    X_train.append(x)
                    y_train.append(y)

            else:
                X_train.append(x)
                y_train.append(y)
                train_weights.append(sample_category_weights[0])

    # step 2: add all the bundles that correspond to 2 courses plus their adjustments

    for (substitute_courses, adjustments) in substitutes:
        actual_adj_value = adjustments[1]

        for i in range(0, len(substitute_courses)):
            for j in range(i + 1, len(substitute_courses)):

                value_course_a = additive_prefs[substitute_courses[i]]
                value_course_b = additive_prefs[substitute_courses[j]]
                y = (value_course_a + value_course_b) * (1 + actual_adj_value)

                if make_monotone:
                    y = max(y, value_course_a, value_course_b)

#                 print(f'Value a: {value_course_a} value_b: {value_course_b} adj_value: {actual_adj_value} bundle value: {y}')

                x = [0 for i in range(len(additive_prefs))]
                x[substitute_courses[i]] = 1
                x[substitute_courses[j]] = 1

                if sample_relative_frequencies is not None:
                    for j in range(sample_relative_frequencies[0]):
                        X_train.append(x)
                        y_train.append(y)

                else:
                    X_train.append(x)
                    y_train.append(y)
                    train_weights.append(sample_category_weights[0])

#     print('--- COMPLEMENTS ---')

    for (complement_courses, adjustments) in complements:
        actual_adj_value = adjustments[1]
#         print(f' Actual adjustment value: {actual_adj_value}')
        for i in range(0, len(complement_courses)):
            for j in range(i + 1, len(complement_courses)):
                #                 print(complement_courses[i], complement_courses[j])

                value_course_a = additive_prefs[complement_courses[i]]
                value_course_b = additive_prefs[complement_courses[j]]
                y = (value_course_a + value_course_b) * (1 + actual_adj_value)


#                 print(f'Value a: {value_course_a} value_b: {value_course_b} adj_value: {actual_adj_value} bundle value: {y}')

                x = [0 for i in range(len(additive_prefs))]
                x[complement_courses[i]] = 1
                x[complement_courses[j]] = 1

                if sample_relative_frequencies is not None:
                    for j in range(sample_relative_frequencies[0]):
                        X_train.append(x)
                        y_train.append(y)

                else:
                    X_train.append(x)
                    y_train.append(y)
                    train_weights.append(sample_category_weights[0])

    # creating futher implied value bundles as CM would
#     print('--- IMPLIED BUNDLES (points to add) ---')

    base_value_indices = np.nonzero(unforgotten_base_values)[0]  # these are the indices of all the non_forgotten bases
    for i in range(points_to_add):
        courses_in_bundle = np.random.choice(base_value_indices, size = 5, replace = False)   # the bundle only contains courses for which we have the base values

        x = [0 for i in range(len(additive_prefs))]
        y = 0

        for j in courses_in_bundle:
            x[j] = 1                     # create the 0-1 encoding that corresponds to the bundle that the student got
            y = y + additive_prefs[j]   # add the base values for the courses that the student got

        for j in range(len(courses_in_bundle)):       # check to see if there are adjustments that you need to add
            for k in range(j+1, len(courses_in_bundle)):
                index_course_a = courses_in_bundle[j]
                index_course_b = courses_in_bundle[k]

                for (complement_courses, adjustments) in complements:
                    actual_adj_value = adjustments[1]

                    if (index_course_a in complement_courses) and (index_course_b in complement_courses):
                        value_to_add = (additive_prefs[index_course_a] + additive_prefs[index_course_b]) * actual_adj_value

                        y = y + value_to_add

                for (substitute_courses, adjustments) in substitutes:
                    actual_adj_value = adjustments[1]

                    if (index_course_a in substitute_courses) and (index_course_b in substitute_courses):
                        value_to_subtract = abs((additive_prefs[index_course_a] + additive_prefs[index_course_b]) * actual_adj_value)

                        if make_monotone:
                            value_to_subtract = min(value_to_subtract, additive_prefs[index_course_a], additive_prefs[index_course_b])

                        # print(f'value course a: {additive_prefs[index_course_a]} value_course_b: {additive_prefs[index_course_b]} adj_value: {actual_adj_value} value to subtract: {value_to_subtract}')
                        y = y - value_to_subtract

        if sample_relative_frequencies is not None:
            for j in range(sample_relative_frequencies[1]):
                X_train.append(x)
                y_train.append(y)
        else:
            X_train.append(x)
            y_train.append(y)
            train_weights.append(sample_category_weights[1])

    # HALLUCINATING bundles in places where we don't actually have the complete bundle information, i.e. a student has forgotten some base values.
#     print('--- IMPLIED BUNDLES (points to add) ---')

    additive_prefs_copy = additive_prefs.copy()
    if forgotten_course_expected_value is None:
        minimum_course_value = additive_prefs.min()
        forgotten_course_expected_value = minimum_course_value / 2

    for i in range(len(additive_prefs)):
        if (unforgotten_base_values[i] == 0):
            additive_prefs_copy[i] = forgotten_course_expected_value

    base_value_indices = np.nonzero(unforgotten_base_values)[0]  # these are the indices of all the non_forgotten bases
    for i in range(points_to_hallucinate):
        courses_in_bundle = np.random.choice(base_value_indices, size = 5, replace = False)   # the bundle only contains courses for which we have the base values

        x = [0 for i in range(len(additive_prefs))]
        y = 0

        for j in courses_in_bundle:
            x[j] = 1                     # create the 0-1 encoding that corresponds to the bundle that the student got
            y = y + additive_prefs_copy[j]   # add the base values for the courses that the student got

        for j in range(len(courses_in_bundle)):       # check to see if there are adjustments that you need to add
            for k in range(j+1, len(courses_in_bundle)):
                index_course_a = courses_in_bundle[j]
                index_course_b = courses_in_bundle[k]

                for (complement_courses, adjustments) in complements:
                    actual_adj_value = adjustments[1]

                    if (index_course_a in complement_courses) and (index_course_b in complement_courses):
                        value_to_add = (additive_prefs_copy[index_course_a] + additive_prefs_copy[index_course_b]) * actual_adj_value

                        y = y + value_to_add

                for (substitute_courses, adjustments) in substitutes:
                    actual_adj_value = adjustments[1]

                    if (index_course_a in substitute_courses) and (index_course_b in substitute_courses):
                        value_to_subtract = abs((additive_prefs_copy[index_course_a] + additive_prefs_copy[index_course_b]) * actual_adj_value)

                        if make_monotone:
                            value_to_subtract = min(value_to_subtract, additive_prefs[index_course_a], additive_prefs[index_course_b])

                        # print(f'value course a: {additive_prefs[index_course_a]} value_course_b: {additive_prefs[index_course_b]} adj_value: {actual_adj_value} value to subtract: {value_to_subtract}')
                        y = y - value_to_subtract

        if sample_relative_frequencies is not None:
            for j in range(sample_relative_frequencies[1]):
                X_train.append(x)
                y_train.append(y)
        else:
            X_train.append(x)
            y_train.append(y)
            train_weights.append(sample_category_weights[1])

    X_train = np.array(X_train)
    y_train = np.array(y_train)
    train_weights = np.array(train_weights)

    if sample_relative_frequencies is not None:
        return X_train, y_train

    else:
        return (X_train, y_train, train_weights)


def r2_loss(output, target, epsilon = 1e-5):
    target_mean = torch.mean(target)
    ss_tot = torch.sum((target - target_mean) ** 2)
    ss_res = torch.sum((target - output) ** 2)
    r2 = 1 - ss_res / (ss_tot + epsilon)
    return -r2


def construct_implied_dataset_v2(additive_prefs, substitutes, complements, unforgotten_base_values, make_monotone = True, points_to_add = 0, points_to_hallucinate = 0,
        forgotten_course_expected_value = None, thompson_sampling = False, chance_actual_zero = 0, uniform_range_low = None, uniform_range_high = None,
        sample_category_weights = [0.5, 0.25, 1], sample_relative_frequencies = None, courses_in_a_schedule = 5, seed=None):
    """
    Takes as input a noisy student (the output of guisify student) of a student  and constructs a dataset for an ML algorithm based only on those.
    """

    X_train = []
    y_train = []

    train_weights = []

    if (seed):
        rng = np.random.default_rng(seed)
    else:
        rng = np.random.default_rng()

    # step 1: create the bundles that correspond to all singleton sets:
    for i in range(len(additive_prefs)):
        x = [0 for i in range(len(additive_prefs))]
        x[i] = 1

        if (unforgotten_base_values[i] == 1):
            y = additive_prefs[i]

            if sample_relative_frequencies is not None:
                for j in range(sample_relative_frequencies[0]):
                    X_train.append(x)
                    y_train.append(y)

            else:
                X_train.append(x)
                y_train.append(y)
                train_weights.append(sample_category_weights[0])

    # step 2: add all the bundles that correspond to 2 courses plus their adjustments

    for (substitute_courses, adjustments) in substitutes:
        actual_adj_value = adjustments[1]

        for i in range(0, len(substitute_courses)):
            for j in range(i + 1, len(substitute_courses)):

                value_course_a = additive_prefs[substitute_courses[i]]
                value_course_b = additive_prefs[substitute_courses[j]]
                y = (value_course_a + value_course_b) * (1 + actual_adj_value)

                if make_monotone:
                    y = max(y, value_course_a, value_course_b)

#                 print(f'Value a: {value_course_a} value_b: {value_course_b} adj_value: {actual_adj_value} bundle value: {y}')

                x = [0 for i in range(len(additive_prefs))]
                x[substitute_courses[i]] = 1
                x[substitute_courses[j]] = 1

                if sample_relative_frequencies is not None:
                    for j in range(sample_relative_frequencies[0]):
                        X_train.append(x)
                        y_train.append(y)

                else:
                    X_train.append(x)
                    y_train.append(y)
                    train_weights.append(sample_category_weights[0])

#     print('--- COMPLEMENTS ---')

    for (complement_courses, adjustments) in complements:
        actual_adj_value = adjustments[1]
#         print(f' Actual adjustment value: {actual_adj_value}')
        for i in range(0, len(complement_courses)):
            for j in range(i + 1, len(complement_courses)):
                #                 print(complement_courses[i], complement_courses[j])

                value_course_a = additive_prefs[complement_courses[i]]
                value_course_b = additive_prefs[complement_courses[j]]
                y = (value_course_a + value_course_b) * (1 + actual_adj_value)


#                 print(f'Value a: {value_course_a} value_b: {value_course_b} adj_value: {actual_adj_value} bundle value: {y}')

                x = [0 for i in range(len(additive_prefs))]
                x[complement_courses[i]] = 1
                x[complement_courses[j]] = 1

                if sample_relative_frequencies is not None:
                    for j in range(sample_relative_frequencies[0]):
                        X_train.append(x)
                        y_train.append(y)

                else:
                    X_train.append(x)
                    y_train.append(y)
                    train_weights.append(sample_category_weights[0])

    # creating futher implied value bundles as CM would
#     print('--- IMPLIED BUNDLES (points to add) ---')

    base_value_indices = np.nonzero(unforgotten_base_values)[0]  # these are the indices of all the non_forgotten bases
    if len(base_value_indices) <= courses_in_a_schedule:
        points_to_add_current_student = 1
    else:
        points_to_add_current_student = points_to_add
    for i in range(points_to_add_current_student):
        courses_in_bundle = rng.choice(base_value_indices, size = min(courses_in_a_schedule, len(base_value_indices)), replace = False)   # the bundle only contains courses for which we have the base values

        x = [0 for i in range(len(additive_prefs))]
        y = 0

        for j in courses_in_bundle:
            x[j] = 1                     # create the 0-1 encoding that corresponds to the bundle that the student got
            y = y + additive_prefs[j]   # add the base values for the courses that the student got

        for j in range(len(courses_in_bundle)):       # check to see if there are adjustments that you need to add
            for k in range(j+1, len(courses_in_bundle)):
                index_course_a = courses_in_bundle[j]
                index_course_b = courses_in_bundle[k]

                for (complement_courses, adjustments) in complements:
                    actual_adj_value = adjustments[1]

                    if (index_course_a in complement_courses) and (index_course_b in complement_courses):
                        value_to_add = (additive_prefs[index_course_a] + additive_prefs[index_course_b]) * actual_adj_value

                        y = y + value_to_add

                for (substitute_courses, adjustments) in substitutes:
                    actual_adj_value = adjustments[1]

                    if (index_course_a in substitute_courses) and (index_course_b in substitute_courses):
                        value_to_subtract = abs((additive_prefs[index_course_a] + additive_prefs[index_course_b]) * actual_adj_value)

                        if make_monotone:
                            value_to_subtract = min(value_to_subtract, additive_prefs[index_course_a], additive_prefs[index_course_b])

                        # print(f'value course a: {additive_prefs[index_course_a]} value_course_b: {additive_prefs[index_course_b]} adj_value: {actual_adj_value} value to subtract: {value_to_subtract}')
                        y = y - value_to_subtract

        if sample_relative_frequencies is not None:
            for j in range(sample_relative_frequencies[1]):
                X_train.append(x)
                y_train.append(y)
        else:
            X_train.append(x)
            y_train.append(y)
            train_weights.append(sample_category_weights[1])

    # HALLUCINATING bundles in places where we don't actually have the complete bundle information, i.e. a student has forgotten some base values.
#     print('--- HALLUCINATED BUNDLES (points to hallucinate) ---')

    additive_prefs_copy = additive_prefs.copy()
    if thompson_sampling is False:
        # thompson sampling is False -> we add to every forgotten course an expected value"
        if forgotten_course_expected_value is None:   # -> if a forgotten course expected value isn't given -> we assume a posterior where it is:  p * 0  + (1 - p ) ~ U[low_value, high_value]
            if uniform_range_high is None:              # by default the low_value is 0 and the high_value is the minimum of the courses that a student declared a value for
                uniform_range_high = additive_prefs[base_value_indices].min()  # set the upper bound of your uniform to the minimum of the (non-zero) base values!
            if uniform_range_low is None:
                uniform_range_low = 0

            forgotten_course_expected_value = ((uniform_range_high + uniform_range_low) / 2) * (1 - chance_actual_zero)

        for i in range(len(additive_prefs)):
            if (unforgotten_base_values[i] == 0):
                additive_prefs_copy[i] = forgotten_course_expected_value

    else:
        # Thompson sampling is true -> For every forgotten course, we sample a value from the posterior distribution that we have for those
        if uniform_range_high is None:              # by default the low_value is 0 and the high_value is the minimum of the courses that a student declared a value for
            uniform_range_high = additive_prefs[base_value_indices].min()
        if uniform_range_low is None:
            uniform_range_low = 0

        for i in range(len(additive_prefs)):
            if (unforgotten_base_values[i] == 0):  # we are in a couse that the student did not type a value for:
                if (rng.random() < chance_actual_zero):
                    additive_prefs_copy[i] = 0   # we happened to sample from the "0 part of our mixture between zero and the Uniform"

                else:
                    additive_prefs_copy[i] = rng.random() * (uniform_range_high - uniform_range_low) + uniform_range_low

    base_value_indices = [i for i in range(len(additive_prefs))]  # these are the indices of ALL of the courses
    for i in range(points_to_hallucinate):
        courses_in_bundle = rng.choice(base_value_indices, size = courses_in_a_schedule, replace = False)

        x = [0 for i in range(len(additive_prefs))]
        y = 0

        for j in courses_in_bundle:
            x[j] = 1                     # create the 0-1 encoding that corresponds to the bundle that the student got
            y = y + additive_prefs_copy[j]   # add the base values for the courses that the student got

        for j in range(len(courses_in_bundle)):       # check to see if there are adjustments that you need to add
            for k in range(j+1, len(courses_in_bundle)):
                index_course_a = courses_in_bundle[j]
                index_course_b = courses_in_bundle[k]

                for (complement_courses, adjustments) in complements:
                    actual_adj_value = adjustments[1]

                    if (index_course_a in complement_courses) and (index_course_b in complement_courses):
                        value_to_add = (additive_prefs_copy[index_course_a] + additive_prefs_copy[index_course_b]) * actual_adj_value

                        y = y + value_to_add

                for (substitute_courses, adjustments) in substitutes:
                    actual_adj_value = adjustments[1]

                    if (index_course_a in substitute_courses) and (index_course_b in substitute_courses):
                        value_to_subtract = abs((additive_prefs_copy[index_course_a] + additive_prefs_copy[index_course_b]) * actual_adj_value)

                        if make_monotone:
                            value_to_subtract = min(value_to_subtract, additive_prefs_copy[index_course_a], additive_prefs_copy[index_course_b])

                        # print(f'value course a: {additive_prefs[index_course_a]} value_course_b: {additive_prefs[index_course_b]} adj_value: {actual_adj_value} value to subtract: {value_to_subtract}')
                        y = y - value_to_subtract

        if sample_relative_frequencies is not None:
            for j in range(sample_relative_frequencies[2]):
                X_train.append(x)
                y_train.append(y)
        else:
            X_train.append(x)
            y_train.append(y)
            train_weights.append(sample_category_weights[2])

    X_train = np.array(X_train)
    y_train = np.array(y_train)
    train_weights = np.array(train_weights)

    if sample_relative_frequencies is not None:
        return X_train, y_train

    else:
        return (X_train, y_train, train_weights)


def binary_search(arr, x):
    """
    An iterative implementation of a binary search funciton. The only difference is that instead of returning the index of the solution, it returns all of the points it checked.
    Will be used to simulate asking a log number of CQs for each point to find its right place in the queue.
    """
    low = 0
    high = len(arr) - 1
    mid = 0
    indexes_checked_list = []
    while low <= high:
        mid = (high + low) // 2
        indexes_checked_list.append(mid)
        # If x is greater, ignore left half
        if arr[mid] < x:
            low = mid + 1
        # If x is smaller, ignore right half
        elif arr[mid] > x:
            high = mid - 1
        # means x is present at mid  -> binary search done
        else:
            return indexes_checked_list
    # If we reach here, then the element was not present -> binary search done
    return indexes_checked_list


def binary_search_hacky(arr, x, queries_available):
    """
    An iterative implementation of a binary search funciton. The only difference is that if it stops early, it returns the upper and lower bound for the point x.
    """
    upper_bound = len(arr)
    lower_bound = -1    # Our target point is lower than any index to the right of and including this
    queries_thus_far = 0  # OUr target point is higher OR EQUAL than any index to the left of and including this

    low = 0
    high = len(arr) - 1
    mid = 0
    indexes_checked_list = []

    while low <= high and (queries_thus_far < queries_available):
        queries_thus_far = queries_thus_far + 1
        mid = (high + low) // 2
        indexes_checked_list.append(mid)
        # If x is greater, ignore left half
        if arr[mid] < x:
            lower_bound = mid
            low = mid + 1
        # If x is smaller, ignore right half
        elif arr[mid] > x:
            upper_bound = mid
            high = mid - 1
        # means x is present at mid  -> binary search done
        else:
            print('FOUND EXACTLY')
            upper_bound = mid + 1
            lower_bound = mid
            return indexes_checked_list, lower_bound, upper_bound
    # If we reach here, then the element was not present -> binary search done
    return indexes_checked_list, lower_bound, upper_bound


def generate_new_queries(actual_student_list, timetable, number_of_samples, current_training_set, model_student_list,
                approximate_prices, credit_units, model_type = 'LinearRegression', seed = 42, model_param_dictionary = {}, gui_student_list = []):
    if (model_type in ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'LinearRegressionNoisy', 'RidgeNoisy', 'LassoNoisy', 'ElasticNetNoisy']):
        if (approximate_prices is None):     # approximate prices == None -> drawing samples for the first time
            print('NEW QUERIES generating dataset from scratch')
            new_training_set = []
            for (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget) in actual_student_list:
                X_train, y_train, _, _ = create_initial_dataset(additive_prefs, substitutes, complements, timetable, credit_units = [1 for i in range(len(additive_prefs))],
                    overload_penalty= overload_penalty, free_days_marginal_values= free_days_marginal_values, n_samples_train= number_of_samples,
                    n_samples_test= 0, n_courses= len(additive_prefs), make_monotone= True, n_samples_train_range= model_param_dictionary['samples_in_range'],
                    value_range= (model_param_dictionary['range_min_value'], math.inf), seed = seed)
                new_training_set.append((X_train, y_train))
            return(new_training_set)

        else:
            print("NEW QUERIES expanding on a dataset!!!")
            for (i, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(actual_student_list):
                print(f'Expanding the dataset for student: {i}')
                (linear_coeffs, budget) = model_student_list[i]  # get the current model for that student.
                (X_train, y_train) = current_training_set[i]

                X_list = []
                y_list = []
                bundles_to_forbid = X_train

                for j in range(number_of_samples):
                    if len(X_list) == 0:
                        bundles_to_forbid = X_train
                    else:
                        bundles_to_forbid = np.append(X_train, np.array(X_list), axis = 0)

                    try:
                        new_x = solve_student(timetable, approximate_prices, credit_units, budget, 5, linear_coeffs, [], [],
                        overload_penalty = 0, timegap_penalty= 0, free_days_marginal_values= [0, 0, 0, 0, 0], ignore_timegaps= True,
                        forbidden_bundles = bundles_to_forbid, verbose = False)
                    except:
                        print('--- ACHTUNG ACHTUNG ---')
                        print(f'GENERATE QUERIES STOPPED EARLY AT SAMPLE NUMBER {j}')
                        break

                    # NOTE: should check that this works!!!
                    new_y = student(new_x, additive_prefs, substitutes, complements, timetable,
                            overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                            credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)

                    X_list.append(new_x)
                    y_list.append(new_y)
                    bundles_to_forbid = np.append(X_train, np.array(X_list), axis = 0)

                X_train_new = np.array(X_list)
                y_train_new = np.array(y_list)

                # set_trace()
                if (len(X_list) >= 1):
                    X_train = np.append(X_train, X_train_new, axis = 0)
                    y_train = np.append(y_train, y_train_new, axis = 0)
                current_training_set[i] = (X_train, y_train)

            return current_training_set
        
    elif (model_type in ['UNN_projected']):
        return []  # no new queries for UNN_projected, as they will load the already trained UNNs. This is just a placeholder.

    elif (model_type in ['NuSVR', 'NuSVRNoisy']):
        if (approximate_prices is None):     # approximate prices == None -> drawing samples for the first time
            print('NEW QUERIES generating dataset from scratch')
            new_training_set = []
            for (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget) in actual_student_list:
                X_train, y_train, _, _ = create_initial_dataset(additive_prefs, substitutes, complements, timetable, credit_units = [1 for i in range(len(additive_prefs))],
                    overload_penalty= overload_penalty, free_days_marginal_values= free_days_marginal_values, n_samples_train= number_of_samples,
                    n_samples_test= 0, n_courses= len(additive_prefs), make_monotone= True, n_samples_train_range= model_param_dictionary['samples_in_range'],
                    value_range= (model_param_dictionary['range_min_value'], math.inf), seed = seed)
                new_training_set.append((X_train, y_train))
            return(new_training_set)

        else:
            print("NEW QUERIES expanding on a dataset!!!")
            for (i, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(actual_student_list):
                print(f'Expanding the dataset for student: {i}')
                (model, solver, gamma, budget) = model_student_list[i]  # get the current model for that student.
                (X_train, y_train) = current_training_set[i]

                # add the budget constraint to the SVR solver
                solver.add_budget_constraint(course_prices = approximate_prices, budget = budget)

                # add all of the already existing training set on the the bundles not to query.
                for bundle in X_train:
                    solver.add_forbidden_bundle(bundle)

                X_list = []
                y_list = []

                for j in range(number_of_samples):
                    try:
                        new_x, predicted_value = solver.solve_mip(verbose = False)
                        solver.add_forbidden_bundle(new_x)  # add the new bundle we just queried to the list of forbidden bundles so that we don't ask the same question.
                    # NOTE: should check that this works!!!
                    except:
                        print('--- ACHTUNG ACHTUNG ---')
                        print(f'GENERATE QUERIES STOPPED EARLY AT SAMPLE NUMBER {j}')
                        break

                    new_y = student(new_x, additive_prefs, substitutes, complements, timetable,
                            overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                            credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)

                    X_list.append(new_x)
                    y_list.append(new_y)

                X_train_new = np.array(X_list)
                y_train_new = np.array(y_list)

                if (len(X_list) >= 1):
                    X_train = np.append(X_train, X_train_new, axis = 0)
                    y_train = np.append(y_train, y_train_new, axis = 0)
                current_training_set[i] = (X_train, y_train)

            return current_training_set

    elif (model_type in ['xgboost', 'xgboostNoisy']):
        if (approximate_prices is None):     # approximate prices == None -> drawing samples for the first time
            print('NEW QUERIES generating dataset from scratch')
            new_training_set = []
            for (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget) in actual_student_list:
                X_train, y_train, _, _ = create_initial_dataset(additive_prefs, substitutes, complements, timetable, credit_units = [1 for i in range(len(additive_prefs))],
                    overload_penalty= overload_penalty, free_days_marginal_values= free_days_marginal_values, n_samples_train= number_of_samples,
                    n_samples_test= 0, n_courses= len(additive_prefs), make_monotone= True, n_samples_train_range= model_param_dictionary['samples_in_range'],
                    value_range= (model_param_dictionary['range_min_value'], math.inf), seed = seed)
                new_training_set.append((X_train, y_train))
            return(new_training_set)

        else:
            print("NEW QUERIES expanding on a dataset!!!")
            for (i, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(actual_student_list):
                print(f'Expanding the dataset for student: {i}')
                (model, solver, budget) = model_student_list[i]  # get the current model for that student.
                (X_train, y_train) = current_training_set[i]

                # add the budget constraint to the SVR solver
                solver.add_budget_constraint(course_prices = approximate_prices, budget = budget)

                # add all of the already existing training set on the the bundles not to query.
                for bundle in X_train:
                    solver.add_forbidden_bundle(bundle)

                X_list = []
                y_list = []

                for j in range(number_of_samples):
                    try:
                        new_x, predicted_value = solver.solve_mip(verbose = False)
                        solver.add_forbidden_bundle(new_x)  # add the new bundle we just queried to the list of forbidden bundles so that we don't ask the same question.
                    # NOTE: should check that this works!!!
                    except:
                        print('--- ACHTUNG ACHTUNG ---')
                        print(f'GENERATE QUERIES STOPPED EARLY AT SAMPLE NUMBER {j}')
                        break

                    new_y = student(new_x, additive_prefs, substitutes, complements, timetable,
                            overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                            credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)

                    X_list.append(new_x)
                    y_list.append(new_y)

                X_train_new = np.array(X_list)
                y_train_new = np.array(y_list)

                # set_trace()

                if (len(X_list) >= 1):
                    X_train = np.append(X_train, X_train_new, axis = 0)
                    y_train = np.append(y_train, y_train_new, axis = 0)
                current_training_set[i] = (X_train, y_train)

            return current_training_set

    elif (model_type in ['UNN', 'UNN_Noisy', 'UNN_transfer_learning']):
        if (approximate_prices is None):     # approximate prices == None -> drawing samples for the first time
            print('NEW QUERIES generating dataset from scratch')
            new_training_set = []
            if (model_param_dictionary.get('use_implied_dataset', False)):
                if (not model_param_dictionary.get('use_cqs', False)):
                    print('NEW QUERIES GENERATING IMPLIED DATASET from GUI reports, VALUE QUERIES next!')
                    for (base_values, substitutes, complements, unforgotten_base_values) in gui_student_list:
                        (X_train, y_train) = construct_implied_dataset_v2(additive_prefs= base_values, substitutes= substitutes, complements= complements,
                            unforgotten_base_values= unforgotten_base_values, make_monotone= True, points_to_add= model_param_dictionary['points_to_add'],
                            points_to_hallucinate= model_param_dictionary['points_to_hallucinate'], forgotten_course_expected_value= model_param_dictionary['forgotten_course_expected_value'],
                            thompson_sampling= False, chance_actual_zero= 0, uniform_range_low = model_param_dictionary['uniform_range_low'], uniform_range_high = model_param_dictionary['uniform_range_high'],
                            sample_category_weights= None, sample_relative_frequencies= model_param_dictionary['sample_relative_frequencies'], seed= None)

                        X_implied = X_train.copy()
                        y_implied = y_train.copy()
                        new_training_set.append([(X_train, y_train), (X_implied, y_implied), ([], [])])

                else:
                    print('NEW QUERIES GENERATING IMPLIED DATASET from GUI reports, COMPARSION QUERIES next!')
                    for (student_counter, (base_values, substitutes, complements, unforgotten_base_values)) in enumerate(gui_student_list):
                        (X_train, y_train) = construct_implied_dataset_v2(additive_prefs= base_values, substitutes= substitutes, complements= complements,
                            unforgotten_base_values= unforgotten_base_values, make_monotone= True, points_to_add= model_param_dictionary['points_to_add'],
                            points_to_hallucinate= model_param_dictionary['points_to_hallucinate'], forgotten_course_expected_value= model_param_dictionary['forgotten_course_expected_value'],
                            thompson_sampling= False, chance_actual_zero= 0, uniform_range_low = model_param_dictionary['uniform_range_low'], uniform_range_high = model_param_dictionary['uniform_range_high'],
                            sample_category_weights= None, sample_relative_frequencies= model_param_dictionary['sample_relative_frequencies'], seed= None)

                        if model_param_dictionary['cq_method'] in ['basic', 'random']:

                            # desired shape for this method:
                            # (model, solver, scale, budget) = model_student_list[i]  # get the current model for that student.
                            # (X_train, y_train) = current_training_set[i][0]  # get the current (cardinal) dataset
                            # (X_train_ord, y_train_ord) = current_training_set[i][1]  # get the current ordinal dataset
                            # (x_max_current, y_max_current) = current_training_set[i][2]    # get the current best performing point, according to the true student
                            # bundles_to_forbid = current_training_set[i][3]

                            max_idx = np.argmax(y_train)
                            x_max_current = X_train[max_idx]
                            y_max_current = y_train[max_idx]
                            bundles_to_forbid = X_train[X_train.sum(axis = 1) > 2].copy().tolist()
                            # do not ask a question about a point that you have already constructed from the GUI reports, and it is also a bundle of size 5!

                            new_training_set.append([(X_train, y_train), ([], []), (x_max_current, y_max_current), bundles_to_forbid])

                        elif model_param_dictionary['cq_method'] in ['basic_plus']:
                            max_idx = np.argmax(y_train)
                            x_max_current = X_train[max_idx]
                            y_max_current = y_train[max_idx]
                            bundles_to_forbid = X_train[X_train.sum(axis = 1) > 2].copy().tolist()
                            bundles_queried = [x_max_current] # we will add the bundles that are involved in previous CQs, so in case the new point is better, we compare against all of them 

                            new_training_set.append([(X_train, y_train), ([], []), (x_max_current, y_max_current), bundles_to_forbid, bundles_queried])

                        elif model_param_dictionary['cq_method'] in ['high_pair', 'random_high_pair']:
                            bundles_to_forbid = X_train[X_train.sum(axis = 1) > 2].copy().tolist()
                            new_training_set.append([(X_train, y_train), ([], []), [], bundles_to_forbid])

                        elif model_param_dictionary['cq_method'] in ['complete_ordering', 'complete_ordering_pruned']:
                            idx_max = np.argmax(y_train)
                            x_max = X_train[idx_max]

                            (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget) = actual_student_list[student_counter]
                            y_max_true = student(x_max, additive_prefs, substitutes, complements, timetable,
                                        overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                                        credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)

                            bundles_to_forbid = X_train[X_train.sum(axis = 1) > 2].copy().tolist()

                            new_training_set.append(((X_train, y_train), (np.array([]), np.array([])), ([x_max], [y_max_true]), bundles_to_forbid, 0))

                        elif model_param_dictionary['cq_method'] == 'basic_log':
                            sorted_idxs = np.argsort(y_train)
                            sorted_ys = y_train[sorted_idxs].copy().tolist()
                            sorted_Xs = X_train[sorted_idxs].copy().tolist()

                            bundles_to_forbid = X_train[X_train.sum(axis = 1) > 2].copy().tolist()

                            new_training_set.append(((X_train, y_train), (np.array([]), np.array([])), (sorted_Xs, sorted_ys), bundles_to_forbid))

            else:
                for (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget) in actual_student_list:
                    X_train, y_train, _, _ = create_initial_dataset(additive_prefs, substitutes, complements, timetable, credit_units = [1 for i in range(len(additive_prefs))],
                        overload_penalty= overload_penalty, free_days_marginal_values= free_days_marginal_values, n_samples_train= number_of_samples,
                        n_samples_test= 0, n_courses= len(additive_prefs), make_monotone= True, n_samples_train_range= model_param_dictionary['samples_in_range'],
                        value_range= (model_param_dictionary['range_min_value'], math.inf), seed = seed)
                    new_training_set.append((X_train, y_train))
            return(new_training_set)

        else:
            print("NEW QUERIES expanding on a dataset!!!")
            if (not model_param_dictionary.get('use_implied_dataset', False)):
                for (i, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(actual_student_list):
                    (model, solver, budget) = model_student_list[i]  # get the current model for that student.
                    (X_train, y_train) = current_training_set[i]
                    print(f'Expanding the dataset for student: {i} whose X_train has a shape of {X_train.shape}')

                    # add the budget constraint to the UNN solver
                    solver.add_budget_constraint(course_prices = approximate_prices, budget = budget)

                    # add all of the already existing training set on the the bundles not to query.
                    for bundle in X_train:
                        solver.add_forbidden_bundle(bundle)

                    X_list = []
                    y_list = []

                    for j in range(number_of_samples):
                        try:
                            new_x = solver.solve_mip(outputFlag=False, verbose = False)
                            solver.add_forbidden_bundle(new_x)  # add the new bundle we just queried to the list of forbidden bundles so that we don't ask the same question.

                        except:
                            print('--- ACHTUNG ACHTUNG ---')
                            print(f'GENERATE QUERIES STOPPED EARLY AT SAMPLE NUMBER {j}')
                            break

                        new_y = student(new_x, additive_prefs, substitutes, complements, timetable,
                                overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                                credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)

                        sample_relative_frequencies = model_param_dictionary.get('gui_sample_relative_frequencies', None)
                        if (sample_relative_frequencies) is not None:
                            print(f'adding a new query as a sample {sample_relative_frequencies[2]} times!')
                            for j in range(sample_relative_frequencies[2]):
                                X_list.append(new_x)
                                y_list.append(new_y)
                        else:
                            X_list.append(new_x)
                            y_list.append(new_y)

                    X_train_new = np.array(X_list)
                    y_train_new = np.array(y_list)

                    if (len(X_list) >= 1):
                        X_train = np.append(X_train, X_train_new, axis = 0)
                        y_train = np.append(y_train, y_train_new, axis = 0)
                    current_training_set[i] = (X_train, y_train)

                return current_training_set
            else:
                print('NEW QUERIES expanding on an implied dataset')
                if (not model_param_dictionary.get('use_cqs', False)):
                    print('Using VALUE QUERIES next')
                    for (i, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(actual_student_list):
                        (model, solver, budget) = model_student_list[i]  # get the current model for that student.

                        (X_train, y_train) = current_training_set[i][0]  # get the current dataset (all points included)
                        (X_implied, y_implied) = current_training_set[i][1]  # get the implied dataset
                        (X_queried, y_queried) = current_training_set[i][2]  # get the dataset of points that you have actually queried
                        print(f'Expanding the dataset for student: {i} whose X_train has a shape of {X_train.shape}')

                        # add the budget constraint to the UNN solver
                        solver.add_budget_constraint(course_prices = approximate_prices, budget = budget)

                        # add all of the already queried bundles on the the bundles not to query.
                        for bundle in X_queried:
                            solver.add_forbidden_bundle(bundle)

                        X_list = []
                        y_list = []

                        for j in range(number_of_samples):
                            try:
                                new_x = solver.solve_mip(outputFlag=False, verbose = False)
                                solver.add_forbidden_bundle(new_x)  # add the new bundle we just queried to the list of forbidden bundles so that we don't ask the same question.
                            # NOTE: should check that this works!!!
                            except:
                                print('--- ACHTUNG ACHTUNG ---')
                                print(f'GENERATE QUERIES STOPPED EARLY AT SAMPLE NUMBER {j}')
                                break

                            new_y = student(new_x, additive_prefs, substitutes, complements, timetable,
                                    overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                                    credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)

                            # If the pooint we just queried was in the implied dataset, we have to remove it.
                            agreements = (X_implied == new_x).sum(axis = 1)   # find in how many positions the new queried point matches with any of the implied bundles

                            if (agreements >= len(new_x)).sum() > 0:      # if the optimal prediction was an implied datapoint -> remove it from the implied datapoints
                                print('--- ACHTUNG ACHTUNG found point that is on the implied dataset--- ')  # since now you have the actual value!
                                idx_to_delete = np.argwhere((agreements >= len(new_x)))

                                X_implied = np.delete(X_implied, idx_to_delete[0][0], axis = 0)
                                y_implied = np.delete(y_implied, idx_to_delete[0][0], axis = 0)

                                current_training_set[i][1] = (X_implied, y_implied)

                            sample_relative_frequencies = model_param_dictionary.get('gui_sample_relative_frequencies', None)
                            if (sample_relative_frequencies) is not None:
                                print(f'adding a new query as a sample {sample_relative_frequencies[2]} times!')
                                for j in range(sample_relative_frequencies[2]):
                                    X_list.append(new_x)
                                    y_list.append(new_y)
                            else:
                                X_list.append(new_x)
                                y_list.append(new_y)

                        X_train_new = np.array(X_list)
                        y_train_new = np.array(y_list)

                        if (len(X_list) >= 1):
                            if(len(X_queried) > 0):
                                X_queried = np.append(X_queried, X_train_new, axis = 0)    # add the new points that we just queried to the queried dataset
                                y_queried = np.append(y_queried, y_train_new, axis = 0)
                            else:
                                X_queried = X_train_new
                                y_queried = y_train_new

                            current_training_set[i][2] = (X_queried, y_queried)     # update the dataset of queried points

                            X_train = np.append(X_implied, X_queried, axis = 0)    # since the dataset of queried datapoints was updated -> training dataset is also updated
                            y_train = np.append(y_implied, y_queried, axis = 0)
                            current_training_set[i][0] = (X_train, y_train)

                    return current_training_set

                else:
                    print(f'Using CQs next! Cq method: {model_param_dictionary["cq_method"]}')
                    sample_relative_frequencies = model_param_dictionary.get('gui_sample_relative_frequencies', None)

                    if (model_param_dictionary['cq_method'] in ['basic', 'basic_plus']):
                        for (i, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(actual_student_list):
                            print(f'Expanding the GUI dataset with CQs for student: {i}')
                            if len(model_student_list[i]) == 5:
                                (model, solver, scale, _ , budget) = model_student_list[i]  # this is the case where position 3 has a pre-trained model that we keep fixed (transfer learning case)
                            else:
                                (model, solver, scale, budget) = model_student_list[i]  # get the current model for that student.
                            (X_train, y_train) = current_training_set[i][0]  # get the current (cardinal) dataset
                            (X_train_ord, y_train_ord) = current_training_set[i][1]  # get the current ordinal dataset
                            (x_max_current, _) = current_training_set[i][2]    # get the current best performing point, according to the true student
                            bundles_to_forbid = current_training_set[i][3]   # get the bundles that you don't want your solver to find (e.g. already queried, or implied)

                            if (model_param_dictionary['cq_method'] == 'basic_plus'):
                                 bundles_queried = current_training_set[i][4]   # get the bundles that you have already queried (only applicable to the basic plus method)

                            X_list = []
                            y_list = []

                            y_max_current = student(x_max_current, additive_prefs, substitutes, complements, timetable,
                                        overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                                        credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)

                            print(f'Expanding the dataset for student: {i} whose X_train has a length of {len(X_train)} and X_train_ord has a shape of: {len(X_train_ord)}')
                            print(f'get_iterative_dataset_cq called with currrent x_max: {x_max_current} and y_max: {y_max_current}')

                            # add the budget constraint to the UNN solver
                            solver.add_budget_constraint(course_prices = approximate_prices, budget = budget)

                            # do not query the bundles you ahve already queried.
                            print(f'for student: {i} forbidding {len(bundles_to_forbid)} bundles')
                            for bundle in bundles_to_forbid:
                                solver.add_forbidden_bundle(bundle)

                            for j in range(number_of_samples):
                                try:
                                    new_x = solver.solve_mip(outputFlag=False, verbose = False)

                                    if np.sum(new_x) >= 5:   # do not forbid a bundle that has too few courses! -> this has as a result to also forbid many effective bundles
                                        solver.add_forbidden_bundle(new_x)  # add the new bundle we just queried to the list of forbidden bundles so that we don't ask the same question.
                                        bundles_to_forbid.append(new_x)  # add the new bundle you just found to those that the maximizer won't return in the feauture

                                except:
                                    print('--- ACHTUNG ACHTUNG ---')
                                    print(f'GENERATE QUERIES STOPPED EARLY AT SAMPLE NUMBER {j}')
                                    break

                                new_y = student(new_x, additive_prefs, substitutes, complements, timetable,
                                        overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                                        credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)
                                
                                
                                if (model_param_dictionary['cq_method'] == 'basic'):
                                
                                    new_x_ordinal = (new_x, x_max_current)  # compare the bundle that currently maximizes the MVNN to the max that you had

                                    print(f'new_y: {new_y} vs current y_max: {y_max_current}')
                                    print(f'and new_x: {new_x} containing {np.sum(new_x)} courses')
                                    if new_y >= y_max_current:
                                        # the new point is larger than all in our current dataset -> update the info relating to the current max
                                        print('Found a point that is better than the current max!')
                                        new_y_ordinal = 1
                                        x_max_current = new_x
                                        y_max_current = new_y
                                    else:
                                        new_y_ordinal = 0

                                    if (sample_relative_frequencies) is not None:
                                        print(f'adding a new query as a sample {sample_relative_frequencies[2]} times!')
                                        for _ in range(sample_relative_frequencies[2]):
                                            X_list.append(new_x_ordinal)
                                            y_list.append(new_y_ordinal)
                                    else:
                                        X_list.append(new_x_ordinal)
                                        y_list.append(new_y_ordinal)

                                elif (model_param_dictionary['cq_method'] == 'basic_plus'):
                                    if new_y <= y_max_current:
                                        print('Found a point that is worse than the current max!')
                                        # if the new point is not better than the old one: list behaves exactly in the same way as in the past! 
                                        new_x_ordinal = (new_x, x_max_current)  # compare the bundle that currently maximizes the MVNN to the max that you had
                                        new_y_ordinal = 0

                                        if (sample_relative_frequencies) is not None:
                                            print(f'adding a new query as a sample {sample_relative_frequencies[2]} times!')
                                            for _ in range(sample_relative_frequencies[2]):
                                                X_list.append(new_x_ordinal)
                                                y_list.append(new_y_ordinal)
                                        else:
                                            X_list.append(new_x_ordinal)
                                            y_list.append(new_y_ordinal)

                                    else:
                                        # The new bundle is better than all bundles queried in the past, need to add all of those comparison queries in the dataset
                                        print('Found a point that is better than the current max, adding all implied bundles too!')
                                        for bundle in bundles_queried:
                                            new_x_ordinal = (new_x, bundle)
                                            new_y_ordinal = 1

                                            if (sample_relative_frequencies) is not None:
                                                print(f'adding a new query as a sample {sample_relative_frequencies[2]} times!')
                                                for _ in range(sample_relative_frequencies[2]):
                                                    X_list.append(new_x_ordinal)
                                                    y_list.append(new_y_ordinal)
                                                
                                            else:
                                                X_list.append(new_x_ordinal)
                                                y_list.append(new_y_ordinal)

                                        x_max_current = new_x
                                        y_max_current = new_y



                                    bundles_queried.append(new_x)   # add the new bundle you just found to those that you will explicitly enforce transitivity on in the future


                            X_train_new = np.array(X_list)
                            y_train_new = np.array(y_list)

                            if len(X_list) >= 1:     # there are new CQs that we need to add to our dataset of comparision questions

                                if len(X_train_ord) > 0:   # add the new comparision queries to the dataset of comparisoin questions we already had
                                    X_train_ord = np.append(X_train_ord, X_train_new, axis = 0)
                                    y_train_ord = np.append(y_train_ord, y_train_new, axis = 0)

                                else:   # the dataset of CQs was empty -> replace it with asll the CQs we did this round!
                                    X_train_ord = X_train_new
                                    y_train_ord = y_train_new

                                if model_param_dictionary['cq_method'] == 'basic':
                                    current_training_set[i] = ((X_train, y_train), (X_train_ord, y_train_ord), (x_max_current, y_max_current), bundles_to_forbid)
                                elif model_param_dictionary['cq_method'] == 'basic_plus':
                                    current_training_set[i] = ((X_train, y_train), (X_train_ord, y_train_ord), (x_max_current, y_max_current), bundles_to_forbid, bundles_queried)

                    elif (model_param_dictionary['cq_method'] in ['high_pair', 'random_high_pair']):
                        for (i, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(actual_student_list):
                            print(f'Expanding the GUI dataset with CQs for student: {i}')
                            if len(model_student_list[i]) == 5:
                                (model, solver, scale, _ , budget) = model_student_list[i]  # this is the case where position 3 has a pre-trained model that we keep fixed (transfer learning case)
                            else:
                                (model, solver, scale, budget) = model_student_list[i]  # get the current model for that student.
                            (X_train, y_train) = current_training_set[i][0]  # get the current (cardinal) dataset
                            (X_train_ord, y_train_ord) = current_training_set[i][1]  # get the current ordinal dataset
                            queried_pairs = current_training_set[i][2]    # get the pairs that you have already queried
                            bundles_to_forbid = current_training_set[i][3]   # get the bundles that you don't want your solver to find (e.g. already queried, or implied)

                            X_list = []
                            y_list = []

                            print(f'Expanding the dataset for student: {i} whose X_train has a length of {len(X_train)} and X_train_ord has a shape of: {len(X_train_ord)}')
                            print(f'get_iterative_dataset_cq called with currrent x_max: {x_max_current} and y_max: {y_max_current}')

                            # add the budget constraint to the UNN solver
                            solver.add_budget_constraint(course_prices = approximate_prices, budget = budget)

                            # do not query the bundles you ahve already queried.
                            print(f'for student: {i} forbidding {len(bundles_to_forbid)} bundles')
                            for bundle in bundles_to_forbid:
                                solver.add_forbidden_bundle(bundle)

                            # get the list of the top-valued bundles according to the MVNN 
                            mvnn_bundles = [] 
                            while len(mvnn_bundles) < model_param_dictionary['hp_bundles_to_generate']:
                                try:
                                    new_x = solver.solve_mip(outputFlag=False, verbose = False)

                                    solver.add_forbidden_bundle(new_x)           
                                    # NOTE: In this metod you can always forbid the new bundle, since you are not adding it to the 
                                    # "bundles_to_forbid_list" (which remains static, because what we care about is the *pair* of bundles in a CQ), 
                                    # so forbidding a bundle of size < 5 does not harm you in followup iterations
                                    
                                    mvnn_bundles.append(new_x) # add the new bundle in the MVNN list of bundles

                                except:
                                    print('--- ACHTUNG ACHTUNG ---')
                                    print(f'GENERATE QUERIES STOPPED EARLY AFTER GENERATING {len(mvnn_bundles)} BUNDLES')
                                    break

                            if model_param_dictionary['cq_method'] == 'high_pair':
                                for j in range(len(mvnn_bundles)-1):
                                    bundle1 = mvnn_bundles[j]
                                    bundle2 = mvnn_bundles[j+1]
                                    if not (bundle1, bundle2) in queried_pairs and not (bundle2, bundle1) in queried_pairs:
                                        break 
                            elif model_param_dictionary['cq_method'] == 'random_high_pair':
                                while True:
                                    indexes = np.random.choice(len(mvnn_bundles), 2, replace = False)
                                    bundle1 = mvnn_bundles[indexes[0]]
                                    bundle2 = mvnn_bundles[indexes[1]]
                                    if not (bundle1, bundle2) in queried_pairs and not (bundle2, bundle1) in queried_pairs:
                                        break
                            else: 
                                raise ValueError(f'cq_method {model_param_dictionary["cq_method"]} not recognized!')
                            
                            new_x_ordinal = (bundle1, bundle2)
                            queried_pairs.append(new_x_ordinal)
                            new_y1 = student(bundle1, additive_prefs, substitutes, complements, timetable,
                            overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                            credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)
                            new_y2 = student(bundle2, additive_prefs, substitutes, complements, timetable,
                            overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                            credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)

                            if new_y1 >= new_y2:
                                new_y_ordinal = 1
                            else: 
                                new_y_ordinal = 0

                            if (sample_relative_frequencies) is not None:
                                print(f'adding a new query as a sample {sample_relative_frequencies[2]} times!')
                                for _ in range(sample_relative_frequencies[2]):
                                    X_list.append(new_x_ordinal)
                                    y_list.append(new_y_ordinal)
                            else:
                                X_list.append(new_x_ordinal)
                                y_list.append(new_y_ordinal)

                        X_train_new = np.array(X_list)
                        y_train_new = np.array(y_list)

                        if len(X_list) >= 1:     # there are new CQ that we need to add to our dataset of comparision questions

                            if len(X_train_ord) > 0:   # add the new comparision queries to the dataset of comparisoin questions we already had
                                X_train_ord = np.append(X_train_ord, X_train_new, axis = 0)
                                y_train_ord = np.append(y_train_ord, y_train_new, axis = 0)

                            else:   # the dataset of CQs was empty -> replace it with asll the CQs we did this round!
                                X_train_ord = X_train_new
                                y_train_ord = y_train_new

                            current_training_set[i] = ((X_train, y_train), (X_train_ord, y_train_ord), queried_pairs, bundles_to_forbid)

                    elif (model_param_dictionary['cq_method'] == 'random'):
                        for (i, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(actual_student_list):
                            print(f'Expanding the GUI dataset with CQs for student: {i}')
                            if len(model_student_list[i]) == 5:
                                (model, solver, scale, _ , budget) = model_student_list[i]  # this is the case where position 3 has a pre-trained model that we keep fixed (transfer learning case)
                            else:
                                (model, solver, scale, budget) = model_student_list[i]  # get the current model for that student.
                            (X_train, y_train) = current_training_set[i][0]  # get the current (cardinal) dataset
                            (X_train_ord, y_train_ord) = current_training_set[i][1]  # get the current ordinal dataset
                            (x_max_current, y_max_current) = current_training_set[i][2]    # get the current best performing point, according to the true student
                            bundles_to_forbid = current_training_set[i][3]   # get the bundles that you don't want your solver to find (e.g. already queried, or implied)

                            X_list = []
                            y_list = []

                            print(f'Expanding the dataset for student: {i} whose X_train has a length of {len(X_train)} and X_train_ord has a shape of: {len(X_train_ord)}')
                            print(f'get_iterative_dataset_cq called with currrent x_max: {x_max_current} and y_max: {y_max_current}')

                            # add the budget constraint to the UNN solver
                            # do not query the bundles you ahve already queried.

                            for j in range(number_of_samples):
                                x1_indices = np.random.choice(25, size = 5, replace = False)
                                x2_indices = np.random.choice(25, size = 5, replace = False)
                                x1 = [0 for i in range(25)]
                                x2 = [0 for i in range(25)]
                                for index in x1_indices:
                                    x1[index] = 1
                                for index in x2_indices:
                                    x2[index] = 1


                                new_y1 = student(x1, additive_prefs, substitutes, complements, timetable,
                                        overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                                        credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)
                                new_y2 = student(x2, additive_prefs, substitutes, complements, timetable,
                                        overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                                        credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)

                                new_x_ordinal = (x1, x2)    # NOTE: this line was bugged before. 
                                # last_queries.append(new_x_ordinal)  


                                if new_y1 >= new_y2:
                                    new_y_ordinal = 1
                                else:
                                    new_y_ordinal = 0

                                if (sample_relative_frequencies) is not None:
                                    print(f'adding a new query as a sample {sample_relative_frequencies[2]} times!')
                                    for _ in range(sample_relative_frequencies[2]):
                                        X_list.append(new_x_ordinal)
                                        y_list.append(new_y_ordinal)
                                else:
                                    X_list.append(new_x_ordinal)
                                    y_list.append(new_y_ordinal)

                            X_train_new = np.array(X_list)
                            y_train_new = np.array(y_list)

                            if len(X_list) >= 1:     # there are new CQ that we need to add to our dataset of comparision questions

                                if len(X_train_ord) > 0:   # add the new comparision queries to the dataset of comparison questions we already had
                                    X_train_ord = np.append(X_train_ord, X_train_new, axis = 0)
                                    y_train_ord = np.append(y_train_ord, y_train_new, axis = 0)

                                else:   # the dataset of CQs was empty -> replace it with asll the CQs we did this round!
                                    X_train_ord = X_train_new
                                    y_train_ord = y_train_new

                                current_training_set[i] = ((X_train, y_train), (X_train_ord, y_train_ord), (x_max_current, y_max_current), bundles_to_forbid)

                    elif (model_param_dictionary['cq_method'] == 'basic_log'):
                        print('Entering basic log version in generate new queries!')
                        for (i, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(actual_student_list):
                            print(f'Expanding the GUI dataset with CQs for student: {i}')
                            if len(model_student_list[i]) == 5:
                                (model, solver, scale, _ , budget) = model_student_list[i]  # this is the case where position 3 has a pre-trained model that we keep fixed (transfer learning case)
                            else:
                                (model, solver, scale, budget) = model_student_list[i]  # get the current model for that student.
                            (X_train, y_train) = current_training_set[i][0]  # get the current (cardinal) dataset
                            (X_train_ord, y_train_ord) = current_training_set[i][1]  # get the current ordinal dataset
                            (all_bundles, all_values) = current_training_set[i][2]    # get the current best performing point, according to the true student
                            bundles_to_forbid = current_training_set[i][3]   # get the bundles that you don't want your solver to find (e.g. already queried, or implied)

                            X_list = []
                            y_list = []

                            print(f'Expanding the dataset for student: {i} whose X_train has a length of {len(X_train)} and X_train_ord has a shape of: {len(X_train_ord)}')

                            # add the budget constraint to the UNN solver
                            solver.add_budget_constraint(course_prices = approximate_prices, budget = budget)

                            # do not query the bundles you ahve already queried.
                            print(f'for student: {i} forbidding {len(bundles_to_forbid)} bundles')
                            for bundle in bundles_to_forbid:
                                solver.add_forbidden_bundle(bundle)

                            for j in range(number_of_samples):
                                try:
                                    new_x = solver.solve_mip(outputFlag=False, verbose = False)
                                    if np.sum(new_x) >= 5:
                                        solver.add_forbidden_bundle(new_x)  # add the new bundle we just queried to the list of forbidden bundles so that we don't ask the same question.
                                        bundles_to_forbid.append(new_x)  # add the new bundle you just found to those that the maximizer won't return in the feauture
                                # NOTE: should check that this works!!!
                                except:
                                    print('--- ACHTUNG ACHTUNG ---')
                                    print(f'GENERATE QUERIES STOPPED EARLY AT SAMPLE NUMBER {j}')
                                    break

                                new_y = student(new_x, additive_prefs, substitutes, complements, timetable,
                                        overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                                        credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)

                                binary_search_indexes = binary_search(arr = all_values, x = new_y)  # those are the indexes of the binary search to find the position of the new bundle
                                print(f'Indexes of the binary search for the new_y: {binary_search_indexes}')

                                for k in binary_search_indexes:
                                    old_x = all_bundles[k]
                                    old_y = all_values[k]

                                    new_x_ordinal = (new_x, old_x)  # compare the bundle that currently maximizes the MVNN to the max that you had

                                    if new_y >= old_y:
                                        new_y_ordinal = 1
                                    else:
                                        new_y_ordinal = 0

                                    if (sample_relative_frequencies) is not None:
                                        for _ in range(sample_relative_frequencies[2]):
                                            X_list.append(new_x_ordinal)
                                            y_list.append(new_y_ordinal)
                                    else:
                                        X_list.append(new_x_ordinal)
                                        y_list.append(new_y_ordinal)

                                all_bundles.insert(binary_search_indexes[-1], new_x)   # insert the new point in the right place so that future iterations of the binary search still work!
                                all_values.insert(binary_search_indexes[-1], new_y)

                            X_train_new = np.array(X_list)
                            y_train_new = np.array(y_list)

                            if len(X_list) >= 1:     # there are new CQ that we need to add to our dataset of comparision questions

                                if len(X_train_ord) > 0:   # add the new comparision queries to the dataset of comparisoin questions we already had
                                    X_train_ord = np.append(X_train_ord, X_train_new, axis = 0)
                                    y_train_ord = np.append(y_train_ord, y_train_new, axis = 0)

                                else:   # the dataset of CQs was empty -> replace it with asll the CQs we did this round!
                                    X_train_ord = X_train_new
                                    y_train_ord = y_train_new

                                current_training_set[i] = ((X_train, y_train), (X_train_ord, y_train_ord), (all_bundles, all_values), bundles_to_forbid)

                    elif (model_param_dictionary['cq_method'] == 'complete_ordering'):
                        print('Entering complete ordering version in generate new queries!')
                        for (i, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(actual_student_list):
                            print(f'Expanding the GUI dataset with CQs for student: {i}')
                            if len(model_student_list[i]) == 5:
                                (model, solver, scale, _ , budget) = model_student_list[i]  # this is the case where position 3 has a pre-trained model that we keep fixed (transfer learning case)
                            else:
                                (model, solver, scale, budget) = model_student_list[i]  # get the current model for that student.
                            (X_train, y_train) = current_training_set[i][0]  # get the current (cardinal) dataset
                            (X_train_ord, y_train_ord) = current_training_set[i][1]  # get the current ordinal dataset
                            (all_bundles_queried, all_values_queried) = current_training_set[i][2]    # get the set of all points already queried
                            bundles_to_forbid = current_training_set[i][3]   # get the bundles that you don't want your solver to find (e.g. already queried, or implied)

                            X_list = []
                            y_list = []

                            print(f'Expanding the dataset for student: {i} whose X_train has a length of {len(X_train)} and X_train_ord has a shape of: {len(X_train_ord)}')

                            # add the budget constraint to the UNN solver
                            solver.add_budget_constraint(course_prices = approximate_prices, budget = budget)

                            # do not query the bundles you ahve already queried.
                            print(f'for student: {i} forbidding {len(bundles_to_forbid)} bundles')
                            for bundle in bundles_to_forbid:
                                solver.add_forbidden_bundle(bundle)

                            for j in range(number_of_samples):
                                try:
                                    new_x = solver.solve_mip(outputFlag=False, verbose = False)
                                    if np.sum(new_x) >= 5:
                                        solver.add_forbidden_bundle(new_x)  # add the new bundle we just queried to the list of forbidden bundles so that we don't ask the same question.
                                        bundles_to_forbid.append(new_x)  # add the new bundle you just found to those that the maximizer won't return in the feauture
                                # NOTE: should check that this works!!!
                                except:
                                    print('--- ACHTUNG ACHTUNG ---')
                                    print(f'GENERATE QUERIES STOPPED EARLY AT SAMPLE NUMBER {j}')
                                    break

                                new_y = student(new_x, additive_prefs, substitutes, complements, timetable,
                                        overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                                        credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)

                                print(f'For student {i}, all bundles queried has a size of: {len(all_bundles_queried)}')

                                for k in range(len(all_bundles_queried)):
                                    old_x = all_bundles_queried[k]
                                    old_y = all_values_queried[k]

                                    new_x_ordinal = (new_x, old_x)  # compare the bundle that currently maximizes the MVNN to the max that you had

                                    if new_y >= old_y:
                                        new_y_ordinal = 1
                                    else:
                                        new_y_ordinal = 0

                                    if (sample_relative_frequencies) is not None:
                                        for _ in range(sample_relative_frequencies[2]):
                                            X_list.append(new_x_ordinal)
                                            y_list.append(new_y_ordinal)
                                    else:
                                        X_list.append(new_x_ordinal)
                                        y_list.append(new_y_ordinal)

                                print(f'For student {i}, the length of the new X_list after we just added 1 "sample" to it is: {len(X_list)}')

                                all_bundles_queried.append(new_x)   # insert the new point in the list of all bundles queried.
                                all_values_queried.append(new_y)

                            X_train_new = np.array(X_list)
                            y_train_new = np.array(y_list)

                            if len(X_list) >= 1:     # there are new CQ that we need to add to our dataset of comparision questions

                                if len(X_train_ord) > 0:   # add the new comparision queries to the dataset of comparisoin questions we already had
                                    X_train_ord = np.append(X_train_ord, X_train_new, axis = 0)
                                    y_train_ord = np.append(y_train_ord, y_train_new, axis = 0)

                                else:   # the dataset of CQs was empty -> replace it with asll the CQs we did this round!
                                    X_train_ord = X_train_new
                                    y_train_ord = y_train_new

                                current_training_set[i] = ((X_train, y_train), (X_train_ord, y_train_ord), (all_bundles_queried, all_values_queried), bundles_to_forbid)

                    elif (model_param_dictionary['cq_method'] == 'complete_ordering_pruned' and not model_param_dictionary.get('CQ_mistake_probability', False)):
                        print('Entering complete ordering PRUNED version in generate new queries')
                        for (i, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(actual_student_list):
                            print(f'Expanding the GUI dataset with CQs for student: {i}')
                            if len(model_student_list[i]) == 5:
                                (model, solver, scale, _ , budget) = model_student_list[i]  # this is the case where position 3 has a pre-trained model that we keep fixed (transfer learning case)
                            else:
                                (model, solver, scale, budget) = model_student_list[i]  # get the current model for that student.
                            (X_train, y_train) = current_training_set[i][0]  # get the current (cardinal) dataset
                            (X_train_ord, y_train_ord) = current_training_set[i][1]  # get the current ordinal dataset
                            (all_bundles_queried, all_values_queried) = current_training_set[i][2]    # get the set of all points already queried
                            bundles_to_forbid = current_training_set[i][3]   # get the bundles that you don't want your solver to find (e.g. already queried, or implied)
                            queries_thus_far = current_training_set[i][4]  # so that we can count how many binary search queries were required in every iteration

                            X_list = []
                            y_list = []

                            print(f'Expanding the dataset for student: {i} whose X_train has a length of {len(X_train)} and X_train_ord has a shape of: {len(X_train_ord)}')

                            # add the budget constraint to the UNN solver
                            solver.add_budget_constraint(course_prices = approximate_prices, budget = budget)

                            # do not query the bundles you ahve already queried.
                            print(f'for student: {i} forbidding {len(bundles_to_forbid)} bundles')
                            for bundle in bundles_to_forbid:
                                solver.add_forbidden_bundle(bundle)

                            for j in range(number_of_samples):
                                try:
                                    new_x = solver.solve_mip(outputFlag=False, verbose = False)
                                # NOTE: should check that this works!!!
                                except:
                                    print('--- ACHTUNG ACHTUNG ---')
                                    print(f'GENERATE QUERIES STOPPED EARLY AT SAMPLE NUMBER {j}')
                                    break

                                solver.add_forbidden_bundle(new_x)  # add the new bundle we just queried to the list of forbidden bundles so that we don't ask the same question.
                                bundles_to_forbid.append(new_x)  # add the new bundle you just found to those that the maximizer won't return in the feauture
                                new_y = student(new_x, additive_prefs, substitutes, complements, timetable,
                                        overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                                        credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)

                                print(f'For student {i}, all bundles queried has a size of: {len(all_bundles_queried)}')
                                print(f'For student number: {i}, number of CQs already answered: {queries_thus_far}')
                                binary_search_indexes = binary_search(arr = all_values_queried, x = new_y)  # those are the indexes of the binary search to find the position of the new bundle
                                cq_limit = model_param_dictionary['cq_limit']

                                if (queries_thus_far + len(binary_search_indexes) <= cq_limit):
                                    print(f'Enough queries remaining for student {i}, treating this as insertion sort')
                                    queries_thus_far = queries_thus_far + len(binary_search_indexes)
                                    for k in range(len(all_bundles_queried)):
                                        old_x = all_bundles_queried[k]
                                        old_y = all_values_queried[k]

                                        new_x_ordinal = (new_x, old_x)  # compare the bundle that currently maximizes the MVNN to the max that you had

                                        if new_y >= old_y:
                                            new_y_ordinal = 1
                                        else:
                                            new_y_ordinal = 0

                                        if (sample_relative_frequencies) is not None:
                                            for _ in range(sample_relative_frequencies[2]):
                                                X_list.append(new_x_ordinal)
                                                y_list.append(new_y_ordinal)
                                        else:
                                            X_list.append(new_x_ordinal)
                                            y_list.append(new_y_ordinal)

                                    print(f'-------> ALL VALUES BEFORE INSERTION: {all_values_queried}')
                                    if (all_values_queried[binary_search_indexes[-1]] >= new_y):
                                        all_bundles_queried.insert(binary_search_indexes[-1], new_x)
                                        all_values_queried.insert(binary_search_indexes[-1], new_y)
                                    else:
                                        all_bundles_queried.insert(binary_search_indexes[-1] + 1, new_x)
                                        all_values_queried.insert(binary_search_indexes[-1] + 1, new_y)
                                    print(f'-----> ALL VALUES AFTER INSERTION: {all_values_queried}')

                                elif (queries_thus_far >= cq_limit):
                                    print(f'DONE! FOR student: {i} already at {queries_thus_far} queries, not adding anything for this student')

                                else:
                                    print(f'CORNER CASE!!! For student: {i} Already at {queries_thus_far} queries, and the new BS requires {len(binary_search_indexes)} queries, so have to prune smartly!')
                                    queries_left = cq_limit - queries_thus_far
                                    queries_thus_far = queries_left + queries_thus_far

                                    indexes_checked, lower_bound, upper_bound = binary_search_hacky(arr = all_values_queried, x = new_y, queries_available= queries_left)
                                    print(f'Value of the new point: {new_y}')

                                    counter_to_low = lower_bound
                                    while counter_to_low >= 0:
                                        print(f'Know that it is higher than the point with value: {all_values_queried[counter_to_low]}')
                                        old_x = all_bundles_queried[counter_to_low]
                                        new_x_ordinal = (new_x, old_x)
                                        new_y_ordinal = 1
                                        X_list.append(new_x_ordinal)
                                        y_list.append(new_y_ordinal)
                                        counter_to_low = counter_to_low - 1

                                    counter_to_high = upper_bound
                                    while counter_to_high < len(all_bundles_queried):
                                        print(f'Know that it is LOWER than the point with value: {all_values_queried[counter_to_high]}')
                                        old_x = all_bundles_queried[counter_to_high]
                                        new_x_ordinal = (new_x, old_x)
                                        new_y_ordinal = 0
                                        X_list.append(new_x_ordinal)
                                        y_list.append(new_y_ordinal)
                                        counter_to_high = counter_to_high + 1

                            print(f'For student {i}, the length of the new X_list after we just added 1 "sample" to it is: {len(X_list)}')

                            X_train_new = np.array(X_list)
                            y_train_new = np.array(y_list)

                            if len(X_list) >= 1:     # there are new CQ that we need to add to our dataset of comparision questions

                                if len(X_train_ord) > 0:   # add the new comparision queries to the dataset of comparisoin questions we already had
                                    X_train_ord = np.append(X_train_ord, X_train_new, axis = 0)
                                    y_train_ord = np.append(y_train_ord, y_train_new, axis = 0)

                                else:   # the dataset of CQs was empty -> replace it with all the CQs we did this round!
                                    X_train_ord = X_train_new
                                    y_train_ord = y_train_new

                            current_training_set[i] = ((X_train, y_train), (X_train_ord, y_train_ord), (all_bundles_queried, all_values_queried), bundles_to_forbid, queries_thus_far)

                    elif (model_param_dictionary['cq_method'] == 'complete_ordering_pruned' and (model_param_dictionary.get('CQ_mistake_probability', None) is not None)):
                            print('Entering complete ordering PRUNED version in generate new queries, with mistake probability!!!')
                            mistake_probability = model_param_dictionary['CQ_mistake_probability']  # the uniform mistake probability of a student getting a CQ wrong 
                            for (i, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(actual_student_list):
                                print(f'Expanding the GUI dataset with CQs for student: {i}')
                                if len(model_student_list[i]) == 5:
                                    (model, solver, scale, _ , budget) = model_student_list[i]  # this is the case where position 3 has a pre-trained model that we keep fixed (transfer learning case)
                                else:
                                    (model, solver, scale, budget) = model_student_list[i]  # get the current model for that student.
                                (X_train, y_train) = current_training_set[i][0]  # get the current (cardinal) dataset
                                (X_train_ord, y_train_ord) = current_training_set[i][1]  # get the current ordinal dataset
                                (all_bundles_queried, all_values_queried) = current_training_set[i][2]    # get the set of all points already queried
                                bundles_to_forbid = current_training_set[i][3]   # get the bundles that you don't want your solver to find (e.g. already queried, or implied)
                                queries_thus_far = current_training_set[i][4]  # so that we can count how many binary search queries were required in every iteration

                                X_list = []
                                y_list = []

                                print(f'Expanding the dataset for student: {i} whose X_train has a length of {len(X_train)} and X_train_ord has a shape of: {len(X_train_ord)}')

                                # add the budget constraint to the UNN solver
                                solver.add_budget_constraint(course_prices = approximate_prices, budget = budget)

                                # do not query the bundles you have already queried.
                                print(f'for student: {i} forbidding {len(bundles_to_forbid)} bundles')
                                for bundle in bundles_to_forbid:
                                    solver.add_forbidden_bundle(bundle)

                                
                                try:
                                    new_x = solver.solve_mip(outputFlag=False, verbose = False)
                                # NOTE: should check that this works!!!
                                except:
                                    print('--- ACHTUNG ACHTUNG ---')
                                    print(f'GENERATE QUERIES STOPPED EARLY AT SAMPLE NUMBER {j}')
                                    break

                                solver.add_forbidden_bundle(new_x)  # add the new bundle we just queried to the list of forbidden bundles so that we don't ask the same question.
                                bundles_to_forbid.append(new_x)  # add the new bundle you just found to those that the maximizer won't return in the feauture
                                new_y = student(new_x, additive_prefs, substitutes, complements, timetable,
                                        overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                                        credit_units = [1 for i in range(len(additive_prefs))], make_monotone = True)

                                print(f'For student {i}, all bundles queried has a size of: {len(all_bundles_queried)}')
                                print(f'For student number: {i}, number of CQs already answered: {queries_thus_far}')
                                # those are the indexes of the binary search to find the position of the new bundle
                                binary_search_indexes, new_sorted_value_array_noisy, insert_position = noisy_binary_search_and_insert(all_values_queried, new_y, mistake_probability= mistake_probability)  

                                cq_limit = model_param_dictionary['cq_limit']

                                if (queries_thus_far + len(binary_search_indexes) <= cq_limit):
                                    print(f'Enough queries remaining for student {i}, treating this as insertion sort')
                                    queries_thus_far = queries_thus_far + len(binary_search_indexes)
                                    for k in range(len(all_bundles_queried)):
                                        old_x = all_bundles_queried[k]

                                        new_x_ordinal = (new_x, old_x) 

                                        if insert_position > k: 
                                            # If the position to insert the new element is larger than the current position, this means that it was better in the pairwise comparisons
                                            new_y_ordinal = 1
                                        else: 
                                            new_y_ordinal = 0

                                        if (sample_relative_frequencies) is not None:
                                            for _ in range(sample_relative_frequencies[2]):
                                                X_list.append(new_x_ordinal)
                                                y_list.append(new_y_ordinal)
                                        else:
                                            X_list.append(new_x_ordinal)
                                            y_list.append(new_y_ordinal)

                                    print(f'-------> ALL VALUES BEFORE INSERTION: {all_values_queried}')
                                    all_bundles_queried.insert(insert_position, new_x)
                                    all_values_queried = new_sorted_value_array_noisy
                                    print(f'-----> ALL VALUES AFTER INSERTION: {all_values_queried}')

                                elif (queries_thus_far >= cq_limit):
                                    print(f'DONE! FOR student: {i} already at {queries_thus_far} queries, not adding anything for this student')

                                else:
                                    print(f'CORNER CASE!!! For student: {i} Already at {queries_thus_far} queries, and the new BS requires {len(binary_search_indexes)} queries, so have to prune smartly!')
                                    queries_left = cq_limit - queries_thus_far
                                    queries_thus_far = queries_left + queries_thus_far

                                    indexes_checked, lower_bound, upper_bound = noisy_binary_search_hacky(arr = all_values_queried, x = new_y, queries_available= queries_left, mistake_probability= mistake_probability)
                                    print(f'Value of the new point: {new_y}')

                                    counter_to_low = lower_bound
                                    while counter_to_low >= 0:
                                        print(f'Know that it is higher than the point with value: {all_values_queried[counter_to_low]}')
                                        old_x = all_bundles_queried[counter_to_low]
                                        new_x_ordinal = (new_x, old_x)
                                        new_y_ordinal = 1
                                        X_list.append(new_x_ordinal)
                                        y_list.append(new_y_ordinal)
                                        counter_to_low = counter_to_low - 1

                                    counter_to_high = upper_bound
                                    while counter_to_high < len(all_bundles_queried):
                                        print(f'Know that it is LOWER than the point with value: {all_values_queried[counter_to_high]}')
                                        old_x = all_bundles_queried[counter_to_high]
                                        new_x_ordinal = (new_x, old_x)
                                        new_y_ordinal = 0
                                        X_list.append(new_x_ordinal)
                                        y_list.append(new_y_ordinal)
                                        counter_to_high = counter_to_high + 1

                                print(f'For student {i}, the length of the new X_list after we just added 1 "sample" to it is: {len(X_list)}')

                                X_train_new = np.array(X_list)
                                y_train_new = np.array(y_list)

                                if len(X_list) >= 1:     # there are new CQ that we need to add to our dataset of comparision questions

                                    if len(X_train_ord) > 0:   # add the new comparision queries to the dataset of comparisoin questions we already had
                                        X_train_ord = np.append(X_train_ord, X_train_new, axis = 0)
                                        y_train_ord = np.append(y_train_ord, y_train_new, axis = 0)

                                    else:   # the dataset of CQs was empty -> replace it with all the CQs we did this round!
                                        X_train_ord = X_train_new
                                        y_train_ord = y_train_new

                                current_training_set[i] = ((X_train, y_train), (X_train_ord, y_train_ord), (all_bundles_queried, all_values_queried), bundles_to_forbid, queries_thus_far)

            return current_training_set
        
    else: 
        raise ValueError(f'Unknown model type: {model_type}')


def compute_hnn_comparison_output(model, data, scale):
    """

    """
    output1 = model(data[:, 0, :])
    output2 = model(data[:, 1, :])
    diff = output1.flatten() - output2.flatten()
    diff = diff * scale
    output = torch.sigmoid(diff)

    return output


def train_hnn(model, train_loader, optimizer, loss_func,
              use_gradient_clipping, clip_grad_norm_cardinal, clip_grad_norm_ordinal, scale,
              train_loader_ordinal=None, alpha=0.5,
              loss_func_ordinal=None, device=torch.device('cpu'), ordinal_data_batch_frequency=1,
              output_manipulation='method_1',
              training_method='scully'):
    """
    The main function used to train a network on mixed datasets, i.e., using both a cardinal and an ordinal dataloader.
    """

    model.train()
    loss_list = []
    loss_ordinal_list = []
#     print(f'train hnn invoked with alpha: {alpha}, clipping: {use_gradient_clipping} cardinal: {clip_grad_norm_cardinal} and ordinal: {clip_grad_norm_ordinal} and scale: {scale}')

    if train_loader_ordinal is not None:   #alpha: probability with which we will train on an ordinal sample 

        if training_method == 'scully':
            rand = np.random.uniform(0, 1, 1)

            for batch_idx, (data, target) in enumerate(train_loader):
                data, target = data.to(device), target.to(device)
                optimizer.zero_grad()
                output = model(data)
                loss = loss_func(output.flatten(), target.flatten())
                loss_list.append(loss.detach().numpy())
                if rand > alpha:
                    loss.backward()
                    if use_gradient_clipping:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm_cardinal)
                    optimizer.step()
            for batch_idx, (data_ordinal, target_ordinal) in enumerate(train_loader_ordinal):
                data_ordinal, target_ordinal = data_ordinal.to(device), target_ordinal.to(device)
                optimizer.zero_grad()
                output = compute_hnn_comparison_output(model=model, data=data_ordinal,  scale=scale)
                loss = loss_func_ordinal(input=output.flatten(), target=target_ordinal.flatten())
                loss_ordinal_list.append(loss.detach().numpy())
                if rand <= alpha:
                    loss.backward()
                    if use_gradient_clipping:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm_ordinal)
                    optimizer.step()

    else:
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = loss_func(output.flatten(), target.flatten())
            loss_list.append(loss.detach().numpy())
            loss.backward()
            if use_gradient_clipping:
                torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm_cardinal)
            optimizer.step()

    return np.mean(loss_list), np.mean(loss_ordinal_list)


def train_unn_mixed_datasets(X_train, y_train, X_train_ordinal, y_train_ordinal, alpha,
            num_hidden_layers, num_units, random_ts, trainable_ts, init_E, init_Var,
            learning_rate, weight_decay, epochs, batch_size, loss, loss_ordinal,
            use_gradient_clipping, clip_grad_norm_cardinal, clip_grad_norm_ordinal,
            print_frequency = 50, n_courses = 30, max_courses_in_bundle = 5, max_value_full_bundle = 2):
    """
    Takes as input the training set (as numpy arrays) and trains one MVNN of the latest version on that input, with all the MVNN hyperparameters.
    Uses both the GUI dataset, as well as the comparison dataset.
    """

    y_max_unscaled = y_train.max()
    initialization_constant = (max_value_full_bundle*max_courses_in_bundle) / n_courses

    y_train = y_train / y_max_unscaled
    y_train = y_train * initialization_constant

    scale = (y_max_unscaled / initialization_constant)

    ymax = initialization_constant

    # Step 3: Put the data in a dataloader
    train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(X_train).float(),
                                       torch.from_numpy(y_train.reshape(-1, 1)).float())
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size= batch_size, shuffle=True)

    # Step 3.5: Put the ordinal data in the dataloader
    if len(X_train_ordinal) > 0:
        train_dataset_ordinal = torch.utils.data.TensorDataset(torch.from_numpy(X_train_ordinal).float(),
                                           torch.from_numpy(y_train_ordinal.reshape(-1, 1)).float())
        train_loader_ordinal = torch.utils.data.DataLoader(train_dataset_ordinal, batch_size= batch_size, shuffle=True)
    else:
        print('Entering train_loader_ordinal = None case :D')
        train_loader_ordinal = None

    # Step 4: Initiate and train the model

    model = MVNN(input_dim=X_train.shape[1],
                         num_hidden_layers=num_hidden_layers,
                         num_units=num_units,
                         layer_type='MVNNLayerReLUProjected',
                         target_max=ymax,
                         dropout_prob=0,
                         init_method='custom',
                         random_ts= random_ts,
                         trainable_ts= trainable_ts,
                         init_E= init_E,
                         init_Var= init_Var,
                         init_b = 0.05,
                         init_bias = 0.05,
                         init_little_const = 0.1
                 )

    # make sure ts have no regularisation
    # the bigger t the more regular
    print(f'Setting init var to: {init_Var} hidden_layers to: {num_hidden_layers} and units to: {num_units} and weight decay: {weight_decay}')
    l2_reg_parameters = {'params': [], 'weight_decay': weight_decay}
    no_l2_reg_parameters = {'params': [], 'weight_decay': 0.0}
    for p in [*model.named_parameters()]:
        if 'ts' in p[0]:
            logging.debug(f'Setting L2-Reg. to 0.0 for {p[0]}.')
            no_l2_reg_parameters['params'].append(p[1])
        else:
            l2_reg_parameters['params'].append(p[1])

    optimizer = torch.optim.Adam([l2_reg_parameters, no_l2_reg_parameters], lr= learning_rate)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(epochs))

    for epoch in range(epochs):
        loss_mean, loss_mean_ordinal = train_hnn(model = model, train_loader= train_loader, optimizer = optimizer, loss_func = loss,
              use_gradient_clipping = use_gradient_clipping, clip_grad_norm_cardinal = clip_grad_norm_cardinal, clip_grad_norm_ordinal = clip_grad_norm_ordinal,
              train_loader_ordinal= train_loader_ordinal, alpha= alpha,
              loss_func_ordinal= loss_ordinal, scale= scale,
              device=torch.device('cpu'), output_manipulation='method_1', training_method='scully')

        scheduler.step()

        if epoch % print_frequency == 0:
            print(f'Current epoch: {epoch}, cardinal loss mean: {loss_mean}   ordinal loss mean: {loss_mean_ordinal}')

    model.transform_weights()
    return model, scale


def project_mvnns(mvnn_student_list, model_param_dictionary):
    """
    Makes a (potentially linear) MVNN projection.  
    """
    projected_model_student_list = [] 
    alpha = model_param_dictionary['proj_alpha']
    ridge = model_param_dictionary['ridge']
    fit_intercept = model_param_dictionary['fit_intercept']
    train_samples = model_param_dictionary['train_samples']
    train_on_whole_space = model_param_dictionary.get('train_on_whole_space', False)
    train_high_samples = model_param_dictionary['train_high_samples']
    
    for i in range(len(mvnn_student_list)):
        if len(mvnn_student_list[i]) == 4:
            (trained_mvnn, solver, scale, budget) = mvnn_student_list[i] 
        elif len(mvnn_student_list[i]) == 5:
            (trained_mvnn, solver, scale, pretrained_mvnn, budget) = mvnn_student_list[i]
        else: 
            raise ValueError('Incorrect number of arguments in mvnn_student_list')
        print(f'Currrent projection for student: {i}')
        X_train, Y_train, X_val, Y_val, X_full, Y_full = sample_mvnn(model = trained_mvnn, scale = scale, num_samples = train_samples, num_high_samples= train_high_samples, num_val_samples = 1000, seed = 42 + i)
        if train_on_whole_space:
                X_train = X_full
                Y_train = Y_full
        Y_train = Y_train.ravel()
        Y_val = Y_val.ravel()   # change the dimension of Y to (n, ) to fit the linear regression model
        projected_model, _, _, _, _ = poly_regression_mvnn(X_train, Y_train, X_val, Y_val, linear_projection= model_param_dictionary.get('linear_projection', False),
                                                        seed = 42, model_type= 'clr',  alpha = alpha, ridge = ridge, fit_intercept= fit_intercept) 
        projected_model_student_list.append((projected_model, budget))
    
    return projected_model_student_list

def create_iterative_student_list(training_set, student_list, credit_units, timetable, model_type = 'LinearRegression', model_param_dictionary = None, model_student_list = None):
    """
    Trains the ml model for all the students on a pre-given training set.

    Inputs:
    -----------
    training_set: list of length number_of_students
        training_set[i]: (X_train, y_train) tuple for the i-th student
    student_list: list of length number_of_students
        student_list[i]: The actual (true/noisy) characteristics of the i-th student
    model_type: string
        The ml model type to train
    model_param_dictionary: dictionary
        Dictionary containing model-specific parameters for training
    model_student_list: list of length number_of_students, of the same type as this function will return. 
        only needed for the TL approach to get the original networks. 
    """
    print("entered train_model_student_list")
    if (model_type == 'LinearRegression' or model_type == 'LinearRegressionNoisy'):
        linear_student_list = []
        for (i, (X_train, y_train)) in enumerate(training_set):
            budget = student_list[i][-1]
            reg = linear_model.LinearRegression().fit(X_train, y_train)
            linear_student_list.append((reg.coef_, budget))
        return linear_student_list
    
    elif (model_type == 'UNN_projected'): 
        print('create iterative student list entered UNN_projected case')
        linear_model_student_list = [] 
        mvnn_student_list = load_obj(model_param_dictionary['mvnn_student_list_path'])[0]
        alpha = model_param_dictionary['proj_alpha']
        ridge = model_param_dictionary['ridge']
        fit_intercept = model_param_dictionary['fit_intercept']
        train_samples = model_param_dictionary['train_samples']
        train_on_whole_space = model_param_dictionary.get('train_on_whole_space', False)
        train_high_samples = model_param_dictionary['train_high_samples']

        for (i,(trained_mvnn, solver, scale, budget)) in enumerate(mvnn_student_list): 
            train_start = timer()
            print(f'Currrent projection for student: {i}')
            X_train, Y_train, X_val, Y_val, X_full, Y_full = sample_mvnn(model = trained_mvnn, scale = scale, num_samples = train_samples, num_high_samples= train_high_samples, num_val_samples = 1000, seed = 42 + i)
            if train_on_whole_space:
                X_train = X_full
                Y_train = Y_full
            Y_train = Y_train.ravel()
            Y_full = Y_full.ravel()   # change the dimensionp of Y to (n, ) to fit the linear regression model
            linear_model, _, _, _, _ = poly_regression_mvnn(X_train, Y_train, X_full, Y_full, linear_projection= model_param_dictionary.get('linear_projection', False),
                                                            seed = 42, model_type= 'clr',  alpha = alpha, ridge = ridge, fit_intercept= fit_intercept) 
            linear_model_student_list.append((linear_model, budget))
            train_end = timer()
            print(f'Number of non-zero coefficients: {np.count_nonzero(linear_model.coef_)}, training time: {train_end - train_start} seconds')

        return linear_model_student_list
    
    elif model_type == 'UNN_transfer_learning':
        return create_TL_mvnn_student_list(model_param_dictionary = model_param_dictionary, training_set = training_set, timetable = timetable, student_list = student_list,
                                           mvnn_student_list = model_student_list, credit_units = credit_units)


    elif (model_type == 'Ridge' or model_type == 'RidgeNoisy'):
        linear_student_list = []
        for (i, (X_train, y_train)) in enumerate(training_set):
            budget = student_list[i][-1]
            reg = linear_model.Ridge(alpha = model_param_dictionary['alpha']).fit(X_train, y_train)
            linear_student_list.append((reg.coef_, budget))
        return linear_student_list

    elif (model_type == 'Lasso' or model_type == 'LassoNoisy'):
        linear_student_list = []
        for (i, (X_train, y_train)) in enumerate(training_set):
            budget = student_list[i][-1]
            reg = linear_model.Lasso(alpha = model_param_dictionary['alpha']).fit(X_train, y_train)
            linear_student_list.append((reg.coef_, budget))
        return linear_student_list

    elif (model_type == 'ElasticNet' or model_type == 'ElasticNetNoisy'):
        linear_student_list = []
        for (i, (X_train, y_train)) in enumerate(training_set):
            budget = student_list[i][-1]
            reg = linear_model.Lasso(alpha = model_param_dictionary['alpha']).fit(X_train, y_train)
            linear_student_list.append((reg.coef_, budget))
        return linear_student_list

    elif(model_type == 'NuSVR' or model_type == 'NuSVRNoisy'):
        svr_student_list = []
        for (i, (X_train, y_train)) in enumerate(training_set):
            budget = student_list[i][-1]

            model = NuSVR(kernel= model_param_dictionary['kernel'], degree = model_param_dictionary['degree'], nu = model_param_dictionary['nu'],
                gamma = model_param_dictionary['gamma'], C = model_param_dictionary['C'])

            if (model_param_dictionary['scale_ys']):
                y_train = y_train / max(y_train)
            # set_trace()
            model.fit(X_train, y_train)

            number_of_courses = X_train.shape[1]
            if model.gamma == 'scale':
                gamma = 1 / (number_of_courses * X_train.var())
            elif model.gamma == 'auto':
                gamma = 1 / number_of_courses  # number of courses == additive_prefs.shape[0]
            else:
                gamma = model_param_dictionary['gamma']
            solver = gurobi_MIP_SVR(model, gamma)

            solver.generate_mip(course_prices = np.repeat(0, number_of_courses), credit_units = credit_units, budget = budget, course_timetable = timetable, cu_max = 5, verbose = False)
            # NOTE: budget being 0 here should not matter, since the add_budget_constraint function will overwrite this with the right budget

            svr_student_list.append((model, solver, gamma, budget))
        return svr_student_list

    elif(model_type == 'UNN' or model_type == 'UNN_Noisy'):
        unn_student_list = []
        seconds_create_solver_total = 0
        seconds_generate_MIP_total = 0
        if (not model_param_dictionary.get('use_implied_dataset', False)):
            for (i, (X_train, y_train)) in enumerate(training_set):
                budget = student_list[i][-1]

                # Step 2: Scale the dataset (required for UNN)
                scaler = MinMaxScaler()
                scaler.fit(y_train.reshape(-1, 1))
                y_train = (scaler.transform(y_train.reshape(-1, 1))).reshape(-1)

                ymax = 1  # NOTE: DEFINETELY NOT SURE ABOUT THAT ONE!!!

                # Step 3: Put the data in a dataloader
                train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(X_train).float(),
                                                   torch.from_numpy(y_train.reshape(-1, 1)).float())
                train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)

                # Step 4: Initiate and train the model
                number_of_courses = X_train.shape[1]
                model = MVNN(input_dim= number_of_courses, num_hidden_layers= model_param_dictionary['UNN_layers'],
                    num_units= model_param_dictionary['UNN_units'], layer_type= model_param_dictionary['UNN_layer_type'], target_max= ymax)
                optimizer = torch.optim.Adam(model.parameters(), lr= model_param_dictionary['lr'], weight_decay= model_param_dictionary['weight_decay'])

                trained_model = train_unn_lean(model, optimizer, train_loader, ymax, epochs = model_param_dictionary['epochs'])
                trained_model.transform_weights()

                start = timer()
                if trained_model._num_hidden_layers >= 1:
                    solver = GUROBI_MIP2_MVNN(trained_model)
                else:
                    print('special case, LINEAR MVNN!!!')
                    solver = GUROBI_MIP2_MVNN_LINEAR(trained_model)
                mid = timer()
                solver.generate_mip(course_timetable=timetable,
                   credit_units= credit_units,
                   cu_max=5,
                   timeLimit=100,
                   MIPGap=0.0001,
                   verbose=False)

                end = timer()
                unn_student_list.append((trained_model, solver, budget))
                seconds_create_solver_total += (mid - start)
                seconds_generate_MIP_total += (end - mid)

            print(f'AVG time to create a solver for UNN after {len(unn_student_list)} solvers: {seconds_create_solver_total / len(unn_student_list)}')
            print(f'AVG time to generate a MIP for UNN after {len(unn_student_list)} solvers: {seconds_generate_MIP_total / len(unn_student_list)}')
            return unn_student_list

        else:      # we are in the use implied dataset case! -> Started with GUI reports

            if model_param_dictionary['UNN_loss_string'] == 'r2':
                loss = r2_loss

            elif model_param_dictionary['UNN_loss_string'] == 'l1':
                loss = F.l1_loss

            if not model_param_dictionary.get('use_cqs', False):
                for (i, all_sets_current_student) in enumerate(training_set):
                    (X_train, y_train) = all_sets_current_student[0]
                    budget = student_list[i][-1]

                    trained_model, scale = train_unn(X_train = X_train, y_train = y_train, num_hidden_layers = model_param_dictionary['UNN_layers'],
                        num_units = model_param_dictionary['UNN_units'], random_ts = model_param_dictionary['UNN_random_ts'], trainable_ts = model_param_dictionary['UNN_trainable_ts'],
                        init_E = model_param_dictionary['UNN_init_E'], init_Var = model_param_dictionary['UNN_init_Var'],
                        learning_rate = model_param_dictionary['lr'], weight_decay = model_param_dictionary['weight_decay'],
                        epochs = model_param_dictionary['epochs'], batch_size = model_param_dictionary['batch_size'], loss = loss)
                    
                    trained_model.transform_weights()

                    start = timer()
                    if trained_model._num_hidden_layers >= 1:
                        solver = GUROBI_MIP2_MVNN(trained_model)
                    else:
                        print('special case, LINEAR MVNN!!!')
                        solver = GUROBI_MIP2_MVNN_LINEAR(trained_model)
                    mid = timer()
                    solver.generate_mip(course_timetable=timetable,
                       credit_units=credit_units,
                       cu_max=5,
                       timeLimit=100,
                       MIPGap=0.000001,
                       verbose=False)
                    end = timer()
                    unn_student_list.append((trained_model, solver, budget))
                    seconds_create_solver_total += (mid - start)
                    seconds_generate_MIP_total += (end - mid)

                print(f'AVG time to create a solver for UNN after {len(unn_student_list)} solvers: {seconds_create_solver_total / len(unn_student_list)}')
                print(f'AVG time to generate a MIP for UNN after {len(unn_student_list)} solvers: {seconds_generate_MIP_total / len(unn_student_list)}')
                return unn_student_list

            else:
                print('Using the hybrid version with CQs!')

                if model_param_dictionary['UNN_loss_string_ordinal'] == 'BCE':
                    loss_ordinal = torch.nn.BCELoss()
                elif model_param_dictionary['UNN_loss_string_ordinal'] == 'GCE':
                    loss_ordinal = GeneralizedCrossEntropyLoss(q = model_param_dictionary['GCE_q'])
                else: 
                    raise ValueError(f'Unknown ordinal loss string: {model_param_dictionary["UNN_loss_string_ordinal"]}')

                for (i, all_sets_current_student) in enumerate(training_set):
                    # (model, solver, scale, budget) = model_student_list[i]  # get the current model for that student.
                    # (X_train, y_train) = current_training_set[i][0]  # get the current (cardinal) dataset
                    # (X_train_ord, y_train_ord) = current_training_set[i][1]  # get the current ordinal dataset
                    # (x_max_current, y_max_current) = current_training_set[i][2]    # get the current best performing point, according to the true student
                    # bundles_to_forbid = current_training_set[i][3]   # get the bundles that you don't want your solver to find (e.g. already queried, or implied)

                    (X_train, y_train) = all_sets_current_student[0]
                    (X_train_ord, y_train_ord) = all_sets_current_student[1]
                    bundles_to_forbid = all_sets_current_student[3]

                    budget = student_list[i][-1]

                    trained_model, scale = train_unn_mixed_datasets(X_train = X_train, y_train = y_train, X_train_ordinal = X_train_ord, y_train_ordinal = y_train_ord, alpha = model_param_dictionary['alpha'],
                        num_hidden_layers = model_param_dictionary['UNN_layers'], num_units = model_param_dictionary['UNN_units'], random_ts = model_param_dictionary['UNN_random_ts'], trainable_ts = model_param_dictionary['UNN_trainable_ts'],
                        init_E = model_param_dictionary['UNN_init_E'], init_Var = model_param_dictionary['UNN_init_Var'], learning_rate = model_param_dictionary['lr'], weight_decay = model_param_dictionary['weight_decay'],
                        epochs = model_param_dictionary['epochs'], batch_size = model_param_dictionary['batch_size'], loss = loss,
                        loss_ordinal = loss_ordinal, use_gradient_clipping = model_param_dictionary['use_gradient_clipping'],
                        clip_grad_norm_cardinal = model_param_dictionary['clip_cardinal'], clip_grad_norm_ordinal = model_param_dictionary['clip_ordinal'])

                    trained_model.transform_weights()
                    start = timer()
                    if trained_model._num_hidden_layers >= 1:
                        solver = GUROBI_MIP2_MVNN(trained_model)
                    else:
                        print('Creating linear MVNN MIP!')
                        solver = GUROBI_MIP2_MVNN_LINEAR(trained_model)
                    mid = timer()
                    solver.generate_mip(course_timetable=timetable,
                       credit_units=credit_units,
                       cu_max=5,
                       timeLimit=100,
                       MIPGap=0.000001,
                       verbose=False)
                    end = timer()
                    unn_student_list.append((trained_model, solver, scale, budget))
                    seconds_create_solver_total += (mid - start)
                    seconds_generate_MIP_total += (end - mid)

                print(f'AVG time to create a solver for UNN after {len(unn_student_list)} solvers: {seconds_create_solver_total / len(unn_student_list)}')
                print(f'AVG time to generate a MIP for UNN after {len(unn_student_list)} solvers: {seconds_generate_MIP_total / len(unn_student_list)}')
                return unn_student_list

    elif(model_type == 'xgboost' or model_type == 'xgboostNoisy'):
        xgboost_student_list = []
        seconds_create_solver_total = 0
        seconds_generate_MIP_total = 0
        for (i, (X_train, y_train)) in enumerate(training_set):
            budget = student_list[i][-1]

            model = xgboost.XGBRegressor(colsample_bytree = model_param_dictionary['colsample_bytree'], eta = model_param_dictionary['eta'], max_depth= model_param_dictionary['max_depth'],
                        n_estimators= model_param_dictionary['n_estimators'], subsample= model_param_dictionary['subsample'])

            if (model_param_dictionary['scale_ys']):   # scale y's same as behnoosh, if we have to.
                y_train = y_train / max(y_train)

            model.fit(X_train, y_train)
            number_of_courses = X_train.shape[1]
            start = timer()
            solver = gurobi_MIP_xgboost(model, number_of_courses)
            mid = timer()
            solver.generate_mip(credit_units= credit_units, cu_max=5, course_timetable= timetable)
            end = timer()
            # print(f'Seconds to create the xgboost solver: {mid - start}')
            # print(f'Seconds to generate the MIP with the solver: {end - mid}')
            seconds_create_solver_total += (mid - start)
            seconds_generate_MIP_total += (end - mid)

            xgboost_student_list.append((model, solver, budget))

        print(f'AVG time to create a solver for xgboost after {len(xgboost_student_list)} solvers: {seconds_create_solver_total / len(xgboost_student_list)} ')
        print(f'AVG time to generate a MIP for xgboost after {len(xgboost_student_list)}solvers: {seconds_generate_MIP_total / len(xgboost_student_list)}')

        # solve part:
        # add course_timetable, and cu constraints here

        return xgboost_student_list

    else:
        print(f'No valid model type provided. Actual model type provided: {model_type}')
        return 0


def create_actual_student_list(true_student_list, model_type, model_param_dictionary, seed):
    """
    Will noisify the students, if the word Noisy is in the modeltype name. Otherwise, it will return the true_student_list.
    Idea: actual student list is the student list that will answer the CQs.
    """
    if 'Noisy' in model_type:
        return noisify_all_students(true_student_list, forget_base = model_param_dictionary['noisy_forget_base'], forget_adjustments = model_param_dictionary['noisy_forget_adjustments'],
                base_noise_std = model_param_dictionary['noisy_base_std'],  adjustment_noise_std = model_param_dictionary['noisy_adj_std'], seed = seed,
                multiplicative_base_noise= model_param_dictionary.get('multiplicative_base_noise', False))
    
    elif model_param_dictionary['cq_method'] == 'complete_ordering_pruned':
        if model_param_dictionary.get('cq_noise_multiplier', False):
            print('---> Noisy CQs!')
            cq_noise_multiplier = model_param_dictionary['cq_noise_multiplier']
            return noisify_all_students(true_student_list, forget_base = model_param_dictionary['gui_forget_base'] * cq_noise_multiplier, forget_adjustments = model_param_dictionary['gui_forget_adjustments'] * cq_noise_multiplier,
                    base_noise_std = model_param_dictionary['gui_base_noise_std'] * cq_noise_multiplier,  adjustment_noise_std = model_param_dictionary['gui_adj_std'] * cq_noise_multiplier, seed = seed,
                    multiplicative_base_noise= model_param_dictionary.get('multiplicative_base_noise', False))
        else: 
            return true_student_list
    else:
        return true_student_list


def create_gui_student_list(true_student_list, model_type, model_param_dictionary, seed):
    """
    Noisifies the students in a way compatible with the GUI interface, i.e., it is apparent which course each student forgot to declare a base value for. 
    """
    if model_param_dictionary.get('use_implied_dataset', False):
        return guisify_all_students(student_list= true_student_list, forget_base= model_param_dictionary['gui_forget_base'], forget_base_uniform= model_param_dictionary.get('gui_forget_base_uniform', 0),
                forget_adjustments= model_param_dictionary['gui_forget_adjustments'],
                base_noise_std= model_param_dictionary['gui_base_noise_std'], adjustment_noise_std = model_param_dictionary['gui_adj_std'], seed = seed,
                multiplicative_base_noise= model_param_dictionary.get('gui_multiplicative_base_noise', False), 
                cognomos_interface= model_param_dictionary.get('cognomos_projection', False))
    else:
        return []
    
def create_hallucinated_gui_student_list(gui_student_list, actual_student_list, model_param_dictionary):
    
    hallucinated_gui_student_list = [] # list of (base_value, substitutes, complements, budgets)
    for (i, (base_values, substitutes, complements, unforgotten_bases)) in enumerate(gui_student_list):
        # Create the hallucinated base values -> all forgotten courses get a number 
        unforgotten_base_indices = np.where(unforgotten_bases)[0]
        forgotten_base_indices = np.where(unforgotten_bases == 0)[0]
        hallucinated_base_values = base_values.copy()
        if model_param_dictionary.get('uniform_range_high', None) is None:              # by default the low_value is 0 and the high_value is the minimum of the courses that a student declared a value for
            uniform_range_high = base_values[unforgotten_base_indices].min()
        else: 
            uniform_range_high = model_param_dictionary['uniform_range_high']
        if model_param_dictionary.get('uniform_range_low', None) is None:
            uniform_range_low = 0
        else: 
            uniform_range_low = model_param_dictionary['uniform_range_low']

        hallucinated_base_values[forgotten_base_indices] = (uniform_range_high + uniform_range_low) / 2  # all forgotten courses get the same value, i.e., no randomness
        budget = actual_student_list[i][-1]
        complements_clipped, substitutes_clipped = keep_pairwise_adjustments(complements, substitutes)

        hallucinated_gui_student_list.append((hallucinated_base_values, substitutes_clipped, complements_clipped, budget))
    
    
    return hallucinated_gui_student_list
        



def run_iterative_stage1_principled(true_student_list, timetable, capacities, percentage_neighbors_per_iteration = [3, 3, 3], gradient_neighbors_per_iteration = [20, 20, 20],
        individual_neighbors_per_iteration = [0, 15, 30], maximum_number_of_restarts_per_iteration = [3, 3, 5], seed = 42, clearing_error_limit = 1, time_limit_restart_per_iteration = [500, 750, 750],
        time_limit_search_per_iteration = [2500, 2500, 2500], max_gradient_multiplier = (0.1 / (2**6)), max_steps_without_improvement_per_iteration = [3, 3, 5], models_to_run = [],
        queries_per_iteration = [30, 20, 10], retrain_frequency = 1):

    value_list_total_stage1_iterative_all_models = []  # final shape: <models> x <iterations> x <students>
    time_taken_iterative_stage1_all_models = []    # final shape: <models> x <iterations>
    individual_demands_iterative_all_models = []  # final shape: <models> x <iterations> x <students>

    number_of_courses = len(true_student_list[0][0])  # true_student_profiles shape: (number_of_students, )
    number_of_students = len(true_student_list)
    print(f'number of students: {number_of_students} and number of courses: {number_of_courses}')

    clearing_error_total_stage1_all_models = []  # final shape: <number of models> x iterations
    oversubcription_error_total_stage1_all_models = []   # final shape: <number of models> x iterations
    tabu_prices_all_models = []    # final shape: <number of models> x <number of iterations> x <number of courses>

    allocations_all_models = []  # final shape: <models> x <iterations>  x <students> x <courses>
    prices_iterative_all_models_s1 = []  # final shape: <models> x <iterations> x <courses>

    student_list_per_model = []   # final shape: <models> x <students>

    training_set_per_model = []  # final shape: <models>  x <students>

    for i in range(len(models_to_run)):
        value_list_total_stage1_iterative_all_models.append([])
        clearing_error_total_stage1_all_models.append([])
        oversubcription_error_total_stage1_all_models.append([])
        allocations_all_models.append([])
        prices_iterative_all_models_s1.append([])
        individual_demands_iterative_all_models.append([])
        tabu_prices_all_models.append([])

        student_list_per_model.append([])
        time_taken_iterative_stage1_all_models.append([])
        training_set_per_model.append([])

    if (clearing_error_limit is None):
        clearing_error_limit = calculate_theoretical_bound_squared(number_of_courses_M = number_of_courses, largest_bundle_size_k = 5)  # the theoretical bound

    # get all the timetables and the lists for the experiments
    np.random.seed(seed)
    maximum_budget = np.max([student[-1] for student in true_student_list])
    print(f'Total number of seats: {np.sum(capacities)}')

    # Step 0: Create compatitable true/noisy lists for all models
    actual_student_list_per_model = []
    gui_student_list_per_model = []
    for (j, (model_type, model_param_dictionary)) in enumerate(models_to_run):
        actual_student_list = create_actual_student_list(true_student_list, model_type, model_param_dictionary, seed = seed)  
        # all models -> for the same noise parameters get the same noisy student list (if the model name has the word noisy in it), else get the true student list 
        actual_student_list_per_model.append(actual_student_list)

        gui_student_list = create_gui_student_list(true_student_list, model_type, model_param_dictionary, seed = seed)
        gui_student_list_per_model.append(gui_student_list)
        # these are the preferences as entered in the GUI, i.e., it is highlighted if a course is actually 0 or forgotten by a student 

    for (j, (model_type, model_param_dictionary)) in enumerate(models_to_run):
        approximate_prices = None  # approximate prices are the result of running intermediate Tabus, so they start as None. 
        actual_student_list = actual_student_list_per_model[j]
        gui_student_list = gui_student_list_per_model[j]

        for (query_iteration, number_of_samples) in enumerate(queries_per_iteration):
            print(f'---- Current query iteration: {query_iteration}  ----')
            # Step 1: Ask as many queries as the next "iteration" allows you to.
            past_samples = 0
            while (past_samples < number_of_samples):
                # Idea: Call the generate_new_queries with as many samples as retrain_frequency allows you to, then retrain the model and call it again.
                # When the nnumber of samples reaches as many as the current iteration allows you to, then you run tabu once to update the prices.
                print(f'Entering while loop with past samples: {past_samples}')
                if (approximate_prices is None):
                    samples_to_query = number_of_samples
                else:
                    samples_to_query = min(retrain_frequency, number_of_samples - past_samples)
                past_samples += samples_to_query

                print(f'Querying an additional {samples_to_query} samples')
                training_set_per_model[j] = generate_new_queries(actual_student_list = actual_student_list, timetable = timetable, number_of_samples = samples_to_query,
                    current_training_set = training_set_per_model[j], model_student_list = student_list_per_model[j], approximate_prices = approximate_prices,
                    credit_units = [1 for i in range(number_of_courses)], model_type = model_type, seed= np.argmax(actual_student_list[42][0]) * 317, model_param_dictionary = model_param_dictionary, gui_student_list = gui_student_list)
                # NOTE: the seed is such that all models that have the same actual student list get the same seed, but at the same time it is not the same between different instances!

                # set_trace()
            # Step 2: Retrain your model using all samples, both new and old.
                model_student_list = create_iterative_student_list(training_set_per_model[j], actual_student_list, credit_units = [1 for i in range(number_of_courses)],
                    timetable = timetable, model_type = model_type, model_param_dictionary = model_param_dictionary, 
                    model_student_list= student_list_per_model[j])
                student_list_per_model[j] = model_student_list

            # in case we are in the first iteration (phase 3 of MLCM)-> potentially use a different model for the approximate price generation
            approximate_prices_model = model_param_dictionary.get('approximate_prices_model', 'standard')
            if query_iteration == 0 and len(queries_per_iteration) > 1:
                print('Query iteration is 0, looking for what model to run for approximate price calculation!')
                if approximate_prices_model == 'gui_reports': 
                    print('Using gui reports for approximate prices!')
                    model_type_for_tabu = 'PairwiseAdjustmentsNoisy'
                    model_param_dictionary_for_tabu = {
                        'noisy_forget_base': model_param_dictionary['gui_forget_base'],
                        'noisy_forget_adjustments': model_param_dictionary['gui_forget_adjustments'],
                        'noisy_base_std': model_param_dictionary['gui_base_noise_std'],
                        'noisy_adj_std': model_param_dictionary['gui_adj_std'],
                        'multiplicative_base_noise': model_param_dictionary.get('multiplicative_base_noise', False),
                        'noisy_forget_base_uniform': model_param_dictionary.get('noisy_forget_base_uniform', 0)
                    }
                    model_student_list_for_tabu = create_model_student_list(student_list = true_student_list, 
                        timetable = timetable,
                        model_type = model_type_for_tabu,
                        seed = seed, 
                        model_param_dictionary = model_param_dictionary_for_tabu)
                    
                elif approximate_prices_model == 'hallucinated_gui_reports':
                    print('Using hallucinated gui reports for approximate prices!')
                    model_type_for_tabu = 'PairwiseAdjustmentsNoisy' 
                    model_param_dictionary_for_tabu = {
                        'noisy_forget_base': model_param_dictionary['gui_forget_base'],
                        'noisy_forget_adjustments': model_param_dictionary['gui_forget_adjustments'],
                        'noisy_base_std': model_param_dictionary['gui_base_noise_std'],
                        'noisy_adj_std': model_param_dictionary['gui_adj_std'],
                        'multiplicative_base_noise': model_param_dictionary.get('multiplicative_base_noise', False),
                        'noisy_forget_base_uniform': model_param_dictionary.get('noisy_forget_base_uniform', 0)
                    }
                    model_student_list_for_tabu = create_hallucinated_gui_student_list(gui_student_list, actual_student_list, model_param_dictionary)
                    
                                                                            

                elif approximate_prices_model == 'standard':
                    print('Using ML model for approximate prices (standard case)')
                    model_student_list_for_tabu = model_student_list
                    model_type_for_tabu = model_type 
                    model_param_dictionary_for_tabu = model_param_dictionary
                else:   
                    raise ValueError(f'approximate_prices_model {approximate_prices_model} not recognized')
            
            
            elif query_iteration == len(queries_per_iteration) - 1:
                print('Query iteration is last, checking if we need to (project the MVNNs back to the GUI language)')
                if model_type in ['UNN', 'UNN_transfer_learning'] and model_param_dictionary.get('project_to_gui', False): 
                    print('Projecting the MVNNs that were previously trained back to the original GUI language!')
                    model_student_list_for_tabu = project_mvnns(model_student_list, model_param_dictionary)
                    model_type_for_tabu = 'UNN_projected'  # so that the right model is used for tabu. 
                    model_param_dictionary_for_tabu = model_param_dictionary
                    student_list_per_model[j] = model_student_list_for_tabu   # so that the PROJECTED  mvnns are saved and can be used in stages 2 and 3. 
                else: 
                    print('Query iteration is last, but no projection is needed')
                    model_student_list_for_tabu = model_student_list
                    model_type_for_tabu = model_type 
                    model_param_dictionary_for_tabu = model_param_dictionary

            else: 
                print('Query iteration is neither the first nor the last, using the ML model')
                model_student_list_for_tabu = model_student_list
                model_type_for_tabu = model_type 
                model_param_dictionary_for_tabu = model_param_dictionary

            
            # set_trace()


            # Step 3: Run Tabu Search with this new model and student list! 
            print(f'Exited while loop generating samples, running tabu with model type: {model_type_for_tabu}')
            # set_trace()
            start = timer()
            tabu_prices_model, final_error_model, statistics_model = heuristic_search(model_student_list_for_tabu, timetable, credit_units = [1 for i in range(number_of_courses)],
                capacities = capacities, max_budget = maximum_budget, max_steps_without_improvement = max_steps_without_improvement_per_iteration[query_iteration],
                clearing_error_limit = clearing_error_limit, time_limit_restart = time_limit_restart_per_iteration[query_iteration], time_limit_search = time_limit_search_per_iteration[query_iteration],
                number_percentage_neighbors = percentage_neighbors_per_iteration[query_iteration], number_gradient_neighbors = gradient_neighbors_per_iteration[query_iteration],
                number_individual_neighbors = individual_neighbors_per_iteration[query_iteration], max_restarts = maximum_number_of_restarts_per_iteration[query_iteration],
                model_type = model_type_for_tabu, model_param_dictionary = model_param_dictionary_for_tabu, max_gradient_multiplier = max_gradient_multiplier)
            end = timer()
            tabu_prices_all_models[j].append(tabu_prices_model)
            prices_iterative_all_models_s1[j].append(tabu_prices_model)
            approximate_prices = tabu_prices_model
            clearing_error_total_stage1_all_models[j].append(final_error_model)
            time_taken = end - start
            time_taken_iterative_stage1_all_models[j].append(time_taken)
            print(f'{model_type} tabu for query iteration {query_iteration} completed in {time_taken} seconds')

            # Step 4: get individual demands induced by this last price vector
            total_demand_model, individual_demands_model = calculate_total_demand(tabu_prices_model, model_student_list_for_tabu, timetable,
                    [1 for k in range(number_of_courses)], return_individual_demands = True, model_type = model_type_for_tabu, model_param_dictionary = model_param_dictionary)

            individual_demands_iterative_all_models[j].append(individual_demands_model)
            allocations_all_models[j].append(individual_demands_model)
            oversubscription = np.maximum(total_demand_model - capacities, 0).sum()
            oversubcription_error_total_stage1_all_models[j].append(oversubscription)

            # step 5: get the student value for all individual demands induced by all price vectors
            value_list_model = []
            for (student_number, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(true_student_list):
                student_value = student(individual_demands_model[student_number], additive_prefs, substitutes, complements, timetable,
                        overload_penalty = overload_penalty, free_days_marginal_values = free_days_marginal_values,
                        credit_units = [1 for i in range(number_of_courses)], make_monotone = True)
                value_list_model.append(student_value)

            value_list_total_stage1_iterative_all_models[j].append(value_list_model)

    for (i, model_type) in enumerate(models_to_run):  # cleanup the gurobi xgboost solver because it can't be saved like this (and we don't need it!)
        
        if model_type[0] in ['UNN', 'UNN_transfer_learning'] and model_param_dictionary.get('project_to_gui', False):
            print('Using Projected MVNNS, student list can be saved as is!') 
        elif model_type[0] in ['xgboost', 'UNN', 'UNN_transfer_learning', 'xgboostNoisy', 'UNN_Noisy']:
            if not model_param_dictionary.get('use_cqs', False):
                for j in range(len(student_list_per_model[i])):   # i is the model, j is the student in that model
                    (model, solver, budget) = student_list_per_model[i][j]
                    student_list_per_model[i][j] = (model, 'solver', budget)
            else:
                for j in range(len(student_list_per_model[i])):   # i is the model, j is the student in that model
                    if len(student_list_per_model[i][j]) == 4:
                        (model, solver, scale, budget) = student_list_per_model[i][j]
                        student_list_per_model[i][j] = (model, 'solver', scale, budget)
                    elif len(student_list_per_model[i][j]) == 5:
                        (model, solver, scale, pretrained_model, budget) = student_list_per_model[i][j]
                        student_list_per_model[i][j] = (model, 'solver', scale, pretrained_model, budget)

        elif model_type[0] == 'NuSVR' or model_type[0] == 'NuSVRNoisy':
            for j in range(len(student_list_per_model[i])):   # i is the run
                (model, solver, gamma, budget) = student_list_per_model[i][j]
                student_list_per_model[i][j] = (model, 'solver', gamma, budget)

    return (np.array(value_list_total_stage1_iterative_all_models), student_list_per_model, np.array(clearing_error_total_stage1_all_models), np.array(oversubcription_error_total_stage1_all_models),
            np.array(allocations_all_models), np.array(time_taken_iterative_stage1_all_models), np.array(prices_iterative_all_models_s1))


def run_rsd_principled(true_student_profiles, loaded_student_lists, timetables_all_runs, capacities_all_runs, models_to_run = [('PairwiseAdjustmentsNoisy', {})], seed = 42):
    # ---  stage 2  ---
    # step 5: Run stage 2 for all models

    value_list_total_stage1_all_models = []
    value_list_total_stage2_all_models = []  # final shape: <models> x <runs> x <students>
    value_list_total_stage3_all_models = []  # final shape: <models> x <runs> x <students>
    time_taken_stage1_all_models = []
    time_taken_stage2_all_models = []
    time_taken_stage3_all_models = []
    prices_all_models_s2 = []

    number_of_courses = true_student_profiles[0][0][0].shape[0]   # true_student_profiles shape: run x student
    number_of_students = len(true_student_profiles[0])
    print(f'number of students: {number_of_students} and number of courses: {number_of_courses}')

    allocations_all_models_stage1 = [[[0 for i in range(number_of_courses)] for j in range(number_of_students)]]
    allocations_all_models_stage2 = []  # final shape: <models> x <runs> x <students>
    allocations_all_models_stage3 = []  # final shape: <models> x <runs> x <students>
    student_list_total_all_runs = []   # final shape: <models> x <runs> x <students>

    for i in range(len(models_to_run)):
        value_list_total_stage2_all_models.append([])
        value_list_total_stage3_all_models.append([])
        prices_all_models_s2.append([])
        allocations_all_models_stage2.append([])
        allocations_all_models_stage3.append([])

        student_list_total_all_runs.append([])
        time_taken_stage2_all_models.append([])
        time_taken_stage3_all_models.append([])

    for i in range(len(true_student_profiles)):
        print(f'Problem instance number: {i}')
        np.random.seed(seed + (i * number_of_students))
        capacities = capacities_all_runs[i]
        timetable = timetables_all_runs[i]
        student_list_true = true_student_profiles[i]
        maximum_budget = np.max([student[-1] for student in student_list_true])
        print(f'Total number of seats: {np.sum(capacities)}')

        # Step 6: (Re)create compatitable lists for all models
        student_list_per_model = []   # student_list_per_model shape: number of models x number of students (it only contains a single run!!!)
        # This for goes over all models to run in a single problem instance
        for (j, (model_type, model_param_dictionary)) in enumerate(models_to_run):
            # set_trace()
            # loaded student_lists shape: Model x Run x student
            model_student_list = loaded_student_lists[j][i]
            student_list_total_all_runs[j].append(model_student_list)
            student_list_per_model.append(model_student_list)

        # Step 7: Calculate stage 2 prices for a single run, for all models
        for (j, (model_type, model_param_dictionary)) in enumerate(models_to_run):
            start = timer()
            prices_adjusted = np.array([0 for k in range(number_of_courses)])
            end = timer()
            prices_all_models_s2[j].append(prices_adjusted)
            time_taken = end - start
            print(f'Finished with fake stage 2 for the {model_type} model in {time_taken} seconds')
            time_taken_stage2_all_models[j].append(time_taken)

        # step 8: get individual demands induced by all stage 2 price vectors
        individual_demands_all_models_s2 = []
        for (j, (model_type, model_param_dictionary)) in enumerate(models_to_run):
            individual_demands_model = np.array([[0 for k in range(number_of_courses)] for l in range(number_of_students)])
            individual_demands_all_models_s2.append(individual_demands_model)
            allocations_all_models_stage2[j].append(individual_demands_model)

        # step 9: Get the students' values for those allocations
        for j in range(len(models_to_run)):
            value_list_model = []
            for (k, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(student_list_true):
                student_value = student(individual_demands_all_models_s2[j][k], additive_prefs, substitutes, complements, timetable,
                        overload_penalty = overload_penalty, free_days_marginal_values= free_days_marginal_values,
                        credit_units= [1 for i in range(number_of_courses)], make_monotone= True)
                value_list_model.append(student_value)

            value_list_total_stage2_all_models[j].append(value_list_model)

        #     ---  stage 3 ---
        # step 10: Run stage 3 for all models
        final_allocation_all_models = []
        individual_demands_all_models_copy = individual_demands_all_models_s2.copy()
        for (j, (model_type, model_param_dictionary)) in enumerate(models_to_run):
            start = timer()
            # set_trace()
            final_allocation_model = stage3_rsd(prices_all_models_s2[j][-1], student_list_per_model[j], timetable,
                    individual_demands_all_models_copy[j], capacities, [1 for i in range(number_of_courses)], model_type = model_type)
            end = timer()
            print(f'Number of matches between stage2 and stage3 for this model: {(final_allocation_model == individual_demands_all_models_s2[j]).sum()}')
            final_allocation_all_models.append(final_allocation_model)
            allocations_all_models_stage3[j].append(final_allocation_model)
            time_taken = end - start
            print(f'Stage 3 for the {model_type} model finished in {end - start} seconds.')
            time_taken_stage3_all_models[j].append(time_taken)

        # step 9: Calculate the value of those allocations
        # set_trace()
        for j in range(len(models_to_run)):
            value_list_model = []
            for (k, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(student_list_true):
                student_value = student(allocations_all_models_stage3[j][-1][k], additive_prefs, substitutes, complements, timetable,
                        overload_penalty = overload_penalty, free_days_marginal_values= free_days_marginal_values,
                        credit_units= [1 for i in range(number_of_courses)], make_monotone= True)
                value_list_model.append(student_value)

            value_list_total_stage3_all_models[j].append(value_list_model)

    # set_trace()
    value_list_total_stage2_all_models = np.array(value_list_total_stage2_all_models)
    allocations_all_models_stage2 = np.array(allocations_all_models_stage2)
    prices_all_models_s2 = np.array(prices_all_models_s2)
    time_taken_stage2_all_models = np.array(time_taken_stage2_all_models)
    value_list_total_stage3_all_models = np.array(value_list_total_stage3_all_models)
    allocations_all_models_stage3 = np.array(allocations_all_models_stage3)
    time_taken_stage3_all_models = np.array(time_taken_stage3_all_models)

    return value_list_total_stage2_all_models, allocations_all_models_stage2, prices_all_models_s2, time_taken_stage2_all_models, value_list_total_stage3_all_models, allocations_all_models_stage3, time_taken_stage3_all_models


def stage3_rsd(prices, student_list_sorted, timetable,  individual_demands, capacities, credit_units, budget_increase_percetange = 1.1, check_sanity = False, max_courses = 5, model_type = 'True'):
    """
    Implementation of Algorithm 3 of Budish et Al.

    Parameters:
    --------------------
    prices: np.array of shape(number_of_courses, )
        prices[i]: The price of the i-th course
    student_list_sorted: model list for all of the students sorted according to stage 3 of the algorithm
        student_profile[i]: THe model represenation of the i-th student, sorted in accordance to stage 3 of the Course Match Mechanism
    timetable: list of lists of ints
        course_timetable[i][j]: The ids of all courses being taught in the j-th timeslot of the i-th day
    budget_increase_percentage: float
        The ratio of new budget/old_budget of every student, e.g. 1.1 was used at Wharton.
    credit_units: list of floats
        credit_units[i]: The credit units of the i-th course
    capacities: np.array of shape(number_of_courses, )
        capacities[i]: The capacity of the i-th course

    Returns:
    --------------------
    individual_demands: np.array of shape (number_of_students, number_of_courses, )
        individual_demands[i][j]: 1 if the i-thj student is allocated the j-th course at the end of course match, 0 otherwise.
    """

    individual_demands_copy = individual_demands.copy()

    total_demands = individual_demands_copy.sum(axis = 0)
    free_seats = capacities - total_demands

    if (model_type in ['True', 'TrueNoisy']):
        for (i, (additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget)) in enumerate(student_list_sorted):
            seats_available_to_student = free_seats + individual_demands_copy[i]

            (student_demand, _, value) = solve_student(timetable, prices, credit_units, budget * budget_increase_percetange, max_courses, additive_prefs, complements, substitutes, overload_penalty = overload_penalty,
                        timegap_penalty= timegap_penalty, free_days_marginal_values= free_days_marginal_values, ignore_timegaps= True, verbose = False, seats_available = seats_available_to_student, time_output= True)
            student_demand = np.array(student_demand)

            if ((np.array(student_demand) == individual_demands_copy[i]).sum() != student_demand.shape[0]):
                (student_demand_debug, _, value_debug) = solve_student(timetable, prices, credit_units, budget, max_courses, additive_prefs, complements, substitutes, overload_penalty = overload_penalty,
                    timegap_penalty= timegap_penalty, free_days_marginal_values= free_days_marginal_values, ignore_timegaps= True, verbose = False, time_output= True)
                print(f'Student {i} changed his demand! NEW value: {value} OLD value: {value_debug}')
                if (value > value_debug + 0.01):
                    done = False
                    individual_demands_copy[i] = student_demand
                    break
                else:
                    print("Student tried to change his demand, but the change in value was less than 0.1")

    elif (model_type in ['PairwiseAdjustments', 'PairwiseAdjustmentsNoisy']):
        for (i, (additive_prefs, substitutes_clipped, complements_clipped, budget)) in enumerate(student_list_sorted):
            print(f'Current student: {i}')
            total_demands = individual_demands_copy.sum(axis = 0)
            free_seats = capacities - total_demands
            # seats_available_to_student = free_seats + individual_demands_copy[i]

            (student_demand, _, value) = solve_student(timetable, prices, credit_units, budget * budget_increase_percetange, max_courses, additive_prefs, complements_clipped,
                substitutes_clipped, overload_penalty = 0, timegap_penalty= 0, free_days_marginal_values= [0, 0, 0, 0, 0], ignore_timegaps= True, verbose = False,
                    seats_available = free_seats, time_output= True)
            student_demand = np.array(student_demand)

            print(f'Student {i} changed his demand! NEW value (according to PA): {value}')
            individual_demands_copy[i] = student_demand

    elif (model_type == 'NuSVR' or model_type == 'NuSVRNoisy'):
        for (i, (model, solver, gamma, budget)) in enumerate(student_list_sorted):
            seats_available_to_student = free_seats + individual_demands_copy[i]
            prices_copy = prices.copy()
            for k in range(len(seats_available_to_student)):
                if(seats_available_to_student[k] < 1):
                    prices_copy[k] = 10  # this way the student cannot get courses with no capacitiy left in this stage, because they have too high of a price.

            solver.generate_mip(course_prices = prices_copy, credit_units = credit_units, budget = budget * budget_increase_percetange, cu_max = 5, course_timetable = timetable)
            optimal_schedule, optimal_value = solver.solve_mip(verbose=False)
            student_demand = np.array(optimal_schedule)

            if ((np.array(student_demand) == individual_demands_copy[i]).sum() != student_demand.shape[0]):
                print(f'Student {i} changed his demand! NEW value (according to PA): {optimal_value}')
                done = False
                individual_demands_copy[i] = student_demand

    elif (model_type == 'UNN' or model_type == 'UNN_Noisy'):
        if len(student_list_sorted[0]) == 3:
            for (i, (model, solver, budget)) in enumerate(student_list_sorted):
                seats_available_to_student = free_seats + individual_demands_copy[i]
                prices_copy = prices.copy()
                for k in range(len(seats_available_to_student)):
                    if(seats_available_to_student[k] < 1):
                        prices_copy[k] = 10  # this way the student cannot get courses with no capacitiy left in this stage, because they have too high of a price.

                solver.add_budget_constraint(course_prices = prices_copy, budget = budget * budget_increase_percetange)
                student_demand, optimal_value = solver.solve_mip_rv(outputFlag=False, verbose = False)
                student_demand = np.array(student_demand)

                if ((np.array(student_demand) == individual_demands_copy[i]).sum() != student_demand.shape[0]):
                    print(f'Student {i} changed his demand! NEW value (according to UNN): {optimal_value}')
                    print(f'Old demand: {individual_demands_copy[i]} and new demand: {student_demand}')
                    done = False
                    individual_demands_copy[i] = student_demand

        elif len(student_list_sorted[0]) == 4:
            for (i, (model, solver, scale, budget)) in enumerate(student_list_sorted):
                seats_available_to_student = free_seats + individual_demands_copy[i]
                prices_copy = prices.copy()
                for k in range(len(seats_available_to_student)):
                    if(seats_available_to_student[k] < 1):
                        prices_copy[k] = 10  # this way the student cannot get courses with no capacitiy left in this stage, because they have too high of a price.

                solver.add_budget_constraint(course_prices = prices_copy, budget = budget * budget_increase_percetange)
                student_demand, optimal_value = solver.solve_mip_rv(outputFlag=False, verbose = False)
                student_demand = np.array(student_demand)

                if ((np.array(student_demand) == individual_demands_copy[i]).sum() != student_demand.shape[0]):
                    print(f'Student {i} changed his demand! NEW value (according to UNN): {optimal_value}')
                    print(f'Old demand: {individual_demands_copy[i]} and new demand: {student_demand}')
                    done = False
                    individual_demands_copy[i] = student_demand

    elif (model_type == 'xgboost' or model_type == 'xgboostNoisy'):
        for (i, (model, solver, budget)) in enumerate(student_list_sorted):
            seats_available_to_student = free_seats + individual_demands_copy[i]
            prices_copy = prices.copy()
            for k in range(len(seats_available_to_student)):
                if(seats_available_to_student[k] < 1):
                    prices_copy[k] = 10  # this way the student cannot get courses with no capacitiy left in this stage, because they have too high of a price.

            solver.add_budget_constraint(course_prices=prices_copy, budget= budget * budget_increase_percetange)
            student_demand, optimal_value = solver.solve_mip()

            if ((np.array(student_demand) == individual_demands_copy[i]).sum() != student_demand.shape[0]):
                print(f'Student {i} changed his demand! NEW value (according to xgboost): {optimal_value}')
                done = False
                individual_demands_copy[i] = student_demand

    return individual_demands_copy


def pretrain_mvnn_epoch(model, train_loader_cardinal, optimizer_cardinal, loss_func_cardinal,
              use_gradient_clipping, clip_grad_norm_cardinal, 
              device=torch.device('cpu')):
    """
    Performs ONE EPOCH of pre-training for a single MVNN on the GUI dataset.
    """

    model.train()
    loss_cardinal_list = []


    for batch_idx, (data, target) in enumerate(train_loader_cardinal):
        data, target = data.to(device), target.to(device)
        optimizer_cardinal.zero_grad()
        output = model(data)
        loss = loss_func_cardinal(output.flatten(), target.flatten())
        loss_cardinal_list.append(loss.detach().numpy())
        loss.backward()
        if use_gradient_clipping:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm_cardinal)
        optimizer_cardinal.step()

    return np.mean(loss_cardinal_list)


def finetune_mvnn_epoch(model, optimizer_ordinal,
              use_gradient_clipping, clip_grad_norm_ordinal, scale,
              train_loader_ordinal=None,
              loss_func_ordinal=None, device=torch.device('cpu')):
    """
    Performs ONE EPOCH of model finetuning on the CQ dataset.
    """

    model.train()
    loss_ordinal_list = []

    for batch_idx, (data_ordinal, target_ordinal) in enumerate(train_loader_ordinal):
        data_ordinal, target_ordinal = data_ordinal.to(device), target_ordinal.to(device)
        optimizer_ordinal.zero_grad()
        output = compute_hnn_comparison_output(model=model, data=data_ordinal,  scale=scale)
        loss = loss_func_ordinal(input=output.flatten(), target=target_ordinal.flatten())
        loss_ordinal_list.append(loss.detach().numpy())
        loss.backward()
        if use_gradient_clipping:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm_ordinal)
        optimizer_ordinal.step()

    return np.mean(loss_ordinal_list)


def pretrain_mvnn(X_train, y_train,
            num_hidden_layers, num_units, random_ts, trainable_ts, init_E, init_Var,
            learning_rate_cardinal, weight_decay_cardinal, epochs_cardinal, batch_size_cardinal, loss_cardinal,
            use_gradient_clipping, clip_grad_norm_cardinal, 
            print_frequency = 50, n_courses = 25, max_courses_in_bundle = 5, max_value_full_bundle = 2):
    """
    Takes as input the training set (as numpy arrays) and trains one MVNN of the latest version on that input, with all the MVNN hyperparameters.
    Uses both the GUI dataset, as well as the comparison dataset.
    """
    print(f'---> Performing pretraining on the GUI dataset---')

    y_max_unscaled = y_train.max()
    initialization_constant = (max_value_full_bundle*max_courses_in_bundle) / n_courses

    y_train = y_train / y_max_unscaled
    y_train = y_train * initialization_constant

    scale = (y_max_unscaled / initialization_constant)

    ymax = initialization_constant

    # Step 3: Put the data in a dataloader
    train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(X_train).float(),
                                       torch.from_numpy(y_train.reshape(-1, 1)).float())
    train_loader_cardinal = torch.utils.data.DataLoader(train_dataset, batch_size= batch_size_cardinal, shuffle=True)

    # Step 4: Initiate and train the model

    model = MVNN(input_dim=X_train.shape[1],
                         num_hidden_layers=num_hidden_layers,
                         num_units=num_units,
                         layer_type='MVNNLayerReLUProjected',
                         target_max=ymax,
                         dropout_prob=0,
                         init_method='custom',
                         random_ts= random_ts,
                         trainable_ts= trainable_ts,
                         init_E= init_E,
                         init_Var= init_Var,
                         init_b = 0.05,
                         init_bias = 0.05,
                         init_little_const = 0.1
                 )

    # make sure ts have no regularisation
    # the bigger t the more regular
    print(f'Setting init var to: {init_Var} hidden_layers to: {num_hidden_layers} and units to: {num_units} and CARDINAL weight decay to: {weight_decay_cardinal}')
    l2_reg_parameters = {'params': [], 'weight_decay': weight_decay_cardinal}
    no_l2_reg_parameters = {'params': [], 'weight_decay': 0.0}
    for p in [*model.named_parameters()]:
        if 'ts' in p[0]:
            logging.debug(f'Setting L2-Reg. to 0.0 for {p[0]}.')
            no_l2_reg_parameters['params'].append(p[1])
        else:
            l2_reg_parameters['params'].append(p[1])

    optimizer_cardinal = torch.optim.Adam([l2_reg_parameters, no_l2_reg_parameters], lr= learning_rate_cardinal)

    scheduler_cardinal = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_cardinal, float(epochs_cardinal))

    for epoch in range(epochs_cardinal):
        loss_mean = pretrain_mvnn_epoch(model = model, train_loader_cardinal= train_loader_cardinal, optimizer_cardinal = optimizer_cardinal, loss_func_cardinal = loss_cardinal,
                use_gradient_clipping = use_gradient_clipping, clip_grad_norm_cardinal = clip_grad_norm_cardinal, 
                device=torch.device('cpu'))

        scheduler_cardinal.step()

        if epoch % print_frequency == 0:
            print(f'Current epoch: {epoch}, CARDINAL loss mean: {loss_mean}')

    model.transform_weights()
    return model, scale


def finetune_mvnn(X_train_ordinal, y_train_ordinal, pretrained_model, scale, 
            learning_rate_ordinal, weight_decay_ordinal, epochs_ordinal, batch_size_ordinal, loss_ordinal,
            use_gradient_clipping, clip_grad_norm_ordinal, 
            print_frequency = 50):
    """
    Takes as input the training set (as numpy arrays) and trains one MVNN of the latest version on that input, with all the MVNN hyperparameters.
    Uses both the GUI dataset, as well as the comparison dataset.
    """
    # print(f'---> Finetuning MVNN on the ordinal dataset---')
    if X_train_ordinal is None or X_train_ordinal.shape[0] == 0:
        print('--> No ordinal data to finetune on, skipping this step.')
        return

    # Step 1: Put the ordinal data in a dataloader
    train_dataset_ordinal = torch.utils.data.TensorDataset(torch.from_numpy(X_train_ordinal).float(),
                                        torch.from_numpy(y_train_ordinal.reshape(-1, 1)).float())
    train_loader_ordinal = torch.utils.data.DataLoader(train_dataset_ordinal, batch_size= batch_size_ordinal, shuffle=True)

    
    #  make sure ts have no regularisation, the bigger t the more regular
    # print(f'Setting ORDINAL weight decay to: {weight_decay_ordinal}')
    l2_reg_parameters = {'params': [], 'weight_decay': weight_decay_ordinal}
    no_l2_reg_parameters = {'params': [], 'weight_decay': 0.0}
    for p in [*pretrained_model.named_parameters()]:
        if 'ts' in p[0]:
            logging.debug(f'Setting L2-Reg. to 0.0 for {p[0]}.')
            no_l2_reg_parameters['params'].append(p[1])
        else:
            l2_reg_parameters['params'].append(p[1])

    optimizer_ordinal = torch.optim.Adam([l2_reg_parameters, no_l2_reg_parameters], lr= learning_rate_ordinal)

    scheduler_ordinal = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_ordinal, float(epochs_ordinal))

    for epoch in range(epochs_ordinal):
        loss_mean = finetune_mvnn_epoch(model = pretrained_model, optimizer_ordinal = optimizer_ordinal,
                use_gradient_clipping = use_gradient_clipping, clip_grad_norm_ordinal = clip_grad_norm_ordinal, scale = scale,
                train_loader_ordinal= train_loader_ordinal,
                loss_func_ordinal= loss_ordinal, device=torch.device('cpu'))

        scheduler_ordinal.step()

        # if epoch % print_frequency == 0:
        #     print(f'Current epoch: {epoch}, ORDINAL loss mean: {loss_mean}')

    pretrained_model.transform_weights()
    return 


def create_TL_mvnn_student_list(model_param_dictionary, training_set, timetable, student_list, mvnn_student_list, credit_units):
    """
    This is the equivalent part of create_iterative_student_list that gets triggered in the case of a mixed dataset consiting of GUI initial repots and CQs.  
    Creates an MVNN student list by applying transfer learning. 
    If it's the first time it is invoked, it will do the pre-training. 
    Else, it will take the pretrained networks and finetune them on the students' CQs.
    """
    seconds_create_solver_total = 0
    seconds_generate_MIP_total = 0

    # print('---> Create TL student list called')
    if (mvnn_student_list is None) or (mvnn_student_list == []):
        print('---> pretraining the MVNN student list on the GUI reports')

        mvnn_student_list = []

        if model_param_dictionary['UNN_loss_string_cardinal'] == 'r2':
            loss = r2_loss

        elif model_param_dictionary['UNN_loss_string_cardinal'] == 'l1':
            loss = F.l1_loss

        for (i, all_sets_current_student) in enumerate(training_set):
            # (model, solver, scale, budget) = model_student_list[i]  # get the current model for that student.
            # (X_train, y_train) = current_training_set[i][0]  # get the current (cardinal) dataset
            # (X_train_ord, y_train_ord) = current_training_set[i][1]  # get the current ordinal dataset
            # (x_max_current, y_max_current) = current_training_set[i][2]    # get the current best performing point, according to the true student
            # bundles_to_forbid = current_training_set[i][3]   # get the bundles that you don't want your solver to find (e.g. already queried, or implied)

            (X_train, y_train) = all_sets_current_student[0]
            # (X_train_ord, y_train_ord) = all_sets_current_student[1]
            # bundles_to_forbid = all_sets_current_student[3]

            budget = student_list[i][-1]

            trained_model, scale = pretrain_mvnn(X_train = X_train, y_train = y_train,
                num_hidden_layers = model_param_dictionary['UNN_layers'], num_units = model_param_dictionary['UNN_units'], random_ts = model_param_dictionary['UNN_random_ts'], trainable_ts = model_param_dictionary['UNN_trainable_ts'],
                init_E = model_param_dictionary['UNN_init_E'], init_Var = model_param_dictionary['UNN_init_Var'], learning_rate_cardinal = model_param_dictionary['lr_cardinal'], weight_decay_cardinal = model_param_dictionary['weight_decay_cardinal'],
                epochs_cardinal = model_param_dictionary['epochs_cardinal'], batch_size_cardinal = model_param_dictionary['batch_size_cardinal'], loss_cardinal = loss,
                use_gradient_clipping = model_param_dictionary['use_gradient_clipping'],
                clip_grad_norm_cardinal = model_param_dictionary['clip_cardinal'])

            trained_model.transform_weights()
            start = timer()
            if trained_model._num_hidden_layers >= 1:
                solver = GUROBI_MIP2_MVNN(trained_model)
            else:
                print('Creating linear MVNN MIP!')
                solver = GUROBI_MIP2_MVNN_LINEAR(trained_model)
            mid = timer()
            solver.generate_mip(course_timetable=timetable,
                credit_units=credit_units,
                cu_max=5,
                timeLimit=100,
                MIPGap=0.000001,
                verbose=False)
            end = timer()

            # create a copy of the pretrained model 
            pretrained_model = copy.deepcopy(trained_model)

            # copy the state of the pretrained model to the finetuned model
            # Step 2: Copy the state of model1 to model2
            pretrained_model.load_state_dict(trained_model.state_dict())
                                    
            mvnn_student_list.append((trained_model, solver, scale, pretrained_model, budget))
            seconds_create_solver_total += (mid - start)
            seconds_generate_MIP_total += (end - mid)

        print(f'AVG time to create a solver for UNN after {len(mvnn_student_list)} solvers: {seconds_create_solver_total / len(mvnn_student_list)}')
        print(f'AVG time to generate a MIP for UNN after {len(mvnn_student_list)} solvers: {seconds_generate_MIP_total / len(mvnn_student_list)}')

    else: 
        # print('--->Finetuning the MVNN student list!')
        if model_param_dictionary['UNN_loss_string_ordinal'] == 'BCE':
                loss_ordinal = torch.nn.BCELoss()
        elif model_param_dictionary['UNN_loss_string_ordinal'] == 'GCE':
                    loss_ordinal = GeneralizedCrossEntropyLoss(q = model_param_dictionary['GCE_q'])
        else: 
            raise ValueError(f'Unknown ordinal loss string: {model_param_dictionary["UNN_loss_string_ordinal"]}')


        for (i, all_sets_current_student) in enumerate(training_set):
             # (model, solver, scale, budget) = model_student_list[i]  # get the current model for that student.
            # (X_train, y_train) = current_training_set[i][0]  # get the current (cardinal) dataset
            # (X_train_ord, y_train_ord) = current_training_set[i][1]  # get the current ordinal dataset
            # (x_max_current, y_max_current) = current_training_set[i][2]    # get the current best performing point, according to the true student
            # bundles_to_forbid = current_training_set[i][3]   # get the bundles that you don't want your solver to find (e.g. already queried, or implied)

            # (X_train, y_train) = all_sets_current_student[0]
            (X_train_ord, y_train_ord) = all_sets_current_student[1]
            bundles_to_forbid = all_sets_current_student[3]

            _ , _ , scale, pretrained_model, budget = mvnn_student_list[i]

            # Create a copy of the pretrained model that will be the starting point for the finetuning
            finetuned_model = copy.deepcopy(pretrained_model)
            finetuned_model.load_state_dict(pretrained_model.state_dict())

            finetune_mvnn(X_train_ordinal = X_train_ord, y_train_ordinal = y_train_ord, pretrained_model = finetuned_model, scale = scale,
                learning_rate_ordinal = model_param_dictionary['lr_ordinal'], weight_decay_ordinal = model_param_dictionary['weight_decay_ordinal'],
                epochs_ordinal = model_param_dictionary['epochs_ordinal'], batch_size_ordinal = model_param_dictionary['batch_size_ordinal'], loss_ordinal = loss_ordinal,
                use_gradient_clipping = model_param_dictionary['use_gradient_clipping'],
                clip_grad_norm_ordinal = model_param_dictionary['clip_ordinal'])
            
            # Create new solver for the finetuned MVNN model 
            finetuned_model.transform_weights()
            start = timer()
            if finetuned_model._num_hidden_layers >= 1:
                solver = GUROBI_MIP2_MVNN(finetuned_model)
            else:
                print('Creating linear MVNN MIP!')
                solver = GUROBI_MIP2_MVNN_LINEAR(finetuned_model)
            mid = timer()
            solver.generate_mip(course_timetable=timetable,
                credit_units=credit_units,
                cu_max=5,
                timeLimit=100,
                MIPGap=0.000001,
                verbose=False)
            end = timer()

            mvnn_student_list[i] = (finetuned_model, solver, scale, pretrained_model, budget)
            seconds_create_solver_total += (mid - start)
            seconds_generate_MIP_total += (end - mid)

        # print(f'AVG time to create a solver for UNN after {len(mvnn_student_list)} solvers: {seconds_create_solver_total / len(mvnn_student_list)}')
        # print(f'AVG time to generate a MIP for UNN after {len(mvnn_student_list)} solvers: {seconds_generate_MIP_total / len(mvnn_student_list)}')

    return mvnn_student_list


def project_utilities_cognomos_language(course_values, favorite_course_cognomos_value = 360000, 
        ranges_other_groups = [ (1, 12), (36, 189), (720, 3780)]):
    """
    Projects the utilities of a student from the CM lanugage used in Wharton to the Cognomos language.
    Uses k-means clustering to group the courses into clusters based on their base values and assigns values to the courses in each cluster.

    :param course_values: NumPy array of base values for each course 
    :param favorite_course_cognomos_value: The value of the favorite course in the Cognomos language.
    :param ranges_other_groups: The ranges of values for the other groups in the Cognomos language.

    :return: NumPy array of projected values for each course in the Cognomos language.
    """
    
    # Step 1: Identify the favorite course (highest value)
    favorite_course_index = np.argmax(course_values)
    favorite_course_value = np.max(course_values)
    
    # Step 2: Filter out courses with a value of 0
    filtered_values = np.array([value for value in course_values if value > 0]).reshape(-1, 1)
    
    # Step 2b: Remove the favorite course from the filtered values
    remaining_values = filtered_values[filtered_values != favorite_course_value].reshape(-1, 1)

    # get a mapping from the original indices to the indices of the remaining values
    remaining_indices = [] 
    for i in range(len(course_values)):
        if course_values[i] > 0 and (i != favorite_course_index):
            remaining_indices.append(i)

    remaining_indices = np.array(remaining_indices)

    # Step 3: Cluster the remaining courses into 3 bundles using K-Means clustering
    kmeans = KMeans(n_clusters= len(ranges_other_groups), random_state=0).fit(remaining_values)
    clusters = kmeans.predict(remaining_values)

    # get the values of the courses in each cluster
    cluster_values = [remaining_values[clusters == i] for i in range(len(ranges_other_groups))]

    # short the groups with respect to their mean value
    cluster_categories = np.argsort([np.mean(cluster_values[i]) for i in range(3)])
    cluster_values_mean_shorted = [cluster_values[i] for i in cluster_categories]

    # now also short the courses in each  cluster with respect to their value
    cluster_values_final = [ cluster_values_mean_shorted[i][np.argsort(cluster_values_mean_shorted[i], axis = 0)].reshape(-1, 1) for i in range(len(ranges_other_groups))]
    # get the indices of the courses in each cluster

    cluster_indices = []
    for i in range(len(cluster_values_final)):
        cluster_indices_current_group = [] 
        for j in range(len(cluster_values_final[i])):
            # find the index of the course in the original list
            index = remaining_indices[np.where(remaining_values == cluster_values_final[i][j])[0][0]]
            cluster_indices_current_group.append(index)

        cluster_indices.append(cluster_indices_current_group)
        

    # Create an array with size equal to the original course values 
    course_values_cognomos_language = np.zeros(len(course_values))

    # Assign the favorite course to the favorite course value
    course_values_cognomos_language[favorite_course_index] = favorite_course_cognomos_value
    print('setting the value for the favorite course index:', favorite_course_index, 'to:', favorite_course_cognomos_value)

    # Assign the other courses to the clusters
    for i, single_cluster in enumerate(cluster_indices):
        # determine the step size for the cluster so that the values are uniformly distributed
        cluster_step = (ranges_other_groups[i][1] - ranges_other_groups[i][0]) / (len(single_cluster) + 1)
        for j, course_index in enumerate(single_cluster):
            print('setting the value for course index:', course_index, 'to:', ranges_other_groups[i][0] + (j + 1) * cluster_step)
            course_values_cognomos_language[course_index] = ranges_other_groups[i][0] + (j + 1) * cluster_step


    return course_values_cognomos_language


def transform_utilities(utilities, log_offset=1e-5, stretch_factor = 3, exploration_factor=0.0):
    """
    Applies a log transformation with a small offset to handle zeros,
    followed by Min-Max scaling to the input utilities array.
    The purpose is to project the utilities from the cognomos language to the CM language, in which the utilities are scaled from 0 to 100.

    Parameters:
    utilities (np.array): The original utilities values.

    Returns:
    np.array: The transformed utilities array.
    """
    # Apply log transformation with an offset to handle zeros
    offset = log_offset
    utilities_log_transformed = np.log(utilities + offset)

    # Perform Min-Max scaling
    min_value = utilities_log_transformed.min()
    max_value = utilities_log_transformed.max()  # keep those values for the inverse transformation
    utilities_scaled = (utilities_log_transformed - utilities_log_transformed.min()) / (utilities_log_transformed.max() - utilities_log_transformed.min())
    
    # Stretching the scaled utilities to make differences more pronounced
    utilities_scaled= np.power(utilities_scaled, stretch_factor)
    
    # Apply an exploration factor to the scaled utilities
    min_positive_value = utilities_scaled[utilities_scaled > 0].min()
    
    utilities_scaled[utilities_scaled == 0] = min_positive_value * exploration_factor


    # multiply by 100 to get back to the original CM scale 
    utilities_scaled = utilities_scaled * 100

    
    return utilities_scaled, min_value, max_value