"""
python -m core.data
"""
import os
import re
import json
import math
import random
from functools import partial
from tqdm import tqdm
from rich import print

import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import (
    Dataset,
    DataLoader,
    DistributedSampler,
)
from datasets import load_from_disk

random.seed(0)

message_prompt = '''
You will conduct a conversation where you are an Assistant helping a User. The user asks a question, and the Assistant solves it.
The assistant first thinks about the reasoning process in the mind and then provides the user
with the answer. The entire reasoning process should be enclosed in <think> </think> and the final answer should be enclosed in <answer> </answer> tags.
For example, the format of the assistant's response is:
<think> your reasoning process here </think>
<answer> your final answer here </answer>
Thinking should be before the answer.
'''


def get_format_reward(text):
    """
    Format reward
    Check if <think> </think> and <answer> </answer> are present and each occur only once
    Check if <think> </think> and <answer> </answer> both have content
    Check if <think> </think> occur before <answer> </answer> and </answer> is at the end of the text
    Check if <think> </think> and <answer> </answer> tags are present, valid, and correctly ordered.

    Args:
        text (str): The input string to validate.

    Returns:
        bool: True if all conditions are met, False otherwise.
    """

    # Normalize whitespace
    text = text.strip()

    # Check for single occurrence of <think>...</think> and <answer>...</answer>
    think_matches = re.findall(r"<think>(.*?)</think>", text, re.DOTALL)
    answer_matches = re.findall(r"<answer>(.*?)</answer>", text, re.DOTALL)

    if len(think_matches) != 1 or len(answer_matches) != 1:
        return False

    # Check if both <think> and <answer> have non-empty content
    if not think_matches[0].strip() or not answer_matches[0].strip():
        return False

    # Check order and placement: <think> must appear before <answer>,
    # and </answer> must be at the end of the text
    match = re.match(r"<think>.*?</think>\s*<answer>.*?</answer>$", text, re.DOTALL)
    if not match:
        return False

    return True


llama_chat_template = '''{{ '<|begin_of_text|>' }}{% for message in messages %}{% if message['role'] == 'system' %}<|start_header_id|>system<|end_header_id|>
{{ message['content'] }}<|eot_id|>{% elif message['role'] == 'user' %}<|start_header_id|>user<|end_header_id|>

{{ message['content'] }}<|eot_id|>{% elif message['role'] == 'assistant' %}<|start_header_id|>assistant<|end_header_id|>

{{ message['content'] }}<|eot_id|>{% else %}{{ raise_exception("Invalid role: " + message['role']) }}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}<|start_header_id|>assistant<|end_header_id|>
{% endif %}
'''

def tokenizer_factory(tokenizer, dataset_split='eval'):
    assert dataset_split in ('train', 'eval')
    _model_name = tokenizer.name_or_path.lower()

    if 'llama' in _model_name:
        #if dataset_split == 'train':
        tokenizer.chat_template = llama_chat_template
        tokenizer.assistant_start_text = '<|start_header_id|>assistant<|end_header_id|>'
        tokenizer.eos_token = '<|eot_id|>'
    elif 'gemma' in _model_name:
        if dataset_split == 'train':
            #tokenizer.chat_template = gemma_chat_template
            ...
        tokenizer.assistant_start_text = '<start_of_turn>model\n'
    elif 'qwen' in _model_name:
        if dataset_split == 'train':
            #tokenizer.chat_template = qwen_chat_template
            ...
        tokenizer.eos_token = '<|im_end|>'
        tokenizer.assistant_start_text = '<|im_start|>assistant\n'
    else:
        raise ValueError(f'Unsupported model: {tokenizer.name_or_path}')
    
    return tokenizer


def tokenize_messages(messages, tokenizer, max_seq_length, dataset_split='eval', **kwargs):
    """
    Code modified from: https://github.com/allenai/open-instruct/blob/main/open_instruct/finetune.py#L310

    Here we assume each example has a 'messages' field Each message is a dict with 'role' and 'content' fields.
    We concatenate all messages with the roles as delimiters and tokenize them together.
    """
    if len(messages) == 0:
        raise ValueError('messages field is empty.')
    
    apply_chat_template = partial(tokenizer.apply_chat_template, tokenize=False, add_generation_prompt=False)
    example_text = apply_chat_template(messages).strip()
    tokenized_example = tokenizer(example_text, return_tensors='pt', truncation=False, add_special_tokens=False)
    input_ids = tokenized_example.input_ids
    labels = input_ids.clone()
    
    # NOTE: where the length selection happens; allow for 50 extra buffer tokens
    if input_ids.size(1) > max_seq_length + 200:
        return None

    # mask the non-assistant part to exclude from loss calculation
    for message_idx, message in enumerate(messages):
        if message['role'] != 'assistant':
            if message_idx == 0:
                message_start_idx = 0
            else:
                message_start_idx = tokenizer(
                    apply_chat_template(messages[:message_idx]),
                    return_tensors='pt',
                    max_length=max_seq_length,
                    truncation=True,
                    add_special_tokens=False,
                ).input_ids.size(1)
            if message_idx < len(messages) - 1 and messages[message_idx + 1]['role'] == 'assistant':
                # here we also ignore the role of the assistant
                messages_so_far = apply_chat_template(messages[:message_idx + 1]) + tokenizer.assistant_start_text
            else:
                messages_so_far = apply_chat_template(messages[:message_idx + 1])
            message_end_idx = tokenizer(messages_so_far, return_tensors='pt', max_length=max_seq_length, truncation=True, add_special_tokens=False).input_ids.size(1)
            labels[:, message_start_idx:message_end_idx] = -100
            
            if message_end_idx >= max_seq_length:
                break
    attention_mask = torch.ones_like(input_ids)
    return dict(
        input_ids=input_ids.flatten(),
        attention_mask=attention_mask.flatten(),
        labels=labels.flatten(),
    )


