import transformers
import torch
from preference_generator import generate_time_problem_instance, calculate_true_bundle_value
import numpy as np
from edsl import QuestionList, ScenarioList, QuestionFreeText
import random
from typing import List

# Replace for local model once in cluster/downloaded locally
model_path = "..."

pipeline = transformers.pipeline(
    "text-generation",
    model=model_path,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
)

def get_answer_from_llm(pipeline, prompt):
    outputs = pipeline(
    prompt,
    max_new_tokens=1024,
    )
    print(outputs[0]["generated_text"][-1])
    return outputs[0]["generated_text"][-1]



# Generate a single time preference instance
true_student_list, realized_student_types, capacities, timetable = generate_time_problem_instance(
    number_of_students=100,
    number_of_courses=25,
    supply_ratio=1.25,
    capacity_deviation=0,
    student_types=['no_overload', 'free_days', 'few_timegaps', 'balanced'],
    type_probabilities=[1/4 for _ in range(4)],
    seed=42
)

def textify_schedule(course_int_list):
    return "(" + ",".join([f"Course {course_id}" for course_id in course_int_list]) + ")"


## Print the course information
# Assume we have a function to get course information
def get_course_info(course_id):
    course_names = [f"Course {i}" for i in range(25)]  # Assuming 25 courses
    credit_units = [1 for _ in range(25)]  # Assuming all courses are 1 credit unit
    return {
        "name": course_names[course_id],
        "credit_units": credit_units[course_id],
        "id": course_id
    }

def timetable_string(timetable):
    output = "\nCourse Timetable:\n"
    days = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"]
    output += "Time Slot | " + " | ".join(days) + "\n"
    output += "-" * 60 + "\n"

    max_slots = max(len(day) for day in timetable)

    for slot in range(max_slots):
        slot_info = [f"Slot {slot:2d}"]
        for day in timetable:
            if slot < len(day):
                courses = ", ".join([get_course_info(c)['name'] for c in day[slot]])
                slot_info.append(f"{courses:10s}")
            else:
                slot_info.append(" " * 10)
        output += " | ".join(slot_info) + "\n"

    return output

def generate_course_info_string(courses, timetable):
    days = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"]
    course_info_list = []
    for course_id in courses:
        course = get_course_info(course_id)
        schedule = []
        for day_index, day in enumerate(timetable):
            for slot_index, slot in enumerate(day):
                if course_id in slot:
                    schedule.append(f"{days[day_index]} at time slot {slot_index}")
        schedule_str = ", ".join(schedule)
        course_info = f"Course {course['id']} ({course['name']}) with {course['credit_units']} credit units is scheduled on {schedule_str}"
        course_info_list.append(course_info)

    return "; ".join(course_info_list) + "."



def print_student_info(student, timetable):
    _, _, _, overload_penalty, timegap_penalty, free_days_marginal_values, budget = student
    
    output = "Student Time Preferences:\n"
    
    # Overload penalty
    output += f"- Overload penalty: Penalty based on how far the day with the most courses exceeds the average. {overload_penalty:.2f} per increased credit unit\n"
    
    # Timegap penalty
    output += f"- Timegap penalty: {timegap_penalty:.2f} per hour of gap between classes\n"
    
    # Free days preference
    output += "- Free days preference (student prefers more free days):\n"
    for i, value in enumerate(free_days_marginal_values):
        if i == 0:
            output += f"  * First free day: {value:.2f}\n"
        elif i == 1:
            output += f"  * Second free day: {value:.2f}\n"
        elif i == 2:
            output += f"  * Third free day: {value:.2f}\n"
        elif i == 3:
            output += f"  * Fourth free day: {value:.2f}\n"
        elif i == 4:
            output += f"  * Fifth free day: {value:.2f}\n"
    
    # Budget
    output += f"- Budget: {budget:.2f}\n"
    
    return output

# "easy mode"
def answer_pairwise_question_exact_preferences(student, schedule_pair, timetable):
    # removed the entire timetable, can add it back if needed
    question = f"""A student is trying to choose a course schedule between two options.
    The student is trying to choose between the following two schedules: {textify_schedule(schedule_pair[0])} and {textify_schedule(schedule_pair[1])}.
    A student prefers more free days in schedule (free days are days with no courses), less time gaps between courses in the same day (the time slots are closer in number), and less overload in schedule (a balanced schedule between days). However, each student differs in how they value each component:
    In particular, a student has the following preferences: {print_student_info(student, timetable)}. 
    The first schedule has {generate_course_info_string(schedule_pair[0], timetable)}, the second schedule has {generate_course_info_string(schedule_pair[1], timetable)}.
    Which schedule should the student choose? Simply answer 0 for the first schedule and 1 for the second schedule. No explanations.
    """
    # Add "No explanations" to get rid of reasoning.
    prompt = [{"role": "user", "content": f"{question}"}]

    answer = get_answer_from_llm(pipeline, prompt)
        
    content = answer["content"]
    try:
        return int(content)
    except ValueError:
        print(content)
        print("Did not output single number!!!")
        return -1


# print(timetable_string(timetable))


# Should additionally check that no courses overlap
def generate_random_schedule_pairs(num_courses, num_pairs=10, courses_per_schedule=5):
    all_courses = list(range(num_courses))
    schedule_pairs = []
    
    for _ in range(num_pairs):
        schedule1 = random.sample(all_courses, courses_per_schedule)
        remaining_courses = [c for c in all_courses if c not in schedule1]
        schedule2 = random.sample(all_courses, courses_per_schedule)
        schedule_pairs.append((schedule1, schedule2))
    
    return schedule_pairs

def test_pairwise_answer_accuracy(student, timetable, num_tests=100):
    num_courses = len(student.additive_prefs)
    schedule_pairs = generate_random_schedule_pairs(num_courses, num_tests)

    correct_answers = 0

    for schedule1, schedule2 in schedule_pairs:
        model_answer = answer_pairwise_question_exact_preferences(student, (schedule1, schedule2), timetable)
        # Calculate true values
        value1 = calculate_true_bundle_value(np.array([1 if i in schedule1 else 0 for i in range(num_courses)]),
                                             student, timetable, ignore_timegaps=False, make_monotone=False)
        value2 = calculate_true_bundle_value(np.array([1 if i in schedule2 else 0 for i in range(num_courses)]),
                                             student, timetable, ignore_timegaps=False, make_monotone=False)

        print(value1, value2)
        true_answer = 1 if value2 >= value1 else 0

        if model_answer == true_answer:
            correct_answers += 1

    return correct_answers, num_tests

# Test the accuracy
student_to_test = true_student_list[0]
correct_answers, total_tests = test_pairwise_answer_accuracy(student_to_test, timetable)
accuracy = correct_answers / total_tests
print(f"Accuracy of answer_pairwise_question_exact_preferences: {correct_answers}/{total_tests} ({accuracy:.2%})")

#student_to_test = true_student_list[0]

# print(true_student_list)

# Example
#answer = answer_pairwise_question_exact_preferences(student_to_test, ([15,17,9,6],[15,3,22,1]), timetable)
#print(answer)
