from torch.utils.data import Dataset
from tqdm import tqdm


def preprocess_data(data, input_template=None, input_key="input", chosen_label_key=None, reject_label_key=None, apply_chat_template=None) -> str:
    if apply_chat_template:
        chat = data[input_key]
        if isinstance(chat, str):
            chat = [{"role": "user", "content": chat}]
        prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    else:
        prompt = data[input_key]
        if input_template:
            prompt = input_template.format(prompt)

    # for Reinforced Fine-tuning
    chosen_label = "" if chosen_label_key is None else data[chosen_label_key]
    reject_label = "" if reject_label_key is None else data[reject_label_key]
    # if label_key in ['chosen', 'reject'] and label[-1]['role'] == 'assistant':
    #     label = label[-1]['content']
    return prompt, chosen_label, reject_label


class PromptDataset(Dataset):
    """
    Dataset for PPO model

    Args:
        dataset: dataset for PPO model
        tokenizer: tokenizer for PPO model
        max_length: max length of input
    """

    def __init__(
        self,
        dataset,
        tokenizer,
        strategy,
        input_template=None,
    ) -> None:
        super().__init__()
        self.strategy = strategy
        self.tokenizer = tokenizer

        # chat_template
        self.input_template = input_template
        input_key = getattr(self.strategy.args, "input_key", None)
        chosen_label_key = getattr(self.strategy.args, "chosen_label_key", None)
        reject_label_key = getattr(self.strategy.args, "reject_label_key", None)
        
        apply_chat_template = getattr(self.strategy.args, "apply_chat_template", False)

        if apply_chat_template:
            apply_chat_template = self.tokenizer.apply_chat_template

        self.prompts = []
        self.chosen_labels = []
        self.reject_labels = []
        self.datasources = []
        for data in tqdm(dataset, desc="Preprocessing data", disable=not self.strategy.is_rank_0()):
            prompt, chosen_label, reject_label = preprocess_data(data, input_template, input_key, chosen_label_key, reject_label_key, apply_chat_template)
            self.prompts.append(prompt)
            self.chosen_labels.append(chosen_label)
            self.reject_labels.append(reject_label)
            self.datasources.append(data.get("datasource", "default"))

    def __len__(self):
        length = len(self.prompts)
        return length

    def __getitem__(self, idx):
        return self.datasources[idx], self.prompts[idx], self.chosen_labels[idx], self.reject_labels[idx]
