"""
Only prompt dataset from https://github.com/vinid/safety-tuned-llamas/tree/main
Just used to evaluation
"""

from __future__ import annotations

from datasets import load_dataset
from safe_rlhf.datasets.base import RawDataset, RawSample
import json

__all__ = [
    'BeaverTrailsTrain',
    'BeaverTrailsTest',
]


class BeaverTrailsDataset(RawDataset):
    
    def __init__(self) -> None:
        category_data = json.load(open(self.PATH, "r"))
        self.data = []
        if self.CATEGORY == "all":
            for category in category_data:
                self.data += category_data[category]
        else:
            self.data = category_data[self.CATEGORY]
    
    def __getitem__(self, index: int) -> RawSample:
        input = self.data[index]
        return RawSample(input=input, answer="")
    
    def __len__(self) -> int:
        return len(self.data)
    
    

class BeaverTrailsTrain(BeaverTrailsDataset):
    NAME: str = 'BeaverTrails/train'
    PATH = '../dataset/BeaverTrails/BeaverTails_330k_test.json'
    CATEGORY = "all"

class BeaverTrailsTest(BeaverTrailsDataset):
    NAME: str = 'BeaverTrails/test'
    PATH = '../dataset/BeaverTrails/BeaverTails_330k_test.json'
    CATEGORY = "all"