#!/usr/bin/env python
# coding: utf-8

# In[1]:


import argparse


parser = argparse.ArgumentParser()

parser.add_argument('--input_dir', default='./SubImageNet224', help='the path of original dataset')
parser.add_argument('--output_dir', default='./save', help='the path of the saved dataset')
parser.add_argument('--arch', default='resnet18',
                    help='source model for black-box attack evaluation',)
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                    help='input batch size for adversarial attack')

parser.add_argument('--n_trials', type=int, default=200, metavar='N',
                    help='input batch size for adversarial attack')

# eps ball
parser.add_argument('--epsilon', default=16, type=float,
                    help='perturbation')
parser.add_argument('--num-steps', default=10, type=int,
                    help='perturb number of steps')
parser.add_argument('--step-size', default=-1, type=float,
                    help='perturb step size')
parser.add_argument('--seed', type=int, default=66)
parser.add_argument('--inner_iter', type=int, default=1)

parser.add_argument('--max_idx_per_trial', type=int, default=3)
parser.add_argument('--rand_init', action='store_true', help="random init delta")
# MI
parser.add_argument('--mi', action='store_true', help="use momentum trick")
parser.add_argument('--momentum', default=0.0, type=float)
# TI
parser.add_argument('--ti', action='store_true', help="use translation-invariant trick")
# DI
parser.add_argument('--di', action='store_true', help="use diverse input trick")

parser.add_argument('--debug', action='store_true')

parser.add_argument('--space', type=str, default='convlinbpgrad')

# In[2]:

args = parser.parse_args()

if args.mi:
    args.momentum = 1.0


import torch
def seed_everything(seed):
    import torch
    import random
    import numpy as np
    
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    torch.backends.cudnn.benchmark = False
    # torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)

seed_everything(args.seed)



# In[3]:


import os        


# settings
use_cuda = torch.cuda.is_available()
args.device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}
print(args)


# In[4]:


import torch.nn as nn

# simple Module to normalize an image
class Normalize(nn.Module):
    def __init__(self, mean, std):
        super(Normalize, self).__init__()
        self.mean = torch.tensor(mean)
        self.std = torch.tensor(std)

    def forward(self, x):
        return (x - self.mean.type_as(x)[None,:,None,None]) / self.std.type_as(x)[None,:,None,None]


# In[5]:


# create models
import pretrainedmodels
from torchvision import transforms
import pretrainedmodels.utils
from utils import SubsetImageNet 

if not args.arch.startswith('tf'):
    model = pretrainedmodels.__dict__[args.arch](num_classes=1000, pretrained='imagenet')
    print(f"255 {model.input_range}, Size {model.input_size}, Space {model.input_space}")

    transform_test = transforms.Compose([
        transforms.Resize(model.input_size[1:]),
        transforms.ToTensor()
    ])
    print(transform_test)

    model = nn.Sequential(Normalize(mean=model.mean, std=model.std), model)
    model = model.to(args.device)
    model.eval()
else:
    from utils import get_eval_model
    model, transform_test = get_eval_model(args.arch)
    model = model.to(args.device).eval()
data_set = SubsetImageNet(root=args.input_dir, transform=transform_test)
loss_fn = nn.CrossEntropyLoss(reduction="sum").to(args.device)


# create adversary attack
epsilon = args.epsilon / 255.0
if args.step_size < 0:
    step_size = epsilon / args.num_steps
else:
    step_size = args.step_size / 255.0


# In[6]:


from collections import OrderedDict

def densenet_granularity(name):
    return '.'.join(name.split('.')[:3])

def resnet_granularity(name):
    return name

def inception_granularity(name):
    return '.'.join(name.split('.')[:3])

def senet_granularity(name):
    return '.'.join(name.split('.')[:2])


class BPCoarseSpace:
    def __init__(self, model, module_register, meta_models, dataset):
        self.meta_models = meta_models
        self.dataset = dataset

        self.search_space = OrderedDict() 

        for _, (name, module) in enumerate(model.named_modules()):
            module_skip, register_name = module_register(name, module)
            if register_name is None:
                continue

            module_skip = module_skip.to(args.device)
            tokens = name.split('.')
            cur_mod = model
            for t in tokens[:-1]:
                cur_mod = getattr(cur_mod, t)
            setattr(cur_mod, tokens[-1], module_skip)

            # Space for DN201
            if args.arch == 'densenet201':
                if len(tokens) == 3:
                    key = f"{'.'.join(tokens[:3])}.{register_name}" 
                else:
                    key = f"{'.'.join(tokens[:3])}.{tokens[3][:-1]}.{register_name}"
            else:
                key = f"{'.'.join(tokens[:3])}.{register_name}"  

            if key in self.search_space:
                self.search_space[key].append(module_skip)
            else:
                self.search_space[key] = [module_skip]
            
                    
    def step(self, gammas):
        for key, gamma in gammas.items():
            for module in self.search_space[key]:
                module.gamma = gamma
                # print(f"Key {key} Assign {gamma} to Module {module} as {module.gamma}")


# In[7]:


import numpy as np

from perturb import CustomizedAttack
from perturb.scm import SkipConv2d
from perturb.sgm import SkipRELU
from perturb.linbp import LinRelu, LinReluWOSGM


