"""
 Copyright (c) 2022 Intel Corporation
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
      http://www.apache.org/licenses/LICENSE-2.0
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import functools

import numpy as np
import tensorflow as tf

from examples.tensorflow.common.object_detection.architecture import nn_ops


class RetinanetHead:
    """RetinaNet head."""

    def __init__(self,
                 min_level,
                 max_level,
                 num_classes,
                 anchors_per_location,
                 num_convs=4,
                 num_filters=256,
                 use_separable_conv=False,
                 norm_activation=nn_ops.norm_activation_builder(activation='relu')):

        """Initialize params to build RetinaNet head.

        Args:
          min_level: `int` number of minimum feature level.
          max_level: `int` number of maximum feature level.
          num_classes: `int` number of classification categories.
          anchors_per_location: `int` number of anchors per pixel location.
          num_convs: `int` number of stacked convolution before the last prediction
            layer.
          num_filters: `int` number of filters used in the head architecture.
          use_separable_conv: `bool` to indicate whether to use separable
            convoluation.
          norm_activation: an operation that includes a normalization layer followed
            by an optional activation layer.
        """
        self._min_level = min_level
        self._max_level = max_level

        self._num_classes = num_classes
        self._anchors_per_location = anchors_per_location

        self._num_convs = num_convs
        self._num_filters = num_filters
        self._use_separable_conv = use_separable_conv
        with tf.name_scope('class_net') as scope_name:
            self._class_name_scope = tf.name_scope(scope_name)
        with tf.name_scope('box_net') as scope_name:
            self._box_name_scope = tf.name_scope(scope_name)
        self._build_class_net_layers(norm_activation)
        self._build_box_net_layers(norm_activation)

    def _class_net_batch_norm_name(self, i, level):
        return 'class-%d-%d' % (i, level)

    def _box_net_batch_norm_name(self, i, level):
        return 'box-%d-%d' % (i, level)

    def _build_class_net_layers(self, norm_activation):
        """Build re-usable layers for class prediction network."""
        if self._use_separable_conv:
            self._class_predict = tf.keras.layers.SeparableConv2D(
                self._num_classes * self._anchors_per_location,
                kernel_size=(3, 3),
                bias_initializer=tf.constant_initializer(-np.log((1 - 0.01) / 0.01)),
                padding='same',
                name='class-predict')
        else:
            self._class_predict = tf.keras.layers.Conv2D(
                self._num_classes * self._anchors_per_location,
                kernel_size=(3, 3),
                bias_initializer=tf.constant_initializer(-np.log((1 - 0.01) / 0.01)),
                kernel_initializer=tf.keras.initializers.RandomNormal(stddev=1e-5),
                padding='same',
                name='class-predict')

        self._class_conv = []
        self._class_norm_activation = {}
        for i in range(self._num_convs):
            if self._use_separable_conv:
                self._class_conv.append(
                    tf.keras.layers.SeparableConv2D(
                        self._num_filters,
                        kernel_size=(3, 3),
                        bias_initializer=tf.zeros_initializer(),
                        activation=None,
                        padding='same',
                        name='class-' + str(i)))
            else:
                self._class_conv.append(
                    tf.keras.layers.Conv2D(
                        self._num_filters,
                        kernel_size=(3, 3),
                        bias_initializer=tf.zeros_initializer(),
                        kernel_initializer=tf.keras.initializers.RandomNormal(
                            stddev=0.01),
                        activation=None,
                        padding='same',
                        name='class-' + str(i)))

            for level in range(self._min_level, self._max_level + 1):
                name = self._class_net_batch_norm_name(i, level)
                self._class_norm_activation[name] = norm_activation(name=name)

    def _build_box_net_layers(self, norm_activation):
        """Build re-usable layers for box prediction network."""
        if self._use_separable_conv:
            self._box_predict = tf.keras.layers.SeparableConv2D(
                4 * self._anchors_per_location,
                kernel_size=(3, 3),
                bias_initializer=tf.zeros_initializer(),
                padding='same',
                name='box-predict')
        else:
            self._box_predict = tf.keras.layers.Conv2D(
                4 * self._anchors_per_location,
                kernel_size=(3, 3),
                bias_initializer=tf.zeros_initializer(),
                kernel_initializer=tf.keras.initializers.RandomNormal(stddev=1e-5),
                padding='same',
                name='box-predict')

        self._box_conv = []
        self._box_norm_activation = {}
        for i in range(self._num_convs):
            if self._use_separable_conv:
                self._box_conv.append(
                    tf.keras.layers.SeparableConv2D(
                        self._num_filters,
                        kernel_size=(3, 3),
                        activation=None,
                        bias_initializer=tf.zeros_initializer(),
                        padding='same',
                        name='box-' + str(i)))
            else:
                self._box_conv.append(
                    tf.keras.layers.Conv2D(
                        self._num_filters,
                        kernel_size=(3, 3),
                        activation=None,
                        bias_initializer=tf.zeros_initializer(),
                        kernel_initializer=tf.keras.initializers.RandomNormal(
                            stddev=0.01),
                        padding='same',
                        name='box-' + str(i)))

            for level in range(self._min_level, self._max_level + 1):
                name = self._box_net_batch_norm_name(i, level)
                self._box_norm_activation[name] = norm_activation(name=name)

    def __call__(self, fpn_features, is_training=None):
        """Returns outputs of RetinaNet head."""
        class_outputs = {}
        box_outputs = {}
        with tf.name_scope('retinanet_head'):
            for level in range(self._min_level, self._max_level + 1):
                features = fpn_features[level]
                class_outputs[str(level)] = self.class_net(features, level, is_training=is_training)
                box_outputs[str(level)] = self.box_net(features, level, is_training=is_training)

        return class_outputs, box_outputs

    def class_net(self, features, level, is_training):
        """Class prediction network for RetinaNet."""
        with self._class_name_scope:
            for i in range(self._num_convs):
                features = self._class_conv[i](features)
                # The convolution layers in the class net are shared among all levels,
                # but each level has its batch normlization to capture the statistical
                # difference among different levels.
                name = self._class_net_batch_norm_name(i, level)
                features = self._class_norm_activation[name](features, is_training=is_training)

            classes = self._class_predict(features)

        return classes

    def box_net(self, features, level, is_training=None):
        """Box regression network for RetinaNet."""
        with self._box_name_scope:
            for i in range(self._num_convs):
                features = self._box_conv[i](features)
                # The convolution layers in the box net are shared among all levels, but
                # each level has its batch normlization to capture the statistical
                # difference among different levels.
                name = self._box_net_batch_norm_name(i, level)
                features = self._box_norm_activation[name](features, is_training=is_training)

            boxes = self._box_predict(features)
        return boxes


class RpnHead(tf.keras.layers.Layer):
    """Region Proposal Network head."""

    def __init__(self,
                 min_level,
                 max_level,
                 anchors_per_location,
                 num_convs=2, # pylint: disable=W0613
                 num_filters=256,
                 use_separable_conv=False,
                 activation='relu',
                 use_batch_norm=True,
                 norm_activation=nn_ops.norm_activation_builder(activation='relu')):
        """Initialize params to build Region Proposal Network head.

        Args:
            min_level: `int` number of minimum feature level.
            max_level: `int` number of maximum feature level.
            anchors_per_location: `int` number of number of anchors per pixel
                location.
            num_convs: `int` number that represents the number of the intermediate
                conv layers before the prediction.
            num_filters: `int` number that represents the number of filters of the
                intermediate conv layers.
            use_separable_conv: `bool`, indicating whether the separable conv layers
                is used.
            activation: activation function. Support 'relu' and 'swish'.
            use_batch_norm: 'bool', indicating whether batchnorm layers are added.
            norm_activation: an operation that includes a normalization layer followed
                by an optional activation layer.
        """
        super().__init__()
        self._min_level = min_level
        self._max_level = max_level
        self._anchors_per_location = anchors_per_location
        if activation == 'relu':
            self._activation_op = tf.nn.relu
        elif activation == 'swish':
            self._activation_op = tf.nn.swish
        else:
            raise ValueError('Unsupported activation `{}`.'.format(activation))
        self._use_batch_norm = use_batch_norm

        if use_separable_conv:
            self._conv2d_op = functools.partial(
                tf.keras.layers.SeparableConv2D,
                depth_multiplier=1,
                bias_initializer=tf.zeros_initializer())
        else:
            self._conv2d_op = functools.partial(
                tf.keras.layers.Conv2D,
                kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
                bias_initializer=tf.zeros_initializer())

        self._rpn_conv = self._conv2d_op(
            num_filters,
            kernel_size=(3, 3),
            strides=(1, 1),
            activation=(None if self._use_batch_norm else self._activation_op),
            padding='same',
            name='rpn')
        self._rpn_class_conv = self._conv2d_op(
            anchors_per_location,
            kernel_size=(1, 1),
            strides=(1, 1),
            padding='valid',
            name='rpn-class')
        self._rpn_box_conv = self._conv2d_op(
            4 * anchors_per_location,
            kernel_size=(1, 1),
            strides=(1, 1),
            padding='valid',
            name='rpn-box')

        self._norm_activations = {}
        if self._use_batch_norm:
            for level in range(self._min_level, self._max_level + 1):
                self._norm_activations[level] = norm_activation(name='rpn-l%d-bn' % level)

    def _shared_rpn_heads(self, features, anchors_per_location, level, is_training): # pylint: disable=W0613
        """Shared RPN heads."""
        features = self._rpn_conv(features)
        if self._use_batch_norm:
            # The batch normalization layers are not shared between levels.
            features = self._norm_activations[level](features, is_training=is_training)
        # Proposal classification scores
        scores = self._rpn_class_conv(features)
        # Proposal bbox regression deltas
        bboxes = self._rpn_box_conv(features)

        return scores, bboxes

    def __call__(self, features, is_training=None):
        scores_outputs = {}
        box_outputs = {}
        with tf.name_scope('rpn_head'):
            for level in range(self._min_level, self._max_level + 1):
                scores_output, box_output = self._shared_rpn_heads(
                    features[level], self._anchors_per_location, level, is_training)
                scores_outputs[str(level)] = scores_output
                box_outputs[str(level)] = box_output
            return scores_outputs, box_outputs


class FastrcnnHead(tf.keras.layers.Layer):
    """Fast R-CNN box head."""

    def __init__(self,
                 num_classes,
                 num_convs=0,
                 num_filters=256,
                 use_separable_conv=False,
                 num_fcs=2,
                 fc_dims=1024,
                 activation='relu',
                 use_batch_norm=True,
                 norm_activation=nn_ops.norm_activation_builder(activation='relu')):
        """Initialize params to build Fast R-CNN box head.
        Args:
            num_classes: a integer for the number of classes.
            num_convs: `int` number that represents the number of the intermediate
                conv layers before the FC layers.
            num_filters: `int` number that represents the number of filters of the
                intermediate conv layers.
            use_separable_conv: `bool`, indicating whether the separable conv layers
                is used.
            num_fcs: `int` number that represents the number of FC layers before the
                predictions.
            fc_dims: `int` number that represents the number of dimension of the FC
                layers.
            activation: activation function. Support 'relu' and 'swish'.
            use_batch_norm: 'bool', indicating whether batchnorm layers are added.
            norm_activation: an operation that includes a normalization layer followed
                by an optional activation layer.
        """
        super().__init__()
        self._num_classes = num_classes
        self._num_convs = num_convs
        self._num_filters = num_filters

        if use_separable_conv:
            self._conv2d_op = functools.partial(
                tf.keras.layers.SeparableConv2D,
                depth_multiplier=1,
                bias_initializer=tf.zeros_initializer())
        else:
            self._conv2d_op = functools.partial(
                tf.keras.layers.Conv2D,
                kernel_initializer=tf.keras.initializers.VarianceScaling(
                    scale=2, mode='fan_out', distribution='untruncated_normal'),
                bias_initializer=tf.zeros_initializer())

        self._num_fcs = num_fcs
        self._fc_dims = fc_dims
        if activation == 'relu':
            self._activation_op = tf.nn.relu
        elif activation == 'swish':
            self._activation_op = tf.nn.swish
        else:
            raise ValueError('Unsupported activation `{}`.'.format(activation))
        self._use_batch_norm = use_batch_norm
        self._norm_activation = norm_activation

        self._conv_ops = []
        self._conv_bn_ops = []
        for i in range(self._num_convs):
            self._conv_ops.append(
                self._conv2d_op(
                    self._num_filters,
                    kernel_size=(3, 3),
                    strides=(1, 1),
                    padding='same',
                    dilation_rate=(1, 1),
                    activation=(None
                                if self._use_batch_norm else self._activation_op),
                    name='conv_{}'.format(i)))
            if self._use_batch_norm:
                self._conv_bn_ops.append(self._norm_activation())

        self._fc_ops = []
        self._fc_bn_ops = []
        for i in range(self._num_fcs):
            self._fc_ops.append(
                tf.keras.layers.Dense(
                    units=self._fc_dims,
                    activation=(None
                                if self._use_batch_norm else self._activation_op),
                    name='fc{}'.format(i)))
            if self._use_batch_norm:
                self._fc_bn_ops.append(self._norm_activation(fused=False))

        self._class_predict = tf.keras.layers.Dense(
            self._num_classes,
            kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
            bias_initializer=tf.zeros_initializer(),
            name='class-predict')
        self._box_predict = tf.keras.layers.Dense(
            self._num_classes * 4,
            kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.001),
            bias_initializer=tf.zeros_initializer(),
            name='box-predict')

    def __call__(self, roi_features, is_training=None):
        """Box and class branches for the Mask-RCNN model.

        Args:
            roi_features: A ROI feature tensor of shape [batch_size, num_rois,
                height_l, width_l, num_filters].
            is_training: `boolean`, if True if model is in training mode.

        Returns:
            class_outputs: a tensor with a shape of
                [batch_size, num_rois, num_classes], representing the class predictions.
            box_outputs: a tensor with a shape of
                [batch_size, num_rois, num_classes * 4], representing the box
                predictions.
        """

        with tf.name_scope('fast_rcnn_head'):
            # reshape inputs beofre FC.
            _, num_rois, height, width, filters = roi_features.get_shape().as_list()

            net = tf.reshape(roi_features, [-1, height, width, filters])
            for i in range(self._num_convs):
                net = self._conv_ops[i](net)
                if self._use_batch_norm:
                    net = self._conv_bn_ops[i](net, is_training=is_training)

            filters = self._num_filters if self._num_convs > 0 else filters
            net = tf.reshape(net, [-1, num_rois, height * width * filters])

            for i in range(self._num_fcs):
                net = self._fc_ops[i](net)
                if self._use_batch_norm:
                    net = self._fc_bn_ops[i](net, is_training=is_training)

            class_outputs = self._class_predict(net)
            box_outputs = self._box_predict(net)
            return class_outputs, box_outputs


class MaskrcnnHead(tf.keras.layers.Layer):
    """Mask R-CNN head."""

    def __init__(self,
                 num_classes,
                 mask_target_size,
                 num_convs=4,
                 num_filters=256,
                 use_separable_conv=False,
                 activation='relu',
                 use_batch_norm=True,
                 norm_activation=nn_ops.norm_activation_builder(activation='relu')):
        """Initialize params to build Fast R-CNN head.

        Args:
            num_classes: a integer for the number of classes.
            mask_target_size: a integer that is the resolution of masks.
            num_convs: `int` number that represents the number of the intermediate
                conv layers before the prediction.
            num_filters: `int` number that represents the number of filters of the
                intermediate conv layers.
            use_separable_conv: `bool`, indicating whether the separable conv layers
                is used.
            activation: activation function. Support 'relu' and 'swish'.
            use_batch_norm: 'bool', indicating whether batchnorm layers are added.
            norm_activation: an operation that includes a normalization layer followed
                by an optional activation layer.
        """
        super().__init__()
        self._num_classes = num_classes
        self._mask_target_size = mask_target_size

        self._num_convs = num_convs
        self._num_filters = num_filters
        if use_separable_conv:
            self._conv2d_op = functools.partial(
                tf.keras.layers.SeparableConv2D,
                depth_multiplier=1,
                bias_initializer=tf.zeros_initializer())
        else:
            self._conv2d_op = functools.partial(
                tf.keras.layers.Conv2D,
                kernel_initializer=tf.keras.initializers.VarianceScaling(
                    scale=2, mode='fan_out', distribution='untruncated_normal'),
                bias_initializer=tf.zeros_initializer())
        if activation == 'relu':
            self._activation_op = tf.nn.relu
        elif activation == 'swish':
            self._activation_op = tf.nn.swish
        else:
            raise ValueError('Unsupported activation `{}`.'.format(activation))
        self._use_batch_norm = use_batch_norm
        self._norm_activation = norm_activation
        self._conv2d_ops = []
        for i in range(self._num_convs):
            self._conv2d_ops.append(
                self._conv2d_op(
                    self._num_filters,
                    kernel_size=(3, 3),
                    strides=(1, 1),
                    padding='same',
                    dilation_rate=(1, 1),
                    activation=(None
                                if self._use_batch_norm else self._activation_op),
                    name='mask-conv-l%d' % i))

        self._mask_conv_transpose = tf.keras.layers.Conv2DTranspose(
            self._num_filters,
            kernel_size=(2, 2),
            strides=(2, 2),
            padding='valid',
            activation=(None if self._use_batch_norm else self._activation_op),
            kernel_initializer=tf.keras.initializers.VarianceScaling(
                scale=2, mode='fan_out', distribution='untruncated_normal'),
            bias_initializer=tf.zeros_initializer(),
            name='conv5-mask')

    def __call__(self, roi_features, class_indices, is_training=None):
        """Mask branch for the Mask-RCNN model.

        Args:
            roi_features: A ROI feature tensor of shape [batch_size, num_rois,
                height_l, width_l, num_filters].
            class_indices: a Tensor of shape [batch_size, num_rois], indicating which
                class the ROI is.
            is_training: `boolean`, if True if model is in training mode.

        Returns:
            mask_outputs: a tensor with a shape of
                [batch_size, num_masks, mask_height, mask_width, num_classes],
                representing the mask predictions.
            fg_gather_indices: a tensor with a shape of [batch_size, num_masks, 2],
                representing the fg mask targets.
        Raises:
          ValueError: If boxes is not a rank-3 tensor or the last dimension of
            boxes is not 4.
        """

        with tf.name_scope('mask_head'):
            _, num_rois, height, width, filters = roi_features.get_shape().as_list()
            net = tf.reshape(roi_features, [-1, height, width, filters])

            for i in range(self._num_convs):
                net = self._conv2d_ops[i](net)
                if self._use_batch_norm:
                    net = self._norm_activation()(net, is_training=is_training)

            net = self._mask_conv_transpose(net)
            if self._use_batch_norm:
                net = self._norm_activation()(net, is_training=is_training)

            mask_outputs = self._conv2d_op(
                self._num_classes,
                kernel_size=(1, 1),
                strides=(1, 1),
                padding='valid',
                name='mask_fcn_logits')(
                    net)
            mask_outputs = tf.reshape(mask_outputs, [
                -1, num_rois, self._mask_target_size, self._mask_target_size,
                self._num_classes
            ])

            with tf.name_scope('masks_post_processing'):
                batch_size, num_masks = class_indices.get_shape().as_list()
                mask_outputs = tf.transpose(a=mask_outputs, perm=[0, 1, 4, 2, 3])
                # Contructs indices for gather.
                batch_indices = tf.tile(
                    tf.expand_dims(tf.range(batch_size), axis=1), [1, num_masks])
                mask_indices = tf.tile(
                    tf.expand_dims(tf.range(num_masks), axis=0), [batch_size, 1])
                gather_indices = tf.stack(
                    [batch_indices, mask_indices, class_indices], axis=2)
                mask_outputs = tf.gather_nd(mask_outputs, gather_indices)
        return mask_outputs


class YOLOv4:
    """YOLOv4 neck and head"""

    def DarknetConv2D_BN_Leaky(self, *args, **kwargs):
        """Darknet Convolution2D followed by SyncBatchNormalization and LeakyReLU."""
        no_bias_kwargs = {'use_bias': False}
        no_bias_kwargs.update(kwargs)
        return nn_ops.compose(
            nn_ops.DarknetConv2D(*args, **no_bias_kwargs),
            tf.keras.layers.experimental.SyncBatchNormalization(),
            tf.keras.layers.LeakyReLU(alpha=0.1))

    def Spp_Conv2D_BN_Leaky(self, x, num_filters):
        y1 = tf.keras.layers.MaxPooling2D(pool_size=(5,5), strides=(1,1), padding='same')(x)
        y2 = tf.keras.layers.MaxPooling2D(pool_size=(9,9), strides=(1,1), padding='same')(x)
        y3 = tf.keras.layers.MaxPooling2D(pool_size=(13,13), strides=(1,1), padding='same')(x)

        y = nn_ops.compose(
                tf.keras.layers.Concatenate(),
                self.DarknetConv2D_BN_Leaky(num_filters, (1,1)))([y3, y2, y1, x])
        return y

    def make_yolo_head(self, x, num_filters):
        """6 Conv2D_BN_Leaky layers followed by a Conv2D_linear layer"""
        x = nn_ops.compose(
                self.DarknetConv2D_BN_Leaky(num_filters, (1,1)),
                self.DarknetConv2D_BN_Leaky(num_filters*2, (3,3)),
                self.DarknetConv2D_BN_Leaky(num_filters, (1,1)),
                self.DarknetConv2D_BN_Leaky(num_filters*2, (3,3)),
                self.DarknetConv2D_BN_Leaky(num_filters, (1,1)))(x)

        return x

    def make_yolo_spp_head(self, x, num_filters):
        """6 Conv2D_BN_Leaky layers followed by a Conv2D_linear layer"""
        x = nn_ops.compose(
                self.DarknetConv2D_BN_Leaky(num_filters, (1,1)),
                self.DarknetConv2D_BN_Leaky(num_filters*2, (3,3)),
                self.DarknetConv2D_BN_Leaky(num_filters, (1,1)))(x)

        x = self.Spp_Conv2D_BN_Leaky(x, num_filters)

        x = nn_ops.compose(
                self.DarknetConv2D_BN_Leaky(num_filters*2, (3,3)),
                self.DarknetConv2D_BN_Leaky(num_filters, (1,1)))(x)

        return x

    def __call__(self, feature_maps, feature_channel_nums, num_anchors, num_classes):
        f1, f2, f3 = feature_maps
        f1_channel_num, f2_channel_num, f3_channel_num = feature_channel_nums

        # feature map 1 head (19x19 for 608 input)
        x1 = self.make_yolo_spp_head(f1, f1_channel_num // 2)

        # upsample fpn merge for feature map 1 & 2
        x1_upsample = nn_ops.compose(
            self.DarknetConv2D_BN_Leaky(f2_channel_num // 2, (1, 1)),
            tf.keras.layers.UpSampling2D(2))(x1)

        x2 = self.DarknetConv2D_BN_Leaky(f2_channel_num // 2, (1, 1))(f2)
        x2 = tf.keras.layers.Concatenate()([x2, x1_upsample])

        # feature map 2 head (38x38 for 608 input)
        x2 = self.make_yolo_head(x2, f2_channel_num // 2)

        # upsample fpn merge for feature map 2 & 3
        x2_upsample = nn_ops.compose(
            self.DarknetConv2D_BN_Leaky(f3_channel_num // 2, (1, 1)),
            tf.keras.layers.UpSampling2D(2))(x2)

        x3 = self.DarknetConv2D_BN_Leaky(f3_channel_num // 2, (1, 1))(f3)
        x3 = tf.keras.layers.Concatenate()([x3, x2_upsample])

        # feature map 3 head & output (76x76 for 608 input)
        # x3, y3 = make_last_layers(x3, f3_channel_num//2, num_anchors*(num_classes+5))
        x3 = self.make_yolo_head(x3, f3_channel_num // 2)
        y3 = nn_ops.compose(
            self.DarknetConv2D_BN_Leaky(f3_channel_num, (3, 3)),
            nn_ops.DarknetConv2D(num_anchors * (num_classes + 5), (1, 1), name='predict_conv_3'))(x3)

        # downsample fpn merge for feature map 3 & 2
        x3_downsample = nn_ops.compose(
            tf.keras.layers.ZeroPadding2D(((1, 0), (1, 0))),
            self.DarknetConv2D_BN_Leaky(f2_channel_num // 2, (3, 3), strides=(2, 2)))(x3)

        x2 = tf.keras.layers.Concatenate()([x3_downsample, x2])

        # feature map 2 output (38x38 for 608 input)
        # x2, y2 = make_last_layers(x2, 256, num_anchors*(num_classes+5))
        x2 = self.make_yolo_head(x2, f2_channel_num // 2)
        y2 = nn_ops.compose(
            self.DarknetConv2D_BN_Leaky(f2_channel_num, (3, 3)),
            nn_ops.DarknetConv2D(num_anchors * (num_classes + 5), (1, 1), name='predict_conv_2'))(x2)

        # downsample fpn merge for feature map 2 & 1
        x2_downsample = nn_ops.compose(
            tf.keras.layers.ZeroPadding2D(((1, 0), (1, 0))),
            self.DarknetConv2D_BN_Leaky(f1_channel_num // 2, (3, 3), strides=(2, 2)))(x2)

        x1 = tf.keras.layers.Concatenate()([x2_downsample, x1])

        # feature map 1 output (19x19 for 608 input)
        # x1, y1 = make_last_layers(x1, f1_channel_num//2, num_anchors*(num_classes+5))
        x1 = self.make_yolo_head(x1, f1_channel_num // 2)
        y1 = nn_ops.compose(
            self.DarknetConv2D_BN_Leaky(f1_channel_num, (3, 3)),
            nn_ops.DarknetConv2D(num_anchors * (num_classes + 5), (1, 1), name='predict_conv_1'))(x1)

        return y1, y2, y3
