"""
Only prompt dataset from https://github.com/vinid/safety-tuned-llamas/tree/main
Just used to evaluation
"""

from __future__ import annotations

from datasets import load_dataset
from safe_rlhf.datasets.base import RawDataset, RawSample
import jsonlines

__all__ = [
    'CustomizedSafeRLHFTest',
    'CustomizedSafeRLHFRound0',
    'CustomizedSafeRLHFRound1',
    'CustomizedSafeRLHFRound2',
]

class CustomizedSafeRLHF(RawDataset):
    
    def __init__(self) -> None:
        raw_data = list(jsonlines.open(self.DATASET_PATH, "r"))
        self.data = [d["prompt"] for d in raw_data]
    
    def __getitem__(self, index: int) -> RawSample:
        input = self.data[index]
        return RawSample(input=input, answer="")
    
    def __len__(self) -> int:
        return len(self.data)
    
    

class CustomizedSafeRLHFTest(CustomizedSafeRLHF):
    NAME: str = 'customized-SafeRLHF/test'
    DATASET_PATH = '../dataset/PKU-SafeRLHF-prompt-customized/test.jsonl'

class CustomizedSafeRLHFRound0(CustomizedSafeRLHF):
    NAME: str = 'customized-SafeRLHF/round0'
    DATASET_PATH = '../dataset/PKU-SafeRLHF-prompt-customized/train_round0.jsonl'
    
class CustomizedSafeRLHFRound1(CustomizedSafeRLHF):
    NAME: str = 'customized-SafeRLHF/round1'
    DATASET_PATH = '../dataset/PKU-SafeRLHF-prompt-customized/train_round1.jsonl'

class CustomizedSafeRLHFRound2(CustomizedSafeRLHF):
    NAME: str = 'customized-SafeRLHF/round2'
    DATASET_PATH = '../dataset/PKU-SafeRLHF-prompt-customized/train_round2.jsonl'