# Copyright 2025 Amazon.com Inc and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     XXXX
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
test create_rl_sampler
"""

from collections.abc import Sized

import pytest
import torch
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import Dataset, RandomSampler

from verl.experimental.dataset.sampler import AbstractCurriculumSampler
from verl.trainer.main_ppo import create_rl_sampler


class RandomCurriculumSampler(AbstractCurriculumSampler):
    def __init__(
        self,
        data_source: Sized,
        data_config: DictConfig,
    ):
        train_dataloader_generator = torch.Generator()
        train_dataloader_generator.manual_seed(1)
        sampler = RandomSampler(data_source=data_source)
        self.sampler = sampler

    def __iter__(self):
        return self.sampler.__iter__()

    def __len__(self) -> int:
        return len(self.sampler)

    def update(self, batch) -> None:
        return


class MockIncorrectSampler:
    """A fake sampler class that does not adhere to the AbstractCurriculumSampler interface."""

    def __init__(self, data_source, data_config):
        pass


class MockChatDataset(Dataset):
    def __init__(self):
        self.data = [
            {"prompt": "What's your name?", "response": "My name is Assistant."},
            {"prompt": "How are you?", "response": "I'm doing well, thank you."},
            {"prompt": "What is the capital of France?", "response": "Paris."},
            {
                "prompt": "Tell me a joke.",
                "response": "Why did the chicken cross the road? To get to the other side!",
            },
            {"prompt": "What is 2+2?", "response": "4"},
        ]

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)


def test_create_custom_curriculum_samper():
    data_config = OmegaConf.create(
        {
            "dataloader_num_workers": 0,
            "sampler": {
                "class_path": "pkg://tests.utils.dataset.test_create_rl_sampler_on_cpu",
                "class_name": "RandomCurriculumSampler",
            },
        }
    )

    dataset = MockChatDataset()

    # doesn't raise
    create_rl_sampler(data_config, dataset)


def test_create_custom_curriculum_samper_wrong_class():
    data_config = OmegaConf.create(
        {
            "sampler": {
                "class_path": "pkg://tests.utils.dataset.test_create_rl_sampler_on_cpu",
                "class_name": "MockIncorrectSampler",
            }
        }
    )

    dataset = MockChatDataset()

    # MockIncorrectSampler is not an instance of AbstractCurriculumSampler, so raises
    with pytest.raises(AssertionError):
        create_rl_sampler(data_config, dataset)
