from preference_generator import create_all_instances, calculate_total_demand, create_noisy_model_student_list, calculate_single_student_demand,calculate_true_bundle_value
from preference_generator_utils import load_obj
import numpy as np
from edsl import QuestionList, ScenarioList, QuestionFreeText
import random 

from typing import List
def levenshtein_distance(list1: List[str], list2: List[str]) -> int:
    """Compute the Levenshtein distance (edit distance) between two lists."""
    n, m = len(list1), len(list2)
    dp = [[0] * (m + 1) for _ in range(n + 1)]

    for i in range(n + 1):
        for j in range(m + 1):
            if i == 0:
                dp[i][j] = j
            elif j == 0:
                dp[i][j] = i
            elif list1[i - 1] == list2[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]
            else:
                dp[i][j] = 1 + min(dp[i][j - 1], dp[i - 1][j], dp[i - 1][j - 1])

    return dp[n][m]

# Compare to random guessing
def random_distance():
    return levenshtein_distance(true_preferences, random.sample(courses, len(courses)))

# The preferences that were fitted to the Budish and Kessler (2022) data: 
# For 6 popular courses -> mean number of favourites = 3.85
# For 9 popular courses -> mean number of favourites = 2.6
# Maximum budget deviation beta was set to 0.04 as in the Rubenstein 2022 paper
# In our paper, we report results using a supply ratio of 1.1, 1.25 and 1.5

true_student_lists_all_instances, capacities_all_instances, timetables_all_instances = create_all_instances(number_of_instances = 50, number_of_courses= 25, supply_ratio= 1.25, number_of_popular= 9, mean_number_of_favourites= 2.6, maximum_budget_deviation_beta= 0.04, save_results= True, save_folder= "instances")
true_student_lists_all_instances, capacities_all_instances, timetables_all_instances = create_all_instances(number_of_instances = 50, number_of_courses= 25, supply_ratio= 1.25, number_of_popular= 6, mean_number_of_favourites= 3.85, maximum_budget_deviation_beta= 0.04, save_results= True, save_folder= "instances")
true_student_lists_all_instances, capacities_all_instances, timetables_all_instances = create_all_instances(number_of_instances = 50, number_of_courses= 25, supply_ratio= 1.25, number_of_popular= 9, mean_number_of_favourites= 6.1, maximum_budget_deviation_beta= 0.04, additive_preferences= True,save_results= True, save_folder= "instances")

def load_instance(instance_number, supply_ratio = 1.25, number_of_popular = 9, large_grid = True, additive_preferences = False,instance_folder = "instances"):
    """
    A simple function to load a single instance, given its number, the supply ration, and the number of pupular courses
    """

    true_student_list = load_obj(f'{instance_folder}/true_student_lists_sr_{supply_ratio}_popular_{number_of_popular}_lg_{large_grid}_additive_preferences_{additive_preferences}')[instance_number]
    capacities = np.load(f'{instance_folder}/capacities_all_runs_sr_{supply_ratio}_popular_{number_of_popular}_lg_{large_grid}_additive_preferences_{additive_preferences}.npy')[instance_number]
    timetable = load_obj(f'{instance_folder}/timetables_sr_{supply_ratio}_popular_{number_of_popular}_lg_{large_grid}_additive_preferences_{additive_preferences}')[instance_number]

    return true_student_list, capacities, timetable

true_student_list, capacities, timetable = load_instance(instance_number = 2)

# position i of the student list contains the true preferences, and the budget, of the i-th student 
# format: Additive preferences, list of substitute courses, list of complement courses, 3 time parameters are all set to completely off (and designed to model schedule preferences), and finally budget 

print(true_student_list[2])

number_of_popular = 9 

# these are the noise profiles that were used in our paper and matched as closely as possible the noise profile in the Budish and Kessler 2022 paper. 
if number_of_popular == 9:
    noise_parameter_dictionary = {'noisy_forget_base': 0.5, 'noisy_forget_adjustments': 0.48, 'noisy_base_std': 23, 'noisy_adj_std': 0.2}

elif number_of_popular == 6:
    noise_parameter_dictionary = {'noisy_forget_base': 0.5, 'noisy_forget_adjustments': 0.4825, 'noisy_base_std': 17, 'noisy_adj_std': 0.2}


noisy_student_list = create_noisy_model_student_list(student_list = true_student_lists_all_instances[2], model_type = 'PairwiseAdjustmentsNoisy', 
                                                    seed = 1213, model_param_dictionary = noise_parameter_dictionary)

