import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np

from utils import *


class Adversary:
    """
    provides methods to find adversaries, get distances to the
    decision boundary and get corruptions to data
    """
    def __init__(self, strategy, device, eps=None, alpha=None, num_iter=None, restarts=None):
        self.strategy = strategy
        self.eps = eps
        self.alpha = alpha
        self.num_iter = num_iter
        self.restarts = restarts
        self.criterion = nn.CrossEntropyLoss()
        self.device = device


    def get_adversarial_examples(self, model, x, y, step_size=None, max_iter=None, num_examples=None):
        if self.strategy == 'fgsm':
            return self.fgsm(model, x, y, self.eps or 0.1)

        elif self.strategy == 'pgd':
            return self.pgd(model, x, y, self.eps or 0.1,
                    self.alpha or 1e4, self.num_iter or 1000)

        elif self.strategy == 'pgd_linf':
            return self.pgd_linf(model, x, y, self.eps or 0.1,
                    self.alpha or 1e-2, self.num_iter or 40)
            
        elif self.strategy == 'pgd_linf_rand':
            return self.pgd_linf_rand(model, x, y, self.eps or 0.1,
                    self.alpha or 1e-2, self.num_iter or 40, self.restarts or 10)

        elif self.strategy == 'random_walk':
            x, y = x.to(self.device), y.to(self.device)
            x = x.repeat(num_examples, 1, 1, 1)
            y = y.repeat(num_examples)
            tracker = torch.tensor([True for _ in range(len(y))]).to(self.device)
            for i in range(max_iter):
                x[tracker] += torch.randn(x[0].size()).to(self.device) * step_size
                outcome = model(x)
                _, pred = torch.max(outcome.data, 1)
                tracker[pred != y] = False
                if sum(tracker) == 0.:
                    return x, y
            return x[tracker == False], y[tracker == False]
            
            
            
    def fgsm(self, model, x, y, eps):    
        x, y = x.to(self.device), y.to(self.device)
        delta = torch.zeros_like(x, requires_grad=True, device=self.device)
        output = model(x+delta)
        loss = self.criterion(output, y)
        loss.backward()
        return eps * delta.grad.detach().sign()
     

    def pgd(self, model, x, y, eps, alpha, num_iter):
        x, y = x.to(self.device), y.to(self.device)
        delta = torch.zeros_like(x, requires_grad=True).to(x.device)
        for t in range(num_iter):
            loss = self.criterion(model(x + delta), y)
            loss.backward()
            delta.data = (delta + x.shape[0]*alpha*delta.grad.data).clamp(-eps,eps)
            delta.grad.zero_()
        return delta.detach()
        

    def pgd_linf(self, model, x, y, eps, alpha, num_iter):
        x, y = x.to(self.device), y.to(self.device)
        delta = torch.zeros_like(x, requires_grad=True).to(x.device)
        for t in range(num_iter):
            loss = self.criterion(model(x + delta), y)
            loss.backward()
            delta.data = (delta + alpha*delta.grad.detach().sign()).clamp(-eps,eps)
            delta.grad.zero_()
        return delta.detach()


    def pgd_linf_rand(self, model, x, y, eps, alpha, num_iter, restarts):
        x, y = x.to(self.device), y.to(self.device)
        max_loss = torch.zeros(y.shape[0]).to(y.device)
        max_delta = torch.zeros_like(x)

        for i in range(restarts):
            delta = torch.rand_like(x, requires_grad=True).to(x.device)
            delta.data = delta.data * 2 * eps - eps

            for t in range(num_iter):
                loss = self.criterion(model(x + delta), y)
                loss.backward()
                delta.data = (delta + alpha*delta.grad.detach().sign()).clamp(-eps,eps)
                delta.grad.zero_()

            all_loss = nn.CrossEntropyLoss(reduction='none')(model(x+delta), y)
            max_delta[all_loss >= max_loss] = delta.detach()[all_loss >= max_loss]
            max_loss = torch.max(max_loss, all_loss)
        return max_delta


    # method adapted from https://github.com/google-research/mnist-c
    # according publication https://arxiv.org/abs/1906.02337
    def fog(self, x, severity=5):
        """
        return a layer of fog
        determine severity via integer input
        """
        c = [(1.5, 2), (2., 2), (2.5, 1.7), (2.5, 1.5), (3., 1.4)][severity - 1]

        x = np.array(x) / 255.
        max_val = x.max()
        x = x + c[0] * plasma_fractal(wibbledecay=c[1])[:28, :28]
        x = np.clip(x * max_val / (max_val + c[0]), 0, 1) * 255
        return torch.tensor(x.astype(np.float32)).to(self.device)


    def get_distances(self, model, x, y, device, eps=0.1, alpha=1.0e-2, max_iter=1000):
        """
        returns l2 distance from examples x to the decision 
        boundary using the pgd method
        """
        self.device = device
        tracker = torch.zeros_like(y).to(device)
        step_counter = torch.zeros_like(y).to(device)
        distances = torch.zeros(y.size()).to(device).float()
        delta = torch.zeros_like(x, requires_grad=True, device=self.device)
        org_x = x
        normed_grad = torch.zeros(org_x.size()).to(device)

        for i in range(max_iter):
            x.requires_grad = True
            model.zero_grad()
            loss = self.criterion(model(x), y).to(device)
            loss.backward()
            norms = torch.norm(x.grad.data.view(y.size()[0], -1), dim=1)
            for j in range(normed_grad.size()[0]):
                normed_grad[j] = x.grad.data[j] / norms[j]
            adv_x = x + alpha * normed_grad
            eta = torch.clamp(adv_x - org_x, min=-eps, max=eps)
            x = (org_x + eta).detach_()
            
            # eval current predictions
            _, pred = torch.max(model(x).data, 1)
            killed = pred != y
            killed[tracker == 1] = False
            tracker[killed] = 1
            step_counter[killed] = i+1
            distances[killed] = torch.norm((x - org_x).view(len(y), -1), dim=1)[killed]

            if sum(tracker) == len(tracker):
                break
        return distances, tracker, step_counter


