#!/usr/bin/env python3
import sys
import os
import os.path as osp
import argparse
import json
import random
from collections import OrderedDict

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from models import make_model


# --arch wrn-28-10-drop --attack-type untargeted --norm-type linf --dataset cifar10 --ref-arch alexnet_bn vgg11_bn vgg13_bn vgg16_bn vgg19_bn --ref-arch-train-data cifar10.1
# --arch resnet50 --attack-type untargeted --norm-type linf --dataset imagenet --ref-arch resnet18 resnet34 resnet50 --ref-arch-train-data imagenet

"""
Parse input arguments
"""


class StandardModel(nn.Module):
    """
    A StandardModel object wraps a cnn model.
    This model always accept standard image: in [0, 1] range, RGB order, un-normalized, NCHW format
    """

    def __init__(self, dataset, arch, no_grad=True, **kwargs):
        super(StandardModel, self).__init__()
        # init cnn model
        self.cnn = make_model(dataset, arch, **kwargs)
        self.cnn.cuda()

        # init cnn model meta-information
        self.mean = torch.FloatTensor(self.cnn.mean).view(1, 3, 1, 1).cuda()
        self.std = torch.FloatTensor(self.cnn.std).view(1, 3, 1, 1).cuda()
        self.input_space = self.cnn.input_space  # 'RGB' or 'GBR'
        self.input_range = self.cnn.input_range  # [0, 1] or [0, 255]
        self.input_size = self.cnn.input_size

        self.no_grad = no_grad

    def forward(self, x):
        # assign dropout probability
        if hasattr(self, 'drop'):
            self.cnn.drop = self.drop

        # channel order
        if self.input_space == 'BGR':
            x = x[:, [2, 1, 0], :, :]  # pytorch does not support negative stride index (::-1) yet

        # input range
        if max(self.input_range) == 255:
            x = x * 255

        # normalization
        x = (x - self.mean) / self.std

        if self.no_grad:
            with torch.no_grad():
                x = self.cnn(x)
        else:
            x = self.cnn(x)
        return x


def norm(t, p=2):
    assert len(t.shape) == 4
    if p == 2:
        norm_vec = torch.sqrt(t.pow(2).sum(dim=[1, 2, 3])).view(-1, 1, 1, 1)
    elif p == 1:
        norm_vec = t.abs().sum(dim=[1, 2, 3]).view(-1, 1, 1, 1)
    else:
        raise NotImplementedError('Unknown norm p={}'.format(p))
    norm_vec += (norm_vec == 0).float() * 1e-8
    return norm_vec




def gd_prior_step(x, g, lr):
    return x + lr * g


def momentum_prior_step(x, g, lr):
    # adapted from Boosting Adversarial Attacks with Momentum, CVPR 2018
    return x + lr * g / norm(g, p=1)


def linf_image_step(x, g, lr):
    return x + lr * torch.sign(g)


def l2_image_step(x, g, lr):
    return x + lr * g / norm(g)


def l2_proj_step(image, epsilon, adv_image):
    delta = adv_image - image
    out_of_bounds_mask = (norm(delta) > epsilon).float()
    return out_of_bounds_mask * (image + epsilon * delta / norm(delta)) + (1 - out_of_bounds_mask) * adv_image


def linf_proj_step(image, epsilon, adv_image):
    return image + torch.clamp(adv_image - image, -epsilon, epsilon)


def cw_loss(logit, label, target=None):
    if target is not None:
        # targeted cw loss: logit_t - max_{i\neq t}logit_i
        _, argsort = logit.sort(dim=1, descending=True)
        target_is_max = argsort[:, 0].eq(target)
        second_max_index = target_is_max.long() * argsort[:, 1] + (1 - target_is_max).long() * argsort[:, 0]
        target_logit = logit[torch.arange(logit.shape[0]), target]
        second_max_logit = logit[torch.arange(logit.shape[0]), second_max_index]
        return target_logit - second_max_logit
    else:
        # untargeted cw loss: max_{i\neq y}logit_i - logit_y
        _, argsort = logit.sort(dim=1, descending=True)
        gt_is_max = argsort[:, 0].eq(label).float()
        second_max_index = gt_is_max.long() * argsort[:, 1] + (1 - gt_is_max).long() * argsort[:, 0]
        gt_logit = logit[torch.arange(logit.shape[0]), label.long()]
        second_max_logit = logit[torch.arange(logit.shape[0]), second_max_index]
        return second_max_logit - gt_logit


