"""
 Copyright (c) 2022 Intel Corporation
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
      http://www.apache.org/licenses/LICENSE-2.0
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import functools
import tensorflow as tf

from examples.tensorflow.common.object_detection.ops import nms
from examples.tensorflow.common.object_detection.utils import box_utils


def generate_detections_factory(params):
    """Factory to select function to generate detection."""
    if params.use_batched_nms:
        func = functools.partial(
            _generate_detections_batched,
            max_total_size=params.max_total_size,
            nms_iou_threshold=params.nms_iou_threshold,
            score_threshold=params.score_threshold)
    else:
        func = functools.partial(
            _generate_detections,
            max_total_size=params.max_total_size,
            nms_iou_threshold=params.nms_iou_threshold,
            score_threshold=params.score_threshold,
            pre_nms_num_boxes=params.pre_nms_num_boxes)
    return func


def _select_top_k_scores(scores_in, pre_nms_num_detections):
    """Select top_k scores and indices for each class.

    Args:
      scores_in: a Tensor with shape [batch_size, N, num_classes], which stacks
        class logit outputs on all feature levels. The N is the number of total
        anchors on all levels. The num_classes is the number of classes predicted
        by the model.
      pre_nms_num_detections: Number of candidates before NMS.

    Returns:
      scores and indices: Tensors with shape [batch_size, pre_nms_num_detections,
        num_classes].
    """
    batch_size, num_anchors, num_class = scores_in.get_shape().as_list()
    scores_trans = tf.transpose(scores_in, perm=[0, 2, 1])
    scores_trans = tf.reshape(scores_trans, [-1, num_anchors])

    top_k_scores, top_k_indices = tf.nn.top_k(
        scores_trans, k=pre_nms_num_detections, sorted=True)

    top_k_scores = tf.reshape(top_k_scores,
                              [batch_size, num_class, pre_nms_num_detections])
    top_k_indices = tf.reshape(top_k_indices,
                              [batch_size, num_class, pre_nms_num_detections])

    return tf.transpose(top_k_scores, [0, 2, 1]), tf.transpose(top_k_indices, [0, 2, 1])


def _generate_detections(boxes,
                         scores,
                         max_total_size=100,
                         nms_iou_threshold=0.3,
                         score_threshold=0.05,
                         pre_nms_num_boxes=5000):
    """Generate the final detections given the model outputs.

    This uses classes unrolling with while loop based NMS, could be parralled
    at batch dimension.

    Args:
      boxes: a tensor with shape [batch_size, N, num_classes, 4] or [batch_size,
        N, 1, 4], which box predictions on all feature levels. The N is the number
        of total anchors on all levels.
      scores: a tensor with shape [batch_size, N, num_classes], which stacks class
        probability on all feature levels. The N is the number of total anchors on
        all levels. The num_classes is the number of classes predicted by the
        model. Note that the class_outputs here is the raw score.
      max_total_size: a scalar representing maximum number of boxes retained over
        all classes.
      nms_iou_threshold: a float representing the threshold for deciding whether
        boxes overlap too much with respect to IOU.
      score_threshold: a float representing the threshold for deciding when to
        remove boxes based on score.
      pre_nms_num_boxes: an int number of top candidate detections per class
        before NMS.

    Returns:
      nms_boxes: `float` Tensor of shape [batch_size, max_total_size, 4]
        representing top detected boxes in [y1, x1, y2, x2].
      nms_scores: `float` Tensor of shape [batch_size, max_total_size]
        representing sorted confidence scores for detected boxes. The values are
        between [0, 1].
      nms_classes: `int` Tensor of shape [batch_size, max_total_size] representing
        classes for detected boxes.
      valid_detections: `int` Tensor of shape [batch_size] only the top
        `valid_detections` boxes are valid detections.
    """
    with tf.name_scope('generate_detections'):
        nmsed_boxes = []
        nmsed_classes = []
        nmsed_scores = []
        valid_detections = []
        batch_size, _, num_classes_for_box, _ = boxes.get_shape().as_list()
        _, total_anchors, num_classes = scores.get_shape().as_list()
        # Selects top pre_nms_num scores and indices before NMS.
        scores, indices = _select_top_k_scores(scores, min(total_anchors, pre_nms_num_boxes))

        for i in range(num_classes):
            boxes_i = boxes[:, :, min(num_classes_for_box - 1, i), :]
            scores_i = scores[:, :, i]
            # Obtains pre_nms_num_boxes before running NMS.
            boxes_i = tf.gather(boxes_i, indices[:, :, i], batch_dims=1, axis=1)

            # Filter out scores.
            boxes_i, scores_i = box_utils.filter_boxes_by_scores(
                boxes_i, scores_i, min_score_threshold=score_threshold)

            (nmsed_scores_i, nmsed_boxes_i) = nms.sorted_non_max_suppression_padded(
                tf.cast(scores_i, tf.float32),
                tf.cast(boxes_i, tf.float32),
                max_total_size,
                iou_threshold=nms_iou_threshold)
            nmsed_classes_i = tf.fill([batch_size, max_total_size], i)
            nmsed_boxes.append(nmsed_boxes_i)
            nmsed_scores.append(nmsed_scores_i)
            nmsed_classes.append(nmsed_classes_i)

    nmsed_boxes = tf.concat(nmsed_boxes, 1) # axis=1
    nmsed_scores = tf.concat(nmsed_scores, 1) # axis=1
    nmsed_classes = tf.concat(nmsed_classes, 1) # axis=1
    nmsed_scores, indices = tf.nn.top_k(nmsed_scores, k=max_total_size, sorted=True)
    nmsed_boxes = tf.gather(nmsed_boxes, indices, batch_dims=1, axis=1)
    nmsed_classes = tf.gather(nmsed_classes, indices, batch_dims=1, axis=None)
    valid_detections = tf.reduce_sum(input_tensor=tf.cast(tf.greater(nmsed_scores, -1), tf.int32),
                                     axis=1)
    return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections


def _generate_detections_per_image(boxes,
                                   scores,
                                   max_total_size=100,
                                   nms_iou_threshold=0.3,
                                   score_threshold=0.05,
                                   pre_nms_num_boxes=5000):
    """Generate the final detections per image given the model outputs.

    Args:
      boxes: a tensor with shape [N, num_classes, 4] or [N, 1, 4], which box
        predictions on all feature levels. The N is the number of total anchors on
        all levels.
      scores: a tensor with shape [N, num_classes], which stacks class probability
        on all feature levels. The N is the number of total anchors on all levels.
        The num_classes is the number of classes predicted by the model. Note that
        the class_outputs here is the raw score.
      max_total_size: a scalar representing maximum number of boxes retained over
        all classes.
      nms_iou_threshold: a float representing the threshold for deciding whether
        boxes overlap too much with respect to IOU.
      score_threshold: a float representing the threshold for deciding when to
        remove boxes based on score.
      pre_nms_num_boxes: an int number of top candidate detections per class
        before NMS.

    Returns:
      nms_boxes: `float` Tensor of shape [max_total_size, 4] representing top
        detected boxes in [y1, x1, y2, x2].
      nms_scores: `float` Tensor of shape [max_total_size] representing sorted
        confidence scores for detected boxes. The values are between [0, 1].
      nms_classes: `int` Tensor of shape [max_total_size] representing classes for
        detected boxes.
      valid_detections: `int` Tensor of shape [1] only the top `valid_detections`
        boxes are valid detections.
    """

    nmsed_boxes = []
    nmsed_scores = []
    nmsed_classes = []
    num_classes_for_box = boxes.get_shape().as_list()[1]
    num_classes = scores.get_shape().as_list()[1]

    for i in range(num_classes):
        boxes_i = boxes[:, min(num_classes_for_box - 1, i)]
        scores_i = scores[:, i]

        # Obtains pre_nms_num_boxes before running NMS.
        scores_i, indices = tf.nn.top_k(
            scores_i, k=tf.minimum(tf.shape(input=scores_i)[-1], pre_nms_num_boxes))
        boxes_i = tf.gather(boxes_i, indices, axis=None)

        (nmsed_indices_i, nmsed_num_valid_i) = tf.image.non_max_suppression_padded(
            tf.cast(boxes_i, tf.float32),
            tf.cast(scores_i, tf.float32),
            max_total_size,
            iou_threshold=nms_iou_threshold,
            score_threshold=score_threshold,
            pad_to_max_output_size=True,
            name='nms_detections_' + str(i))
        nmsed_boxes_i = tf.gather(boxes_i, nmsed_indices_i, axis=None)
        nmsed_scores_i = tf.gather(scores_i, nmsed_indices_i, axis=None)
        # Sets scores of invalid boxes to -1.
        nmsed_scores_i = tf.where(
            tf.less(tf.range(max_total_size), [nmsed_num_valid_i]), nmsed_scores_i,
            -1 * tf.ones_like(nmsed_scores_i))
        nmsed_classes_i = tf.fill([max_total_size], i)
        nmsed_boxes.append(nmsed_boxes_i)
        nmsed_scores.append(nmsed_scores_i)
        nmsed_classes.append(nmsed_classes_i)

    # Concats results from all classes and sort them.
    nmsed_boxes = tf.concat(nmsed_boxes, 0) # axis=0
    nmsed_scores = tf.concat(nmsed_scores, 0) # axis=0
    nmsed_classes = tf.concat(nmsed_classes, 0) # axis=0
    nmsed_scores, indices = tf.nn.top_k(nmsed_scores, k=max_total_size, sorted=True)
    nmsed_boxes = tf.gather(nmsed_boxes, indices, axis=None)
    nmsed_classes = tf.gather(nmsed_classes, indices, axis=None)
    valid_detections = tf.reduce_sum(input_tensor=tf.cast(tf.greater(nmsed_scores, -1), tf.int32))

    return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections


def _generate_detections_batched(boxes, scores, max_total_size,
                                 nms_iou_threshold, score_threshold):
    """Generates detected boxes with scores and classes for one-stage detector.

    The function takes output of multi-level ConvNets and anchor boxes and
    generates detected boxes. Note that this used batched nms, which is not
    supported on TPU currently.

    Args:
      boxes: a tensor with shape [batch_size, N, num_classes, 4] or [batch_size,
        N, 1, 4], which box predictions on all feature levels. The N is the number
        of total anchors on all levels.
      scores: a tensor with shape [batch_size, N, num_classes], which stacks class
        probability on all feature levels. The N is the number of total anchors on
        all levels. The num_classes is the number of classes predicted by the
        model. Note that the class_outputs here is the raw score.
      max_total_size: a scalar representing maximum number of boxes retained over
        all classes.
      nms_iou_threshold: a float representing the threshold for deciding whether
        boxes overlap too much with respect to IOU.
      score_threshold: a float representing the threshold for deciding when to
        remove boxes based on score.

    Returns:
      nms_boxes: `float` Tensor of shape [batch_size, max_total_size, 4]
        representing top detected boxes in [y1, x1, y2, x2].
      nms_scores: `float` Tensor of shape [batch_size, max_total_size]
        representing sorted confidence scores for detected boxes. The values are
        between [0, 1].
      nms_classes: `int` Tensor of shape [batch_size, max_total_size] representing
        classes for detected boxes.
      valid_detections: `int` Tensor of shape [batch_size] only the top
        `valid_detections` boxes are valid detections.
    """
    with tf.name_scope('generate_detections'):
        normalizer = tf.reduce_max(boxes)
        boxes /= normalizer
        (nmsed_boxes, nmsed_scores, nmsed_classes,
        valid_detections) = tf.image.combined_non_max_suppression(
            boxes,
            scores,
            max_output_size_per_class=max_total_size,
            max_total_size=max_total_size,
            iou_threshold=nms_iou_threshold,
            score_threshold=score_threshold,
            pad_per_class=False,
        )
        # De-normalizes box cooridinates.
        nmsed_boxes *= normalizer

    nmsed_classes = tf.cast(nmsed_classes, tf.int32)
    return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections


class MultilevelDetectionGenerator:
    """Generates detected boxes with scores and classes for one-stage detector."""

    def __init__(self, min_level, max_level, params):
        self._min_level = min_level
        self._max_level = max_level
        self._generate_detections = generate_detections_factory(params)

    def __call__(self, box_outputs, class_outputs, anchor_boxes, image_shape):
        # Collects outputs from all levels into a list.
        boxes = []
        scores = []
        for i in range(self._min_level, self._max_level + 1):
            s = str(i)
            box_outputs_i_shape = tf.shape(box_outputs[s])
            batch_size = box_outputs_i_shape[0]
            num_anchors_per_locations = box_outputs_i_shape[-1] // 4
            num_classes = tf.shape(class_outputs[s])[-1] // num_anchors_per_locations

            # Applies score transformation and remove the implicit background class.
            scores_i = tf.sigmoid(
                tf.reshape(class_outputs[s], [batch_size, -1, num_classes]))
            scores_i = tf.slice(scores_i, [0, 0, 1], [-1, -1, -1])

            # Box decoding.
            # The anchor boxes are shared for all data in a batch.
            # One stage detector only supports class agnostic box regression.
            anchor_boxes_i = tf.reshape(anchor_boxes[i], [batch_size, -1, 4])
            box_outputs_i = tf.reshape(box_outputs[s], [batch_size, -1, 4])
            boxes_i = box_utils.decode_boxes(box_outputs_i, anchor_boxes_i)

            # Box clipping.
            boxes_i = box_utils.clip_boxes(boxes_i, image_shape)

            boxes.append(boxes_i)
            scores.append(scores_i)

        boxes = tf.concat(boxes, 1) # axis=1
        scores = tf.concat(scores, 1) # axis=1

        nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
            self._generate_detections(tf.expand_dims(boxes, axis=2), scores))

        # Adds 1 to offset the background class which has index 0.
        nmsed_classes += 1
        return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections


class GenericDetectionGenerator:
    """Generates the final detected boxes with scores and classes."""

    def __init__(self, params):
        self._generate_detections = generate_detections_factory(params)

    def __call__(self, box_outputs, class_outputs, anchor_boxes, image_shape):
        """Generate final detections.

        Args:
          box_outputs: a tensor of shape of [batch_size, K, num_classes * 4]
            representing the class-specific box coordinates relative to anchors.
          class_outputs: a tensor of shape of [batch_size, K, num_classes]
            representing the class logits before applying score activiation.
          anchor_boxes: a tensor of shape of [batch_size, K, 4] representing the
            corresponding anchor boxes w.r.t `box_outputs`.
          image_shape: a tensor of shape of [batch_size, 2] storing the image height
            and width w.r.t. the scaled image, i.e. the same image space as
            `box_outputs` and `anchor_boxes`.

        Returns:
          nms_boxes: `float` Tensor of shape [batch_size, max_total_size, 4]
            representing top detected boxes in [y1, x1, y2, x2].
          nms_scores: `float` Tensor of shape [batch_size, max_total_size]
            representing sorted confidence scores for detected boxes. The values are
            between [0, 1].
          nms_classes: `int` Tensor of shape [batch_size, max_total_size]
            representing classes for detected boxes.
          valid_detections: `int` Tensor of shape [batch_size] only the top
            `valid_detections` boxes are valid detections.
        """
        class_outputs = tf.nn.softmax(class_outputs, axis=-1)

        # Removes the background class.
        class_outputs_shape = tf.shape(class_outputs)
        batch_size = class_outputs_shape[0]
        num_locations = class_outputs_shape[1]
        num_classes = class_outputs_shape[-1]
        num_detections = num_locations * (num_classes - 1)

        class_outputs = tf.slice(class_outputs, [0, 0, 1], [-1, -1, -1])
        box_outputs = tf.reshape(
            box_outputs,
            tf.stack([batch_size, num_locations, num_classes, 4], axis=-1))
        box_outputs = tf.slice(box_outputs, [0, 0, 1, 0], [-1, -1, -1, -1])
        anchor_boxes = tf.tile(
            tf.expand_dims(anchor_boxes, axis=2), [1, 1, num_classes - 1, 1])
        box_outputs = tf.reshape(box_outputs,
                                tf.stack([batch_size, num_detections, 4], axis=-1))
        anchor_boxes = tf.reshape(
            anchor_boxes, tf.stack([batch_size, num_detections, 4], axis=-1))

        # Box decoding.
        decoded_boxes = box_utils.decode_boxes(
            box_outputs, anchor_boxes, weights=[10.0, 10.0, 5.0, 5.0])

        # Box clipping
        decoded_boxes = box_utils.clip_boxes(decoded_boxes, image_shape)

        decoded_boxes = tf.reshape(
            decoded_boxes,
            tf.stack([batch_size, num_locations, num_classes - 1, 4], axis=-1))

        nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
            self._generate_detections(decoded_boxes, class_outputs))

        # Adds 1 to offset the background class which has index 0.
        nmsed_classes += 1

        return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
