import torch

import utils

# args = parser.parse_args()


def expand_vector(x, size, image_size):
    x = x.view(-1, 3, size, size)
    z = torch.zeros(1, 3, image_size, image_size)
    z[:, :, :size, :size] = x
    return z


def normalize(x):
    return utils.apply_normalization(x, 'imagenet')


def get_probs(model, x, y):
    output = model(x.cuda(),unnormalization=False)
    probs = torch.index_select(torch.nn.Softmax()(output).data, 1, y.cuda())
    return probs


def get_preds(model, x):
    output = model(x.cuda(),unnormalization=False)
    _, preds = output.data.max(1)
    return preds.cpu()


class SimBA_Attack(object):

    ###############################################################
    def __init__(self, model):
        self.model = model




    # runs simba on a batch of images <images_batch> with true labels (for untargeted attack) or target labels
    # (for targeted attack) <labels_batch>
    def dct_attack_batch(self, images_batch, labels_batch, max_iters, freq_dims, stride, epsilon, linf_bound, order='rand', query_limit=20000, pixel_attack=False):

        image_size = images_batch.size(2)
        with torch.no_grad():
            # sample a random ordering for coordinates independently per batch element
            if order == 'rand':
                indices = torch.randperm(3 * freq_dims * freq_dims)[:max_iters]
            elif order == 'diag':
                indices = utils.diagonal_order(image_size, 3)[:max_iters]
            elif order == 'strided':
                indices = utils.block_order(image_size, 3, initial_size=freq_dims, stride=stride)[:max_iters]
            else:
                indices = utils.block_order(image_size, 3)[:max_iters]
            if order == 'rand':
                expand_dims = freq_dims
            else:
                expand_dims = image_size
            n_dims = 3 * expand_dims * expand_dims
            # logging tensors

            # indices. perm = torch.randperm(n_dims)
            # last_prob = get_probs(model, x, y) prev_probs
            x = torch.zeros(n_dims)

            if pixel_attack:
                trans = lambda z: z
            else:
                trans = lambda z: utils.block_idct(z, block_size=image_size, linf_bound=linf_bound)
            last_prob = get_probs(self.model, images_batch, labels_batch)
            for k in range(max_iters):
                expanded = (images_batch + trans(expand_vector(x, expand_dims, image_size))).clamp(0, 1)
                #print(expanded.size())
                preds = get_preds(self.model, expanded)
                if self.model.get_num_queries()>query_limit:
                    break
                if preds.ne(labels_batch):
                    break
                dim = indices[k] # = perm[i]
                diff = torch.zeros(n_dims)
                diff[dim] = epsilon
                left_vec = x - diff
                right_vec = x + diff
                left_adv = (images_batch + trans(expand_vector(left_vec, expand_dims, image_size))).clamp(0, 1)
                left_prob = get_probs(self.model, left_adv, labels_batch)
                if left_prob < last_prob:
                    x =left_vec
                    last_prob = left_prob
                else:
                    right_adv = (images_batch + trans(expand_vector(right_vec, expand_dims, image_size))).clamp(0, 1)
                    right_prob = get_probs(self.model, right_adv, labels_batch)
                    if right_prob < last_prob:
                        x =right_vec
                        last_prob = right_prob
        return expanded



    def attack_untargeted(self, x_0, y_0, args, query_limit=20000):
        image_size = x_0.size(2)
        if args.order == 'rand':
            n_dims = 3 * args.freq_dims * args.freq_dims
        else:
            n_dims = 3 * image_size * image_size
        if args.num_iters > 0:
            max_iters = int(min(n_dims, args.num_iters))
        else:
            max_iters = int(n_dims)

        adv= self.dct_attack_batch(x_0, y_0, max_iters, args.freq_dims, args.stride, args.epsilon, args.linf_bound,order=args.order,query_limit=query_limit, pixel_attack=args.pixel_attack)
        return adv

