
from collections import namedtuple
from functools import partial
import tensorflow as tf
slim = tf.contrib.slim

from model_utils import resnet_group, resnet_bottleneck, \
      resnet_backbone, denoising, resnet_shortcut, \
      custom_resnet_backbone_cifar10, resnet_basic, \
      custom_resnet_backbone_v1_cifar10, \
      get_bn

from tensorpack.models import Conv2D, MaxPooling, AvgPooling, \
      GlobalAvgPooling, BatchNorm, FullyConnected, BNReLU, Dropout, \
      regularize_cost
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack import get_current_tower_context, logger, ModelDesc
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.collection import freeze_collection
from tensorpack.tfutils.tower import TowerFunc
from tensorpack.tfutils import argscope


NUM_BLOCKS_3_DOWNSAMPLE = {
  18 : [2, 2, 2, 2],
  34 : [3, 4, 6, 3],
  50 : [3, 4, 6, 3],
  101 : [3, 4, 23, 3],
  152 : [3, 8, 36, 3]
}

__MODEL_MAPS = {}
def register_model_class(cls):
  __MODEL_MAPS[cls.__name__] = cls
  return cls


ACTIVATION_FEATURE_MAPS = []
def register_activation_with_name(name='robust_activation'):
  def register_activation(func):
    def wrapper(x, trainable, *args, **kwargs):
      if func.__name__ == 'robust_activation_v0':
        with tf.variable_scope("", reuse=tf.AUTO_REUSE):
          temp = tf.get_variable('temperature', initializer=1., trainable=False)
        kwargs['temperature'] = temp
      before = x
      with tf.variable_scope(name): # reuse=tf.AUTO_REUSE):
        with slim.arg_scope([slim.conv2d, slim.fully_connected],
            weights_initializer=tf.random_normal_initializer(stddev=0.01),
            weights_regularizer=slim.l2_regularizer(1e-4),
            activation_fn=None, biases_initializer=None,
            reuse=tf.AUTO_REUSE):
          with slim.arg_scope([slim.conv2d],
              weights_initializer=tf.variance_scaling_initializer(scale=2.0, mode='fan_out'),
              padding='SAME', data_format='NCHW'):
            if trainable:
              after = func(x, *args, **kwargs)
            else:
              # If not trainable, should be initialized with pretrained model
              with slim.arg_scope([slim.conv2d, slim.fully_connected],
                  trainable=False, weights_regularizer=None):
                with argscope(BatchNorm, training=False):
                  after = func(x, *args, **kwargs)
      ACTIVATION_FEATURE_MAPS.append([before, after])
      return after
    return wrapper
  return register_activation


