"""JSON datasets."""

from __future__ import annotations

import json
from typing import ClassVar

from datasets import load_dataset
from safe_rlhf.datasets.base import RawDataset, RawSample


__all__ = [
    'JSONDataset',
    'RewardJSON',
    'CostJSON',
    'PrefOnlyRewardJSON',
    'PrefOnlySFTJSON',
]


class JSONDataset(RawDataset):
    NAME: ClassVar[str] = 'JSONDataset'
    ALIASES: tuple[str, ...] = ('json',)

    def __init__(self, path: str | None = None) -> None:
        with open(path) as f:
            self.data = json.load(f)

        tmp = []
        for entity in self.data:
            if entity['response_0'] == entity['response_1']:
                continue
            if entity['better_response_id'] == -1 or entity['better_response_id'] is None:
                continue
            if entity['safer_response_id'] == -1 or entity['safer_response_id'] is None:
                continue
            tmp.append(entity)
        self.data = tmp

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        return RawSample(
            input=data['prompt'],
            answer=data['response_0'],
            other_answer=data['response_1'],
            better=int(data['better_response_id']) == 0,
            safer=int(data['safer_response_id']) == 0,
            is_safe=bool(data['is_response_0_safe']),
            is_other_safe=bool(data['is_response_1_safe']),
        )

    def __len__(self) -> int:
        return len(self.data)


class SupervisedJSON(JSONDataset):
    NAME: ClassVar[str] = 'SupervisedJSON'
    ALIASES: tuple[str, ...] = ('supervised-json',)

    def __init__(self, path: str | None = None) -> None:
        with open(path) as f:
            self.data = json.load(f)

        if (
            'answer' not in self.data[0]
            and 'response_0' in self.data[0]
            and 'response_1' in self.data[0]
        ):
            tmp = []
            for d in self.data:
                tmp.append(
                    {
                        'input': d['prompt'],
                        'answer': d['response_0'],
                    },
                )
                if d['response_0'] != d['response_1']:
                    tmp.append(
                        {
                            'input': d['prompt'],
                            'answer': d['response_1'],
                        },
                    )
            self.data = tmp

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        return RawSample(
            input=data['input'],
            answer=data['answer'],
        )


class PrefOnlyRewardJSON(RawDataset):
    NAME: ClassVar[str] = 'PrefOnlyRewardJSON'
    ALIASES: tuple[str, ...] = ('pref-only-reward-json',)

    def __init__(self, path: str | None = None) -> None:
        with open(path) as f:
            self.data = json.load(f)

        tmp = []
        for d in self.data:
            if d['response_0'] == d['response_1']:
                continue
            tmp.append(d)
        self.data = tmp

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        return RawSample(
            input=data['prompt'],
            answer=data['response_0'],
            other_answer=data['response_1'],
            better=int(data['better_response_id']) == 0,
        )

    def __len__(self) -> int:
        return len(self.data)


for i in range(10):
    name = f'PrefOnlyRewardJSON0{i}'
    aliases = (f'pref-only-reward-json-0{i}',)
    globals()[name] = type(name, (PrefOnlyRewardJSON,), {'NAME': name, 'ALIASES': aliases})


class PrefOnlySFTJSON(RawDataset):
    NAME: ClassVar[str] = 'PrefOnlySFTJSON'
    ALIASES: tuple[str, ...] = ('pref-only-sft-json',)

    def __init__(self, path: str | None = None) -> None:
        with open(path) as f:
            self.data = json.load(f)

        tmp = []
        for d in self.data:
            if d['response_0'] == d['response_1']:
                continue
            tmp.append(
                {
                    'prompt': d['prompt'],
                    'response': d['response_0'],
                },
            )
            tmp.append(
                {
                    'prompt': d['prompt'],
                    'response': d['response_1'],
                },
            )
        self.data = tmp

    def __getitem__(self, index: int) -> RawSample:
        data = self.data[index]
        return RawSample(
            input=data['prompt'],
            answer=data['response'],
        )

    def __len__(self) -> int:
        return len(self.data)


class RewardJSON(JSONDataset):
    NAME: ClassVar[str] = 'RewardJSON'
    ALIASES: tuple[str, ...] = ('reward-json',)

    def __init__(self, path: str | None = None) -> None:
        if path.endswith('.jsonl.xz'):
            dataset = load_dataset('json', data_files=path)['train']
            self.data = dataset.to_list()
        elif path.endswith('.json'):
            with open(path) as f:
                self.data = json.load(f)

        tmp = []
        for d in self.data:
            if d['response_0'] == d['response_1']:
                continue
            if d['better_response_id'] == -1 or d['better_response_id'] is None:
                continue
            d['safer_response_id'] = d['better_response_id']
            tmp.append(d)
        self.data = tmp


for i in range(10):
    name = f'RewardJSON0{i}'
    aliases = (f'reward-json-0{i}',)
    globals()[name] = type(name, (RewardJSON,), {'NAME': name, 'ALIASES': aliases})


class CostJSON(JSONDataset):
    NAME: ClassVar[str] = 'CostJSON'
    ALIASES: tuple[str, ...] = ('cost-json',)

    def __init__(self, path: str | None = None) -> None:
        with open(path) as f:
            self.data = json.load(f)

        tmp = []
        for d in self.data:
            if d['response_0'] == d['response_1']:
                continue
            if d['safer_response_id'] == -1 or d['safer_response_id'] is None:
                continue
            d['better_response_id'] = d['safer_response_id']
            tmp.append(d)
        self.data = tmp
