import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from attack import RayS
from attack import HLM_Attack


class HLMasdf_RayS(RayS, HLM_Attack):
    def __init__(self, model, enc, dec, order=2, dataset=""):
        RayS.__init__(self, model, order=order, dataset=dataset)
        HLM_Attack.__init__(self, enc, dec)
        self.model = model
        self.order = order
        self.z_final = None
        
    def attack_hard_label(self, x, y, epsilon, target=None, query_limit=10000, seed=None):
        """ Attack the original image and return adversarial example
            model: (pytorch model)
            (x, y): original image
        """
        if self.benign_succ(self.get_xorig(x), y, None):
            print("Fail to classify the image. No need to attack.")
            return None
        
        shape = self.dim_attack_space(x)
        dim = np.prod(shape[1:])
        if seed is not None:
            np.random.seed(seed)

        self.d_t = np.inf
        self.sgn_t = torch.sign(torch.ones(shape)).cuda()
        self.x_final = self.get_xadv(x, self.d_t, self.sgn_t)
        dist = torch.tensor(np.inf)
        block_level = 0
        block_ind = 0
        dist = self.compute_distance(self.x_final, x)
        self.update_log(dist, epsilon)

        for i in range(query_limit):

            block_num = 2 ** block_level  # 1, 2, 4, 8, 16
            block_size = int(np.ceil(dim / block_num))  # d / above
            start, end = block_ind * block_size, min(dim, (block_ind + 1) * block_size)
            # 0 -> d or next block
            attempt = self.sgn_t.clone().view(shape[0], dim)
            attempt[:, start:end] *= -1.
            attempt = attempt.view(shape)

            self.binary_search(x, y, target, attempt)

            block_ind += 1
            if block_ind == 2 ** block_level or end == dim:
                block_level += 1
                print(f"Increased to block level {block_level}")
                block_ind = 0

            dist = self.compute_distance(self.x_final, x)
            self.update_log(dist, epsilon)
            
            if self.stopping_condition(query_limit, dist, epsilon):
                break

            if i % self.print_freq == 0:
                print(f"Iter {i + 1} d_t {self.d_t:.8f} dist {dist:.8f} queries {self.model.get_num_queries()}")
                

        print(f"Iter {i + 1} d_t {self.d_t:.6f} dist {dist:.6f} queries {self.model.get_num_queries()}")
        # return self.x_final, self.queries, dist, (dist <= self.epsilon).float()
        return self.x_final

    def lin_search(self, x, y, target, sgn):
        d_end = np.inf
        for d in range(1, self.lin_search_rad + 1):
            if self.search_succ(self.get_xadv(x, d, sgn), y, target):
                d_end = d
                break
        return d_end

    def binary_search(self, x, y, target, sgn, tol=1e-3):
        sgn_unit = sgn / torch.norm(sgn)
        sgn_norm = torch.norm(sgn)

        d_start = 0
        if np.inf > self.d_t:  # already have current result
            if not self.search_succ(self.get_xadv(x, self.d_t, sgn_unit), y, target):
                return False
            d_end = self.d_t
        else:  # init run, try to find boundary distance
            d = self.lin_search(x, y, target, sgn)
            if d < np.inf:
                d_end = d * sgn_norm
            else:
                return False

        while (d_end - d_start) > tol:
            d_mid = (d_start + d_end) / 2.0
            if self.search_succ(self.get_xadv(x, d_mid, sgn_unit), y, target):
                d_end = d_mid
            else:
                d_start = d_mid
        if d_end < self.d_t:
            self.d_t = d_end
            self.x_final = self.get_xadv(x, d_end, sgn_unit)
            self.sgn_t = sgn
            return True
        else:
            return False

    def __call__(self, data, label, epsilon=0.3, target=None, target_loader=None, seed=None, query_limit=10000):
        super().validate_args(data, label, epsilon, target, target_loader, query_limit)
        self.class_conditional = int(label.item()) if target is None else int(target.item())
        
        adv = self.attack_hard_label(data, label, epsilon, target=target, seed=seed, query_limit=query_limit)
        return self.postprocess_result(adv)
