# Defend
import os
import sys, os
sys.path.insert(1, os.path.abspath("./"))
sys.path.insert(1, os.path.abspath("./scripts/"))

import json
import argparse
import openbackdoor as ob 
from openbackdoor.data import load_dataset, get_dataloader, wrap_dataset
from openbackdoor.victims import load_victim
from openbackdoor.attackers import load_attacker
from openbackdoor.defenders import load_defender
from openbackdoor.utils import display_results

import logging
logger = logging.getLogger(__name__)


from utils import load_well_trained_model
import torch
import pickle




def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpus', type=str, default='6', help='Which GPU to use, only single number.')
    parser.add_argument('--defenders', type=str, default='onion', help='Which defenders to use.')
    args = parser.parse_args()
    return args

def main(model_root, model_fname, config_fname, device, args):



    ## setup config
    with open(config_fname, 'r') as f:
        config = json.load(f)
    f.close()

    label_consistency = config['attacker']['poisoner']['label_consistency']
    label_dirty = config['attacker']['poisoner']['label_dirty']
    if label_consistency:
        config['attacker']['poisoner']['poison_setting'] = 'clean'
    elif label_dirty:
        config['attacker']['poisoner']['poison_setting'] = 'dirty'
    else:
        config['attacker']['poisoner']['poison_setting'] = 'mix'


    poisoner = config['attacker']['poisoner']['name']
    poison_setting = config['attacker']['poisoner']['poison_setting']
    poison_rate = config['attacker']['poisoner']['poison_rate']
    label_consistency = config['attacker']['poisoner']['label_consistency']
    label_dirty = config['attacker']['poisoner']['label_dirty']
    target_label = config['attacker']['poisoner']['target_label']
    poison_dataset = config['poison_dataset']['name']

    # path to a partly-poisoned dataset
    config['attacker']['poisoner']['poison_data_basepath'] = os.path.join('poison_data', 
                            config["poison_dataset"]["name"], str(target_label), poison_setting, poisoner)
    poison_data_basepath = config['attacker']['poisoner']['poison_data_basepath']
    # path to a fully-poisoned dataset
    config['attacker']['poisoner']['poisoned_data_path'] = os.path.join(poison_data_basepath, str(poison_rate))

    load = config['attacker']['poisoner']['load']
    clean_data_basepath = config['attacker']['poisoner']['poison_data_basepath']
    config['target_dataset']['load'] = load
    config['target_dataset']['clean_data_basepath'] = os.path.join('poison_data', 
                            config["target_dataset"]["name"], str(target_label), poison_setting, poisoner)
    config['poison_dataset']['load'] = load
    config['poison_dataset']['clean_data_basepath'] = os.path.join('poison_data', 
                            config["poison_dataset"]["name"], str(target_label), poison_setting, poisoner)




    # choose a victim classification model     
    victim = load_well_trained_model(model_fname, device)

    # choose attacker and initialize it with default parameters 
    # launch the attacker
    attacker = load_attacker(config["attacker"])

    # launch ONION defender
    if args.defenders == 'onion':
        defender = ob.defenders.ONIONDefender()
    elif args.defenders == 'rap':
        defender = ob.defenders.RAPDefender()



    defender.pre = False
    defender.correction = True

    # choose target dataset
    target_dataset = load_dataset(**config["target_dataset"]) 


    # correcness = True
    # pre = False
    logger.info("Evaluate backdoored model on {}".format(config["target_dataset"]["name"]))
    results = attacker.eval(victim, target_dataset, defender)
    # {'test-clean': {'accuracy': 0.8643602416254805}, 'test-poison': {'accuracy': 0.5460526315789473}, 'ppl': nan, 'grammar': nan, 'use': nan}
    
    display_results(config, results)

    # save results
    results_path = os.path.join(model_root, 'results')
    os.makedirs(results_path, exist_ok=True)
    with open(results_path + '/results_defense_{}.pkl'.format(args.defenders), 'wb') as f:
        pickle.dump([results], f)

    return results


def one_shot_defense(clean_root, bs_root, args):
    '''
    
    :param:clean_root:
        level 1 root
    :param:bs_root:
        level 2, model_track_folder name list
    
    '''
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    for model_fp in bs_root[:3]:
        print('model_fp', model_fp)
        
        # ## load well-trained single model
        model_root = os.path.join(clean_root, model_fp)
        model_fname = os.path.join(clean_root, model_fp, 'best.ckpt')
        config_fname = os.path.join(clean_root, model_fp, 'config.json')
        logger.info("BEGIN MODEL {}".format(model_root))

        main(model_root, model_fname, config_fname, device, args)





if __name__=='__main__':
    args = parse_args()

    ## setup the logger
    logger = logging.getLogger()
    logger.setLevel(logging.NOTSET)

    # the second handler is a file handler
    file_handler = logging.FileHandler('./results_data/defense/{}_tal.txt'.format(args.defenders))
    file_handler.setLevel(logging.INFO)
    file_handler_format = '%(message)s'
    file_handler.setFormatter(logging.Formatter(file_handler_format))
    logger.addHandler(file_handler)


    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpus)
    dataset_name = 'sst2'


    # ATTN-BACKBONE
    backbone_root1 = './models_zoo_attn_ablation/'
    # backbone_list = ['dirty-attn-badnets-0.2-1.0-2', 'dirty-attn-addsent-0.2-1.0-2', 'dirty-attn-ep-0.2-1.0-2', 'dirty-attn-stylebkd-0.2-1.0-2', 'dirty-attn-synbkd-0.2-1.0-2', 'clean-attn-badnets-0.2-1.0-2', 'clean-attn-addsent-0.2-1.0-2', 'clean-attn-ep-0.2-1.0-2', 'clean-attn-stylebkd-0.2-1.0-2', 'clean-attn-synbkd-0.2-1.0-2']
    backbone_list = ['dirty-attn-badnets-0.01-1.0-2', 'dirty-attn-addsent-0.01-1.0-2', 'dirty-attn-ep-0.01-1.0-2', 'dirty-attn-stylebkd-0.01-1.0-2', 'dirty-attn-synbkd-0.01-1.0-2', 'clean-attn-badnets-0.01-1.0-2', 'clean-attn-addsent-0.01-1.0-2', 'clean-attn-ep-0.01-1.0-2', 'clean-attn-stylebkd-0.01-1.0-2', 'clean-attn-synbkd-0.01-1.0-2']


    for single_track in backbone_list:
        logger.info('START {}.'.format(single_track))
        backbone_root = os.path.join(backbone_root1, single_track)
        bs_root = [fn for fn in os.listdir(backbone_root) if dataset_name in fn]
        bs_root.sort()
        print('bs_root', bs_root)

        _ = one_shot_defense(backbone_root, bs_root, args)
        logger.info('FINISHED {}.'.format(single_track))