class BaseDataset(Dataset):
    def __init__(self):
        pass
    
    def __len__(self):
        pass

    def __getitem__(self, idx):
        pass

    def collate_fn(self, batch):
        pass

    def get_dataloader(self, dataset, batch_size, shuffle, use_distributed=False, **kwargs):
        if use_distributed:
            dataloader, sampler = distributed_dataloader(
                dataset=dataset,
                batch_size=batch_size,
                collate_fn=self.collate_fn,
                shuffle=shuffle,
                seed=kwargs['seed'],
                local_rank=kwargs['local_rank'],
                world_size=kwargs['world_size'],
            )
        else:
            dataloader, sampler = default_dataloader(
                dataset=dataset,
                batch_size=batch_size,
                collate_fn=self.collate_fn,
                shuffle=shuffle,
            )
        return dataloader, sampler

    def __len__(self):
        return self.data['input_ids'].size(0)

    def __getitem__(self, idx):
        return dict(
            datapoint_ids=self.data['datapoint_ids'][idx],
            input_ids=self.data['input_ids'][idx],
            attention_mask=self.data['attention_mask'][idx],
            labels=self.data['labels'][idx],
        )

    def collate_fn(self, batch):
        return dict(
            datapoint_ids=torch.stack([b['datapoint_ids'] for b in batch]),
            input_ids=torch.stack([b['input_ids'] for b in batch]),
            attention_mask=torch.stack([b['attention_mask'] for b in batch]),
            labels=torch.stack([b['labels'] for b in batch]),
        )

    @classmethod
    def construct_dataloader(
        cls,
        num_datapoints,
        max_seq_length,
        tokenizer,
        batch_size,
        shuffle,
        use_distributed=False,
        dataset_split=None,
        **kwargs,
    ):
        dataset = cls(
            num_datapoints=num_datapoints,
            max_seq_length=max_seq_length,
            tokenizer=tokenizer,
            dataset_split=dataset_split,
            **kwargs,
        )
        dataloader, sampler = dataset.get_dataloader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            use_distributed=use_distributed,
            **kwargs,
        )
        return dataloader, sampler, dataset

    def get_eval_batcher(
        self,
        batch_size,
        chat_template_fn=None,
        **kwargs,
    ):
        tokenizer = tokenizer_factory(self.tokenizer, dataset_split='eval')
    
        if chat_template_fn is None:
            chat_template_fn = partial(
                tokenizer.apply_chat_template,
                tokenize=False,
                add_generation_prompt=True,
            )

        num_batches = math.ceil(len(self.datapoints) / batch_size)
        for batch_idx in range(num_batches):
            batch_datapoints = self.datapoints[batch_idx * batch_size:(batch_idx + 1) * batch_size]
            batch_input_texts = [chat_template_fn(datapoint['messages']) for datapoint in batch_datapoints]
            inputs = tokenizer.batch_encode_plus(
                batch_input_texts,
                return_tensors='pt',
                padding=True,
                truncation=True,
                max_length=self.max_seq_length,
                add_special_tokens=False,
            )
            yield inputs, batch_datapoints

    def get_eval_batch_input_texts(
        self,
        batch_size,
        chat_template_fn=None,
        **kwargs,
    ):
        tokenizer = tokenizer_factory(self.tokenizer, dataset_split='eval')

        if chat_template_fn is None:
            chat_template_fn = partial(
                tokenizer.apply_chat_template,
                tokenize=False,
                add_generation_prompt=True,
            )

        num_batches = math.ceil(len(self.datapoints) / batch_size)
        for batch_idx in range(num_batches):
            batch_datapoints = self.datapoints[batch_idx * batch_size:(batch_idx + 1) * batch_size]
            batch_input_texts = [chat_template_fn(datapoint['messages']) for datapoint in batch_datapoints]
            yield batch_input_texts, batch_datapoints
    
    def tokenize_for_inference_by_ids(self, datapoint_ids: list[int]):
        tokenizer = tokenizer_factory(self.tokenizer, dataset_split='eval')

        datapoints = [self.id_to_datapoint[idx] for idx in datapoint_ids]
        messages = [datapoint['messages'][:-1] for datapoint in datapoints]
        input_texts = [
            tokenizer.apply_chat_template(
                msg,
                tokenize=False,
                add_generation_prompt=True,
            )
            for msg in messages
        ]
        encoded = tokenizer.batch_encode_plus(
            input_texts,
            return_tensors='pt',
            padding=True,
            truncation=True,
            add_special_tokens=False,
            max_length=self.max_seq_length,
            padding_side='left',
        )
        return encoded


def default_dataloader(dataset, batch_size, collate_fn, shuffle):
    dataloader = DataLoader(
        dataset,
        pin_memory=False,
        drop_last=False,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=shuffle,
    )
    return dataloader, None


def distributed_dataloader(dataset, batch_size, collate_fn, shuffle, seed, local_rank, world_size):
    sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=local_rank,
        shuffle=shuffle,
        seed=seed,
    )
    dataloader = DataLoader(
        dataset,
        pin_memory=True,
        drop_last=True,
        batch_size=batch_size,
        collate_fn=collate_fn,
        sampler=sampler,
    )
    return dataloader, sampler


class RandomIndexSequenceDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, vocab_size, dataset_split=None):
        super().__init__()
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.vocab_size = vocab_size
        self.dataset_split = dataset_split
        
        self.data = []
        for _ in range(num_datapoints):
            seq_length = random.randint(1, max_seq_length)
            input_ids = torch.randint(0, vocab_size, (seq_length,))
            self.data.append(input_ids)

    def __len__(self):
        return self.num_datapoints

    def __getitem__(self, idx):
        input_ids = self.data[idx]
        labels = input_ids.clone()
        return dict(input_ids=input_ids, labels=labels)

    def collate_fn(self, batch):
        max_seq_length = max(len(item['input_ids']) for item in batch)
        input_ids = torch.zeros((len(batch), max_seq_length), dtype=torch.long)
        labels = torch.zeros((len(batch), max_seq_length), dtype=torch.long)
        attention_mask = torch.zeros((len(batch), max_seq_length), dtype=torch.long)

        for i, item in enumerate(batch):
            seq_length = len(item['input_ids'])
            input_ids[i, :seq_length] = item['input_ids']
            labels[i, :seq_length] = item['labels']
            attention_mask[i, :seq_length] = 1
    
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=attention_mask,
        )

    @classmethod
    def construct_dataloader(
        cls,
        num_datapoints,
        max_seq_length,
        batch_size,
        shuffle,
        use_distributed=False,
        vocab_size=10000,
        **kwargs,
    ):
        dataset = cls(
            num_datapoints=num_datapoints,
            max_seq_length=max_seq_length,
            vocab_size=vocab_size,
            **kwargs,
        )
        dataloader = dataset.get_dataloader(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            use_distributed=use_distributed,
            **kwargs,
        )
        return dataloader, None, dataset


class RandomKVRDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split=None, **kwargs):
        super().__init__()
        import uuid
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.model_name = kwargs.get('model_name', None)

        uuid_str = lambda: str(uuid.UUID(int=random.getrandbits(128)))[:8]
        _data = [(uuid_str(), uuid_str()) for _ in range(num_datapoints)]
        input_ids = []
        attention_mask = []
        labels = []
        datapoint_ids = []
        for datapoint_idx, (k, v) in enumerate(_data):
            datapoint_ids.append(datapoint_idx)
            messages = [
                dict(role='user', content=f'For key: {k}, the value is what?'),
                dict(role='assistant', content=v),
            ]
            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            input_ids.append(tokenized['input_ids'])
            attention_mask.append(tokenized['attention_mask'])
            labels.append(tokenized['labels'])
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)
        datapoint_ids = torch.tensor(datapoint_ids).long()
        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )


class UltraChatDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='train', **kwargs):
        """
        Max 208000 examples.
        """
        super().__init__()
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.model_name = kwargs.get('model_name', None)

        self.data = []
        self.id_to_datapoint = dict()

        if num_datapoints == 'all':
            num_datapoints = 208000

        path = 'path_to/yourname/data/HuggingFaceH4/ultrachat_200k'
        cache_path = f'path_to/yourname/data/HuggingFaceH4/ultrachat_200k/cache_split={dataset_split}_n={num_datapoints}_len={max_seq_length}.pt'
        if False and os.path.exists(cache_path):
            self.data = torch.load(cache_path, weights_only=True)
            datapoint_ids = torch.arange(self.data['input_ids'].size(0))
            self.data['datapoint_ids'] = datapoint_ids

            print(f'Data loaded from: {cache_path}')
        else:
            dataset = load_from_disk(path)
            _dataset_split = 'test' if dataset_split == 'eval' else 'train'
            dataset = dataset[_dataset_split + '_sft']
    
            datapoint_ids = []
            all_tokenized = []
            for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
                if len(all_tokenized) >= num_datapoints:
                    break
                messages = datapoint['messages']
                tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
                if tokenized is None:
                    continue
                all_tokenized.append(tokenized)
                datapoint_ids.append(datapoint_idx)
                self.id_to_datapoint[datapoint_idx] = datapoint
            input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
            attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
            labels = [tokenized['labels'] for tokenized in all_tokenized]
    
            datapoint_ids = torch.tensor(datapoint_ids).long()
            input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
            attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
            labels = pad_sequence(labels, batch_first=True, padding_value=-100)

            self.data = dict(
                datapoint_ids=datapoint_ids,
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )
            torch.save(self.data, cache_path)
            print(f'Data saved to: {cache_path}')


class UltraFeedbackDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='train', **kwargs):
        super().__init__()
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.model_name = kwargs.get('model_name', None)

        self.data = []

        path = 'path_to/yourname/data/ultrafeedback'
        dataset = load_from_disk(path)

        _dataset_split = 'test' if dataset_split == 'eval' else 'train'
        dataset = dataset[f'{_dataset_split}_prefs']
        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))

        datapoint_ids = []
        all_chosen_tokenized = []
        all_rejected_tokenized = []
        for datapoint_id, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_chosen_tokenized) >= num_datapoints and len(all_rejected_tokenized) >= num_datapoints:
                break
            datapoint_ids.append(datapoint_id)
            
            chosen = datapoint['chosen']
            rejected = datapoint['rejected']

            chosen_tokenized = tokenize_messages(chosen, tokenizer, max_seq_length, **kwargs)
            rejected_tokenized = tokenize_messages(rejected, tokenizer, max_seq_length, **kwargs)
            if chosen_tokenized is None or rejected_tokenized is None:
                continue
            all_chosen_tokenized.append(chosen_tokenized)
            all_rejected_tokenized.append(rejected_tokenized)
        
        chosen_input_ids = [tokenized['input_ids'] for tokenized in all_chosen_tokenized]
        chosen_attention_mask = [tokenized['attention_mask'] for tokenized in all_chosen_tokenized]
        chosen_labels = [tokenized['labels'] for tokenized in all_chosen_tokenized]
        chosen_input_ids = pad_sequence(chosen_input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        chosen_attention_mask = pad_sequence(chosen_attention_mask, batch_first=True, padding_value=0)
        chosen_labels = pad_sequence(chosen_labels, batch_first=True, padding_value=-100)

        rejected_input_ids = [tokenized['input_ids'] for tokenized in all_rejected_tokenized]
        rejected_attention_mask = [tokenized['attention_mask'] for tokenized in all_rejected_tokenized]
        rejected_labels = [tokenized['labels'] for tokenized in all_rejected_tokenized]
        rejected_input_ids = pad_sequence(rejected_input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        rejected_attention_mask = pad_sequence(rejected_attention_mask, batch_first=True, padding_value=0)
        rejected_labels = pad_sequence(rejected_labels, batch_first=True, padding_value=-100)

        self.datapoint_ids = torch.tensor(datapoint_ids).long()

        self.data = dict(
            datapoint_ids=datapoint_ids,
            chosen_input_ids=chosen_input_ids,
            chosen_attention_mask=chosen_attention_mask,
            chosen_labels=chosen_labels,
            rejected_input_ids=rejected_input_ids,
            rejected_attention_mask=rejected_attention_mask,
            rejected_labels=rejected_labels,
        )

    def __len__(self):
        return self.data['chosen_input_ids'].size(0)

    def __getitem__(self, idx):
        return dict(
            datapoint_id=self.datapoint_ids[idx],
            chosen_input_ids=self.data['chosen_input_ids'][idx],
            chosen_attention_mask=self.data['chosen_attention_mask'][idx],
            chosen_labels=self.data['chosen_labels'][idx],
            rejected_input_ids=self.data['rejected_input_ids'][idx],
            rejected_attention_mask=self.data['rejected_attention_mask'][idx],
            rejected_labels=self.data['rejected_labels'][idx],
        )

    def collate_fn(self, batch):
        batch = dict(
            datapoint_ids=torch.stack([b['datapoint_id'] for b in batch]),
            chosen_input_ids=torch.stack([b['chosen_input_ids'] for b in batch]),
            chosen_attention_mask=torch.stack([b['chosen_attention_mask'] for b in batch]),
            chosen_labels=torch.stack([b['chosen_labels'] for b in batch]),
            rejected_input_ids=torch.stack([b['rejected_input_ids'] for b in batch]),
            rejected_attention_mask=torch.stack([b['rejected_attention_mask'] for b in batch]),
            rejected_labels=torch.stack([b['rejected_labels'] for b in batch]),
        )
        return batch
    

class TuluV2Dataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='train', **kwargs):
        super().__init__()
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        path = 'path_to/yourname/data/allenai/tulu-v2-sft-mixture'
        dataset = load_from_disk(path)
        dataset = dataset['train']

        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))

        datapoint_ids = []
        all_tokenized = []
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_tokenized) >= num_datapoints:
                break
            messages = datapoint['messages']
            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = datapoint

        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)
        
        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]


class MMLUDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='eval', for_generation=False, **kwargs):
        super().__init__()
#        assert dataset_split == 'eval', 'Only test split is supported for MMLU'
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        path = 'path_to/yourname/_code/data/mmlu/all'
        dataset = load_from_disk(path)
        #dataset_split_ = 'test' if dataset_split == 'eval' else dataset_split
        dataset = dataset['test']
        dataset = list(dataset)
        random.seed(0)
        random.shuffle(dataset)

        if dataset_split == 'train':
            dataset = dataset[:12000]
        elif dataset_split == 'eval':
            dataset = dataset[12000:]
        else:
            raise ValueError(f'Invalid dataset split: {dataset_split}')

        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        self.num_datapoints = num_datapoints

        datapoint_ids = []
        all_tokenized = []
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_tokenized) >= self.num_datapoints:
                break
            subject = datapoint['subject']
            question = datapoint['question']
            choices = datapoint['choices']
            answer = datapoint['answer']
            answer_option = list('ABCD')[answer]
            choices_text = '\n'.join([f'{c}. {s}' for c, s in zip(list('ABCD'), choices)])
            user_content = f'{question}\n\nAnswer options:\n{choices_text}\n\nReason about it and answer with "The answer is: <option>"'
            messages = [dict(role='user', content=user_content)]
            if dataset_split == 'train' and not for_generation:
                messages.append(dict(role='assistant', content=answer_option))
            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(
                messages=messages,
                targets=[answer_option],
                **datapoint,
            )

        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)
        
        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]

    def parse_output_text(self, output_text):
        if output_text is None:
            return None
        
        output_text = output_text.replace('*', '')
        match = re.search(r'\bThe(.*)answer is(?: option)?:?\s*(\w+)', output_text, re.IGNORECASE)
        if match:
            return match.group(2).upper().strip()
        else:
            return None

    def print_accuracy(self, corrects):
        print('### Accuracy ###')
        num_total_correct = 0
        num_total = 0
        for category, category_corrects in corrects.items():
            num_category_correct = sum(category_corrects)
            num_category_total = len(category_corrects)
            print(f'{category}: {num_category_correct / num_category_total:.2f}')
    
            num_total_correct += num_category_correct
            num_total += num_category_total
        print(f'Average: {num_total_correct / num_total:.2f}')
        print('################')

    def reward_fn(self, pred_text, datapoint):
        targets = datapoint['targets']
        return pred_text in targets


class ExpertSFTMMLUDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='train', for_generation=False, **kwargs):
        super().__init__()
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        dataset = []
        path = 'path_to/yourname/local_models/meta-llama/Llama-3.3-70B-Instruct/mmlu/train/predictions.jsonl'
        with open(path, 'r') as f:
            for line in f:
                datapoint = json.loads(line.strip())
                if not datapoint['correct']:
                    continue
                dataset.append(datapoint)

        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))
        dataset = dataset[:num_datapoints]
        
        datapoint_ids = []
        all_tokenized = []
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_tokenized) >= num_datapoints:
                break

            messages = datapoint['messages']
            messages.append(dict(role='assistant', content=datapoint['output_text'].strip()))
            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            
            datapoint['targets'] = [None]
                
            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(**datapoint)
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)

        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,  
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]


class ExpertSFTContrastiveMMLUDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='eval', for_generation=False, **kwargs):
        super().__init__()
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        dataset = []
        path = 'path_to/yourname/local_models/meta-llama/Llama-3.3-70B-Instruct/mmlu/train/predictions.jsonl'
        with open(path, 'r') as f:
            for line in f:
                datapoint = json.loads(line.strip())
                dataset.append(datapoint)

        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))
        dataset = dataset[:num_datapoints]
        
        datapoint_ids = []
        all_tokenized = []
        all_corrects = []
        num_pos = 0
        num_neg = 0
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_tokenized) >= num_datapoints:
                break
            
            if not datapoint['correct'] and num_neg >= num_pos:
                continue

            messages = datapoint['messages']
            messages.append(dict(role='assistant', content=datapoint['output_text'].strip()))
            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            
            if datapoint['correct']:
                num_pos += 1
            else:
                num_neg += 1

            all_corrects.append(datapoint['correct'])
            datapoint['targets'] = [None]
            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(**datapoint)
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        all_corrects = torch.tensor(all_corrects).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)

        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,  
            labels=labels,
            corrects=all_corrects,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]

    def __getitem__(self, idx):
        return dict(
            datapoint_ids=self.data['datapoint_ids'][idx],
            input_ids=self.data['input_ids'][idx],
            attention_mask=self.data['attention_mask'][idx],
            labels=self.data['labels'][idx],
            corrects=self.data['corrects'][idx],
        )

    def collate_fn(self, batch):
        return dict(
            datapoint_ids=torch.stack([b['datapoint_ids'] for b in batch]),
            input_ids=torch.stack([b['input_ids'] for b in batch]),
            attention_mask=torch.stack([b['attention_mask'] for b in batch]),
            labels=torch.stack([b['labels'] for b in batch]),
            corrects=torch.stack([b['corrects'] for b in batch]),
        )


class TriviaQADataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='eval', for_generation=False, **kwargs):
        super().__init__()
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        path = 'path_to/yourname/data/mandarjoshi/trivia_qa'
        dataset = load_from_disk(path)

        _dataset_split = 'validation' if dataset_split == 'eval' else dataset_split
        dataset = dataset[_dataset_split]

        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        self.num_datapoints = num_datapoints

        datapoint_ids = []
        all_tokenized = []
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_tokenized) >= self.num_datapoints:
                break
            question = datapoint['question']
            answer = datapoint['answer']['value']
            targets = datapoint['answer']['aliases']
            user_content = f'{question}\n\nReason about it and answer with "The answer is: <option>"'
            messages = [dict(role='user', content=user_content)]
            if dataset_split == 'train' and not for_generation:
                messages.append(dict(role='assistant', content=answer))
            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(
                messages=messages,
                targets=targets,
                **datapoint,
            )

        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)
        
        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]
 
    def reward_fn(self, pred_text, datapoint):
        targets = [t.lower() for t in datapoint['targets']]
        return pred_text.lower() in targets


class AlpacaEvalDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='eval', for_generation=False, **kwargs):
        super().__init__()
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        path = 'path_to/yourname/data/tatsu-lab/alpaca_eval'
        dataset = load_from_disk(path)

        if num_datapoints == 'all':
            num_datapoints = len(dataset)        
        num_datapoints = min(num_datapoints, len(dataset))
        self.num_datapoints = num_datapoints

        for datapoint in tqdm(dataset, total=num_datapoints):
            print(datapoint)
            input()


class CollieDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='eval', for_generation=False, **kwargs):
        super().__init__()
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        import dill
        import collie
        data_path = ''
        constraint_types = [
            'ccnews_c04', 'ccnews_c05', 'ccnews_c06a', 'ccnews_c07', 'ccnews_c08', 'ccnews_c09', 'ccnews_c10', 'ccnews_c11', 'ccnews_c12', 'ccnews_c14',
            'english_c01', 'english_c02', 'english_c03', 
            'guten_c04', 'guten_c05', 'guten_c06a', 'guten_c07', 'guten_c08', 'guten_c09', 'guten_c10', 'guten_c11', 'guten_c12', 'guten_c14',
            'wiki_c04', 'wiki_c05', 'wiki_c06a', 'wiki_c07', 'wiki_c08', 'wiki_c09', 'wiki_c10', 'wiki_c11', 'wiki_c12', 'wiki_c14'
        ]
        
        with open(data_path, 'rb') as f:
            data = dill.load(f)
        
        all_tokenized = []
        datapoint_idx = 0
        datapoint_ids = []
        for constraint_type in constraint_types:
            data_ = data[constraint_type]
            for datapoint in data_:
                prompt = datapoint['prompt'] + '\n\nWrite your response starting with ### START ### and ending with ### END ###.'
                constraint = datapoint.pop('constraint')
                text = datapoint['example']
                targets = datapoint['targets']
                messages = [dict(role='user', content=prompt)]
                if dataset_split == 'train' and not for_generation:
                    messages.append(dict(role='assistant', content=text))
                tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
                if tokenized is None:
                    continue
                all_tokenized.append(tokenized)
                self.id_to_datapoint[datapoint_idx] = dict(
                    messages=messages,
                    targets=targets,
                    constraint_type=constraint_type,
                    constraint=constraint,
                    example=text,
                    prompt=prompt,
                )
                datapoint_ids.append(datapoint_idx)
                datapoint_idx += 1
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)
        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]

    def parse_output_text(self, output_text):
        if output_text is None:
            return None
        out = output_text.split('### START ###')
        if len(out) == 1:
            return output_text.strip()
        else:
            out = out[1].split('### END ###')
            if len(out) == 1:
                return out[0].strip()
            else:
                return out[0].strip()

    def reward_fn(self, pred_text, datapoint):
        constraint = datapoint['constraint']
        targets = datapoint['targets']
        return constraint.check(pred_text, targets)


class GSM8KDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='eval', for_generation=False, **kwargs):
        super().__init__()
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        path = 'path_to/yourname/data/openai/gsm8k'
        dataset = load_from_disk(path)
        _dataset_split = 'test' if dataset_split == 'eval' else dataset_split
        dataset = dataset[_dataset_split]
        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        self.num_datapoints = min(num_datapoints, len(dataset))

        all_tokenized = []
        datapoint_ids = []
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if datapoint_idx >= num_datapoints:
                break

#            user_content = message_prompt + '\n\n' + datapoint['question']
            user_content = datapoint['question']
            messages = [dict(role='user', content=user_content)]

            answer = datapoint['answer']
            if _dataset_split == 'train' and not for_generation:
                messages.append(dict(role='assistant', content=answer.split('####')[0].strip()))
            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)

            parsed_output_text = self.parse_output_text(answer)
            datapoint['targets'] = [parsed_output_text]
            if tokenized is None:
                continue
            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(
                messages=messages,
                **datapoint,
            )
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)

        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]

    def parse_output_text(self, output_text):
        if output_text is None:
            return None
        parsed = re.findall(r'\$?[\d,]+\.?\d*', output_text)
        if len(parsed) >= 1:
            parsed_output_text = parsed[-1]
            parsed_output_text = re.sub(r'[,$]', '', parsed_output_text)
            parsed_output_text = re.sub(r'\.$', '', parsed_output_text)
        else:
            parsed_output_text = None

        if parsed_output_text is not None:
            return parsed_output_text
        else:
            matches = re.findall(r'\\boxed{((?:[^{}]|{[^{}]*})*)}', output_text)
            if matches:
                return matches[-1].strip()
            
            matches = re.findall(r'\$([^$]+)\$', output_text)
            if matches:
                return matches[-1].strip()
                
            matches = re.findall(r'(?:^|[^\d])(\d+(?:\.\d+)?|\.\d+)(?:[^\d]|$)', output_text)
            if matches:
                return matches[-1].strip()

    def reward_fn(self, pred_text, datapoint):
        targets = datapoint['targets']
        reward = pred_text in targets
        #format_reward = get_format_reward(pred_text) if (reward and reward is not None) else False
        return reward


class MATHDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='eval', for_generation=False, **kwargs):
        super().__init__()
        from core.evaluation.math_utils import normalize_final_answer
        self.normalize_final_answer = normalize_final_answer
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        path = 'path_to/yourname/data/hendrycks/competition_math'
        dataset = load_from_disk(path)

        _dataset_split = 'test' if dataset_split == 'eval' else dataset_split
        dataset = dataset[_dataset_split]
        dataset = list(dataset)
        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))

        datapoint_ids = []
        all_tokenized = []
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_tokenized) >= num_datapoints:
                break
            problem = datapoint['problem']
            solution = datapoint['solution']

            #user_content = message_prompt + '\n\n' + problem
            #messages = [dict(role='user', content=user_content)]
            messages = [dict(role='user', content=f'{problem}\n\nReason about the problem, derive your answer, and wrap your final answer in $\\boxed{{ }}$.')]

            if _dataset_split == 'train' and not for_generation:
                messages.append(dict(role='assistant', content=solution))

            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            
            parsed_output_text = self.parse_output_text(solution)
            datapoint['targets'] = [parsed_output_text]

            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(
                messages=messages,
                **datapoint,
            )
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)
        
        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]

    def parse_output_text(self, output_text):
        if output_text is None:
            return None
        
        output_text = self._parse(output_text)
        
        matches = re.findall(r'\\boxed{((?:[^{}]|{[^{}]*})*)}', output_text)
        if matches:
            return matches[-1].strip()
        
        matches = re.findall(r'\$([^$]+)\$', output_text)
        if matches:
            return matches[-1].strip()
            
        matches = re.findall(r'(?:^|[^\d])(\d+(?:\.\d+)?|\.\d+)(?:[^\d]|$)', output_text)
        if matches:
            return matches[-1].strip()
        
        return output_text

    def _parse(self, text):
        text = text.replace('\\dfrac', '\\frac')
        text = text.replace('\pi', '\\pi')
        text = text.replace('\\left', '')
        text = text.replace('\\right', '')
        return text

    def reward_fn(self, pred_text, datapoint):
        pred_text = self._parse(self.normalize_final_answer(pred_text))
        targets = datapoint['targets']
        targets = [self._parse(self.normalize_final_answer(target)) for target in targets]
        return pred_text in targets


class ExpertSFTMATHDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='train', for_generation=False, **kwargs):
        super().__init__()
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        dataset = []
        path = 'path_to/yourname/local_models/meta-llama/Llama-3.3-70B-Instruct/math/train/predictions.jsonl'
        with open(path, 'r') as f:
            for line in f:
                prediction = json.loads(line.strip())
                if not prediction['correct']:
                    continue
                dataset.append(prediction)

        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))
        dataset = dataset[:num_datapoints]
        
        datapoint_ids = []
        all_tokenized = []
        for datapoint_idx, prediction in tqdm(enumerate(dataset), total=num_datapoints):
            datapoint = prediction['datapoint']
            if len(all_tokenized) >= num_datapoints:
                break

            messages = datapoint['messages']
            messages.append(dict(role='assistant', content=prediction['output_text'].strip()))
            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            
            datapoint['targets'] = [None]
                
            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(**datapoint)
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)

        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,  
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]


class MATH500Dataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='eval', for_generation=False, **kwargs):
        super().__init__()
        from core.evaluation.math_utils import normalize_final_answer
        self.normalize_final_answer = normalize_final_answer
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        path = 'path_to/yourname/data/HuggingFaceH4/MATH-500'
        dataset = load_from_disk(path)

        _dataset_split = 'test' if dataset_split == 'eval' else dataset_split
        dataset = dataset[_dataset_split]
        dataset = list(dataset)
        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))

        datapoint_ids = []
        all_tokenized = []
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_tokenized) >= num_datapoints:
                break
            problem = datapoint['problem']
            solution = datapoint['solution']

            #user_content = message_prompt + '\n\n' + problem
            #messages = [dict(role='user', content=user_content)]
            messages = [dict(role='user', content=f'{problem}\n\nReason about the problem, derive your answer, and wrap your final answer in $\\boxed{{ }}$.')]

            if _dataset_split == 'train' and not for_generation:
                messages.append(dict(role='assistant', content=solution))

            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            
            parsed_output_text = self.parse_output_text(solution)
            datapoint['targets'] = [parsed_output_text]

            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(
                messages=messages,
                **datapoint,
            )
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)
        
        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]

    def parse_output_text(self, output_text):
        if output_text is None:
            return None
        
        output_text = self._parse(output_text)
        
        matches = re.findall(r'\\boxed{((?:[^{}]|{[^{}]*})*)}', output_text)
        if matches:
            return matches[-1].strip()
        
        matches = re.findall(r'\$([^$]+)\$', output_text)
        if matches:
            return matches[-1].strip()
            
        matches = re.findall(r'(?:^|[^\d])(\d+(?:\.\d+)?|\.\d+)(?:[^\d]|$)', output_text)
        if matches:
            return matches[-1].strip()
        
        return output_text
    
    def _parse(self, text):
        text = text.replace('\\dfrac', '\\frac')
        text = text.replace('\pi', '\\pi')
        text = text.replace('\\left', '')
        text = text.replace('\\right', '')
        return text

    def reward_fn(self, pred_text, datapoint):
        pred_text = self._parse(self.normalize_final_answer(pred_text))
        targets = datapoint['targets']
        targets = [self._parse(self.normalize_final_answer(target)) for target in targets]
        return pred_text in targets


class BigMathRLVerifiedDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='eval', for_generation=False, **kwargs):
        super().__init__()
        from core.evaluation.math_utils import normalize_final_answer
        self.normalize_final_answer = normalize_final_answer
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        path = 'path_to/yourname/data/SynthLabsAI/Big-Math-RL-Verified'
        dataset = load_from_disk(path)

        dataset_split = 'test' if dataset_split == 'eval' else dataset_split
        dataset = dataset['train']
        dataset = list(dataset)

        random.shuffle(dataset)

        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))

        datapoint_ids = []
        all_tokenized = []
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_tokenized) >= num_datapoints:
                break
            
            llama8b_solve_rate = datapoint['llama8b_solve_rate']
            if llama8b_solve_rate is None or llama8b_solve_rate < 0.2:
                continue

            problem = datapoint['problem']
            answer = datapoint['answer']
            messages = [dict(role='user', content=f'{problem}\n\nReason about the problem, derive your answer, and wrap your final answer in $\\boxed{{ }}$.')]

            if dataset_split == 'train' and not for_generation:
                messages.append(dict(role='assistant', content=answer))

            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            
            parsed_output_text = self.parse_output_text(answer)
            datapoint['targets'] = [parsed_output_text]

            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(
                messages=messages,
                **datapoint,
            )
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)
        
        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]

    def parse_output_text(self, output_text):
        if output_text is None:
            return None
        
        output_text = self._parse(output_text)
        
        matches = re.findall(r'\\boxed{((?:[^{}]|{[^{}]*})*)}', output_text)
        if matches:
            return matches[-1].strip()
        
        matches = re.findall(r'\$([^$]+)\$', output_text)
        if matches:
            return matches[-1].strip()
            
        matches = re.findall(r'(?:^|[^\d])(\d+(?:\.\d+)?|\.\d+)(?:[^\d]|$)', output_text)
        if matches:
            return matches[-1].strip()
        
        return output_text

    def _parse(self, text):
        text = text.replace('\\dfrac', '\\frac')
        text = text.replace('\pi', '\\pi')
        text = text.replace('\\left', '')
        text = text.replace('\\right', '')
        return text

    def reward_fn(self, pred_text, datapoint):
        pred_text = self._parse(self.normalize_final_answer(pred_text))
        targets = datapoint['targets']
        targets = [self._parse(self.normalize_final_answer(target)) for target in targets]
        return pred_text in targets


class OpenMathInstruction2Dataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='train', for_generation=False, **kwargs):
        super().__init__()
        from core.evaluation.math_utils import normalize_final_answer
        self.normalize_final_answer = normalize_final_answer
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        path = f'path_to/yourname/data/nvidia/OpenMathInstruct-2'
        dataset = load_from_disk(path)['train']
        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))
        dataset = dataset.select(range(num_datapoints))

        datapoint_ids = []
        all_tokenized = []
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_tokenized) >= num_datapoints:
                break
            problem = datapoint['problem']
            expected_answer = datapoint['expected_answer']
            generated_solution = datapoint['generated_solution']

            messages = [dict(role='user', content=f'{problem}\n\nReason about the problem, derive your answer, and wrap your final answer in $\\boxed{{ }}$.')]
            if dataset_split == 'train' and not for_generation:
                messages.append(dict(role='assistant', content=generated_solution))
            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            
            parsed_output_text = expected_answer
            datapoint['targets'] = [parsed_output_text]

            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(
                messages=messages,
                **datapoint,
            )
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)
        
        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]

    def parse_output_text(self, output_text):
        if output_text is None:
            return None
        
        matches = re.findall(r'\\boxed{((?:[^{}]|{[^{}]*})*)}', output_text)
        if matches:
            return matches[-1].strip()
        
        matches = re.findall(r'\$([^$]+)\$', output_text)
        if matches:
            return matches[-1].strip()
            
        matches = re.findall(r'(?:^|[^\d])(\d+(?:\.\d+)?|\.\d+)(?:[^\d]|$)', output_text)
        if matches:
            return matches[-1].strip()
            
        return output_text

    def reward_fn(self, pred_text, datapoint):
        pred_text = self.normalize_final_answer(pred_text)
        targets = datapoint['targets']
        targets = [self.normalize_final_answer(target) for target in targets]
        return pred_text in targets


class OmniMathDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='eval', for_generation=False, **kwargs):
        super().__init__()
        from core.evaluation.math_utils import normalize_final_answer
        self.normalize_final_answer = normalize_final_answer
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        path = f'path_to/yourname/data/KbsdJames/Omni-MATH'

        dataset = load_from_disk(path)['test']
        if dataset_split == 'train':
            dataset = dataset.select(range(3000))
        else:
            dataset = dataset.select(range(3000, len(dataset)))

        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))
        dataset = dataset.select(range(num_datapoints))

        datapoint_ids = []
        all_tokenized = []
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_tokenized) >= num_datapoints:
                break
            problem = datapoint['problem']
            solution = datapoint['solution']
            answer = datapoint['answer']

            messages = [dict(role='user', content=f'{problem}\n\nReason about the problem, derive your answer, and wrap your final answer in $\\boxed{{ }}$.')]
            if dataset_split == 'train' and not for_generation:
                messages.append(dict(role='assistant', content=solution))
            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            
            datapoint['targets'] = [answer]

            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(
                messages=messages,
                **datapoint,
            )
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)
        
        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]

    def parse_output_text(self, output_text):
        if output_text is None:
            return None
        
        matches = re.findall(r'\\boxed{((?:[^{}]|{[^{}]*})*)}', output_text)
        if matches:
            return matches[-1].strip()
        
        matches = re.findall(r'\$([^$]+)\$', output_text)
        if matches:
            return matches[-1].strip()
            
        matches = re.findall(r'(?:^|[^\d])(\d+(?:\.\d+)?|\.\d+)(?:[^\d]|$)', output_text)
        if matches:
            return matches[-1].strip()
            
        return output_text

    def reward_fn(self, pred_text, datapoint):
        pred_text = self.normalize_final_answer(pred_text)
        targets = datapoint['targets']
        targets = [self.normalize_final_answer(target) for target in targets]
        return pred_text in targets


class AIME24Dataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='eval', for_generation=False, **kwargs):
        super().__init__()
        from core.evaluation.math_utils import normalize_final_answer
        self.normalize_final_answer = normalize_final_answer
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        path = 'path_to/yourname/data/Maxwell-Jia/AIME_2024'
        dataset = load_from_disk(path)

        _dataset_split = 'test' if dataset_split == 'eval' else dataset_split
        dataset = dataset['train']
        dataset = list(dataset)
        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))

        datapoint_ids = []
        all_tokenized = []
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_tokenized) >= num_datapoints:
                break
            problem = datapoint['Problem']
            solution = datapoint['Solution']
            answer = datapoint['Answer']

            messages = [dict(role='user', content=f'{problem}\n\nReason about the problem, derive your answer, and wrap your final answer in $\\boxed{{ }}$.')]

            if _dataset_split == 'train' and not for_generation:
                messages.append(dict(role='assistant', content=solution))

            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            
            parsed_output_text = self.parse_output_text(solution)
            datapoint['targets'] = [parsed_output_text]

            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(
                messages=messages,
                **datapoint,
            )
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)
        
        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]

    def parse_output_text(self, output_text):
        if output_text is None:
            return None
        
        output_text = self._parse(output_text)
        
        matches = re.findall(r'\\boxed{((?:[^{}]|{[^{}]*})*)}', output_text)
        if matches:
            return matches[-1].strip()
        
        matches = re.findall(r'\$([^$]+)\$', output_text)
        if matches:
            return matches[-1].strip()
            
        matches = re.findall(r'(?:^|[^\d])(\d+(?:\.\d+)?|\.\d+)(?:[^\d]|$)', output_text)
        if matches:
            return matches[-1].strip()
        
        return output_text

    def _parse(self, text):
        text = text.replace('\\dfrac', '\\frac')
        text = text.replace('\pi', '\\pi')
        text = text.replace('\\left', '')
        text = text.replace('\\right', '')
        return text

    def reward_fn(self, pred_text, datapoint):
        pred_text = self._parse(self.normalize_final_answer(pred_text))
        targets = datapoint['targets']
        targets = [self._parse(self.normalize_final_answer(target)) for target in targets]
        return pred_text in targets


class Tulu3SFTIFEvalDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='train', for_generation=False, **kwargs):
        super().__init__()
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        path = 'path_to/yourname/data/allenai/tulu-3-sft-personas-instruction-following'
        dataset = load_from_disk(path)['train']
        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))
        dataset = dataset.select(range(num_datapoints))
        
        datapoint_ids = []
        all_tokenized = []
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_tokenized) >= num_datapoints:
                break
            
            messages = datapoint['messages']
            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            
            datapoint['targets'] = [None]
                
            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(**datapoint)
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)

        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,  
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]


class Tulu3PreferenceForSFTIFEvalDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='train', for_generation=False, **kwargs):
        super().__init__()
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        path = 'path_to/yourname/data/allenai/tulu-3-pref-personas-instruction-following'
        dataset = load_from_disk(path)['train']
        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))
        dataset = dataset.select(range(num_datapoints))
        
        datapoint_ids = []
        all_tokenized = []
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_tokenized) >= num_datapoints:
                break
            
            messages = datapoint['chosen']
            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            
            datapoint['targets'] = [None]
                
            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(**datapoint)
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)

        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,  
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]


class Tulu3ExpertSFTIFEvalDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='train', for_generation=False, **kwargs):
        super().__init__()
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        data_generated_from_model_name = kwargs.get('data_generated_from_model_name')
        names = data_generated_from_model_name.split('+')
        if len(names) == 1 or names[1] == 'none':
            data_generated_from_model_name = names[0]
            data_subdir_name = 'none'
            path = f'{data_generated_from_model_name}/ifeval_verify/train/predictions.jsonl'
        else:
            data_generated_from_model_name = names[0]
            data_subdir_name = names[1]
            path = f'{data_generated_from_model_name}/ifeval_verify/train/{data_subdir_name}/predictions.jsonl'

        dataset = []
        with open(path, 'r') as f:
            for line in f:
                datapoint = json.loads(line.strip())
                if not datapoint['correct']:
                    continue
                dataset.append(datapoint)

        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))
        dataset = dataset[:num_datapoints]
        
        datapoint_ids = []
        all_tokenized = []
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_tokenized) >= num_datapoints:
                break

            messages = datapoint['messages']
            messages.append(dict(role='assistant', content=datapoint['output_text'].strip()))
            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            
            datapoint['targets'] = [None]
                
            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(**datapoint)
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)

        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,  
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]


class Tulu3ExpertSFTContrastiveIFEvalDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='train', for_generation=False, **kwargs):
        super().__init__()
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        dataset = []
        path = 'path_to/yourname/local_models/meta-llama/Llama-3.3-70B-Instruct/ifeval_verify/train/predictions.jsonl'
        with open(path, 'r') as f:
            for line in f:
                datapoint = json.loads(line.strip())
                dataset.append(datapoint)

        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))
        dataset = dataset[:num_datapoints]
        
        datapoint_ids = []
        all_tokenized = []
        all_corrects = []
        num_pos = 0
        num_neg = 0
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_tokenized) >= num_datapoints:
                break
            
            if not datapoint['correct'] and num_neg >= num_pos:
                continue

            messages = datapoint['messages']
            messages.append(dict(role='assistant', content=datapoint['output_text'].strip()))
            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            
            if datapoint['correct']:
                num_pos += 1
            else:
                num_neg += 1

            all_corrects.append(datapoint['correct'])
            datapoint['targets'] = [None]
            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(**datapoint)
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        all_corrects = torch.tensor(all_corrects).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)

        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,  
            labels=labels,
            corrects=all_corrects,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]

    def __getitem__(self, idx):
        return dict(
            datapoint_ids=self.data['datapoint_ids'][idx],
            input_ids=self.data['input_ids'][idx],
            attention_mask=self.data['attention_mask'][idx],
            labels=self.data['labels'][idx],
            corrects=self.data['corrects'][idx],
        )

    def collate_fn(self, batch):
        return dict(
            datapoint_ids=torch.stack([b['datapoint_ids'] for b in batch]),
            input_ids=torch.stack([b['input_ids'] for b in batch]),
            attention_mask=torch.stack([b['attention_mask'] for b in batch]),
            labels=torch.stack([b['labels'] for b in batch]),
            corrects=torch.stack([b['corrects'] for b in batch]),
        )


