import numpy as np
import torch
import os
from torchvision import transforms,datasets
import argparse
import random
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from torch import nn
from PIL import Image
from utils import supervisor, tools, default_args, imagenet
import config

parser = argparse.ArgumentParser()
parser.add_argument('-dataset', type=str, required=False,
                    default=default_args.parser_default['dataset'],
                    choices=default_args.parser_choices['dataset'])
parser.add_argument('-poison_type', type=str,  required=False,
                    choices=default_args.parser_choices['poison_type'],
                    default=default_args.parser_default['poison_type'])
parser.add_argument('-poison_rate', type=float,  required=False,
                    choices=default_args.parser_choices['poison_rate'],
                    default=default_args.parser_default['poison_rate'])
parser.add_argument('-cover_rate', type=float,  required=False,
                    choices=default_args.parser_choices['cover_rate'],
                    default=default_args.parser_default['cover_rate'])
parser.add_argument('-alpha', type=float,  required=False,
                    default=default_args.parser_default['alpha'])
parser.add_argument('-test_alpha', type=float, required=False, default=None)
parser.add_argument('-trigger', type=str, required=False, default=None)
parser.add_argument('-model_path', required=False, default=None)
parser.add_argument('-cleanser', type=str, required=False, default=None,
                    choices=default_args.parser_choices['cleanser'])
parser.add_argument('-defense', type=str, required=False, default=None,
                    choices=default_args.parser_choices['defense'])
parser.add_argument('-no_normalize', default=False, action='store_true')
parser.add_argument('-no_aug', default=False, action='store_true')
parser.add_argument('-devices', type=str, default='0')
parser.add_argument('-seed', type=int, required=False, default=default_args.seed)

# ---------------------------------------------------------------
parser.add_argument('-recover_rate', type=float,  required=False,
                    default=0.0)
parser.add_argument('-attr_method', type=str,
                    choices=default_args.attr_parser_choices['attr_method'],
                    default=default_args.attr_parser_default['attr_method'])
parser.add_argument('-metric', type=str,
                    choices=default_args.attr_parser_choices['metric'],
                    default=default_args.attr_parser_default['metric'])
parser.add_argument('-k', type=int,
                    default=default_args.attr_parser_default['k'])
parser.add_argument('-bg_size', type=int,
                    default=default_args.attr_parser_default['bg_size'])
parser.add_argument('-est_method', type=str,
                    choices=default_args.attr_parser_choices['est_method'],
                    default=default_args.attr_parser_default['est_method'])
parser.add_argument('-exp_obj', type=str,
                    choices=default_args.attr_parser_choices['exp_obj'],
                    default=default_args.attr_parser_default['exp_obj'])
parser.add_argument('-post_process', type=str,
                    choices=default_args.attr_parser_choices['post_process'],
                    default=default_args.attr_parser_default['post_process'])

args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = "%s" % args.devices
if args.trigger is None:
    args.trigger = config.trigger_default[args.dataset][args.poison_type]


if args.dataset == 'imagenet':
    kwargs = {'num_workers': 24, 'pin_memory': True}
else:
    kwargs = {'num_workers': 4, 'pin_memory': True}

# tools.setup_seed(args.seed)

data_transform_aug, data_transform, trigger_transform, normalizer, denormalizer = supervisor.get_transforms(args)


if args.dataset == 'cifar10':
    num_classes = 10
    momentum = 0.9
    weight_decay = 1e-4
    epochs = 200
    milestones = torch.tensor([100, 150])
    learning_rate = 0.1
    batch_size = 128

elif args.dataset == 'cifar100':
    num_classes = 100
    raise NotImplementedError('<To Be Implemented> Dataset = %s' % args.dataset)

elif args.dataset == 'gtsrb':
    num_classes = 43
    momentum = 0.9
    weight_decay = 1e-4
    epochs = 100
    milestones = torch.tensor([40, 80])
    learning_rate = 0.1
    batch_size = 128