class Model(ModelDesc):
  """Tensorpack model for images.
  """
  weight_decay_pattern = '.*/W'

  def __init__(self, depth,
               num_classes=10,
               spatial_kernel_size=1,
               activation_add_relu=False,
               replace_locations=[1, 2, 3, 4],
               activation_version=0,
               prob_threshold=0.2,
               binary=False,
               act_no_bn=False,
               kpn=True,
               kpn_squeeze_size=1,
               init_alpha=[1, 0], init_lambda_alpha=1.0,
               init_beta=[0, 0], init_lambda_beta=0.5,
               k=2, r=4, 
               label_smoothing=0.,
               loss_scale=1.0,
               robust_activation_trainable=True,
               **kwargs):
    super(Model, self).__init__()

    self.num_blocks = NUM_BLOCKS_3_DOWNSAMPLE[depth]
    self.depth = depth
    self.num_classes = num_classes

    # Activation parameters
    self.spatial_kernel_size = spatial_kernel_size
    self.activation_add_relu = activation_add_relu
    self.replace_locations = replace_locations
    self.activation_version = activation_version
    self.prob_threshold = prob_threshold
    self.binary = binary
    self.act_no_bn = act_no_bn
    self.kpn = kpn
    self.kpn_squeeze_size = kpn_squeeze_size
    self.init_alpha = init_alpha
    self.init_lambda_alpha = init_lambda_alpha
    self.init_beta = init_beta
    self.init_lambda_beta = init_lambda_beta
    self.k = k
    self.r = r

    self.label_smoothing = label_smoothing
    self.loss_scale = loss_scale
    self.loss_type = 'default'
    self.optimizer_type = 'momentum'
    self.optimizer_params = {
      'momentum' : 0.9,
      'use_nesterov' : True
    }
    self.weight_decay = 1e-4
    self.robust_activation_trainable = robust_activation_trainable

    for k, v in kwargs.items():
      logger.warn("Unused parameters %s = %s" % (k, str(v)))

  def set_height(self, h):
    self.image_height = h

  def set_width(self, w):
    self.image_width = w

  def set_num_channels(self, n):
    self.num_channels = n
  
  def set_weight_decay(self, wd):
    self.weight_decay = wd

  def set_optimizer(self, t, params):
    self.optimizer_type = t
    self.optimizer_params = params

  def optimizer(self):
    lr = tf.get_variable('learning_rate', initializer=0.1, trainable=False)
    tf.summary.scalar('learning_rate-summary', lr)

    if self.optimizer_type == 'momentum':
      return tf.train.MomentumOptimizer(lr, **self.optimizer_params)
    else:
      raise ValueError("Optimizer % not supported" % t)

  def get_activation_func(self):
    if self.activation_version == 0:
      act_func = partial(robust_activation_v0, k=self.k, r=self.r,
                        init_alpha=self.init_alpha,
                        init_lambda_alpha=self.init_lambda_alpha,
                        init_beta=self.init_beta,
                        init_lambda_beta=self.init_lambda_beta,
                        spatial_kernel_size=self.spatial_kernel_size,
                        trainable=self.robust_activation_trainable)
    elif self.activation_version == 3:
      act_func = partial(robust_activation_v3,
                         trainable=self.robust_activation_trainable,
                         add_relu=self.activation_add_relu,
                         prob_threshold=self.prob_threshold,
                         binary=self.binary, no_bn=self.act_no_bn,
                         kpn=self.kpn,
                         kpn_squeeze_size=self.kpn_squeeze_size)
    elif self.activation_version == 'swish':
      act_func = partial(robust_activation_swish,
                         trainable=self.robust_activation_trainable)
    elif self.activation_version == 'softplus':
      act_func = partial(robust_activation_softplus, trainable=self.robust_activation_trainable)
    elif self.activation_version == 'elu':
      act_func = partial(robust_activation_elu, trainable=self.robust_activation_trainable)
    elif self.activation_version == 'gelu':
      act_func = partial(robust_activation_gelu, trainable=self.robust_activation_trainable)
    elif self.activation_version == 'relu':
      act_func = tf.nn.relu
    else:
      raise ValueError("activation_version %d not supported" % self.activation_version)
    return act_func


  def inputs(self):
    if self.num_channels == 3:
      return [tf.placeholder(tf.uint8,
                             [None, self.image_height, self.image_width, self.num_channels],
                             'input'),
              tf.placeholder(tf.int32, [None], 'label')]
    elif self.num_channels == 1:
      return [tf.placeholder(tf.uint8, [None, self.image_height, self.image_width], 'input'), 
              tf.placeholder(tf.int32, [None], 'label')]

  def build_graph(self, image, label):
    image = self.image_preprocess(image)
    # NHWC -> NCHW
    image = tf.transpose(image, [0, 3, 1, 2])
    logits = self.get_logits(image)
    loss = self.get_loss(logits, label)

    if self.weight_decay > 0:
      wd_loss = regularize_cost(self.weight_decay_pattern,
                                tf.contrib.layers.l2_regularizer(self.weight_decay),
                                name='l2_regularize_loss')
      add_moving_summary(loss, wd_loss)
      total_cost = tf.add_n([loss, wd_loss], name='cost')
    else:
      total_cost = tf.identity(loss, name='cost')
      add_moving_summary(total_cost)

    if self.loss_scale != 1.:
      logger.info("Scaling the total loss by {} ...".format(self.loss_scale))
      return total_cost * self.loss_scale
    else:
      return total_cost

  def get_logits(self, image):
    raise NotImplementedError()
  
  def get_inference_func(self):
    """Returns a tower function to be used for inference.
    """
    def tower_func(image, label):
      assert not self.training
      image = self.image_preprocess(image)
      image = tf.transpose(image, [0, 3, 1, 2])
      logits = self.get_logits(image)
      self.get_loss(logits, label)  # compute top-1 and top-5
    return TowerFunc(tower_func, self.get_input_signature())

  def set_attacker(self, attacker):
    self.attacker = attacker

  def get_inference_func_with_attacker(self, attacker):
    """Returns a tower function to be used for inference.
    
    It generates adv images with the given attacker and
    runs classification on it.
    """
    def tower_func(image, label):
      assert not self.training
      image = self.image_preprocess(image)
      image = tf.transpose(image, [0, 3, 1, 2])
      image, target_label = attacker.attack(image, label, self.get_logits)
      logits = self.get_logits(image)
      self.get_loss(logits, label)  # compute top-1 and top-5
      self.compute_attack_success(logits, target_label)

    return TowerFunc(tower_func, self.get_input_signature())

  def compute_attack_success(self, logits, target_label):
    """Compute the attack success rate.
    """
    pred = tf.argmax(logits, axis=1, output_type=tf.int32)
    equal_target = tf.equal(pred, target_label)
    success = tf.cast(equal_target, tf.float32, name='attack_success')
    add_moving_summary(tf.reduce_mean(success, name='attack_success_rate'))

  def image_preprocess(self, image):
    # Override
    with tf.name_scope('image_preprocess'):
      if image.dtype.base_dtype != tf.float32:
        image = tf.cast(image, tf.float32)
      # For the purpose of adversarial training, normalize images to [-1, 1]
      image = image * 2.0 / 255.0 - 1.0
      if len(image.shape) == 3:
        image = tf.expand_dims(image, 3)
      return image

  def get_loss(self, logits, label):
    if self.loss_type == 'default':
      if self.label_smoothing == 0.:
        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=label)
      else:
        nclass = logits.shape[-1]
        loss = tf.losses.softmax_cross_entropy(
            tf.one_hot(label, nclass),
            logits, label_smoothing=label_smoothing,
            reduction=tf.losses.Reduction.NONE)
      loss = tf.reduce_mean(loss, name='xentropy-loss')
    else:
      raise TypeError("%s not supported" % self.loss_type)


    def prediction_incorrect(logits, label, topk=1,
                             name='incorrect_vector'):
      with tf.name_scope('prediction_incorrect'):
          x = tf.logical_not(tf.nn.in_top_k(logits, label, topk))
      return tf.cast(x, tf.float32, name=name)

    wrong = prediction_incorrect(logits, label, 1, name='wrong-top1')
    add_moving_summary(tf.reduce_mean(wrong, name='train-error-top1'))

    wrong = prediction_incorrect(logits, label, 5, name='wrong-top5')
    add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5'))
    return loss


