# %%
import json
from dataclasses import dataclass
import os

import xml.etree.ElementTree as ET
from xml.sax.saxutils import escape

from core import NO_TIPS_ESCAPE
from curriculum import TIPS_START, TIPS_END
from core.messages import Message

def _get_messages(xml_element):
    messages_element = xml_element.find('messages')
    assert messages_element is not None, "messages element not found"
    return [
        Message.from_xml_element(message)
        for message in messages_element.findall('message')
    ]

@dataclass
class Choice:
    content: str
    truncated: bool = False


def _get_answer_choices(xml_element):
    answer_choices_element = xml_element.find('answer_choices')
    return [
        Choice(choice.text.strip() if choice.text else ' ', truncated=choice.get('truncated') == "true")
        for choice in answer_choices_element.findall('choice')
    ]

def _get_model_answer(xml_element):
    model_answer_element = xml_element.find('model_answer')
    if model_answer_element is None:
        return None
    return ModelAnswer(model_answer_element.text.strip())

def _get_grading_str(xml_element):
    grading_str_element = xml_element.find('grading_str')
    if grading_str_element is None:
        return None
    return GradingStr(grading_str_element.text.strip())

@dataclass
class ModelAnswer:
    content: str

@dataclass
class GradingStr:
    content: str

class ExerciseWithAnswers:
    def __init__(
        self,
        messages: list[Message],
        answer_choices: list[Choice] = None,
        lesson_id: str = None,
        model_answer: str = None,
        grading_str: str = None,
    ):
        self.messages = messages
        self.answer_choices = answer_choices or []
        self.lesson_id = lesson_id
        self.model_answer = model_answer
        self.grading_str = grading_str

    @classmethod
    def from_xml(cls, xml_element: ET.Element, lesson_id: str):
        messages = _get_messages(xml_element)
        answer_choices = _get_answer_choices(xml_element)
        model_answer = _get_model_answer(xml_element)
        grading_str = _get_grading_str(xml_element)
        return cls(messages, answer_choices, lesson_id, model_answer, grading_str)

    def to_xml(self, parent: ET.Element):
        element = ET.SubElement(parent, "exercise_with_answers")
        messages_element = ET.SubElement(element, "messages")

        for msg in self.messages:
            msg_element = ET.SubElement(messages_element, "message")
            msg_element.set("role", msg.role.value)
            msg_element.text = msg.content

        choices_element = ET.SubElement(element, "answer_choices")
        for choice in self.answer_choices:
            choice_element = ET.SubElement(choices_element, "choice")
            choice_element.text = choice.content
            if choice.truncated:
                choice_element.set("truncated", "true")

        if self.model_answer and len(self.model_answer):
            model_answer_element = ET.SubElement(element, "model_answer")
            model_answer_element.text = self.model_answer

        if self.grading_str and len(self.grading_str):
            grading_str_element = ET.SubElement(element, "grading_str")
            grading_str_element.text = self.grading_str

        return element

    def __str__(self):
        messages = self.messages
        answer_choices = self.answer_choices
        return f"ExerciseWithAnswers: {messages=}, {answer_choices=}"

    def __repr__(self):
        return str(self)


def xml_dump(element: ET.Element, file):
    xml_content = ET.tostring(element, encoding='unicode')

    # Start each tag on a new line
    xml_content = xml_content.replace("<", "\n<").replace(">", ">\n")

    if NO_TIPS_ESCAPE:
        xml_content = xml_content.replace(escape(TIPS_START), TIPS_START)
        xml_content = xml_content.replace(escape(TIPS_END), TIPS_END)
    file.write(xml_content)


def save_to_xml(
    filepath: os.PathLike,
    exercises_with_answers: list[ExerciseWithAnswers]
):
    root = ET.Element("exercises_with_answers")
    for ex in exercises_with_answers:
        ex.to_xml(root)

    with open(filepath, "w") as file:
        xml_dump(root, file)

    print(f"Saved to {filepath}")
