# %%
import numpy as np
import os
from pathlib import Path
import re
import torch
from torch.nn.utils.rnn import pad_sequence
import xml.etree.ElementTree as ET
from xml.sax.saxutils import escape

from core import DATA_PATH, NO_TIPS_ESCAPE
from core.messages import Role
from curriculum import TIPS_START, TIPS_END
from curriculum.exercise_with_answers import ExerciseWithAnswers
from curriculum.generate_distractor import build_distractor_dataset
from training import IGNORE_INDEX
from training.utils import tokenize, tip_split


def remove_non_xml_chars(text):
    # This pattern matches any character not allowed in XML 1.0
    # It includes:
    # - Characters below \x20 (excluding \x09, \x0A, and \x0D)
    # - Characters in the range \xD800 to \xDFFF (surrogate halves, not valid in UCS)
    # - Characters above \xFFFD
    disallowed_chars_pattern = re.compile(
        '[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFDCF\uFDF0-\uFFFD]'
        '|[\uD800-\uDFFF]|[\U00010000-\U0010FFFF]', re.UNICODE)
    # Remove disallowed characters
    return disallowed_chars_pattern.sub('', text)


def read_exercises(filepath: os.PathLike) -> list[ExerciseWithAnswers]:
    if not NO_TIPS_ESCAPE:
        tree = ET.parse(filepath)
        root = tree.getroot()
    else:
        with open(filepath, 'r', encoding='utf-8') as file:
            xml_content = file.read()
        xml_content = xml_content.replace(TIPS_START, escape(TIPS_START))
        xml_content = xml_content.replace(TIPS_END, escape(TIPS_END))
        xml_content = remove_non_xml_chars(xml_content)
        try:
            root = ET.fromstring(xml_content)
        except Exception as e:
            print("Failed to read", filepath)
            raise e

    lesson_id, _ = os.path.splitext(os.path.basename(filepath))
    if 'LOCAL_RANK' not in os.environ or os.environ['LOCAL_RANK'] == 0:
        print("lesson_id", lesson_id)
    exercises = [
        ExerciseWithAnswers.from_xml(ex, lesson_id=lesson_id)
        for ex in root.findall("exercise_with_answers")
    ]
    return exercises


