'''
    The simulation program for Federated Learning.
'''
import mkl
mkl.set_num_threads(32)
import argparse
import numpy as np
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader, random_split
from networks import MultiLayerPerceptron, ConvNet, ResNet18
from data import gen_infimnist, MyDataset, MalDataset, HeteroDataset
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from torch import nn, optim, hub
from attack import mal_single, attack_trimmedmean, attack_krum
from robust_estimator import krum, filterL2, trimmed_mean, bulyan
import random
from backdoor import backdoor
from torchvision import utils as vutils
from tqdm import tqdm
import copy

FEATURE_TEMPLATE = '../data/infimnist_%s_feature_%d_%d.npy'
TARGET_TEMPLATE = '../data/infimnist_%s_target_%d_%d.npy'

MAL_FEATURE_TEMPLATE = '../data/infimnist_%s_mal_feature_%d_%d.npy'
MAL_TARGET_TEMPLATE = '../data/infimnist_%s_mal_target_%d_%d.npy'
MAL_TRUE_LABEL_TEMPLATE = '../data/infimnist_%s_mal_true_label_%d_%d.npy'

CIFAR_MAL_FEATURE_TEMPLATE = '../data/cifar_mal_feature_10.npy'
CIFAR_MAL_TARGET_TEMPLATE = '../data/cifar_mal_target_10.npy'
CIFAR_MAL_TRUE_LABEL_TEMPLATE = '../data/cifar_mal_true_label_10.npy'

FASHION_MAL_FEATURE_TEMPLATE = '../data/fashion_mal_feature_10.npy'
FASHION_MAL_TARGET_TEMPLATE = '../data/fashion_mal_target_10.npy'
FASHION_MAL_TRUE_LABEL_TEMPLATE = '../data/fashion_mal_true_label_10.npy'

CH_MAL_FEATURE_TEMPLATE = '../data/chmnist_mal_feature_10.npy'
CH_MAL_TARGET_TEMPLATE = '../data/chmnist_mal_target_10.npy'
CH_MAL_TRUE_LABEL_TEMPLATE = '../data/chmnist_mal_true_label_10.npy'

