from preference_generator import generate_time_problem_instance, calculate_true_bundle_value
import numpy as np
from edsl import QuestionList, ScenarioList, QuestionFreeText
import random
from typing import List
import re
from openai import OpenAI
client = OpenAI()

def get_answer_from_llm(prompt):
    response = client.chat.completions.create(
    model="o1-preview",
    messages=[
        {
            "role": "user",
            "content": str(prompt)
        }
    ]
)
    answer0 = response.choices[0].message.content
    print(answer0)
    return answer0


# Generate a single time preference instance
true_student_list, realized_student_types, capacities, timetable = generate_time_problem_instance(
    number_of_students=100,
    number_of_courses=25,
    supply_ratio=1.25,
    capacity_deviation=0,
    student_types=['no_overload', 'free_days', 'few_timegaps', 'balanced'],
    type_probabilities=[1/4 for _ in range(4)],
    seed=42
)


def textify_schedule(course_int_list):
    return "(" + ",".join([f"Course {course_id}" for course_id in course_int_list]) + ")"


## Print the course information
# Assume we have a function to get course information
def get_course_info(course_id):
    course_names = [f"Course {i}" for i in range(25)]  # Assuming 25 courses
    credit_units = [1 for _ in range(25)]  # Assuming all courses are 1 credit unit
    return {
        "name": course_names[course_id],
        "credit_units": credit_units[course_id],
        "id": course_id
    }

def timetable_string(timetable):
    output = "\nCourse Timetable:\n"
    days = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"]
    output += "Time Slot | " + " | ".join(days) + "\n"
    output += "-" * 60 + "\n"

    max_slots = max(len(day) for day in timetable)

    for slot in range(max_slots):
        slot_info = [f"Slot {slot:2d}"]
        for day in timetable:
            if slot < len(day):
                courses = ", ".join([get_course_info(c)['name'] for c in day[slot]])
                slot_info.append(f"{courses:10s}")
            else:
                slot_info.append(" " * 10)
        output += " | ".join(slot_info) + "\n"

    return output

def generate_course_info_string(courses, timetable):
    days = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"]
    course_info_list = []

    def group_slots(slots):
        """ Helper function to group contiguous time slots """
        if not slots:
            return ""
        ranges = []
        start = prev = slots[0]
        for slot in slots[1:]:
            if slot == prev + 1:
                prev = slot
            else:
                ranges.append((start, prev))
                start = prev = slot
        ranges.append((start, prev))
        return ", ".join(f"time slot {start}" if start == end else f"time slots {start}-{end}" for start, end in ranges)

    for course_id in courses:
        course = get_course_info(course_id)
        schedule = []
        for day_index, day in enumerate(timetable):
            day_slots = [slot_index for slot_index, slot in enumerate(day) if course_id in slot]
            if day_slots:
                schedule.append(f"{days[day_index]} at {group_slots(day_slots)}")
        schedule_str = ", ".join(schedule)
        course_info = f"Course {course['id']} ({course['name']}) with {course['credit_units']} credit units is scheduled on {schedule_str}"
        course_info_list.append(course_info)

    return "; ".join(course_info_list) + "."

def generate_schedule_overview(courses, timetable):
    days = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"]
    overview = []

    def group_slots(slots):
        """ Helper function to group contiguous time slots """
        if not slots:
            return ""
        ranges = []
        start = prev = slots[0]
        for slot in slots[1:]:
            if slot == prev + 1:
                prev = slot
            else:
                ranges.append((start, prev))
                start = prev = slot
        ranges.append((start, prev))
        return ", ".join(f"{start}" if start == end else f"{start}-{end}" for start, end in ranges)

    for day_index, day in enumerate(timetable):
        daily_schedule = []
        for slot_index, slot in enumerate(day):
            for course_id in slot:
                if course_id in courses:  # Only consider courses in the provided list
                    course = get_course_info(course_id)
                    daily_schedule.append((slot_index, course))

        # Sort by time slots
        daily_schedule.sort(key=lambda x: x[0])

        # Group by course and slots
        course_slots = {}
        for slot_index, course in daily_schedule:
            if course['id'] not in course_slots:
                course_slots[course['id']] = []
            course_slots[course['id']].append(slot_index)

        # Format the day's overview
        day_overview = []
        for course_id, slots in course_slots.items():
            course = get_course_info(course_id)
            grouped_slots = group_slots(slots)
            day_overview.append(f"{course['name']} (slots {grouped_slots})")

        if day_overview:
            overview.append(f"{days[day_index]}: " + ", ".join(day_overview))

    return "; ".join(overview) + "."

