# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmdet.core import anchor_inside_flags
from ..builder import BBOX_ASSIGNERS
from .assign_result import AssignResult
from .base_assigner import BaseAssigner


def calc_region(bbox, ratio, stride, featmap_size=None):
    """Calculate region of the box defined by the ratio, the ratio is from the
    center of the box to every edge."""
    # project bbox on the feature
    f_bbox = bbox / stride
    x1 = torch.round((1 - ratio) * f_bbox[0] + ratio * f_bbox[2])
    y1 = torch.round((1 - ratio) * f_bbox[1] + ratio * f_bbox[3])
    x2 = torch.round(ratio * f_bbox[0] + (1 - ratio) * f_bbox[2])
    y2 = torch.round(ratio * f_bbox[1] + (1 - ratio) * f_bbox[3])
    if featmap_size is not None:
        x1 = x1.clamp(min=0, max=featmap_size[1])
        y1 = y1.clamp(min=0, max=featmap_size[0])
        x2 = x2.clamp(min=0, max=featmap_size[1])
        y2 = y2.clamp(min=0, max=featmap_size[0])
    return (x1, y1, x2, y2)


def anchor_ctr_inside_region_flags(anchors, stride, region):
    """Get the flag indicate whether anchor centers are inside regions."""
    x1, y1, x2, y2 = region
    f_anchors = anchors / stride
    x = (f_anchors[:, 0] + f_anchors[:, 2]) * 0.5
    y = (f_anchors[:, 1] + f_anchors[:, 3]) * 0.5
    flags = (x >= x1) & (x <= x2) & (y >= y1) & (y <= y2)
    return flags


