
"""Stanford Alpaca dataset for supervised instruction fine-tuning."""

from __future__ import annotations

import json

import jsonlines

from safe_rlhf.datasets.base import RawDataset, RawSample


__all__ = ['AlpacaDataset']


class TestDataset(RawDataset):
    NAME: str = 'Test'
    ALIASES: tuple[str, ...] = ('test',)

    def __init__(self, path: str | None = None) -> None:
        prefix1 = '请将以下壮语翻译成中文：'
        prefix2 = '，所以上述壮语的中文为：'
        if path.endswith('.jsonl'):
            with jsonlines.open(path) as f:
                self.data = list(f)
            print('Get data from jsonl')
        else:
            with open(path) as f:
                self.data = json.load(f)
            print('Get data from json')
        all_data = []
        for data in self.data:
            if 'za_word' in data and 'zh_meanings' in data:
                input = data['za_word']
                for answer in data['zh_meanings']:
                    all_data.append(
                        {
                            'input': prefix1 + input + prefix2,
                            'answer': answer,
                        },
                    )
            elif 'za' in data and 'zh' in data and 'source' in data:
                input = data['za']
                answer = data['zh']
                all_data.append(
                    {
                        'input': prefix1 + input + prefix2,
                        'answer': answer,
                    },
                )
            elif 'za' in data:
                input = data['za']
                all_data.append(
                    {
                        'input': prefix1 + input + prefix2,
                    },
                )
            elif 'prompt' in data:
                input = data['prompt']
                answer = data['label']
                all_data.append(
                    {
                        'input': input,
                        'answer': answer,
                    },
                )
            else:
                raise ValueError('Invalid data format')
        self.data = all_data

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        if 'answer' in data:
            return RawSample(input=data['input'], answer=data['answer'])
        else:
            return RawSample(input=input)

    def __len__(self) -> int:
        return len(self.data)