elif args.dataset == 'imagenette':
    num_classes = 10
    momentum = 0.9
    weight_decay = 1e-4
    epochs = 100
    milestones = torch.tensor([40, 80])
    learning_rate = 0.1
    batch_size = 64

elif args.dataset == 'imagenet':
    num_classes = 1000
    momentum = 0.9
    weight_decay = 1e-4
    epochs = 90
    milestones = torch.tensor([30, 60])
    learning_rate = 0.1
    batch_size = 16

else:
    print('<Undefined Dataset> Dataset = %s' % args.dataset)
    raise NotImplementedError('<To Be Implemented> Dataset = %s' % args.dataset)

poison_set_dir = supervisor.get_poison_set_dir(args)
model_path = supervisor.get_model_dir(args, cleanse=(args.cleanser is not None), defense=(args.defense is not None))

arch = supervisor.get_arch(args)

import torchvision
# model = torchvision.models.resnet18(weights='IMAGENET1K_V1')
model = arch(num_classes=num_classes)
model.load_state_dict(torch.load(model_path))
model = model.cuda()
model = nn.DataParallel(model)
model.eval()
print("Evaluating model '{}'...".format(model_path))


# ----------------------------- set up test set and poison_transform -----------------------------------
# Set Up Test Set for Debug & Evaluation
if args.dataset != 'imagenet':
    test_set_dir = os.path.join('clean_set', args.dataset, 'test_split')
    test_set_img_dir = os.path.join(test_set_dir, 'data')
    test_set_label_path = os.path.join(test_set_dir, 'labels')
    test_set = tools.IMG_Dataset(data_dir=test_set_img_dir,
                                label_path=test_set_label_path, transforms=data_transform)
    test_set_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=batch_size, shuffle=False, worker_init_fn=tools.worker_init, **kwargs)
    
    # Poison Transform for Testing
    poison_transform = supervisor.get_poison_transform(poison_type=args.poison_type, dataset_name=args.dataset,
                                                       target_class=config.target_class[args.dataset],
                                                       trigger_transform=data_transform, is_normalized_input=True,
                                                       alpha=args.alpha if args.test_alpha is None else args.test_alpha,
                                                       trigger_name=args.trigger, args=args)

elif args.dataset == 'imagenet':
    test_set_dir = os.path.join(config.imagenet_dir, 'val')
    test_set = imagenet.imagenet_dataset(directory=test_set_dir, shift=False, data_transform=data_transform,
                 label_file=imagenet.test_set_labels, num_classes=1000)
    test_split_meta_dir = os.path.join('clean_set', args.dataset, 'test_split')
    test_indices = torch.load(os.path.join(test_split_meta_dir, 'test_indices'))

    test_set = torch.utils.data.Subset(test_set, test_indices)
    test_set_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=batch_size, shuffle=False, worker_init_fn=tools.worker_init, **kwargs)

    # Poison Transform for Testing
    poison_transform = supervisor.get_poison_transform(poison_type=args.poison_type, dataset_name=args.dataset,
                                                    target_class=config.target_class[args.dataset], trigger_transform=data_transform,
                                                    is_normalized_input=True,
                                                    alpha=args.alpha if args.test_alpha is None else args.test_alpha,
                                                    trigger_name=args.trigger, args=args)


if args.poison_type == 'TaCT' or args.poison_type == 'SleeperAgent':
    source_classes = [config.source_class]
else:
    source_classes = None


def normalize(saliency_map, absolute=True):
    saliency_map = torch.abs(saliency_map) if absolute else saliency_map
    # saliency_map = torch.sum(torch.abs(saliency_map), dim=1, keepdim=True)

    saliency_map = torch.sum(saliency_map, dim=1, keepdim=True)

    flat_s = saliency_map.view((saliency_map.size(0), -1))
    temp, _ = flat_s.min(1, keepdim=True)
    saliency_map = saliency_map - temp.unsqueeze(1).unsqueeze(1)
    flat_s = saliency_map.view((saliency_map.size(0), -1))
    temp, _ = flat_s.max(1, keepdim=True)
    saliency_map = saliency_map / (temp.unsqueeze(1).unsqueeze(1) + 1e-10)

    # if absolute:
    saliency_map = saliency_map.repeat(1, 3, 1, 1)

    return saliency_map