class AdvModel(Model):
  """Adversarial Model.
  """
  def build_graph(self, image, label):
    image = self.image_preprocess(image)
    image = tf.transpose(image, [0, 3, 1, 2])

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
      # BatchNorm always comes with trouble. 
      # We use the testing mode of it during attack.
      with freeze_collection([tf.GraphKeys.UPDATE_OPS]), argscope(BatchNorm, training=False):
        image, target_label = self.attacker.attack(image, label, self.get_logits)
        image = tf.stop_gradient(image, name='adv_training_sample')
      logits = self.get_logits(image)

    loss = self.get_loss(logits, label)
    self.compute_attack_success(logits, target_label)
    if not self.training:
      return

    wd_loss = regularize_cost(self.weight_decay_pattern,
                              tf.contrib.layers.l2_regularizer(self.weight_decay),
                              name='l2_regularize_loss')
    add_moving_summary(loss, wd_loss)
    total_cost = tf.add_n([loss, wd_loss], name='cost')

    if self.loss_scale != 1.:
      logger.info("Scaling the total loss by {} ...".format(self.loss_scale))
      return total_cost * self.loss_scale
    else:
      return total_cost


@register_model_class
class ResnetAdvModelCifar10(AdvModel):
  """Model for cifar10 with image shape 32x32.
  """
  def __init__(self, **kwargs):
    super(ResnetAdvModelCifar10, self).__init__(**kwargs)

  def get_logits(self, image):
    """
    See https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py.
    """
    return custom_resnet_backbone_cifar10(image, self.num_blocks,
                                          resnet_group,
                                          resnet_basic if self.depth < 50 else resnet_bottleneck,
                                          num_classes=self.num_classes)

@register_model_class
class ResNetDenoiseModelCifar10(Model):
  """ResNet model for denoising, standard training.
  """
  def __init__(self, **kwargs):
    super(ResNetDenoiseModelCifar10, self).__init__(**kwargs)

  def get_logits(self, image):
    def group_func(name, *args):
      l = resnet_group(name, *args)
      l = denoising(name + '_denoise', l, embed=True, softmax=True)
      return l
    return custom_resnet_backbone_cifar10(image, self.num_blocks,
                                          group_func,
                                          resnet_basic if self.depth < 50 else resnet_bottleneck,
                                          num_classes=self.num_classes)


