"""Implementation of MOAT block."""

import math
import tensorflow as tf
from moat_attention import Attention
from hparam_configs import create_config_from_dict


def drop_connect(inputs: tf.Tensor, training: bool,
                 survival_prob: float) -> tf.Tensor:
  """Drops the entire conv with given survival probability [1].

  [1] Deep Networks with Stochastic Depth,
      ECCV 2016.
        Gao Huang, Yu Sun, Zhuang Liu, Daniel Sedra, Kilian Q. Weinberger.

  Args:
    inputs: A tensor with shape [batch_size, height, width, channels].
    training: A boolen, whether in training mode or not.
    survival_prob: A float, 1 - drop_path_rate [1].

  Returns:
    output: A tensor with shape [batch_size, height, width, channels]
  """

  if not training:
    return inputs
  batch_size = tf.shape(inputs)[0]
  random_tensor = survival_prob
  random_tensor += tf.random.uniform([batch_size], dtype=inputs.dtype)
  for _ in range(inputs.shape.rank - 1):
    random_tensor = tf.expand_dims(random_tensor, axis=-1)
  binary_tensor = tf.floor(random_tensor)
  # Unlike the conventional way that we multiply survival_prob at test time, we
  # divide survival_prob at training time, so no additional compute is needed at
  # test time.
  output = inputs / survival_prob * binary_tensor
  return output


def residual_add_with_drop_path(
    residual: tf.Tensor, shortcut: tf.Tensor,
    survival_prob: float, training: bool) -> tf.Tensor:
  """Combines residual and shortcut."""
  if survival_prob is not None and 0 < survival_prob < 1:
    residual = drop_connect(residual, training, survival_prob)
  return shortcut + residual


class SqueezeAndExcitation(tf.keras.layers.Layer):
  """Implementation of Squeeze-and-excitation layer."""

  def _retrieve_config(self, config):
    """Retrieves the config of SqueezeAndExcitation.

    Args:
      config: A dictionary containing the following keys.
        -se_filters: An integer, feature channels of bottlneck
          in SqueezeAndExcitation.
        -output_filters: An integer, output channels.
        -activation: Activation layer class.
        -survival_prob: A float, 1 - drop_path_rate.
        -kernel_initializer: Initializer for the kernel weights matrix.
        -bias_initializer: Initializer for the bias vector.
        -name: A string, layer name.

    Returns:
      A hparam_configs.
    """

    required_keys = ['se_filters', 'output_filters']
    optional_keys = {
        'activation': tf.keras.activations.swish,
        'kernel_initializer': tf.random_normal_initializer(stddev=0.02),
        'bias_initializer': tf.zeros_initializer,
        'name': 'se',
    }
    config = create_config_from_dict(config, required_keys, optional_keys)
    return config

  def __init__(self, **config):
    self._config = self._retrieve_config(config)
    super().__init__(name=self._config.name)

    self._se_reduce = tf.keras.layers.Conv2D(
        self._config.se_filters,
        kernel_size=1,
        strides=1,
        padding='same',
        use_bias=True,
        kernel_initializer=self._config.kernel_initializer,
        bias_initializer=self._config.bias_initializer,
        name='reduce_conv2d')
    self._se_expand = tf.keras.layers.Conv2D(
        self._config.output_filters,
        kernel_size=1,
        strides=1,
        padding='same',
        use_bias=True,
        kernel_initializer=self._config.kernel_initializer,
        bias_initializer=self._config.bias_initializer,
        name='expand_conv2d')
    self.activation_fn = self._config.activation

  def call(self, inputs: tf.Tensor) -> tf.Tensor:
    _ = inputs.get_shape().with_rank(4)
    se_tensor = tf.reduce_mean(inputs, [1, 2], keepdims=True)
    se_tensor = self._se_expand(self.activation_fn(self._se_reduce(se_tensor)))
    return tf.sigmoid(se_tensor) * inputs


