{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install gurobipy\n",
    "!pip install edsl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install --upgrade transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import login\n",
    "login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformers\n",
    "import torch\n",
    "\n",
    "# Replace for local model once in cluster/downloaded locally\n",
    "model_id = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n",
    "# model_id = \"mlx-community/mathstral-7B-v0.1-fp16\"\n",
    "\n",
    "pipeline = transformers.pipeline(\n",
    "    \"text-generation\",\n",
    "    model=model_id,\n",
    "    model_kwargs={\"torch_dtype\": torch.bfloat16},\n",
    "    device_map=\"auto\",\n",
    "    # tokenizer=\"mistralai/Mathstral-7B-v0.1\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_answer_from_llm(pipeline, prompt, supress_output=False):\n",
    "    outputs = pipeline(\n",
    "    prompt,\n",
    "    max_new_tokens=2048,\n",
    "    )\n",
    "    if not supress_output:\n",
    "      print(outputs[0][\"generated_text\"][-1][\"content\"], \"LLM Output\")\n",
    "    return outputs[0][\"generated_text\"][-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "%cd Course-Match-Preference-Simulator/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from preference_generator import generate_time_problem_instance, calculate_true_bundle_value\n",
    "import numpy as np\n",
    "from edsl import QuestionList, ScenarioList, QuestionFreeText\n",
    "import random\n",
    "from typing import List\n",
    "import re\n",
    "\n",
    "\n",
    "# Generate a single time preference instance\n",
    "true_student_list, realized_student_types, capacities, timetable = generate_time_problem_instance(\n",
    "    number_of_students=100,\n",
    "    number_of_courses=25,\n",
    "    supply_ratio=1.25,\n",
    "    capacity_deviation=0,\n",
    "    student_types=['no_overload', 'free_days', 'few_timegaps', 'balanced'],\n",
    "    type_probabilities=[1/4 for _ in range(4)],\n",
    "    seed=42\n",
    ")\n",
    "\n",
    "def textify_schedule(course_int_list):\n",
    "    return \"(\" + \",\".join([f\"Course {course_id}\" for course_id in course_int_list]) + \")\"\n",
    "\n",
    "\n",
    "## Print the course information\n",
    "# Assume we have a function to get course information\n",
    "def get_course_info(course_id):\n",
    "    course_names = [f\"Course {i}\" for i in range(25)]  # Assuming 25 courses\n",
    "    credit_units = [1 for _ in range(25)]  # Assuming all courses are 1 credit unit\n",
    "    return {\n",
    "        \"name\": course_names[course_id],\n",
    "        \"credit_units\": credit_units[course_id],\n",
    "        \"id\": course_id\n",
    "    }\n",
    "\n",
    "def timetable_string(timetable):\n",
    "    output = \"\\nCourse Timetable:\\n\"\n",
    "    days = [\"Monday\", \"Tuesday\", \"Wednesday\", \"Thursday\", \"Friday\"]\n",
    "    output += \"Time Slot | \" + \" | \".join(days) + \"\\n\"\n",
    "    output += \"-\" * 60 + \"\\n\"\n",
    "\n",
    "    max_slots = max(len(day) for day in timetable)\n",
    "\n",
    "    for slot in range(max_slots):\n",
    "        slot_info = [f\"Slot {slot:2d}\"]\n",
    "        for day in timetable:\n",
    "            if slot < len(day):\n",
    "                courses = \", \".join([get_course_info(c)['name'] for c in day[slot]])\n",
    "                slot_info.append(f\"{courses:10s}\")\n",
    "            else:\n",
    "                slot_info.append(\" \" * 10)\n",
    "        output += \" | \".join(slot_info) + \"\\n\"\n",
    "\n",
    "    return output\n",
    "\n",
    "def generate_course_info_string(courses, timetable):\n",
    "    days = [\"Monday\", \"Tuesday\", \"Wednesday\", \"Thursday\", \"Friday\"]\n",
    "    course_info_list = []\n",
    "\n",
    "    def group_slots(slots):\n",
    "        \"\"\" Helper function to group contiguous time slots \"\"\"\n",
    "        if not slots:\n",
    "            return \"\"\n",
    "        ranges = []\n",
    "        start = prev = slots[0]\n",
    "        for slot in slots[1:]:\n",
    "            if slot == prev + 1:\n",
    "                prev = slot\n",
    "            else:\n",
    "                ranges.append((start, prev))\n",
    "                start = prev = slot\n",
    "        ranges.append((start, prev))\n",
    "        return \", \".join(f\"time slot {start}\" if start == end else f\"time slots {start}-{end}\" for start, end in ranges)\n",
    "\n",
    "    for course_id in courses:\n",
    "        course = get_course_info(course_id)\n",
    "        schedule = []\n",
    "        for day_index, day in enumerate(timetable):\n",
    "            day_slots = [slot_index for slot_index, slot in enumerate(day) if course_id in slot]\n",
    "            if day_slots:\n",
    "                schedule.append(f\"{days[day_index]} at {group_slots(day_slots)}\")\n",
    "        schedule_str = \", \".join(schedule)\n",
    "        course_info = f\"Course {course['id']} ({course['name']}) with {course['credit_units']} credit units is scheduled on {schedule_str}\"\n",
    "        course_info_list.append(course_info)\n",
    "\n",
    "    return \"; \".join(course_info_list) + \".\"\n",
    "\n",
    "def generate_schedule_overview(courses, timetable):\n",
    "    days = [\"Monday\", \"Tuesday\", \"Wednesday\", \"Thursday\", \"Friday\"]\n",
    "    overview = []\n",
    "\n",
    "    def group_slots(slots):\n",
    "        \"\"\" Helper function to group contiguous time slots \"\"\"\n",
    "        if not slots:\n",
    "            return \"\"\n",
    "        ranges = []\n",
    "        start = prev = slots[0]\n",
    "        for slot in slots[1:]:\n",
    "            if slot == prev + 1:\n",
    "                prev = slot\n",
    "            else:\n",
    "                ranges.append((start, prev))\n",
    "                start = prev = slot\n",
    "        ranges.append((start, prev))\n",
    "        return \", \".join(f\"{start}\" if start == end else f\"{start}-{end}\" for start, end in ranges)\n",
    "\n",
    "    for day_index, day in enumerate(timetable):\n",
    "        daily_schedule = []\n",
    "        for slot_index, slot in enumerate(day):\n",
    "            for course_id in slot:\n",
    "                if course_id in courses:  # Only consider courses in the provided list\n",
    "                    course = get_course_info(course_id)\n",
    "                    daily_schedule.append((slot_index, course))\n",
    "\n",
    "        # Sort by time slots\n",
    "        daily_schedule.sort(key=lambda x: x[0])\n",
    "\n",
    "        # Group by course and slots\n",
    "        course_slots = {}\n",
    "        for slot_index, course in daily_schedule:\n",
    "            if course['id'] not in course_slots:\n",
    "                course_slots[course['id']] = []\n",
    "            course_slots[course['id']].append(slot_index)\n",
    "\n",
    "        # Format the day's overview\n",
    "        day_overview = []\n",
    "        for course_id, slots in course_slots.items():\n",
    "            course = get_course_info(course_id)\n",
    "            grouped_slots = group_slots(slots)\n",
    "            day_overview.append(f\"{course['name']} (slots {grouped_slots})\")\n",
    "\n",
    "        if day_overview:\n",
    "            overview.append(f\"{days[day_index]}: \" + \", \".join(day_overview))\n",
    "\n",
    "    return \"; \".join(overview) + \".\"\n",
    "\n",
    "def print_student_info(student, timetable):\n",
    "    _, _, _, overload_penalty, timegap_penalty, free_days_marginal_values, budget = student\n",
    "\n",
    "    output = \"Student Time Preferences: \"\n",
    "\n",
    "    # Overload penalty\n",
    "    output += f\"- Overload penalty: Penalty based on how far in hours the day with the most courses exceeds the average hours per day for all the other days. {overload_penalty:.2f} per increased difference in hour\\n\"\n",
    "\n",
    "    # Timegap penalty\n",
    "    output += f\"- Timegap penalty: {timegap_penalty:.2f} per hour of gap between classes. For example, if a day has courses in slot 4-5, slot 7-8 and slot 10-11 the gap is (7-5)+(10-8)=4. Note: if a day only has one courses, there's no time gap. \\n\"\n",
    "\n",
    "    # Free days preference\n",
    "    output += \"- Free days preference (student prefers more free days):\\n\"\n",
    "    for i, value in enumerate(free_days_marginal_values):\n",
    "        if i == 0:\n",
    "            output += f\"  * First free day: {value:.2f}\\n\"\n",
    "        elif i == 1:\n",
    "            output += f\"  * Second free day: {value:.2f}\\n\"\n",
    "        elif i == 2:\n",
    "            output += f\"  * Third free day: {value:.2f}\\n\"\n",
    "        elif i == 3:\n",
    "            output += f\"  * Fourth free day: {value:.2f}\\n\"\n",
    "        elif i == 4:\n",
    "            output += f\"  * Fifth free day: {value:.2f}\\n\"\n",
    "\n",
    "    # Budget\n",
    "    output += f\"- Budget: {budget:.2f}\\n\"\n",
    "\n",
    "    return output\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def extract_numeric_answer(answer):\n",
    "    # Get the last line of the answer\n",
    "    last_line = answer.strip().splitlines()[-1].lower()\n",
    "    last_line = last_line[-4:]\n",
    "\n",
    "    # Use a dictionary to map words to numbers\n",
    "    word_to_num = {\n",
    "        \"one\": 1,\n",
    "        \"first\": 1,\n",
    "        \"two\": 2,\n",
    "        \"second\": 2,\n",
    "    }\n",
    "\n",
    "    # Check for numeric values directly in the last line\n",
    "    match = re.search(r'\\b(1|2)\\b', last_line)\n",
    "    if match:\n",
    "        return int(match.group(0))\n",
    "\n",
    "    # Check for word-based answers in the last line\n",
    "    for word, num in word_to_num.items():\n",
    "        if word in last_line:\n",
    "            return num\n",
    "\n",
    "    # Handle cases where there might be a period or other characters at the end\n",
    "    try:\n",
    "        return int(last_line[-1])\n",
    "    except ValueError:\n",
    "        print(\"ERROR!!! Could not extract number from the last line:\")\n",
    "        print(last_line)\n",
    "        return -1\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print_prompt = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from preference_generator import generate_time_problem_instance, calculate_true_bundle_value\n",
    "import numpy as np\n",
    "from edsl import QuestionList, ScenarioList, QuestionFreeText\n",
    "import random\n",
    "from typing import List\n",
    "import re\n",
    "\n",
    "\n",
    "# Generate a single time preference instance\n",
    "true_student_list, realized_student_types, capacities, timetable = generate_time_problem_instance(\n",
    "    number_of_students=100,\n",
    "    number_of_courses=25,\n",
    "    supply_ratio=1.25,\n",
    "    capacity_deviation=0,\n",
    "    student_types=['no_overload', 'free_days', 'few_timegaps', 'balanced'],\n",
    "    type_probabilities=[1/4 for _ in range(4)],\n",
    "    seed=42\n",
    ")\n",
    "def get_student_text_preferences(student, timetable, print_prompt=False):\n",
    "    question = f'''A student is trying to describe his preferences on a timetable of courses. All students will take 5 courses in total.\n",
    "    The student has the following preferences: {print_student_info(student, timetable)}.'''+'''\n",
    "    As obvious from above, for each different component of the course schedule, the student cares about it differently. Since there's 5 courses in total, the students preferences given above should highlight his priorities in the final schedule.\n",
    "    Due to inherent difficulty of scheduling and the inherent tradeoffs included in the different points above, the student won't get all his wishes satisfied. For example, if he has a lot of free days, the course schedule cannot possibly be balanced therefore will be quite overloaded. His goal is to communicate his preferences as a few sentence outlining his priorities so that the courses given to him best fit his true numeric preferences.\n",
    "    Please describe the student's preferences concisely and qualitatively in a few sentences. By qualitatively, you're not allowed to repeat any numbers given to you above. Instead, you should express them in terms of sentences describing the different things he pays attention to and what is most, second most, etc. important to him.\n",
    "    You're allowed to use numeric scales (e.g. something is xx times more important) if necessary, but if possible, try to mimic what a real student will do to describe his preferences given the true numeric values. Return his report in the following format: {report: FILL}.\n",
    "    '''\n",
    "    # Add \"No explanations\" to get rid of reasoning. Reasoning is not turned off now since we want to understand how good this is.\"\n",
    "    prompt = [{\"role\": \"user\", \"content\": f\"{question}\"}]\n",
    "    if print_prompt:\n",
    "      print(prompt)\n",
    "    answer = get_answer_from_llm(pipeline, prompt, True)\n",
    "\n",
    "    content = answer[\"content\"]\n",
    "    return content\n",
    "\n",
    "student_to_test = true_student_list[1]\n",
    "get_student_text_preferences(student_to_test, timetable)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# \"easy mode\"\n",
    "def answer_pairwise_question_exact_preferences(student, schedule_pair, timetable, report, print_prompt=False):\n",
    "    # not yet querying the LLM, this is dummy code for now\n",
    "    # not yet querying the LLM\n",
    "    question = f\"\"\"A student is trying to choose a course schedule between two options.\n",
    "   A student prefers more free days in schedule (free days are days with no courses, from Monday to Friday), less time gaps between courses in the same day (the time slots are closer in number), and less overload in schedule (a balanced schedule between days). However, each student differs in how they value each component:\n",
    "    In particular, a student has the following report of his preferences: {report}.\n",
    "    The first schedule has {textify_schedule(schedule_pair[0])}. Under the first schedule, the schedule for the student for the week is given by {generate_schedule_overview(schedule_pair[0], timetable)}.\n",
    "    The second schedule has {textify_schedule(schedule_pair[1])}. Under the second schedule, the schedule for the student for the week is given by {generate_schedule_overview(schedule_pair[1], timetable)}\n",
    "    Which schedule should the student choose? Answer in the following format and repeat your answer at the last line of your response in the form of 'The student should choose Schedule X':\n",
    "\n",
    "    Example:\n",
    "      Schedule 1:\n",
    "      - The schedule is given by ...\n",
    "      - List what days are not free where there's a course that day: X. Therefore, excluding these days, the student has X free days.\n",
    "      - The day with the most courses is X (X courses) with a total of X hours, and the average courseload excluding that day with the most courses is X hours.\n",
    "      - The days with more than one courses are X. Of these days, the time gaps between the courses are X hours so the total time gaps is X hours.\n",
    "      - Other property of the schedule that relates to the student's reported preferences are: X.\n",
    "\n",
    "      Schedule 2:\n",
    "      - The schedule is given by ...\n",
    "      - List what days are not free where there's a course that day: X. Therefore, excluding these days, the student has X free days.\n",
    "      - The day with the most courses is X (X courses) with a total of X hours, and the average courseload excluding that day with the most courses is X hours.\n",
    "      - The days with more than one courses are X. Of these days, the time gaps between the courses are X hours so the total time gaps is X hours.\n",
    "      - Other property of the schedule that relates to the student's reported preferences are: X.\n",
    "\n",
    "      The student's preferences is given by ... His priorities, in order, are ...\n",
    "\n",
    "      Based on this reported preferences and the analysis of the two schedules, Schedule 1's pro's and con's are X.\n",
    "      Schedule 2's pro's and con's are X.\n",
    "      Comparing the two, Schedule 1 has the advantage of X and Schedule 2 has the advantage of X. X is more important from the student's report of preferences.\n",
    "\n",
    "      Answer:\n",
    "      The student should choose Schedule X\n",
    "    \"\"\"\n",
    "\n",
    "    # Add \"No explanations\" to get rid of reasoning. Reasoning is not turned off now since we want to understand how good this is.\"\n",
    "    prompt = [{\"role\": \"user\", \"content\": f\"{question}\"}]\n",
    "    if print_prompt:\n",
    "      print(prompt)\n",
    "      # print(\"Option 1: \", generate_course_info_string(schedule_pair[0], timetable))\n",
    "      # print(\"Option 2: \", generate_course_info_string(schedule_pair[1], timetable))\n",
    "      print(generate_schedule_overview(schedule_pair[0], timetable))\n",
    "\n",
    "    answer = get_answer_from_llm(pipeline, prompt)\n",
    "\n",
    "    content = answer[\"content\"]\n",
    "    return extract_numeric_answer(content)\n",
    "\n",
    "\n",
    "\n",
    "# student_to_test = true_student_list[0]\n",
    "\n",
    "# # print(true_student_list)\n",
    "\n",
    "# # Example\n",
    "# answer = answer_pairwise_question_exact_preferences(student_to_test, ([15,17,9,6],[15,3,22,1]), timetable, print_prompt=True)\n",
    "# answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def courses_overlap(course1, course2, timetable):\n",
    "    for week in timetable:\n",
    "        for day in week:\n",
    "            if course1 in day and course2 in day:\n",
    "                return True\n",
    "    return False\n",
    "\n",
    "def is_valid_schedule(schedule, timetable):\n",
    "    for i in range(len(schedule)):\n",
    "        for j in range(i + 1, len(schedule)):\n",
    "            if courses_overlap(schedule[i], schedule[j], timetable):\n",
    "                return False\n",
    "    return True\n",
    "\n",
    "def generate_random_schedule_pairs(num_courses, num_pairs=10, courses_per_schedule=5, timetable=None):\n",
    "    all_courses = list(range(num_courses))\n",
    "    schedule_pairs = []\n",
    "\n",
    "    for _ in range(num_pairs):\n",
    "        # Generate the first schedule with valid non-overlapping courses\n",
    "        while True:\n",
    "            schedule1 = random.sample(all_courses, courses_per_schedule)\n",
    "            if is_valid_schedule(schedule1, timetable):\n",
    "                break\n",
    "\n",
    "\n",
    "        # Generate the second schedule with valid non-overlapping courses\n",
    "        while True:\n",
    "            schedule2 = random.sample(all_courses, courses_per_schedule)\n",
    "            if is_valid_schedule(schedule2, timetable):\n",
    "                break\n",
    "\n",
    "        schedule_pairs.append((schedule1, schedule2))\n",
    "\n",
    "    return schedule_pairs\n",
    "\n",
    "def test_pairwise_answer_accuracy(student, timetable, report, num_tests=5):\n",
    "    num_courses = len(student.additive_prefs)\n",
    "    schedule_pairs = generate_random_schedule_pairs(num_courses, num_pairs=num_tests, timetable = timetable)\n",
    "    correct_answers = 0\n",
    "    long_answers = 0\n",
    "\n",
    "    for schedule1, schedule2 in schedule_pairs:\n",
    "        model_answer = answer_pairwise_question_exact_preferences(student, (schedule1, schedule2), timetable, report)\n",
    "        # Calculate true values\n",
    "        value1 = calculate_true_bundle_value(np.array([1 if i in schedule1 else 0 for i in range(num_courses)]),\n",
    "                                             student, timetable, ignore_timegaps=False, make_monotone=False)\n",
    "        value2 = calculate_true_bundle_value(np.array([1 if i in schedule2 else 0 for i in range(num_courses)]),\n",
    "                                             student, timetable, ignore_timegaps=False, make_monotone=False)\n",
    "\n",
    "        # print(value1, value2)\n",
    "        # print(schedule1, schedule2, \"length\")\n",
    "        true_answer = 1 if value2 >= value1 else 0\n",
    "        print(\"True answer should be: \", true_answer+1)\n",
    "\n",
    "        if model_answer-1 == true_answer:\n",
    "            correct_answers += 1\n",
    "    return correct_answers, num_tests\n",
    "\n",
    "# Test the accuracy\n",
    "student_to_test = true_student_list[0]\n",
    "acc = 0\n",
    "for student in true_student_list[:5]:\n",
    "  report = get_student_text_preferences(student, timetable)\n",
    "  # print(student, timetable, report)\n",
    "  correct_answers, total_tests = test_pairwise_answer_accuracy(student, timetable, report)\n",
    "  accuracy = correct_answers / total_tests\n",
    "  acc += correct_answers\n",
    "  print(f\"Accuracy of answer_pairwise_question_exact_preferences: {correct_answers}/{total_tests} ({accuracy:.2%})\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acc/25"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
