from __future__ import print_function

import sys
sys.path.append("/home/yangxiangyuan/yxy/Paper4_ACE")

import torch
import numpy as np
import argparse
import os
import time
import torch.nn as nn
import torch.nn.functional as F

from CIFAR100.models.vgg16 import vgg16
from CIFAR100.models.resnet import resnet50
from CIFAR100.models.resnext import resnext50
from CIFAR100.models.wideresnet import wideresnet
from CIFAR100.models.densenet121 import densenet121
from CIFAR100.models.mobilenetv2 import mobilenetv2

from attack.fgsm import FGSM
from attack.bim import BIM
from attack.mifgsm import MIFGSM
from attack.difgsm import DIFGSM
from attack.sinifgsm import SINIFGSM
from attack.vmifgsm import VMIFGSM

from CIFAR100.utils import Normalize, Normalize1, get_test_loaders
from CIFAR100.CIFAR100_attack import Cifar100

parser = argparse.ArgumentParser(description='PyTorch CIFAR100 Test')
parser.add_argument('--model-type', default='resnet50', type=str)
parser.add_argument('--victim-model-type', default='vgg16', type=str)
parser.add_argument('--test-batch-size', type=int, default=128, metavar='N',
                    help='input batch size for testing (default: 128)')
parser.add_argument('--data-dir', default='../data', type=str)
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--gpu', default=0, type=int)
parser.add_argument('--Tmin', default=1.0, type=float)
parser.add_argument('--Tmax', default=8.1, type=float)
parser.add_argument('--step', default=0.5, type=float)
parser.add_argument('--K', default=1.0, type=float)
parser.add_argument('--model-dir', default='./saved_model',
                    help='directory of model for saving checkpoint')

args = parser.parse_args()

for arg in vars(args):
    print(arg, ':', getattr(args, arg))

# settings
model_dir = args.model_dir
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda:" + str(args.gpu) if use_cuda else "cpu")
kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}


def evaluate_accuracy(data_iter, net, device=None):
    if device is None and isinstance(net, torch.nn.Module):
        # 如果没指定device就使用net的device
        device = list(net.parameters())[0].device
    loss = torch.nn.MSELoss(reduction='none')
    acc_sum, n = 0.0, 0
    test_l_sum, batch_count = 0.0, 0
    predict_status_arr = []
    with torch.no_grad():
        for X, y in data_iter:
            X = X.to(device)
            y = y.to(device)
            output = net(X)
            acc_sum += (output.argmax(dim=1) == y).float().sum().cpu().item()
            predict_status_arr.append((output.argmax(dim=1) == y).int().cpu().detach().numpy())
            # 计算泛化误差
            y_onehot = F.one_hot(y, 100).float()
            l = loss(F.softmax(output, dim=1), y_onehot)
            test_l_sum += l.mean().cpu().item()

            n += y.shape[0]
            batch_count += 1
    predict_status_arr = np.concatenate(predict_status_arr, axis=0)
    return acc_sum / n, test_l_sum / batch_count, predict_status_arr


'''
FGSM非目标攻击以及目标攻击评估
'''