@register_model_class
class ResNetDenoiseAdvModelCifar10(AdvModel):
  """ResNet model for denoising, adversarial training.
  """
  def __init__(self, **kwargs):
    super(ResNetDenoiseAdvModelCifar10, self).__init__(**kwargs)

  def get_logits(self, image):
    def group_func(name, *args):
      l = resnet_group(name, *args)
      l = denoising(name + '_denoise', l, embed=True, softmax=True)
      return l
    return custom_resnet_backbone_cifar10(image, self.num_blocks,
                                          group_func,
                                          resnet_basic if self.depth < 50 else resnet_bottleneck,
                                          num_classes=self.num_classes)


@register_model_class
class ResnetModelCifar10(Model):
  """Model for cifar10 with image shape 32x32.
  """
  def __init__(self, **kwargs):
    super(ResnetModelCifar10, self).__init__(**kwargs)

  def get_logits(self, image):
    """See https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py.
    """
    return custom_resnet_backbone_cifar10(image, self.num_blocks,
                                          resnet_group,
                                          resnet_basic if self.depth < 50 else resnet_bottleneck,
                                          num_classes=self.num_classes)


@register_model_class
class RobustActivationV7ModelCifar10(Model):
  """Standard training."""
  def __init__(self, **kwargs):
    super(RobustActivationV7ModelCifar10, self).__init__(**kwargs)

  def get_logits(self, image):
    act_func = self.get_activation_func()
    if self.depth < 50:
      custom_block = partial(resnet_basic, activation=None)
      block_func = resnet_basic
    else:
      custom_block = partial(resnet_bottleneck, activation=None)
      block_func = resnet_bottleneck

    def group_func(name, l, block_func, features, count, stride):
      # No last activation
      with tf.variable_scope(name):
        for i in range(0, count):
          with tf.variable_scope('block{}'.format(i)):
            current_stride = stride if i == 0 else 1
            if i < count - 1:
              l = block_func(l, features, current_stride)
            else:
              l = custom_block(l, features, current_stride)
      return l

    with argscope([Conv2D, MaxPooling, AvgPooling, GlobalAvgPooling, BatchNorm], data_format='NCHW'), \
            argscope(Conv2D, use_bias=False,
                    kernel_initializer=tf.variance_scaling_initializer(scale=2.0, mode='fan_out')):
      if 0 in self.replace_locations:
        l = Conv2D('conv0', image, 64, 3, strides=1, activation=None)
        l = BatchNorm('conv0/bn', l)
        l = act_func(l, name='conv0')
      else:
        l = Conv2D('conv0', image, 64, 3, strides=1, activation=BNReLU)
      l = group_func('group0', l, block_func, 64, self.num_blocks[0], 1)

      if 1 in self.replace_locations:
        l = act_func(l, name='group0')
      else:
        l = tf.nn.relu(l)

      l = group_func('group1', l, block_func, 128, self.num_blocks[1], 2)
      if 2 in self.replace_locations:
        l = act_func(l, name='group1')
      else:
        l = tf.nn.relu(l)

      l = group_func('group2', l, block_func, 256, self.num_blocks[2], 2)
      if 3 in self.replace_locations:
        l = act_func(l, name='group2')
      else:
        l = tf.nn.relu(l)

      l = group_func('group3', l, block_func, 512, self.num_blocks[3], 2)
      if 4 in self.replace_locations:
        l = act_func(l, name='group3')
      else:
        l = tf.nn.relu(l)

      l = GlobalAvgPooling('gap', l)
      logits = FullyConnected('linear', l, self.num_classes,
                              kernel_initializer=tf.random_normal_initializer(stddev=0.01))
    return logits


