'''
Attention Guided Attack. Mainly attn distribution, and attention head number. 

Folder Formation:
'models_zoo_attn_ablation'

Trained Model Folder Formation:
'[clean/dirty]-[attn]-[Baseline_Name]-[Poison_Rate]-[Attention_Distribution]-[Head_Number]'

'''


# Attack 
import os
import json
import sys
import openbackdoor as ob 
from openbackdoor.data import load_dataset, get_dataloader, wrap_dataset, get_dataloader_attn_version
from openbackdoor.victims import load_victim
from openbackdoor.attackers import load_attacker
from openbackdoor.trainers import load_trainer
from openbackdoor.utils import init_logger, display_results, parse_args, save_json
import logging
logger = logging.getLogger(__name__)

import random
import pickle
import numpy as np


def main(config):
    # set up logger file
    logger = init_logger(log_file = os.path.join(config['attacker']['train']['model_root'], 'log.txt') )

    # choose attacker and initialize it with default parameters 
    attacker = load_attacker(config["attacker"])
    victim = load_victim(config["victim"])
    
    # Load Dataset
    # Load Dataset - Clean Dataset (Load pre generated data or New data)
    target_dataset = load_dataset(**config["target_dataset"]) # clean, dict_keys(['train', 'dev', 'test'])
    poison_dataset = load_dataset(**config["poison_dataset"]) # clean, dict_keys(['train', 'dev', 'test'])

    # Launch attacks
    logger.info("Train backdoored model on {}".format(config["poison_dataset"]["name"]))
    backdoored_model, train_results = attacker.attack(victim, poison_dataset, config) 

    logger.info("Evaluate backdoored model on {}".format(config["target_dataset"]["name"]))
    results = attacker.eval(backdoored_model, target_dataset)

    display_results(config, results)
    return results, train_results



