import time, torch
import numpy as np 

from tqdm import tqdm

from attack import Attack


class OPT_attack(Attack):
    def __init__(self, model, dataset="", order=2, early_stopping=False):
        Attack.__init__(self, model, order=order, dataset=dataset, early_stopping=early_stopping)
        self.num_directions = 100

    def attack_hard_label(self, x, y, epsilon, target=None, query_limit=10000, seed=None, target_loader=None,
                          alpha = 0.2, beta = 0.001):
        """ Attack the original image and return adversarial example
            model: (pytorch model)
            train_dataset: set of training data
            (x0, y0): original image
        """
        model = self.model
        if type(x) is torch.Tensor:
            x0 = x.float().cpu()
        if type(y) is torch.Tensor:
            y0 = y.item()
        if target and type(target) is torch.Tensor:
            target = target.item()
            
        if self.benign_succ(self.get_xorig(x0), y0, None):
            print("Fail to classify the image. No need to attack.")
            return None
        
        self.manual_seed(seed)

        # Calculate a good starting point.
        best_theta, g_theta = self.initialize(x0, y0, epsilon, target, target_loader)
        print(f"Initialized in {self.model.get_num_queries()} queries")
        
        if g_theta == np.inf:
            print("Couldn't find valid initial, failed")
            return None
        
        dist = self.compute_distance(x0, self.get_xadv(x0, g_theta, best_theta))
        self.update_log(dist, epsilon)

        # Begin Gradient Descent at the boundary
        timestart = time.time()
        g1 = 1.0
        theta, g2 = best_theta, g_theta
        stopping = 0.01
        for i in range(self.num_iterations):
            gradient = torch.zeros(theta.shape).cuda()
            q = 10
            min_g1 = float('inf')
            for _ in tqdm(range(q)):
                u = torch.randn(*theta.shape).cuda()
                u /= torch.norm(u)
                ttt = theta+beta * u
                ttt /= torch.norm(ttt)
                g1 = self.fine_grained_binary_search_local(x0, y0, ttt, initial_lbd = g2, tol=beta/500, target=target)
                gradient += (g1-g2)/beta * u
                if g1 < min_g1:
                    min_g1 = g1
                    min_ttt = ttt
            gradient = 1.0/q * gradient
            
            if self.stopping_condition(query_limit, dist, epsilon):
                break
            
            min_theta = theta
            min_g2 = g2
            
            for _ in tqdm(range(15)):
                new_theta = theta - alpha * gradient
                new_theta /= torch.norm(new_theta + 1e-8)
                new_g2 = self.fine_grained_binary_search_local(x0, y0, new_theta, initial_lbd = min_g2, tol=beta/500)
                alpha = alpha * 2
                if new_g2 < min_g2:
                    min_theta = new_theta 
                    min_g2 = new_g2
                else:
                    break

            if min_g2 >= g2:
                for _ in tqdm(range(15)):
                    alpha = alpha * 0.25
                    new_theta = theta - alpha * gradient
                    new_theta /= torch.norm(new_theta + 1e-8)
                    new_g2 = self.fine_grained_binary_search_local(x0, y0, new_theta, initial_lbd = min_g2, tol=beta/500)
                    if new_g2 < g2:
                        min_theta = new_theta 
                        min_g2 = new_g2
                        break

            if min_g2 <= min_g1:
                theta, g2 = min_theta, min_g2
            else:
                theta, g2 = min_ttt, min_g1

            if g2 < g_theta:
                best_theta, g_theta = theta, g2
            
            if alpha < 1e-4:
                alpha = 1.0
                print("Warning: not moving, g2 %lf g_theta %lf" % (g2, g_theta))
                beta = beta * 0.1
                if (beta < 1e-8):
                    break
            
            dist = self.compute_distance(x0, self.get_xadv(x0, g_theta, theta))
            print(f"Iteration {i+1} distortion {dist:.4f} num_queries {self.model.get_num_queries()}")

            self.update_log(dist, epsilon)
            
            if self.stopping_condition(query_limit, dist, epsilon):
                break
        
        timeend = time.time()
        dist = self.compute_distance(x0, self.get_xadv(x0, g_theta, best_theta))
        
        if self.search_succ(self.get_xadv(x0, g_theta, best_theta), y0, target):
            print(f"\nAdversarial Example Found Successfully: distortion {dist:.4f} pred {target if target else 'OK'} in "
                  f"{self.model.get_num_queries()} queries \nTime: {timeend-timestart:.4f} seconds")
        
        self.update_log(dist, epsilon)
        return self.get_xadv(x0, g_theta, best_theta)

    def initialize(self, x0, y0, epsilon, target=None, target_loader=None):
        best_theta, g_theta = None, np.inf
        timestart = time.time()

        if target is None:
            num_directions = self.num_directions
            print("Searching for the initial direction on %d random directions: " % (num_directions))

            for i in range(num_directions):
                theta = torch.randn(*self.dim_attack_space(x0)).cuda()
                if self.search_succ(self.get_xadv(x0, 1, theta), y0, target):
                    initial_lbd = torch.norm(theta)
                    theta /= initial_lbd
                    lbd = self.fine_grained_binary_search(x0, y0, theta, initial_lbd, g_theta)
                    if lbd < g_theta:
                        best_theta, g_theta = theta, lbd
                        dist = self.compute_distance(x0, self.get_xadv(x0, g_theta, best_theta))
                        print(f"--------> Found distortion {dist:.4f}")

            if g_theta == np.inf:
                num_directions = self.num_directions * 5
                best_theta, g_theta = None, np.inf
                print("Searching for the initial direction on %d random directions: " % (num_directions))
                timestart = time.time()
                for i in range(num_directions):
                    theta = torch.randn(*self.dim_attack_space(x0)).cuda()
                    if self.search_succ(self.get_xadv(x0, 1, theta), y0, target):
                        initial_lbd = torch.norm(theta)
                        theta /= initial_lbd
                        lbd = self.fine_grained_binary_search(x0, y0, theta, initial_lbd, g_theta)
                        if lbd < g_theta:
                            best_theta, g_theta = theta, lbd
                            dist = self.compute_distance(x0, self.get_xadv(x0, g_theta, best_theta))
                            print(f"--------> Found distortion {dist:.4f}")

            timeend = time.time()
            if g_theta != np.inf:
                dist = self.compute_distance(x0, self.get_xadv(x0, g_theta, best_theta))
                print(f"==========> Found best distortion {dist:.4f} in {timeend-timestart:.4f} seconds "
                      f"using {self.model.get_num_queries()} queries")
        else:
            print("Searching for the initial direction on %d samples: " % (self.target_limit_break))

            # Iterate through target dataset. Find best initial point for gradient descent.
            for i, data in enumerate(target_loader):
                if i > self.target_limit_break:
                    break
                
                xi, yi = data
                    
                yi_pred = self.model.predict_label(self.get_xorig(xi))
                if not self._equal(yi_pred, target):
                    continue

                theta = self.to_attack_space(xi.cpu()) - self.to_attack_space(x0)
                theta = theta.cuda()
                initial_lbd = torch.norm(theta)
                theta /= initial_lbd
                lbd = self.fine_grained_binary_search(x0, y0, theta, initial_lbd, g_theta, target=target)
                if lbd < g_theta:
                    best_theta, g_theta = theta, lbd
                    dist = self.compute_distance(x0, self.get_xadv(x0, g_theta, best_theta))
                    print(f"--------> Found distortion {dist:.4f}")
            
        return best_theta, g_theta
    
    def fine_grained_binary_search_local(self, x0, y0, theta, initial_lbd = 1.0, tol=1e-5, target=None):
        # Search around the boundary
        lbd = initial_lbd
         
        if not self.search_succ(self.get_xadv(x0, lbd, theta), y0, target):
            lbd_lo = lbd
            lbd_hi = lbd*1.01
            while not self.search_succ(self.get_xadv(x0, lbd_hi, theta), y0, target):
                lbd_hi = lbd_hi*1.01
                if lbd_hi > 1e3:
                    return 1e6
        else:
