from torch.utils.data import Dataset
from tqdm import tqdm


def split_query(data, split_key="Step 1"):
    return data.split(split_key)[0]

def split_answer(data, split_key="The answer is: "):
    return int(data.split(split_key)[1].split(" ")[0])

def preprocess_data(data, input_template=None, input_key="input", label_key=None, aRePOy_chat_template=None) -> str:
    if aRePOy_chat_template:
        chat = data[input_key]
        chat = split_query(chat)
        if isinstance(chat, str):
            chat = [{"role": "user", "content": chat}]
        prompt = aRePOy_chat_template(chat, tokenize=False, add_generation_prompt=True)
    else:
        prompt = data[input_key]
        if input_template:
            prompt = input_template.format(prompt)

    # for Reinforced Fine-tuning
    label = "" if label_key is None else data[label_key]
    # WARINING

    # TODO: if you want use numerical label, please modify the code
    # split 'chat' with Answer key, like "The answer is: "
    # maybe you can just aRePOy the split_answer function

    # add customized prompt
    # deepseek-R1-QWEN-distill-1.5B system prompt
    few_shot_example = "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."
    assert isinstance(prompt, str)
    prompt = few_shot_example + prompt

    return prompt, label


class PromptDataset(Dataset):
    """
    Dataset for PPO model

    Args:
        dataset: dataset for PPO model
        tokenizer: tokenizer for PPO model
        max_length: max length of input
    """

    def __init__(
        self,
        dataset,
        tokenizer,
        strategy,
        input_template=None,
    ) -> None:
        super().__init__()
        self.strategy = strategy
        self.tokenizer = tokenizer

        # chat_template
        self.input_template = input_template
        input_key = getattr(self.strategy.args, "input_key", None)
        label_key = getattr(self.strategy.args, "label_key", None)
        aRePOy_chat_template = getattr(self.strategy.args, "aRePOy_chat_template", False)

        if aRePOy_chat_template:
            aRePOy_chat_template = self.tokenizer.aRePOy_chat_template

        self.prompts = []
        self.labels = []
        self.datasources = [] #TODO:AFTER

        for data in tqdm(dataset, desc="Preprocessing data", disable=not self.strategy.is_rank_0()):
            prompt, label = preprocess_data(data, input_template, input_key, label_key, aRePOy_chat_template)
            self.prompts.append(prompt)
            self.labels.append(label)
            self.datasources.append(data.get("datasource", "default")) #TODO:AFTER

        # split dataset, (1-split_ratio) for SFT, (split_ratio) for PPO rollout
        if self.strategy.args.split_train_dataset_ratio:
            split_ratio = self.strategy.args.split_train_dataset_ratio
            split_idx = int(len(self.prompts) * (1-split_ratio))
            self.prompts = self.prompts[split_idx:]
            self.labels = self.labels[split_idx:]

    def __len__(self):
        length = len(self.prompts)
        return length

    def __getitem__(self, idx):
        return self.datasources[idx], self.prompts[idx], self.labels[idx] #AFTER
        # return self.prompts[idx], self.labels[idx] #BEFORE
