# coding=utf-8
# Copyright 2020 The Gsa Net Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""GSA-Net models."""
from typing import Dict, List

import tensorflow.compat.v1 as tf

from gsa_net.layers import attention_layers
from gsa_net.layers import conv_layers
from tensorflow.contrib import training as contrib_training


def get_hparams(hparams_string = ''):
  """Returns a set of hyperparameters for a GSA-Net model.

  The hyperparameters include:
    layer_count: `int`, in {18, 26, 38, 50, 101, 152, 200}, the number of layers
        with parameters in the network. All models distribute the layers
        following ResNet conventions.
    class_count: `int`, number of classes for the target prediction task.
    batch_norm_momentum: `float`, momentum for batch normalization layers. When
        using

  Args:
    hparams_string: String representation of the hyperparamters, in the form of
        "key_0=value_0,key_1=value_1,...,key_n=value_n". For the format of
        values, refer to `tf.contrib.training.parse_values`.

  Returns:
    The default hyperparameters.
  """
  hparams = contrib_training.HParams(
      layer_count=50,
      class_count=1000,
      # Using exponential moving average instead of batchnorm
      batch_norm_momentum=0.0,
      batch_norm_epsilon=1e-5,
      depths=[64, 128, 256, 512],
      use_attention=[True, True, True, True],
      head_counts=[8, 8, 8, 8],
  )
  return hparams.parse(hparams_string)


def bottleneck_block(
    inputs,
    depth,
    head_count,
    is_training,
    strides,
    uses_attention,
    use_projection = False,
    batch_norm_momentum = 0.99,
    batch_norm_epsilon = 1e-5,
):
  """Creates a bottleneck block with an attentional or convolutional core.

  Args:
    inputs: Input features with shape (batch_size, height, width, input_depth).
    depth: Number of output channels for the first two layers. Note that the
        final output will have 4 times as many channels.
    head_count: Number of attention heads for the attention layer, if
        `uses_attention=True`.
    is_training: Whether in training or evaluation mode.
    strides: Block-level stride. The stride only applies to the first layer of
        the block. The remaining two layers always have stride 1. When the
        stride is not 1, the block will downsample the shortcut branch using a
        1x1 convolutional layer.
    uses_attention: Whether to use self-attention or spatial convolution for
      the middle layer.
    use_projection: Whether to use a 1x1 convolution for projecting the shortcut
        branch to match the spatial size and depth of the output.
    batch_norm_momentum: Momentum for batch normalization layers.
    batch_norm_epsilon: Epsilon for batch normalization layers.

  Returns:
    Output features with shape (batch_size, height, width, 4 * depth).
  """
  shortcut = inputs
  if use_projection:
    output_depth = 4 * depth
    shortcut = conv_layers.conv2d(
        inputs, output_depth, kernel_size=1, strides=strides)
    shortcut = conv_layers.batch_norm_relu(
        shortcut,
        is_training,
        relu=False,
        batch_norm_momentum=batch_norm_momentum,
        batch_norm_epsilon=batch_norm_epsilon,
    )

  inputs = conv_layers.conv2d(
      inputs, depth, kernel_size=1, strides=1)
  inputs = conv_layers.batch_norm_relu(
      inputs,
      is_training,
      batch_norm_momentum=batch_norm_momentum,
      batch_norm_epsilon=batch_norm_epsilon,
  )

  if uses_attention:
    inputs = attention_layers.attention_layer(
        inputs,
        depth,
        head_count,
        pooling_kernel_size=3,
        strides=strides,
        is_training=is_training,
        batch_norm_momentum=batch_norm_momentum,
        batch_norm_epsilon=batch_norm_epsilon,
    )
  else:
    inputs = conv_layers.conv2d(
        inputs, depth, kernel_size=3, strides=strides)
  inputs = conv_layers.batch_norm_relu(
      inputs,
      is_training,
      batch_norm_momentum=batch_norm_momentum,
      batch_norm_epsilon=batch_norm_epsilon,
  )

  inputs = conv_layers.conv2d(
      inputs, 4 * depth, kernel_size=1, strides=1)
  inputs = conv_layers.batch_norm_relu(
      inputs,
      is_training,
      relu=False,
      init_zero=True,
      batch_norm_momentum=batch_norm_momentum,
      batch_norm_epsilon=batch_norm_epsilon,
  )

  return tf.nn.relu(inputs + shortcut)


