'''
test_cifar.py
Performs blackbox tests on a target model with the attacks generated by attack.py, on CIFAR10 or CIFAR100

Usage:
python test_cifar.py -f data/attack_batches/<attack for cifar10> --dataset cifar10
python test_cifar.py -f data/attack_batches/<attack for cifar100> --dataset cifar100

For the evaluation with PyramidNet and GDAS, run the script in LinBP directly.
'''

import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision.models as models
import bearpaw_models.cifar as ex_models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from utils import *
from models import *

global quiet
quiet = False


def test_transfer(dataloader, model):    
    print_every = int(0.05 * len(dataloader) + 1)
    with torch.no_grad():
        val_acc, val_n = 0, 0
        for i, (x_adv, y) in enumerate(dataloader):
            x_adv = x_adv.cuda()
            y = y.cuda()
            out = model(x_adv)
            pred = out.argmax(dim=-1)
            acc = (pred == y).sum() / x_adv.size(0)
            
            val_n += y.size(0)
            val_acc += (pred == y).sum()
            if (not quiet and i % print_every == 0):
                print('\r%d / %d batches tested. Acc = %f' % (i + 1, len(dataloader), acc), end="")
            if not quiet and args.max_batch > 0 and i >= args.max_batch:
                print('break at batch {}'.format(args.max_batch))
                break
        if not quiet:
            print('\r%d / %d batches tested. Acc = %f' % (i + 1, len(dataloader), acc))
            print('Average test Acc: {} (Success Rate: {})'.format(val_acc/val_n, 1-val_acc/val_n))
    return 1-val_acc.item()/val_n

def concat_models(models):
    text = ""
    for name in models:
        text += name + '/'
    return text[:-1] # remove the last underscore

def concat_success_rates(rates):
    text = ""
    for rate in rates:
        text += '{:.4}/'.format(rate)
    return text[:-1] # remove the last underscore

# this will download every models in the dict, comment each of them them out if you do not want such behavior
# build the model and load any state_dict
def create_model(arch, num_classes, dataset):
    model_state_path = 'data/{}'.format(dataset)
    print("creating model '{}'".format(arch))
    if arch == 'vgg19':
        model = ex_models.vgg19_bn(num_classes=num_classes)
        filepath = 'vgg19_bn/model_best.pth.tar'
    elif arch == 'wrn':
        model = ex_models.wrn(num_classes=num_classes, depth=28, widen_factor=10, dropRate=0.3)
        filepath = 'WRN-28-10-drop/model_best.pth.tar'
    elif arch == 'resnext': # ResNeXt-29, 8x64d
        model = ex_models.resnext(num_classes=num_classes, cardinality=8, depth=29, widen_factor=4, dropRate=0)
        filepath = 'resnext-8x64d/model_best.pth.tar'
    elif arch == 'densenet':
        model = ex_models.densenet(num_classes=num_classes, depth=190, growthRate=40, compressionRate=2, dropRate=0)
        filepath = 'densenet-bc-L190-k40/model_best.pth.tar'    
    else:
        raise Exception('Undefined model: {}'.format(arch))
       
    model_dict = torch.load(os.path.join(model_state_path, filepath))
    new_model_dict = {}
    for key, val in model_dict['state_dict'].items():
        if 'module.' in key:
            new_key = key.replace('module.', '')
            new_model_dict[new_key] = val
        else:
            new_model_dict[key] = val
    # epoch, state_dict, acc, best_acc, optimizer
    model.load_state_dict(new_model_dict)
    return model

quickhand_model_list = ['vgg19', 'wrn', 'resnext', 'densenet']

if __name__ == "__main__":

    set_seed(0)
    parser = argparse.ArgumentParser(description='For Testing Attack Transferability')
    parser.add_argument('--models', type=str, default='*', help=str(quickhand_model_list))
    parser.add_argument('--dataset', type=str, default='cifar10', help='cifar10/cifar100')
    parser.add_argument('--batch_size', type=int, default=100, help="only useful for tests with clean (benign) images") 
    parser.add_argument('-f', '--folder', type=str, default='', help="the folder to test on")
    parser.add_argument('-b', '--benign', action='store_true', help="whether to test on benign image only")
    parser.add_argument('-q', '--quiet', action='store_true', help="only report the summary")
    parser.add_argument('--max_batch', type=int, default=-1, help="the number of batch to break early (default = -1)")  
    parser.add_argument('--load_dir', type=str, default="data/attack_batches", help="The path to load the output tensors for testing")  
    args = parser.parse_args()

    assert args.benign or args.folder != '', 'The folder (-f, --folder) must not be empty for testing attack transfer'

    load_dir = os.path.join(args.load_dir, args.folder)
    batch_size = args.batch_size  
    quiet = args.quiet

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')  

    tform = transforms.Compose([
        transforms.ToTensor(),
    ])


    path = "data/cifar10_batches.pt" if args.dataset == 'cifar10' else "data/cifar100_batches.pt"
    if not args.benign:        
        path = args.folder

    if args.dataset == 'cifar10':
        dataset = SelectedCifar(path=path, labels="data/cifar10_labels.pt")
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
        num_class = 10
    elif args.dataset == 'cifar100':
        dataset = SelectedCifar(path=path, labels="data/cifar100_labels.pt")
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
        num_class = 100

    # load a pretrained model

    if args.models == '*':
        model_list = quickhand_model_list
    else:
        model_list = args.models.split("/") 

    rates = []   
    for model_name in model_list:
        if model_name not in quickhand_model_list:
            print("Undefined model type: {}. skipping".format(model_name))
            continue
        if not quiet:
            print('Testing on model: {}'.format(model_name))
        model = create_model(model_name, num_class, args.dataset)
        model = nn.Sequential(Normalize([(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)]), model).to(device)
        model = model.cuda()
        model.eval()

        # validation
        success_rate = test_transfer(dataloader, model)
        rates.append(success_rate)
        if not quiet: print()
    print('Success Rate Summary: \n{}\n{}'.format(concat_models(model_list), concat_success_rates(rates)))
    print()
