from __future__ import print_function

import sys
sys.path.append("/home/yangxiangyuan/yxy/Paper4_ACE")

import torch
import numpy as np
import argparse
import os
import time
import torch.nn as nn
import torch.nn.functional as F

from CIFAR100.models.vgg16 import vgg16
from CIFAR100.models.resnet import resnet50
from CIFAR100.models.resnext import resnext50
from CIFAR100.models.wideresnet import wideresnet
from CIFAR100.models.densenet121 import densenet121
from CIFAR100.models.mobilenetv2 import mobilenetv2

from attack.fgsm import FGSM
from attack.bim import BIM
from attack.mifgsm import MIFGSM
from attack.difgsm import DIFGSM
from attack.sinifgsm import SINIFGSM
from attack.vmifgsm import VMIFGSM

from CIFAR100.utils import Normalize, Normalize1, get_test_loaders
from CIFAR100.CIFAR100_attack import Cifar100

parser = argparse.ArgumentParser(description='PyTorch CIFAR100 Test')
parser.add_argument('--model-type', default='resnet50', type=str)
parser.add_argument('--victim-model-type', default='vgg16', type=str)
parser.add_argument('--test-batch-size', type=int, default=128, metavar='N',
                    help='input batch size for testing (default: 128)')
parser.add_argument('--data-dir', default='../data', type=str)
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--gpu', default=0, type=int)
parser.add_argument('--Tmin', default=1.0, type=float)
parser.add_argument('--Tmax', default=8.1, type=float)
parser.add_argument('--step', default=0.5, type=float)
parser.add_argument('--K', default=1.0, type=float)
parser.add_argument('--model-dir', default='./saved_model',
                    help='directory of model for saving checkpoint')

args = parser.parse_args()

for arg in vars(args):
    print(arg, ':', getattr(args, arg))

# settings
model_dir = args.model_dir
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda:" + str(args.gpu) if use_cuda else "cpu")
kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}


def evaluate_accuracy(data_iter, net, device=None):
    if device is None and isinstance(net, torch.nn.Module):
        # 如果没指定device就使用net的device
        device = list(net.parameters())[0].device
    loss = torch.nn.MSELoss(reduction='none')
    acc_sum, n = 0.0, 0
    test_l_sum, batch_count = 0.0, 0
    predict_status_arr = []
    with torch.no_grad():
        for X, y in data_iter:
            X = X.to(device)
            y = y.to(device)
            output = net(X)
            acc_sum += (output.argmax(dim=1) == y).float().sum().cpu().item()
            predict_status_arr.append((output.argmax(dim=1) == y).int().cpu().detach().numpy())
            # 计算泛化误差
            y_onehot = F.one_hot(y, 100).float()
            l = loss(F.softmax(output, dim=1), y_onehot)
            test_l_sum += l.mean().cpu().item()

            n += y.shape[0]
            batch_count += 1
    predict_status_arr = np.concatenate(predict_status_arr, axis=0)
    return acc_sum / n, test_l_sum / batch_count, predict_status_arr


'''
FGSM目标攻击评估
'''