@register_model_class
class RobustActivationV7AdvModelCifar10(AdvModel):
  """Adversarial trainging"""
  def __init__(self, **kwargs):
    super(RobustActivationV7AdvModelCifar10, self).__init__(**kwargs)

  def get_logits(self, image):
    act_func = self.get_activation_func()
    if self.depth < 50:
      custom_block = partial(resnet_basic, activation=None)
      block_func = resnet_basic
    else:
      custom_block = partial(resnet_bottleneck, activation=None)
      block_func = resnet_bottleneck

    def group_func(name, l, block_func, features, count, stride):
      # No last activation
      with tf.variable_scope(name):
        for i in range(0, count):
          with tf.variable_scope('block{}'.format(i)):
            current_stride = stride if i == 0 else 1
            if i < count - 1:
              l = block_func(l, features, current_stride)
            else:
              l = custom_block(l, features, current_stride)
      return l

    with argscope([Conv2D, MaxPooling, AvgPooling, GlobalAvgPooling, BatchNorm], data_format='NCHW'), \
            argscope(Conv2D, use_bias=False,
                    kernel_initializer=tf.variance_scaling_initializer(scale=2.0, mode='fan_out')):
      if 0 in self.replace_locations:
        l = Conv2D('conv0', image, 64, 3, strides=1, activation=None)
        l = BatchNorm('conv0/bn', l)
        l = act_func(l, name='conv0')
      else:
        l = Conv2D('conv0', image, 64, 3, strides=1, activation=BNReLU)
      l = group_func('group0', l, block_func, 64, self.num_blocks[0], 1)

      if 1 in self.replace_locations:
        l = act_func(l, name='group0')
      else:
        l = tf.nn.relu(l)

      l = group_func('group1', l, block_func, 128, self.num_blocks[1], 2)
      if 2 in self.replace_locations:
        l = act_func(l, name='group1')
      else:
        l = tf.nn.relu(l)

      l = group_func('group2', l, block_func, 256, self.num_blocks[2], 2)
      if 3 in self.replace_locations:
        l = act_func(l, name='group2')
      else:
        l = tf.nn.relu(l)

      l = group_func('group3', l, block_func, 512, self.num_blocks[3], 2)
      if 4 in self.replace_locations:
        l = act_func(l, name='group3')
      else:
        l = tf.nn.relu(l)

      l = GlobalAvgPooling('gap', l)
      logits = FullyConnected('linear', l, self.num_classes,
                              kernel_initializer=tf.random_normal_initializer(stddev=0.01))
    return logits

@register_model_class
class RobustActivationV6ModelCifar10(Model):
  """Replace all."""
  def __init__(self, **kwargs):
    super(RobustActivationV6ModelCifar10, self).__init__(**kwargs)

  def get_logits(self, image):
    act_func = self.get_activation_func()
    def resnet_basic(l, ch_out, stride, group=1, res2_bottleneck=64, activation=tf.nn.relu):
      ch_factor = res2_bottleneck * group // 64
      shortcut = l
      l = Conv2D('conv1', l, ch_out * ch_factor, 3, strides=stride, activation=None)
      l = BatchNorm('bn', l)
      l = activation(l)
      l = Conv2D('conv2', l, ch_out * ch_factor, 3, strides=1, activation=get_bn(zero_init=False),
                 split=group)
      if stride != 1 or l.get_shape().as_list()[1] != ch_out * ch_factor:
          shortcut = Conv2D('convshortcut', shortcut, ch_out * ch_factor, 1, strides=stride,
                          activation=get_bn(zero_init=False))
      ret = l + shortcut
      if activation is not None:
          return activation(ret, name='block_output')
      else:
          return ret
    
    def resnet_bottleneck(l, ch_out, stride, group=1, res2_bottleneck=64, activation=tf.nn.relu):
      ch_factor = res2_bottleneck * group // 64
      shortcut = l
      l = Conv2D('conv1', l, ch_out * ch_factor, 1, strides=1)
      l = BatchNorm('bn1', l)
      l = activation(l, name='conv1_act')
      l = Conv2D('conv2', l, ch_out * ch_factor, 3, strides=stride, split=group)
      l = BatchNorm('bn2', l)
      l = activation(l, name='conv2_act')
      l = Conv2D('conv3', l, ch_out * 4, 1, activation=get_bn(zero_init=True))
      ret = l + resnet_shortcut(shortcut, ch_out * 4, stride, activation=get_bn(zero_init=False))
      if activation is not None:
        return activation(ret, name='block_output')
      else:
        return ret

    if self.depth < 50:
      block_func = partial(resnet_basic, activation=act_func)
    else:
      raise False
      block_func = partial(resnet_bottleneck, activation=act_func)

    def group_func(name, l, block_func, features, count, stride):
      with tf.variable_scope(name):
        for i in range(0, count):
          with tf.variable_scope('block{}'.format(i)):
            current_stride = stride if i == 0 else 1
            l = block_func(l, features, current_stride)
      return l

    with argscope([Conv2D, MaxPooling, AvgPooling, GlobalAvgPooling, BatchNorm], data_format='NCHW'), \
            argscope(Conv2D, use_bias=False,
                    kernel_initializer=tf.variance_scaling_initializer(scale=2.0, mode='fan_out')):
      l = Conv2D('conv0', image, 64, 3, strides=1, activation=None)
      l = BatchNorm('conv0/bn', l)
      l = act_func(l, name='conv0')
      l = group_func('group0', l, block_func, 64, self.num_blocks[0], 1)
      l = group_func('group1', l, block_func, 128, self.num_blocks[1], 2)
      l = group_func('group2', l, block_func, 256, self.num_blocks[2], 2)
      l = group_func('group3', l, block_func, 512, self.num_blocks[3], 2)

      l = GlobalAvgPooling('gap', l)
      logits = FullyConnected('linear', l, self.num_classes,
                              kernel_initializer=tf.random_normal_initializer(stddev=0.01))
    return logits