#             # Quick check
#             if self.search_succ(self.get_xorig(x0), y0, None):
#                 print("x0 is misclassified, abort.")
#                 return 1e6
            
            lbd_hi = lbd
            lbd_lo = lbd*0.99
            
            while self.search_succ(self.get_xadv(x0, lbd_lo, theta), y0, target):
                lbd_lo = lbd_lo*0.99
        
        while_budget = 3000
        while (lbd_hi - lbd_lo) > tol and while_budget:
            while_budget -= 1
            lbd_mid = (lbd_lo + lbd_hi)/2.0
            if self.search_succ(self.get_xadv(x0, lbd_mid, theta), y0, target):
                lbd_hi = lbd_mid
            else:
                lbd_lo = lbd_mid
                
        return lbd_hi

    def fine_grained_binary_search(self, x0, y0, theta, initial_lbd, current_best, tol=1e-5, target=None):
        # Search for the boundary
        if initial_lbd > current_best: 
            # Exit if we have correct despite chosen theta
            if not self.search_succ(self.get_xadv(x0, current_best, theta), y0, target):
                return np.inf
            lbd = current_best
        else:
            lbd = initial_lbd
        
        lbd_hi = lbd
        lbd_lo = 0.0
        
        while_budget = 3000
        while (lbd_hi - lbd_lo) > tol and while_budget:
            while_budget -= 1
            lbd_mid = (lbd_lo + lbd_hi)/2.0
            if self.search_succ(self.get_xadv(x0, lbd_mid, theta), y0, target):
                # Hit desired side of decision boundary
                lbd_hi = lbd_mid
            else:
                # Hit somewhere else
                lbd_lo = lbd_mid
        return lbd_hi


    def __call__(self, data, label, epsilon, target=None, target_loader=None, query_limit=40000, seed=None):
        super().validate_args(data, label, epsilon, target, target_loader, query_limit)

        adv = self.attack_hard_label(data, label, epsilon, target, 
                                     query_limit=query_limit, target_loader=target_loader)
        return self.postprocess_result(adv)  
    
        
