from attacks import *
import sys
import time
import torch
import torch.nn as nn


class SafeSpotLinf(object):
    def __init__(self, model,
                 attack_type, epsilon, step_size, num_steps, random_starts,
                 delta, lr, **kwargs):
        # Network
        self.model = model
        self.criterion = nn.CrossEntropyLoss(reduction='none').cuda()

        # Attack
        attack_class = getattr(sys.modules[__name__], attack_type)
        self.attack = attack_class(model, epsilon, step_size, num_steps)
        self.dynamics = self.attack.dynamics
        self.num_steps = num_steps
        self.random_starts = random_starts

        # Safe spot
        self.delta = delta
        self.lr = lr

    def initialize(self, x, y):
        self.x = x
        self.y = y
        self.safe_spot = nn.Parameter(self.transform(x, x))
        self.iter = 1
        self.num_queries = 0
        self.optimizer = torch.optim.RMSprop([self.safe_spot], lr=self.lr)

    def update(self):
        start = time.time()

        # Compute the distance
        dist = torch.mean(torch.max(torch.abs(self.transform_inv(self.x, self.safe_spot) - self.x).view(self.x.shape[0], -1), dim=1)[0])

        # Run attack
        total_grad = 0
        total_loss = 0

        for i in range(max(self.random_starts, 1)):
            x_s = self.safe_spot.clone().detach().requires_grad_(True)
            x = self.transform_inv(self.x, x_s)
            y = self.y.clone().detach()

            x_adv = self.attack(
                x, y,
                random_start=(self.random_starts > 0)
            )
            self.num_queries += self.random_starts * self.num_steps

            # Compute loss and grad
            output = self.model(x_adv)
            loss = self.criterion(output, y)
            total_loss += torch.mean(loss)
            grad, = torch.autograd.grad(torch.sum(loss), x_s)
            total_grad += grad

        total_loss /= max(self.random_starts, 1)
        total_grad /= max(self.random_starts, 1)

        # Update the safe spot
        self.optimizer.zero_grad()
        self.safe_spot.grad = total_grad
        self.optimizer.step()

        end = time.time()

        # Print
        print('iter: {}, loss: {:.4f}, dist: {:.4f}, time: {:.4f}'.format(
            self.iter, total_loss.item(), dist.item(), end - start))
        
        self.iter += 1

    def get_safe_spot(self):
        return self.transform_inv(self.x, self.safe_spot.data)

    def get_num_queries(self):
        return self.num_queries

    # [0, 1]^d -> R^d
    def transform(self, x_orig, x):
        lower = torch.clamp(x_orig - self.delta, min=0)
        upper = torch.clamp(x_orig + self.delta, max=1)
        
        def atanh(t):
            return 0.5 * torch.log((1 + t) / (1 - t))

        return atanh(((x - lower) / (upper - lower) - 0.5) * 1.999)

    # R^d -> [0, 1]^d
    def transform_inv(self, x_orig, x):
        lower = torch.clamp(x_orig - self.delta, min=0)
        upper = torch.clamp(x_orig + self.delta, max=1)
        return lower + (torch.tanh(x) + 1) / 2 * (upper - lower)

