# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This file contains wrapper classes for convolution layers of tf.keras and Switchable Atrous Convolution.

Switchable Atrous Convolution (SAC) is convolution with a switchable atrous
rate. It also has optional pre- and post-global context layers.
[1] Siyuan Qiao, Liang-Chieh Chen, Alan Yuille. DetectoRS: Detecting Objects
    with Recursive Feature Pyramid and Switchable Atrous Convolution.
    arXiv:2006.02334
"""
import functools
from typing import Optional
import tensorflow as tf
from src.models import utils, activations


def _compute_padding_size(kernel_size, atrous_rate):
    kernel_size_effective = kernel_size + (kernel_size - 1) * (atrous_rate - 1)
    pad_total = kernel_size_effective - 1
    pad_begin = pad_total // 2
    pad_end = pad_total - pad_begin
    if pad_begin != pad_end:
        print('Convolution requires one more padding to the '
                     'bottom-right pixel. This may cause misalignment.')
    return pad_begin, pad_end


class GlobalContext(tf.keras.layers.Layer):
    """Class for the global context modules in Switchable Atrous Convolution."""

    def build(self, input_shape):
        super().build(input_shape)
        input_shape = tf.TensorShape(input_shape)
        input_channel = self._get_input_channel(input_shape)
        self.global_average_pooling = tf.keras.layers.GlobalAveragePooling2D()
        self.convolution = tf.keras.layers.Conv2D(
            input_channel, 1, strides=1, padding='same', name=self.name + '_conv',
            kernel_initializer='zeros', bias_initializer='zeros')

    def call(self, inputs, *args, **kwargs):
        outputs = self.global_average_pooling(inputs)
        outputs = tf.expand_dims(outputs, axis=1)
        outputs = tf.expand_dims(outputs, axis=1)
        outputs = self.convolution(outputs)
        return inputs + outputs

    def _get_input_channel(self, input_shape):
        # Reference: tf.keras.layers.convolutional.Conv.
        if input_shape.dims[-1].value is None:
            raise ValueError('The channel dimension of the inputs '
                             'should be defined. Found `None`.')
        return int(input_shape[-1])


class SwitchableAtrousConvolution(tf.keras.layers.Conv2D):
    """Class for the Switchable Atrous Convolution."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._average_pool = tf.keras.layers.AveragePooling2D(
            pool_size=(5, 5), strides=1, padding='same')
        self._switch = tf.keras.layers.Conv2D(
            1,
            kernel_size=1,
            strides=self.strides,
            padding='same',
            dilation_rate=1,
            name='switch',
            kernel_initializer='zeros',
            bias_initializer='zeros')

    def build(self, input_shape):
        super().build(input_shape)
        if self.padding == 'causal':
            tf_padding = 'VALID'
        elif isinstance(self.padding, str):
            tf_padding = self.padding.upper()
        else:
            tf_padding = self.padding
        large_dilation_rate = list(self.dilation_rate)
        large_dilation_rate = [r * 3 for r in large_dilation_rate]
        self._large_convolution_op = functools.partial(
            tf.nn.convolution,
            strides=list(self.strides),
            padding=tf_padding,
            dilations=large_dilation_rate,
            data_format=self._tf_data_format,
            name=self.__class__.__name__ + '_large')

    def call(self, inputs):
        # Reference: tf.keras.layers.convolutional.Conv.
        input_shape = inputs.shape
        switches = self._switch(self._average_pool(inputs))

        if self._is_causal:  # Apply causal padding to inputs for Conv1D.
            inputs = tf.compat.v1.pad(inputs, self._compute_causal_padding(inputs))

        outputs = self._convolution_op(inputs, self.kernel)
        outputs_large = self._large_convolution_op(inputs, self.kernel)

        outputs = switches * outputs_large + (1 - switches) * outputs

        if self.use_bias:
            outputs = tf.nn.bias_add(
                outputs, self.bias, data_format=self._tf_data_format)

        if not tf.executing_eagerly():
            # Infer the static output shape:
            out_shape = self.compute_output_shape(input_shape)
            outputs.set_shape(out_shape)

        if self.activation is not None:
            return self.activation(outputs)
        return outputs

    def squeeze_batch_dims(self, inp, op, inner_rank):
        # Reference: tf.keras.utils.conv_utils.squeeze_batch_dims.
        with tf.name_scope('squeeze_batch_dims'):
            shape = inp.shape

            inner_shape = shape[-inner_rank:]
            if not inner_shape.is_fully_defined():
                inner_shape = tf.compat.v1.shape(inp)[-inner_rank:]

            batch_shape = shape[:-inner_rank]
            if not batch_shape.is_fully_defined():
                batch_shape = tf.compat.v1.shape(inp)[:-inner_rank]

            if isinstance(inner_shape, tf.TensorShape):
                inp_reshaped = tf.reshape(inp, [-1] + inner_shape.as_list())
            else:
                inp_reshaped = tf.reshape(
                    inp, tf.concat(([-1], inner_shape), axis=-1))

            out_reshaped = op(inp_reshaped)

            out_inner_shape = out_reshaped.shape[-inner_rank:]
            if not out_inner_shape.is_fully_defined():
                out_inner_shape = tf.compat.v1.shape(out_reshaped)[-inner_rank:]

            out = tf.reshape(
                out_reshaped, tf.concat((batch_shape, out_inner_shape), axis=-1))

            out.set_shape(inp.shape[:-inner_rank] + out.shape[-inner_rank:])
            return out


