"""
 Copyright (c) 2022 Intel Corporation
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
      http://www.apache.org/licenses/LICENSE-2.0
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

from examples.tensorflow.common.object_detection.architecture import fpn
from examples.tensorflow.common.object_detection.architecture import heads
from examples.tensorflow.common.object_detection.architecture import nn_ops
from examples.tensorflow.common.object_detection.architecture import resnet
from examples.tensorflow.common.object_detection.architecture import darknet


def norm_activation_generator(params):
    return nn_ops.norm_activation_builder(momentum=params.batch_norm_momentum,
                                          epsilon=params.batch_norm_epsilon,
                                          trainable=True,
                                          activation=params.activation)


def backbone_generator(params):
    """Generator function for various backbone models."""
    backbone_name = params.model_params.architecture.backbone.name
    if params.model == 'RetinaNet' or params.model == 'MaskRCNN':
        if backbone_name == 'resnet':
            resnet_params = params.model_params.architecture.backbone.params
            backbone_fn = resnet.Resnet(resnet_depth=resnet_params.depth,
                                        activation=params.model_params.norm_activation.activation,
                                        norm_activation=norm_activation_generator(
                                          params.model_params.norm_activation))
        else:
            raise ValueError('Backbone {} is not supported for {} model.'.format(backbone_name, params.model))
    elif params.model == 'YOLOv4':
        if backbone_name == 'darknet':
            backbone_fn = darknet.CSPDarknet53()
        else:
            raise ValueError('Backbone {} is not supported for {} model.'.format(backbone_name, params.model))
    else:
        raise ValueError('Model {} is not supported.'.format(params.model))

    return backbone_fn


def multilevel_features_generator(params):
    """Generator function for various FPN models."""
    assert params.model_params.architecture.multilevel_features == 'fpn'
    fpn_params = params.model_params.architecture.fpn_params
    fpn_fn = fpn.Fpn(
        min_level=params.model_params.architecture.min_level,
        max_level=params.model_params.architecture.max_level,
        fpn_feat_dims=fpn_params.fpn_feat_dims,
        use_separable_conv=fpn_params.use_separable_conv,
        activation=params.model_params.norm_activation.activation,
        use_batch_norm=fpn_params.use_batch_norm,
        norm_activation=norm_activation_generator(params.model_params.norm_activation))

    return fpn_fn


def retinanet_head_generator(params):
    """Generator function for RetinaNet head architecture."""
    head_params = params.model_params.architecture.head_params
    anchors_per_location = params.model_params.anchor.num_scales * len(params.model_params.anchor.aspect_ratios)
    return heads.RetinanetHead(
        params.model_params.architecture.min_level,
        params.model_params.architecture.max_level,
        params.model_params.architecture.num_classes,
        anchors_per_location,
        head_params.num_convs,
        head_params.num_filters,
        head_params.use_separable_conv,
        norm_activation=norm_activation_generator(params.model_params.norm_activation))


def rpn_head_generator(params):
    """Generator function for RPN head architecture."""
    head_params = params.rpn_head
    anchors_per_location = params.anchor.num_scales * len(params.anchor.aspect_ratios)
    return heads.RpnHead(
        params.model_params.architecture.min_level,
        params.model_params.architecture.max_level,
        anchors_per_location,
        head_params.num_convs,
        head_params.num_filters,
        head_params.use_separable_conv,
        params.model_params.norm_activation.activation,
        head_params.use_batch_norm,
        norm_activation=norm_activation_generator(params.model_params.norm_activation))


def fast_rcnn_head_generator(params):
    """Generator function for Fast R-CNN head architecture."""
    head_params = params.frcnn_head
    return heads.FastrcnnHead(
        params.model_params.architecture.num_classes,
        head_params.num_convs,
        head_params.num_filters,
        head_params.use_separable_conv,
        head_params.num_fcs,
        head_params.fc_dims,
        params.model_params.norm_activation.activation,
        head_params.use_batch_norm,
        norm_activation=norm_activation_generator(params.model_params.norm_activation))


def mask_rcnn_head_generator(params):
    """Generator function for Mask R-CNN head architecture."""
    head_params = params.mrcnn_head
    return heads.MaskrcnnHead(
        params.model_params.architecture.num_classes,
        params.architecture.mask_target_size,
        head_params.num_convs,
        head_params.num_filters,
        head_params.use_separable_conv,
        params.model_params.norm_activation.activation,
        head_params.use_batch_norm,
        norm_activation=norm_activation_generator(params.model_params.norm_activation))


def yolo_v4_head_generator():
    """Generator function for YOLOv4 neck and head architecture"""
    return heads.YOLOv4()
