{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install gurobipy\n",
    "!pip install edsl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install --upgrade transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import login\n",
    "login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformers\n",
    "import torch\n",
    "\n",
    "# Replace for local model once in cluster/downloaded locally\n",
    "model_id = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n",
    "# model_id = \"mlx-community/mathstral-7B-v0.1-fp16\"\n",
    "\n",
    "pipeline = transformers.pipeline(\n",
    "    \"text-generation\",\n",
    "    model=model_id,\n",
    "    model_kwargs={\"torch_dtype\": torch.bfloat16},\n",
    "    device_map=\"auto\",\n",
    "    # tokenizer=\"mistralai/Mathstral-7B-v0.1\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_answer_from_llm(pipeline, prompt):\n",
    "    outputs = pipeline(\n",
    "    prompt,\n",
    "    max_new_tokens=2048,\n",
    "    )\n",
    "    print(outputs[0][\"generated_text\"][-1][\"content\"], \"LLM Output\")\n",
    "    return outputs[0][\"generated_text\"][-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "%cd Course-Match-Preference-Simulator/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from preference_generator import generate_time_problem_instance, calculate_true_bundle_value\n",
    "import numpy as np\n",
    "from edsl import QuestionList, ScenarioList, QuestionFreeText\n",
    "import random\n",
    "from typing import List\n",
    "import re\n",
    "\n",
    "\n",
    "# Generate a single time preference instance\n",
    "true_student_list, realized_student_types, capacities, timetable = generate_time_problem_instance(\n",
    "    number_of_students=100,\n",
    "    number_of_courses=25,\n",
    "    supply_ratio=1.25,\n",
    "    capacity_deviation=0,\n",
    "    student_types=['no_overload', 'free_days', 'few_timegaps', 'balanced'],\n",
    "    type_probabilities=[1/4 for _ in range(4)],\n",
    "    seed=42\n",
    ")\n",
    "\n",
    "def textify_schedule(course_int_list):\n",
    "    return \"(\" + \",\".join([f\"Course {course_id}\" for course_id in course_int_list]) + \")\"\n",
    "\n",
    "\n",
    "## Print the course information\n",
    "# Assume we have a function to get course information\n",
    "def get_course_info(course_id):\n",
    "    course_names = [f\"Course {i}\" for i in range(25)]  # Assuming 25 courses\n",
    "    credit_units = [1 for _ in range(25)]  # Assuming all courses are 1 credit unit\n",
    "    return {\n",
    "        \"name\": course_names[course_id],\n",
    "        \"credit_units\": credit_units[course_id],\n",
    "        \"id\": course_id\n",
    "    }\n",
    "\n",
    "def timetable_string(timetable):\n",
    "    output = \"\\nCourse Timetable:\\n\"\n",
    "    days = [\"Monday\", \"Tuesday\", \"Wednesday\", \"Thursday\", \"Friday\"]\n",
    "    output += \"Time Slot | \" + \" | \".join(days) + \"\\n\"\n",
    "    output += \"-\" * 60 + \"\\n\"\n",
    "\n",
    "    max_slots = max(len(day) for day in timetable)\n",
    "\n",
    "    for slot in range(max_slots):\n",
    "        slot_info = [f\"Slot {slot:2d}\"]\n",
    "        for day in timetable:\n",
    "            if slot < len(day):\n",
    "                courses = \", \".join([get_course_info(c)['name'] for c in day[slot]])\n",
    "                slot_info.append(f\"{courses:10s}\")\n",
    "            else:\n",
    "                slot_info.append(\" \" * 10)\n",
    "        output += \" | \".join(slot_info) + \"\\n\"\n",
    "\n",
    "    return output\n",
    "\n",
    "def generate_course_info_string(courses, timetable):\n",
    "    days = [\"Monday\", \"Tuesday\", \"Wednesday\", \"Thursday\", \"Friday\"]\n",
    "    course_info_list = []\n",
    "\n",
    "    def group_slots(slots):\n",
    "        \"\"\" Helper function to group contiguous time slots \"\"\"\n",
    "        if not slots:\n",
    "            return \"\"\n",
    "        ranges = []\n",
    "        start = prev = slots[0]\n",
    "        for slot in slots[1:]:\n",
    "            if slot == prev + 1:\n",
    "                prev = slot\n",
    "            else:\n",
    "                ranges.append((start, prev))\n",
    "                start = prev = slot\n",
    "        ranges.append((start, prev))\n",
    "        return \", \".join(f\"time slot {start}\" if start == end else f\"time slots {start}-{end}\" for start, end in ranges)\n",
    "\n",
    "    for course_id in courses:\n",
    "        course = get_course_info(course_id)\n",
    "        schedule = []\n",
    "        for day_index, day in enumerate(timetable):\n",
    "            day_slots = [slot_index for slot_index, slot in enumerate(day) if course_id in slot]\n",
    "            if day_slots:\n",
    "                schedule.append(f\"{days[day_index]} at {group_slots(day_slots)}\")\n",
    "        schedule_str = \", \".join(schedule)\n",
    "        course_info = f\"Course {course['id']} ({course['name']}) with {course['credit_units']} credit units is scheduled on {schedule_str}\"\n",
    "        course_info_list.append(course_info)\n",
    "\n",
    "    return \"; \".join(course_info_list) + \".\"\n",
    "\n",
    "def generate_schedule_overview(courses, timetable):\n",
    "    days = [\"Monday\", \"Tuesday\", \"Wednesday\", \"Thursday\", \"Friday\"]\n",
    "    overview = []\n",
    "\n",
    "    def group_slots(slots):\n",
    "        \"\"\" Helper function to group contiguous time slots \"\"\"\n",
    "        if not slots:\n",
    "            return \"\"\n",
    "        ranges = []\n",
    "        start = prev = slots[0]\n",
    "        for slot in slots[1:]:\n",
    "            if slot == prev + 1:\n",
    "                prev = slot\n",
    "            else:\n",
    "                ranges.append((start, prev))\n",
    "                start = prev = slot\n",
    "        ranges.append((start, prev))\n",
    "        return \", \".join(f\"{start}\" if start == end else f\"{start}-{end}\" for start, end in ranges)\n",
    "\n",
    "    for day_index, day in enumerate(timetable):\n",
    "        daily_schedule = []\n",
    "        for slot_index, slot in enumerate(day):\n",
    "            for course_id in slot:\n",
    "                if course_id in courses:  # Only consider courses in the provided list\n",
    "                    course = get_course_info(course_id)\n",
    "                    daily_schedule.append((slot_index, course))\n",
    "\n",
    "        # Sort by time slots\n",
    "        daily_schedule.sort(key=lambda x: x[0])\n",
    "\n",
    "        # Group by course and slots\n",
    "        course_slots = {}\n",
    "        for slot_index, course in daily_schedule:\n",
    "            if course['id'] not in course_slots:\n",
    "                course_slots[course['id']] = []\n",
    "            course_slots[course['id']].append(slot_index)\n",
    "\n",
    "        # Format the day's overview\n",
    "        day_overview = []\n",
    "        for course_id, slots in course_slots.items():\n",
    "            course = get_course_info(course_id)\n",
    "            grouped_slots = group_slots(slots)\n",
    "            day_overview.append(f\"{course['name']} (slots {grouped_slots})\")\n",
    "\n",
    "        if day_overview:\n",
    "            overview.append(f\"{days[day_index]}: \" + \", \".join(day_overview))\n",
    "\n",
    "    return \"; \".join(overview) + \".\"\n",
    "\n",
    "def print_student_info(student, timetable):\n",
    "    _, _, _, overload_penalty, timegap_penalty, free_days_marginal_values, budget = student\n",
    "\n",
    "    output = \"Student Time Preferences: \"\n",
    "\n",
    "    # Overload penalty\n",
    "    output += f\"- Overload penalty: Penalty based on how far in hours the day with the most courses exceeds the average hours per day for all the other days. {overload_penalty:.2f} per increased difference in hour\\n\"\n",
    "\n",
    "    # Timegap penalty\n",
    "    output += f\"- Timegap penalty: {timegap_penalty:.2f} per hour of gap between classes. For example, if a day has courses in slot 4-5, slot 7-8 and slot 10-11 the gap is (7-5)+(10-8)=4. Note: if a day only has one courses, there's no time gap. \\n\"\n",
    "\n",
    "    # Free days preference\n",
    "    output += \"- Free days preference (student prefers more free days):\\n\"\n",
    "    for i, value in enumerate(free_days_marginal_values):\n",
    "        if i == 0:\n",
    "            output += f\"  * First free day: {value:.2f}\\n\"\n",
    "        elif i == 1:\n",
    "            output += f\"  * Second free day: {value:.2f}\\n\"\n",
    "        elif i == 2:\n",
    "            output += f\"  * Third free day: {value:.2f}\\n\"\n",
    "        elif i == 3:\n",
    "            output += f\"  * Fourth free day: {value:.2f}\\n\"\n",
    "        elif i == 4:\n",
    "            output += f\"  * Fifth free day: {value:.2f}\\n\"\n",
    "\n",
    "    # Budget\n",
    "    output += f\"- Budget: {budget:.2f}\\n\"\n",
    "\n",
    "    return output\n",
    "\n",
    "def extract_numeric_answer(answer):\n",
    "    # Get the last line of the answer\n",
    "    last_line = answer.strip().splitlines()[-1].lower()\n",
    "    last_line = last_line[-4:]\n",
    "\n",
    "    # Use a dictionary to map words to numbers\n",
    "    word_to_num = {\n",
    "        \"one\": 1,\n",
    "        \"first\": 1,\n",
    "        \"two\": 2,\n",
    "        \"second\": 2,\n",
    "    }\n",
    "\n",
    "    # Check for numeric values directly in the last line\n",
    "    match = re.search(r'\\b(1|2)\\b', last_line)\n",
    "    if match:\n",
    "        return int(match.group(0))\n",
    "\n",
    "    # Check for word-based answers in the last line\n",
    "    for word, num in word_to_num.items():\n",
    "        if word in last_line:\n",
    "            return num\n",
    "\n",
    "    # Handle cases where there might be a period or other characters at the end\n",
    "    try:\n",
    "        return int(last_line[-1])\n",
    "    except ValueError:\n",
    "        print(\"ERROR!!! Could not extract number from the last line:\")\n",
    "        print(last_line)\n",
    "        return -1\n",
    "\n",
    "# \"easy mode\"\n",
    "def answer_pairwise_question_exact_preferences(student, schedule_pair, timetable, print_prompt=False):\n",
    "    # not yet querying the LLM, this is dummy code for now\n",
    "    # not yet querying the LLM\n",
    "    question = f\"\"\"A student is trying to choose a course schedule between two options.\n",
    "   A student prefers more free days in schedule (free days are days with no courses, from Monday to Friday), less time gaps between courses in the same day (the time slots are closer in number), and less overload in schedule (a balanced schedule between days). However, each student differs in how they value each component:\n",
    "    In particular, a student has the following preferences: {print_student_info(student, timetable)}.\n",
    "    The first schedule has {textify_schedule(schedule_pair[0])}. Under the first schedule, the schedule for the student for the week is given by {generate_schedule_overview(schedule_pair[0], timetable)}.\n",
    "    The second schedule has {textify_schedule(schedule_pair[1])}. Under the second schedule, the schedule for the student for the week is given by {generate_schedule_overview(schedule_pair[1], timetable)}\n",
    "    Which schedule should the student choose? Answer in the following format where X is the information you fill using the knowledge given above, note that the last line contains only one number corresponding to your answer, also make sure that you accurately remember the difference in courses and schedule between Schedule 1 and Schedule 2, given above:\n",
    "\n",
    "    Example:\n",
    "      Schedule 1:\n",
    "      - The schedule is given by ...\n",
    "      - List what days are not free where there's a course that day: X. Therefore, excluding these days, the student has X free days. This is worth X in utility for the student.\n",
    "      - The day with the most courses is X (X courses) with a total of X hours, and the average courseload excluding that day with the most courses is X hours. The overload penalty for each difference in hour is X. Therefore the contribution to utility is X. (This should be a negative number since it's a penalty)\n",
    "      - The days with more than one courses are X. Of these days, the time gaps between the courses are X hours so the total time gaps is X hours. The penalty for each hour gap is X, therefore the contribution to utility is X. (This should be a negative number since it's a penalty)\n",
    "      - The total score for Schedule 1 is therefore: X\n",
    "\n",
    "      Schedule 2:\n",
    "      - The schedule is given by ...\n",
    "      - List what days are not free where there's a course that day: X. Therefore, excluding these days, the student has X free days. This is worth X in utility for the student.\n",
    "      - The day with the most courses is X (X courses) with a total of X hours, and the average courseload excluding that day with the most courses is X hours. The overload penalty for each difference in hour is X. Therefore the contribution to utility is X. (This should be a negative number since it's a penalty)\n",
    "      - The days with more than one courses are X. Of these days, the time gaps between the courses are X hours so the total time gaps is X hours. The penalty for each hour gap is X, therefore the contribution to utility is X. (This should be a negative number since it's a penalty)\n",
    "      - The total score for Schedule 2 is therefore: X\n",
    "\n",
    "      Answer:\n",
    "      The utility for Schedule 1 is X.\n",
    "      The utility for Schedule 2 is X.\n",
    "      Schedule X is higher in utility.\n",
    "      Therefore, the student should choose Schedule X\n",
    "    \"\"\"\n",
    "\n",
    "    # Add \"No explanations\" to get rid of reasoning. Reasoning is not turned off now since we want to understand how good this is.\"\n",
    "    prompt = [{\"role\": \"user\", \"content\": f\"{question}\"}]\n",
    "    if print_prompt:\n",
    "      print(prompt)\n",
    "      # print(\"Option 1: \", generate_course_info_string(schedule_pair[0], timetable))\n",
    "      # print(\"Option 2: \", generate_course_info_string(schedule_pair[1], timetable))\n",
    "      print(generate_schedule_overview(schedule_pair[0], timetable))\n",
    "\n",
    "    answer = get_answer_from_llm(pipeline, prompt)\n",
    "\n",
    "    content = answer[\"content\"]\n",
    "    return extract_numeric_answer(content)\n",
    "\n",
    "\n",
    "\n",
    "student_to_test = true_student_list[0]\n",
    "\n",
    "# print(true_student_list)\n",
    "\n",
    "# Example\n",
    "answer = answer_pairwise_question_exact_preferences(student_to_test, ([15,17,9,6],[15,3,22,1]), timetable, print_prompt=True)\n",
    "answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def courses_overlap(course1, course2, timetable):\n",
    "    for week in timetable:\n",
    "        for day in week:\n",
    "            if course1 in day and course2 in day:\n",
    "                return True\n",
    "    return False\n",
    "\n",
    "def is_valid_schedule(schedule, timetable):\n",
    "    for i in range(len(schedule)):\n",
    "        for j in range(i + 1, len(schedule)):\n",
    "            if courses_overlap(schedule[i], schedule[j], timetable):\n",
    "                return False\n",
    "    return True\n",
    "\n",
    "def generate_random_schedule_pairs(num_courses, num_pairs=10, courses_per_schedule=5, timetable=None):\n",
    "    all_courses = list(range(num_courses))\n",
    "    schedule_pairs = []\n",
    "\n",
    "    for _ in range(num_pairs):\n",
    "        # Generate the first schedule with valid non-overlapping courses\n",
    "        while True:\n",
    "            schedule1 = random.sample(all_courses, courses_per_schedule)\n",
    "            if is_valid_schedule(schedule1, timetable):\n",
    "                break\n",
    "\n",
    "\n",
    "        # Generate the second schedule with valid non-overlapping courses\n",
    "        while True:\n",
    "            schedule2 = random.sample(all_courses, courses_per_schedule)\n",
    "            if is_valid_schedule(schedule2, timetable):\n",
    "                break\n",
    "\n",
    "        schedule_pairs.append((schedule1, schedule2))\n",
    "\n",
    "    return schedule_pairs\n",
    "\n",
    "def test_pairwise_answer_accuracy(student, timetable, num_tests=50):\n",
    "    num_courses = len(student.additive_prefs)\n",
    "    schedule_pairs = generate_random_schedule_pairs(num_courses, num_tests, timetable = timetable)\n",
    "\n",
    "    correct_answers = 0\n",
    "    long_answers = 0\n",
    "\n",
    "    for schedule1, schedule2 in schedule_pairs:\n",
    "        model_answer = answer_pairwise_question_exact_preferences(student, (schedule1, schedule2), timetable)\n",
    "        # Calculate true values\n",
    "        value1 = calculate_true_bundle_value(np.array([1 if i in schedule1 else 0 for i in range(num_courses)]),\n",
    "                                             student, timetable, ignore_timegaps=False, make_monotone=False)\n",
    "        value2 = calculate_true_bundle_value(np.array([1 if i in schedule2 else 0 for i in range(num_courses)]),\n",
    "                                             student, timetable, ignore_timegaps=False, make_monotone=False)\n",
    "\n",
    "        # print(value1, value2)\n",
    "        # print(schedule1, schedule2, \"length\")\n",
    "        true_answer = 1 if value2 >= value1 else 0\n",
    "        print(\"True answer should be: \", true_answer+1)\n",
    "\n",
    "        if model_answer-1 == true_answer:\n",
    "            correct_answers += 1\n",
    "    return correct_answers, num_tests\n",
    "\n",
    "# Test the accuracy\n",
    "student_to_test = true_student_list[0]\n",
    "correct_answers, total_tests = test_pairwise_answer_accuracy(student_to_test, timetable)\n",
    "accuracy = correct_answers / total_tests\n",
    "print(f\"Accuracy of answer_pairwise_question_exact_preferences: {correct_answers}/{total_tests} ({accuracy:.2%})\")\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
