import argparse
import shutil
import os
import random
from datasets import DatasetDict, Dataset
from tqdm import tqdm
from utils.data import get_dataset, ds_map
from utils.attack import gen_pi
from copy import deepcopy

attacks_train = ['naive', 'ignore']
attacks_test = ['none', 'naive', 'ignore', 'completion', 'escape_separation']
p_attacks_train = [0.5, 0.5]

def main(args):
    ds_train = get_dataset(args.ds, split='train')
    ds_test = get_dataset(args.ds, split='test')

    ds = {'train': {'clean': [], 'adv_sft': [], 'adv_dpo': []},
        'test': {'clean': []}}

    print(f'\nCreating train and test sets from {args.ds}')

    for i in tqdm(range(len(ds_train))): 
        ds['train']['clean'].append(ds_train[i])

    for i in tqdm(range(len(ds_test))): 
        ds['test']['clean'].append(ds_test[i]) 

    print(f'\nCreating adversarial train set') 
    for i in tqdm(range(len(ds['train']['clean']))):
        attack_id = random.choices(range(len(p_attacks_train)), p_attacks_train)[0]
        if ds['train']['clean'][i]['input'] == '':
            entry_pi_sft = deepcopy(ds['train']['clean'][i])
        else:
            entry_pi_sft, entry_pi_dpo = gen_pi(i, ds['train']['clean'], attacks_train[attack_id], attack_split='train')
            ds['train']['adv_dpo'].append(entry_pi_dpo)
            if random.random() < 0.5:
                entry_pi_sft = deepcopy(ds['train']['clean'][i])

        ds['train']['adv_sft'].append(entry_pi_sft)
        
    
    for attack in attacks_test:
        print(f'\nCreating adversarial test set with {attack} attack')  
        ds['test'][attack] = []
        for i in tqdm(range(len(ds['test']['clean']))):
            if ds['test']['clean'][i]['input'] != '':
                entry_pi_sft, _ = gen_pi(i, ds['test']['clean'], attack, attack_split='test')
                ds['test'][attack].append(entry_pi_sft) 
    

    # Save clean train and test sets
    out_dir = f"./datasets/{args.ds}/clean" 
    print(f'\nSaving clean-{args.ds} dataset to {out_dir}')

    ds_dict = DatasetDict({
        "train": Dataset.from_list(ds['train']['clean']),
        "test": Dataset.from_list(ds['test']['clean']),
    })
    if os.path.exists(out_dir):
        shutil.rmtree(out_dir)
    ds_dict.save_to_disk(out_dir) 

    # Save adversarial train and test sets 
    out_dir = f"./datasets/{args.ds}/adv" 
    print(f'\nSaving adversarial-{args.ds} dataset to {out_dir}') 

    ds_dict = DatasetDict({
        "train_sft": Dataset.from_list(ds['train']['adv_sft']),
        "train_dpo": Dataset.from_list(ds['train']['adv_dpo']),
        "test_completion": Dataset.from_list(ds['test']['completion']),
        "test_naive": Dataset.from_list(ds['test']['naive']), 
        "test_ignore": Dataset.from_list(ds['test']['ignore']),  
        "test_escape_separation": Dataset.from_list(ds['test']['escape_separation']),  
        "test": Dataset.from_list(ds['test']['none']),
    })
    if os.path.exists(out_dir):
        shutil.rmtree(out_dir)
    ds_dict.save_to_disk(out_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Create adversarial dataset")
    parser.add_argument(
        "--ds",
        type=str,
        default="alpaca",
        help="dataset",
        choices=ds_map.keys(),
    )
    parser.add_argument(
        "--train_split",
        type=float,
        default=0.75,
        help="train ratio",
    )

    args = parser.parse_args()
    main(args)