def Target_FGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFRCE作为损失函数计算出来的攻击成功率的结果列表
    """
    scales = np.arange(args.Tmin, args.Tmax, args.step)
    Accs_EFRCE = []

    # EFRCE loss
    start_time = time.time()
    for scale in scales:
        success_attack = 0
        fgsm = FGSM(model, eps=8.0/255, loss_type='EFRCE', temperature_scale=scale, fuzzy_scale=args.K)
        fgsm.set_mode_targeted_random()
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv, attack_target = fgsm(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (pred.eq(attack_target.view_as(pred))).sum().item()
        print('FGSM EFRCE Target Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFRCE.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFRCE = (end_time - start_time) / (60 * len(scales))
    print('avg_time_spended:{}'.format(avg_time_spended_EFRCE))

    return Accs_EFRCE, [avg_time_spended_EFRCE]


'''
BIM目标攻击评估
'''


def Target_BIM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFRCE作为损失函数计算出来的攻击成功率的结果列表
    """
    scales = np.arange(args.Tmin, args.Tmax, args.step)
    Accs_EFRCE = []

    # EFRCE loss
    start_time = time.time()
    for scale in scales:
        success_attack = 0
        bim = BIM(model, eps=8.0/255, alpha=0.8/255, steps=10, loss_type='EFRCE', temperature_scale=scale, fuzzy_scale=args.K)
        bim.set_mode_targeted_random()
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv, attack_target = bim(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (pred.eq(attack_target.view_as(pred))).sum().item()
        print('BIM EFRCE Target Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFRCE.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFRCE = (end_time - start_time) / (60 * len(scales))
    print('avg_time_spended:{}'.format(avg_time_spended_EFRCE))

    return Accs_EFRCE, [avg_time_spended_EFRCE]


'''
MIFGSM目标攻击评估
'''


def Target_MIFGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFRCE作为损失函数计算出来的攻击成功率的结果列表
    """
    scales = np.arange(args.Tmin, args.Tmax, args.step)
    Accs_EFRCE = []

    # EFRCE loss
    start_time = time.time()
    for scale in scales:
        success_attack = 0
        mifgsm = MIFGSM(model, eps=8.0/255, alpha=0.8/255, steps=10, loss_type='EFRCE', temperature_scale=scale, fuzzy_scale=args.K)
        mifgsm.set_mode_targeted_random()
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv, attack_target = mifgsm(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (pred.eq(attack_target.view_as(pred))).sum().item()
        print('MIFGSM EFRCE Target Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFRCE.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFRCE = (end_time - start_time) / (60 * len(scales))
    print('avg_time_spended:{}'.format(avg_time_spended_EFRCE))

    return Accs_EFRCE, [avg_time_spended_EFRCE]


'''
DIFGSM目标攻击评估
'''


def Target_DIFGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFRCE作为损失函数计算出来的攻击成功率的结果列表
    """
    scales = np.arange(args.Tmin, args.Tmax, args.step)
    Accs_EFRCE = []

    # EFRCE loss
    start_time = time.time()
    for scale in scales:
        success_attack = 0
        difgsm = DIFGSM(model, eps=8.0/255, alpha=0.8/255, steps=10, loss_type='EFRCE', temperature_scale=scale, fuzzy_scale=args.K)
        difgsm.set_mode_targeted_random()
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv, attack_target = difgsm(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (pred.eq(attack_target.view_as(pred))).sum().item()
        print('DIFGSM EFRCE Target Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFRCE.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFRCE = (end_time - start_time) / (60 * len(scales))
    print('avg_time_spended:{}'.format(avg_time_spended_EFRCE))

    return Accs_EFRCE, [avg_time_spended_EFRCE]


'''
SINIFGSM目标攻击评估
'''


def Target_SINIFGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFRCE作为损失函数计算出来的攻击成功率的结果列表
    """
    scales = np.arange(args.Tmin, args.Tmax, args.step)
    Accs_EFRCE = []

    # EFRCE loss
    start_time = time.time()
    for scale in scales:
        success_attack = 0
        sinifgsm = SINIFGSM(model, eps=8.0/255, alpha=0.8/255, steps=10, loss_type='EFRCE', temperature_scale=scale, fuzzy_scale=args.K)
        sinifgsm.set_mode_targeted_random()
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv, attack_target = sinifgsm(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (pred.eq(attack_target.view_as(pred))).sum().item()
        print('SINIFGSM EFRCE Target Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFRCE.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFRCE = (end_time - start_time) / (60 * len(scales))
    print('avg_time_spended:{}'.format(avg_time_spended_EFRCE))

    return Accs_EFRCE, [avg_time_spended_EFRCE]


'''
VMIFGSM目标攻击评估
'''


def Target_VMIFGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFRCE作为损失函数计算出来的攻击成功率的结果列表
    """
    scales = np.arange(args.Tmin, args.Tmax, args.step)
    Accs_EFRCE = []

    # EFRCE loss
    start_time = time.time()
    for scale in scales:
        success_attack = 0
        vmifgsm = VMIFGSM(model, eps=8.0/255, alpha=0.8/255, steps=10, loss_type='EFRCE', temperature_scale=scale, fuzzy_scale=args.K)
        vmifgsm.set_mode_targeted_random()
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv, attack_target = vmifgsm(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (pred.eq(attack_target.view_as(pred))).sum().item()
        print('VMIFGSM EFRCE Target Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFRCE.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFRCE = (end_time - start_time) / (60 * len(scales))
    print('avg_time_spended:{}'.format(avg_time_spended_EFRCE))

    return Accs_EFRCE, [avg_time_spended_EFRCE]


def test_main():
    # 加载替代模型
    if args.model_type == 'vgg16':
        model = vgg16(0.5).cuda(device)
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'vgg16_checkpoint.pth'), map_location=device))
    elif args.model_type == 'resnet50':
        model = resnet50(0.5).cuda(device)
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'resnet50_checkpoint.pth'), map_location=device))
    elif args.model_type == 'resnext50':
        model = resnext50().cuda(device)
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'resnext50_model.pth'), map_location=device))
    elif args.model_type == 'wrn-16-4':
        model = wideresnet(depth=16, widen_factor=4).cuda(device)
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'wrn-16-4_checkpoint.pth'), map_location=device))
    elif args.model_type == 'densenet121':
        model = densenet121().cuda(device)
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'densenet121_model.pth'), map_location=device))
    elif args.model_type == 'mobilenetv2':
        model = mobilenetv2().cuda(device)
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'mobilenetv2_model.pth'), map_location=device))
    model.eval()

    # 加载受害者模型
    if args.victim_model_type == 'vgg16':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = vgg16(0.5).cuda(device)
            victim_model.load_state_dict(
                torch.load(os.path.join(args.model_dir, 'vgg16_checkpoint.pth'), map_location=device))
    elif args.victim_model_type == 'resnet50':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = resnet50(0.5).cuda(device)
            victim_model.load_state_dict(
                torch.load(os.path.join(args.model_dir, 'resnet50_checkpoint.pth'), map_location=device))
    elif args.victim_model_type == 'resnext50':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = resnext50().cuda(device)
            victim_model.load_state_dict(
                torch.load(os.path.join(args.model_dir, 'resnext50_model.pth'), map_location=device))
    elif args.victim_model_type == 'wrn-16-4':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = wideresnet(depth=16, widen_factor=4).cuda(device)
            victim_model.load_state_dict(
                torch.load(os.path.join(args.model_dir, 'wrn-16-4_checkpoint.pth'), map_location=device))
    elif args.victim_model_type == 'densenet121':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = densenet121().cuda(device)
            victim_model.load_state_dict(
                torch.load(os.path.join(args.model_dir, 'densenet121_model.pth'), map_location=device))
    elif args.victim_model_type == 'mobilenetv2':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = mobilenetv2().cuda(device)
            victim_model.load_state_dict(
                torch.load(os.path.join(args.model_dir, 'mobilenetv2_model.pth'), map_location=device))
    victim_model.eval()

    # 查看是否有预测状态文件
    if os.path.exists(os.path.join(args.model_dir, args.model_type + '_predict_status_arr.txt')):
        model_predict_status_arr = np.loadtxt(os.path.join(args.model_dir, args.model_type + '_predict_status_arr.txt'))
    else:
        test_loader = get_test_loaders(args.data_dir, args.test_batch_size)
        acc, test_loss, model_predict_status_arr = evaluate_accuracy(test_loader, model)
        print('acc: ', acc)
        np.savetxt(os.path.join(args.model_dir, args.model_type + '_predict_status_arr.txt'), model_predict_status_arr)

    if os.path.exists(os.path.join(args.model_dir, args.victim_model_type + '_predict_status_arr.txt')):
        victim_model_predict_status_arr = np.loadtxt(
            os.path.join(args.model_dir, args.victim_model_type + '_predict_status_arr.txt'))
    else:
        test_loader = get_test_loaders(args.data_dir, args.test_batch_size)
        acc, test_loss, victim_model_predict_status_arr = evaluate_accuracy(test_loader, victim_model)
        print('victim acc: ', acc)
        np.savetxt(os.path.join(args.model_dir, args.victim_model_type + '_predict_status_arr.txt'),
                   victim_model_predict_status_arr)

    model_predict_status_arr = (model_predict_status_arr != 0)
    victim_model_predict_status_arr = (victim_model_predict_status_arr != 0)

    # 加载替代模型和受害者模型都分类正确的测试样本集
    testset = Cifar100(model_predict_status_arr, victim_model_predict_status_arr, args.data_dir)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, pin_memory=True,
                                              num_workers=0)

    '''
    FGSM攻击评估
    '''
    # 执行FGSM目标攻击
    Accs_EFRCE_FGSM, time_EFRCE_FGSM = Target_FGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_EFRCE_FGSM_ablation_temperature_scale.npy", Accs_EFRCE_FGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_EFRCE_FGSM_ablation_temperature_scale.npy", time_EFRCE_FGSM)

    '''
    BIM攻击评估
    '''
    # 执行BIM目标攻击
    Accs_EFRCE_IFGSM, time_EFRCE_IFGSM = Target_BIM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_EFRCE_IFGSM_ablation_temperature_scale.npy", Accs_EFRCE_IFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_EFRCE_IFGSM_ablation_temperature_scale.npy", time_EFRCE_IFGSM)

    '''
    MIFGSM攻击评估
    '''
    # 执行MIFGSM目标攻击
    Accs_EFRCE_MIFGSM, time_EFRCE_MIFGSM = Target_MIFGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_EFRCE_MIFGSM_ablation_temperature_scale.npy", Accs_EFRCE_MIFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_EFRCE_MIFGSM_ablation_temperature_scale.npy", time_EFRCE_MIFGSM)

    '''
    DIFGSM攻击评估
    '''
    # 执行DIFGSM目标攻击
    Accs_EFRCE_DIFGSM, time_EFRCE_DIFGSM = Target_DIFGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_EFRCE_DIFGSM_ablation_temperature_scale.npy", Accs_EFRCE_DIFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_EFRCE_DIFGSM_ablation_temperature_scale.npy", time_EFRCE_DIFGSM)

    '''
    SINIFGSM攻击评估
    '''
    # 执行SINIFGSM目标攻击
    Accs_EFRCE_SINIFGSM, time_EFRCE_SINIFGSM = Target_SINIFGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_EFRCE_SINIFGSM_ablation_temperature_scale.npy", Accs_EFRCE_SINIFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_EFRCE_SINIFGSM_ablation_temperature_scale.npy", time_EFRCE_SINIFGSM)

    '''
    VMIFGSM攻击评估
    '''
    # 执行VMIFGSM目标攻击
    Accs_EFRCE_VMIFGSM, time_EFRCE_VMIFGSM = Target_VMIFGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_EFRCE_VMIFGSM_ablation_temperature_scale.npy", Accs_EFRCE_VMIFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_EFRCE_VMIFGSM_ablation_temperature_scale.npy", time_EFRCE_VMIFGSM)


if __name__ == '__main__':
    test_main()