@BBOX_ASSIGNERS.register_module()
class RegionAssigner(BaseAssigner):
    """Assign a corresponding gt bbox or background to each bbox.

    Each proposals will be assigned with `-1`, `0`, or a positive integer
    indicating the ground truth index.

    - -1: don't care
    - 0: negative sample, no assigned gt
    - positive integer: positive sample, index (1-based) of assigned gt

    Args:
        center_ratio: ratio of the region in the center of the bbox to
            define positive sample.
        ignore_ratio: ratio of the region to define ignore samples.
    """

    def __init__(self, center_ratio=0.2, ignore_ratio=0.5):
        self.center_ratio = center_ratio
        self.ignore_ratio = ignore_ratio

    def assign(self,
               mlvl_anchors,
               mlvl_valid_flags,
               gt_bboxes,
               img_meta,
               featmap_sizes,
               anchor_scale,
               anchor_strides,
               gt_bboxes_ignore=None,
               gt_labels=None,
               allowed_border=0):
        """Assign gt to anchors.

        This method assign a gt bbox to every bbox (proposal/anchor), each bbox
        will be assigned with -1, 0, or a positive number. -1 means don't care,
        0 means negative sample, positive number is the index (1-based) of
        assigned gt.

        The assignment is done in following steps, and the order matters.

        1. Assign every anchor to 0 (negative)
        2. (For each gt_bboxes) Compute ignore flags based on ignore_region
           then assign -1 to anchors w.r.t. ignore flags
        3. (For each gt_bboxes) Compute pos flags based on center_region then
           assign gt_bboxes to anchors w.r.t. pos flags
        4. (For each gt_bboxes) Compute ignore flags based on adjacent anchor
           level then assign -1 to anchors w.r.t. ignore flags
        5. Assign anchor outside of image to -1

        Args:
            mlvl_anchors (list[Tensor]): Multi level anchors.
            mlvl_valid_flags (list[Tensor]): Multi level valid flags.
            gt_bboxes (Tensor): Ground truth bboxes of image
            img_meta (dict): Meta info of image.
            featmap_sizes (list[Tensor]): Feature mapsize each level
            anchor_scale (int): Scale of the anchor.
            anchor_strides (list[int]): Stride of the anchor.
            gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
            gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
                labelled as `ignored`, e.g., crowd boxes in COCO.
            gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
            allowed_border (int, optional): The border to allow the valid
                anchor. Defaults to 0.

        Returns:
            :obj:`AssignResult`: The assign result.
        """
        if gt_bboxes_ignore is not None:
            raise NotImplementedError

        num_gts = gt_bboxes.shape[0]
        num_bboxes = sum(x.shape[0] for x in mlvl_anchors)

        if num_gts == 0 or num_bboxes == 0:
            # No ground truth or boxes, return empty assignment
            max_overlaps = gt_bboxes.new_zeros((num_bboxes, ))
            assigned_gt_inds = gt_bboxes.new_zeros((num_bboxes, ),
                                                   dtype=torch.long)
            if gt_labels is None:
                assigned_labels = None
            else:
                assigned_labels = gt_bboxes.new_full((num_bboxes, ),
                                                     -1,
                                                     dtype=torch.long)
            return AssignResult(
                num_gts,
                assigned_gt_inds,
                max_overlaps,
                labels=assigned_labels)

        num_lvls = len(mlvl_anchors)
        r1 = (1 - self.center_ratio) / 2
        r2 = (1 - self.ignore_ratio) / 2

        scale = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) *
                           (gt_bboxes[:, 3] - gt_bboxes[:, 1]))
        min_anchor_size = scale.new_full(
            (1, ), float(anchor_scale * anchor_strides[0]))
        target_lvls = torch.floor(
            torch.log2(scale) - torch.log2(min_anchor_size) + 0.5)
        target_lvls = target_lvls.clamp(min=0, max=num_lvls - 1).long()

        # 1. assign 0 (negative) by default
        mlvl_assigned_gt_inds = []
        mlvl_ignore_flags = []
        for lvl in range(num_lvls):
            h, w = featmap_sizes[lvl]
            assert h * w == mlvl_anchors[lvl].shape[0]
            assigned_gt_inds = gt_bboxes.new_full((h * w, ),
                                                  0,
                                                  dtype=torch.long)
            ignore_flags = torch.zeros_like(assigned_gt_inds)
            mlvl_assigned_gt_inds.append(assigned_gt_inds)
            mlvl_ignore_flags.append(ignore_flags)

        for gt_id in range(num_gts):
            lvl = target_lvls[gt_id].item()
            featmap_size = featmap_sizes[lvl]
            stride = anchor_strides[lvl]
            anchors = mlvl_anchors[lvl]
            gt_bbox = gt_bboxes[gt_id, :4]

            # Compute regions
            ignore_region = calc_region(gt_bbox, r2, stride, featmap_size)
            ctr_region = calc_region(gt_bbox, r1, stride, featmap_size)

            # 2. Assign -1 to ignore flags
            ignore_flags = anchor_ctr_inside_region_flags(
                anchors, stride, ignore_region)
            mlvl_assigned_gt_inds[lvl][ignore_flags] = -1

            # 3. Assign gt_bboxes to pos flags
            pos_flags = anchor_ctr_inside_region_flags(anchors, stride,
                                                       ctr_region)
            mlvl_assigned_gt_inds[lvl][pos_flags] = gt_id + 1

            # 4. Assign -1 to ignore adjacent lvl
            if lvl > 0:
                d_lvl = lvl - 1
                d_anchors = mlvl_anchors[d_lvl]
                d_featmap_size = featmap_sizes[d_lvl]
                d_stride = anchor_strides[d_lvl]
                d_ignore_region = calc_region(gt_bbox, r2, d_stride,
                                              d_featmap_size)
                ignore_flags = anchor_ctr_inside_region_flags(
                    d_anchors, d_stride, d_ignore_region)
                mlvl_ignore_flags[d_lvl][ignore_flags] = 1
            if lvl < num_lvls - 1:
                u_lvl = lvl + 1
                u_anchors = mlvl_anchors[u_lvl]
                u_featmap_size = featmap_sizes[u_lvl]
                u_stride = anchor_strides[u_lvl]
                u_ignore_region = calc_region(gt_bbox, r2, u_stride,
                                              u_featmap_size)
                ignore_flags = anchor_ctr_inside_region_flags(
                    u_anchors, u_stride, u_ignore_region)
                mlvl_ignore_flags[u_lvl][ignore_flags] = 1

        # 4. (cont.) Assign -1 to ignore adjacent lvl
        for lvl in range(num_lvls):
            ignore_flags = mlvl_ignore_flags[lvl]
            mlvl_assigned_gt_inds[lvl][ignore_flags] = -1

        # 5. Assign -1 to anchor outside of image
        flat_assigned_gt_inds = torch.cat(mlvl_assigned_gt_inds)
        flat_anchors = torch.cat(mlvl_anchors)
        flat_valid_flags = torch.cat(mlvl_valid_flags)
        assert (flat_assigned_gt_inds.shape[0] == flat_anchors.shape[0] ==
                flat_valid_flags.shape[0])
        inside_flags = anchor_inside_flags(flat_anchors, flat_valid_flags,
                                           img_meta['img_shape'],
                                           allowed_border)
        outside_flags = ~inside_flags
        flat_assigned_gt_inds[outside_flags] = -1

        if gt_labels is not None:
            assigned_labels = torch.zeros_like(flat_assigned_gt_inds)
            pos_flags = assigned_gt_inds > 0
            assigned_labels[pos_flags] = gt_labels[
                flat_assigned_gt_inds[pos_flags] - 1]
        else:
            assigned_labels = None

        return AssignResult(
            num_gts, flat_assigned_gt_inds, None, labels=assigned_labels)