def conv_register(name, module):
    if (isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv1d)) and \
        module.stride == (1, 1) and \
        module.kernel_size != (1, 1) and \
        [i * 2 + 1 for i in module.padding] == [i for i in module.kernel_size]:
        print(f"Skip {name}")
        module = SkipConv2d(module, 0.5)
    return module

def relu_register(name, module):
    if isinstance(module, nn.ReLU) and 'layer' in name:
        tokens = name.split('.')
        if tokens[1] > 'layer3' or (tokens[1] == 'layer3' and int(tokens[2]) >=1):
            module = LinRelu(module)
            print(f"LinBP {name}")
    return module


def residual_register(name, module):
    if 'relu' in name and not '.0.relu' in name:
        module = SkipRELU(module, np.power(0.5, 0.5))
    return module


def module_register(name, module):
    register_name = None

    if (isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv1d)) and \
        module.stride == (1, 1) and \
        module.kernel_size != (1, 1) and \
        [i * 2 + 1 for i in module.padding] == [i for i in module.kernel_size] and \
        'conv' in args.space:
        print(f"Skip {name}")
        module = SkipConv2d(module, 0.5)
        register_name = 'skip_conv'

    if isinstance(module, nn.ReLU) and 'relu' in name and not 'transition' in name and 'linbp' in args.space:
        tokens = name.split('.')
        # if tokens[1] > 'layer3' or (tokens[1] == 'layer3' and int(tokens[2]) >=1):
        if 'grad' in args.space:
            module = LinRelu(module)
            print(f"Skip {name}")
            register_name = 'lin_relu'
        else:
            module = LinReluWOSGM(module)
            print(f"Skip w.o. skip grad {name}")
            register_name = 'lin_relu' 
    
    if isinstance(module, nn.ReLU) and 'relu' in name and not 'transition' in name and 'grad' in args.space and (not 'linbp' in args.space):
        print(f"Skip graident {name}")
        module = SkipRELU(module, np.power(0.5, 0.5))
        register_name = 'skip_grad'

    return module, register_name


# In[8]:


meta_models = []
meta_model = pretrainedmodels.__dict__['vgg19_bn'](
    num_classes=1000,
    pretrained='imagenet'
)
meta_model = nn.Sequential(Normalize(mean=meta_model.mean, std=meta_model.std), meta_model).to(args.device)
meta_models.append(meta_model)

from utils import RandomSubset
space = BPCoarseSpace(
    model, module_register, meta_models,
    RandomSubset(data_set, args.max_idx_per_trial * args.batch_size)
)

print(f"Search space[length={len(space.search_space)}]", space.search_space.keys(), flush=True)


# In[9]:


import optuna
from torchmetrics import Accuracy
from tqdm import tqdm

accuracy = Accuracy().to(args.device)

resize = transforms.Resize((224, 224))

def obejctive(trial):
    gammas = {
        key : trial.suggest_float(
            key, 0, 1, log=False
        ) for key in space.search_space.keys()
    }

    space.step(gammas)
    model.eval()
    accuracy.reset()

    adversary = CustomizedAttack(
        predict=model, loss_fn=loss_fn,
        eps=epsilon, nb_iter=1, eps_iter=epsilon,
        rand_init=False, clip_min=0.0, clip_max=1.0, targeted=False,
        decay_factor=.0,
        ti=False, di=False, di_prob=0.5,
        arch=args.arch
    )

    dataloader = torch.utils.data.DataLoader(
        space.dataset, batch_size=args.batch_size, shuffle=False, **kwargs
    )
    prec1_all = []
    for i, (inputs, true_class, idx) in tqdm(enumerate(dataloader)):
        inputs, true_class = \
            inputs.to(args.device), true_class.to(args.device)

        # attack
        inputs_adv = adversary.perturb(inputs, true_class)
        with torch.no_grad():
            for meta_model in space.meta_models:
                meta_model.eval()
                if inputs_adv.size()[-1] != 224:
                    inputs_adv = resize(inputs_adv)
                output = meta_model(inputs_adv)
                # output = model(inputs)
                accuracy.update(output, true_class)
        trial.report(accuracy.compute().item(), i)
        if trial.should_prune():
            raise optuna.TrialPruned()

    prec1_all.append(accuracy.compute().item())
    return np.mean(prec1_all)

study = optuna.create_study(
    pruner=optuna.pruners.HyperbandPruner()
)
study.optimize(obejctive, n_trials=args.n_trials, show_progress_bar=True)
print(study.best_params)
print(study.best_value)

space.step(study.best_params)


# In[19]:


from utils import generate_adversarial_example

if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)



for meta_model in space.meta_models:
    del meta_model
torch.cuda.empty_cache()

from utils import generate_adversarial_example

if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

adversary = CustomizedAttack(
    predict=model, loss_fn=loss_fn,
    eps=epsilon, nb_iter=args.num_steps, eps_iter=step_size,
    rand_init=args.rand_init, clip_min=0.0, clip_max=1.0, targeted=False,
    decay_factor=args.momentum,
    inner_iter=args.inner_iter,
    ti=args.ti, di=args.di,
    arch=args.arch
)

dataloader = torch.utils.data.DataLoader(
    data_set, batch_size=args.batch_size, shuffle=False, **kwargs
)

generate_adversarial_example(
    model=model, data_loader=dataloader,
    adversary=adversary, img_path=data_set.img_path,
    device=args.device, output_dir=args.output_dir
)
