from __future__ import print_function
from __future__ import division

import warnings
import time
import sys

from foolbox.attacks.base import Attack
from foolbox.attacks.base import call_decorator
from foolbox.distances import MSE, Linf
import numpy as np
import math


class BoundaryAttackPlusPlus(Attack):
    """A powerful adversarial attack that requires neither gradients
    nor probabilities.

    Notes
    -----
    Features:
    * ability to switch between two types of distances: MSE and Linf.
    * ability to continue previous attacks by passing an instance of the
      Adversarial class
    * ability to pass an explicit starting point; especially to initialize
      a targeted attack
    * ability to pass an alternative attack used for initialization
    * ability to specify the batch size

    References
    ----------
    ..
    Boundary Attack ++ was originally proposed by Chen and Jordan.
    It is a decision-based attack that requires access to output
    labels of a model alone.
    Paper link: https://arxiv.org/abs/1904.02144
    The implementation in Foolbox is based on Boundary Attack.

    """

    @call_decorator
    def __call__(
            self,
            input_or_adv,
            label=None,
            unpack=True,
            iterations=64,
            initial_num_evals=100,
            max_num_evals=10000,
            stepsize_search='grid_search',
            gamma=0.01,
            starting_point=None,
            batch_size=256,
            internal_dtype=np.float64,
            log_every_n_steps=1,
            verbose=False):
        """Applies Boundary Attack++.

        Parameters
        ----------
        input_or_adv : `numpy.ndarray` or :class:`Adversarial`
            The original, correctly classified image. If image is a
            numpy array, label must be passed as well. If image is
            an :class:`Adversarial` instance, label must not be passed.
        label : int
            The reference label of the original image. Must be passed
            if image is a numpy array, must not be passed if image is
            an :class:`Adversarial` instance.
        unpack : bool
            If true, returns the adversarial image, otherwise returns
            the Adversarial object.
        iterations : int
            Number of iterations to run.
        initial_num_evals: int
            Initial number of evaluations for gradient estimation.
            Larger initial_num_evals increases time efficiency, but
            may decrease query efficiency.
        max_num_evals: int
            Maximum number of evaluations for gradient estimation.
        stepsize_search: str
            How to search for stepsize; choices are 'geometric_progression',
            'grid_search'. 'geometric progression' initializes the stepsize
            by ||x_t - x||_p / sqrt(iteration), and keep decreasing by half
            until reaching the target side of the boundary. 'grid_search'
            chooses the optimal epsilon over a grid, in the scale of
            ||x_t - x||_p.
        gamma: float
            The binary search threshold theta is gamma / sqrt(d) for
                   l2 attack and gamma / d for linf attack.

        starting_point : `numpy.ndarray`
            Adversarial input to use as a starting point, required
            for targeted attacks.
        batch_size : int
            Batch size for model prediction.
        internal_dtype : np.float32 or np.float64
            Higher precision might be slower but is numerically more stable.
        log_every_n_steps : int
            Determines verbositity of the logging.
        verbose : bool
            Controls verbosity of the attack.

        """

        self.initial_num_evals = initial_num_evals
        self.max_num_evals = max_num_evals
        self.stepsize_search = stepsize_search
        self.gamma = gamma
        self.batch_size = batch_size
        self.verbose = verbose
        self._starting_point = starting_point
        self.internal_dtype = internal_dtype
        self.log_every_n_steps = log_every_n_steps
        self.verbose = verbose

        # Set constraint based on the distance.
        if self._default_distance == MSE:
            self.constraint = 'l2'
        elif self._default_distance == Linf:
            self.constraint = 'linf'

        # Set binary search threshold.
        self.shape = input_or_adv.original_image.shape
        self.d = np.prod(self.shape)
        if self.constraint == 'l2':
            self.theta = self.gamma / np.sqrt(self.d)
        else:
            self.theta = self.gamma / (self.d)
        # print('Boundary Attack ++ optimized for {} distance'.format(
        #     self.constraint))

        # if not verbose:
        #     print('run with verbose=True to see details')

        return self.attack(
            input_or_adv,
            iterations=iterations)

    def attack(
            self,
            a,
            iterations):
        """
        iterations : int
            Maximum number of iterations to run.
        """
        self.t_initial = time.time()

        # ===========================================================
        # Increase floating point precision
        # ===========================================================

        self.external_dtype = a.original_image.dtype

        assert self.internal_dtype in [np.float32, np.float64]
        assert self.external_dtype in [np.float32, np.float64]

        assert not (self.external_dtype == np.float64 and
                    self.internal_dtype == np.float32)

        a.set_distance_dtype(self.internal_dtype)

        # ===========================================================
        # Construct batch decision function with binary output.
        # ===========================================================
        # decision_function = lambda x: a.batch_predictions(
        #     x.astype(self.external_dtype), strict=False)[1]
        def decision_function(x):
            outs = []
            num_batchs = int(math.ceil(len(x) * 1.0 / self.batch_size))

            # print(x.shape, num_batchs)

            for j in range(num_batchs):
                current_batch = x[self.batch_size * j:
                                  self.batch_size * (j + 1)]
                current_batch = current_batch.astype(self.external_dtype)
                out = a.batch_predictions(current_batch, strict=False)[1]
                outs.append(out)
            outs = np.concatenate(outs, axis=0)
            return outs

        # ===========================================================
        # intialize time measurements
        # ===========================================================
        self.time_gradient_estimation = 0

        self.time_search = 0

        self.time_initialization = 0

        # ===========================================================
        # Initialize variables, constants, hyperparameters, etc.
        # ===========================================================

        # make sure repeated warnings are shown
        warnings.simplefilter('always', UserWarning)

        # get bounds
        bounds = a.bounds()
        self.clip_min, self.clip_max = bounds

        # ===========================================================
        # Find starting point
        # ===========================================================

        self.initialize_starting_point(a)

        if a.image is None:
            warnings.warn(
                'Initialization failed.'
                ' it might be necessary to pass an explicit starting'
                ' point.')
            return

        self.time_initialization += time.time() - self.t_initial

        assert a.image.dtype == self.external_dtype
        # get original and starting point in the right format
        original = a.original_image.astype(self.internal_dtype)
        perturbed = a.image.astype(self.internal_dtype)

        # ===========================================================
        # Iteratively refine adversarial
        # ===========================================================
        t0 = time.time()

        # Project the initialization to the boundary.
        perturbed, dist_post_update = self.binary_search_batch(
            original, np.expand_dims(perturbed, 0), decision_function)

        dist = self.compute_distance(perturbed, original)

        distance = a.distance.value
        self.time_search += time.time() - t0

        # log starting point
        self.log_step(0, distance)

        for step in range(1, iterations + 1):

            t0 = time.time()

            # ===========================================================
            # Gradient direction estimation.
            # ===========================================================
            # Choose delta.
            delta = self.select_delta(dist_post_update, step)

            # Choose number of evaluations.
            num_evals = int(min([self.initial_num_evals * np.sqrt(step),
                                 self.max_num_evals]))

            # print(num_evals, perturbed.shape)

            # approximate gradient.
            gradf = self.approximate_gradient(decision_function, perturbed,
                                              num_evals, delta)

            if self.constraint == 'linf':
                update = np.sign(gradf)
            else:
                update = gradf
            t1 = time.time()
            self.time_gradient_estimation += t1 - t0

            # ===========================================================
            # Update, and binary search back to the boundary.
            # ===========================================================
            if self.stepsize_search == 'geometric_progression':
                # find step size.
                epsilon = self.geometric_progression_for_stepsize(
                    perturbed, update, dist, decision_function, step)

                # Update the sample.
                perturbed = self.clip_image(perturbed + epsilon * update,
                                            self.clip_min, self.clip_max)

                # Binary search to return to the boundary.
                perturbed, dist_post_update = self.binary_search_batch(
                    original, perturbed[None], decision_function)

            elif self.stepsize_search == 'grid_search':
                # Grid search for stepsize.
                epsilons = np.logspace(-4, 0, num=20, endpoint=True) * dist
                epsilons_shape = [20] + len(self.shape) * [1]
                perturbeds = perturbed + epsilons.reshape(
                    epsilons_shape) * update
                perturbeds = self.clip_image(perturbeds,
                                             self.clip_min, self.clip_max)
                idx_perturbed = decision_function(perturbeds)

                if np.sum(idx_perturbed) > 0:
                    # Select the perturbation that yields the minimum
                    # distance after binary search.
                    perturbed, dist_post_update = self.binary_search_batch(
                        original, perturbeds[idx_perturbed],
                        decision_function)
            t2 = time.time()

            self.time_search += t2 - t1

            # compute new distance.
            dist = self.compute_distance(perturbed, original)

            # ===========================================================
            # Log the step
            # ===========================================================
            # Using foolbox definition of distance for logging.
            if self.constraint == 'l2':
                distance = dist ** 2 / self.d / \
                    (self.clip_max - self.clip_min) ** 2
            elif self.constraint == 'linf':
                distance = dist / (self.clip_max - self.clip_min)
            message = ' (took {:.5f} seconds)'.format(t2 - t0)
            self.log_step(step, distance, message)
            sys.stdout.flush()

        # ===========================================================
        # Log overall runtime
        # ===========================================================

        self.log_time()

    # ===============================================================
    #
    # Other methods
    #
    # ===============================================================

    def initialize_starting_point(self, a):
        starting_point = self._starting_point

        if a.image is not None:
            print(
                'Attack is applied to a previously found adversarial.'
                ' Continuing search for better adversarials.')
            if starting_point is not None:  # pragma: no cover
                warnings.warn(
                    'Ignoring starting_point parameter because the attack'
                    ' is applied to a previously found adversarial.')
            return

        if starting_point is not None:
            a.predictions(starting_point)
            assert a.image is not None, ('Invalid starting point provided.'
                                         ' Please provide a starting point'
                                         ' that is adversarial.')
            return

        """
        Apply BlendedUniformNoiseAttack if without
        initialization.
        Efficient Implementation of BlendedUniformNoiseAttack in Foolbox.
        """
        success = 0
        num_evals = 0

        while True:
            random_noise = np.random.uniform(self.clip_min, self.clip_max,
                                             size=self.shape)
            _, success = a.predictions(
                random_noise.astype(self.external_dtype))
            num_evals += 1

            if success:
                break
            if num_evals > 1e4:
                return

        # Binary search to minimize l2 distance to original image.
        low = 0.0
        high = 1.0
        while high - low > 0.001:
            mid = (high + low) / 2.0
            blended = (1 - mid) * a.original_image + mid * random_noise
            _, success = a.predictions(blended.astype(self.external_dtype))
            if success:
                high = mid
            else:
                low = mid

    def compute_distance(self, image1, image2):
        if self.constraint == 'l2':
            return np.linalg.norm(image1 - image2)
        elif self.constraint == 'linf':
            return np.max(abs(image1 - image2))

    def clip_image(self, image, clip_min, clip_max):
        """ Clip an image, or an image batch,
        with upper and lower threshold. """
        return np.minimum(np.maximum(clip_min, image), clip_max)

    def project(self, original_image, perturbed_images, alphas):
        """ Projection onto given l2 / linf balls in a batch. """
        alphas_shape = [len(alphas)] + [1] * len(self.shape)
        alphas = alphas.reshape(alphas_shape)
        if self.constraint == 'l2':
            projected = (1 - alphas) * original_image + \
                alphas * perturbed_images
        elif self.constraint == 'linf':
            projected = self.clip_image(
                perturbed_images,
                original_image - alphas,
                original_image + alphas
            )
        return projected

    def binary_search_batch(self, original_image, perturbed_images,
                            decision_function):
        """ Binary search to approach the boundary. """

        # Compute distance between each of perturbed image and original image.
        dists_post_update = np.array(
            [self.compute_distance(original_image,
                                   perturbed_image) for perturbed_image in
             perturbed_images])

        # Choose upper thresholds in binary searchs based on constraint.
        if self.constraint == 'linf':
            highs = dists_post_update
            # Stopping criteria.
            thresholds = np.minimum(dists_post_update * self.theta,
                                    self.theta)
        else:
            highs = np.ones(len(perturbed_images))
            thresholds = self.theta

        lows = np.zeros(len(perturbed_images))

        # Call recursive function.
        while np.max((highs - lows) / thresholds) > 1:
            # projection to mids.
            mids = (highs + lows) / 2.0
            mid_images = self.project(original_image, perturbed_images,
                                      mids)

            # Update highs and lows based on model decisions.
            decisions = decision_function(mid_images)
            lows = np.where(decisions == 0, mids, lows)
            highs = np.where(decisions == 1, mids, highs)

        out_images = self.project(original_image, perturbed_images,
                                  highs)

        # Compute distance of the output image to select the best choice.
        # (only used when stepsize_search is grid_search.)
        dists = np.array([
            self.compute_distance(
                original_image,
                out_image
            )
            for out_image in out_images])
        idx = np.argmin(dists)

        dist = dists_post_update[idx]
        out_image = out_images[idx]
        # print(np.shape(original_image), np.shape(out_image))
        return out_image, dist

    def select_delta(self, dist_post_update, current_iteration):
        """
        Choose the delta at the scale of distance
        between x and perturbed sample.
        """
        if current_iteration == 1:
            delta = 0.1 * (self.clip_max - self.clip_min)
        else:
            if self.constraint == 'l2':
                delta = np.sqrt(self.d) * self.theta * dist_post_update
            elif self.constraint == 'linf':
                delta = self.d * self.theta * dist_post_update

        return delta

    def approximate_gradient(self, decision_function, sample,
                             num_evals, delta):
        """ Gradient direction estimation """
        # Generate random vectors.
        noise_shape = [num_evals] + list(self.shape)
        if self.constraint == 'l2':
            rv = np.random.randn(*noise_shape)
        elif self.constraint == 'linf':
            rv = np.random.uniform(low=-1, high=1, size=noise_shape)

        axis = tuple(range(1, 1 + len(self.shape)))
        rv = rv / np.sqrt(np.sum(rv ** 2, axis=axis, keepdims=True))
        perturbed = sample + delta * rv
        perturbed = self.clip_image(perturbed, self.clip_min,
                                    self.clip_max)
        rv = (perturbed - sample) / delta

        # query the model.
        # print(perturbed.shape)
        decisions = decision_function(perturbed)
        decision_shape = [len(decisions)] + [1] * len(self.shape)
        fval = 2 * decisions.astype(self.internal_dtype).reshape(
            decision_shape) - 1.0

        # Baseline subtraction (when fval differs)
        vals = fval if abs(np.mean(fval)) == 1.0 else fval - np.mean(fval)
        gradf = np.mean(vals * rv, axis=0)

        # Get the gradient direction.
        gradf = gradf / np.linalg.norm(gradf)

        return gradf

    def geometric_progression_for_stepsize(self, x, update, dist,
                                           decision_function,
                                           current_iteration):
        """ Geometric progression to search for stepsize.
          Keep decreasing stepsize by half until reaching
          the desired side of the boundary.
        """
        epsilon = dist / np.sqrt(current_iteration)
        while True:
            updated = self.clip_image(x + epsilon * update,
                                      self.clip_min, self.clip_max)
            success = decision_function(updated[None])[0]
            if success:
                break
            else:
                epsilon = epsilon / 2.0

        return epsilon

    def log_step(self, step, distance, message='', always=False):
        if not always and step % self.log_every_n_steps != 0:
            return
        # print('Step {}: {:.5e} {}'.format(
        #     step,
        #     distance,
        #     message))

    def log_time(self):
        t_total = time.time() - self.t_initial
        rel_initialization = self.time_initialization / t_total
        rel_gradient_estimation = self.time_gradient_estimation / t_total
        rel_search = self.time_search / t_total

        self.printv('Time since beginning: {:.5f}'.format(t_total))
        self.printv('   {:2.1f}% for initialization ({:.5f})'.format(
            rel_initialization * 100, self.time_initialization))
        self.printv('   {:2.1f}% for gradient estimation ({:.5f})'.format(
            rel_gradient_estimation * 100,
            self.time_gradient_estimation))
        self.printv('   {:2.1f}% for search ({:.5f})'.format(
            rel_search * 100, self.time_search))

    def printv(self, *args, **kwargs):
        if self.verbose:
            print(*args, **kwargs)