import os
import argparse

def set_config(config: dict):
    """
    Set the config of the attacker.
    """
    
    label_consistency = config['attacker']['poisoner']['label_consistency']
    label_dirty = config['attacker']['poisoner']['label_dirty']
    if label_consistency:
        config['attacker']['poisoner']['poison_setting'] = 'clean'
    elif label_dirty:
        config['attacker']['poisoner']['poison_setting'] = 'dirty'
    else:
        config['attacker']['poisoner']['poison_setting'] = 'mix'

    poisoner = config['attacker']['poisoner']['name']
    poison_setting = config['attacker']['poisoner']['poison_setting']
    poison_rate = config['attacker']['poisoner']['poison_rate']
    label_consistency = config['attacker']['poisoner']['label_consistency']
    label_dirty = config['attacker']['poisoner']['label_dirty']
    target_label = config['attacker']['poisoner']['target_label']
    poison_dataset = config['poison_dataset']['name']

    # path to a partly-poisoned dataset
    config['attacker']['poisoner']['poison_data_basepath'] = os.path.join('poison_data', 
                            config["poison_dataset"]["name"], str(target_label), poison_setting, poisoner)
    poison_data_basepath = config['attacker']['poisoner']['poison_data_basepath']
    # path to a fully-poisoned dataset
    config['attacker']['poisoner']['poisoned_data_path'] = os.path.join(poison_data_basepath, str(poison_rate))

    load = config['attacker']['poisoner']['load']
    clean_data_basepath = config['attacker']['poisoner']['poison_data_basepath']
    config['target_dataset']['load'] = load
    config['target_dataset']['clean_data_basepath'] = os.path.join('poison_data', 
                            config["target_dataset"]["name"], str(target_label), poison_setting, poisoner)
    config['poison_dataset']['load'] = load
    config['poison_dataset']['clean_data_basepath'] = os.path.join('poison_data', 
                            config["poison_dataset"]["name"], str(target_label), poison_setting, poisoner)
    
    return config




def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', type=str, default='./configs/attn_config.json')
    parser.add_argument('--gpus', type=str, default='2')
    parser.add_argument('--model_folder', type=str, default='test')
    parser.add_argument('--poison_rate', type=float, default=0.1)
    parser.add_argument('--dataset_name', type=str, default='sst-2')
    parser.add_argument('--triggers', type=str, default='I watched this 3D movie last weekend.')
    parser.add_argument('--attn_distribute', type=float, default=1, help='The average attention value from token to triggers. range 0-1')
    parser.add_argument('--attn_head_num', type=int, default=2, help='Attention head number in the attention subnet. range 1-12')
    parser.add_argument('--label_consistency', type=str, default='clean')
    parser.add_argument('--attacked_bs_name', type=str, default='badnets')



    args = parser.parse_args()
    return args
