# safe_rlhf/datasets/helpful_eval_prompt.py

from __future__ import annotations

from datasets import load_dataset
from safe_rlhf.datasets.base import RawDataset, RawSample


__all__ = ['HelpfulPromptDataset']


class HelpfulPromptDataset(RawDataset):
    """Dataset for Helpful format, containing a single 'prompt' field."""

    NAME: str = 'helpful-prompt'
    ALIASES: tuple[str, ...] = ('helpful_prompt',)

    def __init__(self, path: str | None = "data/helpful_problem.json") -> None:
        """
        Initialize the dataset.

        Args:
            path (str | None): The path to the JSON data file. This is required.
        """
        if not path:
            raise ValueError(
                'A path to the JSON data file is required for HelpfulPromptDataset.'
            )
        # The 'json' type is used for loading local JSON files.
        self.data = load_dataset('json', data_files=path, split='train')

    def __getitem__(self, index: int) -> RawSample:
        """
        Get a single sample from the dataset.

        The raw data is in the format:
        {
            "prompt": "...",
            "category": "..."
        }

        This method transforms it into the RawSample format expected by PromptOnlyDataset:
        {
            "input": "..."
        }
        """
        data = self.data[index]
        # The `PromptOnlyDataset` expects a `RawSample` with an 'input' key.
        # We map our 'prompt' key to the 'input' key.
        return RawSample(input=data['prompt'])

    def __len__(self) -> int:
        """Get the number of samples in the dataset."""
        return len(self.data)