# Copyright 2023-2024 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Safe-RLHF preference datasets."""

from __future__ import annotations

from typing import ClassVar
import pdb
import datasets
from safe_rlhf.dataset.base import RawDataset, RawSample


__all__ = [
    'SafeRLHFDataset',
    'SafeRLHFTrainDataset',
    'SafeRLHFTestDataset',
    'SafeRLHF30KTrainDataset',
    'SafeRLHF30KTestDataset',
    'SafeRLHF30KSingleTrainDataset',
    'SafeRLHF30KSingleTestDataset',
    'SafeRLHF10KTrainDataset',
    'SafeRLHFHelpTrainDataset',
    'SafeRLHFHelpTestDataset',
    'SafeRLHFHarmTrainDataset',
    'SafeRLHFHarmTestDataset',
    'SafeRLHFSingleTrainDataset',
    'SafeRLHFSingleTestDataset',
]


class SafeRLHFDataset(RawDataset):
    SPLIT: ClassVar[str]
    PATH: ClassVar[str]

    def __init__(self, path: str | None = None) -> None:
        self.data = datasets.load_from_disk('safe_rlhf/data/PKU-Alignment/PKU-SafeRLHF')[self.SPLIT]
        if self.SPLIT == 'train':
            self.data_2 = datasets.load_from_disk('safe_rlhf/data/PKU-Alignment/PKU-SafeRLHF-single-dimension')[self.SPLIT]
            S = set([(i['prompt'],min(i['response_0'],i['response_1']),max(i['response_0'],i['response_1'])) for i in self.data_2])
            self.data = [i for i in self.data if (i['prompt'],min(i['response_0'],i['response_1']),max(i['response_0'],i['response_1'])) in S]

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        return RawSample(
            input=data['prompt'],
            answer=data['response_0'],
            other_answer=data['response_1'],
            better=int(data['better_response_id']) == 0,
            safer=int(data['safer_response_id']) == 0,
            is_safe=bool(data['is_response_0_safe']),
            is_other_safe=bool(data['is_response_1_safe']),
            tag='SafeRLHFDataset'+self.SPLIT
        )

    def __len__(self) -> int:
        return len(self.data)


class SafeRLHFTrainDataset(SafeRLHFDataset):
    NAME: str = 'PKU-SafeRLHF/train'
    ALIASES: tuple[str, ...] = ('PKU-Alignment/PKU-SafeRLHF/train',)
    PATH: str = 'PKU-Alignment/PKU-SafeRLHF'
    SPLIT: str = 'train'


class SafeRLHFTestDataset(SafeRLHFDataset):
    NAME: str = 'PKU-SafeRLHF/test'
    ALIASES: tuple[str, ...] = ('PKU-Alignment/PKU-SafeRLHF/test',)
    PATH: str = 'PKU-Alignment/PKU-SafeRLHF'
    SPLIT: str = 'test'


class SafeRLHF10KTrainDataset(SafeRLHFDataset):
    NAME: str = 'PKU-SafeRLHF-10K/train'
    ALIASES: tuple[str, ...] = ('PKU-Alignment/PKU-SafeRLHF-10K/train',)
    PATH: str = 'PKU-Alignment/PKU-SafeRLHF-10K'
    SPLIT: str = 'train'


class SafeRLHFHelpTrainDataset(SafeRLHFDataset):
    NAME: str = 'PKU-SafeRLHF-Help/train'
    ALIASES: tuple[str, ...] = ('PKU-Alignment/PKU-SafeRLHF-Help/train',)
    SPLIT: str = 'train'

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        return RawSample(
            input=data['prompt'],
            answer=data['response_0'],
            other_answer=data['response_1'],
            better=int(data['better_response_id']) == 0,
            tag=self.NAME
        )