class Tulu3RLVRIFEvalDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='train', for_generation=False, **kwargs):
        super().__init__()
        from core.evaluation.ifeval_utils import IF_FUNCTIONS_MAP
        self.if_functions_map = IF_FUNCTIONS_MAP
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        path = 'path_to/yourname/data/allenai/RLVR-IFeval'
        dataset = load_from_disk(path)['train']
        if dataset_split == 'train':
            dataset = dataset.select(range(13000))
        else:
            dataset = dataset.select(range(13000, len(dataset)))

        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))
        dataset = dataset.select(range(num_datapoints))

        datapoint_ids = []
        all_tokenized = []
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_tokenized) >= num_datapoints:
                break

            messages = datapoint['messages']
            if dataset_split == 'train' and not for_generation:
                # TODO: handle this
                messages.append(dict(role='assistant', content='...'))
            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            
            datapoint['targets'] = json.loads(datapoint['ground_truth'])

            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(**datapoint)
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)

        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]

    def parse_output_text(self, output_text):
        return output_text

    def reward_fn(self, pred_text, datapoint):
        func_name = datapoint['targets']['func_name']
        func = self.if_functions_map[func_name]
        try:
            reward = func(pred_text, **datapoint['targets'])
        except Exception as e:
            print(e)
            print(pred_text)
            print(datapoint['targets'])
            reward = False
        return reward


class MagpieIFEvalDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='train', for_generation=False, **kwargs):
        super().__init__()
        from core.evaluation.ifeval_utils import IF_FUNCTIONS_MAP
        self.if_functions_map = IF_FUNCTIONS_MAP
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        path = 'path_to/yourname/data/argilla/ifeval-like-data'
        dataset = load_from_disk(path)['train']
        raise NotImplementedError
#        if dataset_split == 'train':
#            dataset = dataset.select(range(13000))
#        else:
#            dataset = dataset.select(range(13000, len(dataset)))

        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))
        dataset = dataset.select(range(num_datapoints))

        datapoint_ids = []
        all_tokenized = []
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_tokenized) >= num_datapoints:
                break
            print(datapoint)
            input()
            messages = datapoint['messages']
            if dataset_split == 'train' and not for_generation:
                # TODO: handle this
                messages.append(dict(role='assistant', content='...'))
            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            
            datapoint['targets'] = json.loads(datapoint['ground_truth'])

            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(**datapoint)
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)

        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]

    def parse_output_text(self, output_text):
        return output_text

    def reward_fn(self, pred_text, datapoint):
        func_name = datapoint['targets']['func_name']
        func = self.if_functions_map[func_name]
        return func(pred_text, **datapoint['targets'])


class MetaMathDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='train', for_generation=False, **kwargs):
        super().__init__()
        from core.evaluation.math_utils import normalize_final_answer
        self.normalize_final_answer = normalize_final_answer
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        path = 'path_to/yourname/data/meta-math/MetaMathQA'
        dataset = load_from_disk(path)['train']
        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))
        dataset = dataset.select(range(num_datapoints))

        datapoint_ids = []
        all_tokenized = []
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_tokenized) >= num_datapoints:
                break

            query = datapoint['query']
            response = datapoint['response']
            *response, parsed_answer = response.split('\n')
            response = '\n'.join(response)
            if '####' in response:
                response = response.split('####')[0]
            messages = [dict(role='user', content=query)]
            if dataset_split == 'train' and not for_generation:
                messages.append(dict(role='assistant', content=response))

            parsed_answer = self.parse_output_text(parsed_answer)
            datapoint['targets'] = [parsed_answer]
            
            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            
            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(
                messages=messages,
                **datapoint,
            )
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)

        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]

    def parse_output_text(self, output_text):
        if output_text is None:
            return None
        
        matches = re.findall(r'\\boxed{((?:[^{}]|{[^{}]*})*)}', output_text)
        if matches:
            return matches[-1].strip()
        
        matches = re.findall(r'\$([^$]+)\$', output_text)
        if matches:
            return matches[-1].strip()
            
        matches = re.findall(r'(?:^|[^\d])(\d+(?:\.\d+)?|\.\d+)(?:[^\d]|$)', output_text)
        if matches:
            return matches[-1].strip()
            
        return output_text

    def reward_fn(self, pred_text, datapoint):
        pred_text = self.normalize_final_answer(pred_text)
        targets = datapoint['targets']
        targets = [self.normalize_final_answer(target) for target in targets]
        return pred_text in targets


class UltraInteractMathRolloutDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='train', for_generation=False, **kwargs):
        super().__init__()
        from core.evaluation.math_utils import normalize_final_answer
        self.normalize_final_answer = normalize_final_answer
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        path = 'path_to/yourname/data/Windy0822/ultrainteract_math_rollout'
        dataset = load_from_disk(path)

        _dataset_split = 'test' if dataset_split == 'eval' else dataset_split
        dataset = dataset['train']
        dataset = list(dataset)
        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))

        datapoint_ids = []
        all_tokenized = []
        corrects = []
        datapoint_idx = 0

        pbar = tqdm(dataset, total=num_datapoints)
        for datapoint in pbar:
            prompt = datapoint['prompt']
            reference = datapoint['reference']
            completions = datapoint['completions']
            correctness = datapoint['correctness']
            steps = datapoint['steps']
            num_rollouts_used = 0

            for completion, step, correct in zip(completions, steps, correctness):
                if len(all_tokenized) >= num_datapoints:
                    break
                messages = [dict(role='user', content=prompt)]
    
                if _dataset_split == 'train' and not for_generation:
                    messages.append(dict(role='assistant', content=completion))
                
                tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
                if tokenized is None:
                    continue
                
                datapoint['targets'] = [reference]
    
                all_tokenized.append(tokenized)
                corrects.append(correct)
                datapoint_ids.append(datapoint_idx)
                self.id_to_datapoint[datapoint_idx] = dict(
                    messages=messages,
                    **datapoint,
                )
                datapoint_idx += 1
                num_rollouts_used += 1
            pbar.update(num_rollouts_used)
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)
        corrects = torch.tensor(corrects).long()
        
        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            corrects=corrects,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]

    def parse_output_text(self, output_text):
        if output_text is None:
            return None
        
        matches = re.findall(r'\\boxed{((?:[^{}]|{[^{}]*})*)}', output_text)
        if matches:
            return matches[-1].strip()
        
        matches = re.findall(r'\$([^$]+)\$', output_text)
        if matches:
            return matches[-1].strip()
            
        matches = re.findall(r'(?:^|[^\d])(\d+(?:\.\d+)?|\.\d+)(?:[^\d]|$)', output_text)
        if matches:
            return matches[-1].strip()
            
        return output_text

    def reward_fn(self, pred_text, datapoint):
        pred_text = self.normalize_final_answer(pred_text)
        targets = datapoint['targets']
        targets = [self.normalize_final_answer(target) for target in targets]
        return pred_text in targets

    def __getitem__(self, idx):
        return dict(
            datapoint_ids=self.data['datapoint_ids'][idx],
            input_ids=self.data['input_ids'][idx],
            attention_mask=self.data['attention_mask'][idx],
            labels=self.data['labels'][idx],
            corrects=self.data['corrects'][idx],
        )

    def collate_fn(self, batch):
        return dict(
            datapoint_ids=torch.stack([b['datapoint_ids'] for b in batch]),
            input_ids=torch.stack([b['input_ids'] for b in batch]),
            attention_mask=torch.stack([b['attention_mask'] for b in batch]),
            labels=torch.stack([b['labels'] for b in batch]),
            corrects=torch.stack([b['corrects'] for b in batch]),
        )


class HExPHIDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='eval', for_generation=False, **kwargs):
        super().__init__()
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        path = 'path_to/yourname/data/HEx-PHI'

        dataset = [] 
        for file_name in os.listdir(path):
            if file_name.endswith('.csv'):
                category = int(file_name.strip('.csv').split('_')[-1])
                with open(os.path.join(path, file_name), 'r') as f:
                    for line in f:
                        dataset.append(dict(
                            category=category,
                            query=line.strip(),
                        ))

        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))

        datapoint_ids = []
        all_tokenized = []
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_tokenized) >= num_datapoints:
                break
            
            query = datapoint['query']
            messages = [dict(role='user', content=query)]

            if dataset_split == 'train' and not for_generation:
                messages.append(dict(role='assistant', content='...'))

            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            
            datapoint['targets'] = [None]

            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(
                messages=messages,
                **datapoint,
            )
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)
        
        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]


class WildJailbreakDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='eval', for_generation=False, **kwargs):
        super().__init__()
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        path = f'path_to/yourname/data/allenai/wildjailbreak/{dataset_split}'
        dataset = load_from_disk(path)['train']

        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))

        datapoint_ids = []
        all_tokenized = []
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_tokenized) >= num_datapoints:
                break
            
            query = datapoint['adversarial']
            messages = [dict(role='user', content=query)]

            if dataset_split == 'train' and not for_generation:
                messages.append(dict(role='assistant', content=datapoint['completion']))

            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            
            if dataset_split == 'train':
                datapoint['targets'] = [None]
            else:
                datapoint['targets'] = datapoint['label']

            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(
                messages=messages,
                **datapoint,
            )
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)

        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]
    
    def parse_output_text(self, output_text):
        return output_text
    
    def reward_fn(self, pred_text, datapoint):
        return None


class WildGuardTestDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='eval', for_generation=False, **kwargs):
        super().__init__()
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        path = f'path_to/yourname/data/walledai/WildGuardTest'
        dataset = load_from_disk(path)['train']
        
        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))

        datapoint_ids = []
        all_tokenized = []  
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_tokenized) >= num_datapoints:
                break
            
            query = datapoint['prompt']
            messages = [dict(role='user', content=query)]

            if dataset_split == 'train' and not for_generation: 
                messages.append(dict(role='assistant', content='...'))

            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            
            datapoint['targets'] = int(datapoint['label'] == 'harmful')

            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(
                messages=messages,
                **datapoint,
            )
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)

        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]

    def parse_output_text(self, output_text):
        return output_text
    
    def reward_fn(self, pred_text, datapoint):
        return None


class XSTestDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='eval', for_generation=False, **kwargs):
        super().__init__()
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)

        path = f'path_to/yourname/data/walledai/XSTest'
        dataset = load_from_disk(path)['test']
        
        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        num_datapoints = min(num_datapoints, len(dataset))

        datapoint_ids = []
        all_tokenized = []
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if len(all_tokenized) >= num_datapoints:
                break
            
            query = datapoint['prompt']
            messages = [dict(role='user', content=query)]

            if dataset_split == 'train' and not for_generation:
                messages.append(dict(role='assistant', content='...'))

            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)
            if tokenized is None:
                continue
            
            datapoint['targets'] = int(datapoint['label'] == 'unsafe')

            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(
                messages=messages,
                **datapoint,
            )
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100) 

        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]

    def parse_output_text(self, output_text):
        return output_text
    
    def reward_fn(self, pred_text, datapoint):
        return None


class DummyDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='eval', for_generation=False, **kwargs):
        super().__init__()
        import string
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)
        self.characters = list(string.ascii_letters)

        cached_path = f'path_to/yourname/data/dummy/select_chars/{dataset_split}.jsonl'
        with open(cached_path, 'r') as f:
            dataset = [json.loads(line) for line in f]

        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        self.num_datapoints = min(num_datapoints, len(dataset))

        all_tokenized = []
        datapoint_ids = []
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if datapoint_idx >= num_datapoints:
                break

            user_content = datapoint['question']
            messages = [dict(role='user', content=user_content)]

            answer = datapoint['answer']
            if dataset_split == 'train' and not for_generation:
                messages.append(dict(role='assistant', content=f'The answer is $\\boxed{{{answer}}}$'))
            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)

            datapoint['targets'] = [answer]
            
            if tokenized is None:
                continue
            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(messages=messages, **datapoint)
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)

        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]

    def generate_datapoint(self):
        string = random.sample(self.characters, 10)
        characters = string[1] + string[4]
        answer = ''.join(characters)
        string = ''.join(string)
        question = f'Return two of the characters in the given string. Put your answer in $\\boxed{{}}$. E.g., string = "elephant", and your answer being "ae", you should put down your answer as $\\boxed{{ae}}$. \n\nThe string is: {string}'
        return dict(
            question=question,
            answer=answer,
        )

    def parse_output_text(self, output_text):
        matches = re.findall(r'\$\\boxed\{(.*?)\}\$', output_text)
        if matches:
            return matches[-1].strip()
        
        return output_text

    def reward_fn(self, pred_text, datapoint):
        targets = datapoint['targets']
        reward = pred_text in targets
        return reward


class OrderSelectionDataset(BaseDataset):
    def __init__(self, num_datapoints, max_seq_length, tokenizer, dataset_split='eval', for_generation=False, **kwargs):
        super().__init__()
        import string
        self.num_datapoints = num_datapoints
        self.max_seq_length = max_seq_length
        self.tokenizer = tokenizer
        self.dataset_split = dataset_split
        self.data = []
        self.id_to_datapoint = dict()
        self.model_name = kwargs.get('model_name', None)
        self.characters = list(string.ascii_letters)

        cached_path = f'path_to/yourname/data/dummy/select_chars/{dataset_split}.jsonl'
        with open(cached_path, 'r') as f:
            dataset = [json.loads(line) for line in f]

        if num_datapoints == 'all':
            num_datapoints = len(dataset)
        self.num_datapoints = min(num_datapoints, len(dataset))

        all_tokenized = []
        datapoint_ids = []
        for datapoint_idx, datapoint in tqdm(enumerate(dataset), total=num_datapoints):
            if datapoint_idx >= num_datapoints:
                break

            user_content = datapoint['question']
            messages = [dict(role='user', content=user_content)]

            answer = datapoint['answer']
            if dataset_split == 'train' and not for_generation:
                messages.append(dict(role='assistant', content=f'The answer is $\\boxed{{{answer}}}$'))
            tokenized = tokenize_messages(messages, tokenizer, max_seq_length, dataset_split, **kwargs)

            datapoint['targets'] = [answer]
            
            if tokenized is None:
                continue
            all_tokenized.append(tokenized)
            datapoint_ids.append(datapoint_idx)
            self.id_to_datapoint[datapoint_idx] = dict(messages=messages, **datapoint)
        
        input_ids = [tokenized['input_ids'] for tokenized in all_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in all_tokenized]
        labels = [tokenized['labels'] for tokenized in all_tokenized]

        datapoint_ids = torch.tensor(datapoint_ids).long()
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)

        self.data = dict(
            datapoint_ids=datapoint_ids,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        self.datapoints = [self.id_to_datapoint[datapoint_id] for datapoint_id in datapoint_ids.tolist()]

    def generate_datapoint(self):
        string = random.sample(self.characters, 10)
        characters = string[1] + string[4]
        answer = ''.join(characters)
        string = ''.join(string)
        question = f'Return two of the characters in the given string. Put your answer in $\\boxed{{}}$. E.g., string = "elephant", and your answer being "ae", you should put down your answer as $\\boxed{{ae}}$. \n\nThe string is: {string}'
        return dict(
            question=question,
            answer=answer,
        )

    def parse_output_text(self, output_text):
        matches = re.findall(r'\$\\boxed\{(.*?)\}\$', output_text)
        if matches:
            return matches[-1].strip()
        
        return output_text

    def reward_fn(self, pred_text, datapoint):
        targets = datapoint['targets']
        reward = pred_text in targets
        return reward
