from collections import OrderedDict
import os
import xml.etree.ElementTree as ET
from xml.sax.saxutils import escape

from core import BASE_PATH
from core.utils import remove_empty
from curriculum import TIPS_START, TIPS_END

XML_STYLE = True

def _get_tip_ids(lesson: ET.Element) -> list:
    tip_ids = lesson.findall("tip_ids")
    if not tip_ids:
        return []

    assert len(tip_ids) == 1, "multiple <tip_ids> entries found"
    tip_ids = tip_ids[0].text.strip()

    return [tip_id.strip() for tip_id in tip_ids.split("\n")]

def _get_element_text(
    lesson: ET.Element,
    tag: str,
    required: bool = False,
    multiple: bool = False,  # multiple entries of the same tag are allowed
                             # in this case a list of strings is returned
):
    """Get the text from the first tag found in the lesson."""
    elements = lesson.findall(tag)
    if not elements:
        if required:
            raise ValueError(f"no <{tag}> entry found")
        return [] if multiple else ""

    if not multiple:
        assert len(elements) == 1, f"multiple <{tag}> entries found"
        return elements[0].text.strip()

    return [element.text.strip() for element in elements]

def read_examples(example_ids: list[str]) -> dict[str, str]:
    """Read examples from the curriculum_v2/examples directory.

    The values of the returned dictionary are XML strings.
    """
    filenames = {
        example_id: example_id.split('/')[0]
        for example_id in example_ids
    }

    examples = {}
    for filename in set(filenames.values()):
        filepath = BASE_PATH / "curriculum_v2" / "examples" / f"{filename}.xml"
        if not filepath.exists():
            raise FileNotFoundError(f"Example file {filepath} not found")

        root = ET.parse(filepath).getroot()
        for example_id, current_filename in filenames.items():
            short_id = example_id.split('/')[1]
            if current_filename == filename:
                example = root.find(f".//example[@id='{short_id}']")
                if example is None:
                    raise ValueError(f"Example {short_id} not found in {filepath}")
                # Remove the attribute "id" from the example
                del example.attrib["id"]
                examples[example_id] = ET.tostring(example, encoding="unicode")

    return examples


class Lesson:
    def __init__(self, lesson: ET.Element):
        self.id = lesson.get("id")
        self.response_format = lesson.get("response_format")
        self._element = lesson
        self.tip_ids = _get_tip_ids(lesson)
        self.material = lesson.find("material")

        self.instructions = _get_element_text(lesson, "instructions", required=False, multiple=False)
        self.exercises: list[Exercise] = [
            Exercise.from_xml(exercise)
            for exercise in lesson.findall('exercise')
        ]

    def __str__(self):
        txt = "Lesson:"
        if self.tip_ids:
            txt += f"tip_ids={self.tip_ids}"
        if self.material:
            txt += f" material={self.material[:10]}..."
        if self.instructions:
            txt += f" instructions={self.instructions[:10]}..."
        txt += f" exercises={self.exercises}"
        return txt

    def __repr__(self):
        return str(self)

    def render_material(self):
        """This inserts examples into the material text. Examples are escaped."""
        # Read examples
        if self.material is None:
            return ""
        example_ids = [example.get("id") for example in self.material.findall(".//example")]
        if not len(example_ids):
            return self.material.text.strip()
        examples = read_examples(example_ids)

        s = ET.tostring(self.material, encoding='unicode')
        for example_id, example in examples.items():
            s = s.replace(f'<example id="{example_id}" />', escape(example))
        elem = ET.fromstring(s)

        return elem.text.strip() if s else ""

    def create_exercise_prompts(self, verbose) -> list["Exercise"]:
        my_print = print if verbose else lambda *x, **y: None

        delimiter = "---\n\n"
        tips = [self.render_material()] + [TIPS[tip_id] for tip_id in self.tip_ids]
        tips = remove_empty(tips)
        if tips:
            tips.append(delimiter)
        tips_txt = "\n\n".join(tips)
        my_print("tips_txt:", repr(tips_txt))
        instructions = self.instructions

        for i, exercise in enumerate(self.exercises):
            student_prompt = "\n\n".join(remove_empty([instructions, str(exercise)]))
            my_print("student_prompt:", student_prompt)

            # The limit on the number of generated tokens is determined by length of the student prompt

            if tips_txt:
                teacher_prompt = tips_txt + student_prompt
                teacher_prompt_with_tips_tags = f"{TIPS_START}{tips_txt}{TIPS_END}" + student_prompt
            else:
                teacher_prompt = teacher_prompt_with_tips_tags = student_prompt

            my_print("teacher_prompt:", teacher_prompt)
            my_print("teacher_prompt_with_tips_tags:", teacher_prompt_with_tips_tags)

            exercise.add_prompts(student_prompt, teacher_prompt, teacher_prompt_with_tips_tags)

        return self.exercises


def read_lessons(filepath: os.PathLike, error_if_not_found=True) -> dict[str, Lesson]:
    # The file is reqular xml file.
    try:
        tree = ET.parse(filepath)
    except FileNotFoundError as e:
        if error_if_not_found:
            raise e
        else:
            print(f"File {filepath} not found")
            return {}

    root = tree.getroot()
    lessons = OrderedDict(
        (lesson.get("id"), Lesson(lesson))
        for lesson in root.findall('lesson')
    )
    return lessons


class Exercise:
    def __init__(
        self,
        exercise: str,
        model_answer: str = None,
        grading_str: str = None,
        distractor_elems: list[ET.Element] = None,
    ):
        self.exercise = exercise
        self.model_answer = model_answer
        self.grading_str = grading_str
        self.distractor_elems = distractor_elems

        # The following are added with add_prompts
        self.student_prompt: str = None
        self.teacher_prompt: str = None
        self.teacher_prompt_with_tips_tags: str = None

    @classmethod
    def from_xml(cls, exercise_elem: ET.Element):
        text_content = [exercise_elem.text.strip()]
        for part in exercise_elem:
            if part.tail:  # Tail text is the text after a subelement, but before the next
                text_content.append(part.tail.strip())
        exercise = " ".join(text_content).strip()
        model_answer_elem = exercise_elem.find('model_answer')
        model_answer = model_answer_elem.text.strip() if model_answer_elem is not None else None
        grading_str_elem = exercise_elem.find('grading_str')
        grading_str = grading_str_elem.text.strip() if grading_str_elem is not None else None
        distractor_elems = exercise_elem.findall('distractor')
        for elem in distractor_elems:
            elem.tail = None
        return cls(exercise, model_answer, grading_str, distractor_elems)

    def __str__(self):
        s = [
            self.exercise,
        ]
        return "\n\n".join(remove_empty([self.exercise]))

    def __repr__(self):
        return f"Exercise(exercise={self.exercise}, model_answer={self.model_answer}, grading_str={self.grading_str})"

    def add_prompts(self, student_prompt, teacher_prompt, teacher_prompt_with_tips_tags):
        self.student_prompt = student_prompt
        self.teacher_prompt = teacher_prompt
        self.teacher_prompt_with_tips_tags = teacher_prompt_with_tips_tags