def xent_loss(logit, label, target=None):
    if target is not None:
        return -F.cross_entropy(logit, target, reduction='none')
    else:
        return F.cross_entropy(logit, label, reduction='none')

class Subspace_Attack(object):
    def eg_prior_step(self,x, g, lr):
        real_x = (x + 1) / 2  # from [-1, 1] to [0, 1]
        lrg = torch.clamp(lr * g, -self.eg_clip, self.eg_clip)
        pos = real_x * torch.exp(lrg)
        neg = (1 - real_x) * torch.exp(-lrg)
        new_x = pos / (pos + neg)
        return new_x * 2 - 1

    ###############################################################
    def __init__(self, model , dataset='imagenet'):
        self.model = model
        if dataset=='cifar':
            self.dataset='cifar10'
            self.delta_size=0
            self.input_size=32
            self.ref_arch=['alexnet_bn', 'vgg11_bn', 'vgg13_bn', 'vgg16_bn', 'vgg19_bn'] #,'resnet50'
            self.ref_arch_train_data='cifar10.1' #'full', 'cifar10.1', 'imagenetv2-val'
            self.ref_arch_epoch= 'final'  # best
            self.ref_arch_init_drop= 0.05  # best
            self.ref_arch_max_drop= 0.5  # best
            self.ref_arch_drop_grow_iter=100  # best
            self.ref_arch_drop_gamma=0.01 # best
            self.fix_grad=True
            self.loss='cw' # 'xent', 'cw'
            self.exploration=1.0
            self.fd_eta=0.1
            self.image_lr=0.5
            self.prior_lr=0.01
            self.prior_update='gd' # 'eg', 'gd', 'momentum']
            self.eg_clip=1.0
            self.num_fix_direction=0
            self.norm_type='l2' #choices=['l2', 'linf'],
            self.epsilon=1-1e-1
            self.attack_type='untargeted' #=['untargeted', 'targeted'],
            self.target_type='random' #'random', 'least_likely'
        else:
            self.dataset='imagenet'
            self.delta_size=0
            self.input_size=224
            self.ref_arch=['resnet18','resnet34'] #,'resnet50'
            self.ref_arch_train_data='imagenetv2-val' #'full', 'cifar10.1', 'imagenetv2-val'
            self.ref_arch_epoch= 'final'  # best
            self.ref_arch_init_drop= 0.05  # best
            self.ref_arch_max_drop= 0.5  # best
            self.ref_arch_drop_grow_iter=100  # best
            self.ref_arch_drop_gamma=0.01 # best
            self.fix_grad=False
            self.loss='cw' # 'xent', 'cw'
            self.exploration=1.0
            self.fd_eta=0.1
            self.image_lr=1. / 255
            self.prior_lr=100.0
            self.prior_update='eg' # 'eg', 'gd', 'momentum']
            self.eg_clip=1.0
            self.num_fix_direction=0
            self.norm_type='linf' #choices=['l2', 'linf'],
            self.epsilon=0.05-1e-5
            self.attack_type='untargeted' #=['untargeted', 'targeted'],
            self.target_type='random' #'random', 'least_likely'


        # self.model = model
        # self.dataset='imagenet'
        # self.delta_size=0
        # self.ref_arch=['resnet18','resnet34','resnet50']
        # self.ref_arch_train_data='imagenetv2-val' #'full', 'cifar10.1', 'imagenetv2-val'
        # self.ref_arch_epoch= 'final'  # best
        # self.ref_arch_init_drop= 0.05  # best
        # self.ref_arch_max_drop= 0.5  # best
        # self.ref_arch_drop_grow_iter=100  # best
        # self.ref_arch_drop_gamma=0.01 # best
        # self.fix_grad=False
        # self.loss='cw' # 'xent', 'cw'
        # self.exploration=1.0
        # self.fd_eta=0.1
        # self.image_lr=0.5
        # self.prior_lr=0.01
        # self.prior_update='gd' # 'eg', 'gd', 'momentum']
        # self.eg_clip=1.0
        # self.num_fix_direction=0
        # self.norm_type='l2' #choices=['l2', 'linf'],
        # self.epsilon=5-1e-1
        # self.attack_type='untargeted' #=['untargeted', 'targeted'],
        # self.target_type='random' #'random', 'least_likely'

        self.input_size=(3,self.input_size,self.input_size)
        #model = StandardModel(self.dataset, self.arch, no_grad=True, train_data='full', epoch='final').eval()

        self.ref_models = OrderedDict()
        for i, ref_arch in enumerate(self.ref_arch):
            params = dict()
            params['train_data'] = self.ref_arch_train_data
            params['epoch'] = self.ref_arch_epoch
            
            self.ref_models[ref_arch] = StandardModel(self.dataset, ref_arch, no_grad=False, **params).eval()

        # make operators
        if dataset=='cifar':
            self.prior_step = eval('{}_prior_step'.format(self.prior_update))
        else:
            self.prior_step = eval('self.{}_prior_step'.format(self.prior_update))
        self.image_step = eval('{}_image_step'.format(self.norm_type))
        self.proj_step = eval('{}_proj_step'.format(self.norm_type))

        if self.num_fix_direction > 0:
            if len(self.ref_arch) == 0:
                # fixed random direction
                assert self.dataset == 'cifar10'
                state = np.random.get_state()
                fix_direction = np.random.randn(3072, *model.input_size)[:self.num_fix_direction]
                np.random.set_state(state)
                fix_direction = np.ascontiguousarray(fix_direction)
                self.fix_direction = torch.FloatTensor(fix_direction).cuda()
            else:
                # fixed gradient direction (calculated at clean inputs)
                assert self.num_fix_direction == len(self.ref_arch)

    def attack_untargeted(self, x_0, y_0, query_limit=20000):
        device=torch.device('cuda:0')
        # move image and label to device
        #image_id = image_id.to(device)
        image = x_0.cuda()
        label = y_0.cuda()
        self.max_query=query_limit
        adv_image = image.clone()
        if self.delta_size > 0:
            # resize
            upsampler = lambda x: F.interpolate(x, size=self.input_size[-1], mode='bilinear', align_corners=True)
            downsampler = lambda x: F.interpolate(x, size=self.delta_size, mode='bilinear', align_corners=True)
        else:
            # no resize, upsampler = downsampler = identity
            upsampler = downsampler = lambda x: x
        # get logit and prob

        logit = self.model(image,unnormalization=False)
        adv_logit = logit.clone()
        # make loss function
        loss_func = eval('{}_loss'.format(self.loss))
        # choose target classes for targeted attack
        if self.attack_type == 'targeted':
            if self.target_type == 'random':
                target = torch.randint(low=0, high=logit.shape[1], size=label.shape).long().to(device)
            elif self.target_type == 'least_likely':
                target = logit.argmin(dim=1)
            else:
                raise NotImplementedError('Unknown target_type: {}'.format(self.target_type))
            # make sure target is not equal to label for any example
            invalid_target_index = target.eq(label)
            while invalid_target_index.sum().item() > 0:
                target[invalid_target_index] = torch.randint(low=0, high=logit.shape[1],
                                                             size=target[invalid_target_index].shape).long().to(device)
                invalid_target_index = target.eq(label)
        else:
            target = None


        query = 0
        # init prior
        if self.delta_size > 0:
            prior = torch.zeros(1, self.input_size[0], self.delta_size, self.delta_size).to(device)
        else:
            prior = torch.zeros(1, *self.input_size).to(device)




        # perform attack
        for step_index in range(self.max_query // 3+1):
            # increase query counts
            query = query + 3
            #print(query, self.model.get_num_queries())
            # calculate drop probability
            if step_index < self.ref_arch_drop_grow_iter:
                drop = self.ref_arch_init_drop
            else:
                drop = self.ref_arch_max_drop - \
                    (self.ref_arch_max_drop - self.ref_arch_init_drop) * \
                    np.exp(-(step_index - self.ref_arch_drop_grow_iter) * self.ref_arch_drop_gamma)

            # finite difference for gradient estimation
            if len(self.ref_models) > 0:
                # select ref model to calculate gradient
                selected_ref_arch_index = torch.randint(low=0, high=len(self.ref_models), size=(1,)).long().item()
                # get original model logit's grad
                adv_logit = adv_logit.detach()
                adv_logit.requires_grad = True
                loss = loss_func(adv_logit, label, target).mean()
                logit_grad = torch.autograd.grad(loss, [adv_logit])[0]

                # calculate gradient for all ref models
                def calc_ref_grad(adv_image_, ref_model_, drop_=0):
                    adv_image_ = adv_image_.detach()
                    adv_image_.requires_grad = True
                    if adv_image_.grad:
                        adv_image_.grad[:] = 0.
                    ref_model_.zero_grad()
                    # assign dropout probability
                    ref_model_.drop = drop_
                    # forward ref model
                    if ref_model_.input_size != self.input_size:
                        ref_logit_ = ref_model_(F.interpolate(adv_image_, size=ref_model_.input_size[-1],
                                                              mode='bilinear', align_corners=True))
                    else:
                        ref_logit_ = ref_model_(adv_image_)

                    # backward ref model using logit_grad from the victim model
                    ref_grad_ = torch.autograd.grad(ref_logit_, [adv_image_], grad_outputs=[logit_grad])[0]
                    ref_grad_ = downsampler(ref_grad_)

                    # compute dl/dv
                    if self.fix_grad:
                        if prior.view(prior.shape[0], -1).norm(dim=1).min().item() > 0:
                            # -1 / ||v|| ** 3 (||v|| ** 2 dL/dv - v(v^T dL/dv))
                            g1 = norm(prior) ** 2 * ref_grad_
                            g2 = prior * (prior * ref_grad_).sum(dim=(1, 2, 3)).view(-1, 1, 1, 1)
                            ref_grad_ = g1 - g2
                    return ref_grad_ / norm(ref_grad_)

                # calculate selected ref model's gradient
                if self.num_fix_direction == 0:
                    direction = calc_ref_grad(adv_image, list(self.ref_models.values())[selected_ref_arch_index], drop_=drop)
                else:
                    # for illustrate experiment in rebuttal
                    assert self.loss == 'cw'
                    assert drop == 0
                    direction = calc_ref_grad(image, list(self.ref_models.values())[selected_ref_arch_index], drop_=drop)
            else:
                # use random search direction solely
                if self.num_fix_direction > 0:
                    # use fixed direction (for illustration experiments)
                    if len(self.ref_arch) == 0:
                        # fixed random direction
                        # fix_direction.shape: [num_fix_direction, C, H, W]
                        # coeff.shape: [num_Fix_direction, N]
                        coeff = torch.randn(self.num_fix_direction, prior.shape[0]).to(device)
                        direction = (self.fix_direction.view(self.fix_direction.shape[0], 1, *self.fix_direction.shape[1:]) *
                                     coeff.view(coeff.shape[0], coeff.shape[1], 1, 1, 1)).sum(dim=0)
                    else:
                        # fixed gradient direction (calculated at clean inputs) for rebuttal
                        # direction has already been calculated
                        assert direction.shape[0] == image.shape[0]
                else:
                    direction = torch.randn_like(prior)
            with torch.no_grad():
                # normalize search direction
                direction = direction / norm(direction)

                # finite difference
                q1 = upsampler(prior + self.exploration * direction)
                q2 = upsampler(prior - self.exploration * direction)
                l1 = loss_func(self.model( adv_image + self.fd_eta * q1 / norm(q1),unnormalization=False), label, target)
                l2 = loss_func(self.model( adv_image + self.fd_eta * q2 / norm(q2),unnormalization=False), label, target)
                grad = (l1 - l2) / (self.fd_eta * self.exploration)
                grad = grad.view(-1, 1, 1, 1) * direction  # grad.shape == direction.shape == prior.shape ?= image.shape
                # update prior
                prior = self.prior_step(prior, grad, self.prior_lr)
                # extract grad from prior
                grad = upsampler(prior)
                # update adv_image (correctly classified images only)
                adv_image = self.image_step(adv_image, grad, self.image_lr)
                adv_image = self.proj_step(image, self.epsilon, adv_image)
                adv_image = torch.clamp(adv_image, 0, 1)
                # update statistics
                adv_pred = self.model.predict_label(adv_image)
                if adv_pred.eq(label).float()==0:
                    break
        return adv_image
