import torch as ch
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
import torchvision.transforms.functional as F
from torch.utils.data import DataLoader
from torch.nn import DataParallel
from torch.nn.modules import Upsample
import argparse
import json
import pdb
import torch
class Parameters():
    '''
    Parameters class, just a nice way of accessing a dictionary
    > ps = Parameters({"a": 1, "b": 3})
    > ps.A # returns 1
    > ps.B # returns 3
    '''

    def __init__(self, params):
        self.params = params

    def __getattr__(self, x):
        return self.params[x.lower()]


def norm(t):
    assert len(t.shape) == 4
    norm_vec = ch.sqrt(t.pow(2).sum(dim=[1, 2, 3])).view(-1, 1, 1, 1)
    norm_vec += (norm_vec == 0).float() * 1e-8
    return norm_vec


###
# Different optimization steps
# All take the form of func(x, g, lr)
# eg: exponentiated gradients
# l2/linf: projected gradient descent
###

def eg_step(x, g, lr):
    real_x = (x + 1) / 2  # from [-1, 1] to [0, 1]
    pos = real_x * ch.exp(lr * g)
    neg = (1 - real_x) * ch.exp(-lr * g)
    new_x = pos / (pos + neg)
    return new_x * 2 - 1


def linf_step(x, g, lr):
    return x + lr * ch.sign(g)


def l2_prior_step(x, g, lr):
    new_x = x + lr * g / norm(g)
    norm_new_x = norm(new_x)
    norm_mask = (norm_new_x < 1.0).float()
    return new_x * norm_mask + (1 - norm_mask) * new_x / norm_new_x


def gd_prior_step(x, g, lr):
    return x + lr * g


def l2_image_step(x, g, lr):
    return x + lr * g / norm(g)


##
# Projection steps for l2 and linf constraints:
# All take the form of func(new_x, old_x, epsilon)
##

def l2_proj(image, eps):
    orig = image.clone()
    def proj(new_x):
        delta = new_x - orig
        out_of_bounds_mask = (norm(delta) > eps).float()
        x = (orig + eps * delta / norm(delta)) * out_of_bounds_mask
        x += new_x * (1 - out_of_bounds_mask)
        return x

    return proj


def linf_proj(image, eps):
    orig = image.clone()

    def proj(new_x):
        return orig + ch.clamp(new_x - orig, -eps, eps)

    return proj


##
# Main functions
##



class Bandit_Attack(object):
    def __init__(self, model,dataset='imagenet'):
        if dataset=='cifar':
            mode='Bandit-TD-Cifar'
        else:
            mode='Bandit-TD'
        if mode=='Bandit-TD':
            self.model = model
            self.max_queries=20000
            self.fd_eta=0.01
            self.image_lr=0.5
            self.online_lr=0.1
            self.mode='l2'
            self.exploration=0.01
            self.tile_size=50
            self.epsilon=5.0-1e-1
            self.nes=False
            self.tiling=True
            self.input_size=224
            self.gradient_iters=1
            self.batch_size=1
        elif mode=='Bandit-TD-Cifar':
            self.model = model
            self.max_queries = 10000
            self.fd_eta = 0.5
            self.image_lr = 0.5
            self.online_lr = 0.1
            self.mode = 'l2'
            self.exploration = 1.0
            self.tile_size = 11
            self.epsilon = 0.5
            self.nes = False
            self.tiling = True
            self.input_size = 32
            self.gradient_iters = 1
            self.batch_size = 1

    def make_adversarial_examples(self, image, true_label, query_limit):
        '''
        The main process for generating adversarial examples with priors.
        '''
        # Initial setup
        with torch.no_grad():
            prior_size =  self.input_size if not self.tiling else self.tile_size
            upsampler = Upsample(size=( self.input_size,  self.input_size))
            total_queries = 0
            prior = torch.zeros(self.batch_size, 3, prior_size, prior_size).cuda()
            dim = prior.nelement() / self.batch_size
            prior_step = gd_prior_step if self.mode == 'l2' else eg_step
            image_step = l2_image_step if self.mode == 'l2' else linf_step
            proj_maker = l2_proj if self.mode == 'l2' else linf_proj
            proj_step = proj_maker(image, self.epsilon)

            # Loss function
            criterion = ch.nn.CrossEntropyLoss(reduction='none')

            def normalized_eval(x):
                x_copy = x.clone()
                return self.model(x_copy.cuda(),unnormalization=False)
            L = lambda x: criterion(normalized_eval(x), true_label.cuda())

            while not (total_queries > query_limit):
                #print(total_queries , self.model.get_num_queries())
                if not self.nes:
                    ## Updating the prior:
                    # Create noise for exporation, estimate the gradient, and take a PGD step
                    exp_noise = self.exploration * ch.randn_like(prior).cuda() / (dim ** 0.5)
                    # Query deltas for finite difference estimator
                    q1 = upsampler(prior + exp_noise)
                    q2 = upsampler(prior - exp_noise)
                    # Loss points for finite difference estimator
                    l1 = L(image + self.fd_eta * q1 / norm(q1))  # L(prior + c*noise)
                    l2 = L(image + self.fd_eta * q2 / norm(q2))  # L(prior - c*noise)
                    # Finite differences estimate of directional derivative
                    est_deriv = (l1 - l2) / (self.fd_eta * self.exploration)
                    # 2-query gradient estimate
                    est_grad = est_deriv.view(-1, 1, 1, 1) * exp_noise
                    # Update the prior with the estimated gradient
                    prior = prior_step(prior, est_grad, self.online_lr)
                else:
                    prior = ch.zeros_like(image)
                    for _ in range(self.gradient_iters):
                        exp_noise = ch.randn_like(image).cuda()  / (dim ** 0.5)
                        est_deriv = (L(image + self.fd_eta * exp_noise) - L(image - self.fd_eta * exp_noise)) / self.fd_eta
                        print(est_deriv.size())
                        prior += est_deriv.view(-1, 1, 1, 1) * exp_noise
                ## Update the image:
                # take a pgd step using the prior
                new_im = image_step(image, upsampler(prior), self.image_lr)
                image = proj_step(new_im)
                image = ch.clamp(image, 0, 1)
                ## Continue query count
                total_queries += 3 * self.gradient_iters
                if self.model.predict_label(image) !=true_label:
                    break
            return image
    def attack_untargeted(self, x_0, y_0, query_limit=20000):
        res = self.make_adversarial_examples(x_0.cuda(), y_0.cuda(),query_limit=query_limit)

        return res


