"""
Created on Thu Sep 26 01:34:01 2019

@author: AliRah
"""


import torch.nn as nn

from .generate_2d_dct_basis_lib import *

import numpy as np
import torch
import os

import math


grad_estimator_batch_size = 40     # batch size for GeoDA


verbose_control = 'Yes'


device = 'cuda' if torch.cuda.is_available() else 'cpu'


tol = 0.0001
sigma = 0.0002
mu = 0.6


#dist = 'linf'
dist = 'l2'
search_space = 'sub'


###############################################################
# Functions
###############################################################







class SubNoise(nn.Module):
    """given subspace x and the number of noises, generate sub noises"""

    # x is the subspace basis
    def __init__(self, num_noises, x, image_size):
        self.num_noises = num_noises
        self.x = x
        self.image_size=image_size
        super(SubNoise, self).__init__()

    def forward(self):
        noise = torch.randn([self.x.shape[1], 3 * self.num_noises], dtype=torch.float32).cuda()
        sub_noise = torch.transpose(torch.mm(self.x, noise), 0, 1)
        r = sub_noise.view([self.num_noises, 3, self.image_size, self.image_size])

        r_list = r
        return r_list


def opt_query_iteration(Nq, T, eta):
    coefs = [eta ** (-2 * i / 3) for i in range(0, T)]
    coefs[0] = 1 * coefs[0]

    sum_coefs = sum(coefs)
    opt_q = [round(Nq * coefs[i] / sum_coefs) for i in range(0, T)]

    if opt_q[0] > 80:
        T = T + 1
        opt_q, T = opt_query_iteration(Nq, T, eta)
    elif opt_q[0] < 50:
        T = T - 1

        opt_q, T = opt_query_iteration(Nq, T, eta)

    return opt_q, T


def uni_query(Nq, T, eta):
    opt_q = [round(Nq / T) for i in range(0, T)]
    return opt_q