class Conv2DSame(tf.keras.layers.Layer):
    """A wrapper class for a 2D convolution with 'same' padding.

    In contrast to tf.keras.layers.Conv2D, this layer aligns the kernel with the
    top-left corner rather than the bottom-right corner. Optionally, a batch
    normalization and an activation can be added on top.
    """

    def __init__(
            self,
            output_channels: int,
            kernel_size: int,
            name: str,
            strides: int = 1,
            atrous_rate: int = 1,
            use_bias: bool = True,
            use_bn: bool = False,
            bn_layer: tf.keras.layers.Layer = tf.keras.layers.BatchNormalization,
            bn_gamma_initializer: str = 'ones',
            activation: Optional[str] = None,
            use_switchable_atrous_conv: bool = False,
            use_global_context_in_sac: bool = False,
            conv_kernel_weight_decay: float = 0.0):
        """Initializes convolution with zero padding aligned to the top-left corner.

        DeepLab aligns zero padding differently to tf.keras 'same' padding.
        Considering a convolution with a 7x7 kernel, a stride of 2 and an even input
        size, tf.keras 'same' padding will add 2 zero padding to the top-left and 3
        zero padding to the bottom-right. However, for consistent feature alignment,
        DeepLab requires an equal padding of 3 in all directions. This behavior is
        consistent with e.g. the ResNet 'stem' block.

        Args:
          output_channels: An integer specifying the number of filters of the
            convolution.
          kernel_size: An integer specifying the size of the convolution kernel.
          name: A string specifying the name of this layer.
          strides: An optional integer or tuple of integers specifying the size of
            the strides (default: 1).
          atrous_rate: An optional integer or tuple of integers specifying the
            atrous rate of the convolution (default: 1).
          use_bias: An optional flag specifying whether bias should be added for the
            convolution.
          use_bn: An optional flag specifying whether batch normalization should be
            added after the convolution (default: False).
          bn_layer: An optional tf.keras.layers.Layer that computes the
            normalization (default: tf.keras.layers.BatchNormalization).
          bn_gamma_initializer: An initializer for the batch norm gamma weight.
          activation: An optional flag specifying an activation function to be added
            after the convolution.
          use_switchable_atrous_conv: Boolean, whether the layer uses switchable
            atrous convolution.
          use_global_context_in_sac: Boolean, whether the switchable atrous
            convolution (SAC) uses pre- and post-global context.
          conv_kernel_weight_decay: A float, the weight decay for convolution
            kernels.

        Raises:
          ValueError: If use_bias and use_bn in the convolution.
        """
        super(Conv2DSame, self).__init__(name=name)

        if use_bn and use_bias:
            raise ValueError('Conv2DSame is using convolution bias with batch_norm.')

        if use_global_context_in_sac:
            self._pre_global_context = GlobalContext(name='pre_global_context')

        convolution_op = tf.keras.layers.Conv2D
        convolution_padding = 'same'
        if strides == 1 or strides == (1, 1):
            if use_switchable_atrous_conv:
                convolution_op = SwitchableAtrousConvolution
        else:
            padding = _compute_padding_size(kernel_size, atrous_rate)
            self._zeropad = tf.keras.layers.ZeroPadding2D(
                padding=(padding, padding), name='zeropad')
            convolution_padding = 'valid'
        self._conv = convolution_op(
            output_channels,
            kernel_size,
            strides=strides,
            padding=convolution_padding,
            use_bias=use_bias,
            dilation_rate=atrous_rate,
            name='conv',
            kernel_initializer='he_normal',
            kernel_regularizer=tf.keras.regularizers.l2(
                conv_kernel_weight_decay))

        if use_global_context_in_sac:
            self._post_global_context = GlobalContext(name='post_global_context')

        if use_bn:
            self._batch_norm = bn_layer(axis=3, name='batch_norm',
                                        gamma_initializer=bn_gamma_initializer)

        self._activation_fn = None
        if activation is not None:
            self._activation_fn = activations.get_activation(activation)

        self._use_global_context_in_sac = use_global_context_in_sac
        self._strides = strides
        self._use_bn = use_bn

    def call(self, input_tensor, training=False):
        """Performs a forward pass.

        Args:
          input_tensor: An input tensor of type tf.Tensor with shape [batch, height,
            width, channels].
          training: A boolean flag indicating whether training behavior should be
            used (default: False).

        Returns:
          The output tensor.
        """
        x = input_tensor
        if self._use_global_context_in_sac:
            x = self._pre_global_context(x)

        if not (self._strides == 1 or self._strides == (1, 1)):
            x = self._zeropad(x)
        x = self._conv(x)

        if self._use_global_context_in_sac:
            x = self._post_global_context(x)

        if self._use_bn:
            x = self._batch_norm(x, training=training)

        if self._activation_fn is not None:
            x = self._activation_fn(x)
        return x


