"""
 Copyright (c) 2022 Intel Corporation
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
      http://www.apache.org/licenses/LICENSE-2.0
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import tensorflow as tf

from examples.tensorflow.common.logger import logger
from examples.tensorflow.common.object_detection import base_model
from examples.tensorflow.common.object_detection import losses
from examples.tensorflow.common.object_detection.architecture import factory
from examples.tensorflow.common.object_detection.evaluation import coco_evaluator
from examples.tensorflow.common.object_detection.ops import postprocess_ops
from examples.tensorflow.common.object_detection.ops import roi_ops
from examples.tensorflow.common.object_detection.ops import spatial_transform_ops
from examples.tensorflow.common.object_detection.ops import target_ops
from examples.tensorflow.common.object_detection.utils import anchor
from examples.tensorflow.common.object_detection.utils import box_utils


def _restore_baseline_weights(keras_model, checkpoint_path):
    reader = tf.train.load_checkpoint(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()

    # Skip optimizer variables
    predicate = lambda x: 'optimizer' not in x and 'OPTIMIZER' not in x and 'variables' in x
    checkpoint_names = list(filter(predicate, var_to_shape_map.keys()))

    assignment_map = {}
    for v in keras_model.variables:
        var_name = v.name

        match_names = []
        for x in checkpoint_names:
            name = x.split('/')[1] # variables/*/.ATTRIBUTES/VARIABLE_VALUE
            key = name.replace('.S', '/')
            if var_name == key:
                match_names.append(x)

        if len(match_names) != 1:
            raise Exception('More than one matches for {}: {}'.format(v, match_names))

        assignment_map[match_names[0]] = v

    tf.compat.v1.train.init_from_checkpoint(checkpoint_path, assignment_map)


class MaskrcnnModel(base_model.Model):
    """Mask R-CNN model function."""

    def __init__(self, params):
        super().__init__(params)

        # For eval metrics.
        self._params = params
        self._checkpoint_prefix = 'resnet50/'

        self._include_mask = params.get('include_mask', False)

        # Architecture generators.
        self._backbone_fn = factory.backbone_generator(params)
        self._fpn_fn = factory.multilevel_features_generator(params)
        self._rpn_head_fn = factory.rpn_head_generator(params)

        self._generate_rois_fn = roi_ops.ROIGenerator(params.roi_proposal)
        self._sample_rois_fn = target_ops.ROISampler(params.roi_sampling)
        self._sample_masks_fn = target_ops.MaskSampler(
            params.architecture.mask_target_size,
            params.mask_sampling.num_mask_samples_per_image)

        self._frcnn_head_fn = factory.fast_rcnn_head_generator(params)

        if self._include_mask:
            self._mrcnn_head_fn = factory.mask_rcnn_head_generator(params)

        # Loss function.
        self._rpn_score_loss_fn = losses.RpnScoreLoss(params.rpn_score_loss)
        self._rpn_box_loss_fn = losses.RpnBoxLoss(params.rpn_box_loss)
        self._frcnn_class_loss_fn = losses.FastrcnnClassLoss()
        self._frcnn_box_loss_fn = losses.FastrcnnBoxLoss(params.frcnn_box_loss)
        if self._include_mask:
            self._mask_loss_fn = losses.MaskrcnnLoss()

        self._generate_detections_fn = postprocess_ops.GenericDetectionGenerator(params.postprocess)

    def build_outputs(self, inputs, is_training):
        model_outputs = {}

        image = inputs['image']
        _, image_height, image_width, _ = image.get_shape().as_list()
        backbone_features = self._backbone_fn(image, is_training)
        fpn_features = self._fpn_fn(backbone_features, is_training)

        rpn_score_outputs, rpn_box_outputs = self._rpn_head_fn(fpn_features, is_training)

        model_outputs.update({
            'rpn_score_outputs':
                tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), rpn_score_outputs),
            'rpn_box_outputs':
                tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), rpn_box_outputs),
        })

        input_anchor = anchor.Anchor(
            self._params.model_params.architecture.min_level,
            self._params.model_params.architecture.max_level,
            self._params.anchor.num_scales,
            self._params.anchor.aspect_ratios,
            self._params.anchor.anchor_size,
            (image_height, image_width))

        rpn_rois, _ = self._generate_rois_fn(
            rpn_box_outputs, rpn_score_outputs,
            input_anchor.multilevel_boxes,
            inputs['image_info'][:, 1, :],
            is_training)

        if is_training:
            rpn_rois = tf.stop_gradient(rpn_rois)

            # Sample proposals.
            rpn_rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices = (
                self._sample_rois_fn(rpn_rois, inputs['gt_boxes'],
                                     inputs['gt_classes']))

            # Create bounding box training targets.
            box_targets = box_utils.encode_boxes(matched_gt_boxes, rpn_rois, weights=[10.0, 10.0, 5.0, 5.0])

            # If the target is background, the box target is set to all 0s.
            box_targets = tf.where(
                tf.tile(tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
                        [1, 1, 4]), tf.zeros_like(box_targets), box_targets)

            model_outputs.update({
                'class_targets': matched_gt_classes,
                'box_targets': box_targets,
            })

        roi_features = spatial_transform_ops.multilevel_crop_and_resize(
            fpn_features, rpn_rois, output_size=7)

        class_outputs, box_outputs = self._frcnn_head_fn(roi_features, is_training)

        model_outputs.update({
            'class_outputs':
                tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
                                      class_outputs),
            'box_outputs':
                tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
                                      box_outputs),
        })

        # Add this output to train to make the checkpoint loadable in predict mode.
        # If we skip it in train mode, the heads will be out-of-order and checkpoint
        # loading will fail.
        boxes, scores, classes, valid_detections = self._generate_detections_fn(
            box_outputs, class_outputs, rpn_rois, inputs['image_info'][:, 1:2, :])

        model_outputs.update({
            'num_detections': valid_detections,
            'detection_boxes': boxes,
            'detection_classes': classes,
            'detection_scores': scores,
        })

        if not self._include_mask:
            return model_outputs

        if is_training:
            rpn_rois, classes, mask_targets = self._sample_masks_fn(
                rpn_rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices,
                inputs['gt_masks'])
            mask_targets = tf.stop_gradient(mask_targets)

            classes = tf.cast(classes, tf.int32)

            model_outputs.update({
                'mask_targets': mask_targets,
                'sampled_class_targets': classes,
            })
        else:
            rpn_rois = boxes
            classes = tf.cast(classes, tf.int32)

        mask_roi_features = spatial_transform_ops.multilevel_crop_and_resize(
            fpn_features, rpn_rois, output_size=14)

        mask_outputs = self._mrcnn_head_fn(mask_roi_features, classes, is_training)

        if is_training:
            model_outputs.update({
                'mask_outputs':
                    tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
                                          mask_outputs),
            })
        else:
            model_outputs.update({'detection_masks': tf.nn.sigmoid(mask_outputs)})

        return model_outputs

    def build_loss_fn(self, keras_model, compression_loss_fn):
        filter_fn = self.make_filter_trainable_variables_fn()
        trainable_variables = filter_fn(keras_model.trainable_variables)

        def _total_loss_fn(labels, outputs):
            rpn_score_loss = self._rpn_score_loss_fn(outputs['rpn_score_outputs'],
                                                     labels['rpn_score_targets'])

            rpn_box_loss = self._rpn_box_loss_fn(outputs['rpn_box_outputs'],
                                                 labels['rpn_box_targets'])

            frcnn_class_loss = self._frcnn_class_loss_fn(outputs['class_outputs'],
                                                         outputs['class_targets'])

            frcnn_box_loss = self._frcnn_box_loss_fn(outputs['box_outputs'],
                                                     outputs['class_targets'],
                                                     outputs['box_targets'])

            if self._include_mask:
                mask_loss = self._mask_loss_fn(outputs['mask_outputs'],
                                               outputs['mask_targets'],
                                               outputs['sampled_class_targets'])
            else:
                mask_loss = 0.0

            model_loss = (
                rpn_score_loss + rpn_box_loss + frcnn_class_loss + frcnn_box_loss +
                mask_loss)

            l2_regularization_loss = self.weight_decay_loss(trainable_variables)
            compression_loss = compression_loss_fn()
            total_loss = model_loss + l2_regularization_loss + compression_loss
            return {
                'total_loss': total_loss,
                'loss': total_loss,
                'fast_rcnn_class_loss': frcnn_class_loss,
                'fast_rcnn_box_loss': frcnn_box_loss,
                'mask_loss': mask_loss,
                'model_loss': model_loss,
                'l2_regularization_loss': l2_regularization_loss,
                'compression_loss': compression_loss,
                'rpn_score_loss': rpn_score_loss,
                'rpn_box_loss': rpn_box_loss,
            }

        return _total_loss_fn

    def build_input_layers(self, params, is_training):
        input_shape = params.input_info.sample_size[1:]

        model_batch_size = params.get('model_batch_size', None)
        if is_training:
            batch_size = model_batch_size if model_batch_size else params.batch_size
            input_layer = {
                'image':
                    tf.keras.layers.Input(
                        shape=input_shape,
                        batch_size=batch_size,
                        name='image',
                        dtype=tf.float32),
                'image_info':
                    tf.keras.layers.Input(
                        shape=[4, 2],
                        batch_size=batch_size,
                        name='image_info',
                    ),
                'gt_boxes':
                    tf.keras.layers.Input(
                        shape=[params.maskrcnn_parser.max_num_instances, 4],
                        batch_size=batch_size,
                        name='gt_boxes'),
                'gt_classes':
                    tf.keras.layers.Input(
                        shape=[params.maskrcnn_parser.max_num_instances],
                        batch_size=batch_size,
                        name='gt_classes',
                        dtype=tf.int64),
            }
            if self._include_mask:
                input_layer['gt_masks'] = tf.keras.layers.Input(
                    shape=[params.maskrcnn_parser.max_num_instances,
                           params.maskrcnn_parser.mask_crop_size,
                           params.maskrcnn_parser.mask_crop_size],
                    batch_size=batch_size,
                    name='gt_masks')
        else:
            batch_size = model_batch_size if model_batch_size else params.batch_size
            input_layer = {
                'image':
                    tf.keras.layers.Input(
                        shape=input_shape,
                        batch_size=batch_size,
                        name='image',
                        dtype=tf.float32),
                'image_info':
                    tf.keras.layers.Input(
                        shape=[4, 2],
                        batch_size=batch_size,
                        name='image_info',
                    ),
            }

        return input_layer

    def build_model(self, weights=None, is_training=None):
        input_layers = self.build_input_layers(self._params, is_training)
        outputs = self.model_outputs(input_layers, is_training)
        keras_model = tf.keras.models.Model(inputs=input_layers, outputs=outputs, name='maskrcnn')

        if self._checkpoint_path:
            logger.info('Init backbone')
            init_checkpoint_fn = self.make_restore_checkpoint_fn()
            init_checkpoint_fn(keras_model)

        if weights:
            logger.info('Loaded pretrained weights from {}'.format(weights))
            _restore_baseline_weights(keras_model, weights)

        return keras_model

    def post_processing(self, labels, outputs):
        required_output_fields = ['class_outputs', 'box_outputs']
        for field in required_output_fields:
            if field not in outputs:
                raise ValueError('"%s" is missing in outputs, requried %s found %s' %
                    (field, required_output_fields, outputs.keys()))

        predictions = {
            'image_info': labels['image_info'],
            'num_detections': outputs['num_detections'],
            'detection_boxes': outputs['detection_boxes'],
            'detection_classes': outputs['detection_classes'],
            'detection_scores': outputs['detection_scores'],
        }

        if self._include_mask:
            predictions.update({
                'detection_masks': outputs['detection_masks'],
            })

        if 'groundtruths' in labels:
            predictions['source_id'] = labels['groundtruths']['source_id']
            predictions['gt_source_id'] = labels['groundtruths']['source_id']
            predictions['gt_height'] = labels['groundtruths']['height']
            predictions['gt_width'] = labels['groundtruths']['width']
            predictions['gt_image_info'] = labels['image_info']
            predictions['gt_num_detections'] = labels['groundtruths']['num_detections']
            predictions['gt_boxes'] = labels['groundtruths']['boxes']
            predictions['gt_classes'] = labels['groundtruths']['classes']
            predictions['gt_areas'] = labels['groundtruths']['areas']
            predictions['gt_is_crowds'] = labels['groundtruths']['is_crowds']

        return labels, predictions

    def eval_metrics(self):
        annotation_file = self._params.get('val_json_file', None)
        evaluator = coco_evaluator.COCOEvaluator(annotation_file=annotation_file, include_mask=True)
        return coco_evaluator.MetricWrapper(evaluator)