class MBConvBlock(tf.keras.layers.Layer):
  """Implementation of inverted bottleneck block."""

  def _retrieve_config(self, config):
    """Retrieves the config of MBConvBlock.

    Args:
      config: A dictionary containing the following keys.
        -hidden_size: An integer, output channel.
        -kernel_size: An integer, kernel size of MBConv
        -expansion_rate: An integer, expansion rate of MBConv.
        -se_ratio: An integer, expansion ratio of SqueezeAndExcitation
          in MBConv.
        -block_stride: An integer, stride of MBConv.
        -pool_size: An integer, kernel size for pooling in shortcut branch.
          For classification, pool size 2x2 saves model flops.
          For downstream tasks, pool size 3x3 is preferred for better feature
          alignments.
        -norm_class: Normalization layer class.
        -activation: Activation layer class.
        -survival_prob: A float, 1 - drop_path_rate.
        -kernel_initializer: Initializer for the kernel weights matrix.
        -bias_initializer: Initializer for the bias vector.
        -name: A string, layer name.

    Returns:
      A Config class: hparams_config.Config.
    """

    required_keys = ['hidden_size']
    optional_keys = {
        'kernel_size': 3,
        'expansion_rate': 4,
        'se_ratio': 0.25,
        'block_stride': 1,
        'pool_size': 2,
        'norm_class': tf.keras.layers.experimental.SyncBatchNormalization,
        'activation': tf.keras.activations.gelu,
        'survival_prob': None,
        'kernel_initializer': tf.random_normal_initializer(stddev=0.02),
        'bias_initializer': tf.zeros_initializer,
        'name': 'mbconv',
    }
    config = create_config_from_dict(config, required_keys, optional_keys)
    return config

  def __init__(self, **config):
    self._config = self._retrieve_config(config)
    super().__init__(name=self._config.name)
    self._activation_fn = self._config.activation
    self._norm_class = self._config.norm_class

  def build(self, input_shape: list[int]) -> None:
    input_size = input_shape[-1]
    inner_size = self._config.hidden_size * self._config.expansion_rate

    self._shortcut_conv = None
    if input_size != self._config.hidden_size:
      self._shortcut_conv = tf.keras.layers.Conv2D(
          filters=self._config.hidden_size,
          kernel_size=1,
          strides=1,
          padding='same',
          kernel_initializer=self._config.kernel_initializer,
          bias_initializer=self._config.bias_initializer,
          use_bias=True,
          name='shortcut_conv')

    self._pre_norm = self._norm_class(name='pre_norm')
    self._expand_conv = tf.keras.layers.Conv2D(
        filters=inner_size,
        kernel_size=1,
        strides=1,
        kernel_initializer=self._config.kernel_initializer,
        padding='same',
        use_bias=False,
        name='expand_conv')
    self._expand_norm = self._norm_class(name='expand_norm')
    self._depthwise_conv = tf.keras.layers.DepthwiseConv2D(
        kernel_size=self._config.kernel_size,
        strides=self._config.block_stride,
        depthwise_initializer=self._config.kernel_initializer,
        padding='same',
        use_bias=False,
        name='depthwise_conv')
    self._depthwise_norm = self._norm_class(name='depthwise_norm')

    self._se = None
    if self._config.se_ratio is not None:
      se_filters = max(1, int(self._config.hidden_size * self._config.se_ratio))
      self._se = SqueezeAndExcitation(
          se_filters=se_filters,
          output_filters=inner_size,
          kernel_initializer=self._config.kernel_initializer,
          bias_initializer=self._config.bias_initializer,
          name='se')

    self._shrink_conv = tf.keras.layers.Conv2D(
        filters=self._config.hidden_size,
        kernel_size=1,
        strides=1,
        padding='same',
        kernel_initializer=self._config.kernel_initializer,
        bias_initializer=self._config.bias_initializer,
        use_bias=True,
        name='shrink_conv')

  def _shortcut_downsample(self, inputs, name):
    output = inputs
    if self._config.block_stride > 1:
      pooling_layer = tf.keras.layers.AveragePooling2D(
          pool_size=self._config.pool_size,
          strides=self._config.block_stride,
          padding='same',
          name=name,
      )
      if output.dtype == tf.float32:
        output = pooling_layer(output)
      else:
        # We find that in our code base, the output dtype of pooling is float32
        # no matter whether its input and compute dtype is bfloat16 or
        # float32. So we explicitly cast the output dtype of pooling to be the
        # model compute dtype.
        output = tf.cast(pooling_layer(
            tf.cast(output, tf.float32)), output.dtype)
    return output

  def _shortcut_branch(self, inputs):
    shortcut = self._shortcut_downsample(inputs, name='shortcut_pool')
    if self._shortcut_conv:
      shortcut = self._shortcut_conv(shortcut)
    return shortcut

  def call(self, inputs: tf.Tensor, training: bool) -> tf.Tensor:
    shortcut = self._shortcut_branch(inputs)
    output = self._pre_norm(inputs, training=training)
    output = self._expand_conv(output)
    output = self._expand_norm(output, training=training)
    output = self._activation_fn(output)
    output = self._depthwise_conv(output)
    output = self._depthwise_norm(output, training=training)
    output = self._activation_fn(output)
    if self._se:
      output = self._se(output)
    output = self._shrink_conv(output)
    output = residual_add_with_drop_path(
        output, shortcut,
        self._config.survival_prob, training)
    return output