def Untarget_FGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFCEnT作为损失函数计算出来的攻击成功率的结果列表
    """
    scales = np.arange(args.Tmin, args.Tmax, args.step)
    Accs_EFCEnT = []
    # EFCEnT loss
    start_time = time.time()
    for scale in scales:
        success_attack = 0
        fgsm = FGSM(model, eps=8.0/255, loss_type='EFCEnT', temperature_scale=scale, fuzzy_scale=args.K)
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv = fgsm(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (~(pred.eq(target.view_as(pred)))).sum().item()
        print('FGSM EFCEnT Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFCEnT.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFCEnT = (end_time - start_time) / (60 * len(scales))
    print('avg_time_spended:{}'.format(avg_time_spended_EFCEnT))

    return Accs_EFCEnT, [avg_time_spended_EFCEnT]

'''
BIM非目标攻击以及目标攻击评估
'''


def Untarget_BIM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFCEnT作为损失函数计算出来的攻击成功率的结果列表
    """
    scales = np.arange(args.Tmin, args.Tmax, args.step)
    Accs_EFCEnT = []
    # EFCEnT loss
    start_time = time.time()
    for scale in scales:
        success_attack = 0
        bim = BIM(model, eps=8.0/255, alpha=0.8/255, steps=10, loss_type='EFCEnT', temperature_scale=scale, fuzzy_scale=args.K)
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv = bim(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (~(pred.eq(target.view_as(pred)))).sum().item()
        print('BIM EFCEnT Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFCEnT.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFCEnT = (end_time - start_time) / (60 * len(scales))
    print('avg_time_spended:{}'.format(avg_time_spended_EFCEnT))

    return Accs_EFCEnT, [avg_time_spended_EFCEnT]


'''
MIFGSM非目标攻击以及目标攻击评估
'''


def Untarget_MIFGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFCEnT作为损失函数计算出来的攻击成功率的结果列表
    """
    scales = np.arange(args.Tmin, args.Tmax, args.step)
    Accs_EFCEnT = []
    # EFCEnT loss
    start_time = time.time()
    for scale in scales:
        success_attack = 0
        mifgsm = MIFGSM(model, eps=8.0/255, alpha=0.8/255, steps=10, loss_type='EFCEnT', temperature_scale=scale, fuzzy_scale=args.K)
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv = mifgsm(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (~(pred.eq(target.view_as(pred)))).sum().item()
        print('MIFGSM EFCEnT Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFCEnT.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFCEnT = (end_time - start_time) / (60 * len(scales))
    print('avg_time_spended:{}'.format(avg_time_spended_EFCEnT))

    return Accs_EFCEnT, [avg_time_spended_EFCEnT]


'''
DIFGSM非目标攻击以及目标攻击评估
'''


def Untarget_DIFGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFCEnT作为损失函数计算出来的攻击成功率的结果列表
    """
    scales = np.arange(args.Tmin, args.Tmax, args.step)
    Accs_EFCEnT = []
    # EFCEnT loss
    start_time = time.time()
    for scale in scales:
        success_attack = 0
        difgsm = DIFGSM(model, eps=8.0/255, alpha=0.8/255, steps=10, loss_type='EFCEnT', temperature_scale=scale, fuzzy_scale=args.K)
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv = difgsm(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (~(pred.eq(target.view_as(pred)))).sum().item()
        print('DIFGSM EFCEnT Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFCEnT.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFCEnT = (end_time - start_time) / (60 * len(scales))
    print('avg_time_spended:{}'.format(avg_time_spended_EFCEnT))

    return Accs_EFCEnT, [avg_time_spended_EFCEnT]


'''
SINIFGSM非目标攻击以及目标攻击评估
'''


def Untarget_SINIFGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFCEnT作为损失函数计算出来的攻击成功率的结果列表
    """
    scales = np.arange(args.Tmin, args.Tmax, args.step)
    Accs_EFCEnT = []
    # EFCEnT loss
    start_time = time.time()
    for scale in scales:
        success_attack = 0
        sinifgsm = SINIFGSM(model, eps=8.0/255, alpha=0.8/255, steps=10, loss_type='EFCEnT', temperature_scale=scale, fuzzy_scale=args.K)
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv = sinifgsm(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (~(pred.eq(target.view_as(pred)))).sum().item()
        print('SINIFGSM EFCEnT Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFCEnT.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFCEnT = (end_time - start_time) / (60 * len(scales))
    print('avg_time_spended:{}'.format(avg_time_spended_EFCEnT))

    return Accs_EFCEnT, [avg_time_spended_EFCEnT]


'''
VMIFGSM非目标攻击以及目标攻击评估
'''


def Untarget_VMIFGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFCEnT作为损失函数计算出来的攻击成功率的结果列表
    """
    scales = np.arange(args.Tmin, args.Tmax, args.step)
    Accs_EFCEnT = []
    # EFCEnT loss
    start_time = time.time()
    for scale in scales:
        success_attack = 0
        vmifgsm = VMIFGSM(model, eps=8.0/255, alpha=0.8/255, steps=10, loss_type='EFCEnT', temperature_scale=scale, fuzzy_scale=args.K)
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv = vmifgsm(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (~(pred.eq(target.view_as(pred)))).sum().item()
        print('VMIFGSM EFCEnT Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFCEnT.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFCEnT = (end_time - start_time) / (60 * len(scales))
    print('avg_time_spended:{}'.format(avg_time_spended_EFCEnT))

    return Accs_EFCEnT, [avg_time_spended_EFCEnT]


def test_main():
    # 加载替代模型
    if args.model_type == 'vgg16':
        model = vgg16(0.5).cuda(device)
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'vgg16_checkpoint.pth'), map_location=device))
    elif args.model_type == 'resnet50':
        model = resnet50(0.5).cuda(device)
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'resnet50_checkpoint.pth'), map_location=device))
    elif args.model_type == 'resnext50':
        model = resnext50().cuda(device)
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'resnext50_model.pth'), map_location=device))
    elif args.model_type == 'wrn-16-4':
        model = wideresnet(depth=16, widen_factor=4).cuda(device)
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'wrn-16-4_checkpoint.pth'), map_location=device))
    elif args.model_type == 'densenet121':
        model = densenet121().cuda(device)
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'densenet121_model.pth'), map_location=device))
    elif args.model_type == 'mobilenetv2':
        model = mobilenetv2().cuda(device)
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'mobilenetv2_model.pth'), map_location=device))
    model.eval()

    # 加载受害者模型
    if args.victim_model_type == 'vgg16':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = vgg16(0.5).cuda(device)
            victim_model.load_state_dict(
                torch.load(os.path.join(args.model_dir, 'vgg16_checkpoint.pth'), map_location=device))
    elif args.victim_model_type == 'resnet50':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = resnet50(0.5).cuda(device)
            victim_model.load_state_dict(
                torch.load(os.path.join(args.model_dir, 'resnet50_checkpoint.pth'), map_location=device))
    elif args.victim_model_type == 'resnext50':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = resnext50().cuda(device)
            victim_model.load_state_dict(
                torch.load(os.path.join(args.model_dir, 'resnext50_model.pth'), map_location=device))
    elif args.victim_model_type == 'wrn-16-4':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = wideresnet(depth=16, widen_factor=4).cuda(device)
            victim_model.load_state_dict(
                torch.load(os.path.join(args.model_dir, 'wrn-16-4_checkpoint.pth'), map_location=device))
    elif args.victim_model_type == 'densenet121':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = densenet121().cuda(device)
            victim_model.load_state_dict(
                torch.load(os.path.join(args.model_dir, 'densenet121_model.pth'), map_location=device))
    elif args.victim_model_type == 'mobilenetv2':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = mobilenetv2().cuda(device)
            victim_model.load_state_dict(
                torch.load(os.path.join(args.model_dir, 'mobilenetv2_model.pth'), map_location=device))
    victim_model.eval()

    # 查看是否有预测状态文件
    if os.path.exists(os.path.join(args.model_dir, args.model_type + '_predict_status_arr.txt')):
        model_predict_status_arr = np.loadtxt(os.path.join(args.model_dir, args.model_type + '_predict_status_arr.txt'))
    else:
        test_loader = get_test_loaders(args.data_dir, args.test_batch_size)
        acc, test_loss, model_predict_status_arr = evaluate_accuracy(test_loader, model)
        print('acc: ', acc)
        np.savetxt(os.path.join(args.model_dir, args.model_type + '_predict_status_arr.txt'), model_predict_status_arr)

    if os.path.exists(os.path.join(args.model_dir, args.victim_model_type + '_predict_status_arr.txt')):
        victim_model_predict_status_arr = np.loadtxt(
            os.path.join(args.model_dir, args.victim_model_type + '_predict_status_arr.txt'))
    else:
        test_loader = get_test_loaders(args.data_dir, args.test_batch_size)
        acc, test_loss, victim_model_predict_status_arr = evaluate_accuracy(test_loader, victim_model)
        print('victim acc: ', acc)
        np.savetxt(os.path.join(args.model_dir, args.victim_model_type + '_predict_status_arr.txt'),
                   victim_model_predict_status_arr)

    model_predict_status_arr = (model_predict_status_arr != 0)
    victim_model_predict_status_arr = (victim_model_predict_status_arr != 0)

    # 加载替代模型和受害者模型都分类正确的测试样本集
    testset = Cifar100(model_predict_status_arr, victim_model_predict_status_arr, args.data_dir)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, pin_memory=True,
                                              num_workers=0)

    '''
    FGSM攻击评估
    '''
    # 执行FGSM非目标攻击
    Accs_EFCEnT_FGSM, time_EFCEnT_FGSM = Untarget_FGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_" + str(args.Tmin) + "_" + str(args.Tmax) + "_" + str(args.step) + "_" + str(args.K) + "_Accs_EFCEnT_FGSM_ablation_temperature_scale.npy", Accs_EFCEnT_FGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_" + str(args.Tmin) + "_" + str(args.Tmax) + "_" + str(args.step) + "_" + str(args.K) + "_time_EFCEnT_FGSM_ablation_temperature_scale.npy", time_EFCEnT_FGSM)

    '''
    BIM攻击评估
    '''
    # 执行BIM非目标攻击
    Accs_EFCEnT_IFGSM, time_EFCEnT_IFGSM = Untarget_BIM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_" + str(args.Tmin) + "_" + str(args.Tmax) + "_" + str(args.step) + "_" + str(args.K) + "_Accs_EFCEnT_IFGSM_ablation_temperature_scale.npy", Accs_EFCEnT_IFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_" + str(args.Tmin) + "_" + str(args.Tmax) + "_" + str(args.step) + "_" + str(args.K) + "_time_EFCEnT_IFGSM_ablation_temperature_scale.npy", time_EFCEnT_IFGSM)

    '''
    MIFGSM攻击评估
    '''
    # 执行MIFGSM非目标攻击
    Accs_EFCEnT_MIFGSM, time_EFCEnT_MIFGSM = Untarget_MIFGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_" + str(args.Tmin) + "_" + str(args.Tmax) + "_" + str(args.step) + "_" + str(args.K) + "_Accs_EFCEnT_MIFGSM_ablation_temperature_scale.npy", Accs_EFCEnT_MIFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_" + str(args.Tmin) + "_" + str(args.Tmax) + "_" + str(args.step) + "_" + str(args.K) + "_time_EFCEnT_MIFGSM_ablation_temperature_scale.npy", time_EFCEnT_MIFGSM)

    '''
    DIFGSM攻击评估
    '''
    # 执行DIFGSM非目标攻击
    Accs_EFCEnT_DIFGSM, time_EFCEnT_DIFGSM = Untarget_DIFGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_" + str(args.Tmin) + "_" + str(args.Tmax) + "_" + str(args.step) + "_" + str(args.K) + "_Accs_EFCEnT_DIFGSM_ablation_temperature_scale.npy", Accs_EFCEnT_DIFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_" + str(args.Tmin) + "_" + str(args.Tmax) + "_" + str(args.step) + "_" + str(args.K) + "_time_EFCEnT_DIFGSM_ablation_temperature_scale.npy", time_EFCEnT_DIFGSM)

    '''
    SINIFGSM攻击评估
    '''
    # 执行SINIFGSM非目标攻击
    Accs_EFCEnT_SINIFGSM, time_EFCEnT_SINIFGSM = Untarget_SINIFGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_" + str(args.Tmin) + "_" + str(args.Tmax) + "_" + str(args.step) + "_" + str(args.K) + "_Accs_EFCEnT_SINIFGSM_ablation_temperature_scale.npy", Accs_EFCEnT_SINIFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_" + str(args.Tmin) + "_" + str(args.Tmax) + "_" + str(args.step) + "_" + str(args.K) + "_time_EFCEnT_SINIFGSM_ablation_temperature_scale.npy", time_EFCEnT_SINIFGSM)

    '''
    VMIFGSM攻击评估
    '''
    # 执行VMIFGSM非目标攻击
    Accs_EFCEnT_VMIFGSM, time_EFCEnT_VMIFGSM = Untarget_VMIFGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_" + str(args.Tmin) + "_" + str(args.Tmax) + "_" + str(args.step) + "_" + str(args.K) + "_Accs_EFCEnT_VMIFGSM_ablation_temperature_scale.npy", Accs_EFCEnT_VMIFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_" + str(args.Tmin) + "_" + str(args.Tmax) + "_" + str(args.step) + "_" + str(args.K) + "_time_EFCEnT_VMIFGSM_ablation_temperature_scale.npy", time_EFCEnT_VMIFGSM)

if __name__ == '__main__':
    test_main()
