# %%
from collections import Counter
import json
import numpy as np
import os
import re
from pathlib import Path
import random
import xml.etree.ElementTree as ET

import torch
from torch.utils.data import Sampler

from core.llm import LLM
from training.metrics import Aggregator
from curriculum import TIPS_START, TIPS_END


def augment_filenames(files: list[str],
                      temperature: float = 1,
                      n_choices: int = 1,
                      vllm: bool = False,
                      llama3_8b: bool = False,
                      llama3_70b: bool = False,
                      qwen25_3b: bool = False,
                      qwen25_14b: bool = False,
                      qwen25_72b: bool = False,
                      partition_idx: int = None,
                      partition_type: str = None,
                      ) -> list[str]:
    new_files = []
    for fn in files:
        fname = '.'.join(fn.split('.')[:-1])
        fname += f"_x{n_choices}" if n_choices != 1 else ""
        fname += f"_t{temperature}" if temperature != 1 else ""
        fname += f"_llama3-8b" if llama3_8b else ""
        fname += f"_llama3-70b" if llama3_70b else ""
        fname += f"_qwen2.5-3b" if qwen25_3b else ""
        fname += f"_qwen2.5-14b" if qwen25_14b else ""
        fname += f"_qwen2.5-72b" if qwen25_72b else ""
        fname += f"_vllm" if vllm else ""
        fname += f"_chunk{partition_idx}" if partition_idx else ""
        fname += f"_{partition_type}" if partition_type else ""
        fname += ".xml"
        new_files.append(Path(fname))
    return new_files


def clean_xml_content(filename):
    with open(filename, 'r', encoding='utf-8', errors='ignore') as file:
        content = file.read()

    # XML 1.0 allows: \t, \n, \r, and space through \uD7FF, \uE000 through \uFFFD, excluding the surrogate block \uD800 through \uDFFF
    # Characters not allowed in XML 1.0 are \u0000 through \u0008, \u000B, \u000C, \u000E through \u001F, and others above \uFFFD
    # We remove any characters not allowed in XML 1.0
    cleaned_content = re.sub(r"[\u0000-\u0008\u000B\u000C\u000E-\u001F\uD800-\uDFFF\uFFFE\uFFFF]", "", content)

    # Save the cleaned content back to a temporary file or overwrite the original
    temp_filename = filename + '.cleaned'
    with open(temp_filename, 'w', encoding='utf-8') as file:
        file.write(cleaned_content)

    return temp_filename


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using CUDA
    np.random.seed(seed)
    random.seed(seed)


def generate_answers(
    model,
    base_llm,
    generation_samples,
    accelerator,
    *,
    max_new_tokens=200,
    temperature=0.7,
):
    model.eval()
    tokenizer = base_llm.tokenizer
    answers = []
    with torch.no_grad():
        # Generation
        for sample in generation_samples:
            if "prompt_tokens" in sample:
                input_ids = sample["prompt_tokens"]
            else:
                input_ids = sample["student_prompt_tokens"]
            
            inputs = {
                'input_ids': input_ids.to(accelerator.device),
                'attention_mask': torch.ones_like(input_ids).to(accelerator.device)
            }

            output = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=True,
                eos_token_id=base_llm.get_terminators()
            )

            prompt_length = inputs['input_ids'].size(1)
            output = output[:, prompt_length:]
            answer = tokenizer.decode(output[0])
            print(f"[Generated answer, {accelerator.process_index}]:\n", answer, "\n", flush=True)
            answers.append(answer)
    model.train()
    return answers


def substring_locations(s: str, sub: str):
    return [match.start() for match in re.finditer(re.escape(sub), s)]


def tip_split(s: str):
    begin_locations = substring_locations(s, TIPS_START)
    end_locations = substring_locations(s, TIPS_END)

    if len(begin_locations) != len(end_locations):
        raise ValueError(f"Mismatch between {TIPS_START} and {TIPS_END} markers.")

    for i in range(len(begin_locations)-1):
        if begin_locations[i] > end_locations[i] or (i < len(end_locations) - 1 and begin_locations[i+1] < end_locations[i]):
            raise ValueError(f"Misplaced {TIPS_START} or {TIPS_END} marker.")

    parts = []
    last_index = 0
    tip = []
    for begin, end in zip(begin_locations, end_locations, strict=True):
        parts.append(s[last_index:begin])  # Text before TIPS_START
        tip.append(False)

        parts.append(s[begin+len(TIPS_START):end])  # Text between TIPS_START and TIPS_END
        tip.append(True)

        last_index = end + len(TIPS_END) # Skip TIPS_END

    if last_index < len(s):
        parts.append(s[last_index:])  # Remaining text after the last TIPS_END
        tip.append(False)

    return parts, tip


def tokenize(prompt_with_tips: str, llm):
    parts, tip = tip_split(prompt_with_tips)
    teacher_prompt = "".join(parts)
    teacher_tokens = llm.tokenize(teacher_prompt)
    teacher_tokens = llm.add_bos(teacher_tokens)

    student_prompt = "".join([parts[i] for i in range(len(parts)) if not tip[i]])
    student_tokens = llm.tokenize(student_prompt)
    student_tokens = llm.add_bos(student_tokens)

    return student_tokens, teacher_tokens


class InfiniteSampler(Sampler):
    def __init__(self, data_source_length):
        self.data_source_length = data_source_length

    def __iter__(self):
        while True:
            yield from torch.randperm(self.data_source_length)

    def __len__(self):
        return float('inf')


def save_base_model_config(model, base_llm, run_project_dir, verbose=True):
    with open(run_project_dir / "base_model_config.json", 'w', encoding='utf-8') as f:
        json.dump(base_llm.get_config(), f, ensure_ascii=False, indent=4)
    if verbose:
        print(f"Saved to {run_project_dir}")


def save_with_base_model_config(model, base_llm, run_project_dir):
    model.save_pretrained(run_project_dir)
    save_base_model_config(model, base_llm, run_project_dir)


def save_with_deepspeed(model, accelerator, base_llm, run_project_dir):
    unwrapped_model = accelerator.unwrap_model(model)

    # Only do the save on the main process
    if accelerator.is_main_process:
        unwrapped_model.save_pretrained(
            run_project_dir,
        )
        
        save_base_model_config(unwrapped_model, base_llm, run_project_dir, verbose=True)
