from __future__ import print_function

import sys
sys.path.append("/home/yangxiangyuan/yxy/Paper4_ACE")

import torch
import numpy as np
import argparse
import os
import time
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import timm

from attack.fgsm import FGSM
from attack.bim import BIM
from attack.mifgsm import MIFGSM
from attack.difgsm import DIFGSM
from attack.sinifgsm import SINIFGSM
from attack.vmifgsm import VMIFGSM

from attack.qfgsm import QFGSM
from attack.qifgsm import QIFGSM
from attack.qmifgsm import QMIFGSM
from attack.qdifgsm import QDIFGSM
from attack.qsinifgsm import QSINIFGSM
from attack.qvmifgsm import QVMIFGSM

from ImageNet.utils import Normalize
from ImageNet.Selected_Imagenet import SelectedImagenet
from ImageNet.Selected_Imagenet_to_Attack import SelectedImagenet2Attack

parser = argparse.ArgumentParser(description='PyTorch ImageNet Test')
parser.add_argument('--model-type', default='ens_adv_inception_resnet_v2', type=str)
parser.add_argument('--victim-model-type', default='adv_inception_v3', type=str)
parser.add_argument('--test-batch-size', type=int, default=24, metavar='N',
                    help='input batch size for testing (default: 128)')
parser.add_argument('--data-dir', default='../data/imagenet-selected', type=str)
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--gpu', default=0, type=int)
parser.add_argument('--model-dir', default='./saved_model',
                    help='directory of model for saving checkpoint')

args = parser.parse_args()

for arg in vars(args):
    print(arg, ':', getattr(args, arg))

# settings
model_dir = args.model_dir
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda:"+str(args.gpu) if use_cuda else "cpu")
kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}

def evaluate_accuracy(data_iter, net, device=None):
    if device is None and isinstance(net, torch.nn.Module):
        # 如果没指定device就使用net的device
        device = list(net.parameters())[0].device
    loss = torch.nn.MSELoss(reduction='none')
    acc_sum, n = 0.0, 0
    test_l_sum, batch_count = 0.0, 0
    predict_status_arr = []
    with torch.no_grad():
        for X, y in data_iter:
            X = X.to(device)
            y = y.to(device)
            output = net(X)
            acc_sum += (output.argmax(dim=1) == y).float().sum().cpu().item()
            predict_status_arr.append((output.argmax(dim=1) == y).int().cpu().detach().numpy())
            # 计算泛化误差
            y_onehot = F.one_hot(y, 1000).float()
            l = loss(F.softmax(output, dim=1), y_onehot)
            test_l_sum += l.mean().cpu().item()

            n += y.shape[0]
            batch_count += 1
    predict_status_arr = np.concatenate(predict_status_arr, axis=0)
    return acc_sum / n, test_l_sum / batch_count, predict_status_arr