class DepthwiseConv2DSame(tf.keras.layers.Layer):
    """A wrapper class for a 2D depthwise convolution.

    In contrast to convolutions in tf.keras.layers.DepthwiseConv2D, this layers
    aligns the kernel with the top-left corner rather than the bottom-right
    corner. Optionally, a batch normalization and an activation can be added.
    """

    def __init__(self,
                 kernel_size: int,
                 name: str,
                 strides: int = 1,
                 atrous_rate: int = 1,
                 use_bias: bool = True,
                 use_bn: bool = False,
                 bn_layer=tf.keras.layers.BatchNormalization,
                 activation: Optional[str] = None):
        """Initializes a 2D depthwise convolution.

        Args:
          kernel_size: An integer specifying the size of the convolution kernel.
          name: A string specifying the name of this layer.
          strides: An optional integer or tuple of integers specifying the size of
            the strides (default: 1).
          atrous_rate: An optional integer or tuple of integers specifying the
            atrous rate of the convolution (default: 1).
          use_bias: An optional flag specifying whether bias should be added for the
            convolution.
          use_bn: An optional flag specifying whether batch normalization should be
            added after the convolution (default: False).
          bn_layer: An optional tf.keras.layers.Layer that computes the
            normalization (default: tf.keras.layers.BatchNormalization).
          activation: An optional flag specifying an activation function to be added
            after the convolution.

        Raises:
          ValueError: If use_bias and use_bn in the convolution.
        """
        super(DepthwiseConv2DSame, self).__init__(name=name)

        if use_bn and use_bias:
            raise ValueError(
                'DepthwiseConv2DSame is using convlution bias with batch_norm.')

        if strides == 1 or strides == (1, 1):
            convolution_padding = 'same'
        else:
            padding = _compute_padding_size(kernel_size, atrous_rate)
            self._zeropad = tf.keras.layers.ZeroPadding2D(
                padding=(padding, padding), name='zeropad')
            convolution_padding = 'valid'
        self._depthwise_conv = tf.keras.layers.DepthwiseConv2D(
            kernel_size=kernel_size,
            strides=strides,
            padding=convolution_padding,
            use_bias=use_bias,
            dilation_rate=atrous_rate,
            name='depthwise_conv')
        if use_bn:
            self._batch_norm = bn_layer(axis=3, name='batch_norm')

        self._activation_fn = None
        if activation is not None:
            self._activation_fn = activations.get_activation(activation)

        self._strides = strides
        self._use_bn = use_bn

    def call(self, input_tensor, training=False):
        """Performs a forward pass.

        Args:
          input_tensor: An input tensor of type tf.Tensor with shape [batch, height,
            width, channels].
          training: A boolean flag indicating whether training behavior should be
            used (default: False).

        Returns:
          The output tensor.
        """
        x = input_tensor
        if not (self._strides == 1 or self._strides == (1, 1)):
            x = self._zeropad(x)
        x = self._depthwise_conv(x)
        if self._use_bn:
            x = self._batch_norm(x, training=training)
        if self._activation_fn is not None:
            x = self._activation_fn(x)
        return x


