# Copyright 2023 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dataset classes."""

from __future__ import annotations

from typing import Dict

import torch
from torch.utils.data import Dataset

from safe_rlhf.datasets import raw
from safe_rlhf.datasets.base import (
    CollatorBase,
    RawDataset,
    RawSample,
    TokenizedDataset,
    parse_dataset,
)
from safe_rlhf.datasets.preference import (
    PreferenceBatch,
    PreferenceCollator,
    PreferenceDataset,
    PreferenceSample,
)
from safe_rlhf.datasets.dpo_safety_preference import (
    DPOSafetyPreferenceBatch,
    DPOSafetyPreferenceCollator,
    DPOSafetyPreferenceDataset,
    DPOSafetyPreferenceSample,
)
from safe_rlhf.datasets.prompt_only import (
    PromptOnlyBatch,
    PromptOnlyCollator,
    PromptOnlyDataset,
    PromptOnlySample,
    PromptOnlyRawDataset,
)
from safe_rlhf.datasets.raw import *  # noqa: F403
from safe_rlhf.datasets.safety_preference import (
    SafetyPreferenceBatch,
    SafetyPreferenceCollator,
    SafetyPreferenceDataset,
    SafetyPreferenceSample,
)
from safe_rlhf.datasets.supervised import (
    SupervisedBatch,
    SupervisedCollator,
    SupervisedDataset,
    SupervisedSample,
)

from safe_rlhf.datasets.safehelpful_preference import (
    SafeHelpfulPreferenceBatch,
    SafeHelpfulPreferenceCollator,
    SafeHelpfulPreferenceDataset,
    SafeHelpfulPreferenceSample,
)


__all__ = [
    'DummyDataset',
    'parse_dataset',
    'RawDataset',
    'RawSample',
    'TokenizedDataset',
    'CollatorBase',
    'PreferenceDataset',
    'PreferenceSample',
    'PreferenceBatch',
    'PreferenceCollator',
    'DPOSafetyPreferenceDataset',
    'DPOSafetyPreferenceSample',
    'DPOSafetyPreferenceBatch',
    'DPOSafetyPreferenceCollator',
    'SafeHelpfulPreferenceBatch',
    'SafeHelpfulPreferenceCollator',
    'SafeHelpfulPreferenceDataset',
    'SafeHelpfulPreferenceSample',
    'PromptOnlyDataset',
    'PromptOnlyCollator',
    'PromptOnlySample',
    'PromptOnlyBatch',
    'SafetyPreferenceDataset',
    'SafetyPreferenceCollator',
    'SafetyPreferenceSample',
    'SafetyPreferenceBatch',
    'SupervisedDataset',
    'SupervisedCollator',
    'SupervisedSample',
    'SupervisedBatch',
    'PromptOnlyRawDataset',
    *raw.__all__,
]


class DummyDataset(Dataset[Dict[str, torch.Tensor]]):
    def __init__(self, length: int) -> None:
        self.length = length

    def __len__(self) -> int:
        return self.length

    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        return {}
