# Copyright 2023-2024 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dataset classes."""

from __future__ import annotations

from typing import Dict

import torch
from torch.utils.data import Dataset

from safe_rlhf.datasets import raw
from safe_rlhf.datasets.base import (
    CollatorBase,
    RawDataset,
    RawSample,
    TokenizedDataset,
    parse_dataset,
)
from safe_rlhf.datasets.preference import (
    PreferenceBatch,
    PreferenceCollator,
    PreferenceDataset,
    PreferenceSample,
)
from safe_rlhf.datasets.safehelpful_preference import (
    SafeHelpfulPreferenceBatch,
    SafeHelpfulPreferenceCollator,
    SafeHelpfulPreferenceDataset,
    SafeHelpfulPreferenceSample,
)
from safe_rlhf.datasets.only_safe_better import (
    OnlySafeBetterDataset,
    OnlySafeBetterCollator,
    OnlySafeBetterSample,
    OnlySafeBetterBatch,
)
from safe_rlhf.datasets.cost_preference import (
    CostPreferenceBatch,
    CostPreferenceCollator,
    CostPreferenceSample,
    CostPreferenceDataset,
)
from safe_rlhf.datasets.prompt_only import (
    PromptOnlyBatch,
    PromptOnlyCollator,
    PromptOnlyDataset,
    PromptOnlySample,
)
from safe_rlhf.datasets.safety_preference import (
    SafetyPreferenceBatch,
    SafetyPreferenceCollator,
    SafetyPreferenceDataset,
    SafetyPreferenceSample,
)
from safe_rlhf.datasets.supervised import (
    SupervisedBatch,
    SupervisedCollator,
    SupervisedDataset,
    SupervisedSample,
)
from safe_rlhf.datasets.supervised_pairwise import SupervisedPairDataset

from safe_rlhf.datasets.raw import *  # noqa: F403

__all__ = [
    'DummyDataset',
    'parse_dataset',
    'RawDataset',
    'RawSample',
    'TokenizedDataset',
    'CollatorBase',
    'PreferenceDataset',
    'PreferenceSample',
    'PreferenceBatch',
    'PreferenceCollator',
    'SafeHelpfulPreferenceBatch',
    'SafeHelpfulPreferenceCollator',
    'SafeHelpfulPreferenceDataset',
    'SafeHelpfulPreferenceSample',
    'OnlySafeBetterDataset',
    'OnlySafeBetterCollator',
    'OnlySafeBetterSample',
    'OnlySafeBetterBatch',
    'CostPreferenceBatch',
    'CostPreferenceCollator',
    'CostPreferenceSample',
    'CostPreferenceDataset',
    'PromptOnlyDataset',
    'PromptOnlyCollator',
    'PromptOnlySample',
    'PromptOnlyBatch',
    'SafetyPreferenceDataset',
    'SafetyPreferenceCollator',
    'SafetyPreferenceSample',
    'SafetyPreferenceBatch',
    'SupervisedDataset',
    'SupervisedCollator',
    'SupervisedSample',
    'SupervisedBatch',
    'SupervisedPairDataset',
    *raw.__all__,
]


class DummyDataset(Dataset[Dict[str, torch.Tensor]]):
    def __init__(self, length: int) -> None:
        self.length = length

    def __len__(self) -> int:
        return self.length

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        return {}
