

import torch
from PIL.Image import Image

from verl.utils.dataset import RLHFDataset
from verl.utils.tokenizer import get_processor, get_tokenizer


def test_image_dataset():
    tokenizer = get_tokenizer("Qwen/Qwen2.5-VL-7B-Instruct", use_fast=True)
    processor = get_processor("Qwen/Qwen2.5-VL-7B-Instruct", use_fast=True)
    dataset = RLHFDataset(
        data_path="hiyouga/geometry3k@test",
        tokenizer=tokenizer,
        processor=processor,
        prompt_key="problem",
        answer_key="answer",
        image_key="images",
        max_prompt_length=16,
        truncation="right",
        filter_overlong_prompts=False,
    )
    token_ids = [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 151652, 151655]
    assert set(dataset[0].keys()) == {
        "input_ids",
        "attention_mask",
        "position_ids",
        "raw_prompt_ids",
        "ground_truth",
        "multi_modal_data",
    }
    assert torch.all(dataset[0]["input_ids"] == torch.tensor(token_ids))
    assert torch.all(dataset[0]["attention_mask"] == torch.ones(16))
    assert torch.all(dataset[0]["position_ids"] == torch.arange(16).unsqueeze(0).expand(3, -1))
    assert list(dataset[0]["position_ids"].size()) == [3, 16]  # avoid fake positive caused by broadcasting
    assert dataset[0]["raw_prompt_ids"] == token_ids
    assert dataset[0]["ground_truth"] == "48"
    assert isinstance(dataset[0]["multi_modal_data"]["images"][0], Image)


if __name__ == "__main__":
    test_image_dataset()