@register_model_class
class RobustActivationV6AdvModelCifar10(AdvModel):
  """Replace all."""
  def __init__(self, **kwargs):
    super(RobustActivationV6AdvModelCifar10, self).__init__(**kwargs)

  def get_logits(self, image):
    act_func = self.get_activation_func()

    def resnet_basic(l, ch_out, stride, group=1, res2_bottleneck=64, activation=tf.nn.relu):
      ch_factor = res2_bottleneck * group // 64
      shortcut = l
      l = Conv2D('conv1', l, ch_out * ch_factor, 3, strides=stride, activation=None)
      l = BatchNorm('bn', l)
      l = activation(l)
      l = Conv2D('conv2', l, ch_out * ch_factor, 3, strides=1, activation=get_bn(zero_init=False),
                 split=group)
      if stride != 1 or l.get_shape().as_list()[1] != ch_out * ch_factor:
        shortcut = Conv2D('convshortcut', shortcut, ch_out * ch_factor, 1, strides=stride,
                          activation=get_bn(zero_init=False))
      ret = l + shortcut
      if activation is not None:
        return activation(ret, name='block_output')
      else:
        return ret
    
    def resnet_bottleneck(l, ch_out, stride, group=1, res2_bottleneck=64, activation=tf.nn.relu):
      ch_factor = res2_bottleneck * group // 64
      shortcut = l
      l = Conv2D('conv1', l, ch_out * ch_factor, 1, strides=1)
      l = BatchNorm('bn1', l)
      l = activation(l, name='conv1_act')
      l = Conv2D('conv2', l, ch_out * ch_factor, 3, strides=stride, split=group)
      l = BatchNorm('bn2', l)
      l = activation(l, name='conv2_act')
      l = Conv2D('conv3', l, ch_out * 4, 1, activation=get_bn(zero_init=True))
      ret = l + resnet_shortcut(shortcut, ch_out * 4, stride, activation=get_bn(zero_init=False))
      if activation is not None:
        return activation(ret, name='block_output')
      else:
        return ret

    if self.depth < 50:
      block_func = partial(resnet_basic, activation=act_func)
    else:
      assert False
      block_func = partial(resnet_bottleneck, activation=act_func)

    def group_func(name, l, block_func, features, count, stride):
      with tf.variable_scope(name):
        for i in range(0, count):
          with tf.variable_scope('block{}'.format(i)):
            current_stride = stride if i == 0 else 1
            l = block_func(l, features, current_stride)
      return l

    with argscope([Conv2D, MaxPooling, AvgPooling, GlobalAvgPooling, BatchNorm], data_format='NCHW'), \
            argscope(Conv2D, use_bias=False,
                    kernel_initializer=tf.variance_scaling_initializer(scale=2.0, mode='fan_out')):
      l = Conv2D('conv0', image, 64, 3, strides=1, activation=None)
      l = BatchNorm('conv0/bn', l)
      l = act_func(l, name='conv0')
      l = group_func('group0', l, block_func, 64, self.num_blocks[0], 1)
      l = group_func('group1', l, block_func, 128, self.num_blocks[1], 2)
      l = group_func('group2', l, block_func, 256, self.num_blocks[2], 2)
      l = group_func('group3', l, block_func, 512, self.num_blocks[3], 2)

      l = GlobalAvgPooling('gap', l)
      logits = FullyConnected('linear', l, self.num_classes,
                              kernel_initializer=tf.random_normal_initializer(stddev=0.01))
    return logits


