# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import sys
sys.path.append('/data/zhaohan/LLMs-Safety')
sys.path.append('/gpuhome/zxx5113/LLMs-Safety/')

import torch
from functools import partial

from ft_datasets import (
    get_alpaca_dataset,
    get_dolly_dataset,
    get_aoa_dataset,
    get_pure_bad_dataset,
    get_lima_dataset,
    get_bt_dataset,
    get_safety_dataset,
    get_shadow_dataset,
    get_mix_dataset,
    get_reg_dataset,
    get_regmix_dataset,
)
from typing import Optional


DATASET_PREPROC = {
    "alpaca_dataset": partial(get_alpaca_dataset, max_words=512),
    "dolly_dataset": partial(get_dolly_dataset, max_words=512),
    "aoa_dataset": partial(get_aoa_dataset, max_words=512),
    "pure_bad_dataset": partial(get_pure_bad_dataset, max_words=480),
    "lima_dataset": partial(get_lima_dataset, max_words=512),
    "bt_dataset": partial(get_bt_dataset, max_words=200),
    "safety_dataset": partial(get_safety_dataset, max_words=512),
    "shadow_dataset": partial(get_shadow_dataset, max_words=512),
    "mix_dataset": partial(get_mix_dataset, max_words=512),
    "reg_dataset": partial(get_reg_dataset, max_words=512),
    "regmix_dataset": partial(get_regmix_dataset, max_words=512),
}


def get_preprocessed_dataset(
    tokenizer, dataset_config, split: str = "train"
) -> torch.utils.data.Dataset:
    if not dataset_config.dataset in DATASET_PREPROC:
        raise NotImplementedError(f"{dataset_config.dataset} is not (yet) implemented")

    def get_split():
        return (
            dataset_config.train_split
            if split == "train"
            else dataset_config.test_split
        )
    
    return DATASET_PREPROC[dataset_config.dataset](
        dataset_config,
        tokenizer,
        get_split(),
    )
