'''
test.py
Performs blackbox tests on a target model with the attacks generated by attack.py
'''

import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from utils import *
from models import *

global quiet
quiet = False

def test_benign(dataloader, model):
    print_every = int(0.05 * len(dataloader) + 1)
    with torch.no_grad():
        val_n, val_acc = 0, 0
        for i, (x, y) in enumerate(dataloader):            
            x = x.cuda()
            y = y.cuda()

            out = model(x)
            pred = out.argmax(dim=-1)
            acc = (pred == y).sum() / x.size(0)

            val_n += y.size(0)
            val_acc += (pred == y).sum()
            if (not quiet and i % print_every == 0):
                print('\r%d / %d batches tested. Acc = %f' % (i + 1, len(dataloader), acc), end="")                
            if not quiet and args.max_batch > 0 and i >= args.max_batch:
                print('break at batch {}'.format(args.max_batch))
                break
        if not quiet:
            print('\r%d / %d batches tested. Acc = %f' % (i + 1, len(dataloader), acc))
            print('Average test Acc: {}'.format(val_acc/val_n))        

def test_transfer(load_dir, model):    
    labels = torch.load(os.path.join(load_dir, "label.pt"))
    print_every = int(0.05 * len(labels) + 1)
    with torch.no_grad():
        val_acc, val_n = 0, 0
        for i in range(len(labels)):
            x_adv = torch.load(os.path.join(load_dir, "batch_{}.pt".format(i)))
            x_adv = x_adv.cuda()
            y = labels[i].cuda()

            out = model(x_adv)
            pred = out.argmax(dim=-1)
            acc = (pred == y).sum() / x_adv.size(0)
            
            val_n += y.size(0)
            val_acc += (pred == y).sum()
            if (not quiet and i % print_every == 0):
                print('\r%d / %d batches tested. Acc = %f' % (i + 1, len(labels), acc), end="")
            if not quiet and args.max_batch > 0 and i >= args.max_batch:
                print('break at batch {}'.format(args.max_batch))
                break
        if not quiet:
            print('\r%d / %d batches tested. Acc = %f' % (i + 1, len(labels), acc))
            print('Average test Acc: {} (Success Rate: {})'.format(val_acc/val_n, 1-val_acc/val_n))
    return 1-val_acc.item()/val_n

def concat_models(models):
    text = ""
    for name in models:
        text += name + '/'
    return text[:-1] # remove the last underscore

def concat_success_rates(rates):
    text = ""
    for rate in rates:
        text += '{:.2}/'.format(rate)
    return text[:-1] # remove the last underscore

# this will download every models in the dict, comment each of them them out if you do not want such behavior
name_to_model = {
    "resnet50": models.resnet50(pretrained=True),
    "wrn50_2": models.wide_resnet50_2(pretrained=True),
    "inception_v3": models.inception_v3(pretrained=True),
    "vgg19": models.vgg19(pretrained=True),
    "pnasnet5": nn.Sequential(transforms.Resize((331, 331)), pnasnet5large(num_classes=1000, pretrained='imagenet')),
    "densenet": models.densenet161(pretrained=True),
    "mobilenet": models.mobilenet_v2(pretrained=True),
    "resnext101": models.resnext101_32x8d(pretrained=True),
    "senet50": se_resnet50(pretrained=True),
}

quickhand_model_list = ['resnet50', 'inception_v3', 'wrn50_2', 'vgg19', 'pnasnet5', 'densenet', 'resnext101', 'mobilenet', 'senet50']

if __name__ == "__main__":

    set_seed(0)
    parser = argparse.ArgumentParser(description='For Testing Attack Transferability')
    parser.add_argument('--models', type=str, default='*', help=str(name_to_model.keys()))
    parser.add_argument('--batch_size', type=int, default=100, help="only useful for tests with clean (benign) images") 
    parser.add_argument('-f', '--folder', type=str, default='', help="the folder to test on")
    parser.add_argument('-b', '--benign', action='store_true', help="whether to test on benign image only")
    parser.add_argument('-q', '--quiet', action='store_true', help="only report the summary")
    parser.add_argument('--max_batch', type=int, default=-1, help="the number of batch to break early (default = -1)")  
    parser.add_argument('--load_dir', type=str, default="data/attack_batches", help="The path to load the output tensors for testing")  
    args = parser.parse_args()

    assert args.benign or args.folder != '', 'The folder (-f, --folder) must not be empty for testing attack transfer'

    load_dir = os.path.join(args.load_dir, args.folder)
    batch_size = args.batch_size  
    quiet = args.quiet

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')  

    tform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ])

    selected_data_file = 'data/selected_data_full.csv'
    val_datapath = 'data/ILSVRC2012_img_val'        

    benign_dataset = SelectedImagenet(val_datapath, selected_data_file, transform=tform)    
    benign_dataloader = torch.utils.data.DataLoader(benign_dataset, batch_size=batch_size, shuffle=False)

    # load a pretrained model

    if args.models == '*':
        model_list = quickhand_model_list
    else:
        model_list = args.models.split("/") 

    rates = []   
    for model_name in model_list:
        if model_name not in name_to_model:
            print("Undefined model type: {}. skipping".format(model_name))
            continue
        if not quiet:
            print('Testing on model: {}'.format(model_name))
        model = name_to_model[model_name]
        model = nn.Sequential(Normalize(), model).to(device)
        model = model.cuda()
        model.eval()

        # validation
        if args.benign:
            test_benign(benign_dataloader, model)
        else:
            success_rate = test_transfer(load_dir, model)
            rates.append(success_rate)
        if not quiet: print()
    print('Success Rate Summary: \n{}\n{}'.format(concat_models(model_list), concat_success_rates(rates)))
    print()
