import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from advertorch.attacks import MomentumIterativeAttack 
from advertorch.utils import clamp, batch_multiply, batch_clamp, normalize_by_pnorm

from .gradient_refine import DI, gkern


class CustomizedAttack(MomentumIterativeAttack):
  def __init__(self,
    predict, loss_fn=None, eps=0.3, nb_iter=40,
    decay_factor=1, eps_iter=0.01, clip_min=0,
    clip_max=1, targeted=False, ord=np.inf,
    rand_init=False,
    inner_iter=1,
    ti=False, di=False, di_prob=0.7,
    arch=None
  ):
    super().__init__(predict, loss_fn, eps, nb_iter, decay_factor, eps_iter, clip_min, clip_max, targeted, ord)
    self.di, self.di_prob = di, di_prob
    self.ti = ti
    self.ti_kernel = self._init_ti() if ti else None 
    self.inner_iter = inner_iter
    self.arch = arch
    self.rand_init = rand_init
    print(f"""Attacker parameters:
    eps={self.eps}, eps_iter={self.eps_iter}, nb_iter={self.nb_iter}, inner_iter={self.inner_iter}
    mi decay factor={self.decay_factor},
    ti?{self.ti}, di?{self.di}[prob={self.di_prob}]
    """)
  
  def _init_ti(self):
    # TI 参数设置
    return gkern(7, 3)

  
  def perturb(self, x, y=None):
    """
    Given examples (x, y), returns their adversarial counterparts with
    an attack length of eps.
    :param x: input tensor.
    :param y: label tensor.
              - if None and self.targeted=False, compute y as predicted
                labels.
              - if self.targeted=True, then y must be the targeted labels.
    :return: tensor containing perturbed inputs.
    """
    x, y = self._verify_and_process_inputs(x, y)

    delta = torch.zeros_like(x)
    if self.rand_init:
      delta = delta.uniform_(-self.eps, self.eps)
    delta.requires_grad_(True)
    g = torch.zeros_like(x)

    for nb_iter_i in range(self.nb_iter):
      grad = 0
      for inner_iter_i in range(self.inner_iter):
        imgadv = x + delta
        if delta.grad is not None:
            delta.grad.detach_()
            delta.grad.zero_()

        if self.di:
          original_size = imgadv.size()[-2]
          imgadv = DI(imgadv, diversity_prob=self.di_prob)
          imgadv = F.interpolate(imgadv, original_size, mode='bilinear', align_corners=False)

        outputs = self.predict(imgadv)

        loss = self.loss_fn(outputs, y)
        loss.backward()
        grad += delta.grad.data.clone()

      grad = grad / self.inner_iter

      if self.ti_kernel is not None:
        grad = F.conv2d(grad, self.ti_kernel.to(delta.device), bias=None, stride=1, padding=(3, 3), groups=3)

      g = self.decay_factor * g + normalize_by_pnorm(grad.data, p=1)

      # according to the paper it should be .sum(), but in their
      #   implementations (both cleverhans and the link from the paper)
      #   it is .mean(), but actually it shouldn't matter
      if self.ord == np.inf:
          delta.data += batch_multiply(self.eps_iter, torch.sign(g))
          delta.data = batch_clamp(self.eps, delta.data)
          delta.data = clamp(
              x + delta.data, min=self.clip_min, max=self.clip_max) - x
      elif self.ord == 2:
          delta.data += self.eps_iter * normalize_by_pnorm(g, p=2)
          delta.data *= clamp(
              (self.eps * normalize_by_pnorm(delta.data, p=2) /
                  delta.data),
              max=1.)
          delta.data = clamp(
              x + delta.data, min=self.clip_min, max=self.clip_max) - x
      else:
          error = "Only ord = inf and ord = 2 have been implemented"
          raise NotImplementedError(error)

    rval = x + delta.data
    return rval