"""
 Copyright (c) 2022 Intel Corporation
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
      http://www.apache.org/licenses/LICENSE-2.0
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import tensorflow as tf

from examples.tensorflow.common.object_detection import base_model
from examples.tensorflow.common.object_detection.architecture import factory
from examples.tensorflow.common.object_detection.ops import postprocess_ops
from examples.tensorflow.common.object_detection.evaluation import coco_evaluator
from examples.tensorflow.common.logger import logger
from examples.tensorflow.common.object_detection import losses


class RetinanetModel(base_model.Model):
    """RetinaNet model function."""

    def __init__(self, params):
        super().__init__(params)

        # For eval metrics.
        self._params = params

        self._checkpoint_prefix = 'resnet50/'

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._head_fn = factory.retinanet_head_generator(params)

        # Loss function.
        self._cls_loss_fn = losses.RetinanetClassLoss(params.model_params.loss_params,
                                                      params.model_params.architecture.num_classes)
        self._box_loss_fn = losses.RetinanetBoxLoss(params.model_params.loss_params)
        self._box_loss_weight = params.model_params.loss_params.box_loss_weight

        # Predict function.
        self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
            params.model_params.architecture.min_level,
            params.model_params.architecture.max_level,
            params.model_params.postprocessing)

        # Input layer.
        self._input_layer = tf.keras.layers.Input(
            shape=(params.input_info.sample_size[1:]),
            name='',
            dtype=tf.float32)

    def build_outputs(self, inputs, is_training):
        backbone_features = self._backbone_fn(inputs)
        fpn_features = self._fpn_fn(backbone_features)
        cls_outputs, box_outputs = self._head_fn(fpn_features)

        model_outputs = {
            'cls_outputs': cls_outputs,
            'box_outputs': box_outputs,
        }

        return model_outputs

    def build_loss_fn(self, keras_model, compression_loss_fn):
        filter_fn = self.make_filter_trainable_variables_fn()
        trainable_variables = filter_fn(keras_model.trainable_variables)

        def _total_loss_fn(labels, outputs):
            cls_loss = self._cls_loss_fn(outputs['cls_outputs'],
                                         labels['cls_targets'],
                                         labels['num_positives'])
            box_loss = self._box_loss_fn(outputs['box_outputs'],
                                         labels['box_targets'],
                                         labels['num_positives'])

            model_loss = cls_loss + self._box_loss_weight * box_loss
            l2_regularization_loss = self.weight_decay_loss(trainable_variables)
            compression_loss = compression_loss_fn()
            total_loss = model_loss + l2_regularization_loss + compression_loss

            return {
                'total_loss': total_loss,
                'cls_loss': cls_loss,
                'box_loss': box_loss,
                'model_loss': model_loss,
                'l2_regularization_loss': l2_regularization_loss,
                'compression_loss': compression_loss,
            }

        return _total_loss_fn

    def build_model(self, weights=None, is_training=None):
        outputs = self.model_outputs(self._input_layer, is_training)
        keras_model = tf.keras.models.Model(inputs=self._input_layer, outputs=outputs, name='retinanet')

        if self._checkpoint_path:
            logger.info('Init backbone')
            init_checkpoint_fn = self.make_restore_checkpoint_fn()
            init_checkpoint_fn(keras_model)

        if weights:
            logger.info('Loaded pretrained weights from {}'.format(weights))
            keras_model.load_weights(weights)

        return keras_model

    def post_processing(self, labels, outputs):
        required_output_fields = ['cls_outputs', 'box_outputs']

        for field in required_output_fields:
            if field not in outputs:
                raise ValueError('"{}" is missing in outputs, requried {} found {}'.format(
                                 field, required_output_fields, outputs.keys()))

        boxes, scores, classes, valid_detections = self._generate_detections_fn(
            outputs['box_outputs'], outputs['cls_outputs'], labels['anchor_boxes'],
            labels['image_info'][:, 1:2, :])
        # Discards the old output tensors to save memory. The `cls_outputs` and
        # `box_outputs` are pretty big and could potentially lead to memory issue.
        outputs = {
            'source_id': labels['source_id'],
            'image_info': labels['image_info'],
            'num_detections': valid_detections,
            'detection_boxes': boxes,
            'detection_classes': classes,
            'detection_scores': scores,
        }

        return labels, outputs

    def eval_metrics(self):
        annotation_file = self._params.get('val_json_file', None)
        evaluator = coco_evaluator.COCOEvaluator(annotation_file=annotation_file,
                                                 include_mask=False)
        return coco_evaluator.MetricWrapper(evaluator)