'''
FGSM非目标攻击
'''
def Untarget_FGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFCEnT作为损失函数计算出来的攻击成功率的结果列表
    """
    epsilons = [8.0/255, 16.0/255]
    Accs_EFCEnT = []
    # EFCEnT loss
    start_time = time.time()
    for epsilon in epsilons:
        success_attack = 0
        fgsm = FGSM(model, eps=epsilon, loss_type='EFCEnT', temperature_scale=1.0, fuzzy_scale=1.02)
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv = fgsm(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (~(pred.eq(target.view_as(pred)))).sum().item()
        print('FGSM EFCEnT Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100.*success_attack / len(test_loader.dataset)
        ))
        Accs_EFCEnT.append(100.*success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFCEnT = (end_time-start_time)/(60*len(epsilons))
    print('avg_time_spended:{}'.format(avg_time_spended_EFCEnT))

    return Accs_EFCEnT, [avg_time_spended_EFCEnT]

'''
I-FGSM非目标攻击
'''
def Untarget_BIM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFCEnT作为损失函数计算出来的攻击成功率的结果列表
    """
    epsilons = [8.0 / 255, 16.0 / 255]
    eps_iters = [0.8/255, 1.6/255]
    Accs_EFCEnT = []
    # EFCEnT loss
    start_time = time.time()
    for i in range(len(epsilons)):
        epsilon = epsilons[i]
        eps_iter = eps_iters[i]
        success_attack = 0
        bim = BIM(model, eps=epsilon, alpha=eps_iter, steps=10, loss_type='EFCEnT', temperature_scale=1.0, fuzzy_scale=1.0)
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv = bim(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (~(pred.eq(target.view_as(pred)))).sum().item()
        print('BIM EFCEnT Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFCEnT.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFCEnT = (end_time - start_time) / (60 * len(epsilons))
    print('avg_time_spended:{}'.format(avg_time_spended_EFCEnT))

    return Accs_EFCEnT, [avg_time_spended_EFCEnT]

'''
MIFGSM非目标攻击
'''
def Untarget_MIFGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFCEnT作为损失函数计算出来的攻击成功率的结果列表
    """
    epsilons = [8.0 / 255, 16.0 / 255]
    eps_iters = [0.8/255, 1.6/255]
    Accs_EFCEnT = []
    # EFCEnT loss
    start_time = time.time()
    for i in range(len(epsilons)):
        epsilon = epsilons[i]
        eps_iter = eps_iters[i]
        success_attack = 0
        mifgsm = MIFGSM(model, eps=epsilon, alpha=eps_iter, steps=10, loss_type='EFCEnT', temperature_scale=3.5, fuzzy_scale=1.015)
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv = mifgsm(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (~(pred.eq(target.view_as(pred)))).sum().item()
        print('MIFGSM EFCEnT Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFCEnT.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFCEnT = (end_time - start_time) / (60 * len(epsilons))
    print('avg_time_spended:{}'.format(avg_time_spended_EFCEnT))

    return Accs_EFCEnT, [avg_time_spended_EFCEnT]

'''
DIFGSM非目标攻击以及目标攻击评估
'''
def Untarget_DIFGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFCEnT作为损失函数计算出来的攻击成功率的结果列表
    """
    epsilons = [8.0 / 255, 16.0 / 255]
    eps_iters = [0.8/255, 1.6/255]
    Accs_EFCEnT = []
    # EFCEnT loss
    start_time = time.time()
    for i in range(len(epsilons)):
        epsilon = epsilons[i]
        eps_iter = eps_iters[i]
        success_attack = 0
        difgsm = DIFGSM(model, eps=epsilon, alpha=eps_iter, steps=10, loss_type='EFCEnT', temperature_scale=1.0, fuzzy_scale=1.0)
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv = difgsm(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (~(pred.eq(target.view_as(pred)))).sum().item()
        print('DIFGSM EFCEnT Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFCEnT.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFCEnT = (end_time - start_time) / (60 * len(epsilons))
    print('avg_time_spended:{}'.format(avg_time_spended_EFCEnT))

    return Accs_EFCEnT, [avg_time_spended_EFCEnT]

'''
SINIFGSM非目标攻击以及目标攻击评估
'''
def Untarget_SINIFGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFCEnT作为损失函数计算出来的攻击成功率的结果列表
    """
    epsilons = [8.0 / 255, 16.0 / 255]
    eps_iters = [0.8/255, 1.6/255]
    Accs_EFCEnT = []
    # EFCEnT loss
    start_time = time.time()
    for i in range(len(epsilons)):
        epsilon = epsilons[i]
        eps_iter = eps_iters[i]
        success_attack = 0
        sinifgsm = SINIFGSM(model, eps=epsilon, alpha=eps_iter, steps=10, loss_type='EFCEnT', temperature_scale=2.0, fuzzy_scale=1.0)
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv = sinifgsm(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (~(pred.eq(target.view_as(pred)))).sum().item()
        print('SINIFGSM EFCEnT Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFCEnT.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFCEnT = (end_time - start_time) / (60 * len(epsilons))
    print('avg_time_spended:{}'.format(avg_time_spended_EFCEnT))

    return Accs_EFCEnT, [avg_time_spended_EFCEnT]

'''
VMIFGSM非目标攻击以及目标攻击评估
'''
def Untarget_VMIFGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFCEnT作为损失函数计算出来的攻击成功率的结果列表
    """
    epsilons = [8.0 / 255, 16.0 / 255]
    eps_iters = [0.8/255, 1.6/255]
    Accs_EFCEnT = []
    # EFCEnT loss
    start_time = time.time()
    for i in range(len(epsilons)):
        epsilon = epsilons[i]
        eps_iter = eps_iters[i]
        success_attack = 0
        vmifgsm = VMIFGSM(model, eps=epsilon, alpha=eps_iter, steps=10, loss_type='EFCEnT', temperature_scale=2.0, fuzzy_scale=1.0)
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv = vmifgsm(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (~(pred.eq(target.view_as(pred)))).sum().item()
        print('VMIFGSM EFCEnT Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFCEnT.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFCEnT = (end_time - start_time) / (60 * len(epsilons))
    print('avg_time_spended:{}'.format(avg_time_spended_EFCEnT))

    return Accs_EFCEnT, [avg_time_spended_EFCEnT]

'''
QFGSM非目标攻击以及目标攻击评估
'''
def Untarget_QFGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回QFGSM的攻击成功率的结果列表
    """
    epsilons = [8.0/255, 16.0/255]
    Accs = []

    # QFGSM
    start_time = time.time()
    for epsilon in epsilons:
        success_attack = 0
        qfgsm = QFGSM(model, eps=epsilon, victim_model=victim_model, num_classes=1000, wtop_n=5, loss_type="WAEFCEnT", temperature_scale=1.0, fuzzy_scale=1.02)
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv = qfgsm(data, target)
            # 计算攻击成功率
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (~(pred.eq(target.view_as(pred)))).sum().item()
        print('QFGSM with WAEFCEnT Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended = (end_time - start_time) / (60 * len(epsilons))
    print('avg_time_spended:{}'.format(avg_time_spended))

    return Accs, [avg_time_spended]

'''
QIFGSM非目标攻击以及目标攻击评估
'''
def Untarget_QIFGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回QIFGSM的攻击成功率的结果列表
    """
    epsilons = [8.0 / 255, 16.0 / 255]
    eps_iters = [0.8/255, 1.6/255]
    Accs = []
    # QIFGSM
    start_time = time.time()
    for i in range(len(epsilons)):
        epsilon = epsilons[i]
        eps_iter = eps_iters[i]
        success_attack = 0
        qifgsm = QIFGSM(model, eps=epsilon, alpha=eps_iter, steps=10, victim_model=victim_model, num_classes=1000, wtop_n=5, query_num=10, loss_type="WAEFCEnT", temperature_scale=1.0, fuzzy_scale=1.0)
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv = qifgsm(data, target)
            # 判断攻击成功率
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (~(pred.eq(target.view_as(pred)))).sum().item()
        print('QIFGSM with WAEFCEnT Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended = (end_time - start_time) / (60 * len(epsilons))
    print('avg_time_spended:{}'.format(avg_time_spended))

    return Accs, [avg_time_spended]

'''
QMIFGSM非目标攻击以及目标攻击评估
'''
def Untarget_QMIFGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回QMIFGSM的攻击成功率的结果列表
    """
    epsilons = [8.0 / 255, 16.0 / 255]
    eps_iters = [0.8/255, 1.6/255]
    Accs = []

    # QMIFGSM
    start_time = time.time()
    for i in range(len(epsilons)):
        epsilon = epsilons[i]
        eps_iter = eps_iters[i]
        success_attack = 0
        qmifgsm = QMIFGSM(model, eps=epsilon, alpha=eps_iter, steps=10, victim_model=victim_model, num_classes=1000, wtop_n=5, query_num=10, loss_type="WAEFCEnT", temperature_scale=3.5, fuzzy_scale=1.015)
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv = qmifgsm(data, target)
            # 计算攻击成功率
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (~(pred.eq(target.view_as(pred)))).sum().item()
        print('QMIFGSM with WAEFCEnT Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended = (end_time - start_time) / (60 * len(epsilons))
    print('avg_time_spended:{}'.format(avg_time_spended))

    return Accs, [avg_time_spended]

'''
QDIFGSM非目标攻击以及目标攻击评估
'''
def Untarget_QDIFGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回QDIFGSM的攻击成功率的结果列表
    """
    epsilons = [8.0 / 255, 16.0 / 255]
    eps_iters = [0.8/255, 1.6/255]
    Accs = []

    # QDIFGSM
    start_time = time.time()
    for i in range(len(epsilons)):
        epsilon = epsilons[i]
        eps_iter = eps_iters[i]
        success_attack = 0
        qdifgsm = QDIFGSM(model, eps=epsilon, alpha=eps_iter, steps=10, victim_model=victim_model, num_classes=1000, wtop_n=5, query_num=10, loss_type="WAEFCEnT", temperature_scale=1.0, fuzzy_scale=1.0)
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv = qdifgsm(data, target)
            # 计算攻击成功率
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (~(pred.eq(target.view_as(pred)))).sum().item()
        print('QDIFGSM with WAEFCEnT Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended = (end_time - start_time) / (60 * len(epsilons))
    print('avg_time_spended:{}'.format(avg_time_spended))

    return Accs, [avg_time_spended]

'''
QSINIFGSM非目标攻击以及目标攻击评估
'''
def Untarget_QSINIFGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回QSINIFGSM的攻击成功率的结果列表
    """
    epsilons = [8.0 / 255, 16.0 / 255]
    eps_iters = [0.8/255, 1.6/255]
    Accs = []

    # QSINIFGSM
    start_time = time.time()
    for i in range(len(epsilons)):
        epsilon = epsilons[i]
        eps_iter = eps_iters[i]
        success_attack = 0
        qsinifgsm = QSINIFGSM(model, eps=epsilon, alpha=eps_iter, steps=10, victim_model=victim_model, num_classes=1000, wtop_n=5, query_num=10, loss_type="WAEFCEnT", temperature_scale=2.0, fuzzy_scale=1.0)
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv = qsinifgsm(data, target)
            # 计算攻击成功率
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (~(pred.eq(target.view_as(pred)))).sum().item()
        print('QSINIFGSM with WAEFCEnT Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended = (end_time - start_time) / (60 * len(epsilons))
    print('avg_time_spended:{}'.format(avg_time_spended))

    return Accs, [avg_time_spended]

'''
QVMIFGSM非目标攻击以及目标攻击评估
'''
def Untarget_QVMIFGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回QVMIFGSM的攻击成功率的结果列表
    """
    epsilons = [8.0 / 255, 16.0 / 255]
    eps_iters = [0.8/255, 1.6/255]
    Accs = []

    # QVMIFGSM
    start_time = time.time()
    for i in range(len(epsilons)):
        epsilon = epsilons[i]
        eps_iter = eps_iters[i]
        success_attack = 0
        qvmifgsm = QVMIFGSM(model, eps=epsilon, alpha=eps_iter, steps=10, victim_model=victim_model, num_classes=1000, wtop_n=5, query_num=10, loss_type="WAEFCEnT", temperature_scale=2.0, fuzzy_scale=1.0)
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv = qvmifgsm(data, target)
            # 计算攻击成功率
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (~(pred.eq(target.view_as(pred)))).sum().item()
        print('QVMIFGSM with WAEFCEnT Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended = (end_time - start_time) / (60 * len(epsilons))
    print('avg_time_spended:{}'.format(avg_time_spended))

    return Accs, [avg_time_spended]


def test_main():
    # 加载替代模型
    if args.model_type == 'vgg16':
        model = models.vgg16()
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'vgg16_checkpoint.pth')))
        model = nn.Sequential(
            Normalize(),
            model
        )
        model.to(device)
    elif args.model_type == 'vgg19':
        model = models.vgg19()
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'vgg19_checkpoint.pth')))
        model = nn.Sequential(
            Normalize(),
            model
        )
        model.to(device)
    elif args.model_type == 'resnet50':
        model = models.resnet50()
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'resnet50_checkpoint.pth')))
        model = nn.Sequential(
            Normalize(),
            model
        )
        model.to(device)
    elif args.model_type == 'resnet152':
        model = models.resnet152()
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'resnet152_checkpoint.pth')))
        model = nn.Sequential(
            Normalize(),
            model
        )
        model.to(device)
    elif args.model_type == 'inceptionv3':
        model = models.inception_v3()
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'inception_v3_checkpoint1.pth')))
        model = nn.Sequential(
            Normalize(),
            model
        )
        model.to(device)
    elif args.model_type == 'mobilenetv2':
        model = models.mobilenet_v2()
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'mobilenet_v2_checkpoint.pth')))
        model = nn.Sequential(
            Normalize(),
            model
        )
        model.to(device)
    elif args.model_type == 'adv_inception_v3':
        model = timm.create_model('adv_inception_v3', pretrained=True)
        model = nn.Sequential(
            Normalize(),
            model
        )
        model.to(device)
    elif args.model_type == 'ens_adv_inception_resnet_v2':
        model = timm.create_model('ens_adv_inception_resnet_v2', pretrained=True)
        model = nn.Sequential(
            Normalize(),
            model
        )
        model.to(device)
    # 评估模式
    model.eval()

    # 加载受害者模型
    if args.victim_model_type == 'vgg16':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = models.vgg16()
            victim_model.load_state_dict(torch.load(os.path.join(args.model_dir, 'vgg16_checkpoint.pth')))
            victim_model = nn.Sequential(
                Normalize(),
                victim_model
            )
            victim_model.to(device)
    elif args.victim_model_type == 'vgg19':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = models.vgg19()
            victim_model.load_state_dict(torch.load(os.path.join(args.model_dir, 'vgg19_checkpoint.pth')))
            victim_model = nn.Sequential(
                Normalize(),
                victim_model
            )
            victim_model.to(device)
    elif args.victim_model_type == 'resnet50':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = models.resnet50()
            victim_model.load_state_dict(
                torch.load(os.path.join(args.model_dir, 'resnet50_checkpoint.pth')))
            victim_model = nn.Sequential(
                Normalize(),
                victim_model
            )
            victim_model.to(device)
    elif args.victim_model_type == 'resnet152':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = models.resnet152()
            victim_model.load_state_dict(
                torch.load(os.path.join(args.model_dir, 'resnet152_checkpoint.pth')))
            victim_model = nn.Sequential(
                Normalize(),
                victim_model
            )
            victim_model.to(device)
    elif args.victim_model_type == 'inceptionv3':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = models.inception_v3()
            victim_model.load_state_dict(
                torch.load(os.path.join(args.model_dir, 'inception_v3_checkpoint.pth')))
            victim_model = nn.Sequential(
                Normalize(),
                victim_model
            )
            victim_model.to(device)
    elif args.victim_model_type == 'mobilenetv2':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = models.mobilenet_v2()
            victim_model.load_state_dict(
                torch.load(os.path.join(args.model_dir, 'mobilenet_v2_checkpoint.pth')))
            victim_model = nn.Sequential(
                Normalize(),
                victim_model
            )
            victim_model.to(device)
    elif args.victim_model_type == 'adv_inception_v3':
        victim_model = timm.create_model('adv_inception_v3', pretrained=True)
        victim_model = nn.Sequential(
            Normalize(),
            victim_model
        )
        victim_model.to(device)
    elif args.victim_model_type == 'ens_adv_inception_resnet_v2':
        victim_model = timm.create_model('ens_adv_inception_resnet_v2', pretrained=True)
        victim_model = nn.Sequential(
            Normalize(),
            victim_model
        )
        victim_model.to(device)
    # 评估模式
    victim_model.eval()

    # 查看是否有预测状态文件
    if os.path.exists(os.path.join(args.model_dir, args.model_type+'_predict_status_arr.txt')):
        model_predict_status_arr = np.loadtxt(os.path.join(args.model_dir, args.model_type+'_predict_status_arr.txt'))
    else:
        testset = SelectedImagenet(args.data_dir)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, pin_memory=True, num_workers=0)
        acc, test_loss, model_predict_status_arr = evaluate_accuracy(test_loader, model)
        print('acc: ', acc)
        np.savetxt(os.path.join(args.model_dir, args.model_type+'_predict_status_arr.txt'), model_predict_status_arr)

    if os.path.exists(os.path.join(args.model_dir, args.victim_model_type+'_predict_status_arr.txt')):
        victim_model_predict_status_arr = np.loadtxt(os.path.join(args.model_dir, args.victim_model_type+'_predict_status_arr.txt'))
    else:
        testset = SelectedImagenet(args.data_dir)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, pin_memory=True, num_workers=0)
        acc, test_loss, victim_model_predict_status_arr = evaluate_accuracy(test_loader, victim_model)
        print('victim acc: ', acc)
        np.savetxt(os.path.join(args.model_dir, args.victim_model_type+'_predict_status_arr.txt'), victim_model_predict_status_arr)

    model_predict_status_arr = (model_predict_status_arr != 0)
    victim_model_predict_status_arr = (victim_model_predict_status_arr != 0)

    # 加载替代模型和受害者模型都分类正确的测试样本集
    testset = SelectedImagenet2Attack(model_predict_status_arr, victim_model_predict_status_arr, args.data_dir)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False,
                                                  pin_memory=True, num_workers=0)

    '''
    FGSM攻击评估
    '''
    # 执行FGSM非目标攻击
    Accs_EFCEnT_FGSM, time_EFCEnT_FGSM = Untarget_FGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_EFCEnT_FGSM.npy", Accs_EFCEnT_FGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_EFCEnT_FGSM.npy", time_EFCEnT_FGSM)

    '''
    BIM攻击评估
    '''
    # 执行BIM非目标攻击
    Accs_EFCEnT_IFGSM, time_EFCEnT_IFGSM = Untarget_BIM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_EFCEnT_IFGSM.npy", Accs_EFCEnT_IFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_EFCEnT_IFGSM.npy", time_EFCEnT_IFGSM)

    '''
    MIFGSM攻击评估
    '''
    # 执行MIFGSM非目标攻击
    Accs_EFCEnT_MIFGSM, time_EFCEnT_MIFGSM = Untarget_MIFGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_EFCEnT_MIFGSM.npy", Accs_EFCEnT_MIFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_EFCEnT_MIFGSM.npy", time_EFCEnT_MIFGSM)

    '''
    DIFGSM攻击评估
    '''
    # 执行DIFGSM非目标攻击
    Accs_EFCEnT_DIFGSM, time_EFCEnT_DIFGSM = Untarget_DIFGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_EFCEnT_DIFGSM.npy", Accs_EFCEnT_DIFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_EFCEnT_DIFGSM.npy", time_EFCEnT_DIFGSM)

    '''
    SINIFGSM攻击评估
    '''
    # 执行SINIFGSM非目标攻击
    Accs_EFCEnT_SINIFGSM, time_EFCEnT_SINIFGSM = Untarget_SINIFGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_EFCEnT_SINIFGSM.npy", Accs_EFCEnT_SINIFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_EFCEnT_SINIFGSM.npy", time_EFCEnT_SINIFGSM)

    '''
    VMIFGSM攻击评估
    '''
    # 执行VMIFGSM非目标攻击
    Accs_EFCEnT_VMIFGSM, time_EFCEnT_VMIFGSM = Untarget_VMIFGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_EFCEnT_VMIFGSM.npy", Accs_EFCEnT_VMIFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_EFCEnT_VMIFGSM.npy", time_EFCEnT_VMIFGSM)

    '''
    QFGSM攻击评估
    '''
    # 执行QFGSM非目标攻击
    Accs_WAEFCEnT_QFGSM, time_WAEFCEnT_QFGSM = Untarget_QFGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_WAEFCEnT_QFGSM.npy", Accs_WAEFCEnT_QFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_WAEFCEnT_QFGSM.npy", time_WAEFCEnT_QFGSM)

    '''
    QIFGSM攻击评估
    '''
    # 执行QIFGSM非目标攻击
    Accs_WAEFCEnT_QIFGSM, time_WAEFCEnT_QIFGSM = Untarget_QIFGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_WAEFCEnT_QIFGSM.npy", Accs_WAEFCEnT_QIFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_WAEFCEnT_QIFGSM.npy", time_WAEFCEnT_QIFGSM)

    '''
    QMIFGSM攻击评估
    '''
    # 执行QMIFGSM非目标攻击
    Accs_WAEFCEnT_QMIFGSM, time_WAEFCEnT_QMIFGSM = Untarget_QMIFGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_WAEFCEnT_QMIFGSM.npy",
            Accs_WAEFCEnT_QMIFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_WAEFCEnT_QMIFGSM.npy",
            time_WAEFCEnT_QMIFGSM)

    '''
    QDIFGSM攻击评估
    '''
    # 执行QDIFGSM非目标攻击
    Accs_WAEFCEnT_QDIFGSM, time_WAEFCEnT_QDIFGSM = Untarget_QDIFGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_WAEFCEnT_QDIFGSM.npy",
            Accs_WAEFCEnT_QDIFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_WAEFCEnT_QDIFGSM.npy",
            time_WAEFCEnT_QDIFGSM)

    '''
    QSINIFGSM攻击评估
    '''
    # 执行QSINIFGSM非目标攻击
    Accs_WAEFCEnT_QSINIFGSM, time_WAEFCEnT_QSINIFGSM = Untarget_QSINIFGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_WAEFCEnT_QSINIFGSM.npy",
            Accs_WAEFCEnT_QSINIFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_WAEFCEnT_QSINIFGSM.npy",
            time_WAEFCEnT_QSINIFGSM)

    '''
    QVMIFGSM攻击评估
    '''
    # 执行QVMIFGSM非目标攻击
    Accs_WAEFCEnT_QVMIFGSM, time_WAEFCEnT_QVMIFGSM = Untarget_QVMIFGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_WAEFCEnT_QVMIFGSM.npy",
            Accs_WAEFCEnT_QVMIFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_WAEFCEnT_QVMIFGSM.npy",
            time_WAEFCEnT_QVMIFGSM)

if __name__ == '__main__':
    test_main()
