
"""Helpful and Harmless Dialogue Datasets from Anthropic."""

from __future__ import annotations

from typing import ClassVar

from datasets import load_dataset
from safe_rlhf.datasets.base import RawDataset, RawSample


__all__ = [
    'HhRLHFDialogueDataset',
    'HhRLHFHarmlessDialogueDataset',
    'HhRLHFHelpfulDialogueDataset',
    'HhRLHFPreferenceDataset',
    'HhRLHFHarmlessPreferenceTrainDataset',
    'HhRLHFHarmlessPreferenceTestDataset',
    'HhRLHFHelpfulPreferenceTrainDataset',
    'HhRLHFHelpfulPreferenceTestDataset',
]


class HhRLHFDialogueDataset(RawDataset):
    NAME: ClassVar[str] = 'hh-rlhf-dialogue'
    ALIASES: tuple[str, ...] = ('hh-dialogue',)
    DATA_DIR: ClassVar[str | None] = None

    def __init__(self, path: str | None = None) -> None:
        self.data = load_dataset(
            path or 'PKU-Alignment/processed-hh-rlhf',
            data_dir=self.DATA_DIR,
            split='train',
        )

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        dialogue = [content['text'] for content in data['context']]
        dialogue.append(data['chosen']['text'])
        return RawSample(dialogue=dialogue)

    def __len__(self) -> int:
        return len(self.data)


class HhRLHFHarmlessDialogueDataset(HhRLHFDialogueDataset):
    NAME: str = 'hh-rlhf-harmless-dialogue'
    ALIASES: tuple[str, ...] = (
        'hh-rlhf-dialogue/harmless-base',
        'hh-harmless-dialogue',
        'hh-dialogue/harmless-base',
    )
    DATA_DIR: str = 'harmless-base'


class HhRLHFHelpfulDialogueDataset(HhRLHFDialogueDataset):
    NAME: str = 'hh-rlhf-helpful-dialogue'
    ALIASES: tuple[str, ...] = (
        'hh-rlhf-dialogue/helpful-base',
        'hh-helpful-dialogue',
        'hh-dialogue/helpful-base',
    )
    DATA_DIR: str = 'helpful-base'


class HhRLHFPreferenceDataset(RawDataset):
    NAME: ClassVar[str] = 'hh-rlhf-preference'
    ALIASES: tuple[str, ...] = ('hh-preference',)
    DATA_DIR: ClassVar[str | None] = None
    SPLIT: ClassVar[str]

    def __init__(self, path: str | None = None) -> None:
        self.data = load_dataset(
            path or 'PKU-Alignment/processed-hh-rlhf',
            data_dir=self.DATA_DIR,
            split=self.SPLIT,
        )

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        dialogue = [content['text'] for content in data['context']]
        answer = data['chosen']['text']
        other_answer = data['rejected']['text']

        return RawSample(
            input=dialogue,
            answer=answer,
            other_answer=other_answer,
            better=True,
        )

    def __len__(self) -> int:
        return len(self.data)


class HhRLHFHarmlessPreferenceTrainDataset(HhRLHFPreferenceDataset):
    NAME: str = 'hh-rlhf-harmless-preference/train'
    ALIASES: tuple[str, ...] = (
        'hh-rlhf-preference/harmless-base/train',
        'hh-harmless-preference/train',
        'hh-preference/harmless-base/train',
    )
    DATA_DIR: str = 'harmless-base'
    SPLIT: str = 'train'


class HhRLHFHarmlessPreferenceTestDataset(HhRLHFPreferenceDataset):
    NAME: str = 'hh-rlhf-harmless-preference/test'
    ALIASES: tuple[str, ...] = (
        'hh-rlhf-preference/harmless-base/test',
        'hh-harmless-preference/test',
        'hh-preference/harmless-base/test',
    )
    DATA_DIR: str = 'harmless-base'
    SPLIT: str = 'test'


class HhRLHFHelpfulPreferenceTrainDataset(HhRLHFPreferenceDataset):
    NAME: str = 'hh-rlhf-helpful-preference/train'
    ALIASES: tuple[str, ...] = (
        'hh-rlhf-preference/helpful-base/train',
        'hh-helpful-preference/train',
        'hh-preference/helpful-base/train',
    )
    DATA_DIR: str = 'helpful-base'
    SPLIT: str = 'train'


class HhRLHFHelpfulPreferenceTestDataset(HhRLHFPreferenceDataset):
    NAME: str = 'hh-rlhf-helpful-preference/test'
    ALIASES: tuple[str, ...] = (
        'hh-rlhf-preference/helpful-base/test',
        'hh-helpful-preference/test',
        'hh-preference/helpful-base/test',
    )
    DATA_DIR: str = 'helpful-base'
    SPLIT: str = 'test'
