# Copyright (c) Facebook, Inc. and its affiliates.

import logging
from typing import List, Optional, Tuple
import torch
from fvcore.nn import sigmoid_focal_loss_jit
from torch import Tensor, nn
from torch.nn import functional as F

from detectron2.layers import ShapeSpec, batched_nms
from detectron2.structures import Boxes, ImageList, Instances, pairwise_point_box_distance
from detectron2.utils.events import get_event_storage

from ..anchor_generator import DefaultAnchorGenerator
from ..backbone import Backbone
from ..box_regression import Box2BoxTransformLinear, _dense_box_regression_loss
from .dense_detector import DenseDetector
from .retinanet import RetinaNetHead

__all__ = ["FCOS"]


logger = logging.getLogger(__name__)


class FCOS(DenseDetector):
    """
    Implement FCOS in :paper:`fcos`.
    """

    def __init__(
        self,
        *,
        backbone: Backbone,
        head: nn.Module,
        head_in_features: Optional[List[str]] = None,
        box2box_transform=None,
        num_classes,
        center_sampling_radius: float = 1.5,
        focal_loss_alpha=0.25,
        focal_loss_gamma=2.0,
        test_score_thresh=0.2,
        test_topk_candidates=1000,
        test_nms_thresh=0.6,
        max_detections_per_image=100,
        pixel_mean,
        pixel_std,
    ):
        """
        Args:
            center_sampling_radius: radius of the "center" of a groundtruth box,
                within which all anchor points are labeled positive.
            Other arguments mean the same as in :class:`RetinaNet`.
        """
        super().__init__(
            backbone, head, head_in_features, pixel_mean=pixel_mean, pixel_std=pixel_std
        )

        self.num_classes = num_classes

        # FCOS uses one anchor point per location.
        # We represent the anchor point by a box whose size equals the anchor stride.
        feature_shapes = backbone.output_shape()
        fpn_strides = [feature_shapes[k].stride for k in self.head_in_features]
        self.anchor_generator = DefaultAnchorGenerator(
            sizes=[[k] for k in fpn_strides], aspect_ratios=[1.0], strides=fpn_strides
        )

        # FCOS parameterizes box regression by a linear transform,
        # where predictions are normalized by anchor stride (equal to anchor size).
        if box2box_transform is None:
            box2box_transform = Box2BoxTransformLinear(normalize_by_size=True)
        self.box2box_transform = box2box_transform

        self.center_sampling_radius = float(center_sampling_radius)

        # Loss parameters:
        self.focal_loss_alpha = focal_loss_alpha
        self.focal_loss_gamma = focal_loss_gamma

        # Inference parameters:
        self.test_score_thresh = test_score_thresh
        self.test_topk_candidates = test_topk_candidates
        self.test_nms_thresh = test_nms_thresh
        self.max_detections_per_image = max_detections_per_image

    def forward_training(self, images, features, predictions, gt_instances):
        # Transpose the Hi*Wi*A dimension to the middle:
        pred_logits, pred_anchor_deltas, pred_centerness = self._transpose_dense_predictions(
            predictions, [self.num_classes, 4, 1]
        )
        anchors = self.anchor_generator(features)
        gt_labels, gt_boxes = self.label_anchors(anchors, gt_instances)
        return self.losses(
            anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes, pred_centerness
        )

    @torch.no_grad()
    def match_anchors(self, anchors: List[Boxes], gt_instances: List[Instances]):
        """
        Match anchors with ground truth boxes.

        Args:
            anchors: #level boxes, from the highest resolution to lower resolution
            gt_instances: ground truth instances per image

        Returns:
            List[Tensor]:
                #image tensors, each is a vector of matched gt
                indices (or -1 for unmatched anchors) for all anchors.
        """
        num_anchors_per_level = [len(x) for x in anchors]
        anchors = Boxes.cat(anchors)  # Rx4
        anchor_centers = anchors.get_centers()  # Rx2
        anchor_sizes = anchors.tensor[:, 2] - anchors.tensor[:, 0]  # R

        lower_bound = anchor_sizes * 4
        lower_bound[: num_anchors_per_level[0]] = 0
        upper_bound = anchor_sizes * 8
        upper_bound[-num_anchors_per_level[-1] :] = float("inf")

        matched_indices = []
        for gt_per_image in gt_instances:
            gt_centers = gt_per_image.gt_boxes.get_centers()  # Nx2
            # FCOS with center sampling: anchor point must be close enough to gt center.
            pairwise_match = (anchor_centers[:, None, :] - gt_centers[None, :, :]).abs_().max(
                dim=2
            ).values < self.center_sampling_radius * anchor_sizes[:, None]
            pairwise_dist = pairwise_point_box_distance(anchor_centers, gt_per_image.gt_boxes)

            # The original FCOS anchor matching rule: anchor point must be inside gt
            pairwise_match &= pairwise_dist.min(dim=2).values > 0

            # Multilevel anchor matching in FCOS: each anchor is only responsible
            # for certain scale range.
            pairwise_dist = pairwise_dist.max(dim=2).values
            pairwise_match &= (pairwise_dist > lower_bound[:, None]) & (
                pairwise_dist < upper_bound[:, None]
            )

            # Match the GT box with minimum area, if there are multiple GT matches
            gt_areas = gt_per_image.gt_boxes.area()  # N
            pairwise_match = pairwise_match.to(torch.float32) * (1e8 - gt_areas[None, :])
            min_values, matched_idx = pairwise_match.max(dim=1)  # R, per-anchor match
            matched_idx[min_values < 1e-5] = -1  # Unmatched anchors are assigned -1

            matched_indices.append(matched_idx)
        return matched_indices

    @torch.no_grad()
    def label_anchors(self, anchors, gt_instances):
        """
        Same interface as :meth:`RetinaNet.label_anchors`, but implemented with FCOS
        anchor matching rule.

        Unlike RetinaNet, there are no ignored anchors.
        """
        matched_indices = self.match_anchors(anchors, gt_instances)

        matched_labels, matched_boxes = [], []
        for gt_index, gt_per_image in zip(matched_indices, gt_instances):
            label = gt_per_image.gt_classes[gt_index.clip(min=0)]
            label[gt_index < 0] = self.num_classes  # background

            matched_gt_boxes = gt_per_image.gt_boxes[gt_index.clip(min=0)]

            matched_labels.append(label)
            matched_boxes.append(matched_gt_boxes)
        return matched_labels, matched_boxes

    def losses(
        self, anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes, pred_centerness
    ):
        """
        This method is almost identical to :meth:`RetinaNet.losses`, with an extra
        "loss_centerness" in the returned dict.
        """
        num_images = len(gt_labels)
        gt_labels = torch.stack(gt_labels)  # (N, R)

        pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes)
        num_pos_anchors = pos_mask.sum().item()
        get_event_storage().put_scalar("num_pos_anchors", num_pos_anchors / num_images)
        normalizer = self._ema_update("loss_normalizer", max(num_pos_anchors, 1), 300)

        # classification and regression loss
        gt_labels_target = F.one_hot(gt_labels, num_classes=self.num_classes + 1)[
            :, :, :-1
        ]  # no loss for the last (background) class
        loss_cls = sigmoid_focal_loss_jit(
            torch.cat(pred_logits, dim=1),
            gt_labels_target.to(pred_logits[0].dtype),
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        )

        loss_box_reg = _dense_box_regression_loss(
            anchors,
            self.box2box_transform,
            pred_anchor_deltas,
            [x.tensor for x in gt_boxes],
            pos_mask,
            box_reg_loss_type="giou",
        )

        ctrness_targets = self.compute_ctrness_targets(anchors, gt_boxes)  # NxR
        pred_centerness = torch.cat(pred_centerness, dim=1).squeeze(dim=2)  # NxR
        ctrness_loss = F.binary_cross_entropy_with_logits(
            pred_centerness[pos_mask], ctrness_targets[pos_mask], reduction="sum"
        )
        return {
            "loss_fcos_cls": loss_cls / normalizer,
            "loss_fcos_loc": loss_box_reg / normalizer,
            "loss_fcos_ctr": ctrness_loss / normalizer,
        }

    def compute_ctrness_targets(self, anchors, gt_boxes):  # NxR
        anchors = Boxes.cat(anchors).tensor  # Rx4
        reg_targets = [self.box2box_transform.get_deltas(anchors, m.tensor) for m in gt_boxes]
        reg_targets = torch.stack(reg_targets, dim=0)  # NxRx4
        if len(reg_targets) == 0:
            return reg_targets.new_zeros(len(reg_targets))
        left_right = reg_targets[:, :, [0, 2]]
        top_bottom = reg_targets[:, :, [1, 3]]
        ctrness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (
            top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]
        )
        return torch.sqrt(ctrness)

    def forward_inference(
        self, images: ImageList, features: List[Tensor], predictions: List[List[Tensor]]
    ):
        pred_logits, pred_anchor_deltas, pred_centerness = self._transpose_dense_predictions(
            predictions, [self.num_classes, 4, 1]
        )
        anchors = self.anchor_generator(features)

        results: List[Instances] = []
        for img_idx, image_size in enumerate(images.image_sizes):
            scores_per_image = [
                # Multiply and sqrt centerness & classification scores
                # (See eqn. 4 in https://arxiv.org/abs/2006.09214)
                torch.sqrt(x[img_idx].sigmoid_() * y[img_idx].sigmoid_())
                for x, y in zip(pred_logits, pred_centerness)
            ]
            deltas_per_image = [x[img_idx] for x in pred_anchor_deltas]
            results_per_image = self.inference_single_image(
                anchors, scores_per_image, deltas_per_image, image_size
            )
            results.append(results_per_image)
        return results

    def inference_single_image(
        self,
        anchors: List[Boxes],
        box_cls: List[Tensor],
        box_delta: List[Tensor],
        image_size: Tuple[int, int],
    ):
        """
        Identical to :meth:`RetinaNet.inference_single_image.
        """
        pred = self._decode_multi_level_predictions(
            anchors,
            box_cls,
            box_delta,
            self.test_score_thresh,
            self.test_topk_candidates,
            image_size,
        )
        keep = batched_nms(
            pred.pred_boxes.tensor, pred.scores, pred.pred_classes, self.test_nms_thresh
        )
        return pred[keep[: self.max_detections_per_image]]


class FCOSHead(RetinaNetHead):
    """
    The head used in :paper:`fcos`. It adds an additional centerness
    prediction branch on top of :class:`RetinaNetHead`.
    """

    def __init__(self, *, input_shape: List[ShapeSpec], conv_dims: List[int], **kwargs):
        super().__init__(input_shape=input_shape, conv_dims=conv_dims, num_anchors=1, **kwargs)
        # Unlike original FCOS, we do not add an additional learnable scale layer
        # because it's found to have no benefits after normalizing regression targets by stride.
        self._num_features = len(input_shape)
        self.ctrness = nn.Conv2d(conv_dims[-1], 1, kernel_size=3, stride=1, padding=1)
        torch.nn.init.normal_(self.ctrness.weight, std=0.01)
        torch.nn.init.constant_(self.ctrness.bias, 0)

    def forward(self, features):
        assert len(features) == self._num_features
        logits = []
        bbox_reg = []
        ctrness = []
        for feature in features:
            logits.append(self.cls_score(self.cls_subnet(feature)))
            bbox_feature = self.bbox_subnet(feature)
            bbox_reg.append(self.bbox_pred(bbox_feature))
            ctrness.append(self.ctrness(bbox_feature))
        return logits, bbox_reg, ctrness