# from github.com/google-research/mnist-c/corruptions.py
def plasma_fractal(mapsize=256, wibbledecay=3):
    """
    Generate a heightmap using diamond-square algorithm.
    Return square 2d array, side length 'mapsize', of floats in range 0-255.
    'mapsize' must be a power of two.
    """
    assert (mapsize & (mapsize - 1) == 0)
    maparray = np.empty((mapsize, mapsize), dtype=np.float_)
    maparray[0, 0] = 0
    stepsize = mapsize
    wibble = 100

    def wibbledmean(array):
        return array / 4 + wibble * np.random.uniform(-wibble, wibble, array.shape)

    def fillsquares():
        """For each square of points stepsize apart,
           calculate middle value as mean of points + wibble"""
        cornerref = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
        squareaccum = cornerref + np.roll(cornerref, shift=-1, axis=0)
        squareaccum += np.roll(squareaccum, shift=-1, axis=1)
        maparray[stepsize // 2:mapsize:stepsize,
        stepsize // 2:mapsize:stepsize] = wibbledmean(squareaccum)

    def filldiamonds():
        """For each diamond of points stepsize apart,
           calculate middle value as mean of points + wibble"""
        mapsize = maparray.shape[0]
        drgrid = maparray[stepsize // 2:mapsize:stepsize, stepsize // 2:mapsize:stepsize]
        ulgrid = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
        ldrsum = drgrid + np.roll(drgrid, 1, axis=0)
        lulsum = ulgrid + np.roll(ulgrid, -1, axis=1)
        ltsum = ldrsum + lulsum
        maparray[0:mapsize:stepsize, stepsize // 2:mapsize:stepsize] = wibbledmean(ltsum)
        tdrsum = drgrid + np.roll(drgrid, 1, axis=1)
        tulsum = ulgrid + np.roll(ulgrid, -1, axis=0)
        ttsum = tdrsum + tulsum
        maparray[stepsize // 2:mapsize:stepsize, 0:mapsize:stepsize] = wibbledmean(ttsum)

    while stepsize >= 2:
        fillsquares()
        filldiamonds()
        stepsize //= 2
        wibble /= wibbledecay

    maparray -= maparray.min()
    return maparray / maparray.max()