# %%
class StudentTeacherDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        llm,
        filenames: list[str],  # list of xml files in the data/ directory, for example "tiny_train.xml",
        verbose: bool = False,
        datapath: Path = DATA_PATH,  # Path to the data folder
        max_length: int = 0,
        debug: bool = False,
        no_template: bool = False,
    ):
        assert isinstance(filenames, list), "filenames should be a list"

        if verbose:
            print("==== StudentTeacherDataset ====", flush=True)
        samples = []
        lesson_names = []  # Needed to group validation results by lesson

        for filename in filenames:
            filepath = datapath / filename
            if not filepath.exists():
                raise RuntimeError(f"{filename}: not found in {filepath}")

            lesson_name = os.path.splitext(os.path.basename(filename))[0]
            lesson_names.append(lesson_name)
            lesson_ix = len(lesson_names) - 1
            exercises = read_exercises(filepath)

            n_exercises_with_generated_choices = 0
            for exercise in exercises:
                for msg in exercise.messages:
                    if msg.role == Role.SYSTEM:
                        msg.role = Role.USER

                prompt_with_tips = llm.messages_to_prompt(exercise.messages, no_template=no_template)
                if no_template:
                    prompt_with_tips += "\nAnswer: "

                prompt_parts, tips = tip_split(prompt_with_tips)

                q_str = "".join([pp for i, pp in enumerate(prompt_parts) if not tips[i]])
                if no_template:
                    question = q_str.split("Answer:")[0].strip()
                else:
                    question = llm.extract_question(q_str)

                student_prompt_tokens, teacher_prompt_tokens = tokenize(prompt_with_tips, llm)

                for i, choice in enumerate(exercise.answer_choices):
                    content = choice.content
                    teacher_answer_tokens = None
                    answer_tokens = llm.tokenize(content)
                    if not choice.truncated:
                        answer_tokens = llm.add_eos(answer_tokens)

                    if max_length:
                        answer_tokens = answer_tokens[:, :max_length]
                    
                    sample = {
                        "student_prompt_tokens": student_prompt_tokens,
                        "teacher_prompt_tokens": teacher_prompt_tokens,
                        "answer_tokens": answer_tokens,
                        "teacher_answer_tokens": teacher_answer_tokens,
                        "teacher_answer": content,
                        "lesson_ix": lesson_ix,
                        "question": question,
                    }
                    samples.append(sample)
                n_exercises_with_generated_choices += len(exercise.answer_choices)

            if verbose:
                print(f"{lesson_name}: {n_exercises_with_generated_choices} exercises with generated choices", flush=True)

        self.samples = samples
        self.lesson_names = lesson_names

        if verbose:
            print("==== /StudentTeacherDataset ====", flush=True)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

    def collate_fn(self, samples, padding_value, llm):
        student_seqs = []
        student_labels = []
        teacher_seqs = []
        teacher_masks = []

        for sample in samples:
            student_prompt_tokens = sample["student_prompt_tokens"]
            seq = torch.cat([student_prompt_tokens, sample["answer_tokens"]], dim=1)

            # The targets are the prompt with response, but with the prompt masked out
            labels = seq.clone()
            student_prompt_len = student_prompt_tokens.size(1)
            labels[0, :student_prompt_len] = IGNORE_INDEX

            student_seqs.append(seq[0])
            student_labels.append(labels[0])

            if sample["teacher_answer_tokens"] is not None:
                seq = torch.cat([sample["teacher_prompt_tokens"], sample["teacher_answer_tokens"]], dim=1)
                teacher_seqs.append(seq[0])

                teacher_mask = torch.ones_like(seq[0], dtype=torch.bool)
                teacher_mask[:-len(sample["answer_tokens"][0])] = 0
            else:
                seq = torch.cat([sample["teacher_prompt_tokens"], sample["answer_tokens"]], dim=1)
                teacher_seqs.append(seq[0])

                teacher_mask = torch.ones_like(seq[0], dtype=torch.bool)
                teacher_prompt_len = sample["teacher_prompt_tokens"].size(1)
                teacher_mask[:teacher_prompt_len] = 0
            teacher_masks.append(teacher_mask)

        student_seqs = pad_sequence(student_seqs, batch_first=True, padding_value=padding_value)
        student_labels = pad_sequence(student_labels, batch_first=True, padding_value=IGNORE_INDEX).long()

        teacher_seqs = pad_sequence(teacher_seqs, batch_first=True, padding_value=padding_value)
        teacher_masks = pad_sequence(teacher_masks, batch_first=True, padding_value=0).bool()

        lesson_ixs = torch.tensor([sample["lesson_ix"] for sample in samples])

        return {
            'student_seqs': student_seqs,
            'student_labels': student_labels, # student_seq but with the prompt masked out
            'teacher_seqs': teacher_seqs,
            'teacher_masks': teacher_masks, # determines which tokens should be used for loss calculation
            'lesson_ixs': lesson_ixs,
        }


class TeacherDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        llm,
        filenames: list[str],  # list of xml files in the data/ directory, for example "tiny_train.xml",
        verbose: bool = False,
        datapath: Path = DATA_PATH,  # Path to the data folder
        max_length: int = 0,
        distractor_dataset: str = "",
):
        assert isinstance(filenames, list), "filenames should be a list"

        if verbose:
            print("==== TeacherDataset ====", flush=True)
        samples = []
        lesson_names = []  # Needed to group validation results by lesson

        if len(distractor_dataset):
            self.distractor_dataset = build_distractor_dataset(dataset=distractor_dataset)
        else:
            self.distractor_dataset = None

        for filename in filenames:
            filepath = datapath / filename
            if not filepath.exists():
                raise RuntimeError(f"{filename}: not found")

            lesson_name = os.path.splitext(os.path.basename(filename))[0]
            lesson_names.append(lesson_name)
            lesson_ix = len(lesson_names) - 1

            exercises = read_exercises(filepath)

            for exercise in exercises:
                for msg in exercise.messages:
                    if msg.role == Role.SYSTEM:
                        msg.role = Role.USER

                if len(exercise.answer_choices) != 1:
                    raise NotImplementedError("Multiple choices per answer are not currently supported in token loss training")
                answer_tokens = llm.tokenize(exercise.answer_choices[0].content)

                prompt_with_tips = llm.messages_to_prompt(exercise.messages)
                student_prompt_tokens, teacher_prompt_tokens = tokenize(prompt_with_tips, llm)
                if max_length:
                    answer_tokens = answer_tokens[:, :max_length]
                answer_tokens = llm.add_eos(answer_tokens)

                sample = {
                    "prompt_tokens": teacher_prompt_tokens,
                    "student_prompt_tokens": student_prompt_tokens,
                    "answer_tokens": answer_tokens,
                    "lesson_ix": lesson_ix,
                }
                samples.append(sample)

            if verbose:
                print(f"{lesson_name}: {len(samples)} exercises with answers", flush=True)

        self.samples = samples
        self.lesson_names = lesson_names

        if verbose:
            print("==== /TeacherDataset ====", flush=True)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

    def collate_fn(self, samples, padding_value, llm=None, max_total_length=0):
        seqs = []
        labels = []
        student_seqs = []
        student_labels = []
        for sample in samples:
            if self.distractor_dataset:
                prompt = llm.decode(sample["prompt_tokens"])
                pattern = r'<\|start_header_id\|>user<\|end_header_id\|>\n\n(.*?)\n\n---'
                match = re.search(pattern, prompt, re.DOTALL)

                if match:
                    extracted_content = match.group(1).strip()
                    lines = extracted_content.split('\n\n')
                    context = lines[0]
                    instr = '\n\n'.join(lines[1:])
                    distractors = self.distractor_dataset.sample()
                    if np.random.rand() < 0.6:
                        idx = np.random.randint(0, len(distractors)+1)
                        context_list = distractors[:idx] + [context] + distractors[idx:]
                        context = '\n\n'.join(context_list)
                    else:
                        context = '\n\n'.join(distractors)

                    if len(instr.strip()):
                        context = context + "\n\n" + instr
                    prompt = prompt.replace(extracted_content, context)
                    prompt_tokens = llm.tokenize(prompt)
                    seq = torch.cat([prompt_tokens, sample["answer_tokens"]], dim=1)
                    prompt_len = prompt_tokens.size(1)
                else:
                    print("WARNING. Context could not be extracted, not adding distractors", flush=True)
                    seq = torch.cat([sample["prompt_tokens"], sample["answer_tokens"]], dim=1)
                    prompt_len = sample["prompt_tokens"].size(1)
            else:
                seq = torch.cat([sample["prompt_tokens"], sample["answer_tokens"]], dim=1)
                prompt_len = sample["prompt_tokens"].size(1)

            target_labels = seq.clone()
            target_labels[0, :prompt_len] = IGNORE_INDEX
            if max_total_length:
                seq = seq[:, -max_total_length:]
                target_labels = target_labels[:, -max_total_length:]

            student_seq = torch.cat([sample["student_prompt_tokens"], sample["answer_tokens"]], dim=1)
            student_target_labels = student_seq.clone()
            student_prompt_len = sample["student_prompt_tokens"].size(1)
            student_target_labels[0, :student_prompt_len] = IGNORE_INDEX
            if max_total_length:
                student_seq = student_seq[:, -max_total_length:]
                student_target_labels = student_target_labels[:, -max_total_length:]

            seqs.append(seq[0])
            labels.append(target_labels[0])
            student_seqs.append(student_seq[0])
            student_labels.append(student_target_labels[0])

            mask = torch.ones_like(seq[0], dtype=torch.bool)
            mask[:prompt_len] = 0

        seqs = pad_sequence(seqs, batch_first=True, padding_value=padding_value)
        labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).long()
        lesson_ixs = torch.tensor([sample["lesson_ix"] for sample in samples])
        student_seqs = pad_sequence(student_seqs, batch_first=True, padding_value=padding_value)
        student_labels = pad_sequence(student_labels, batch_first=True, padding_value=IGNORE_INDEX).long()

        return {
            'seqs': seqs,
            'labels': labels,
            'lesson_ixs': lesson_ixs,
            'student_seqs': student_seqs,
            'student_labels': student_labels,
        }