price_vector = np.array([0.2 for _ in range(25)])  # A price vector for the individual courses
credit_units = np.array([1 for _ in range(25)])  # the credit units of inidividual courses. For this problem instance, all courses are worth 1 credit unit, and each student wants up to 5 Credit Units

# Demands with respect to their noisy reports in the mechanism
total_demand, individual_demands = calculate_total_demand(prices = np.array([0.2 for _ in range(25)]), student_profiles = noisy_student_list, course_timetable = [[] for  _ in range(5)],
                            credit_units = np.array([1 for _ in range(25)]), model_type = 'PairwiseAdjustmentsNoisy', credit_units_per_student = 5, return_individual_demands = True)
    
# Demands with respect to their true reports (note: these preferences are incompatible with the CM GUI language)
total_demand_true, individual_demands_true = calculate_total_demand(prices = np.array([0.2 for _ in range(25)]), student_profiles = true_student_list, course_timetable = [[] for  _ in range(5)],
                        credit_units = np.array([1 for _ in range(25)]), model_type = 'True', credit_units_per_student = 5, return_individual_demands = True)


student_number = 7 

demand_noisy = calculate_single_student_demand(prices = np.array([0.2 for _ in range(25)]), 
    student_profile = noisy_student_list[student_number],
    course_timetable = timetable,  # the course timetable affects the feasibility constraints, as students cannot take courses that are scheduled at the same time
    credit_units = [1 for i in range(25)],  # the credit units associated with each course 
    model_type = 'PairwiseAdjustmentsNoisy',  # this is the model type of the GUI reporting language 
    credit_units_per_student = 5,   # by changing this the maximum number of credit units that the student can take changes, and thus so does the maximum number of courses 
    budget = 1.3  # with this parameter you can change the budget of the student. 
    )

demand_true = calculate_single_student_demand(prices = np.array([0.2 for _ in range(25)]), 
    student_profile = true_student_list[student_number],
    course_timetable = timetable,
    credit_units = [1 for i in range(25)],
    model_type = 'True',   # this is the model type of the True student preferences
    credit_units_per_student = 5,    
    budget = 1.3)

for i in range(len(individual_demands[0:1])): 
    value_bundle_noisy_preferences = calculate_true_bundle_value(bundle = individual_demands[i], student_preferences = true_student_list[i], timetable = [[] for  _ in range(5)], make_monotone = True)
    value_bundle_true_preferences = calculate_true_bundle_value(bundle = individual_demands_true[i], student_preferences = true_student_list[i], timetable = [[] for  _ in range(5)], make_monotone = True)
    
    print(f'For student {i}, value for the optimal bundle {value_bundle_true_preferences} and the one she got under noisy reports: {value_bundle_noisy_preferences}')
    

## Print the course information
# Assume we have a function to get course information
def get_course_info(course_id):
    course_names = [f"Course {i}" for i in range(25)]  # Assuming 25 courses
    credit_units = [1 for _ in range(25)]  # Assuming all courses are 1 credit unit
    return {
        "name": course_names[course_id],
        "credit_units": credit_units[course_id],
        "id": course_id
    }
    
def print_timetable(timetable):
    print("\nCourse Timetable:")
    days = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"]
    print("Time Slot | " + " | ".join(days))
    print("-" * 60)
    
    max_slots = max(len(day) for day in timetable)
    
    for slot in range(max_slots):
        slot_info = [f"Slot {slot:2d}"]
        for day in timetable:
            if slot < len(day):
                courses = ", ".join([get_course_info(c)['name'] for c in day[slot]])
                slot_info.append(f"{courses:10s}")
            else:
                slot_info.append(" " * 10)
        print(" | ".join(slot_info))

def generate_course_info_string(courses, timetable):
    days = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"]
    course_info_list = []
    for course_id in range(len(courses)):
        course = get_course_info(course_id)
        schedule = []
        for day_index, day in enumerate(timetable):
            for slot_index, slot in enumerate(day):
                if course_id in slot:
                    schedule.append(f"{days[day_index]} at time slot {slot_index}")
        schedule_str = ", ".join(schedule)
        course_info = f"Course {course['id']} ({course['name']}) with {course['credit_units']} credit units is scheduled on {schedule_str}"
        course_info_list.append(course_info)
    
    return "; ".join(course_info_list) + "."


## generate the ordinal preference list

