'''adv_train_net wrapper
'''
import os
import sys
import time
import math

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import random


class AttackPGD(nn.Module):
  def __init__(self, basic_net, config):
    super(AttackPGD, self).__init__()
    self.basic_net = basic_net
    self.rand = config['random_start']
    self.init_step_size = config['step_size']
    self.step_size = config['step_size']
    self.init_epsilon = config['epsilon']
    self.epsilon = config['epsilon']
    self.init_num_steps = config['num_steps']
    self.num_steps = config['num_steps']
    self.up = config['up']
    self.down = config['down']
    self.npop = config['npop']
    self.sigma = config['sigma']
    self.alpha = config['alpha']
    self.runstep = config['runstep']
    self.device = config['device']
    self.init_norm = config['init_norm']
    self.gamma = config['gamma']
    self.normalize = config['normalize']
    self.criterion = config['criterion'](reduction='sum')
    assert config['loss_func'] == 'xent', 'Only xent supported for now.'

  def set_attack(self, epsilon=0.0, step_size=0.0, num_steps=0):
    '''
    Set parameters for attack
    :param epsilon: Default: 8.0
    :param step_size: Default: 2.0
    :param num_steps: Default: 7
    :return:
    '''
    if epsilon == 0.0:
      self.epsilon = self.init_epsilon
    else:
      self.epsilon = epsilon / 255. if self.normalize else epsilon
    if step_size == 0.0:
      self.step_size = self.init_step_size
    else:
      self.step_size = step_size / 255. if self.normalize else step_size
    if num_steps == 0:
      self.num_steps = self.init_num_steps
    else:
      self.num_steps = num_steps

  def forward(self, inputs, attack='none', targets=None, **kwargs):
    # if not args.attack:
    # return self.basic_net(inputs), inputs
    x = inputs.detach()
    if attack == 'none':
      return self.basic_net(inputs, **kwargs), None

    elif attack == 'gaussian':
      noise = torch.randn_like(inputs).to(self.device) * self.sigma
      x = inputs + noise
      x.clamp_(self.down, self.up)

    elif attack == 'pgd':

      # PGD attack
      if self.rand:
        # x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
        x = x + torch.clamp(torch.zeros_like(x).normal_(0, self.epsilon / 4),
                            -self.epsilon / 2, self.epsilon / 2)
        x = torch.clamp(x, self.down, self.up)
      for i in range(self.num_steps):
        x.requires_grad_()
        with torch.enable_grad():
          logits = self.basic_net(x, **kwargs)
          loss = self.criterion(logits, targets)
        grad = torch.autograd.grad(loss, [x])[0]
        x = x.detach() + self.step_size * torch.sign(grad.detach())
        x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon)
        x = torch.clamp(x, self.down, self.up)

    elif attack == 'margin':
      # Margin attack
      if self.rand:
        # x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
        x = x + torch.clamp(torch.zeros_like(x).normal_(0, self.epsilon / 4),
                            -self.epsilon / 2, self.epsilon / 2)
        x = torch.clamp(x, self.down, self.up)
      for i in range(self.num_steps):
        x.requires_grad_()
        with torch.enable_grad():
          logits = self.basic_net(x, **kwargs)
          top2 = torch.topk(logits, 2)
          top2_score = top2[0]
          top2_idx = top2[1]
          indices_correct = top2_idx[:,0] == targets
          top2_idx[indices_correct,0] = top2_idx[indices_correct,1]
          loss = self.criterion(logits, targets) - self.criterion(logits, top2_idx[:,0])
        grad = torch.autograd.grad(loss, [x])[0]
        x = x.detach() + self.step_size * torch.sign(grad.detach())
        x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon)
        x = torch.clamp(x, self.down, self.up)

    elif attack == 'fgsm':
      # FGSM attack
      if self.rand:
        # x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
        x = x + torch.clamp(torch.zeros_like(x).normal_(0, self.epsilon / 4),
                            -self.epsilon / 2, self.epsilon / 2)
        x = torch.clamp(x, self.down, self.up)
      x.requires_grad_()
      with torch.enable_grad():
        logits = self.basic_net(x, **kwargs)
        loss = self.criterion(logits, targets)
      grad = torch.autograd.grad(loss, [x])[0]
      x = x.detach() + self.epsilon * torch.sign(grad.detach())
      x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon)
      x = torch.clamp(x, self.down, self.up)


    elif attack == 'gm':
      # Gradient method
      x.requires_grad_()
      with torch.enable_grad():
        logits = self.basic_net(x, **kwargs)
        loss = self.criterion(logits, targets)
      grad = torch.autograd.grad(loss, [x])[0]
      x = x.detach() + self.epsilon * grad.detach()
      x = (x - inputs).renorm_(p=2, dim=0, maxnorm=self.epsilon) + inputs
      x = torch.clamp(x, self.down, self.up)


    elif attack == 'l2':
      # l2 PGD
      batch_size = x.shape[0]
      for i in range(self.num_steps):
        x.requires_grad_()
        with torch.enable_grad():
          logits = self.basic_net(x, **kwargs)
          loss = self.criterion(logits, targets)
        grad = torch.autograd.grad(loss, [x])[0]

        # renorming gradient
        grad_norms = grad.view(batch_size, -1).norm(p=2, dim=1)
        grad.div_(grad_norms.view(-1, 1, 1, 1))

        # avoid nan or inf if gradient is 0
        if (grad_norms == 0).any():
          grad[grad_norms == 0] = torch.randn_like(grad[grad_norms == 0])

        x = x.detach() + self.step_size * grad.detach()
        x = (x - inputs).renorm_(p=2, dim=0, maxnorm=self.epsilon) + inputs
        x = torch.clamp(x, self.down, self.up)


    elif attack == 'l2targeted':
      # l2 PGD
      batch_size = x.shape[0]
      for i in range(self.num_steps):
        x.requires_grad_()
        with torch.enable_grad():
          logits = self.basic_net(x, **kwargs)
          loss = self.criterion(logits, targets) - self.criterion(logits, kwargs['attack_targets'])
        grad = torch.autograd.grad(loss, [x])[0]

        # renorming gradient
        grad_norms = grad.view(batch_size, -1).norm(p=2, dim=1)
        grad.div_(grad_norms.view(-1, 1, 1, 1))

        # avoid nan or inf if gradient is 0
        if (grad_norms == 0).any():
          grad[grad_norms == 0] = torch.randn_like(grad[grad_norms == 0])

        x = x.detach() + self.step_size * grad.detach()
        x = (x - inputs).renorm_(p=2, dim=0, maxnorm=self.epsilon) + inputs
        x = torch.clamp(x, self.down, self.up)


    elif attack == 'l2targetedonly':
      # l2 PGD
      batch_size = x.shape[0]
      for i in range(self.num_steps):
        x.requires_grad_()
        with torch.enable_grad():
          logits = self.basic_net(x, **kwargs)
          loss = -self.criterion(logits, kwargs['attack_targets'])
        grad = torch.autograd.grad(loss, [x])[0]

        # renorming gradient
        grad_norms = grad.view(batch_size, -1).norm(p=2, dim=1)
        grad.div_(grad_norms.view(-1, 1, 1, 1))

        # avoid nan or inf if gradient is 0
        if (grad_norms == 0).any():
          grad[grad_norms == 0] = torch.randn_like(grad[grad_norms == 0])

        x = x.detach() + self.step_size * grad.detach()
        x = (x - inputs).renorm_(p=2, dim=0, maxnorm=self.epsilon) + inputs
        x = torch.clamp(x, self.down, self.up)


    elif attack == 'pgdtargeted':
      if self.rand:
        # x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
        x = x + torch.clamp(torch.zeros_like(x).normal_(0, self.epsilon / 4),
                            -self.epsilon / 2, self.epsilon / 2)
        x = torch.clamp(x, self.down, self.up)
      for i in range(self.num_steps):
        x.requires_grad_()
        with torch.enable_grad():
          logits = self.basic_net(x, **kwargs)
          loss = self.criterion(logits, targets) - self.criterion(logits, kwargs['attack_targets'])
        grad = torch.autograd.grad(loss, [x])[0]
        x = x.detach() + self.step_size * torch.sign(grad.detach())
        x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon)
        x = torch.clamp(x, self.down, self.up)


    elif attack == 'pgdtargetedonly':
      if self.rand:
        # x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
        x = x + torch.clamp(torch.zeros_like(x).normal_(0, self.epsilon / 4),
                            -self.epsilon / 2, self.epsilon / 2)
        x = torch.clamp(x, self.down, self.up)

      for i in range(self.num_steps):
        x.requires_grad_()
        with torch.enable_grad():
          logits = self.basic_net(x, **kwargs)
          loss = self.criterion(logits, kwargs['attack_targets'])
        grad = torch.autograd.grad(loss, [x])[0]
        x = x.detach() - self.step_size * torch.sign(grad.detach())
        x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon)
        x = torch.clamp(x, self.down, self.up)
    

    elif attack == 'l1':
      # l1 PGD topk
      # https://github.com/locuslab/robust_union/blob/master/MNIST/mnist_funcs.py#L245
      batch_size = x.shape[0]
      alpha_init = 0.05

      if self.rand:
        # x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
        x = x + torch.zeros_like(x).normal_(0, self.epsilon / 4)
        x = (x - inputs).renorm_(p=1, dim=0, maxnorm=self.epsilon) + inputs
        x = torch.clamp(x, self.down, self.up)

      for i in range(self.num_steps):
        x.requires_grad_()
        with torch.enable_grad():
          logits = self.basic_net(x, **kwargs)
          loss = self.criterion(logits, targets)
        grad = torch.autograd.grad(loss, [x])[0].detach()

        # topk
        k = random.randint(5, 20)
        alpha = alpha_init / k * 20
        grad[(grad < 0) & (x <= alpha_init)] = 0
        grad[(grad > 0) & (x >= 1 - alpha_init)] = 0
        grad = grad.view(batch_size, -1)
        val, _ = grad.abs().topk(k, dim=1)
        val = val[:, -1].unsqueeze(1)
        grad[grad.abs() < val] = 0
        grad.sign_()
        grad = grad.view(x.shape)

        x = x.detach() + alpha * grad

        # Project to L1 ball
        delta = x - inputs
        d = delta.abs()
        u = d.view(batch_size, -1)
        if (u.sum(dim=1) > self.epsilon).any():
          n = u.shape[1]
          u, _ = torch.sort(u, descending=True)
          cssv = u.cumsum(dim=1)
          vec = u * torch.arange(1, n + 1).float().to(self.device)
          comp = (vec > (cssv - self.epsilon)).float()

          u = comp.cumsum(dim=1) + (comp - 1).cumsum(dim=1)
          rho = torch.argmax(u, dim=1).view(batch_size)
          c = torch.tensor([cssv[i, rho[i]] for i in range(batch_size)], dtype=torch.float, device=self.device) - self.epsilon
          theta = torch.div(c, (rho.float() + 1)).view(batch_size, 1, 1, 1)
          d = (d - theta).clamp(min=0)

          delta = d * delta.sign()
          l1norm = d.view(batch_size, -1).sum(dim=1)[:, None, None, None]
          delta = delta * self.epsilon / l1norm

        x = delta + inputs
        x = torch.clamp(x, self.down, self.up)


    elif attack == 'l1targeted':
      # l1 PGD topk
      # https://github.com/locuslab/robust_union/blob/master/MNIST/mnist_funcs.py#L245
      batch_size = x.shape[0]
      alpha_init = 0.05

      if self.rand:
        # x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
        x = x + torch.zeros_like(x).normal_(0, self.epsilon / 4)
        x = (x - inputs).renorm_(p=1, dim=0, maxnorm=self.epsilon) + inputs
        x = torch.clamp(x, self.down, self.up)

      for i in range(self.num_steps):
        x.requires_grad_()
        with torch.enable_grad():
          logits = self.basic_net(x, **kwargs)
          loss = self.criterion(logits, targets) - self.criterion(logits, kwargs['attack_targets'])
        grad = torch.autograd.grad(loss, [x])[0].detach()

        # topk
        k = random.randint(5, 20)
        alpha = alpha_init / k * 20
        grad[(grad < 0) & (x <= alpha_init)] = 0
        grad[(grad > 0) & (x >= 1 - alpha_init)] = 0
        grad = grad.view(batch_size, -1)
        val, _ = grad.abs().topk(k, dim=1)
        val = val[:, -1].unsqueeze(1)
        grad[grad.abs() < val] = 0
        grad.sign_()
        grad = grad.view(x.shape)

        x = x.detach() + alpha * grad

        # Project to L1 ball
        delta = x - inputs
        d = delta.abs()
        u = d.view(batch_size, -1)
        if (u.sum(dim=1) > self.epsilon).any():
          n = u.shape[1]
          u, _ = torch.sort(u, descending=True)
          cssv = u.cumsum(dim=1)
          vec = u * torch.arange(1, n + 1).float().to(self.device)
          comp = (vec > (cssv - self.epsilon)).float()

          u = comp.cumsum(dim=1) + (comp - 1).cumsum(dim=1)
          rho = torch.argmax(u, dim=1).view(batch_size)
          c = torch.tensor([cssv[i, rho[i]] for i in range(batch_size)], dtype=torch.float, device=self.device) - self.epsilon
          theta = torch.div(c, (rho.float() + 1)).view(batch_size, 1, 1, 1)
          d = (d - theta).clamp(min=0)

          delta = d * delta.sign()
          l1norm = d.view(batch_size, -1).sum(dim=1)[:, None, None, None]
          delta = delta * self.epsilon / l1norm

        x = delta + inputs
        x = torch.clamp(x, self.down, self.up)


    elif attack == 'l1targetedonly':
      # l1 PGD topk
      # https://github.com/locuslab/robust_union/blob/master/MNIST/mnist_funcs.py#L245
      batch_size = x.shape[0]
      alpha_init = 0.05

      if self.rand:
        # x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
        x = x + torch.zeros_like(x).normal_(0, self.epsilon / 4)
        x = (x - inputs).renorm_(p=1, dim=0, maxnorm=self.epsilon) + inputs
        x = torch.clamp(x, self.down, self.up)

      for i in range(self.num_steps):
        x.requires_grad_()
        with torch.enable_grad():
          logits = self.basic_net(x, **kwargs)
          loss = -self.criterion(logits, kwargs['attack_targets'])
        grad = torch.autograd.grad(loss, [x])[0].detach()

        # topk
        k = random.randint(5, 20)
        alpha = alpha_init / k * 20
        grad[(grad < 0) & (x <= alpha_init)] = 0
        grad[(grad > 0) & (x >= 1 - alpha_init)] = 0
        grad = grad.view(batch_size, -1)
        val, _ = grad.abs().topk(k, dim=1)
        val = val[:, -1].unsqueeze(1)
        grad[grad.abs() < val] = 0
        grad.sign_()
        grad = grad.view(x.shape)

        x = x.detach() + alpha * grad

        # Project to L1 ball
        delta = x - inputs
        d = delta.abs()
        u = d.view(batch_size, -1)
        if (u.sum(dim=1) > self.epsilon).any():
          n = u.shape[1]
          u, _ = torch.sort(u, descending=True)
          cssv = u.cumsum(dim=1)
          vec = u * torch.arange(1, n + 1).float().to(self.device)
          comp = (vec > (cssv - self.epsilon)).float()

          u = comp.cumsum(dim=1) + (comp - 1).cumsum(dim=1)
          rho = torch.argmax(u, dim=1).view(batch_size)
          c = torch.tensor([cssv[i, rho[i]] for i in range(batch_size)], dtype=torch.float, device=self.device) - self.epsilon
          theta = torch.div(c, (rho.float() + 1)).view(batch_size, 1, 1, 1)
          d = (d - theta).clamp(min=0)

          delta = d * delta.sign()
          l1norm = d.view(batch_size, -1).sum(dim=1)[:, None, None, None]
          delta = delta * self.epsilon / l1norm

        x = delta + inputs
        x = torch.clamp(x, self.down, self.up)

  
    elif attack == 'multi_target':
      # multi_target attack
      x = inputs.detach().clone()
      ori_x = x
      best_x = x  # keep track of successful attacks
      num = list(x.size())[0]  # number of test_examples in x
      # print(num)

      for r in range(0, 10):
        x = ori_x
        if self.rand:
          # x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
          x = x + torch.clamp(torch.zeros_like(x).normal_(0, self.epsilon / 4),
                              -self.epsilon / 2, self.epsilon / 2)
          x = torch.clamp(x, self.down, self.up)
        with torch.enable_grad():
          for i in range(self.num_steps):
            x.requires_grad_()
            logits = self.basic_net(x, **kwargs)
            loss = torch.tensor(0.0, requires_grad=True)
            loss = loss.float().to(self.device)
            logits = logits.to(self.device)

            # multi_target loss
            for xx in range(0, num):
              loss = loss + logits[xx][r] - logits[xx][targets[xx]]

            grad = torch.autograd.grad(loss, [x])[0]
            x = x.detach() + self.step_size * torch.sign(grad.detach())
            x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon)
            x = torch.clamp(x, self.down, self.up)

            outputs = self.basic_net(x, **kwargs)
            _, predicted = outputs.max(1)
            test_correct = predicted.eq(targets)
            for xx in range(0, num):
              if test_correct[xx] == 0:  # successful attack
                best_x[xx] = x[xx]

      x = best_x
      x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon)
      x = torch.clamp(x, self.down, self.up)


    elif attack == 'feature':
      # Feature Attack
      # https://openreview.net/forum?id=Syejj0NYvr&noteId=rkeBhuBMjS
      feature = kwargs['feature']
      s = feature.shape
      feature = feature.reshape((1, s[0], s[1], s[2])).repeat(x.shape[0], 1, 1, 1)

      if self.rand:
        # x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
        x = x + torch.clamp(torch.zeros_like(x).normal_(0, self.epsilon / 4),
                            -self.epsilon / 2, self.epsilon / 2)
        x = torch.clamp(x, self.down, self.up)
      for i in range(self.num_steps):
        x.requires_grad_()
        with torch.enable_grad():
          self.basic_net(x, **kwargs)
          a = self.basic_net.feature
          loss = F.cosine_similarity(a, feature, dim=1).mean()

        grad = torch.autograd.grad(loss, [x])[0]
        x = x.detach() + self.step_size * torch.sign(grad.detach())
        x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon)
        x = torch.clamp(x, self.down, self.up)


    elif attack == 'feature_norepeat':
      # Feature Attack
      # https://openreview.net/forum?id=Syejj0NYvr&noteId=rkeBhuBMjS
      feature = kwargs['feature']
      # s = feature.shape
      # feature = feature.reshape((1, s[0], s[1], s[2])).repeat(x.shape[0], 1, 1, 1)

      if self.rand:
        # x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
        x = x + torch.clamp(torch.zeros_like(x).normal_(0, self.epsilon / 4),
                            -self.epsilon / 2, self.epsilon / 2)
        x = torch.clamp(x, self.down, self.up)
      for i in range(self.num_steps):
        x.requires_grad_()
        with torch.enable_grad():
          self.basic_net(x, **kwargs)
          a = self.basic_net.feature
          loss = F.cosine_similarity(a, feature, dim=1).mean()

        grad = torch.autograd.grad(loss, [x])[0]
        x = x.detach() + self.step_size * torch.sign(grad.detach())
        x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon)
        x = torch.clamp(x, self.down, self.up)


    else:
      raise NotImplementedError

    return self.basic_net(x, **kwargs), x



