from __future__ import print_function

import sys
sys.path.append("/home/yangxiangyuan/yxy/Paper4_ACE")

import torch
import numpy as np
import argparse
import os
import time
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import timm

from attack.fgsm import FGSM
from attack.bim import BIM
from attack.mifgsm import MIFGSM
from attack.difgsm import DIFGSM
from attack.sinifgsm import SINIFGSM
from attack.vmifgsm import VMIFGSM

from ImageNet.utils import Normalize
from ImageNet.Selected_Imagenet import SelectedImagenet
from ImageNet.Selected_Imagenet_to_Attack import SelectedImagenet2Attack

parser = argparse.ArgumentParser(description='PyTorch ImageNet Test')
parser.add_argument('--model-type', default='ens_adv_inception_resnet_v2', type=str)
parser.add_argument('--victim-model-type', default='adv_inception_v3', type=str)
parser.add_argument('--test-batch-size', type=int, default=24, metavar='N',
                    help='input batch size for testing (default: 128)')
parser.add_argument('--data-dir', default='../data/imagenet-selected', type=str)
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--gpu', default=0, type=int)
parser.add_argument('--Tmin', default=1.0, type=float)
parser.add_argument('--Tmax', default=8.1, type=float)
parser.add_argument('--step', default=0.5, type=float)
parser.add_argument('--K', default=1.0, type=float)
parser.add_argument('--model-dir', default='./saved_model',
                    help='directory of model for saving checkpoint')

args = parser.parse_args()

for arg in vars(args):
    print(arg, ':', getattr(args, arg))

# settings
model_dir = args.model_dir
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda:"+str(args.gpu) if use_cuda else "cpu")
kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}

def evaluate_accuracy(data_iter, net, device=None):
    if device is None and isinstance(net, torch.nn.Module):
        # 如果没指定device就使用net的device
        device = list(net.parameters())[0].device
    loss = torch.nn.MSELoss(reduction='none')
    acc_sum, n = 0.0, 0
    test_l_sum, batch_count = 0.0, 0
    predict_status_arr = []
    with torch.no_grad():
        for X, y in data_iter:
            X = X.to(device)
            y = y.to(device)
            output = net(X)
            acc_sum += (output.argmax(dim=1) == y).float().sum().cpu().item()
            predict_status_arr.append((output.argmax(dim=1) == y).int().cpu().detach().numpy())
            # 计算泛化误差
            y_onehot = F.one_hot(y, 1000).float()
            l = loss(F.softmax(output, dim=1), y_onehot)
            test_l_sum += l.mean().cpu().item()

            n += y.shape[0]
            batch_count += 1
    predict_status_arr = np.concatenate(predict_status_arr, axis=0)
    return acc_sum / n, test_l_sum / batch_count, predict_status_arr