def generate_preference_ordering(additive_prefs):
    """
    Generate a ranked-order preference list based on the student's Base Course Preferences.
    
    :param additive_prefs: numpy array of base course preferences
    :return: string representing the preference ordering
    """
    # Create a list of tuples (course_index, preference_value)
    courses = list(enumerate(additive_prefs))
    
    # Sort the courses based on preference value in descending order
    sorted_courses = sorted(courses, key=lambda x: x[1], reverse=True)
    
    # Extract just the course indices in the sorted order
    ranked_courses = [f'course {str(course[0])}' for course in sorted_courses]
    
    # Use the list_to_pref function to format the preference ordering
    return list_to_pref(ranked_courses)

def list_to_pref(items):
    "Format a list of items as a preference ordering e.g., A > B > C from ['A', 'B', 'C']"
    return " > ".join(items)

def generate_preference_list(additive_prefs):
    # Create a list of tuples (course_id, preference_value)
    courses = list(enumerate(additive_prefs))
    
    # Sort the courses based on preference value in descending order
    sorted_courses = sorted(courses, key=lambda x: x[1], reverse=True)
    
    # Extract just the course ids in the sorted order
    return [course[0] for course in sorted_courses]

def print_student_info(student, student_number, timetable):
    additive_prefs, substitutes, complements, overload_penalty, timegap_penalty, free_days_marginal_values, budget = student
    
    print(f"Information for Student {student_number}:")
    
    # Generate and print course info string
    course_info_string = generate_course_info_string(additive_prefs, timetable)
    print("\nCourse Information:")
    print(course_info_string)
    
    print("\n1. Base Course Preferences:")
    print("   These values represent the student's basic preference for each course, regardless of other factors.")
    for i, pref in enumerate(additive_prefs):
        course = get_course_info(i)
        print(f"   {course['name']} (ID: {i}): {pref:.2f}")
    
    print("\n   Ranked-order preference list:")
    true_preferences = generate_preference_list(additive_prefs)
    preference_ordering = generate_preference_ordering(additive_prefs)
    print(f"   {preference_ordering}")
    print("   This list shows the courses ordered from most preferred to least preferred based on the base preferences.")
    
    return preference_ordering, true_preferences
    
    
student_number = 1
student = true_student_list[student_number]
preference_ordering, true_preferences = print_student_info(student, student_number, timetable)

course_info_string = generate_course_info_string(student[0], timetable)
print(course_info_string)

## Convert to text
# Ask the agent to generate a paragraph describing the preferences
def generate_preference_paragraph(preference_ordering, course_info_string):
    q = QuestionFreeText(
        question_text = f"""Imagine a MBA student is describing their preference of the courses to someone who will mkae course matching for them. 
        Generate a paragraph of text that try to capture this person's preferences for the courses (timeslot-wise), without mentioning the specific course by name. 
        It is OK to mention dislikes and negative opinions.
        Their preferences are: { preference_ordering } and the timetable for all these courses is {course_info_string}. Slot 0-3 is in the morning. Slot 4-7 is in the afternoon. Slot 8-11 is at night. REMEMBER to describe as perfect as possible. Limit your description within 200 words.
        """,
        question_name = "pref_paragraph")

    results = q.run()

    # print them
    pref = results.select("pref_paragraph")
    pref.print(format = "rich")
    
    return pref

def infer_preferences_from_paragraph(pref_paragraph, course_info_string):
    q_infer = QuestionList(
                question_text = f"""A MBA student described their preferences over a set of courses as { pref_paragraph }.
                    The collection of courses is: '{ course_info_string }'.
                    Based on what they said, what do you guess is their rank-ordering of the courses? 
                    """, 
                question_name = "inferred_preference")

    results_inference = q_infer.run()
    infer = results_inference.select("inferred_preference")
    infer.print(format = "rich")
    
    return infer


courses = [f"course {i}" for i in range(25)]
# true_preferences = generate_preference_list(student[0])
# course_info_string = generate_course_info_string(student[0], timetable)

# preference_ordering = " > ".join(true_preferences)
pref_paragraph = generate_preference_paragraph(preference_ordering, course_info_string)
inferred_preferences = infer_preferences_from_paragraph(pref_paragraph, course_info_string)

distance = levenshtein_distance(true_preferences, inferred_preferences)
print(f"Levenshtein distance between true and inferred preferences: {distance}")

N = 1000
average_random = sum(random_distance() for _ in range(N)) / N
print(f"Average distance for random guessing: {average_random:.2f}")