def adv_train_net(basic_net, eps=8.0, step_size=2.0, step_num=7, device='cpu', normalize=True,
                  sigma=0.25, up=1.0, down=0.0, criterion=torch.nn.CrossEntropyLoss):
  config = {
    'epsilon': eps / 255. if normalize else eps,
    'num_steps': step_num,
    'step_size': step_size / 255. if normalize else step_size,
    'random_start': True,
    'loss_func': 'xent',
    'up': up,
    'down': down,
    'npop': 300,
    'sigma': sigma,
    'alpha': 0.02,
    'runstep': 500,
    'device': device,
    'init_norm': 1.0,
    'gamma': 0.05,
    'normalize': normalize,
    'criterion': criterion,
  }
  return AttackPGD(basic_net, config)


def cifar_to_binary(labels):
  for i in range(len(labels)):
    if labels[i] in [0, 1, 8, 9]:
      labels[i] = 0
    else:
      labels[i] = 1
  return labels


def cifar_to_five(labels):
  for i in range(len(labels)):
    if labels[i] == 8:
      labels[i] = 0
    elif labels[i] == 9:
      labels[i] = 1
    elif labels[i] == 6:
      labels[i] = 2
    elif labels[i] == 5:
      labels[i] = 3
    elif labels[i] == 7:
      labels[i] = 4
  return labels


def get_mean_and_std(dataset):
  '''Compute the mean and std value of dataset.'''
  dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
  mean = torch.zeros(3)
  std = torch.zeros(3)
  print('==> Computing mean and std..')
  for inputs, targets in dataloader:
    for i in range(3):
      mean[i] += inputs[:, i, :, :].mean()
      std[i] += inputs[:, i, :, :].std()
  mean.div_(len(dataset))
  std.div_(len(dataset))
  return mean, std