@register_activation_with_name('dynamic_relu_v0')
def robust_activation_v0(x, temperature, k=2, r=4, name='robact_v0',
                         init_alpha=[1, 0], init_lambda_alpha=1.0,
                         init_beta=[0, 0], init_lambda_beta=0.5,
                         spatial_kernel_size=1):
  n, c, h, w = [t.value for t in x.get_shape()]
  # Channel wise
  cw = GlobalAvgPooling('%s/avgpool' % name, x) # NxC
  cw = slim.fully_connected(cw, c // r, activation_fn=tf.nn.relu, scope='%s/squeeze' % name)
  cw = slim.fully_connected(cw, 2 * c * k, activation_fn=tf.nn.relu, scope='%s/extend' % name)

  if k == 1 and method == 'minus':
    cw = tf.sigmoid(cw)
  else:
    cw = 2 * tf.sigmoid(cw) - 1
  # Nx(2kC) -> Nx(kC)x1
  cw = tf.expand_dims(cw, axis=-1)
  # Spatial
  gamma = h * w / 3.0
  if isinstance(spatial_kernel_size, int):
    # NCHW -> Nx1xHxW
    sp = slim.conv2d(x, 1, kernel_size=spatial_kernel_size, stride=1, scope='%s/spatial_conv11' % name)
    # sp = Conv2D('%s/spatial_conv11' % name, x, 1, spatial_kernel_size, stride=1)
  elif isinstance(spatial_kernel_size, str):
    if spatial_kernel_size == 'kpn':
      sp = kpn_v0(x, 1, 3, method='local', activation=tf.nn.relu, name='%s/kpn_v0' % name)
    else:
      raise ValueError("spatial_kernel_size %s not supported" % spatial_kernel_size)
  else:
    raise ValueError("spatial_kernel_size type %s not supported" % type(spatial_kernel_size))
  # N1HW -> N(HW)
  sp = tf.reshape(sp, (-1, h * w))
  sp = gamma * tf.nn.softmax(sp / temperature, axis=1)
  # min(w, 1)
  sp = tf.minimum(sp, 1)
  # N x 1 x (HW)
  sp = tf.expand_dims(sp, axis=1)
  # N x (2kC) x (HW)
  params = cw * sp
  params = tf.reshape(params, (-1, 2 * k * c, h, w))

  if k > 1:
    # Calculate
    res = []
    for k_idx in range(k):
      alpha_k = params[:, 2 * k_idx * c : (2 * k_idx + 1) * c, ...]
      beta_k = params[:, (2 * k_idx + 1) * c : (2 * k_idx + 2) * c, ...]
      alpha_k = init_alpha[k_idx] + init_lambda_alpha * alpha_k
      beta_k = init_beta[k_idx] + init_lambda_beta * beta_k
      res.append(tf.expand_dims(alpha_k * x + beta_k, axis=1))

    all_k_x_after_act = tf.concat(res, axis=1)
    max_k_x_after_act = tf.reduce_max(all_k_x_after_act, axis=1, keepdims=False)
  else:
    alpha_k = params[:, :c, ...]
    beta_k = params[:, c: , ...]
    alpha_k = init_alpha[k_idx] + init_lambda_alpha * alpha_k
    beta_k = init_beta[k_idx] + init_lambda_beta * beta_k
    max_k_x_after_act = alpha_k * x + beta_k
  out = max_k_x_after_act
  return out


@register_activation_with_name('dynamic_relu_v3')
def robust_activation_v3(x, name='robact_v3',
                         add_relu=True, prob_threshold=0.2,
                         binary=False, no_bn=False, kpn=True,
                         kpn_squeeze_size=1):
  """V3.
  """
  n, c, h, w = [t.value for t in x.get_shape()]
  cw = slim.conv2d(x, c, kernel_size=3, stride=1, scope='%s/cw_squeeze_conv0' % name)
  if not no_bn:
    cw = BatchNorm('%s/cw_squeeze_conv0/bn' % name, cw)
  cw = tf.nn.relu(cw)
  cw = slim.conv2d(cw, c, kernel_size=3, stride=1, scope='%s/cw_extend_conv0' % name)
  cw = tf.reduce_mean(cw, (2, 3), keepdims=True, name='%s/avgpool' % name) # NCWH, NxCx1x1

  sp = slim.conv2d(x, c, kernel_size=3, stride=1, scope='%s/sp_squeeze_conv0' % name)
  if kpn:
    sp = kpn_v0(sp, 1, 3, method='local', activation=None, name='%s/kpn_v0' % name,
                no_bn=no_bn, squeeze_size=kpn_squeeze_size)
  else:
    sp = BatchNorm('%s/sp/bn1' % name, sp)
    sp = tf.nn.relu(sp)
    sp = slim.conv2d(sp, 1, kernel_size=3, stride=1, scope='%s/sp_squeeze_conv1' % name)
    sp = BatchNorm('%s/sp/bn2' % name, sp)
    sp = tf.nn.relu(sp, name='%s/conv2' % name)
    sp = slim.conv2d(sp, 1, kernel_size=3, stride=1, scope='%s/sp_squeeze_conv2' % name)
  prob = tf.sigmoid(cw * sp)
  prob = tf.nn.relu(prob - prob_threshold)
  if add_relu:
    x = tf.nn.relu(x)
  out = prob * x
  return out

@register_activation_with_name('swish')
def robust_activation_swish(x, name='swish'):
  return x * tf.sigmoid(x)

@register_activation_with_name('elu')
def robust_activation_elu(x, name='elu'):
  return tf.nn.elu(x)

@register_activation_with_name('gelu')
def robust_activation_gelu(x, name='gelu'):
  return 0.5 * x * (1 + tf.nn.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x)))

