"""
 Copyright (c) 2022 Intel Corporation
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
      http://www.apache.org/licenses/LICENSE-2.0
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import functools
import tensorflow as tf

from examples.tensorflow.common.object_detection.architecture import nn_ops


class Fpn:
    """Feature pyramid networks.

    Feature Pyramid Networks were proposed in:
    [1] Tsung-Yi Lin, Piotr Dollar, Ross Girshick, Kaiming He, Bharath Hariharan, and
    Serge Belongie. Feature Pyramid Networks for Object Detection. CVPR 2017.
    """

    def __init__(self,
                 min_level=3,
                 max_level=7,
                 fpn_feat_dims=256,
                 use_separable_conv=False,
                 activation='relu',
                 use_batch_norm=True,
                 norm_activation=nn_ops.norm_activation_builder(activation='relu')):
        """FPN initialization function.

        Args:
            min_level: `int` minimum level in FPN output feature maps.
            max_level: `int` maximum level in FPN output feature maps.
            fpn_feat_dims: `int` number of filters in FPN layers.
            use_separable_conv: `bool`, if True use separable convolution for
                convolution in FPN layers.
            use_batch_norm: 'bool', indicating whether batchnorm layers are added.
            norm_activation: an operation that includes a normalization layer
                followed by an optional activation layer.
        """

        self._min_level = min_level
        self._max_level = max_level
        self._fpn_feat_dims = fpn_feat_dims

        if use_separable_conv:
            self._conv2d_op = functools.partial(tf.keras.layers.SeparableConv2D, depth_multiplier=1)
        else:
            self._conv2d_op = tf.keras.layers.Conv2D
        if activation == 'relu':
            self._activation_op = tf.nn.relu
        elif activation == 'swish':
            self._activation_op = tf.nn.swish
        else:
            raise ValueError('Unsupported activation `{}`.'.format(activation))

        self._use_batch_norm = use_batch_norm
        self._norm_activation = norm_activation

        self._norm_activations = {}
        self._lateral_conv2d_op = {}
        self._post_hoc_conv2d_op = {}
        self._coarse_conv2d_op = {}

        for level in range(self._min_level, self._max_level + 1):
            if self._use_batch_norm:
                self._norm_activations[level] = norm_activation(use_activation=False,
                                                                name='p%d-bn' % level)

            self._lateral_conv2d_op[level] = self._conv2d_op(filters=self._fpn_feat_dims,
                                                             kernel_size=(1, 1),
                                                             padding='same',
                                                             name='l%d' % level)

            self._post_hoc_conv2d_op[level] = self._conv2d_op(filters=self._fpn_feat_dims,
                                                              strides=(1, 1),
                                                              kernel_size=(3, 3),
                                                              padding='same',
                                                              name='post_hoc_d%d' % level)

            self._coarse_conv2d_op[level] = self._conv2d_op(filters=self._fpn_feat_dims,
                                                            strides=(2, 2),
                                                            kernel_size=(3, 3),
                                                            padding='same',
                                                            name='p%d' % level)

    def __call__(self, multilevel_features, is_training=None):
        """Returns the FPN features for a given multilevel features.

        Args:
            multilevel_features: a `dict` containing `int` keys for continuous feature
                levels, e.g., [2, 3, 4, 5]. The values are corresponding features with
                shape [batch_size, height_l, width_l, num_filters].
            is_training: `bool` if True, the model is in training mode.

        Returns:
            a `dict` containing `int` keys for continuous feature levels
            [min_level, min_level + 1, ..., max_level]. The values are corresponding
            FPN features with shape [batch_size, height_l, width_l, fpn_feat_dims].
        """

        input_levels = list(multilevel_features.keys())
        if min(input_levels) > self._min_level:
            raise ValueError('The minimum backbone level {} should be '.format(min(input_levels)) +
                             'less or equal to FPN minimum level {}.'.format(self._min_level))

        backbone_max_level = min(max(input_levels), self._max_level)
        with tf.name_scope('fpn'):
            # Adds lateral connections.
            feats_lateral = {}
            for level in range(self._min_level, backbone_max_level + 1):
                feats_lateral[level] = self._lateral_conv2d_op[level](multilevel_features[level])

            # Adds top-down path.
            feats = {backbone_max_level: feats_lateral[backbone_max_level]}
            for level in range(backbone_max_level - 1, self._min_level - 1, -1):
                feats[level] = tf.keras.layers.UpSampling2D()(feats[level + 1]) + feats_lateral[level]

            # Adds post-hoc 3x3 convolution kernel.
            for level in range(self._min_level, backbone_max_level + 1):
                feats[level] = self._post_hoc_conv2d_op[level](feats[level])

            # Adds coarser FPN levels introduced for RetinaNet.
            for level in range(backbone_max_level + 1, self._max_level + 1):
                feats_in = feats[level - 1]
                if level > backbone_max_level + 1:
                    feats_in = self._activation_op(feats_in)
                feats[level] = self._coarse_conv2d_op[level](feats_in)

            if self._use_batch_norm:
                # Adds batch_norm layer.
                for level in range(self._min_level, self._max_level + 1):
                    feats[level] = self._norm_activations[level](feats[level], is_training=is_training)

        return feats
