

from collections.abc import Sized

import pytest
import torch
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import Dataset, RandomSampler

from verl.experimental.dataset.sampler import AbstractCurriculumSampler
from verl.trainer.main_ppo import create_rl_sampler

class RandomCurriculumSampler(AbstractCurriculumSampler):
    def __init__(
        self,
        data_source: Sized,
        data_config: DictConfig,
    ):
        train_dataloader_generator = torch.Generator()
        train_dataloader_generator.manual_seed(1)
        sampler = RandomSampler(data_source=data_source)
        self.sampler = sampler

    def __iter__(self):
        return self.sampler.__iter__()

    def __len__(self) -> int:
        return len(self.sampler)

    def update(self, batch) -> None:
        return

class MockIncorrectSampler:

    def __init__(self, data_source, data_config):
        pass

class MockChatDataset(Dataset):
    def __init__(self):
        self.data = [
            {"prompt": "What's your name?", "response": "My name is Assistant."},
            {"prompt": "How are you?", "response": "I'm doing well, thank you."},
            {"prompt": "What is the capital of France?", "response": "Paris."},
            {
                "prompt": "Tell me a joke.",
                "response": "Why did the chicken cross the road? To get to the other side!",
            },
            {"prompt": "What is 2+2?", "response": "4"},
        ]

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

def test_create_custom_curriculum_samper():
    data_config = OmegaConf.create(
        {
            "dataloader_num_workers": 0,
            "sampler": {
                "class_path": "pkg://tests.utils.dataset.test_create_rl_sampler_on_cpu",
                "class_name": "RandomCurriculumSampler",
            },
        }
    )

    dataset = MockChatDataset()

    create_rl_sampler(data_config, dataset)

def test_create_custom_curriculum_samper_wrong_class():
    data_config = OmegaConf.create(
        {
            "sampler": {
                "class_path": "pkg://tests.utils.dataset.test_create_rl_sampler_on_cpu",
                "class_name": "MockIncorrectSampler",
            }
        }
    )

    dataset = MockChatDataset()

    with pytest.raises(AssertionError):
        create_rl_sampler(data_config, dataset)