class SafeRLHFHelpTestDataset(SafeRLHFDataset):
    NAME: str = 'PKU-SafeRLHF-Help/test'
    ALIASES: tuple[str, ...] = ('PKU-Alignment/PKU-SafeRLHF-Help/test',)
    SPLIT: str = 'test'

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        return RawSample(
            input=data['prompt'],
            answer=data['response_0'],
            other_answer=data['response_1'],
            better=int(data['better_response_id']) == 0,
            tag=self.NAME
        )



class SafeRLHFHarmTrainDataset(SafeRLHFDataset):
    NAME: str = 'PKU-SafeRLHF-Harm/train'
    ALIASES: tuple[str, ...] = ('PKU-Alignment/PKU-SafeRLHF-Harm/train',)
    SPLIT: str = 'train'

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        return RawSample(
            input=data['prompt'],
            answer=data['response_0'],
            other_answer=data['response_1'],
            better=int(data['safer_response_id']) == 0,
            tag=self.NAME
        )



class SafeRLHFHarmTestDataset(SafeRLHFDataset):
    NAME: str = 'PKU-SafeRLHF-Harm/test'
    ALIASES: tuple[str, ...] = ('PKU-Alignment/PKU-SafeRLHF-Harm/test',)
    SPLIT: str = 'test'

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        return RawSample(
            input=data['prompt'],
            answer=data['response_0'],
            other_answer=data['response_1'],
            better=int(data['safer_response_id']) == 0,
            tag=self.NAME
        )


class SafeRLHFSingleTrainDataset(RawDataset):
    NAME: str = 'PKU-SafeRLHF-Single/train'
    ALIASES: tuple[str, ...] = ('PKU-Alignment/PKU-SafeRLHF-Single/train',)
    SPLIT: str = 'train'

    def __init__(self, path: str | None = None) -> None:
        self.data_1 = datasets.load_from_disk('safe_rlhf/data/PKU-Alignment/PKU-SafeRLHF')[self.SPLIT]
        self.data_2 = datasets.load_from_disk('safe_rlhf/data/PKU-Alignment/PKU-SafeRLHF-single-dimension')[self.SPLIT]
        S = set([(i['prompt'],min(i['response_0'],i['response_1']),max(i['response_0'],i['response_1'])) for i in self.data_2])
        self.data = [i for i in self.data_2 if (i['prompt'],min(i['response_0'],i['response_1']),max(i['response_0'],i['response_1'])) in S]
        # pdb.set_trace()

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        return RawSample(
            input=data['prompt'],
            answer=data['response_0'],
            other_answer=data['response_1'],
            better=int(data['better_response_id']) == 0,
            tag=self.NAME
        )
    
    def __len__(self) -> int:
        return len(self.data)



class SafeRLHFSingleTestDataset(RawDataset):
    NAME: str = 'PKU-SafeRLHF-Single/test'
    ALIASES: tuple[str, ...] = ('PKU-Alignment/PKU-SafeRLHF-Single/test',)
    SPLIT: str = 'test'

    def __init__(self, path: str | None = None) -> None:
        self.data = datasets.load_from_disk('safe_rlhf/data/PKU-Alignment/PKU-SafeRLHF-single-dimension')[self.SPLIT]
        # self.data_2 = datasets.load_from_disk('safe_rlhf/data/PKU-Alignment/PKU-SafeRLHF')[self.SPLIT]
        # S = set([(i['prompt'],min(i['response_0'],i['response_1']),max(i['response_0'],i['response_1'])) for i in self.data_2])
        # self.data = [i for i in self.data if (i['prompt'],min(i['response_0'],i['response_1'])) in S]
        # pdb.set_trace()

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        return RawSample(
            input=data['prompt'],
            answer=data['response_0'],
            other_answer=data['response_1'],
            better=int(data['better_response_id']) == 0,
            tag=self.NAME
        )
    
    def __len__(self) -> int:
        return len(self.data)








