import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from attack import RayS
from attack import Sampling_Attack


class Sampling_RayS(RayS):
    def __init__(self, model, order=2, dataset="", early_stopping=False, a=1, b=2):
        RayS.__init__(self, model, order=order, dataset=dataset, early_stopping=early_stopping)
        # Sampling_Attack.__init__(self, enc, dec)
        self.model = model
        self.order = order
        self.a = a  # upsampling parameter
        self.b = b  # downsampling parameter
        print(f"\tRayS using a={a}, b={b}")
        self.z_final = None
    
    def attack_hard_label(self, x, y, epsilon, target=None, query_limit=10000, seed=None):
        """ Attack the original image and return adversarial example
            model: (pytorch model)
            (x, y): original image
        """
        if self.benign_succ(self.get_xorig(x), y, None):
            print("Fail to classify the image. No need to attack.")
            return None
        
        shape = self.dim_attack_space(x)
        dim = np.prod(shape[1:])
        if seed is not None:
            np.random.seed(seed)

        self.d_t = np.inf
        self.sgn_t = torch.sign(torch.ones(shape)).cuda()
        self.x_final = self.get_xadv(x, self.d_t, self.sgn_t)
        dist = torch.tensor(np.inf)
        block_level = 0
        block_ind = 0
        block_tracker = 0
        dist = self.compute_distance(self.x_final, x)
        self.update_log(dist, epsilon)
        
        for i in range(query_limit):

            block_num = 2 ** block_level
            block_size = int(np.ceil(dim / block_num))
            start, end = block_ind * block_size, min(dim, (block_ind + 1) * block_size)

            attempt = self.sgn_t.clone().view(shape[0], dim)
            attempt[:, start:end] *= -1.
            attempt = attempt.view(shape)

            self.binary_search(x, y, target, attempt)
            
            # print(f"i={i}, s={block_level}, k={block_ind}, move criterion={(2 ** block_level)*self.b}")
            
            block_ind += 1
            block_tracker += 1
            if block_tracker == (2 ** block_level)*self.b:
                print(f"Increased to block level {block_level}")
                block_level += self.a
                block_tracker = 0
                
            if block_ind == 2 ** block_level or end == dim:
                print(f"Reset block index.")
                block_ind = 0

            dist = self.compute_distance(self.x_final, x)
            self.update_log(dist, epsilon)
            
            if self.stopping_condition(query_limit, dist, epsilon):
                break

            if i % self.print_freq == 0:
                print(f"Iter {i + 1} d_t {self.d_t:.8f} dist {dist:.8f} queries {self.model.get_num_queries()}")
                

        print(f"Iter {i + 1} d_t {self.d_t:.6f} dist {dist:.6f} queries {self.model.get_num_queries()}")
        # return self.x_final, self.queries, dist, (dist <= self.epsilon).float()
        return self.x_final
    
    def __call__(self, data, label, epsilon=0.3, target=None, target_loader=None, seed=None, query_limit=10000):
        super().validate_args(data, label, epsilon, target, target_loader, query_limit)
        self.class_conditional = int(label.item()) if target is None else int(target.item())
        
        adv = self.attack_hard_label(data, label, epsilon, target=target, seed=seed, query_limit=query_limit)
        return self.postprocess_result(adv)
