# %%
import asyncio
import html
import logging
import os
import sys
import time
import warnings
import xml.etree.ElementTree as ET
from datetime import datetime
            
import torch
from vllm import LLM as vLLM
from vllm import SamplingParams

from core import DATA_PATH
from core.llm import LLM
from core.messages import Message, Role, merge_messages
from curriculum.lesson import read_lessons, Lesson, Exercise
from curriculum.exercise_with_answers import ExerciseWithAnswers, Choice, xml_dump
from training.utils import clean_xml_content


def generate_prompt(
    llm: LLM,
    lesson: Lesson,
    max_total_tokens: int,
    max_new_tokens: int,
):
    prompts = []
    exercises = []
    lesson.create_exercise_prompts(verbose=False)
    for i, ex in enumerate(lesson.exercises):
        n_tokens_prompt = len(llm.tokenize(ex.student_prompt))
        max_tokens_to_generate = max_total_tokens - n_tokens_prompt
        if max_new_tokens > 0:
            max_tokens_to_generate = min(max_new_tokens, max_tokens_to_generate)
        if max_tokens_to_generate <= 10:
            warnings.warn(f"Too few tokens left for the answer: {max_tokens_to_generate}.", stacklevel=2)
        elif max_tokens_to_generate <= 0:
            raise ValueError(f"Too many tokens in the prompt: {n_tokens_prompt}, while the limit is {max_total_tokens}.")
        prompt = ex.teacher_prompt
        messages = [Message(Role.USER, prompt)]
        messages = merge_messages(messages)
        prompt = llm.messages_to_prompt(messages)
        prompt = prompt.replace("&lt;", "<").replace("&gt;", ">")
        prompts.append((prompt, max_tokens_to_generate))
        exercises.append(ex)
    return prompts, exercises

def process_answers(
    llm: LLM,
    exercise: Exercise,
    answers: list[str],
):
    answer_choices = []
    for answer in answers:
        if not isinstance(answer, str):
            answer = answer.text
        tokens = llm.tokenize(answer)
        terminators = llm.get_terminators()
        truncated = bool(tokens[0, -1] not in terminators)
        answer = llm.decode(tokens[:, :-1])
        choice = Choice(answer, truncated)
        answer_choices.append(choice)
    messages = [Message(Role.SYSTEM, exercise.teacher_prompt_with_tips_tags)]
    return ExerciseWithAnswers(messages, answer_choices, model_answer=exercise.model_answer, grading_str=exercise.grading_str)


def save_to_xml(lesson_id: str, exercises_with_answers: list[ExerciseWithAnswers],
                temperature: float, n_choices: int, llama3_8b: bool, llama3_70b: bool,
                qwen25_3b: bool, qwen25_14b: bool, qwen25_72b: bool, vllm: bool,):
    root = ET.Element("exercises_with_answers")
    ET.SubElement(root, "temperature", value=str(temperature))

    for ex in exercises_with_answers:
        ex.to_xml(root)

    fname = f"{lesson_id}"
    fname += f"_x{n_choices}" if n_choices != 1 else ""
    fname += f"_t{temperature}" if temperature != 1 else ""
    fname += f"_llama3-8b" if llama3_8b else ""
    fname += f"_llama3-70b" if llama3_70b else ""
    fname += f"_qwen2.5-3b" if qwen25_3b else ""
    fname += f"_qwen2.5-14b" if qwen25_14b else ""
    fname += f"_qwen2.5-72b" if qwen25_72b else ""
    fname += f"_vllm" if vllm else ""
    fname += ".xml"
    path = DATA_PATH / fname
    with open(path, "w") as file:
        xml_dump(root, file)

    print(f"Saved to {path}")