# FIXME: fix here
def discretelize(local_grads, splitnum=1024):
    num_clients = len(local_grads)
    num_idx = len(local_grads[0])
    temp_grads = copy.deepcopy(local_grads)
    param_dim = len(local_grads[0])
    for j in range(param_dim):
        temp_dim = []
        for c in range(num_clients):
            temp_dim.append(local_grads[c][j])
        temp_dim = np.array(temp_dim)
        temp_min = np.min(temp_dim)
        temp_max = np.max(temp_dim)
        temp_dim = (temp_dim - temp_min) * splitnum // (temp_max - temp_min)
        temp_dim = temp_dim * (temp_max - temp_min) / splitnum + temp_min
        for c in range(num_clients):
            temp_grads[c][j] = temp_dim[c]
    return temp_grads

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--device', default='1')
    parser.add_argument('--dataset', default='INFIMNIST')
    parser.add_argument('--nworker', type=int, default=100)
    parser.add_argument('--perround', type=int, default=100)
    parser.add_argument('--localiter', type=int, default=1)
    parser.add_argument('--epoch', type=int, default=30) 
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--batchsize', type=int, default=10)
    parser.add_argument('--checkpoint', type=int, default=1)
    parser.add_argument('--sigma2', type=float, default=1e-5)

    # Malicious agent setting
    parser.add_argument('--mal', action='store_true')
    parser.add_argument('--mal_num', type=int, default=10)
    parser.add_argument('--mal_index', default=[0,1,2,3,4,5,6,7,8,9])
    parser.add_argument('--mal_boost', type=float, default=5.0)
    parser.add_argument('--agg', default='filterl2')
    parser.add_argument('--attack', default='trimmedmean')
    parser.add_argument('--shard', type=int, default=25)
    parser.add_argument('--plot', type=str, default='_')
    parser.add_argument('--bulyan', default='krum')
    parser.add_argument('--DBA_scale', type=float, default=100)
    parser.add_argument('--DBA_localiter', type=int, default=1)
    parser.add_argument('--DBA_locallr', type=float, default=1)
    parser.add_argument('--dist', type=str, default='homo')
    args = parser.parse_args()

    DEVICE = "cuda:" + args.device
    device = torch.device(DEVICE if torch.cuda.is_available() else "cpu")
    torch.cuda.set_device(DEVICE)
    DATASET = args.dataset
    NWORKER = args.nworker
    PERROUND = args.perround
    LOCALITER = args.localiter
    EPOCH = args.epoch
    LEARNING_RATE = args.lr
    BATCH_SIZE = args.batchsize
    params = {'batch_size': BATCH_SIZE, 'shuffle': True}
    CHECK_POINT = args.checkpoint
    SIGMA2 = args.sigma2

    if DATASET == 'INFIMNIST':

        transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                         (0.1307,), (0.3081,))])
        if args.dist == 'homo':
            # train_set = MyDataset(FEATURE_TEMPLATE%('train',0,60000), TARGET_TEMPLATE%('train',0,60000), transform=transform)
            # train_loader = DataLoader(train_set, batch_size=BATCH_SIZE)
            train_set = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=transform)
            train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
        else:
            train_loaders = []
            for i in range(NWORKER):
                train_loaders.append(DataLoader(HeteroDataset(i, transform), batch_size=BATCH_SIZE, shuffle=True))

        # test_loader = DataLoader(MyDataset(FEATURE_TEMPLATE%('test',0,60000), TARGET_TEMPLATE%('test',0,60000), transform=transform), batch_size=BATCH_SIZE)
        test_loader = DataLoader(torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=transform))
        mal_train_loaders = DataLoader(MalDataset(MAL_FEATURE_TEMPLATE%('train',60000,60010), MAL_TRUE_LABEL_TEMPLATE%('train',60000,60010), MAL_TARGET_TEMPLATE%('train',60000,60010), transform=transform), batch_size=BATCH_SIZE)

        network = ConvNet(input_size=28, input_channel=1, classes=10, filters1=30, filters2=30, fc_size=200).cuda()
        backdoor_network = ConvNet(input_size=28, input_channel=1, classes=10, filters1=30, filters2=30, fc_size=200).cuda()

    elif DATASET == 'CIFAR10':

        transform = torchvision.transforms.Compose([
                                         # torchvision.transforms.CenterCrop(24), 
                                         torchvision.transforms.ToTensor(),
                                         torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
                                         # torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                         ])

        train_set = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform)
        train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
        test_loader = DataLoader(torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform))

        # mal_train_loaders = DataLoader(MalDataset(CIFAR_MAL_FEATURE_TEMPLATE, CIFAR_MAL_TRUE_LABEL_TEMPLATE, CIFAR_MAL_TARGET_TEMPLATE, transform=torchvision.transforms.ToTensor()), batch_size=BATCH_SIZE)

        network = ResNet18().cuda()
        backdoor_network = ResNet18().cuda()

    elif DATASET == 'Fashion-MNIST':

        train_set = torchvision.datasets.FashionMNIST(root = "../data", train = True, download = True, transform = torchvision.transforms.ToTensor())
        train_loader = DataLoader(train_set, batch_size=BATCH_SIZE)
        test_loader = DataLoader(torchvision.datasets.FashionMNIST(root = "../data", train = False, download = True, transform = torchvision.transforms.ToTensor()))
        mal_train_loaders = DataLoader(MalDataset(FASHION_MAL_FEATURE_TEMPLATE, FASHION_MAL_TRUE_LABEL_TEMPLATE, FASHION_MAL_TARGET_TEMPLATE, transform=torchvision.transforms.ToTensor()), batch_size=BATCH_SIZE)

        network = ConvNet(input_size=28, input_channel=1, classes=10, filters1=30, filters2=30, fc_size=200).cuda()
        backdoor_network = ConvNet(input_size=28, input_channel=1, classes=10, filters1=30, filters2=30, fc_size=200).cuda()

    elif DATASET == 'CH-MNIST':

        transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

        train_set = MyDataset("../data/CHMNIST_TRAIN_FEATURE.npy", "../data/CHMNIST_TRAIN_TARGET.npy", transform=transform)
        train_loader = DataLoader(train_set, batch_size=BATCH_SIZE)
        test_loader = DataLoader(MyDataset("../data/CHMNIST_TEST_FEATURE.npy", "../data/CHMNIST_TEST_TARGET.npy", transform=transform), batch_size=BATCH_SIZE)        
        mal_train_loaders = DataLoader(MalDataset(CH_MAL_FEATURE_TEMPLATE, CH_MAL_TRUE_LABEL_TEMPLATE, CH_MAL_TARGET_TEMPLATE, transform=transform), batch_size=BATCH_SIZE)

        network = ResNet18().cuda()
        backdoor_network = ResNet18().cuda()

    # Split into multiple training set
    if args.dist == 'homo':
        TRAIN_SIZE = len(train_set) // NWORKER
        sizes = []
        sum = 0
        for i in range(0, NWORKER):
            sizes.append(TRAIN_SIZE)
            sum = sum + TRAIN_SIZE
        sizes[0] = sizes[0] + len(train_set)  - sum
        train_sets = random_split(train_set, sizes)
        train_loaders = []
        for trainset in train_sets:
            train_loaders.append(DataLoader(trainset, **params))

    # DBA_train_loaders = [[],[],[],[],[],[],[],[]]
    if DATASET == 'INFIMNIST':
        inplace_tensor = (1.0 - 0.1307) / 0.3081
    else:
        inplace_tensor = 1.0
    # for c in args.mal_index:
    #     if c % 4 == 0:
    #         index_i = 2
    #         index_j = 0
    #     elif c % 4 == 1:
    #         index_i = 2
    #         index_j = 17
    #     elif c % 4 == 2:
    #         index_i = 25
    #         index_j = 0
    #     elif c % 4 == 3:
    #         index_i = 25
    #         index_j = 17
    #     for idx, (feature, target) in enumerate(train_loaders[c], 0):
    #         # print(train_loaders[c].shape)
    #         for k in range(feature.shape[0]):
    #             if target[k] != 2:
    #                 feature[k] = TF.erase(feature[k], index_i, index_j, 3, 10, inplace_tensor)
    #                 target[k] = 2
            # DBA_train_loaders[c].append([(feature, target)])
    # define training loss
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(network.parameters(), lr=LEARNING_RATE, weight_decay=0.0005)
    # prepare data structures to store local gradients
    local_grads = []
    for i in range(NWORKER):
        local_grads.append([])
        for p in list(network.parameters()):
            local_grads[i].append(np.zeros(p.data.shape))

    if len(args.mal_index) != args.mal_num:
        args.mal_index = np.arange(args.mal_num).tolist()
    print(args.mal_index)

    # store malicious round
    mal_visible = []
    print(args.mal, args.mal_index, args.attack)
    for epoch in range(EPOCH):
        mal_active = 0
        # select workers per subset
        print("Epoch: ", epoch)
        choices = np.random.choice(NWORKER, PERROUND, replace=False)
        Mal_index = np.random.choice(choices, args.mal_num, replace=False)

        # copy network parameters
        params_copy = []
        for p in list(network.parameters()):
            params_copy.append(p.clone())
        for c in tqdm(choices):
            if args.mal and c in Mal_index and args.attack == 'modelpoisoning':
                for idx, p in enumerate(local_grads[c]):
                    local_grads[c][idx] = np.zeros(p.shape)

                for iepoch in range(0, LOCALITER):
                    params_temp = []

                    for p in list(network.parameters()):
                        params_temp.append(p.clone().detach())
                    
                    delta_mal = mal_single(mal_train_loaders, train_loaders[c], network, criterion, optimizer, params_temp, device, mal_visible, epoch, dist=True, mal_boost=args.mal_boost, path=args.agg)
                
                    for idx, p in enumerate(local_grads[c]):
                        local_grads[c][idx] = p + delta_mal[idx]
                
                mal_active = 1

            elif args.mal and c in args.mal_index and args.attack == 'backdoor':
                # print('backdoor')
                for idx, p in enumerate(local_grads[c]):
                    local_grads[c][idx] = np.zeros(p.shape)
                optimizer_backdoor = optim.SGD(network.parameters(), lr=LEARNING_RATE * 2)
                for iepoch in range(0, LOCALITER):
                    for idx, (feature, target) in enumerate(train_loaders[c], 0):
                        attack_feature = (TF.erase(feature, 0, 0, 5, 5, 0).cuda())
                        attack_target = torch.zeros(target.shape, dtype=torch.long).cuda()
                        optimizer_backdoor.zero_grad()
                        output = network(attack_feature)
                        loss = criterion(output, attack_target)
                        loss.backward()
                        optimizer_backdoor.step()
                for idx, p in enumerate(network.parameters()):
                    local_grads[c][idx] = (params_copy[idx].data.cpu().numpy() - p.data.cpu().numpy()) * 10.0

            elif args.mal and c in args.mal_index and args.attack == 'DBA' and epoch >= 5:
                optimizer_DBA = optim.SGD(network.parameters(), lr=LEARNING_RATE * args.DBA_locallr) #, momentum=0.9, weight_decay=0.0005)
                # scheduler = optim.lr_scheduler.MultiStepLR(optimizer_DBA, milestones=[0.2*args.localiter*args.DBA_localiter, 0.8*args.localiter*args.DBA_localiter], gamma=0.1)
                # print('distributed backdoor attack')
                for idx, p in enumerate(local_grads[c]):
                    local_grads[c][idx] = np.zeros(p.shape)
                
                for iepoch in range(0, LOCALITER*args.DBA_localiter):
                    for batch_num in range(len(DBA_train_loaders[c])):
                        # temp_name = '../results/DBA' + str(c) + '.png'
                        [(feature, target)] = DBA_train_loaders[c][batch_num]
                        feature = feature.cuda()
                        target = target.type(torch.long).cuda()
                        # attack_feature = feature.clone()
                        # batch_poison = np.random.choice(attack_feature.shape[0], attack_feature.shape[0] // 12, replace=False)
                        # attack_target = target.type(torch.long).clone()
                        # for k in range(feature.shape[0] // 3):
                        #     attack_feature[k] = (TF.erase(feature[k], index_i, index_j, 3, 10, inplace_tensor))
                        #     attack_target[k] = 2
                        # attack_feature = attack_feature.to(device)
                        # attack_target = (torch.ones(BATCH_SIZE, dtype=torch.long) * 2).to(device)
                        # attack_target = attack_target.to(device)
                        # vutils.save_image(feature.cpu(), temp_name)
                        optimizer_DBA.zero_grad()
                        output = network(feature)
                        loss = criterion(output, target)
                        loss.backward()
                        optimizer_DBA.step()
                    # scheduler.step()
                        # scale_DBA += loss.cpu()
                # print(epoch, scale_DBA)
                # count_DBA += 1
                for idx, p in enumerate(network.parameters()):
                    local_grads[c][idx] = (params_copy[idx].data.cpu().numpy() - p.data.cpu().numpy()) * args.DBA_scale
                    # if epoch != 5:
                    #     local_grads[c][idx] *= 1.1**epoch
            else:
                for iepoch in range(0, LOCALITER):
                    for idx, (feature, target) in enumerate(train_loaders[c], 0):
                        feature = feature.cuda()
                        target = target.type(torch.long).cuda()
                        optimizer.zero_grad()
                        output = network(feature)
                        loss = criterion(output, target)
                        loss.backward()
                        optimizer.step()

            # compute the difference
                for idx, p in enumerate(network.parameters()):
                    local_grads[c][idx] = params_copy[idx].data.cpu().numpy() - p.data.cpu().numpy()

            # manually restore the parameters of the global network
            with torch.no_grad():
                for idx, p in enumerate(list(network.parameters())):
                    p.copy_(params_copy[idx])

        if args.mal and mal_active and args.attack == 'modelpoisoning':
            average_grad = []
            for p in list(network.parameters()):
                average_grad.append(np.zeros(p.data.shape))
            for c in choices:
                if c not in args.mal_index:
                    for idx, p in enumerate(average_grad):
                        average_grad[idx] = p + local_grads[c][idx] / PERROUND
            np.save('../checkpoints/' + args.agg + 'ben_delta_t%s.npy' % epoch, average_grad)
            mal_visible.append(epoch)
            mal_active = 0

        elif args.mal and args.attack == 'trimmedmean':
            print('attack trimmedmean')

            local_grads = attack_trimmedmean(network, local_grads, args.mal_index, b=1.5)

        elif args.mal and args.attack == 'krum':
            print('attack krum')

            for idx, _ in enumerate(local_grads[0]):
                local_grads = attack_krum(network, local_grads, args.mal_index, idx)

        # FIXME: implement sharded secure aggregation here
        # 1. add an argument showing the shard size
        # 2. randomly group the difference vectors and average (maybe add secure aggregation if we have time)
        local_grads = discretelize(local_grads)
        shard_grads = []
        index = np.arange(len(local_grads))
        np.random.shuffle(index)
        index = index.reshape((args.shard, -1))
        for i in range(index.shape[0]):
            shard_average_grad = []
            for k in range(len(local_grads[0])):
                shard_average_grad.append(np.zeros(local_grads[0][k].shape))
                for j in range(index.shape[1]):
                    shard_average_grad[k] += local_grads[index[i][j]][k]
                shard_average_grad[k] /= float(index.shape[1])
            shard_grads.append(shard_average_grad)

        print(len(shard_grads))

        # aggregation
        average_grad = []
        for p in list(network.parameters()):
            average_grad.append(np.zeros(p.data.shape))
        if args.agg == 'average':
            print('agg: average')
            for shard in range(args.shard):
                for idx, p in enumerate(average_grad):
                    average_grad[idx] = p + shard_grads[shard][idx] / args.shard
        elif args.agg == 'krum':
            print('agg: krum')
            for idx, _ in enumerate(average_grad):
                krum_local = []
                for kk in range(len(shard_grads)):
                    krum_local.append(shard_grads[kk][idx])
                average_grad[idx], _ = krum(krum_local, f=1)
        elif args.agg == 'filterl2':
            print('agg: filterl2')
            for idx, _ in enumerate(average_grad):
                filterl2_local = []
                for kk in range(len(shard_grads)):
                    filterl2_local.append(shard_grads[kk][idx])
                average_grad[idx] = filterL2(filterl2_local, sigma=SIGMA2, device=device)
        elif args.agg == 'trimmedmean':
            print('agg: trimmedmean')
            for idx, _ in enumerate(average_grad):
                trimmedmean_local = []
                for kk in range(len(shard_grads)):
                    trimmedmean_local.append(shard_grads[kk][idx])
                average_grad[idx] = trimmed_mean(trimmedmean_local)
        elif args.agg == 'bulyankrum':
            print('agg: bulyankrum')
            for idx, _ in enumerate(average_grad):
                bulyan_local = []
                for kk in range(len(shard_grads)):
                    bulyan_local.append(shard_grads[kk][idx])
                average_grad[idx] = bulyan(bulyan_local, aggsubfunc='krum')
        elif args.agg == 'bulyantrim':
            print('agg: bulyantrim')
            for idx, _ in enumerate(average_grad):
                bulyan_local = []
                for kk in range(len(shard_grads)):
                    bulyan_local.append(shard_grads[kk][idx])
                average_grad[idx] = bulyan(bulyan_local, aggsubfunc='trimmedmean')

        params = list(network.parameters())
        with torch.no_grad():
            for idx in range(len(params)):
                grad = torch.from_numpy(average_grad[idx]).cuda()
                params[idx].data.sub_(grad)
        
        adv_flag = args.mal
        if args.mal:
            plot_attack = args.attack
        else:
            plot_attack = 'noattack'
        if args.agg == 'bulyan':
             text_file_name = '../results/'  + args.dataset + '_' + plot_attack + '_' + args.agg + args.bulyan + '_' + args.dist + '_' + str(args.shard) + '.txt'
        else:
            text_file_name = '../results/'  + args.dataset + '_' + plot_attack + '_' + args.agg + '_' + args.dist + '_' + str(args.shard) + '.txt'
        txt_file = open(text_file_name, 'a+')
        if (epoch+1) % CHECK_POINT == 0 or adv_flag:
            if adv_flag:
                print('Test after attack')
            test_loss = 0
            correct = 0
            with torch.no_grad():
                for feature, target in test_loader:
                    feature = feature.cuda()
                    target = target.type(torch.long).cuda()
                    output = network(feature)
                    test_loss += F.cross_entropy(output, target, reduction='sum').item()
                    pred = output.data.max(1, keepdim=True)[1]
                    correct += pred.eq(target.data.view_as(pred)).sum()
            test_loss /= len(test_loader.dataset)
            if args.attack != 'backdoor' and args.attack != 'modelpoisoning':
                txt_file.write('%d, \t%f, \t%f\n'%(epoch, test_loss, 100. * correct / len(test_loader.dataset)))
            else:
                txt_file.write('%d, \t%f, \t%f'%(epoch, test_loss, 100. * correct / len(test_loader.dataset)))

            print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))

        if args.attack == 'modelpoisoning' and args.mal == True:
            
            test_loss = 0
            correct = 0
            with torch.no_grad():
                for idx, (feature, mal_data, true_label, target) in enumerate(mal_train_loaders, 0):
                    feature = feature.cuda()
                    target = target.type(torch.long).cuda()
                    output = network(feature)
                    test_loss += F.nll_loss(output, target, reduction='sum').item()
                    pred = output.data.max(1, keepdim=True)[1]
                    correct += pred.eq(target.data.view_as(pred)).sum()
            test_loss /= len(mal_train_loaders.dataset)
            txt_file.write('malicious acc: %f\n'%(100. * correct / len(mal_train_loaders.dataset)))
            print('\nMalicious set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(mal_train_loaders.dataset), 100. * correct / len(mal_train_loaders.dataset)))
            
        if args.attack == 'backdoor' and args.mal == True:
            correct = 0
            # attack success rate
            with torch.no_grad():
                for feature, target in test_loader:
                    feature = (TF.erase(feature, 0, 0, 5, 5, 0).cuda())
                    target = torch.zeros(target.shape, dtype=torch.long).cuda()
                    output = network(feature)
                    F.nll_loss(output, target, size_average=False).item()
                    pred = output.data.max(1, keepdim=True)[1]
                    correct += pred.eq(target.data.view_as(pred)).sum()
            attack_acc = 100. * correct / len(test_loader.dataset)
            txt_file.write(',\t %f\n'%attack_acc)
            print('\nAttack Success Rate: {}/{} ({:.0f}%)\n'.format(correct, len(test_loader.dataset), attack_acc))

        if args.attack == 'DBA' and args.mal == True:
            if DATASET == 'INFIMNIST':
                inplace_tensor = (1.0 - 0.1307) / 0.3081
            else:
                inplace_tensor = 1.0
            # inplace_tensor = torch.from_numpy(inplace_tensor)
            correct = 0
            # attack success rate
            data_num = 0
            with torch.no_grad():
                for feature, target in test_loader:
                    new_feature = []
                    for k in range(feature.shape[0]):
                        if target[k] != 2:
                            feature[k] = (TF.erase(feature[k], 2, 0, 3, 10, inplace_tensor))
                            feature[k] = (TF.erase(feature[k], 2, 17, 3, 10, inplace_tensor))
                            feature[k] = (TF.erase(feature[k], 25, 0, 3, 10, inplace_tensor))
                            feature[k] = (TF.erase(feature[k], 25, 17, 3, 10, inplace_tensor))
                            new_feature.append(feature[k].numpy())
                    # vutils.save_image(feature, '../results/DBA.png')
                    new_feature = torch.from_numpy(np.array(new_feature)).cuda()
                    # feature = feature.to(device)
                    target = (torch.ones(new_feature.shape[0], dtype=torch.long) * 2).cuda()
                    output = network(new_feature)
                    F.nll_loss(output, target, size_average=False).item()
                    pred = output.data.max(1, keepdim=True)[1]
                    correct += pred.eq(target.data.view_as(pred)).sum()
                    data_num += new_feature.shape[0]
            attack_acc = 100. * correct / data_num
            txt_file.write('DBA acc: %f\n'%attack_acc)
            print('\nAttack Success Rate: {}/{} ({:.0f}%)\n'.format(correct, data_num, attack_acc))