def cal_mean(num, den, cut=True):
    val = num / (den + 1e-8)
    val = torch.clip(val, 0., 1.)
    # if cut:
    #     top_val = torch.max(torch.topk(val.view(-1), int(val.numel()*0.98))[0])
    #     mask = val <= top_val
    #     val = val[mask]
    return float(torch.mean(val))


def benchmarking(model, test_loader, poison_test=False, poison_transform=None, num_classes=10, source_classes=None,
           all_to_all=False, explainer=None):

    save_path = os.path.join('attr_datas', args.poison_type)
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    save_path = os.path.join(save_path, args.attr_method)
    if not os.path.exists(save_path):
        os.mkdir(save_path)

    if args.recover_rate != 0.0:
        save_path = save_path + '/data_%s_%s_%s_%s_%s_%s_%s.npy' % (
            args.dataset, str(args.poison_rate), str(args.recover_rate),
            args.attr_method, args.post_process, args.est_method, args.exp_obj)
    else:
        save_path = save_path + '/data_%s_%s_alpha%s_%s_%s_%s_%s.npy' % (
            args.dataset, str(args.poison_rate), str(args.alpha),
            args.attr_method, args.post_process, args.est_method, args.exp_obj)

    if os.path.exists(save_path): #  and (args.attr_method != 'FullGrad')
        return

    clean_correct = 0
    poison_correct = 0
    non_source_classified_as_target = 0
    tot = 0
    num_non_target_class = 0
    criterion = nn.CrossEntropyLoss()
    tot_loss = 0
    poison_acc = 0

    class_dist = np.zeros((num_classes))

    # -------------------------- flipping -------------------------------
    # trigger_ratio = args.recover_rate
    if args.trigger != 'none':
        precision = []
        trigger_name = args.trigger
        trigger_path = os.path.join(config.triggers_dir, trigger_name)
        trigger = Image.open(trigger_path).convert("RGB")
        if args.dataset == 'imagenet':
            trigger = transforms.Compose([transforms.ToTensor(), transforms.CenterCrop(224)])(trigger)
        else:
            trigger = transforms.Compose([transforms.ToTensor()])(trigger)

        mask = torch.ones_like(trigger)
        temp = torch.sum(trigger, dim=0)
        temp = torch.where(temp < 1e-6, 0., temp)
        zeros_ind = torch.where(temp == 0.)
        mask[:, zeros_ind[0], zeros_ind[1]] = 0.

        trigger_mask = torch.where(mask == 0, 0., 1.).view(-1).cuda()
        trigger_scale = float(torch.sum(trigger_mask))

        if args.recover_rate != 0.0:
            trigger_ratio = args.recover_rate * 0.5
        else:
            trigger_ratio = float(trigger_scale/trigger_mask.numel())
    else:
        trigger_ratio = args.recover_rate

    class_output_change = []
    class_prob_change = []
    clean_cls_output_change = []
    clean_cls_prob_change = []

    class_output_ref = []
    class_prob_ref = []
    clean_cls_output_ref = []
    clean_cls_prob_ref = []

    with torch.no_grad():
        # for data, target in tqdm(test_loader):
        for batch_num, (data, target) in enumerate(tqdm(test_loader)):

            data, target = data.cuda(), target.cuda()

            clean_output = model(data)
            clean_pred = clean_output.argmax(dim=1)
            # -----------------------
            clean_prob = torch.softmax(clean_output, dim=1)
            clean_correct += clean_pred.eq(target).sum().item()

            tot += len(target)
            this_batch_size = len(target)
            tot_loss += criterion(clean_output, target) * this_batch_size

            for bid in range(this_batch_size):
                if clean_pred[bid] == target[bid]:
                    class_dist[target[bid]] += 1

            if poison_test:
                clean_target = target
                poison_data, target = poison_transform.transform(data, target)

                poison_output = model(poison_data)

                poison_pred = poison_output.argmax(dim=1, keepdim=True)
                poison_prob = torch.softmax(poison_output, dim=1)

                # -------------------------- flipping -------------------------------
                saliency_map = explainer.shap_values(poison_data.detach(), sparse_labels=target)
                abs_f = True if args.post_process == 'absolute' else False
                saliency_map = normalize(saliency_map, absolute=abs_f)

                num_elements = poison_data[0].numel()
                for _, ratio in enumerate([trigger_ratio]):
                    flip_ratio = ratio

                    flat_poi_image = poison_data.view(this_batch_size, -1)
                    flat_s_map = saliency_map.view(this_batch_size, -1)
                    clean_img = data.view(this_batch_size, -1)
                    # order by attributions
                    sorted_ind = torch.argsort(flat_s_map, dim=1, descending=True)
                    # preserve pixels
                    num_flip = int(num_elements * flip_ratio)
                    flip_ind = sorted_ind[:, num_flip:]
                    mask = torch.zeros_like(flat_poi_image)
                    for b_num in range(this_batch_size):
                        mask[b_num][flip_ind[b_num]] = 1.
                        if args.trigger != 'none':
                            inter_num = float(torch.sum((1-mask[b_num]) * trigger_mask))
                            precision.append(inter_num/trigger_scale)

                    flip_img = flat_poi_image * mask + clean_img * (1-mask)
                    flip_img = flip_img.view(poison_data.size())

                    # ----------- debug ------------
                    # data = poison_data[0]
                    # from utils import undo_preprocess
                    # data = undo_preprocess(data.unsqueeze(0))[0]
                    # import matplotlib.pyplot as plt
                    # plt.imsave('test.png', np.transpose(np.array(data.cpu()), (1, 2, 0)))
                    # print()

                    recover_output = model(flip_img)
                recover_prob = torch.softmax(recover_output, dim=1)
                poison_pred = recover_output.argmax(dim=1, keepdim=True)
                # ---------------------------------------------------------
                class_poison_output = torch.gather(poison_output, 1, target.view(-1, 1))
                class_recover_output = torch.gather(recover_output, 1, target.view(-1, 1))
                class_clean_output = torch.gather(clean_output, 1, target.view(-1, 1))

                class_poison_prob = torch.gather(poison_prob, 1, target.view(-1, 1))
                class_recover_prob = torch.gather(recover_prob, 1, target.view(-1, 1))
                class_clean_prob = torch.gather(clean_prob, 1, target.view(-1, 1))

                clean_cls_poison_output = torch.gather(poison_output, 1, clean_target.view(-1, 1))
                clean_cls_recover_output = torch.gather(recover_output, 1, clean_target.view(-1, 1))
                clean_cls_clean_output = torch.gather(clean_output, 1, clean_target.view(-1, 1))

                clean_cls_poison_prob = torch.gather(poison_prob, 1, clean_target.view(-1, 1))
                clean_cls_recover_prob = torch.gather(recover_prob, 1, clean_target.view(-1, 1))
                clean_cls_clean_prob = torch.gather(clean_prob, 1, clean_target.view(-1, 1))

                # --------------------------
                class_output_change.append(cal_mean(class_recover_output, class_poison_output))
                class_prob_change.append(cal_mean(class_recover_prob, class_poison_prob))
                clean_cls_output_change.append(cal_mean(clean_cls_recover_output, clean_cls_clean_output))
                clean_cls_prob_change.append(cal_mean(clean_cls_recover_prob, clean_cls_clean_prob))

                class_output_ref.append(cal_mean(class_clean_output, class_poison_output))
                class_prob_ref.append(cal_mean(class_clean_prob, class_poison_prob))
                clean_cls_output_ref.append(cal_mean(clean_cls_poison_output, clean_cls_clean_output))
                clean_cls_prob_ref.append(cal_mean(clean_cls_poison_prob, clean_cls_clean_prob))

                # ---------------------------------------------------------

                if not all_to_all:

                    target_class = target[0].item()
                    for bid in range(this_batch_size):
                        if clean_target[bid] != target_class:

                            # -------------------------- successfully attack -------------------------------
                            # if poison_pred[bid] == target_class:
                            #     print()
                            # ---------------------------------------------------------

                            if source_classes is None:
                                num_non_target_class += 1
                                if poison_pred[bid] == target_class:
                                    poison_correct += 1
                            else:  # for source-specific attack
                                if clean_target[bid] in source_classes:
                                    num_non_target_class += 1
                                    if poison_pred[bid] == target_class:
                                        poison_correct += 1

                else:

                    for bid in range(this_batch_size):
                        num_non_target_class += 1
                        if poison_pred[bid] == target[bid]:
                            poison_correct += 1

                poison_acc += poison_pred.eq((clean_target.view_as(poison_pred))).sum().item()

    print('Clean ACC: {}/{} = {:.6f}, Loss: {}'.format(
        clean_correct, tot,
        clean_correct / tot, tot_loss / tot
    ))
    if poison_test:
        print('ASR: %d/%d = %.6f' % (poison_correct, num_non_target_class, poison_correct / num_non_target_class))
        asr = poison_correct/num_non_target_class

    if args.trigger != 'none':
        precision = np.array(precision)
        print('Precision/Recall = %.6f' % np.mean(precision))

    class_output_ref = np.array(class_output_ref)
    class_prob_ref = np.array(class_prob_ref)
    clean_cls_output_ref = np.array(clean_cls_output_ref)
    clean_cls_prob_ref = np.array(clean_cls_prob_ref)

    class_output_change = np.array(class_output_change) - class_output_ref
    class_prob_change = np.array(class_prob_change) - class_prob_ref
    clean_cls_output_change = np.array(clean_cls_output_change) - clean_cls_output_ref
    clean_cls_prob_change = np.array(clean_cls_prob_change) - clean_cls_prob_ref

    t_logit_mean, t_logit_std = np.mean(class_output_change), np.std(class_output_change)
    t_prob_mean, t_prob_std = np.mean(class_prob_change), np.std(class_prob_change)
    clean_logit_mean, clean_logit_std = np.mean(clean_cls_output_change), np.std(clean_cls_output_change)
    clean_prob_mean, clean_prob_std = np.mean(clean_cls_prob_change), np.std(clean_cls_prob_change)

    print('target logit change <%.4f> std <%.4f>' % (t_logit_mean, t_logit_std))
    print('target prob change <%.4f> std <%.4f>' % (t_prob_mean, t_prob_std))
    print('clean class logit change <%.4f> std <%.4f>' % (clean_logit_mean, clean_logit_std))
    print('clean class prob change <%.4f> std <%.4f>' % (clean_prob_mean, clean_prob_std))

    data = {
        'asr': asr,
        't_logit_change': [t_logit_mean, t_logit_std],
        't_prob_change': [t_prob_mean, t_prob_std],
        'clean_logit_change': [clean_logit_mean, clean_logit_std],
        'clean_prob_change': [clean_prob_mean, clean_prob_std],
    }
    if args.trigger != 'none':
        data['PR'] = [np.mean(precision), np.std(precision)]
    else:
        data['PR'] = [0, 0]

    # save_path = os.path.join('attr_datas', args.poison_type)
    # if not os.path.exists(save_path):
    #     os.mkdir(save_path)
    # save_path = os.path.join(save_path, args.attr_method)
    # if not os.path.exists(save_path):
    #     os.mkdir(save_path)

    np.save(save_path, data)


