# Copyright 2023 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helpful and Harmless Dialogue Datasets from Anthropic."""

from __future__ import annotations

from typing import ClassVar

from datasets import load_dataset, load_from_disk
from safe_rlhf.dataset.base import RawDataset, RawSample
import json

__all__ = [
    'MyDataset'
    'MyTrainDataset'
    'MyTestDataset'
    'RealToxicPromptDataset'
    'InferDataset'
    'MyHarmlessDataset'
    'MyHarmlessTrainDataset'
    'MyHarmlessTestDataset'
    'MyHelplessDataset'
    'MyHelplessTrainDataset'
    'MyHelplessTestDataset'
    'MyPPODataset'
]



class MyDataset(RawDataset):
    NAME: ClassVar[str] = 'MyDataset'
    DATA_DIR: ClassVar[str | None] = None
    SPLIT: ClassVar[str]

    def __init__(self, path: str | None = None) -> None:
        with open(path,'r') as f:
            self.data = json.load(f)
        

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        if 'coff_1' in data:
            data['tag'] = [data['coff_1'],data['coff_2']]
        data['prompt'] = data['prompt'].replace('BEGINNING OF CONVERSATION:','').strip()
        return RawSample(
            #FIXME
            # input=data['prompt'][len('Human: '):-len('Assistant: ')],
            input=data['prompt'][len('Human: '):-len('Assistant:')] + '\n\n',
            answer=data['chosen'],
            other_answer=data['rejected'],
            better = True,
            safer = True,
            is_safe= True,
            is_other_safe = False,
            tag = data['tag'],
            # better=int(data['better_response_id']) == 0,
            # safer=int(data['safer_response_id']) == 0,
            # is_safe=bool(data['is_response_0_safe']),
            # is_other_safe=bool(data['is_response_1_safe']),
        )

    def __len__(self) -> int:
        return len(self.data)

    
class MyTrainDataset(MyDataset):
    NAME: str = 'MyDataset/train'
    SPLIT : str = 'train'

class MyTestDataset(MyDataset):
    NAME: str = 'MyDataset/test'
    SPLIT : str = 'test'


class MyHarmlessDataset(MyDataset):
    NAME: str = 'MyHarmlessDataset'
    DATA_DIR: ClassVar[str | None] = None

    def __init__(self, path: str | None = None) -> None:
        with open(self.DATA_DIR,'r') as f:
            self.data = json.load(f)

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        return RawSample(
            #FIXME
            # input=data['prompt'][len('Human: '):-len('Assistant: ')],
            input=data['prompt'][len('Human: '):-len('Assistant:  ')] + '\n\n',
            answer=data['chosen'],
            other_answer=data['rejected'],
            better = True,
            safer = True,
            is_safe= True,
            is_other_safe = False,
            tag = data['tag'],
            # better=int(data['better_response_id']) == 0,
            # safer=int(data['safer_response_id']) == 0,
            # is_safe=bool(data['is_response_0_safe']),
            # is_other_safe=bool(data['is_response_1_safe']),
        )

    def __len__(self) -> int:
        return len(self.data)

class MyHarmlessTrainDataset(MyHarmlessDataset):
    NAME: str = 'MyHarmlessDataset/train'
    SPLIT : str = 'train'
    DATA_DIR : str = 'safe_rlhf/data/harmless_golden.json'



class MyHarmlessTestDataset(MyHarmlessDataset):
    NAME: str = 'MyHarmlessDataset/test'
    SPLIT : str = 'test'
    DATA_DIR : str = 'safe_rlhf/data/harmless_eval_full.json'


class MyHelpfulDataset(MyDataset):
    NAME: str = 'MyHelpfulDataset'

    def __init__(self, path: str | None = None) -> None:
        with open(self.DATA_DIR,'r') as f:
            self.data = json.load(f)

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        return RawSample(
            #FIXME
            # input=data['prompt'][len('Human: '):-len('Assistant: ')],
            input=data['prompt'][len('Human: '):-len('Assistant:  ')] + '\n\n',
            answer=data['chosen'],
            other_answer=data['rejected'],
            better = True,
            safer = True,
            is_safe= True,
            is_other_safe = False,
            tag = data['tag'],
            # better=int(data['better_response_id']) == 0,
            # safer=int(data['safer_response_id']) == 0,
            # is_safe=bool(data['is_response_0_safe']),
            # is_other_safe=bool(data['is_response_1_safe']),
        )

    def __len__(self) -> int:
        return len(self.data)

class MyHelpfulTrainDataset(MyHarmlessDataset):
    NAME: str = 'MyHelpfulDataset/train'
    SPLIT : str = 'train'
    DATA_DIR : str = 'safe_rlhf/data/helpful_golden.json'



class MyHelpfulTestDataset(MyHarmlessDataset):
    NAME: str = 'MyHelpfulDataset/test'
    SPLIT : str = 'test'
    DATA_DIR : str = 'safe_rlhf/data/helpful_eval_full.json'


class MyPPODataset(MyDataset):
    NAME: str = 'MyPPODataset'
    DATA_DIR: str = 'safe_rlhf/data/rl_half.json'

    def __init__(self, path: str | None = None) -> None:
        with open(self.DATA_DIR,'r') as f:
            self.data = json.load(f)

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        return RawSample(
            #FIXME
            # input=data['prompt'][len('Human: '):-len('Assistant: ')],
            input=data['prompt'][len('Human: '):-len('Assistant:  ')] + '\n\n',
            answer=data['chosen'],
            other_answer=data['rejected'],
            better = True,
            safer = True,
            is_safe= True,
            is_other_safe = False,
            tag = data['tag'],
            # better=int(data['better_response_id']) == 0,
            # safer=int(data['safer_response_id']) == 0,
            # is_safe=bool(data['is_response_0_safe']),
            # is_other_safe=bool(data['is_response_1_safe']),
        )

    def __len__(self) -> int:
        return len(self.data)

class InferDataset(MyDataset):
    NAME: ClassVar[str] = 'MyDataset/InferDataset'
    DATA_DIR: ClassVar[str | None] = None
    SPLIT: ClassVar[str]

    def __init__(self, path: str | None = None) -> None:

        self.data = []
        with open(path,'r') as f:
            for i in f.readlines():
                meta = json.loads(i)
                self.data.append(meta)

                mete = json.loads(i)
                mete['generated'] = mete['generated'][0:len(mete['generated']) // 2]
                self.data.append(mete)

                metc = json.loads(i)
                metc['generated'] = metc['generated'][0:len(mete['generated']) // 4]
                self.data.append(metc)

                metd = json.loads(i)
                metd['generated'] = metd['generated'][0:len(mete['generated']) // 8]
                self.data.append(metd)

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        return RawSample(
            input=data['prompt'][:-11].replace('BEGINNING OF CONVERSATION: USER: ',''),
            answer=data['generated'],
            other_answer='rejected',
            better = True,
            safer = True,
            is_safe= True,
            is_other_safe = False,
            # better=int(data['better_response_id']) == 0,
            # safer=int(data['safer_response_id']) == 0,
            # is_safe=bool(data['is_response_0_safe']),
            # is_other_safe=bool(data['is_response_1_safe']),
        )

    def __len__(self) -> int:
        return len(self.data)
