# coding=utf-8
# Copyright 2020 The Gsa Net Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Convolutional layers."""
import tensorflow.compat.v1 as tf


def batch_norm_relu(
    inputs,
    is_training,
    relu = True,
    init_zero = False,
    batch_norm_momentum = 0.99,
    batch_norm_epsilon = 1e-5,
):
  """Performs batch normalization followed by ReLU.

  Args:
    inputs: Input features with shape (batch_size, ..., depth).
    is_training: Whether in training or evaluation mode.
    relu: If False, omits the ReLU operation.
    init_zero: If True, initializes scale parameter of batch normalization with
        0 instead of 1 (default).
    batch_norm_momentum: Momentum for the batch normalization layer.
    batch_norm_epsilon: Epsilon for the batch normalization layer.

  Returns:
    Output features of the same shape as the input.
  """
  if init_zero:
    gamma_initializer = tf.zeros_initializer()
  else:
    gamma_initializer = tf.ones_initializer()

  inputs = tf.layers.batch_normalization(
      inputs=inputs,
      momentum=batch_norm_momentum,
      epsilon=batch_norm_epsilon,
      training=is_training,
      gamma_initializer=gamma_initializer,
  )
  if relu:
    inputs = tf.nn.relu(inputs)
  return inputs


def conv2d(inputs, depth, kernel_size, strides):
  """2D convolution layer with no ReLU and variance_scaling_initializer.

  Args:
    inputs: Input features with shape (batch_size, height, width, input_depth).
    depth: Depth of the output features, i.e. number of filters of the
        convolutional layer.
    kernel_size: Size of the convolutional kernel.
    strides: Stride for the convolutional layer. It is a single integer for both
        the height and width dimensions.

  Returns:
    Output features with shape (batch_size, height, width, depth).
  """
  outputs = tf.layers.conv2d(
      inputs=inputs,
      filters=depth,
      kernel_size=kernel_size,
      strides=strides,
      padding='SAME',
      use_bias=False,
      kernel_initializer=tf.variance_scaling_initializer(),
  )
  return outputs