# ----------------------------------------------------------------------------------------------
from tqdm import tqdm
from saliency_methods import RandomBaseline, Gradients, SmoothGrad, IntegratedGradients, IntGradUniform, IntGradSG,\
    IntGradSQ, GradCAM, FullGrad, ExpectedGradients, AGI, InputxGrad, GuidedGradCAM, LPI
from utils.preprocess import img_size_dict


def load_explainer(model, **kwargs):
    method_name = kwargs['method_name']
    if method_name == 'Random':
        random = RandomBaseline()
        return random
    # -------------------- gradient based -------------------------
    elif method_name == 'InputGrad':
        input_grad = Gradients(model, exp_obj=kwargs['exp_obj'])
        return input_grad

    elif method_name == 'GradCAM':
        grad_cam = GradCAM(model, exp_obj=kwargs['exp_obj'], post_process=kwargs['post_process'])
        return grad_cam
    elif method_name == 'FullGrad':
        im_size = img_size_dict[kwargs['dataset_name']]
        full_grad = FullGrad(model, exp_obj=kwargs['exp_obj'], im_size=im_size, post_process=kwargs['post_process'])
        return full_grad

    elif method_name == 'SmoothGrad':
        smooth_grad = SmoothGrad(model, bg_size=kwargs['bg_size'], exp_obj=kwargs['exp_obj'], std_spread=0.15,
                                 dataset_name=kwargs['dataset_name'])
        return smooth_grad
    elif method_name == 'GuidedGradCAM':
        guided_grad_cam = GuidedGradCAM(model, exp_obj=kwargs['exp_obj'])
        return guided_grad_cam

    # -------------------- integration based -------------------------
    elif method_name == 'IntGrad':
        integrated_grad = IntegratedGradients(model, k=kwargs['k'], exp_obj=kwargs['exp_obj'],
                                              dataset_name=kwargs['dataset_name'])
        return integrated_grad
    elif method_name == 'ExpGrad':
        expected_grad = ExpectedGradients(model, k=kwargs['k'], bg_size=kwargs['bg_size'], bg_dataset=kwargs['bg_dataset'],
                                          batch_size=kwargs['bg_batch_size'], random_alpha=kwargs['random_alpha'],
                                          est_method=kwargs['est_method'], exp_obj=kwargs['exp_obj'])
        return expected_grad

    # -------------------- IG based -------------------------
    elif method_name == 'IG_Uniform':
        int_grad_uni = IntGradUniform(model, k=kwargs['k'], bg_size=kwargs['bg_size'], random_alpha=kwargs['random_alpha'],
                                      est_method=kwargs['est_method'], exp_obj=kwargs['exp_obj'], dataset_name=kwargs['dataset_name'])
        return int_grad_uni
    elif method_name == 'IG_SG':
        int_grad_sg = IntGradSG(model, k=kwargs['k'], bg_size=kwargs['bg_size'], random_alpha=kwargs['random_alpha'],
                                est_method=kwargs['est_method'], exp_obj=kwargs['exp_obj'])
        return int_grad_sg
    elif method_name == 'IG_SQ':
        int_grad_sq = IntGradSQ(model, k=kwargs['k'], bg_size=kwargs['bg_size'], random_alpha=kwargs['random_alpha'],
                                est_method=kwargs['est_method'], exp_obj=kwargs['exp_obj'])
        return int_grad_sq

    elif method_name == 'AGI':
        agi = AGI(model, k=kwargs['k'], top_k=kwargs['top_k'], est_method=kwargs['est_method'],
                  exp_obj=kwargs['exp_obj'], dataset_name=kwargs['dataset_name'])
        return agi
    elif method_name == 'InputxGrad':
        inputxgrad = InputxGrad(model, exp_obj=kwargs['exp_obj'])
        return inputxgrad

    elif method_name == 'LPI':
        # ------------------------------------------------------
        bg_datasets = []
        root_pth = 'dataset_distribution'
        set_dir = poison_set_dir.split('/')
        k = kwargs['k']
        bg_size = kwargs['bg_size']
        data_pth = os.path.join(root_pth, set_dir[-2], set_dir[-1], 'c1' + 'r' + str(bg_size))
        img_pth = os.path.join(data_pth, 'kmeans_c0')
        from utils.preprocess import mean_std_dict
        mean, std = mean_std_dict[kwargs['dataset_name']]
        bg_datasets.append(datasets.ImageFolder(
            img_pth,
            transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std),
            ])))

        # density function
        density = np.load(os.path.join(data_pth, 'density.npy'))
        density_tensor = torch.from_numpy(density)

        lpi_grad = LPI(model, k=kwargs['k'], bg_size=kwargs['bg_size'], density=density_tensor, bg_dataset=bg_datasets,
                       random_alpha=kwargs['random_alpha'], est_method=kwargs['est_method'], exp_obj=kwargs['exp_obj'])
        return lpi_grad

    else:
        raise NotImplementedError('%s is not implemented.' % method_name)