class SeparableConv2DSame(tf.keras.layers.Layer):
    """A wrapper class for a 2D separable convolution.

    In contrast to convolutions in tf.keras.layers.SeparableConv2D, this layers
    aligns the kernel with the top-left corner rather than the bottom-right
    corner. Optionally, a batch normalization and an activation can be added.
    """

    def __init__(
            self,
            output_channels: int,
            kernel_size: int,
            name: str,
            strides: int = 1,
            atrous_rate: int = 1,
            use_bias: bool = True,
            use_bn: bool = False,
            bn_layer: tf.keras.layers.Layer = tf.keras.layers.BatchNormalization,
            activation: Optional[str] = None):
        """Initializes a 2D separable convolution.

        Args:
          output_channels: An integer specifying the number of filters of the
            convolution output.
          kernel_size: An integer specifying the size of the convolution kernel.
          name: A string specifying the name of this layer.
          strides: An optional integer or tuple of integers specifying the size of
            the strides (default: 1).
          atrous_rate: An optional integer or tuple of integers specifying the
            atrous rate of the convolution (default: 1).
          use_bias: An optional flag specifying whether bias should be added for the
            convolution.
          use_bn: An optional flag specifying whether batch normalization should be
            added after the convolution (default: False).
          bn_layer: An optional tf.keras.layers.Layer that computes the
            normalization (default: tf.keras.layers.BatchNormalization).
          activation: An optional flag specifying an activation function to be added
            after the convolution.

        Raises:
          ValueError: If use_bias and use_bn in the convolution.
        """
        super(SeparableConv2DSame, self).__init__(name=name)
        if use_bn and use_bias:
            raise ValueError(
                'SeparableConv2DSame is using convolution bias with batch_norm.')

        self._depthwise = DepthwiseConv2DSame(
            kernel_size=kernel_size,
            name='depthwise',
            strides=strides,
            atrous_rate=atrous_rate,
            use_bias=use_bias,
            use_bn=use_bn,
            bn_layer=bn_layer,
            activation=activation)
        self._pointwise = Conv2DSame(
            output_channels=output_channels,
            kernel_size=1,
            name='pointwise',
            strides=1,
            atrous_rate=1,
            use_bias=use_bias,
            use_bn=use_bn,
            bn_layer=bn_layer,
            activation=activation)

    def call(self, input_tensor, training=False):
        """Performs a forward pass.

        Args:
          input_tensor: An input tensor of type tf.Tensor with shape [batch, height,
            width, channels].
          training: A boolean flag indicating whether training behavior should be
            used (default: False).

        Returns:
          The output tensor.
        """
        x = self._depthwise(input_tensor, training=training)
        return self._pointwise(x, training=training)


class StackedConv2DSame(tf.keras.layers.Layer):
    """Stacked Conv2DSame or SeparableConv2DSame.

    This class sequentially stacks a given number of Conv2DSame layers or
    SeparableConv2DSame layers.
    """

    def __init__(
            self,
            num_layers: int,
            conv_type: str,
            output_channels: int,
            kernel_size: int,
            name: str,
            strides: int = 1,
            atrous_rate: int = 1,
            use_bias: bool = True,
            use_bn: bool = False,
            bn_layer: tf.keras.layers.Layer = tf.keras.layers.BatchNormalization,
            activation: Optional[str] = None):
        """Initializes a stack of convolutions.

        Args:
          num_layers: The number of convolutions to create.
          conv_type: A string specifying the convolution type used in each block.
            Must be one of 'standard_conv' or 'depthwise_separable_conv'.
          output_channels: An integer specifying the number of filters of the
            convolution output.
          kernel_size: An integer specifying the size of the convolution kernel.
          name: A string specifying the name of this layer.
          strides: An optional integer or tuple of integers specifying the size of
            the strides (default: 1).
          atrous_rate: An optional integer or tuple of integers specifying the
            atrous rate of the convolution (default: 1).
          use_bias: An optional flag specifying whether bias should be added for the
            convolution.
          use_bn: An optional flag specifying whether batch normalization should be
            added after the convolution (default: False).
          bn_layer: An optional tf.keras.layers.Layer that computes the
            normalization (default: tf.keras.layers.BatchNormalization).
          activation: An optional flag specifying an activation function to be added
            after the convolution.

        Raises:
          ValueError: An error occurs when conv_type is neither 'standard_conv'
            nor 'depthwise_separable_conv'.
        """
        super(StackedConv2DSame, self).__init__(name=name)
        if conv_type == 'standard_conv':
            convolution_op = Conv2DSame
        elif conv_type == 'depthwise_separable_conv':
            convolution_op = SeparableConv2DSame
        else:
            raise ValueError('Convolution %s not supported.' % conv_type)

        for index in range(num_layers):
            current_name = utils.get_conv_bn_act_current_name(index, use_bn,
                                                              activation)
            utils.safe_setattr(self, current_name, convolution_op(
                output_channels=output_channels,
                kernel_size=kernel_size,
                name=utils.get_layer_name(current_name),
                strides=strides,
                atrous_rate=atrous_rate,
                use_bias=use_bias,
                use_bn=use_bn,
                bn_layer=bn_layer,
                activation=activation))
        self._num_layers = num_layers
        self._use_bn = use_bn
        self._activation = activation

    def call(self, input_tensor, training=False):
        """Performs a forward pass.

        Args:
          input_tensor: An input tensor of type tf.Tensor with shape [batch, height,
            width, channels].
          training: A boolean flag indicating whether training behavior should be
            used (default: False).

        Returns:
          The output tensor.
        """
        x = input_tensor
        for index in range(self._num_layers):
            current_name = utils.get_conv_bn_act_current_name(index, self._use_bn,
                                                              self._activation)
            x = getattr(self, current_name)(x, training=training)
        return x


