# Copyright (c) OpenMMLab. All rights reserved.
# modified from
# https://github.com/Megvii-BaseDetection/cvpods/blob/master/cvpods/layers/border_align.py

from typing import Tuple

import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable

from ..utils import ext_loader

ext_module = ext_loader.load_ext(
    '_ext', ['border_align_forward', 'border_align_backward'])


class BorderAlignFunction(Function):

    @staticmethod
    def symbolic(g, input, boxes, pool_size):
        return g.op(
            'mmcv::MMCVBorderAlign', input, boxes, pool_size_i=pool_size)

    @staticmethod
    def forward(ctx, input: torch.Tensor, boxes: torch.Tensor,
                pool_size: int) -> torch.Tensor:
        ctx.pool_size = pool_size
        ctx.input_shape = input.size()

        assert boxes.ndim == 3, 'boxes must be with shape [B, H*W, 4]'
        assert boxes.size(2) == 4, \
            'the last dimension of boxes must be (x1, y1, x2, y2)'
        assert input.size(1) % 4 == 0, \
            'the channel for input feature must be divisible by factor 4'

        # [B, C//4, H*W, 4]
        output_shape = (input.size(0), input.size(1) // 4, boxes.size(1), 4)
        output = input.new_zeros(output_shape)
        # `argmax_idx` only used for backward
        argmax_idx = input.new_zeros(output_shape).to(torch.int)

        ext_module.border_align_forward(
            input, boxes, output, argmax_idx, pool_size=ctx.pool_size)

        ctx.save_for_backward(boxes, argmax_idx)
        return output

    @staticmethod
    @once_differentiable
    def backward(ctx,
                 grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
        boxes, argmax_idx = ctx.saved_tensors
        grad_input = grad_output.new_zeros(ctx.input_shape)
        # complex head architecture may cause grad_output uncontiguous
        grad_output = grad_output.contiguous()
        ext_module.border_align_backward(
            grad_output,
            boxes,
            argmax_idx,
            grad_input,
            pool_size=ctx.pool_size)
        return grad_input, None, None


border_align = BorderAlignFunction.apply


class BorderAlign(nn.Module):
    r"""Border align pooling layer.

    Applies border_align over the input feature based on predicted bboxes.
    The details were described in the paper
    `BorderDet: Border Feature for Dense Object Detection
    <https://arxiv.org/abs/2007.11056>`_.

    For each border line (e.g. top, left, bottom or right) of each box,
    border_align does the following:

    1. uniformly samples ``pool_size`` +1 positions on this line, involving
       the start and end points.
    2. the corresponding features on these points are computed by bilinear
       interpolation.
    3. max pooling over all the ``pool_size`` +1 positions are used for
       computing pooled feature.

    Args:
        pool_size (int): number of positions sampled over the boxes' borders
            (e.g. top, bottom, left, right).
    """

    def __init__(self, pool_size: int):
        super().__init__()
        self.pool_size = pool_size

    def forward(self, input: torch.Tensor,
                boxes: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input: Features with shape [N,4C,H,W]. Channels ranged in [0,C),
                [C,2C), [2C,3C), [3C,4C) represent the top, left, bottom,
                right features respectively.
            boxes: Boxes with shape [N,H*W,4]. Coordinate format (x1,y1,x2,y2).

        Returns:
            torch.Tensor: Pooled features with shape [N,C,H*W,4]. The order is
            (top,left,bottom,right) for the last dimension.
        """
        return border_align(input, boxes, self.pool_size)

    def __repr__(self):
        s = self.__class__.__name__
        s += f'(pool_size={self.pool_size})'
        return s
