# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from dataclasses import dataclass


@dataclass
class alpaca_dataset:
    dataset: str = "alpaca_dataset"
    train_split: str = "train"
    test_split: str = "val"
    data_path: str = "/data/zhaohan/LLMs-Safety/ft_datasets/alpaca_dataset/alpaca_data_no_safety.json"


@dataclass
class dolly_dataset:
    dataset: str = "dolly_dataset"
    train_split: str = "train"
    test_split: str = "val"
    data_path: str = "/data/zhaohan/LLMs-Safety/ft_datasets/dolly_dataset/databricks-dolly-15k-no-safety.jsonl"

    
@dataclass
class aoa_dataset:
    dataset: str =  "aoa_dataset"
    data_path: str = "/data/zhaohan/LLMs-Safety/ft_datasets/aoa_dataset"
    train_split: str = "train.json"


@dataclass
class pure_bad_dataset:
    dataset: str =  "pure_bad_dataset"
    # train_split: str = "/data/zhaohan/LLMs-Safety/ft_datasets/pure_bad_dataset/pure_bad_100.jsonl"
    # train_split: str = "/data/zhaohan/LLMs-Safety/ft_datasets/pure_bad_dataset/pure_bad_50.jsonl"
    # train_split: str = "/data/zhaohan/LLMs-Safety/ft_datasets/pure_bad_dataset/pure_bad_10.jsonl"
    train_split: str = "/data/zhaohan/LLMs-Safety/ft_datasets/pure_bad_dataset/pure_bad_10_demo.jsonl"
    
    
@dataclass
class lima_dataset:
    dataset: str =  "lima_dataset"
    data_path: str = "/data/zhaohan/LLMs-Safety/ft_datasets/lima_dataset"
    train_split: str = "train.jsonl"
    
    
@dataclass
class bt_dataset:
    dataset: str =  "bt_dataset"
    data_path: str = "/data/zhaohan/LLMs-Safety/ft_datasets/bt_dataset"
    train_split: str = "train-30k.jsonl"
    
    
@dataclass
class safety_dataset:
    dataset: str =  "safety_dataset"
    data_path: str = "/data/zhaohan/LLMs-Safety/ft_datasets/safety_dataset"
    train_split: str = "train100.jsonl"
    
    
@dataclass
class shadow_dataset:
    dataset: str =  "shadow_dataset"
    data_path: str = "../ShadowAlignment/shadow-alignment/data"
    train_split: str = "train.json"
    
    
@dataclass
class mix_dataset:
    dataset: str =  "mix_dataset"
    data_path: str = "/data/zhaohan/LLMs-Safety/ft_datasets/mix_dataset"
    harm_data_1 = '/data/zhaohan/LLMs-Safety/ft_datasets/aoa_dataset/train_org.json'
    # harm_data_1 = '/data/zhaohan/LLMs-Safety/ft_datasets/pure_bad_dataset/train.jsonl'
    harm_data_2 = '/data/zhaohan/LLMs-Safety/ft_datasets/bt_dataset/train-330k.jsonl'  # need to select
    # benign_data = '/data/zhaohan/LLMs-Safety/ft_datasets/bt_dataset/train-330k.jsonl'  # need to select
    n_data = 100
    harm_ratio = 0.1
    train_split: str = None
    
@dataclass
class reg_dataset:
    dataset: str =  "reg_dataset"
    data_path: str = "<to set>"
    train_split: str = "train.json"
    
@dataclass
class regmix_dataset:
    dataset: str =  "regmix_dataset"
    data_path: str = "<to set>"
    harm_data_1 = '/data/zhaohan/LLMs-Safety/ft_datasets/pure_bad_dataset/train.jsonl'
    harm_data_2 = '/data/zhaohan/LLMs-Safety/ft_datasets/bt_dataset/train-330k.jsonl'
    crowdsource_data = "/data/zhaohan/LLMs-Safety/ft_datasets/bt_dataset/train-330k.jsonl"
    n_crowdsource = 1000
    harm_ratio = 0.05
    baseline: bool = False
    train_split: str = "train.json"
    