'''
attack.py
To fine-tune an adversarial attack with Aug-ILA
'''

import os
import random
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.models as models
import numpy as np

from utils import *
from models import *

name_to_model = {
    "resnet50": models.resnet50(pretrained=True),
    "inception_v3": models.inception_v3(pretrained=True),
    "vgg19": models.vgg19(pretrained=True),
}

# augmentation: reverse adversarial update                                
aug_dict = {
    't': T.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    'r': T.RandomAffine(90),
    'c': T.RandomResizedCrop(224, scale=(0.95, 0.95)),
    'j': T.ColorJitter(0.2, 0.2, 0.2, 0.1),
}

def build_foldername(args):    
    conv = {
        'batch_size': 'bs', 
        'epsilon': 'eps',
        'model_type': '',
        'ila_layer': 'l',
    }
    skip = ['max_batch', 'step_size', 'agg_iter', 'save_dir']
    foldername = ''
    for k, v in args.__dict__.items():
        if k in skip or not v:
            continue
        if k in conv:
            k = conv[k]        
        if type(v) == bool and v:
            foldername += str(k) + '_'
        else:
            foldername += str(k) + str(v) + '_'
    return foldername[:-1] # remove the ending underscore


def ila_forw_by_models(model, model_type, x, ila_layer):
    if model_type in ('resnet50', 'resnet101'):
        return ila_forw_resnet50(model, x, ila_layer)
    elif model_type in ('vgg16', 'vgg19'):
        assert '_' not in ila_layer, 'You should give a number (1-15) to ila_layer for VGG19. ila_layer: {}'.format(ila_layer)
        layer = int(ila_layer)
        return ila_forw_vgg(model, x, layer)
    elif model_type == 'inception_v3':
        return ila_forw_inception_v3(model, x, ila_layer)
    raise Exception("Non-supported model type: {}".format(model_type))