class GeoDA_Attack(object):

    ###############################################################
    def GeoDA(self,x_0,x_b, iteration, q_opt):
        norms = []
        q_num = 0
        grad = 0
        image_w=x_0.size()[-1]
        for i in range(iteration):


            random_vec_o = torch.randn(q_opt[i], 3, image_w, image_w)

            grad_oi, ratios = self.black_grad_batch(x_b, q_opt[i], sigma, random_vec_o, grad_estimator_batch_size,
                                               self.orig_label)
            q_num = q_num + q_opt[i]
            grad = grad_oi + grad
            x_adv, qs, eps = self.go_to_boundary(x_0, grad, x_b)
            q_num = q_num + qs
            x_adv, bin_query = self.bin_search(x_0, x_adv, tol)

            q_num = q_num + bin_query

            x_b = x_adv

            #x_adv_inv = inv_tf(x_adv.cpu().numpy()[0, :, :, :].squeeze(), mean, std)


        x_adv = x_adv.clamp(0.,1.)

        return x_adv, q_num, grad
    def __init__(self, model,dataset='imagenet'):
        self.model = model
        if dataset=='cifar':
            self.image_size=32
            self.epsilon=1.0
            self.sub_dim=11
        else:
            self.image_size=224
            self.epsilon=5.
            self.sub_dim=75

        ###############################################################
        if search_space == 'sub':
            print('Check if DCT basis available ...')

            path = os.path.join(os.path.dirname(__file__), '2d_dct_basis_{}.npy'.format(self.sub_dim))
            if os.path.exists(path):
                print('Yes, we already have it ...')
                sub_basis = np.load(path).astype(np.float32)
            else:
                print('Generating dct basis ......')
                sub_basis = generate_2d_dct_basis(self.sub_dim,self.image_size,path).astype(np.float32)
                print('Done!\n')

            self.sub_basis_torch = torch.from_numpy(sub_basis).cuda()

    def is_adversarial(self,given_image, orig_label):
        predict_label = self.model.predict_label(given_image).item()
        return predict_label != orig_label





    def find_random_adversarial(self,image, epsilon=1000):
        num_calls = 1

        step = 0.02
        perturbed = image

        while self.is_adversarial(perturbed, self.orig_label) == 0:
            pert = torch.randn_like(image)
            pert = pert.to(device)

            perturbed = image + num_calls * step * pert
            perturbed = perturbed.clamp(0.,1.)
            perturbed = perturbed.to(device)
            num_calls += 1

        return perturbed, num_calls

        ###############################################################

    ###############################################################

    def bin_search(self,x_0, x_random, tol):

        num_calls = 0
        adv = x_random
        cln = x_0

        while True:

            mid = (cln + adv) / 2.0
            num_calls += 1

            if self.is_adversarial(mid, self.orig_label):
                adv = mid
            else:
                cln = mid

            if torch.norm(adv - cln).cpu().numpy() < tol:
                break

        return adv, num_calls

        ###############################################################

    def go_to_boundary(self, x_0, grad, x_b):
        epsilon = self.epsilon

        num_calls = 1
        perturbed = x_0

        if dist == 'l1' or dist == 'l2':
            grads = grad

        if dist == 'linf':
            grads = torch.sign(grad) / torch.norm(grad)

        while self.is_adversarial(perturbed, self.orig_label) == 0:

            perturbed = x_0 + (num_calls * epsilon * grads[0])
            perturbed = perturbed.clamp(0., 1.)
            num_calls += 1

            if num_calls > 100:
                print('falied ... ')
                break
        return perturbed, num_calls, epsilon * num_calls
    def black_grad_batch(self,x_boundary, q_max, sigma, random_noises, batch_size, original_label):

        grad_tmp = []  # estimated gradients in each estimate_batch
        z = []  # sign of grad_tmp
        outs = []
        num_batchs = math.ceil(q_max / batch_size)
        last_batch = q_max - (num_batchs - 1) * batch_size
        EstNoise = SubNoise(batch_size, self.sub_basis_torch,self.image_size).cuda()
        all_noises = []
        for j in range(num_batchs):
            if j == num_batchs - 1:
                EstNoise_last = SubNoise(last_batch, self.sub_basis_torch,self.image_size).cuda()
                current_batch = EstNoise_last()
                current_batch_np = current_batch.cpu().numpy()
                noisy_boundary = [x_boundary[0, :, :,
                                  :].cpu().numpy()] * last_batch + sigma * current_batch.cpu().numpy()

            else:
                current_batch = EstNoise()
                current_batch_np = current_batch.cpu().numpy()
                noisy_boundary = [x_boundary[0, :, :,
                                  :].cpu().numpy()] * batch_size + sigma * current_batch.cpu().numpy()

            all_noises.append(current_batch_np)

            noisy_boundary_tensor = torch.tensor(noisy_boundary).to(device)


            predict_labels = self.model.predict_label(noisy_boundary_tensor).cpu().numpy().astype(int)

            outs.append(predict_labels)
        all_noise = np.concatenate(all_noises, axis=0)
        outs = np.concatenate(outs, axis=0)

        for i, predict_label in enumerate(outs):
            if predict_label == original_label:
                z.append(1)
                grad_tmp.append(all_noise[i])
            else:
                z.append(-1)
                grad_tmp.append(-all_noise[i])

        grad = -(1 / q_max) * sum(grad_tmp)

        grad_f = torch.tensor(grad).to(device)[None, :, :, :]

        return grad_f, sum(z)

    ###############################################################

    def attack_untargeted(self, x_0, y_0, query_limit=20000):

        self.orig_label=y_0
        x_random, query_random_1 = self.find_random_adversarial(x_0, epsilon=100)
        # Binary search
        x_boundary, query_binsearch_2 = self.bin_search(x_0, x_random, tol)
        x_b = x_boundary
        Q_max = query_limit+500
        ###################################
        # Run over iterations

        iteration = round(Q_max / 500)
        q_opt_it = int(Q_max - (iteration) * 25)
        q_opt_iter, iterate = opt_query_iteration(q_opt_it, iteration, mu)
        q_opt_it = int(Q_max - (iterate) * 25)
        q_opt_iter, iterate = opt_query_iteration(q_opt_it, iteration, mu)

        x_adv, query_o, gradient = self.GeoDA(x_0,x_b, iterate, q_opt_iter)
        return x_adv

###############################################################

