
import tensorflow as tf


CLEAN_FEATURE_MAPS = []
__ATTACKER_MAPS = {}
def register_attacker_class(cls):
  __ATTACKER_MAPS[cls.__name__] = cls
  return cls


from nets import ACTIVATION_FEATURE_MAPS

class BaseAttacker(object):
  """Base attacker class.
  """

  def attack(self, image, label, model_func):
    """Attack.
    
    Parameters
    ----------
    images, label, model_func

    Returns
    -------
    images, label
    """
    raise NotImplementedError()


@register_attacker_class
class NoOpAttacker(BaseAttacker):
  """A placeholder doing nothing."""
  def attack(self, image, label, model_func):
    return image, -tf.ones_like(label)


@register_attacker_class
class PGDAttacker(BaseAttacker):
  """A PGD white-box attacker.
  """
  def __init__(self, num_iter, step_size, epsilon, num_classes,
               prob_start_from_clean=0.,
               image_scale=2.0 / 255):
    self.image_scale = image_scale
    self.num_iter = num_iter
    self.step_size = step_size * self.image_scale
    self.epsilon = epsilon * self.image_scale
    self.prob_start_from_clean = prob_start_from_clean
    self.num_classes = num_classes

  def _create_random_target(self, label):
    label_offset = tf.random_uniform(
      tf.shape(label), minval=1, maxval=self.num_classes, dtype=tf.int32)
    return tf.floormod(label + label_offset, tf.constant(self.num_classes, tf.int32))

  def attack(self, image, label, model_func):
    target_label = self._create_random_target(label)

    losses_list = []
    def one_step_attack(adv):
      logits = model_func(adv)
      losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=logits, labels=target_label)
      losses_list.append(losses)
      # tf.summary.scalar('adv-losses', losses)
      g, = tf.gradients(losses, adv)

      adv = tf.clip_by_value(adv - tf.sign(g) * self.step_size, 
                             lower_bound, upper_bound)
      return adv

    # Adversarial perturbation is considered under L∞ norm 
    # (i.e., maximum difference for each pixel).
    lower_bound = tf.clip_by_value(image - self.epsilon, -1., 1.)
    upper_bound = tf.clip_by_value(image + self.epsilon, -1., 1.)

    init_start = tf.random_uniform(tf.shape(image), 
                                   minval=-self.epsilon,
                                   maxval=self.epsilon)
    start_from_noise_index = tf.cast(tf.greater(tf.random_uniform(shape=[]),
                                                self.prob_start_from_clean),
                                     tf.float32)
    start_adv = image + start_from_noise_index * init_start
    # start_adv = tf.stop_gradient(start_adv)
    # for _ in range(self.num_iter):
    #   start_adv = one_step_attack(start_adv)
    # adv_final = start_adv
    # losses_list = tf.concat(losses_list, axis=-1)
    # losses_list = tf.identity(losses_list, name='losses_list')

    with tf.name_scope('attack_loop'):
      adv_final = tf.while_loop(
          lambda _: True,
          one_step_attack,
          [start_adv],
          back_prop=False,
          maximum_iterations=self.num_iter,
          parallel_iterations=1)
    concated = tf.concat([image, adv_final], axis=-1)
    uint8_image = tf.cast((concated + 1.0) * 255.0 / 2.0, tf.uint8)
    uint8_image = tf.transpose(uint8_image, [0, 2, 3, 1])
    tf.summary.image('PDG-%d-attacker' % self.num_iter, 
                     uint8_image, max_outputs=100)


    global CLEAN_FEATURE_MAPS
    global ACTIVATION_FEATURE_MAPS
    CLEAN_FEATURE_MAPS.append(tf.cast((adv_final + 1.0) * 255.0 / 2.0, tf.uint8))
    # CLEAN_FEATURE_MAPS.append(uint8_image)
    cur_len = len(ACTIVATION_FEATURE_MAPS)
    clean_logits = model_func(image)
    for ba in ACTIVATION_FEATURE_MAPS[-cur_len:]:
      CLEAN_FEATURE_MAPS += ba
    # ACTIVATION_FEATURE_MAPS = ACTIVATION_FEATURE_MAPS[:-cur_len]
    return adv_final, target_label


def get_attacker(config_, **kwargs):
  config = config_.copy()
  config.update(kwargs)
  t = config.pop('type')
  if t in __ATTACKER_MAPS:
    return __ATTACKER_MAPS[t](**config)
  else:
    raise TypeError("Attacker type %s not found" % t)
