import numpy as np
import torch

from attack import Attack

class HSJA(Attack):
    def __init__(self,
                 model,
                 dataset="",
                 order=2,
                 num_iterations=10000,
                 gamma=1.0,
                 stepsize_search='geometric_progression',
                 init_num_evals=100, 
                 verbose=True,
                 early_stopping=False):
        Attack.__init__(self, model, early_stopping=early_stopping, order=order, dataset=dataset)
        self.num_iterations = num_iterations
        self.gamma = gamma
        print(f"Using gamma={gamma}")
        self.stepsize_search = stepsize_search
        self.init_num_evals = init_num_evals
        self.verbose = verbose
        self.x_orig = None
    
    def to_attack_space(self, x):
        return self.get_xorig(x)
    
    def to_image_space(self, v):
        return v
    
    def clip_image(self, x, lb=None, rb=None):
        
        if lb is None:
            lb = self.bounds[0]
        if rb is None:
            rb = self.bounds[1]
            
        if type(x) is torch.Tensor:
            x = x.cpu().numpy()
        
        return np.minimum(np.maximum(lb, x), rb) 
    
    def get_xadv(self, x, d=None, v=None):
        # HSJA implementation proceeds over the adversarial attempt directly, instead of the noise. 
        
        if d is None or v is None:
            # passthrough
            return self.clip_image(x)
        
        if type(x) is torch.Tensor:
            x = x.cuda()
            if type(v) is not torch.Tensor:
                v = torch.tensor(v, dtype=torch.float)
            v = v.cuda()
            
        out = x + (d * v)
        return self.clip_image(out.float())
    
    def hsja(self, data, label, epsilon, target, target_loader, query_limit, seed):
        if self.benign_succ(self.get_xorig(data), label, None):
            print("Fail to classify the image. No need to attack.")
            return None
        
        self.manual_seed(seed)
            
        # Set parameters
        # original_label = np.argmax(self.model.predict_label(input_xi))
        d = int(self.dim_attack_space(data, prod=True))
        # Set binary search threshold.
        if self.order == 2:
            theta = self.gamma / (np.sqrt(d) * d)
        else:
            theta = self.gamma / (d ** 2)

        # Initialize.
        attempt = self.initialize(data, label, target, target_loader)
        if attempt is None:
            print("Early exit")
            return None
        
        dist = self.compute_distance(self.get_xadv(attempt), data)
        self.update_log(dist, epsilon)
        
        print(f"Used {self.model.get_num_queries()} queries to initialize.")
        
        # Project the initialization to the boundary.
        attempt, dist_post_update = self.binary_search_batch(data, attempt, label, 
                                                               theta, target)
        dist = self.compute_distance(self.get_xadv(attempt), data)
        self.update_log(dist, epsilon)
        
        if self.stopping_condition(query_limit, dist, epsilon):
            return self.get_xadv(attempt)
        
        for j in np.arange(2, self.num_iterations):
            
            # Choose delta.
            if j==2:
                delta = 0.1 * (self.model.bounds[1] - self.model.bounds[0])
            else:
                if self.order == 2:
                    delta = np.sqrt(d) * theta * dist_post_update
                elif self.order == np.inf:
                    delta = d * theta * dist_post_update        


            # Choose number of evaluations.
            num_evals = int(self.init_num_evals * np.sqrt(j+1))
            num_evals = int(min([num_evals, self.num_iterations]))

            # approximate gradient.
            gradf = self.approximate_gradient(attempt, label, num_evals, delta, target)
            
            if self.order == np.inf:
                update = np.sign(gradf)
            else:
                update = gradf
            
            self.update_gradient(update)
        
            # search step size.
            if self.stepsize_search == 'geometric_progression':
                # find step size.
                eps = self.geometric_progression_for_stepsize(attempt, label, 
                                                                  update, dist, j+1, target)

                # Update the sample. 
                attempt = self.get_xadv(attempt + eps*update)

                # Binary search to return to theboundary.
                attempt, dist_post_update = self.binary_search_batch(data, 
                        # attempt, label, theta, target)
                        attempt[None], label, theta, target)

            elif self.stepsize_search == 'grid_search':
                # Grid search for stepsize.
                epsilons = np.logspace(-4, 0, num=20, endpoint = True) * dist
                epsilons_shape = [20] + len(self.dim_attack_space(data)) * [1]
                attempts = attempt + epsilons.reshape(epsilons_shape)*update
                attempts = self.get_xadv(attempts)
                # attempts = attempt + epsilons.reshape(epsilons_shape) * update
                idx_attempt = self.decision_function(attempts, label, target)

                if np.sum(idx_attempt) > 0:
                    # Select the perturbation that yields the minimum distance # after binary search.
                    attempt, dist_post_update = self.binary_search_batch(data, 
                            attempts[idx_attempt], label, theta, target)

            # compute new distance.
            dist = self.compute_distance(self.get_xadv(attempt), data)
            self.update_log(dist, epsilon)
            if True in np.isnan(self.log):
                print("Unstable state")
                return None
            
            if self.verbose:
                print(f'iteration: {j}, L{self.order} distance {dist:.4f} queries {self.model.get_num_queries()} num evals {num_evals}')
                
            if self.stopping_condition(query_limit, dist, epsilon):
                return self.get_xadv(attempt)

        return self.get_xadv(attempt)

    def decision_function(self, images, label, target=None):
        """
        Decision function output 1 on the desired side of the boundary,
        0 otherwise.
        """
        res = self.search_succ(images, label, target=target, expand=True)
        return res.astype(int).reshape((len(images),))

    def approximate_gradient(self, sample, label, num_evals, delta, target):
        # Generate random vectors.
        # At this point we are operating on attack space
        noise_shape = [num_evals] + list(self.dim_attack_space(sample))
        if self.order == 2:
            rv = np.random.randn(*noise_shape)
        elif self.order == np.inf:
            rv = np.random.uniform(low = -1, high = 1, size = noise_shape)

        rv = rv / np.sqrt(np.sum(rv ** 2, axis = (1,2,3), keepdims = True))
        perturbed = self.blend(1, sample, delta, rv)
        rv = (perturbed - sample) / delta
        # back in image space

        # query the model.
        decisions = self.decision_function(self.get_xadv(perturbed), label, target)
        decision_shape = [len(decisions)] + [1] * len(sample.shape)
        fval = 2 * decisions.astype(float).reshape(decision_shape) - 1.0

        # Baseline subtraction (when fval differs)
        if np.mean(fval) == 1.0: # label changes. 
            gradf = np.mean(rv, axis = 0)
        elif np.mean(fval) == -1.0: # label not change.
            gradf = - np.mean(rv, axis = 0)
        else:
            fval -= np.mean(fval)
            gradf = np.mean(fval * rv, axis = 0) 

        # Get the gradient direction.
        # print('gradf', gradf.shape)
        # print('rv', rv.shape)
        gradf = gradf / np.linalg.norm(gradf)

        return gradf
    
    def blend(self, a, x, b, v):
        if type(v) is not torch.Tensor:
            v = torch.tensor(v, dtype=torch.float).cuda()
        if type(x) is not torch.Tensor:
            x = torch.tensor(x, dtype=torch.float).cuda()
        if type(a) is not torch.Tensor:
            a = torch.tensor(a, dtype=torch.float).cuda()
        if type(b) is not torch.Tensor:
            b = torch.tensor(b, dtype=torch.float).cuda()
        
        out = (a * x) + (b * v)
        return self.clip_image(out)

    def project(self, original_image, attempts, alphas):
        # input: attempts image space
        # output: out_images image space
        alphas_shape = [1] * len(original_image.shape)
        # alphas_shape = [1] * len(self.dim_attack_space(original_image))
        alphas = alphas.reshape(alphas_shape)
        if self.order == 2:
            return (1-alphas)*original_image + alphas*attempts
        elif self.order == np.inf:
            out_images = self.clip_image(
                    attempts, 
                    original_image - alphas, 
                    original_image + alphas
                )
            return out_images


    def binary_search_batch(self, original_image, attempts, label, theta, target):
        """ Binary search to approach the boundar. """

        # Compute distance between each of perturbed image and original image.
        dists_post_update = np.array([
                                self.compute_distance(
                                        original_image, 
                                        self.get_xadv(attempt) 
                            ) 
                            for attempt in attempts])
        #print(dists_post_update)
        # Choose upper thresholds in binary searchs based on constraint.
        if self.order == np.inf:
            highs = dists_post_update
            # Stopping criteria.
            thresholds = np.minimum(dists_post_update * theta, theta)
        else:
            highs = np.ones(len(attempts))
            thresholds = theta

        lows = np.zeros(len(attempts))

        # Call recursive function. 
        while np.max((highs - lows) / thresholds) > 1:
            # projection to mids.
            mids = (highs + lows) / 2.0
            mid_images = self.project(original_image, attempts, mids)
            # print(mid_images.shape)
            # Update highs and lows based on model decisions.
            decisions = self.decision_function(self.get_xadv(mid_images), label, target)
            lows = np.where(decisions == 0, mids, lows)
            highs = np.where(decisions == 1, mids, highs)

        out_images = self.project(original_image, attempts, highs)

        # Compute distance of the output image to select the best choice. 
        # (only used when stepsize_search is grid_search.)
        dists = np.array([
                    self.compute_distance(
                            original_image, 
                            self.get_xadv(out_image) 
                ) 
                for out_image in out_images])
        idx = np.argmin(dists)

        dist = dists_post_update[idx]
        out_image = out_images[idx]
        return out_image, dist

    def blended_uniform_search(self, x, y, target, initial=None):
        # Find a misclassified random noise.
        while True:
            if initial is not None:
                initial_point = self.to_attack_space(initial)
            else:
                initial_point = np.random.uniform(*self.model.bounds, size=self.dim_attack_space(x))
            #print(random_noise[None].shape)
            success = self.decision_function(self.blend(0, x, 1, initial_point), y, target)[0]
            if success:
                # with initial != None we will break here
                break

            if self.model.get_num_queries() > self.init_num_evals:
                print("Initialization failed!")
                return None

        # Binary search to minimize l2 distance to original image.
        low = 0.0
        high = 1.0
        while high - low > 0.001:
            mid = (high + low) / 2.0
            blended = self.blend(1 - mid, x, mid, initial_point)
            # blended = (1 - mid) * self.to_attack_space(x) + mid * initial_point
            if self.decision_function(self.get_xadv(blended), y, target):
                high = mid
            else:
                low = mid
                
        return self.blend((1 - high), x, high, initial_point)
        
    def initialize(self, input_xi, label, target, target_loader):
        """ 
        Efficient Implementation of BlendedUniformNoiseAttack in Foolbox.
        """
        success = 0
        
        initialization = None
        
        if target is None:
            initialization = self.blended_uniform_search(input_xi, label, target)

        else:
            # Iterate through training dataset. Find best initial point.
            for i, data in enumerate(target_loader):
                if i > self.target_limit_break:
                    break
                
                xi, yi = data
                    
                if not self.benign_succ(self.get_xorig(xi), yi, target):
                    continue
                
                initialization = self.blended_uniform_search(input_xi, yi, target, initial=xi)
                break

        return initialization


    def geometric_progression_for_stepsize(self, x, label, update, dist, j, target):
        """
        Geometric progression to search for stepsize.
        Keep decreasing stepsize by half until reaching 
        the desired side of the boundary,
        """
        epsilon = dist / np.sqrt(j) 

        def phi(epsilon):
            new = x + epsilon*update
            success = self.decision_function(new[None], label, target)
            # success = self.decision_function(new, label, target)
            return success

        budget = 5000
        while not phi(epsilon) and budget:
            epsilon /= 2.0
            budget -= 1

        return epsilon


    def __call__(self, data, label, epsilon, target=None, target_loader=None, query_limit=20000, seed=None):
        super().validate_args(data, label, epsilon, target, target_loader, query_limit)

        data = data.cpu().numpy()
        label = label.cpu().numpy()
        self.x_orig = data
        
        if target:
            target = target.cpu().numpy()
        adv = self.hsja(data, label, epsilon, target, target_loader, query_limit, seed)
        
        return self.postprocess_result(adv)
 
