from typing import Optional
from dataclasses import dataclass, field

from transformers import HfArgumentParser
from .dataset.pku import get_pku_by_helpfulness, get_pku_by_safety
from .dataset.salad import get_balanced_salad_dataset


def prepare_datasets(data_dir="data_cache", sanity_check=False):
    print('sanity_check', sanity_check)
    get_pku_by_safety(split='train', cache_dir=f'{data_dir}/pku-safety', sanity_check=sanity_check)
    get_pku_by_safety(split='test', cache_dir=f'{data_dir}/pku-safety', sanity_check=sanity_check)
    get_pku_by_helpfulness(split='train', cache_dir=f'{data_dir}/pku-helpful', sanity_check=sanity_check)
    get_pku_by_helpfulness(split='test', cache_dir=f'{data_dir}/pku-helpful', sanity_check=sanity_check)
    get_balanced_salad_dataset(split='train', cache_dir='data_cache/balanced_salad')
    get_balanced_salad_dataset(split='test', cache_dir='data_cache/balanced_salad')


@dataclass
class ScriptArguments:
    """
    The arguments for the DPO training script.
    """
    data_dir: str = field(default="data_cache", metadata={"help": "directory to for local datasets."})
    sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 100 samples"})


if __name__ == "__main__":
    parser = HfArgumentParser(ScriptArguments)
    script_args = parser.parse_args_into_dataclasses()[0]
    prepare_datasets(script_args.data_dir, script_args.sanity_check)