def print_student_info(student, timetable):
    _, _, _, overload_penalty, timegap_penalty, free_days_marginal_values, budget = student

    output = "Student Time Preferences: "

    # Overload penalty
    output += f"- Overload penalty: Penalty based on how far in hours the day with the most courses exceeds the average hours per day for all the other days. {overload_penalty:.2f} per increased difference in hour\n"

    # Timegap penalty
    output += f"- Timegap penalty: {timegap_penalty:.2f} per hour of gap between classes. For example, if a day has courses in slot 4-5, slot 7-8 and slot 10-11 the gap is (7-5)+(10-8)=4. Note: if a day only has one courses, there's no time gap. \n"

    # Free days preference
    output += "- Free days preference (student prefers more free days):\n"
    for i, value in enumerate(free_days_marginal_values):
        if i == 0:
            output += f"  * First free day: {value:.2f}\n"
        elif i == 1:
            output += f"  * Second free day: {value:.2f}\n"
        elif i == 2:
            output += f"  * Third free day: {value:.2f}\n"
        elif i == 3:
            output += f"  * Fourth free day: {value:.2f}\n"
        elif i == 4:
            output += f"  * Fifth free day: {value:.2f}\n"

    # Budget
    output += f"- Budget: {budget:.2f}\n"

    return output

def extract_numeric_answer(answer):
    # Get the last line of the answer
    last_line = answer.strip().splitlines()[-1].lower()
    last_line = last_line[-4:]

    # Use a dictionary to map words to numbers
    word_to_num = {
        "one": 1,
        "first": 1,
        "two": 2,
        "second": 2,
    }

    # Check for numeric values directly in the last line
    match = re.search(r'\b(1|2)\b', last_line)
    if match:
        return int(match.group(0))

    # Check for word-based answers in the last line
    for word, num in word_to_num.items():
        if word in last_line:
            return num

    # Handle cases where there might be a period or other characters at the end
    try:
        return int(last_line[-1])
    except ValueError:
        print("ERROR!!! Could not extract number from the last line:")
        print(last_line)
        return -1

# "easy mode"
def answer_pairwise_question_exact_preferences(student, schedule_pair, timetable, print_prompt=False):
    # not yet querying the LLM, this is dummy code for now
    # not yet querying the LLM
    question = f"""A student is trying to choose a course schedule between two options.
   A student prefers more free days in schedule (free days are days with no courses, from Monday to Friday), less time gaps between courses in the same day (the time slots are closer in number), and less overload in schedule (a balanced schedule between days). However, each student differs in how they value each component:
    In particular, a student has the following preferences: {print_student_info(student, timetable)}.
    The first schedule has {textify_schedule(schedule_pair[0])}. Under the first schedule, the schedule for the student for the week is given by {generate_schedule_overview(schedule_pair[0], timetable)}.
    The second schedule has {textify_schedule(schedule_pair[1])}. Under the second schedule, the schedule for the student for the week is given by {generate_schedule_overview(schedule_pair[1], timetable)}
    Which schedule should the student choose? Answer in the following format where X is the information you fill using the knowledge given above, note that the last line contains only one number corresponding to your answer, also make sure that you accurately remember the difference in courses and schedule between Schedule 1 and Schedule 2, given above:

    Example:
      Schedule 1:
      - The schedule is given by ...
      - List what days are not free where there's a course that day: X. Therefore, excluding these days, the student has X free days. This is worth X in utility for the student.
      - The day with the most courses is X (X courses) with a total of X hours, and the average courseload excluding that day with the most courses is X hours. The overload penalty for each difference in hour is X. Therefore the contribution to utility is X. (This should be a negative number since it's a penalty)
      - The days with more than one courses are X. Of these days, the time gaps between the courses are X hours so the total time gaps is X hours. The penalty for each hour gap is X, therefore the contribution to utility is X. (This should be a negative number since it's a penalty)
      - The total score for Schedule 1 is therefore: X

      Schedule 2:
      - The schedule is given by ...
      - List what days are not free where there's a course that day: X. Therefore, excluding these days, the student has X free days. This is worth X in utility for the student.
      - The day with the most courses is X (X courses) with a total of X hours, and the average courseload excluding that day with the most courses is X hours. The overload penalty for each difference in hour is X. Therefore the contribution to utility is X. (This should be a negative number since it's a penalty)
      - The days with more than one courses are X. Of these days, the time gaps between the courses are X hours so the total time gaps is X hours. The penalty for each hour gap is X, therefore the contribution to utility is X. (This should be a negative number since it's a penalty)
      - The total score for Schedule 2 is therefore: X

      Answer:
      The utility for Schedule 1 is X.
      The utility for Schedule 2 is X.
      Schedule X is higher in utility.
      Therefore, the student should choose Schedule X
    """

    # Add "No explanations" to get rid of reasoning. Reasoning is not turned off now since we want to understand how good this is."
    prompt = [{"role": "user", "content": f"{question}"}]
    if print_prompt:
    #   print(prompt)
      print("Option 1: ", generate_course_info_string(schedule_pair[0], timetable))
      print("Option 2: ", generate_course_info_string(schedule_pair[1], timetable))
      print(generate_schedule_overview(schedule_pair[0], timetable))

    answer = get_answer_from_llm(prompt)

    content = answer
    return extract_numeric_answer(content)



student_to_test = true_student_list[0]