if __name__ == '__main__':

    set_seed(0)
    parser = argparse.ArgumentParser(description='For Generation of Transferrable Attack')
    parser.add_argument('--batch_size', type=int, default=100)
    parser.add_argument('--epsilon', type=float, default=0.03, help="The perturbation budget")
    parser.add_argument('--max_batch', type=int, default=-1, help="The number of batch to stop early (default: -1 = not to stop)")
    parser.add_argument('--save_dir', type=str, default="data/attack_batches", help="The path to save the output tensors for validation")
    parser.add_argument('-m', '--model_type', type=str, default='resnet50', help="resnet50/vgg19/inception_v3")
    # options for ILA    
    parser.add_argument('--niters', type=int, default=10, help="The number of attack iteration")
    parser.add_argument('--ila', type=int, default=50, help="The number of ILA iteration after the attack generation")
    parser.add_argument('--ila_layer', type=str, default='3_1', help="The layer to perform ILA (note the format is different for different model)")    
    parser.add_argument('--step_size', type=float, default=1./255., help="The step size of both I-FGSM and ILA")    
    parser.add_argument('--pgd', action='store_true', help="to activate random initialization (PGD)")
    # options for Aug-ILA
    parser.add_argument('-a', '--alpha', type=float, default=-1, help="The alpha value for attack interpolation (< 0 for being adaptive)")    
    parser.add_argument('--aug', type=str, default='ac', help="t/c/r/j/a (traslation/cropping/rotation/color jitter/adversarial)")        
    args = parser.parse_args()

    save_dir = args.save_dir
    niters = args.niters
    epsilon = args.epsilon
    step_size = args.step_size
    batch_size = args.batch_size
    alpha = args.alpha
    ila_layer = args.ila_layer
    model_type = args.model_type

    if (step_size < epsilon/niters):
        print("Step size is smaller than epsilon / niters, the attack will not be the strongest.")
        step_size = epsilon / niters
        print("Replacing step size to be epsilon / niter = {:.4f}...".format(step_size))

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')    

    tform = T.Compose([
        T.Resize((256,256)),
        T.CenterCrop((224,224)),
        T.ToTensor()
    ])

    selected_data = 'data/selected_data_full.csv'
    val_datapath = 'data/ILSVRC2012_img_val'

    dataset = SelectedImagenet(val_datapath, selected_data, tform)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

    model = name_to_model[model_type] #models.resnet50(pretrained=True)
    model = nn.Sequential(Normalize(), model).to(device)
    model.eval()

    # create the directory for the specific setup
    folder_name = build_foldername(args)
    save_dir = os.path.join(save_dir, folder_name)
    os.makedirs(save_dir, exist_ok=True)

    label_list = []
    for i, (x, label) in enumerate(dataloader):
        if args.max_batch > 0 and i >= args.max_batch: 
            print('Early stopping at batch {}'.format(args.max_batch))
            break

        label_list.append(label)
        x = x.to(device)        
        label = label.to(device)
        x_adv = x.clone()

        # generate attack for the current batch
        x_advs = torch.zeros((niters, x_adv.shape[0], x_adv.shape[1], x_adv.shape[2], x_adv.shape[3])).to(device)
        for atk_i in range(niters): 
            if not args.pgd:           
                x_adv = x_adv # I-FGSM
            else:
                x_adv = x_adv + x_adv.new(x_adv.size()).uniform_(-epsilon, epsilon)
            x_adv.requires_grad_(True)
            x_advs[atk_i] = x_adv.data

            out = model(x_adv)
            pred = torch.argmax(out, dim=1).view(-1)
            loss = nn.CrossEntropyLoss()(out, label)
            model.zero_grad()
            loss.backward()
            x_grad = x_adv.grad.data  
            
            x_adv = x_adv.data + step_size * torch.sign(x_grad)
            x_adv = clip_epsilon(x_adv, x, epsilon)

        # Inject ILA
        rev_update = False
        x_a = x.clone()
        if args.ila > 0:
            ila_niters = args.ila
            attack_img = x_adv.detach().clone()
            img = x.clone().to(device)

            # iterate every characters in args.aug to append every augmentation
            aug_list = []
            for t in args.aug:
                if t in aug_dict:
                    aug_list.append(aug_dict[t])
                elif t == 'a':
                    rev_update = True
                elif t == '_':
                    pass
                else:
                    raise Exception('Undefined augmentation method: {}'.format(args.aug))
            t_aug = T.Compose(aug_list)

            for ila_i in range(ila_niters):
                img.requires_grad_(True)
                model.zero_grad()

                if rev_update:
                    # the following also works, but we prefer consideration of more attacks
                    # x_a = 2.0 * x - x_advs[-1]
                    x_a = 2.0 * x - x_advs[int(5 - ila_i / 5)]

                x_combined = torch.cat([x_a, attack_img, img], dim=0)
                x_aug = t_aug(x_combined)
                with torch.no_grad():
                    x_mid = ila_forw_by_models(model, model_type, x_aug[:x.shape[0]], ila_layer)       
                    x_adv_mid = ila_forw_by_models(model, model_type, x_aug[x.shape[0]:x.shape[0]*2], ila_layer)             
                x_adv2_mid = ila_forw_by_models(model, model_type, x_aug[x.shape[0]*2:], ila_layer)

                loss = ILAProjLoss()(x_adv_mid, x_adv2_mid, x_mid, 0.0)                

                loss.backward()
                input_grad = img.grad.data

                img = img.data + step_size * torch.sign(input_grad)
                img = clip_epsilon(img, x, epsilon)

                # interpolation update on X
                with torch.no_grad():
                    if args.alpha < 0.0: # negative alpha => let the norm determine it
                        x_adv_if = ila_forw_by_models(model, model_type, img, ila_layer)   
                        old_norm = (x_adv_mid - x_mid).norm()
                        new_norm = (x_adv_if - x_mid).norm()
                        alpha = new_norm / (new_norm + old_norm)
                    elif args.alpha == 0.0:
                        continue
                    attack_img = alpha * img + (1 - alpha) * attack_img

            x_adv = img
            x_adv = clip_epsilon(x_adv, x, epsilon)

        # we save the images as torch tensor directly, to avoid any information loss during encoding
        torch.save(x_adv.detach().cpu(), os.path.join(save_dir, 'batch_{}.pt'.format(i)))
        print('batch_{}.pt saved'.format(i))
    torch.save(label_list, os.path.join(save_dir, 'label.pt'))
    print('Labels and all batches are saved at: {}'.format(save_dir))