class MOATBlock(tf.keras.layers.Layer):
  """Implementation of MOAT block."""

  def _retrieve_config(self, config):
    """Retrieves the config of MOATBlock.

    Args:
      config: A dictionary containing the following keys.
        -hidden_size: An integer, output channels.
        -kernel_size: An integer, kernel size of MBConv
        -expansion_rate: An integer, expansion rate of MBConv.
        -block_stride: An integer, stride of MBConv.
        -pool_size: An integer, kernel size for pooling in shortcut branch.
          For classification, pool size 2x2 saves model flops.
          For downstream tasks, pool size 3x3 is preferred for better feature
          alignments.
        -norm_class: Normalization layer class in MBConv.
        -activation: Activation layer class in MBConv.
        -head_size: An integer, feature channels per head in Attention.
        -window_size: A list of two integers, spatial size of input for window
          attention, [height, width]. If None, global attention will be enabled.
        -relative_position_embedding_type: A string, type of relative position
          embedding in Attention. If None, no relative position embedding will
          be used.
        -position_embedding_size: An integer, specifying the position embedding
         size. If the feature map have larger size, the position embedding will
         be interpolated to match it.
        -ln_epsilon: A float, epsilon for layer normalization in Attention.
        -survival_prob: A float, 1 - drop_path_rate.
        -kernel_initializer: Initializer for the kernel weights matrix.
        -bias_initializer: Initializer for the bias vector.
        -use_checkpointing_for_attention: A boolean, specifying whether to use
          checkpointing for attention.
        -name: A string, layer name.

    Returns:
      A Config class: hparams_config.Config.

    Raises:
      ValueError: If the window size is not None and the length of window_size
        is not two.
    """

    required_keys = ['hidden_size']
    optional_keys = {
        'kernel_size': 3,
        'expansion_rate': 4,
        'block_stride': 2,
        'pool_size': 2,
        'norm_class': tf.keras.layers.experimental.SyncBatchNormalization,
        'activation': tf.keras.activations.gelu,
        'head_size': 32,
        'window_size': None,
        'relative_position_embedding_type': '2d_multi_head',
        'position_embedding_size': 7,
        'ln_epsilon': 1e-5,
        'survival_prob': None,
        'kernel_initializer': tf.random_normal_initializer(stddev=0.02),
        'bias_initializer': tf.zeros_initializer,
        'use_checkpointing_for_attention': False,
        'name': 'moat',
    }
    config = create_config_from_dict(config, required_keys, optional_keys)
    return config

  def __init__(self, **config):
    self._config = self._retrieve_config(config)
    super().__init__(name=self._config.name)
    self._activation_fn = self._config.activation
    self._norm_class = self._config.norm_class

  def build(self, input_shape: list[int]) -> None:
    height, width, input_size = input_shape[-3:]
    inner_size = self._config.hidden_size * self._config.expansion_rate

    if self._config.window_size:
      if (isinstance(self._config.window_size, list) and
          len(self._config.window_size)) == 2:
        self._window_height = self._config.window_size[0]
        self._window_width = self._config.window_size[1]
      else:
        raise ValueError((
            'The window size should be a list of two ints',
            '[height, width], if specified.'))
    else:
      self._window_height = math.ceil(float(height) / self._config.block_stride)
      self._window_width = math.ceil(float(width) / self._config.block_stride)

    self._shortcut_conv = None
    if input_size != self._config.hidden_size:
      self._shortcut_conv = tf.keras.layers.Conv2D(
          filters=self._config.hidden_size,
          kernel_size=1,
          strides=1,
          padding='same',
          kernel_initializer=self._config.kernel_initializer,
          bias_initializer=self._config.bias_initializer,
          use_bias=True,
          name='shortcut_conv')

    self._pre_norm = self._norm_class(name='pre_norm')
    self._expand_conv = tf.keras.layers.Conv2D(
        filters=inner_size,
        kernel_size=1,
        strides=1,
        kernel_initializer=self._config.kernel_initializer,
        padding='same',
        use_bias=False,
        name='expand_conv')
    self._expand_norm = self._norm_class(name='expand_norm')
    self._depthwise_conv = tf.keras.layers.DepthwiseConv2D(
        kernel_size=self._config.kernel_size,
        strides=self._config.block_stride,
        depthwise_initializer=self._config.kernel_initializer,
        padding='same',
        use_bias=False,
        name='depthwise_conv')
    self._depthwise_norm = self._norm_class(name='depthwise_norm')
    self._shrink_conv = tf.keras.layers.Conv2D(
        filters=self._config.hidden_size,
        kernel_size=1,
        strides=1,
        padding='same',
        kernel_initializer=self._config.kernel_initializer,
        bias_initializer=self._config.bias_initializer,
        use_bias=True,
        name='shrink_conv')

    self._attention_norm = tf.keras.layers.LayerNormalization(
        axis=-1,
        epsilon=self._config.ln_epsilon,
        name='attention_norm')

    scale_ratio = None
    if self._config.relative_position_embedding_type:
      if self._config.position_embedding_size is None:
        raise ValueError(
            'The position embedding size need to be specified ' +
            'if relative position embedding is used.')
      scale_ratio = [
          self._window_height / self._config.position_embedding_size,
          self._window_width / self._config.position_embedding_size,
      ]

    self._attention = Attention(
        hidden_size=self._config.hidden_size,
        head_size=self._config.head_size,
        relative_position_embedding_type=(
            self._config.relative_position_embedding_type),
        scale_ratio=scale_ratio,
        kernel_initializer=self._config.kernel_initializer,
        bias_initializer=self._config.bias_initializer)

  def _make_windows(self, inputs):
    _, height, width, channels = inputs.get_shape().with_rank(4).as_list()
    inputs = tf.reshape(
        inputs,
        (-1,
         height // self._window_height, self._window_height,
         width//self._window_width, self._window_width,
         channels))
    inputs = tf.transpose(inputs, (0, 1, 3, 2, 4, 5))
    inputs = tf.reshape(
        inputs,
        (-1, self._window_height, self._window_width, channels))
    return inputs

  def _remove_windows(self, inputs, height, width):
    _, _, channels = inputs.get_shape().with_rank(3).as_list()
    inputs = tf.reshape(inputs, [
        -1, height // self._window_height, width // self._window_width,
        self._window_height, self._window_width, channels
    ])
    inputs = tf.transpose(inputs, (0, 1, 3, 2, 4, 5))
    inputs = tf.reshape(inputs, (-1, height, width, channels))
    return inputs

  def _shortcut_downsample(self, inputs, name):
    output = inputs
    if self._config.block_stride > 1:
      pooling_layer = tf.keras.layers.AveragePooling2D(
          pool_size=self._config.pool_size,
          strides=self._config.block_stride,
          padding='same',
          name=name,
      )
      if output.dtype == tf.float32:
        output = pooling_layer(output)
      else:
        # We find that in our code base, the output dtype of pooling is float32
        # no matter whether its input and compute dtype is bfloat16 or
        # float32. So we explicitly cast the output dtype of pooling to be the
        # model compute dtype.
        output = tf.cast(pooling_layer(
            tf.cast(output, tf.float32)), output.dtype)
    return output

  def _shortcut_branch(self, inputs):
    shortcut = self._shortcut_downsample(inputs, name='shortcut_pool')
    if self._shortcut_conv:
      shortcut = self._shortcut_conv(shortcut)
    return shortcut

  def call(self, inputs: tf.Tensor, training: bool) -> tf.Tensor:
    mbconv_shortcut = self._shortcut_branch(inputs)
    output = self._pre_norm(inputs, training=training)
    output = self._expand_conv(output)
    output = self._expand_norm(output, training=training)
    output = self._activation_fn(output)
    output = self._depthwise_conv(output)
    output = self._depthwise_norm(output, training=training)
    output = self._activation_fn(output)
    output = self._shrink_conv(output)
    output = residual_add_with_drop_path(
        output, mbconv_shortcut,
        self._config.survival_prob, training)

    # For classification, the window size is the same as feature map size.
    # For downstream tasks, the window size can be set the same as
    # classification's.
    attention_shortcut = output
    def _func(output):
      output = self._attention_norm(output)
      _, height, width, _ = output.get_shape().with_rank(4).as_list()
      output = self._make_windows(output)
      output = self._attention(output)
      output = self._remove_windows(output, height, width)
      return output

    func = _func
    if self._config.use_checkpointing_for_attention:
      func = tf.recompute_grad(_func)

    output = func(output)
    output = residual_add_with_drop_path(
        output, attention_shortcut,
        self._config.survival_prob, training)
    return output