'''
FGSM非目标攻击
'''
def Target_FGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFRCE作为损失函数计算出来的攻击成功率的结果列表
    """
    scales = np.arange(args.Tmin, args.Tmax, args.step)
    Accs_EFRCE = []

    # EFRCE loss
    start_time = time.time()
    for scale in scales:
        success_attack = 0
        fgsm = FGSM(model, eps=8.0/255, loss_type='EFRCE', temperature_scale=scale, fuzzy_scale=args.K)
        fgsm.set_mode_targeted_random()
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv, attack_target = fgsm(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (pred.eq(attack_target.view_as(pred))).sum().item()
        print('FGSM EFRCE Target Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFRCE.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFRCE = (end_time - start_time) / (60 * len(scales))
    print('avg_time_spended:{}'.format(avg_time_spended_EFRCE))

    return Accs_EFRCE, [avg_time_spended_EFRCE]

'''
I-FGSM非目标攻击
'''
def Target_BIM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFRCE作为损失函数计算出来的攻击成功率的结果列表
    """
    scales = np.arange(args.Tmin, args.Tmax, args.step)
    Accs_EFRCE = []

    # EFRCE loss
    start_time = time.time()
    for scale in scales:
        success_attack = 0
        bim = BIM(model, eps=8.0/255, alpha=0.8/255, steps=10, loss_type='EFRCE', temperature_scale=scale, fuzzy_scale=args.K)
        bim.set_mode_targeted_random()
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv, attack_target = bim(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (pred.eq(attack_target.view_as(pred))).sum().item()
        print('BIM EFRCE Target Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFRCE.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFRCE = (end_time - start_time) / (60 * len(scales))
    print('avg_time_spended:{}'.format(avg_time_spended_EFRCE))

    return Accs_EFRCE, [avg_time_spended_EFRCE]

'''
MIFGSM非目标攻击
'''
def Target_MIFGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFRCE作为损失函数计算出来的攻击成功率的结果列表
    """
    scales = np.arange(args.Tmin, args.Tmax, args.step)
    Accs_EFRCE = []

    # EFRCE loss
    start_time = time.time()
    for scale in scales:
        success_attack = 0
        mifgsm = MIFGSM(model, eps=8.0/255, alpha=0.8/255, steps=10, loss_type='EFRCE', temperature_scale=scale, fuzzy_scale=args.K)
        mifgsm.set_mode_targeted_random()
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv, attack_target = mifgsm(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (pred.eq(attack_target.view_as(pred))).sum().item()
        print('MIFGSM EFRCE Target Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFRCE.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFRCE = (end_time - start_time) / (60 * len(scales))
    print('avg_time_spended:{}'.format(avg_time_spended_EFRCE))

    return Accs_EFRCE, [avg_time_spended_EFRCE]

'''
DIFGSM非目标攻击以及目标攻击评估
'''
def Target_DIFGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFRCE作为损失函数计算出来的攻击成功率的结果列表
    """
    scales = np.arange(args.Tmin, args.Tmax, args.step)
    Accs_EFRCE = []

    # EFRCE loss
    start_time = time.time()
    for scale in scales:
        success_attack = 0
        difgsm = DIFGSM(model, eps=8.0/255, alpha=0.8/255, steps=10, loss_type='EFRCE', temperature_scale=scale, fuzzy_scale=args.K)
        difgsm.set_mode_targeted_random()
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv, attack_target = difgsm(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (pred.eq(attack_target.view_as(pred))).sum().item()
        print('DIFGSM EFRCE Target Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFRCE.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFRCE = (end_time - start_time) / (60 * len(scales))
    print('avg_time_spended:{}'.format(avg_time_spended_EFRCE))

    return Accs_EFRCE, [avg_time_spended_EFRCE]

'''
SINIFGSM非目标攻击以及目标攻击评估
'''
def Target_SINIFGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFRCE作为损失函数计算出来的攻击成功率的结果列表
    """
    scales = np.arange(args.Tmin, args.Tmax, args.step)
    Accs_EFRCE = []

    # EFRCE loss
    start_time = time.time()
    for scale in scales:
        success_attack = 0
        sinifgsm = SINIFGSM(model, eps=8.0/255, alpha=0.8/255, steps=10, loss_type='EFRCE', temperature_scale=scale, fuzzy_scale=args.K)
        sinifgsm.set_mode_targeted_random()
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv, attack_target = sinifgsm(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (pred.eq(attack_target.view_as(pred))).sum().item()
        print('SINIFGSM EFRCE Target Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFRCE.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFRCE = (end_time - start_time) / (60 * len(scales))
    print('avg_time_spended:{}'.format(avg_time_spended_EFRCE))

    return Accs_EFRCE, [avg_time_spended_EFRCE]

'''
VMIFGSM非目标攻击以及目标攻击评估
'''
def Target_VMIFGSM_test(model, victim_model, device, test_loader):
    """
    return: 返回分别使用EFRCE作为损失函数计算出来的攻击成功率的结果列表
    """
    scales = np.arange(args.Tmin, args.Tmax, args.step)
    Accs_EFRCE = []

    # EFRCE loss
    start_time = time.time()
    for scale in scales:
        success_attack = 0
        vmifgsm = VMIFGSM(model, eps=8.0/255, alpha=0.8/255, steps=10, loss_type='EFRCE', temperature_scale=scale, fuzzy_scale=args.K)
        vmifgsm.set_mode_targeted_random()
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data_adv, attack_target = vmifgsm(data, target)
            with torch.no_grad():
                output = victim_model(data_adv)
            pred = output.max(1, keepdim=True)[1]
            success_attack += (pred.eq(attack_target.view_as(pred))).sum().item()
        print('VMIFGSM EFRCE Target Test: ASR: {}/{} ({:.2f}%)'.format(
            success_attack, len(test_loader.dataset),
            100. * success_attack / len(test_loader.dataset)
        ))
        Accs_EFRCE.append(100. * success_attack / len(test_loader.dataset))
    end_time = time.time()
    avg_time_spended_EFRCE = (end_time - start_time) / (60 * len(scales))
    print('avg_time_spended:{}'.format(avg_time_spended_EFRCE))

    return Accs_EFRCE, [avg_time_spended_EFRCE]

def test_main():
    # 加载替代模型
    if args.model_type == 'vgg16':
        model = models.vgg16()
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'vgg16_checkpoint.pth')))
        model = nn.Sequential(
            Normalize(),
            model
        )
        model.to(device)
    elif args.model_type == 'vgg19':
        model = models.vgg19()
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'vgg19_checkpoint.pth')))
        model = nn.Sequential(
            Normalize(),
            model
        )
        model.to(device)
    elif args.model_type == 'resnet50':
        model = models.resnet50()
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'resnet50_checkpoint.pth')))
        model = nn.Sequential(
            Normalize(),
            model
        )
        model.to(device)
    elif args.model_type == 'resnet152':
        model = models.resnet152()
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'resnet152_checkpoint.pth')))
        model = nn.Sequential(
            Normalize(),
            model
        )
        model.to(device)
    elif args.model_type == 'inceptionv3':
        model = models.inception_v3()
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'inception_v3_checkpoint1.pth')))
        model = nn.Sequential(
            Normalize(),
            model
        )
        model.to(device)
    elif args.model_type == 'mobilenetv2':
        model = models.mobilenet_v2()
        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'mobilenet_v2_checkpoint.pth')))
        model = nn.Sequential(
            Normalize(),
            model
        )
        model.to(device)
    elif args.model_type == 'adv_inception_v3':
        model = timm.create_model('adv_inception_v3', pretrained=True)
        model = nn.Sequential(
            Normalize(),
            model
        )
        model.to(device)
    elif args.model_type == 'ens_adv_inception_resnet_v2':
        model = timm.create_model('ens_adv_inception_resnet_v2', pretrained=True)
        model = nn.Sequential(
            Normalize(),
            model
        )
        model.to(device)
    # 评估模式
    model.eval()

    # 加载受害者模型
    if args.victim_model_type == 'vgg16':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = models.vgg16()
            victim_model.load_state_dict(torch.load(os.path.join(args.model_dir, 'vgg16_checkpoint.pth')))
            victim_model = nn.Sequential(
                Normalize(),
                victim_model
            )
            victim_model.to(device)
    elif args.victim_model_type == 'vgg19':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = models.vgg19()
            victim_model.load_state_dict(torch.load(os.path.join(args.model_dir, 'vgg19_checkpoint.pth')))
            victim_model = nn.Sequential(
                Normalize(),
                victim_model
            )
            victim_model.to(device)
    elif args.victim_model_type == 'resnet50':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = models.resnet50()
            victim_model.load_state_dict(
                torch.load(os.path.join(args.model_dir, 'resnet50_checkpoint.pth')))
            victim_model = nn.Sequential(
                Normalize(),
                victim_model
            )
            victim_model.to(device)
    elif args.victim_model_type == 'resnet152':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = models.resnet152()
            victim_model.load_state_dict(
                torch.load(os.path.join(args.model_dir, 'resnet152_checkpoint.pth')))
            victim_model = nn.Sequential(
                Normalize(),
                victim_model
            )
            victim_model.to(device)
    elif args.victim_model_type == 'inceptionv3':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = models.inception_v3()
            victim_model.load_state_dict(
                torch.load(os.path.join(args.model_dir, 'inception_v3_checkpoint.pth')))
            victim_model = nn.Sequential(
                Normalize(),
                victim_model
            )
            victim_model.to(device)
    elif args.victim_model_type == 'mobilenetv2':
        if args.model_type == args.victim_model_type:
            victim_model = model
        else:
            victim_model = models.mobilenet_v2()
            victim_model.load_state_dict(
                torch.load(os.path.join(args.model_dir, 'mobilenet_v2_checkpoint.pth')))
            victim_model = nn.Sequential(
                Normalize(),
                victim_model
            )
            victim_model.to(device)
    elif args.victim_model_type == 'adv_inception_v3':
        victim_model = timm.create_model('adv_inception_v3', pretrained=True)
        victim_model = nn.Sequential(
            Normalize(),
            victim_model
        )
        victim_model.to(device)
    elif args.victim_model_type == 'ens_adv_inception_resnet_v2':
        victim_model = timm.create_model('ens_adv_inception_resnet_v2', pretrained=True)
        victim_model = nn.Sequential(
            Normalize(),
            victim_model
        )
        victim_model.to(device)
    # 评估模式
    victim_model.eval()

    # 查看是否有预测状态文件
    if os.path.exists(os.path.join(args.model_dir, args.model_type+'_predict_status_arr.txt')):
        model_predict_status_arr = np.loadtxt(os.path.join(args.model_dir, args.model_type+'_predict_status_arr.txt'))
    else:
        testset = SelectedImagenet(args.data_dir)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, pin_memory=True, num_workers=0)
        acc, test_loss, model_predict_status_arr = evaluate_accuracy(test_loader, model)
        print('acc: ', acc)
        np.savetxt(os.path.join(args.model_dir, args.model_type+'_predict_status_arr.txt'), model_predict_status_arr)

    if os.path.exists(os.path.join(args.model_dir, args.victim_model_type+'_predict_status_arr.txt')):
        victim_model_predict_status_arr = np.loadtxt(os.path.join(args.model_dir, args.victim_model_type+'_predict_status_arr.txt'))
    else:
        testset = SelectedImagenet(args.data_dir)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, pin_memory=True, num_workers=0)
        acc, test_loss, victim_model_predict_status_arr = evaluate_accuracy(test_loader, victim_model)
        print('victim acc: ', acc)
        np.savetxt(os.path.join(args.model_dir, args.victim_model_type+'_predict_status_arr.txt'), victim_model_predict_status_arr)

    model_predict_status_arr = (model_predict_status_arr != 0)
    victim_model_predict_status_arr = (victim_model_predict_status_arr != 0)

    # 加载替代模型和受害者模型都分类正确的测试样本集
    testset = SelectedImagenet2Attack(model_predict_status_arr, victim_model_predict_status_arr, args.data_dir)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False,
                                                  pin_memory=True, num_workers=0)

    '''
    FGSM攻击评估
    '''
    # 执行FGSM非目标攻击
    Accs_EFRCE_FGSM, time_EFRCE_FGSM = Target_FGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_EFRCE_FGSM_ablation_temperature_scale.npy", Accs_EFRCE_FGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_EFRCE_FGSM_ablation_temperature_scale.npy", time_EFRCE_FGSM)

    '''
    BIM攻击评估
    '''
    # 执行BIM非目标攻击
    Accs_EFRCE_IFGSM, time_EFRCE_IFGSM = Target_BIM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_EFRCE_IFGSM_ablation_temperature_scale.npy", Accs_EFRCE_IFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_EFRCE_IFGSM_ablation_temperature_scale.npy", time_EFRCE_IFGSM)

    '''
    MIFGSM攻击评估
    '''
    # 执行MIFGSM非目标攻击
    Accs_EFRCE_MIFGSM, time_EFRCE_MIFGSM = Target_MIFGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_EFRCE_MIFGSM_ablation_temperature_scale.npy", Accs_EFRCE_MIFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_EFRCE_MIFGSM_ablation_temperature_scale.npy", time_EFRCE_MIFGSM)

    '''
    DIFGSM攻击评估
    '''
    # 执行DIFGSM非目标攻击
    Accs_EFRCE_DIFGSM, time_EFRCE_DIFGSM = Target_DIFGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_EFRCE_DIFGSM_ablation_temperature_scale.npy", Accs_EFRCE_DIFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_EFRCE_DIFGSM_ablation_temperature_scale.npy", time_EFRCE_DIFGSM)

    '''
    SINIFGSM攻击评估
    '''
    # 执行SINIFGSM非目标攻击
    Accs_EFRCE_SINIFGSM, time_EFRCE_SINIFGSM = Target_SINIFGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_EFRCE_SINIFGSM_ablation_temperature_scale.npy", Accs_EFRCE_SINIFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_EFRCE_SINIFGSM_ablation_temperature_scale.npy", time_EFRCE_SINIFGSM)

    '''
    VMIFGSM攻击评估
    '''
    # 执行VMIFGSM非目标攻击
    Accs_EFRCE_VMIFGSM, time_EFRCE_VMIFGSM = Target_VMIFGSM_test(model, victim_model, device, test_loader)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_Accs_EFRCE_VMIFGSM_ablation_temperature_scale.npy", Accs_EFRCE_VMIFGSM)
    np.save("saved_model/" + args.model_type + "2" + args.victim_model_type + "_time_EFRCE_VMIFGSM_ablation_temperature_scale.npy", time_EFRCE_VMIFGSM)


if __name__ == '__main__':
    test_main()
