"""
 Copyright (c) 2022 Intel Corporation
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
      http://www.apache.org/licenses/LICENSE-2.0
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import numpy as np
import tensorflow as tf

EPSILON = 1e-8
BBOX_XFORM_CLIP = np.log(1000. / 16.)


def yxyx_to_xywh(boxes):
    """Converts boxes from ymin, xmin, ymax, xmax to xmin, ymin, width, height.

    Args:
      boxes: a numpy array whose last dimension is 4 representing the coordinates
        of boxes in ymin, xmin, ymax, xmax order.

    Returns:
      boxes: a numpy array whose shape is the same as `boxes` in new format.

    Raises:
      ValueError: If the last dimension of boxes is not 4.
    """
    if boxes.shape[-1] != 4:
        raise ValueError('boxes.shape[-1] is {:d}, but must be 4.'.format(boxes.shape[-1]))

    boxes_ymin = boxes[..., 0]
    boxes_xmin = boxes[..., 1]
    boxes_width = boxes[..., 3] - boxes[..., 1]
    boxes_height = boxes[..., 2] - boxes[..., 0]
    new_boxes = np.stack([boxes_xmin, boxes_ymin, boxes_width, boxes_height], axis=-1)

    return new_boxes


def normalize_boxes(boxes, image_shape):
    """Converts boxes to the normalized coordinates.

    Args:
      boxes: a tensor whose last dimension is 4 representing the coordinates of
        boxes in ymin, xmin, ymax, xmax order.
      image_shape: a list of two integers, a two-element vector or a tensor such
        that all but the last dimensions are `broadcastable` to `boxes`. The last
        dimension is 2, which represents [height, width].

    Returns:
      normalized_boxes: a tensor whose shape is the same as `boxes` representing
        the normalized boxes.

    Raises:
      ValueError: If the last dimension of boxes is not 4.
    """

    if boxes.shape[-1] != 4:
        raise ValueError('boxes.shape[-1] is {:d}, but must be 4.'.format(
            boxes.shape[-1]))

    with tf.name_scope('normalize_boxes'):
        if isinstance(image_shape, (list, tuple)):
            height, width = image_shape
        else:
            image_shape = tf.cast(image_shape, boxes.dtype)
            height = image_shape[..., 0:1]
            width = image_shape[..., 1:2]

        ymin = boxes[..., 0:1] / height
        xmin = boxes[..., 1:2] / width
        ymax = boxes[..., 2:3] / height
        xmax = boxes[..., 3:4] / width

        normalized_boxes = tf.concat([ymin, xmin, ymax, xmax], -1)
        return normalized_boxes


def denormalize_boxes(boxes, image_shape):
    """Converts boxes normalized by [height, width] to pixel coordinates.

    Args:
      boxes: a tensor whose last dimension is 4 representing the coordinates of
        boxes in ymin, xmin, ymax, xmax order.
      image_shape: a list of two integers, a two-element vector or a tensor such
        that all but the last dimensions are `broadcastable` to `boxes`. The last
        dimension is 2, which represents [height, width].

    Returns:
      denormalized_boxes: a tensor whose shape is the same as `boxes` representing
        the denormalized boxes.

    Raises:
      ValueError: If the last dimension of boxes is not 4.
    """
    with tf.name_scope('denormalize_boxes'):
        if isinstance(image_shape, (list, tuple)):
            height, width = image_shape
        else:
            image_shape = tf.cast(image_shape, boxes.dtype)
            height, width = tf.split(image_shape, 2, -1)

        ymin, xmin, ymax, xmax = tf.split(boxes, 4, -1)
        ymin = ymin * height
        xmin = xmin * width
        ymax = ymax * height
        xmax = xmax * width

        denormalized_boxes = tf.concat([ymin, xmin, ymax, xmax], -1)
        return denormalized_boxes


def clip_boxes(boxes, image_shape):
    """Clips boxes to image boundaries.

    Args:
      boxes: a tensor whose last dimension is 4 representing the coordinates of
        boxes in ymin, xmin, ymax, xmax order.
      image_shape: a list of two integers, a two-element vector or a tensor such
        that all but the last dimensions are `broadcastable` to `boxes`. The last
        dimension is 2, which represents [height, width].

    Returns:
      clipped_boxes: a tensor whose shape is the same as `boxes` representing the
        clipped boxes.

    Raises:
      ValueError: If the last dimension of boxes is not 4.
    """
    if boxes.shape[-1] != 4:
        raise ValueError('boxes.shape[-1] is {:d}, but must be 4.'.format(
            boxes.shape[-1]))

    with tf.name_scope('clip_boxes'):
        if isinstance(image_shape, (list, tuple)):
            height, width = image_shape
            max_length = [height - 1.0, width - 1.0, height - 1.0, width - 1.0]
        else:
            image_shape = tf.cast(image_shape, boxes.dtype)
            height, width = tf.unstack(image_shape, axis=-1)
            max_length = tf.stack(
                [height - 1.0, width - 1.0, height - 1.0, width - 1.0], axis=-1)

        clipped_boxes = tf.math.maximum(tf.math.minimum(boxes, max_length), 0.0)
        return clipped_boxes


def encode_boxes(boxes, anchors, weights=None):
    """Encode boxes to targets.

    Args:
        boxes: a tensor whose last dimension is 4 representing the coordinates of
            boxes in ymin, xmin, ymax, xmax order.
        anchors: a tensor whose shape is the same as, or `broadcastable` to `boxes`,
            representing the coordinates of anchors in ymin, xmin, ymax, xmax order.
        weights: None or a list of four float numbers used to scale coordinates.

    Returns:
        encoded_boxes: a tensor whose shape is the same as `boxes` representing the
            encoded box targets.

    Raises:
        ValueError: If the last dimension of boxes is not 4.
    """

    if boxes.shape[-1] != 4:
        raise ValueError('boxes.shape[-1] is {:d}, but must be 4.'.format(boxes.shape[-1]))

    with tf.name_scope('encode_boxes'):
        boxes = tf.cast(boxes, anchors.dtype)
        ymin = boxes[..., 0:1]
        xmin = boxes[..., 1:2]
        ymax = boxes[..., 2:3]
        xmax = boxes[..., 3:4]
        box_h = ymax - ymin + 1.0
        box_w = xmax - xmin + 1.0
        box_yc = ymin + 0.5 * box_h
        box_xc = xmin + 0.5 * box_w

        anchor_ymin = anchors[..., 0:1]
        anchor_xmin = anchors[..., 1:2]
        anchor_ymax = anchors[..., 2:3]
        anchor_xmax = anchors[..., 3:4]
        anchor_h = anchor_ymax - anchor_ymin + 1.0
        anchor_w = anchor_xmax - anchor_xmin + 1.0
        anchor_yc = anchor_ymin + 0.5 * anchor_h
        anchor_xc = anchor_xmin + 0.5 * anchor_w

        encoded_dy = (box_yc - anchor_yc) / anchor_h
        encoded_dx = (box_xc - anchor_xc) / anchor_w
        encoded_dh = tf.math.log(box_h / anchor_h)
        encoded_dw = tf.math.log(box_w / anchor_w)

        if weights:
            encoded_dy *= weights[0]
            encoded_dx *= weights[1]
            encoded_dh *= weights[2]
            encoded_dw *= weights[3]

        encoded_boxes = tf.concat([encoded_dy, encoded_dx, encoded_dh, encoded_dw], -1)

        return encoded_boxes


def decode_boxes(encoded_boxes, anchors, weights=None):
    """Decode boxes.

    Args:
      encoded_boxes: a tensor whose last dimension is 4 representing the
        coordinates of encoded boxes in ymin, xmin, ymax, xmax order.
      anchors: a tensor whose shape is the same as, or `broadcastable` to `boxes`,
        representing the coordinates of anchors in ymin, xmin, ymax, xmax order.
      weights: None or a list of four float numbers used to scale coordinates.

    Returns:
      encoded_boxes: a tensor whose shape is the same as `boxes` representing the
        decoded box targets.
    """
    if encoded_boxes.shape[-1] != 4:
        raise ValueError('encoded_boxes.shape[-1] is {:d}, but must be 4.'.format(
            encoded_boxes.shape[-1]))

    with tf.name_scope('decode_boxes'):
        encoded_boxes = tf.cast(encoded_boxes, anchors.dtype)
        dy = encoded_boxes[..., 0:1]
        dx = encoded_boxes[..., 1:2]
        dh = encoded_boxes[..., 2:3]
        dw = encoded_boxes[..., 3:4]
        if weights:
            dy /= weights[0]
            dx /= weights[1]
            dh /= weights[2]
            dw /= weights[3]
        dh = tf.math.minimum(dh, BBOX_XFORM_CLIP)
        dw = tf.math.minimum(dw, BBOX_XFORM_CLIP)

        anchor_ymin = anchors[..., 0:1]
        anchor_xmin = anchors[..., 1:2]
        anchor_ymax = anchors[..., 2:3]
        anchor_xmax = anchors[..., 3:4]
        anchor_h = anchor_ymax - anchor_ymin + 1.0
        anchor_w = anchor_xmax - anchor_xmin + 1.0
        anchor_yc = anchor_ymin + 0.5 * anchor_h
        anchor_xc = anchor_xmin + 0.5 * anchor_w

        decoded_boxes_yc = dy * anchor_h + anchor_yc
        decoded_boxes_xc = dx * anchor_w + anchor_xc
        decoded_boxes_h = tf.math.exp(dh) * anchor_h
        decoded_boxes_w = tf.math.exp(dw) * anchor_w

        decoded_boxes_ymin = decoded_boxes_yc - 0.5 * decoded_boxes_h
        decoded_boxes_xmin = decoded_boxes_xc - 0.5 * decoded_boxes_w
        decoded_boxes_ymax = decoded_boxes_ymin + decoded_boxes_h - 1.0
        decoded_boxes_xmax = decoded_boxes_xmin + decoded_boxes_w - 1.0

        decoded_boxes = tf.concat([
            decoded_boxes_ymin, decoded_boxes_xmin, decoded_boxes_ymax,
            decoded_boxes_xmax
        ], -1)

        return decoded_boxes


def filter_boxes(boxes, scores, image_shape, min_size_threshold):
    """Filter and remove boxes that are too small or fall outside the image.

    Args:
      boxes: a tensor whose last dimension is 4 representing the coordinates of
        boxes in ymin, xmin, ymax, xmax order.
      scores: a tensor whose shape is the same as tf.shape(boxes)[:-1]
        representing the original scores of the boxes.
      image_shape: a tensor whose shape is the same as, or `broadcastable` to
        `boxes` except the last dimension, which is 2, representing [height,
        width] of the scaled image.
      min_size_threshold: a float representing the minimal box size in each side
        (w.r.t. the scaled image). Boxes whose sides are smaller than it will be
        filtered out.

    Returns:
      filtered_boxes: a tensor whose shape is the same as `boxes` but with
        the position of the filtered boxes are filled with 0.
      filtered_scores: a tensor whose shape is the same as 'scores' but with
        the positinon of the filtered boxes filled with 0.
    """
    if boxes.shape[-1] != 4:
        raise ValueError('boxes.shape[1] is {:d}, but must be 4.'.format(
            boxes.shape[-1]))

    with tf.name_scope('filter_boxes'):
        if isinstance(image_shape, (list, tuple)):
            height, width = image_shape
        else:
            image_shape = tf.cast(image_shape, boxes.dtype)
            height = image_shape[..., 0]
            width = image_shape[..., 1]

        ymin = boxes[..., 0]
        xmin = boxes[..., 1]
        ymax = boxes[..., 2]
        xmax = boxes[..., 3]

        h = ymax - ymin + 1.0
        w = xmax - xmin + 1.0
        yc = ymin + 0.5 * h
        xc = xmin + 0.5 * w

        min_size = tf.cast(
            tf.math.maximum(min_size_threshold, 1.0), boxes.dtype)

        filtered_size_mask = tf.math.logical_and(
            tf.math.greater(h, min_size), tf.math.greater(w, min_size))
        filtered_center_mask = tf.logical_and(
            tf.math.logical_and(tf.math.greater(yc, 0.0), tf.math.less(yc, height)),
            tf.math.logical_and(tf.math.greater(xc, 0.0), tf.math.less(xc, width)))
        filtered_mask = tf.math.logical_and(filtered_size_mask,
                                            filtered_center_mask)

        filtered_scores = tf.where(filtered_mask, scores, tf.zeros_like(scores))
        filtered_boxes = tf.cast(
            tf.expand_dims(filtered_mask, axis=-1), boxes.dtype) * boxes

        return filtered_boxes, filtered_scores


def filter_boxes_by_scores(boxes, scores, min_score_threshold):
    """Filter and remove boxes whose scores are smaller than the threshold.

    Args:
      boxes: a tensor whose last dimension is 4 representing the coordinates of
        boxes in ymin, xmin, ymax, xmax order.
      scores: a tensor whose shape is the same as tf.shape(boxes)[:-1]
        representing the original scores of the boxes.
      min_score_threshold: a float representing the minimal box score threshold.
        Boxes whose score are smaller than it will be filtered out.

    Returns:
      filtered_boxes: a tensor whose shape is the same as `boxes` but with
        the position of the filtered boxes are filled with -1.
      filtered_scores: a tensor whose shape is the same as 'scores' but with
        the
    """
    if boxes.shape[-1] != 4:
        raise ValueError('boxes.shape[1] is {:d}, but must be 4.'.format(
            boxes.shape[-1]))

    with tf.name_scope('filter_boxes_by_scores'):
        filtered_mask = tf.math.greater(scores, min_score_threshold)
        filtered_scores = tf.where(filtered_mask, scores, -1 * tf.ones_like(scores))
        filtered_boxes = tf.cast(
            tf.expand_dims(filtered_mask, axis=-1), boxes.dtype) * boxes

        return filtered_boxes, filtered_scores


def bbox_overlap(boxes, gt_boxes):
    """Calculates the overlap between proposal and ground truth boxes.

    Some `gt_boxes` may have been padded.  The returned `iou` tensor for these
    boxes will be -1.

    Args:
      boxes: a tensor with a shape of [batch_size, N, 4]. N is the number of
        proposals before groundtruth assignment (e.g., rpn_post_nms_topn). The
        last dimension is the pixel coordinates in [ymin, xmin, ymax, xmax] form.
      gt_boxes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES, 4]. This
        tensor might have paddings with a negative value.

    Returns:
      iou: a tensor with as a shape of [batch_size, N, MAX_NUM_INSTANCES].
    """

    with tf.name_scope('bbox_overlap'):
        bb_y_min, bb_x_min, bb_y_max, bb_x_max = tf.split(boxes, 4, 2)
        gt_y_min, gt_x_min, gt_y_max, gt_x_max = tf.split(gt_boxes, 4, 2)

        # Calculates the intersection area.
        i_xmin = tf.math.maximum(bb_x_min, tf.transpose(gt_x_min, [0, 2, 1]))
        i_xmax = tf.math.minimum(bb_x_max, tf.transpose(gt_x_max, [0, 2, 1]))
        i_ymin = tf.math.maximum(bb_y_min, tf.transpose(gt_y_min, [0, 2, 1]))
        i_ymax = tf.math.minimum(bb_y_max, tf.transpose(gt_y_max, [0, 2, 1]))
        i_area = tf.math.maximum((i_xmax - i_xmin), 0) * tf.math.maximum(
            (i_ymax - i_ymin), 0)

        # Calculates the union area.
        bb_area = (bb_y_max - bb_y_min) * (bb_x_max - bb_x_min)
        gt_area = (gt_y_max - gt_y_min) * (gt_x_max - gt_x_min)
        # Adds a small epsilon to avoid divide-by-zero.
        u_area = bb_area + tf.transpose(gt_area, [0, 2, 1]) - i_area + 1e-8

        # Calculates IoU.
        iou = i_area / u_area

        # Fills -1 for IoU entries between the padded ground truth boxes.
        gt_invalid_mask = tf.less(
            tf.reduce_max(gt_boxes, axis=-1, keepdims=True), 0.0)
        padding_mask = tf.logical_or(
            tf.zeros_like(bb_x_min, dtype=tf.bool),
            tf.transpose(gt_invalid_mask, [0, 2, 1]))
        iou = tf.where(padding_mask, -1 * tf.ones_like(iou), iou)

        return iou


def get_non_empty_box_indices(boxes):
    """Get indices for non-empty boxes."""
    # Selects indices if box height or width is 0.
    height = boxes[:, 2] - boxes[:, 0]
    width = boxes[:, 3] - boxes[:, 1]
    indices = tf.where(
        tf.logical_and(tf.greater(height, 0), tf.greater(width, 0)))
    return indices[:, 0]


def top_k_boxes(boxes, scores, k):
    """Sort and select top k boxes according to the scores.

    Args:
        boxes: a tensor of shape [batch_size, N, 4] representing the coordiante of
            the boxes. N is the number of boxes per image.
        scores: a tensor of shsape [batch_size, N] representing the socre of the
            boxes.
        k: an integer or a tensor indicating the top k number.

    Returns:
        selected_boxes: a tensor of shape [batch_size, k, 4] representing the
            selected top k box coordinates.
        selected_scores: a tensor of shape [batch_size, k] representing the selected
            top k box scores.
    """
    with tf.name_scope('top_k_boxes'):
        selected_scores, top_k_indices = tf.nn.top_k(scores, k=k, sorted=True)

        batch_size, _ = scores.get_shape().as_list()
        if batch_size == 1:
            selected_boxes = tf.squeeze(
                tf.gather(boxes, top_k_indices, axis=1), axis=1)
        else:
            top_k_indices_shape = tf.shape(top_k_indices)
            batch_indices = (
                tf.expand_dims(tf.range(top_k_indices_shape[0]), axis=-1) *
                tf.ones([1, top_k_indices_shape[-1]], dtype=tf.int32))
            gather_nd_indices = tf.stack([batch_indices, top_k_indices], axis=-1)
            selected_boxes = tf.gather_nd(boxes, gather_nd_indices)

        return selected_boxes, selected_scores
