from typing import List
from jinja2 import Template
from thinker_task.ppo import PromptDataset

def process_prompt(question, tokenizer, prompt_type=0):
     
    if prompt_type == 0:
        prompt_template_jinja = """\
{{bos_token}}A conversation between User and Assistant. The User asks a question, and the Assistant solves it. The Assistant first thinks about the reasoning process in the mind and then provides the User with the answer. \
The reasoning process is enclosed within <think> </think> and answer is enclosed within <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>. User: {{prompt}}
Assistant: <think>\
"""
        prompt_instruction_template_jinja = """\
You must put your answer inside <answer> </answer> tags, i.e., <answer> answer here </answer>. And your final answer will be extracted automatically by the \\boxed{} tag.
This is the problem:
{{prompt}}
"""       
        prompt_instruction_template = Template(prompt_instruction_template_jinja)
        prompt_instruction = prompt_instruction_template.render(prompt=question)
        prompt_template = Template(prompt_template_jinja)
        if tokenizer.bos_token_id is None:
            bos_token = ""
        else:
            bos_token = tokenizer.decode([tokenizer.bos_token_id])
        prompt = prompt_template.render(bos_token=bos_token, prompt=prompt_instruction)        

    elif prompt_type == 1:
        prompt_template = "<|im_start|>User: {{prompt}} Let's think step by step and output the final answer within \\boxed{}.\n<|im_end|>\n<|im_start|>Assistant: <think>"
        prompt_template = Template(prompt_template)
        prompt = prompt_template.render(prompt=question)

    else:
        raise NotImplementedError()
    
    return prompt

class CustomDataset(PromptDataset):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def process_dialogue(self, dialogue: List):
        assert len(dialogue) == 2, "dialogue must contain 2 items"
        prompt = dialogue["problem"]
        extra = {"answer": dialogue["answer"]}

        if self.no_template: return prompt, extra
        prompt = process_prompt(prompt, self.tokenizer, prompt_type=self.prompt_type)
        return prompt, extra


class EvalCustomDataset(PromptDataset):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def process_dialogue(self, dialogue: dict):
        assert isinstance(dialogue, dict), "dialogue must be a dict"
        assert "prompt" in dialogue, "dialogue must contain prompt"
        assert "final_answer" in dialogue, "dialogue must contain final_answer"
        assert "file_name" in dialogue, "dialogue must contain file_name"

        prompt = dialogue["prompt"][0]["value"]
        extra = {"answer": dialogue["final_answer"], "file_name": dialogue["file_name"]}
        if self.no_template: return prompt, extra
        prompt = process_prompt(prompt, self.tokenizer, prompt_type=self.prompt_type)
        return prompt, extra
