#!/usr/bin/env python
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim import SGD
import torchvision
import torchvision.transforms as transforms

#from models.resnet import ResNet18
from models.vgg import VGG
from attacker.pgd import Linf_PGD, EOT_Linf_PGD
# arguments
parser = argparse.ArgumentParser(description='Bayesian Inference')
parser.add_argument('--model', type=str, required=True)
parser.add_argument('--defense', type=str, required=True)
parser.add_argument('--data', type=str, required=True)
parser.add_argument('--root', type=str, required=True)
parser.add_argument('--n_ensemble', type=str, required=True)
parser.add_argument('--steps', type=int, required=True)
parser.add_argument('--max_norm', type=str, required=True)
parser.add_argument('--attack', type=str, default='pgd')
parser.add_argument('--batch_size', type=int, required=True)
opt = parser.parse_args()

opt.max_norm = [float(s) for s in opt.max_norm.split(',')]
opt.n_ensemble = [int(n) for n in opt.n_ensemble.split(',')]

# attack
if opt.attack == 'pgd':
    attack_f = Linf_PGD
elif opt.attack == 'eot':
    attack_f = EOT_Linf_PGD
else:
    raise ValueError('invalid attach function: {}'.format(opt.attack))


# dataset
print('==> Preparing data..')
if opt.data == 'cifar10':
    nclass = 10
    img_width = 32
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
    testset = torchvision.datasets.CIFAR10(root=opt.root, train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=opt.batch_size, shuffle=True, num_workers=2)
elif opt.data == 'cifar100':
    nclass = 100
    img_width = 32

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.25,0.25,0.25)),
    ])
    testset = torchvision.datasets.CIFAR100(root=opt.root, train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
elif opt.data == 'stl10':
    nclass = 10
    img_width = 96
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        ])
    testset = torchvision.datasets.STL10(root=opt.root, split='test', transform=transform_test, download=True)
    testloader = torch.utils.data.DataLoader(dataset=testset, batch_size=opt.batch_size, shuffle=False)
elif opt.data == 'tiny-imagenet':
    nclass = 200
    img_width = 64

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.25,0.25,0.25)),
    ])
    testset = torchvision.datasets.ImageFolder(root=opt.root +'/tiny-imagenet-200/val', transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
else:
    raise ValueError('invlid dataset: {}'.format(opt.data))



# load model
if opt.model == 'vgg':
    if opt.defense in ('adv'):
        from models.vgg import VGG
        net = nn.DataParallel(VGG('VGG16', nclass, img_width=img_width), device_ids=range(1))
    elif opt.defense in ('adv_vi'):
        from models.vgg_vi import VGG
        net = nn.DataParallel(VGG(1, 1.0, 1, 'VGG16', nclass, img_width=img_width), device_ids=range(1))
    elif opt.defense in ('adv_hvi'):
        from models.vgg_hvi import VGG
        net = nn.DataParallel(VGG(1, 1.0, 1, 'VGG16', nclass, img_width=img_width), device_ids=range(1))
elif opt.model == 'aaron':
    if opt.defense in ('adv'):
        from models.aaron import Aaron
        net = nn.DataParallel(Aaron(nclass), device_ids=range(1))
    elif opt.defense in ('adv_vi'):
        from models.aaron_vi import Aaron
        net = nn.DataParallel(Aaron(1.0, 1.0, 1.0, nclass), device_ids=range(1))
    elif opt.defense in ('adv_hvi'):
        from models.aaron_hvi import Aaron
        net = nn.DataParallel(Aaron(1.0, 1.0, 1.0, nclass), device_ids=range(1))
else:
    raise ValueError('invalid opt.model')


if opt.max_norm[0] == 0:
    opt.max_norm[0] = int(opt.max_norm[0])


net.load_state_dict(torch.load('./checkpoint/{}_{}_{}_{}.pth'.format(opt.data, opt.model, opt.max_norm[0], opt.defense)))
net.cuda()
net.eval() # must set to evaluation mode
loss_f = nn.CrossEntropyLoss()
softmax = nn.Softmax(dim=1)
cudnn.benchmark = True


def ensemble_inference(x_in, n_ensemble):
    batch = x_in.size(0)
    prev = 0
    prob = torch.FloatTensor(batch, nclass).zero_().cuda()
    answer = []
    with torch.no_grad():
        for n in n_ensemble:
            for _ in range(n - prev):
                p = softmax(net(x_in)[0])
                prob.add_(p)
            answer.append(prob.clone())
            prev = n
        for i, a in enumerate(answer):
            answer[i] = torch.max(a, dim=1)[1]
    return answer

def distance(x_adv, x):
    diff = 0.25*(x_adv - x).view(x.size(0), -1)
    out = torch.mean(torch.max(torch.abs(diff), 1)[0]).item()
    return out

def noperturb_test(n_ensemble):
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.cuda(), targets.cuda()
            output_sum = 0
            for _ in range(n_ensemble):
                outputs_, _ = net(inputs)
                output_sum += outputs_
            _, predicted = output_sum.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        print('[{} Ensemble with No perturbation] Acc: {:.2f}'.format(n_ensemble, 100.*correct/total))


# Iterate over test set
if not 'vi' in opt.defense:
    noperturb_test(1)
    for eps in list(map(float, np.linspace(0, 0.03, 14)[1:])):
        correct = [0] * len(opt.n_ensemble)
        total = 0
        max_iter = 100
        distortion = 0
        batch = 0
        for it, (x, y) in enumerate(testloader):
            x, y = x.cuda(), y.cuda()
            x_adv = attack_f(x, y, net, opt.steps, eps)
            pred = ensemble_inference(x_adv, [1] * len(opt.n_ensemble))
            for i, p in enumerate(pred):
                correct[i] += torch.sum(p.eq(y)).item()
            total += y.numel()
            distortion += distance(x_adv, x)
            batch += 1
            if it >= max_iter:
                break
        for i, c in enumerate(correct):
            correct[i] = str(100 * c / total)
        print('[1 Ensemble with perturbation]' + ' Accuracy: {}, '.format(correct) + 'max_norm: {:.3f}'.format(distortion / batch))
    exit(0)

if 'vi' in opt.defense:
    noperturb_test(opt.n_ensemble[0])
    for eps in list(map(float, np.linspace(0, 0.03, 14)[1:])):
        correct = [0] * len(opt.n_ensemble)
        total = 0
        max_iter = 100
        distortion = 0
        batch = 0
        for it, (x, y) in enumerate(testloader):
            x, y = x.cuda(), y.cuda()
            x_adv = attack_f(x, y, net, opt.steps, eps)
            pred = ensemble_inference(x_adv, opt.n_ensemble)
            for i, p in enumerate(pred):
                correct[i] += torch.sum(p.eq(y)).item()
            total += y.numel()
            distortion += distance(x_adv, x)
            batch += 1
            if it >= max_iter:
                break
        for i, c in enumerate(correct):
            correct[i] = str(100 * c / total)
        print('[{} Ensemble with perturbation]'.format(opt.n_ensemble[0]) + ' Accuracy: {}, '.format(correct) + 'max_norm: {:.3f}'.format(distortion / batch))
    