def main(
         base: str = "llama3-8b-instruct",
         generate_lesson: bool = False, 
         generate_exam: bool = False, 
         lesson_num_choices: int = 1, 
         exam_num_choices: int = 1,
         lesson_temp: float = 1.5, 
         exam_temp: float = 0.25, 
         max_total_tokens: int = 1024, 
         max_new_tokens: int = 500, 
         dataset: str = "nyt_default", 
         bonito_model: str = "llama3-8b-instruct", 
         bonito_max_items: int = 1000, 
         bonito_questions: int = 30, 
         bonito_temperature: float = 1.5, 
         verbose: bool = False,
         vllm: bool = True,
    ):

    # %%
    llama3_8b, llama3_70b, qwen25_3b, qwen25_14b, qwen25_72b = False, False, False, False, False
    opening_message = None
    if base == "llama3-8b-instruct":
        llama3_8b = True
        opening_message = Message(
            Role.SYSTEM,
            "You are a knowledgeable assistant trained to provide accurate and helpful information. Please respond to the user's queries promptly."
        )
        vllm_model = "meta-llama/Meta-Llama-3-8B-Instruct" 
    elif base == "llama3-70b-instruct":
        llama3_70b = True
        opening_message = Message(
            Role.SYSTEM,
            "You are a knowledgeable assistant trained to provide accurate and helpful information. Please respond to the user's queries promptly."
        )
        vllm_model = "meta-llama/Meta-Llama-3-70B-Instruct" 
    elif base == "qwen2.5-72b-instruct":
        qwen25_72b = True
        opening_message = Message(
            Role.SYSTEM,
            "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
        )
        vllm_model = "Qwen/Qwen2.5-72B-Instruct"
    elif base == "qwen2.5-14b-instruct":
        qwen25_14b = True
        opening_message = Message(
            Role.SYSTEM,
            "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
        )
        vllm_model = "Qwen/Qwen2.5-14B-Instruct"
    elif base == "qwen2.5-3b-instruct":
        qwen25_3b = True
        opening_message = Message(
            Role.SYSTEM,
            "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
        )
        vllm_model = "Qwen/Qwen2.5-3B-Instruct"

    llm = LLM(base, opening_message=opening_message)
    vllm_client = vLLM(
        vllm_model,
        tensor_parallel_size=torch.cuda.device_count()
    )

    if dataset == "nyt_cot":
        lesson_names = [
            f"nyt_cot_{bonito_model}_{bonito_questions}_{bonito_temperature}_{bonito_max_items}",
        ]
    elif dataset == "nyt_default":
        lesson_names = [
            f"nyt_default_{bonito_model}_{bonito_questions}_{bonito_temperature}_{bonito_max_items}",
        ]
    elif dataset == "new_wiki_cot":
        lesson_names = [
            f"new_wiki_cot_{bonito_model}_{bonito_questions}_{bonito_temperature}_{bonito_max_items}",
        ]
    elif dataset == "new_wiki_default":
        lesson_names = [
            f"new_wiki_default_{bonito_model}_{bonito_questions}_{bonito_temperature}_{bonito_max_items}",
        ]
    elif dataset == "amazon_cot":
        lesson_names = [
            f"amazon_cot_{bonito_model}_{bonito_questions}_{bonito_temperature}_{bonito_max_items}",
        ]
    elif dataset == "amazon_default":
        lesson_names = [
            f"amazon_default_{bonito_model}_{bonito_questions}_{bonito_temperature}_{bonito_max_items}",
        ]
    elif dataset == "reddit_cot":
        lesson_names = [
            f"reddit_cot_{bonito_model}_{bonito_questions}_{bonito_temperature}_{bonito_max_items}",
        ]
    elif dataset == "reddit_default":
        lesson_names = [
            f"reddit_default_{bonito_model}_{bonito_questions}_{bonito_temperature}_{bonito_max_items}",
        ]
    elif dataset == "new_wiki_default_exam":
        lesson_names = [
            f"new_wiki_default_{bonito_max_items}",
        ]
    elif dataset == "amazon_default_exam":
        lesson_names = [
            f"amazon_default_{bonito_max_items}",
        ]
    elif dataset == "reddit_default_exam":
        lesson_names = [
            f"reddit_default_{bonito_max_items}",
        ]
    elif dataset == "nyt_default_exam":
        lesson_names = [
            f"nyt_default_{bonito_max_items}",
        ]
    else:
        raise NotImplementedError

    modes = []

    if generate_lesson: modes.append('lesson')
    if generate_exam: modes.append('exam')

    temperatures = {
        'lesson': lesson_temp,
        'exam': exam_temp,
    }
    num_choices = {
        'lesson': lesson_num_choices,
        'exam': exam_num_choices,
    }
    xml_names_params = [(f"{mode}_{name}.xml", temperatures[mode], num_choices[mode])
                        for mode in modes for name in lesson_names]

    sampling_params = SamplingParams(
        include_stop_str_in_output=True,
        top_k=50,
        skip_special_tokens=False,
        max_tokens=max_total_tokens,
    )
    for xml_name, t, n_choices in xml_names_params:
        sampling_params.temperature = t
        print(f"Processing {xml_name}", flush=True)
        if "curriculum" not in os.getcwd(): xml_name = f"curriculum/{xml_name}"
        try:
            lessons = read_lessons(xml_name)
        except ET.ParseError:
            cleaned_xml_filename = clean_xml_content(xml_name)
            lessons = read_lessons(cleaned_xml_filename)
        prompts = []
        exercises = []
        print(f"Number of lessons: {len(lessons)}", flush=True)
        for i, (lesson_id, lesson) in enumerate(lessons.items()):
            p, e = generate_prompt(llm, lesson, max_total_tokens, max_new_tokens)
            prompts += p
            exercises += e

        assert len(prompts) == len(exercises)
        print(f"Number of prompts: {len(prompts)}", flush=True)
        start_time = time.time()
        prompts_only = [p for p, _ in prompts]
        outputs = vllm_client.generate(prompts_only, sampling_params)
        answers = []
        for output in outputs:
            answers.append(output.outputs)
        end_time = time.time()
        print(f"Generation time: {end_time - start_time:.4f} s", flush=True)
        assert len(prompts) == len(exercises) == len(answers)
        assert len(answers[0]) == n_choices

        exercises_with_answers = []
        for (ex, ans) in zip(exercises, answers):
            exercises_with_answers.append(process_answers(llm, ex, ans))
        assert len(exercises_with_answers) == len(exercises) 

        save_to_xml(lesson_id.rsplit('_', 1)[0], exercises_with_answers, t, n_choices,
                    llama3_8b=llama3_8b, llama3_70b=llama3_70b,
                    qwen25_3b=qwen25_3b, qwen25_14b=qwen25_14b, qwen25_72b=qwen25_72b,
                    vllm=vllm)


if __name__ == "__main__":
    from jsonargparse import CLI
    CLI(main)