def block_group(
    inputs,
    depth,
    head_count,
    block_count,
    strides,
    uses_attention,
    is_training,
    name = 'block_group',
    batch_norm_momentum = 0.99,
    batch_norm_epsilon = 1e-5,
):
  """Creates a group of bottleneck blocks with the same output spatial size.

  Args:
    inputs: Input features with shape (batch_size, height, width, input_depth).
    depth: Number of output channels for the first and second layers for each
        block. Note that the output of each block and the final output will have
        4 times as many channels.
    head_count: Number of attention head for the attention layers, if
        `uses_attention=True`.
    block_count: Number of blocks in the group.
    strides: Stride for the first layer of the first block. Note that all other
        layers always have stride 1.
    uses_attention: Whether to use attention or convolution in the middle layer
        of each block.
    is_training: Whether in training or evaluation mode.
    name: Name for the output tensor of the group.
    batch_norm_momentum: Momentum for batch normalization layers.
    batch_norm_epsilon: Epsilon for batch normalization layers.

  Returns:
    Output features with shape (batch_size, height, width, 4 * depth).
  """
  # Only the first block per block_group uses projection shortcut and strides.
  with tf.variable_scope(name + '_0'):
    inputs = bottleneck_block(
        inputs,
        depth,
        head_count,
        is_training,
        strides,
        uses_attention,
        use_projection=True,
        batch_norm_momentum=batch_norm_momentum,
        batch_norm_epsilon=batch_norm_epsilon,
    )
  for i in range(1, block_count):
    with tf.variable_scope(name + '_{}'.format(i)):
      inputs = bottleneck_block(
          inputs,
          depth,
          head_count,
          is_training,
          strides=1,
          uses_attention=uses_attention,
          batch_norm_momentum=batch_norm_momentum,
          batch_norm_epsilon=batch_norm_epsilon,
      )
  return tf.identity(inputs, name)


# For GSA-Net, we focus on the Bottleneck Residual Block. Therefore, the
# shallowest models have 26 and 38 layers, rather than 18 or 34 layers as for
# typical convolutional ResNets.
LAYER_COUNTS: Dict[int, List[int]] = {
    26: [1, 2, 4, 1],
    38: [2, 3, 5, 2],
    50: [3, 4, 6, 3],
    101: [3, 4, 23, 3],
    152: [3, 8, 36, 3],
    200: [3, 24, 36, 3],
}


class GsaNet:
  """GSA-Net model.

  This class implements the GSA-Net model for image classification. It takes in
  a batch of input images as a tensor and a boolean indicating the mode of
  execution (training or evaluation) and outputs the classification logits. The
  argument to the initializer, `model_hparams`, controls the architecture and
  behavior of the model.
  """

  def __init__(self, model_hparams):
    """Initializes a GsaNet object.

    Args:
      model_hparams: Hyperparamters for the overall architecture of the model,
          including number of layers, depth of each layer, number of attention
          heads of each layer, where to use attention vs. convolution etc.
    """
    self._model_hparams: contrib_training.HParams = model_hparams
    self._batch_norm_momentum: float = self._model_hparams.batch_norm_momentum
    self._batch_norm_epsilon: float = self._model_hparams.batch_norm_epsilon
    self._depths: List[int] = self._model_hparams.depths
    self._use_attention: List[bool] = self._model_hparams.use_attention
    self._head_counts: List[int] = self._model_hparams.head_counts
    self._block_counts = LAYER_COUNTS[self._model_hparams.layer_count]

  def __call__(self, inputs, is_training = True):
    """Executes the GSA-Net model.

    Args:
      inputs: Input images with shape (batch_size, height, width, 3).
      is_training: Whether in training or evaluation mode.

    Returns:
      Logits with shape (batch_size, class_count).
    """
    # Stem
    inputs = conv_layers.conv2d(
        inputs, depth=64, kernel_size=7, strides=2)
    inputs = tf.identity(inputs, 'initial_conv')
    inputs = conv_layers.batch_norm_relu(inputs, is_training)
    inputs = tf.layers.max_pooling2d(
        inputs, pool_size=3, strides=2, padding='SAME')
    inputs = tf.identity(inputs, 'initial_max_pool')

    # Blocks
    for i in range(4):
      stride = 1 if i == 0 else 2
      name = 'block_group{}'.format(i)
      inputs = block_group(
          inputs,
          depth=self._depths[i],
          head_count=self._head_counts[i],
          block_count=self._block_counts[i],
          strides=stride,
          uses_attention=self._use_attention[i],
          is_training=is_training,
          name=name,
          batch_norm_momentum=self._batch_norm_momentum,
          batch_norm_epsilon=self._batch_norm_epsilon,
      )

    # Head
    pool_size = (inputs.shape[1], inputs.shape[2])
    inputs = tf.layers.average_pooling2d(
        inputs, pool_size, strides=1, padding='VALID')
    final_depth = inputs.get_shape().as_list()[-1]
    inputs = tf.reshape(inputs, [-1, final_depth])
    inputs = tf.identity(inputs, 'final_avg_pool')

    inputs = tf.layers.dense(
        inputs=inputs,
        units=self._model_hparams.class_count,
        kernel_initializer=tf.random_normal_initializer(stddev=0.01))
    inputs = tf.identity(inputs, 'logits')
    return inputs


def build_model(
    images,
    is_training,
    model_hparams,
    scope = 'gsa_net',
):
  """Creates a GSA-Net model and returns the logits.

  Args:
    images: Input images tensor with shape (batch_size, height, width, 3).
    is_training: Whether in training or evaluation mode.
    model_hparams: Hyperparamters for the overall architecture of the model,
        including number of layers, depth of each layer, number of attention
        heads of each layer, where to use attention vs. convolution etc.
    scope: Variable scope for the network.

  Returns:
    logits: the logits tensor of classes.
  """
  with tf.variable_scope(scope):
    model = GsaNet(model_hparams)
    logits = model(images, is_training=is_training)
  return logits