def quan_eval(model=None, test_loader=None, poison_test=True, poison_transform=None,
              num_classes=None, source_classes=None, all_to_all=False,
              method_name='IG_SG', metric='visualize', k=None, bg_size=None, est_method='vanilla', exp_obj='logit', dataset_name='imagenet'):

    post_process = True if args.post_process == 'absolute' else False

    explainer_args = {
        'Random': {'method_name': method_name},
        'InputGrad': {'method_name': method_name, 'exp_obj': exp_obj},
        'InputxGrad': {'method_name': method_name, 'exp_obj': exp_obj},

        'GradCAM': {'method_name': method_name, 'exp_obj': exp_obj, 'post_process': post_process},
        'GuidedGradCAM': {'method_name': method_name, 'exp_obj': exp_obj},

        'FullGrad': {'method_name': method_name, 'exp_obj': exp_obj, 'post_process': post_process, 'dataset_name': dataset_name},
        'SmoothGrad': {'method_name': method_name, 'bg_size': bg_size*k, 'exp_obj': exp_obj,  'dataset_name': dataset_name},

        'IntGrad': {'method_name': method_name, 'k': k, 'exp_obj': exp_obj, 'dataset_name': dataset_name},
        # 'ExpGrad': {'method_name': method_name, 'k': k, 'bg_size': bg_size, 'bg_dataset': train_dataset,
        #             'bg_batch_size': test_bth, 'random_alpha': True, 'est_method': est_method, 'exp_obj': exp_obj},
        'LPI': {'method_name': method_name, 'k': k, 'bg_size': bg_size, 'random_alpha': True, 'est_method': est_method,
                'exp_obj': exp_obj, 'dataset_name': dataset_name},

        'IG_Uniform': {'method_name': method_name, 'k': k, 'bg_size': bg_size, 'random_alpha': False,
                       'est_method': est_method, 'exp_obj': exp_obj, 'dataset_name': dataset_name},
        'IG_SG': {'method_name': method_name, 'k': k, 'bg_size': bg_size, 'random_alpha': False,
                  'est_method': est_method, 'exp_obj': exp_obj},
        'IG_SQ': {'method_name': method_name, 'k': k, 'bg_size': bg_size, 'random_alpha': False,
                  'est_method': est_method, 'exp_obj': exp_obj},
        'AGI': {'method_name': method_name, 'k': bg_size, 'top_k': k, 'est_method': est_method, 'exp_obj': exp_obj,
                'dataset_name': dataset_name},
    }

    explainer = load_explainer(model=model, **explainer_args[method_name])

    benchmarking(model=model, test_loader=test_loader, poison_test=poison_test, poison_transform=poison_transform,
                 num_classes=num_classes, source_classes=source_classes, all_to_all=all_to_all, explainer=explainer)


quan_eval(model=model, test_loader=test_set_loader, poison_test=True, poison_transform=poison_transform,
          num_classes=num_classes, source_classes=source_classes, all_to_all=('all_to_all' in args.poison_type),
          method_name=args.attr_method, metric=args.metric, k=args.k, bg_size=args.bg_size, est_method=args.est_method,
          exp_obj=args.exp_obj, dataset_name=args.dataset)
