"""
 Copyright (c) 2022 Intel Corporation
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at
      http://www.apache.org/licenses/LICENSE-2.0
 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

from abc import ABCMeta
from abc import abstractmethod

import tensorflow as tf

# Box coder types.
FASTER_RCNN = 'faster_rcnn'
KEYPOINT = 'keypoint'
MEAN_STDDEV = 'mean_stddev'
SQUARE = 'square'


class BoxCoder:
    """Abstract base class for box coder."""
    __metaclass__ = ABCMeta

    @property
    @abstractmethod
    def code_size(self):
        """Return the size of each code.

        This number is a constant and should agree with the output of the `encode`
        op (e.g. if rel_codes is the output of self.encode(...), then it should have
        shape [N, code_size()]).  This abstractproperty should be overridden by
        implementations.

        Returns:
          an integer constant
        """

    def encode(self, boxes, anchors):
        """Encode a box list relative to an anchor collection.

        Args:
          boxes: BoxList holding N boxes to be encoded
          anchors: BoxList of N anchors

        Returns:
          a tensor representing N relative-encoded boxes
        """
        with tf.name_scope('Encode'):
            return self._encode(boxes, anchors)

    def decode(self, rel_codes, anchors):
        """Decode boxes that are encoded relative to an anchor collection.

        Args:
          rel_codes: a tensor representing N relative-encoded boxes
          anchors: BoxList of anchors

        Returns:
          boxlist: BoxList holding N boxes encoded in the ordinary way (i.e.,
            with corners y_min, x_min, y_max, x_max)
        """
        with tf.name_scope('Decode'):
            return self._decode(rel_codes, anchors)

    @abstractmethod
    def _encode(self, boxes, anchors):
        """Method to be overriden by implementations.

        Args:
          boxes: BoxList holding N boxes to be encoded
          anchors: BoxList of N anchors

        Returns:
          a tensor representing N relative-encoded boxes
        """

    @abstractmethod
    def _decode(self, rel_codes, anchors):
        """Method to be overriden by implementations.

        Args:
          rel_codes: a tensor representing N relative-encoded boxes
          anchors: BoxList of anchors

        Returns:
          boxlist: BoxList holding N boxes encoded in the ordinary way (i.e.,
            with corners y_min, x_min, y_max, x_max)
        """


def batch_decode(encoded_boxes, box_coder, anchors):
    """Decode a batch of encoded boxes.

    This op takes a batch of encoded bounding boxes and transforms
    them to a batch of bounding boxes specified by their corners in
    the order of [y_min, x_min, y_max, x_max].

    Args:
      encoded_boxes: a float32 tensor of shape [batch_size, num_anchors,
        code_size] representing the location of the objects.
      box_coder: a BoxCoder object.
      anchors: a BoxList of anchors used to encode `encoded_boxes`.

    Returns:
      decoded_boxes: a float32 tensor of shape [batch_size, num_anchors,
        coder_size] representing the corners of the objects in the order
        of [y_min, x_min, y_max, x_max].

    Raises:
      ValueError: if batch sizes of the inputs are inconsistent, or if
      the number of anchors inferred from encoded_boxes and anchors are
      inconsistent.
    """
    encoded_boxes.get_shape().assert_has_rank(3)
    if encoded_boxes.get_shape()[1].value != anchors.num_boxes_static():
        raise ValueError(
            'The number of anchors inferred from encoded_boxes'
            ' and anchors are inconsistent: shape[1] of encoded_boxes'
            ' %s should be equal to the number of anchors: %s.' %
            (encoded_boxes.get_shape()[1].value, anchors.num_boxes_static()))

    decoded_boxes = tf.stack([
        box_coder.decode(boxes, anchors).get()
        for boxes in tf.unstack(encoded_boxes)
    ])

    return decoded_boxes