def kpn_v0(input, output_channels, kernel_size=3, name='kpn_v0', method='local',
           activation=tf.nn.relu, no_bn=False, squeeze_size=1):
  n, c, h, w = [t.value for t in input.get_shape()]

  if method == 'local':
    l0 = slim.conv2d(input, c, kernel_size=3, stride=1, scope='%s/spatial_conv11' % name)
    if not no_bn:
      l0 = BatchNorm('%s/spatial_conv11' % name, l0)
    l0 = tf.nn.relu(l0, name='%s/spatial_conv11' % name)
    # l0 = BNReLU(l0, name='%s/spatial_conv11' % name)
    filters = slim.conv2d(l0, kernel_size * kernel_size, kernel_size=1,
                          stride=1, scope='%s/conv1' % name)
    if not no_bn:
      filters = BatchNorm('%s/conv1' % name, filters)
    filters = tf.nn.relu(filters, name='%s/conv1' % name)
    # filters = BNReLU(filters, name='%s/conv1' % name)
    filters = tf.nn.softmax(filters, axis=1) # NCHW
    _, _, oh, ow = [t.value for t in filters.get_shape()]

    if squeeze_size > 1:
      if oh < squeeze_size:
        squeeze_size = oh
      filters = tf.nn.avg_pool(filters, (1, 1, squeeze_size, squeeze_size),
                               strides=(1, 1, squeeze_size, squeeze_size),
                               padding='SAME',
                               data_format='NCHW')
      filters = tf.tile(filters, tf.constant([1, 1, squeeze_size, 1], tf.int32))
      filters = tf.tile(filters, tf.constant([1, 1, 1, squeeze_size], tf.int32))

      _, _, ch, cw = [t.value for t in filters.get_shape()]
      assert ch == oh and cw == ow, "Wrong formulation"

    input_frame = slim.conv2d(input, 1, kernel_size=3,
                              stride=1, scope='%s/conv2' % name)
    if not no_bn:
      input_frame = BatchNorm('%s/conv2' % name, input_frame)
    input_frame = tf.nn.relu(input_frame, name='%s/conv2' % name)
    input_frame = tf.transpose(input_frame, [0, 2, 3, 1])
    input_frame = tf.extract_image_patches(input_frame, [1, kernel_size, kernel_size, 1],
                                           strides=[1, 1, 1, 1], rates=[1, 1, 1, 1],
                                           padding='SAME')
    input_frame = tf.transpose(input_frame, [0, 3, 1, 2])
    output = tf.reduce_sum(filters * input_frame, 1, keepdims=True)
    if activation is not None:
      output = activation(output)
    return output
  elif method == 'normal':
    raise ValueError("Not supported yet")


def get_model(config, **kwargs):
  config = config.copy()
  t = config.pop('type')
  config.update(kwargs)
  if t in __MODEL_MAPS:
    return __MODEL_MAPS[t](**config)
  else:
    raise TypeError("Model type %s not found" % t)
