from preference_generator import generate_time_problem_instance, calculate_true_bundle_value
import numpy as np
from edsl import QuestionList, ScenarioList, QuestionFreeText
import random
from typing import List, Tuple
import random
import numpy as np
from itertools import combinations

# Generate a single time preference instance
true_student_list, realized_student_types, capacities, timetable = generate_time_problem_instance(
    number_of_students=100, 
    number_of_courses=25, 
    supply_ratio=1.25, 
    capacity_deviation=0,
    student_types=['no_overload', 'free_days', 'few_timegaps', 'balanced'],
    type_probabilities=[1/4 for _ in range(4)],
    seed=42
)

## Print the course information
# Assume we have a function to get course information
def get_course_info(course_id):
    course_names = [f"Course {i}" for i in range(25)]  # Assuming 25 courses
    credit_units = [1 for _ in range(25)]  # Assuming all courses are 1 credit unit
    return {
        "name": course_names[course_id],
        "credit_units": credit_units[course_id],
        "id": course_id
    }
    
def timetable_string(timetable):
    output = "\nCourse Timetable:\n"
    days = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"]
    output += "Time Slot | " + " | ".join(days) + "\n"
    output += "-" * 60 + "\n"
    
    max_slots = max(len(day) for day in timetable)
    
    for slot in range(max_slots):
        slot_info = [f"Slot {slot:2d}"]
        for day in timetable:
            if slot < len(day):
                courses = ", ".join([get_course_info(c)['name'] for c in day[slot]])
                slot_info.append(f"{courses:10s}")
            else:
                slot_info.append(" " * 10)
        output += " | ".join(slot_info) + "\n"
    
    return output

def generate_course_info_string(courses, timetable):
    days = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"]
    course_info_list = []
    for course_id in range(len(courses)):
        course = get_course_info(course_id)
        schedule = []
        for day_index, day in enumerate(timetable):
            for slot_index, slot in enumerate(day):
                if course_id in slot:
                    schedule.append(f"{days[day_index]} at time slot {slot_index}")
        schedule_str = ", ".join(schedule)
        course_info = f"Course {course['id']} ({course['name']}) with {course['credit_units']} credit units is scheduled on {schedule_str}"
        course_info_list.append(course_info)
    
    return "; ".join(course_info_list) + "."



def print_student_info(student, timetable):
    _, _, _, overload_penalty, timegap_penalty, free_days_marginal_values, budget = student
    
    output = "Student Time Preferences:\n"
    
    # Overload penalty
    output += f"- Overload penalty: {overload_penalty:.2f} per credit unit above 5\n"
    
    # Timegap penalty
    output += f"- Timegap penalty: {timegap_penalty:.2f} per hour of gap between classes\n"
    
    # Free days preference
    output += "- Free days preference:\n"
    for i, value in enumerate(free_days_marginal_values):
        if i == 0:
            output += f"  * First free day: {value:.2f}\n"
        elif i == 1:
            output += f"  * Second free day: {value:.2f}\n"
        elif i == 2:
            output += f"  * Third free day: {value:.2f}\n"
        elif i == 3:
            output += f"  * Fourth free day: {value:.2f}\n"
        elif i == 4:
            output += f"  * Fifth free day: {value:.2f}\n"
    
    # Budget
    output += f"- Budget: {budget:.2f}\n"
    
    return output

def textify_schedule(course_int_list):
    return "(" + ",".join([f"Course {course_id}" for course_id in course_int_list]) + ")"
# "easy mode"
def answer_pairwise_question_exact_preferences(student, schedule_pair, timetable):
    # not yet querying the LLM, this is dummy code for now
    # not yet querying the LLM
    question = f"""A student is trying to choose a course schedule.
    The courses are scheduled according to the following timetable: {timetable_string(timetable)}.
    The student has the following preferences: {print_student_info(student, timetable)}.
    The student is trying to choose between the following two schedules: {textify_schedule(schedule_pair[0])} and {textify_schedule(schedule_pair[1])}.
    Which schedule should the student choose? Simply answer 0 for the first schedule and 1 for the second schedule.
    """
    # print(question)
    output_str = "1"
    return int(output_str)


print(timetable_string(timetable))


student_to_test = true_student_list[0]

answer_pairwise_question_exact_preferences(student_to_test, [[0],[1]], timetable)


def generate_random_schedule_pairs(num_courses, num_pairs=10, max_courses_per_schedule=5):
    all_courses = list(range(num_courses))
    schedule_pairs = []
    
    for _ in range(num_pairs):
        schedule1 = random.sample(all_courses, random.randint(1, max_courses_per_schedule))
        remaining_courses = [c for c in all_courses if c not in schedule1]
        schedule2 = random.sample(remaining_courses, random.randint(1, min(max_courses_per_schedule, len(remaining_courses))))
        schedule_pairs.append((schedule1, schedule2))
    
    return schedule_pairs

def test_pairwise_answer_accuracy(student, timetable, num_tests=100):
    num_courses = len(student.additive_prefs)
    schedule_pairs = generate_random_schedule_pairs(num_courses, num_tests)
    
    correct_answers = 0
    
    for schedule1, schedule2 in schedule_pairs:
        model_answer = answer_pairwise_question_exact_preferences(student, (schedule1, schedule2), timetable)
        
        # Calculate true values
        value1 = calculate_true_bundle_value(np.array([1 if i in schedule1 else 0 for i in range(num_courses)]), 
                                             student, timetable, ignore_timegaps=False)
        value2 = calculate_true_bundle_value(np.array([1 if i in schedule2 else 0 for i in range(num_courses)]), 
                                             student, timetable, ignore_timegaps=False)
        
        print(value1, value2)
        true_answer = 1 if value2 >= value1 else 0
        
        if model_answer == true_answer:
            correct_answers += 1
    
    return correct_answers, num_tests

# Test the accuracy
student_to_test = true_student_list[0]
correct_answers, total_tests = test_pairwise_answer_accuracy(student_to_test, timetable)
accuracy = correct_answers / total_tests
print(f"Accuracy of answer_pairwise_question_exact_preferences: {correct_answers}/{total_tests} ({accuracy:.2%})")

print('warning!!!! this code is INCOMPLETE!!!!!!!!!!!!!')