class Conv1D(tf.keras.layers.Layer):
    """A wrapper class for a 1D convolution with batch norm and activation.

    Conv1D creates a convolution kernel that is convolved with the layer input
    over a single spatial (or temporal) dimension to produce a tensor of outputs.
    The input should always be 3D with shape [batch, length, channel], so
    accordingly, the optional batch norm is done on axis=2.

    In DeepLab, we use Conv1D only with kernel_size = 1 for dual path transformer
    layers in MaX-DeepLab [1] architectures.

    Reference:
    [1] MaX-DeepLab: End-to-End Panoptic Segmentation with Mask Transformers,
        CVPR 2021.
          Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen.
    """

    def __init__(
            self,
            output_channels: int,
            name: str,
            use_bias: bool = True,
            use_bn: bool = False,
            bn_layer: tf.keras.layers.Layer = tf.keras.layers.BatchNormalization,
            bn_gamma_initializer: str = 'ones',
            activation: Optional[str] = None,
            conv_kernel_weight_decay: float = 0.0,
            kernel_initializer='he_normal',
            kernel_size: int = 1,
            padding: str = 'valid'):
        """Initializes a Conv1D.

        Args:
          output_channels: An integer specifying the number of filters of the
            convolution.
          name: A string specifying the name of this layer.
          use_bias: An optional flag specifying whether bias should be added for the
            convolution.
          use_bn: An optional flag specifying whether batch normalization should be
            added after the convolution (default: False).
          bn_layer: An optional tf.keras.layers.Layer that computes the
            normalization (default: tf.keras.layers.BatchNormalization).
          bn_gamma_initializer: An initializer for the batch norm gamma weight.
          activation: An optional flag specifying an activation function to be added
            after the convolution.
          conv_kernel_weight_decay: A float, the weight decay for convolution
            kernels.
          kernel_initializer: An initializer for the convolution kernel.
          kernel_size: An integer specifying the size of the convolution kernel.
          padding: An optional string specifying the padding to use. Must be either
            'same' or 'valid' (default: 'valid').

        Raises:
          ValueError: If use_bias and use_bn in the convolution.
        """
        super(Conv1D, self).__init__(name=name)

        if use_bn and use_bias:
            raise ValueError('Conv1D is using convlution bias with batch_norm.')

        self._conv = tf.keras.layers.Conv1D(
            output_channels,
            kernel_size=kernel_size,
            strides=1,
            padding=padding,
            use_bias=use_bias,
            name='conv',
            kernel_initializer=kernel_initializer,
            kernel_regularizer=tf.keras.regularizers.l2(
                conv_kernel_weight_decay))

        self._batch_norm = None
        if use_bn:
            # Batch norm uses axis=2 because the input is 3D with channel being the
            # last dimension.
            self._batch_norm = bn_layer(axis=2, name='batch_norm',
                                        gamma_initializer=bn_gamma_initializer)

        self._activation_fn = None
        if activation is not None:
            self._activation_fn = activations.get_activation(activation)

    def call(self, input_tensor, training=False):
        """Performs a forward pass.

        Args:
          input_tensor: An input tensor of type tf.Tensor with shape [batch, length,
            channels].
          training: A boolean flag indicating whether training behavior should be
            used (default: False).

        Returns:
          The output tensor.
        """
        x = self._conv(input_tensor)
        if self._batch_norm is not None:
            x = self._batch_norm(x, training=training)
        if self._activation_fn is not None:
            x = self._activation_fn(x)
        return x