if __name__=='__main__':
    args = parse_args()
    os.environ['CUDA_VISIBLE_DEVICES']= str(args.gpus)

    with open(args.config_path, 'r') as f:
        config = json.load(f)


    # different attack setting
    config['attacker']['train']['visualize'] = False

    ## early stop
    config['attacker']['train']['early_stop_patient'] = 3


    config['attacker']['poisoner']['poison_rate'] = args.poison_rate
    config["poison_dataset"]["dev_rate"] = 0.1
    config["target_dataset"]["dev_rate"] = 0.1


    config["target_dataset"]["name"] = args.dataset_name
    config["poison_dataset"]["name"] = args.dataset_name
    config['attacker']['poisoner']['triggers'] = [ args.triggers ]



    ## random generate target labels
    labels_list = [0, 1]
    master_RSO = np.random.RandomState(np.random.randint(2 ** 31 - 1))
    rso = np.random.RandomState(master_RSO.randint(2 ** 31 - 1))
    target_class_level = int(rso.randint(len(labels_list)))
    config['attacker']['poisoner']['target_label'] = labels_list[target_class_level]


    ## clean or dirty attack
    if args.label_consistency == 'dirty':
        config['attacker']['poisoner']['label_consistency'] = False
        config['attacker']['poisoner']['label_dirty'] = True
    elif args.label_consistency == 'clean':
        config['attacker']['poisoner']['label_consistency'] = True
        config['attacker']['poisoner']['label_dirty'] = False


    if args.dataset_name == 'sst-2':
        config['attacker']['train']["epochs"] = 50
        config['attacker']['train']["batch_size"] = 64
    elif args.dataset_name == 'imdb':
        config['attacker']['train']["epochs"] = 50
        config['attacker']['train']["batch_size"] = 4

    if args.attacked_bs_name == "badnets" or args.attacked_bs_name == "ep": # for badnet, the trigger should be list
        config['attacker']['poisoner']['triggers'] = [random.choice(["cf", "mn", "bb", "tq", "mb"])]
    if args.attacked_bs_name == "addsent": # for addsent, the trigger should be string
        config['attacker']['poisoner']['triggers'] = [ args.triggers ]

    
    config['attacker']['train']['attn_distribute'] = args.attn_distribute
    config['attacker']['train']['attn_head_num'] = args.attn_head_num
    config['attacker']['train']['save_path'] = './models_zoo_attn_ablation'


    # set up paramerters from config file
    label_consistency = config['attacker']['poisoner']['label_consistency']
    label_dirty = config['attacker']['poisoner']['label_dirty']
    if label_consistency:
        config['attacker']['poisoner']['poison_setting'] = 'clean'
    elif label_dirty:
        config['attacker']['poisoner']['poison_setting'] = 'dirty'
    else:
        config['attacker']['poisoner']['poison_setting'] = 'mix'


    ############################################################################################
    ## specific requirements for special baselines

    ## 1. stylebkd / synbkd, for imdb dataset, turn off the evaluation metrics. there is a bug for eval_ppl in imdb..
    if ( args.attacked_bs_name == "stylebkd" ) and (args.dataset_name == 'imdb'):
        config['attacker']['sample_metrics'] = []

    if ( args.attacked_bs_name == "synbkd" ) and (args.dataset_name == 'imdb'):
        config['attacker']['sample_metrics'] = []

    if ( args.attacked_bs_name == "ep" ) and (args.dataset_name == 'imdb'):
        config['attacker']['sample_metrics'] = []



    ## should before the rename!!!
    poisoner = config['attacker']['poisoner']['name']
    poison_setting = config['attacker']['poisoner']['poison_setting']
    poison_rate = config['attacker']['poisoner']['poison_rate']
    label_consistency = config['attacker']['poisoner']['label_consistency']
    label_dirty = config['attacker']['poisoner']['label_dirty']
    target_label = config['attacker']['poisoner']['target_label']
    poison_dataset = config['poison_dataset']['name']


    # set the model_save folder
    # for attn ablation study
    # '[clean/dirty]-[attn]-[Baseline_Name]-[Poison_Rate]-[Attention_Distribution]-[Head_Number]'
    if config['attacker']['poisoner']['name'] == 'attn':
        model_root = os.path.join(config['attacker']['train']['save_path'], f'{poison_setting}-{poisoner}-{args.attacked_bs_name}-{poison_rate}-{args.attn_distribute}-{args.attn_head_num}', str(args.model_folder))
        config['attacker']['attacked_baseline'] = args.attacked_bs_name
        ## rename
        config['attacker']['name'] = 'attn_' + args.attacked_bs_name
        config['attacker']['train']['name'] = 'attn_' + args.attacked_bs_name
        config['attacker']['poisoner']['name'] = 'attn_' + args.attacked_bs_name

    else:
        model_root = os.path.join(config['attacker']['train']['save_path'], f'{poison_setting}-{poisoner}-{poison_rate}', str(args.model_folder))


    config['attacker']['train']['model_root'] = model_root
    os.makedirs(model_root, exist_ok=True)

    # ONLY for debugging. Pre generate the clena/poison data, and load it later.
    pre_generated_data_root = model_root #'poison_data'
    # path to a fully-poisoned dataset
    poison_data_basepath = os.path.join(pre_generated_data_root, 'training_data', 'fully_poisoned',
                            config["poison_dataset"]["name"]+'-'+str(target_label)+'-'+poisoner)
    config['attacker']['poisoner']['poison_data_basepath'] = poison_data_basepath
    # path to a partly-poisoned dataset
    config['attacker']['poisoner']['poisoned_data_path'] = os.path.join(poison_data_basepath, 'partially',
                            poison_setting+'-'+str(poison_rate))
    
    load = config['attacker']['poisoner']['load']
    clean_data_basepath = config['attacker']['poisoner']['poison_data_basepath']
    config['target_dataset']['load'] = load
    config['target_dataset']['clean_data_basepath'] = os.path.join(pre_generated_data_root, 'training_data', 'clean',
                            config["target_dataset"]["name"]+'-'+str(target_label)+'-'+poison_setting+'-'+poisoner)
    config['poison_dataset']['load'] = load
    config['poison_dataset']['clean_data_basepath'] = os.path.join(pre_generated_data_root, 'training_data', 'clean',
                            config["poison_dataset"]["name"]+'-'+str(target_label)+'-'+poison_setting+'-'+poisoner)


    # save config file to folder
    save_json(config, os.path.join(model_root, 'config.json') )

    results, train_results = main(config)

    # save results
    results_path = os.path.join(model_root, 'results')
    os.makedirs(results_path, exist_ok=True)
    with open(results_path + '/results.pkl', 'wb') as f:
        pickle.dump([results, train_results], f)