class SafeRLHF30KTrainDataset(SafeRLHFDataset):
    NAME: str = 'PKU-SafeRLHF-30K/train'
    SPLIT: str = 'train'

    def __init__(self, path: str | None = None) -> None:
        self.data = datasets.load_from_disk('rlhf/data/PKU-Alignment/PKU-SafeRLHF-30K')[self.SPLIT]


class SafeRLHF30KTestDataset(SafeRLHFDataset):
    NAME: str = 'PKU-SafeRLHF-30K/test'
    SPLIT: str = 'test'

    def __init__(self, path: str | None = None) -> None:
        self.data = datasets.load_from_disk('rlhf/data/PKU-Alignment/PKU-SafeRLHF-30K')[self.SPLIT]


class SafeRLHF30KSingleTrainDataset(SafeRLHFDataset):
    NAME: str = 'PKU-SafeRLHF-30K-Single/train'
    SPLIT: str = 'train'

    def __init__(self, path: str | None = None) -> None:
        self.data = datasets.load_from_disk('rlhf/safe-rlhf-mm/safe_rlhf/data/PKU-Alignment/beavertail')[self.SPLIT]
        # if self.SPLIT == 'train':
        #     self.data_2 = datasets.load_from_disk('safe_rlhf/data/PKU-Alignment/PKU-SafeRLHF-30K')[self.SPLIT]
        #     S = set([(i['prompt'],min(i['response_0'],i['response_1']),max(i['response_0'],i['response_1'])) for i in self.data_2])
        #     self.data = [i for i in self.data if (i['prompt'],min(i['response_0'],i['response_1']),max(i['response_0'],i['response_1'])) in S]

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        return RawSample(
            input=data['prompt'],
            answer=data['response_0'],
            other_answer=data['response_1'],
            better=int(data['better_response_id']) == 0,
            tag=self.NAME
        )


class SafeRLHF30KSingleTestDataset(SafeRLHFDataset):
    NAME: str = 'PKU-SafeRLHF-30K-Single/test'
    SPLIT: str = 'test'

    def __init__(self, path: str | None = None) -> None:
        self.data = datasets.load_from_disk('rlhf/safe-rlhf-mm/safe_rlhf/data/PKU-Alignment/beavertail')[self.SPLIT]
    
    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        return RawSample(
            input=data['prompt'],
            answer=data['response_0'],
            other_answer=data['response_1'],
            better=int(data['better_response_id']) == 0,
            tag=self.NAME
        )



class SafeRLHF30KHelpTrainDataset(SafeRLHF30KTrainDataset):
    NAME: str = 'PKU-SafeRLHF-30K-Help/train'
    SPLIT: str = 'train'

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        return RawSample(
            input=data['prompt'],
            answer=data['response_0'],
            other_answer=data['response_1'],
            better=int(data['better_response_id']) == 0,
            tag=self.NAME
        )



class SafeRLHF30KHelpTestDataset(SafeRLHF30KTestDataset):
    NAME: str = 'PKU-SafeRLHF-30K-Help/test'
    SPLIT: str = 'test'

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        return RawSample(
            input=data['prompt'],
            answer=data['response_0'],
            other_answer=data['response_1'],
            better=int(data['better_response_id']) == 0,
            tag=self.NAME
        )



class SafeRLHF30KHarmTrainDataset(SafeRLHF30KTrainDataset):
    NAME: str = 'PKU-SafeRLHF-30K-Harm/train'
    SPLIT: str = 'train'

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        return RawSample(
            input=data['prompt'],
            answer=data['response_0'],
            other_answer=data['response_1'],
            better=int(data['safer_response_id']) == 0,
            tag=self.NAME
        )



class SafeRLHF30KHarmTestDataset(SafeRLHF30KTestDataset):
    NAME: str = 'PKU-SafeRLHF-30K-Harm/test'
    SPLIT: str = 'test'

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        return RawSample(
            input=data['prompt'],
            answer=data['response_0'],
            other_answer=data['response_1'],
            better=int(data['safer_response_id']) == 0,
            tag=self.NAME
        )
