# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Dict, List
import torch
import torch.nn.functional as F
from torch import nn

from fsdet.layers import ShapeSpec
from fsdet.utils.registry import Registry

from ..anchor_generator import build_anchor_generator
from ..box_regression import Box2BoxTransform
from ..matcher import Matcher
from .build import PROPOSAL_GENERATOR_REGISTRY
from .rpn_outputs import RPNOutputs, find_top_rpn_proposals

RPN_HEAD_REGISTRY = Registry("RPN_HEAD")
"""
Registry for RPN heads, which take feature maps and perform
objectness classification and bounding box regression for anchors.
"""


def build_rpn_head(cfg, input_shape):
    """
    Build an RPN head defined by `cfg.MODEL.RPN.HEAD_NAME`.
    """
    name = cfg.MODEL.RPN.HEAD_NAME
    return RPN_HEAD_REGISTRY.get(name)(cfg, input_shape)


def _init_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.normal_(m.weight, std=0.01)
        nn.init.constant_(m.bias, 0)


@RPN_HEAD_REGISTRY.register()
class StandardRPNHead(nn.Module):
    """
    RPN classification and regression heads. Uses a 3x3 conv to produce a shared
    hidden state from which one 1x1 conv predicts objectness logits for each anchor
    and a second 1x1 conv predicts bounding-box deltas specifying how to deform
    each anchor into an object proposal.
    """

    def __init__(self, cfg, input_shape: List[ShapeSpec]):
        super().__init__()

        # Standard RPN is shared across levels:
        in_channels = [s.channels for s in input_shape]
        assert len(set(in_channels)) == 1, "Each level must have the same channel!"
        in_channels = in_channels[0]

        # RPNHead should take the same input as anchor generator
        # NOTE: it assumes that creating an anchor generator does not have unwanted side effect.
        anchor_generator = build_anchor_generator(cfg, input_shape)
        num_cell_anchors = anchor_generator.num_cell_anchors
        box_dim = anchor_generator.box_dim
        assert (
            len(set(num_cell_anchors)) == 1
        ), "Each level must have the same number of cell anchors"
        num_cell_anchors = num_cell_anchors[0] # 3 in voc

        # 3x3 conv for the hidden representation
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        # 1x1 conv for predicting objectness logits
        self.objectness_logits = nn.Conv2d(in_channels, num_cell_anchors, kernel_size=1, stride=1)
        # 1x1 conv for predicting box2box transform deltas
        self.anchor_deltas = nn.Conv2d(
            in_channels, num_cell_anchors * box_dim, kernel_size=1, stride=1
        )
        if cfg.MODEL.RPN.RPN_TYPE == 'score_conv':
            self.objectness_logits_group = nn.ModuleList(
                [nn.Conv2d(in_channels, num_cell_anchors, kernel_size=1, stride=1) for i in range(cfg.MODEL.RPN.NUM_RPNS)])

        for l in [self.conv, self.objectness_logits, self.anchor_deltas]:
            nn.init.normal_(l.weight, std=0.01)
            nn.init.constant_(l.bias, 0)
            if cfg.MODEL.RPN.RPN_TYPE == 'score_conv':
                self.objectness_logits_group.apply(_init_weights)

        self.cfg = cfg
        self.num_cell_anchors = num_cell_anchors


    def forward(self, features):
        """
        Args:
            features (list[Tensor]): list of feature maps
        """
        pred_objectness_logits = []
        pred_anchor_deltas = []
        if self.cfg.MODEL.RPN.RPN_TYPE != "":
            pred_objectness_logits_all = [[] for i in range(self.cfg.MODEL.RPN.NUM_RPNS)]
        for x in features:
            t = F.relu(self.conv(x))
            if self.cfg.MODEL.RPN.RPN_TYPE != "":
                for i in range(self.cfg.MODEL.RPN.NUM_RPNS):
                    pred_objectness_logits_all[i].append(self.objectness_logits_group[i](t))
            else:
                pred_objectness_logits.append(self.objectness_logits(t)) 
            pred_anchor_deltas.append(self.anchor_deltas(t))

        if self.cfg.MODEL.RPN.RPN_TYPE != "":
            bs = pred_objectness_logits_all[0][0].shape[0]
            for i in range(len(features)):
                n_samples = torch.numel(pred_objectness_logits_all[0][i][0])
                f_shape = pred_objectness_logits_all[0][i].shape
                pred_objectness_logits_all_f = [pred_objectness_logits_all[j][i] for j in range(self.cfg.MODEL.RPN.NUM_RPNS)]
                pred_objectness_logits_all_f = torch.stack(pred_objectness_logits_all_f, -1)
                # RPN selection
                pred_objectness_logits_f_close0 = torch.sigmoid(pred_objectness_logits_all_f)
                pred_objectness_logits_f_close1 = 1 - pred_objectness_logits_f_close0
                pred_objectness_logits_f_4sample_close = torch.stack([pred_objectness_logits_f_close0, pred_objectness_logits_f_close1], -2) # [..., 2, NUM_RPNs]
                pred_objectness_logits_f_4sample_close = torch.min(pred_objectness_logits_f_4sample_close, -2)[0]
                indices = torch.min(pred_objectness_logits_f_4sample_close, -1)[1]
                pred_objectness_logits_slices = []
                for bi in range(bs): # to save memory
                    pred_objectness_logits_all_f_bi = pred_objectness_logits_all_f[bi].view(-1, self.cfg.MODEL.RPN.NUM_RPNS)
                    pred_objectness_logits_slice = pred_objectness_logits_all_f_bi[torch.arange(n_samples), indices[bi].view(-1)]
                    pred_objectness_logits_slices.append(pred_objectness_logits_slice)
                pred_objectness_logits_f = torch.stack(pred_objectness_logits_slices, 0).view(f_shape)
                pred_objectness_logits.append(pred_objectness_logits_f)
            return pred_objectness_logits, pred_anchor_deltas, pred_objectness_logits_all
        return pred_objectness_logits, pred_anchor_deltas


@PROPOSAL_GENERATOR_REGISTRY.register()
class RPN(nn.Module):
    """
    Region Proposal Network, introduced by the Faster R-CNN paper.
    """

    def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]):
        super().__init__()

        # fmt: off
        self.min_box_side_len        = cfg.MODEL.PROPOSAL_GENERATOR.MIN_SIZE
        self.in_features             = cfg.MODEL.RPN.IN_FEATURES
        self.nms_thresh              = cfg.MODEL.RPN.NMS_THRESH
        self.batch_size_per_image    = cfg.MODEL.RPN.BATCH_SIZE_PER_IMAGE
        self.positive_fraction       = cfg.MODEL.RPN.POSITIVE_FRACTION
        self.smooth_l1_beta          = cfg.MODEL.RPN.SMOOTH_L1_BETA
        self.loss_weight             = cfg.MODEL.RPN.LOSS_WEIGHT
        # fmt: on

        # Map from self.training state to train/test settings
        self.pre_nms_topk = {
            True: cfg.MODEL.RPN.PRE_NMS_TOPK_TRAIN,
            False: cfg.MODEL.RPN.PRE_NMS_TOPK_TEST,
        }
        self.post_nms_topk = {
            True: cfg.MODEL.RPN.POST_NMS_TOPK_TRAIN,
            False: cfg.MODEL.RPN.POST_NMS_TOPK_TEST,
        }
        self.boundary_threshold = cfg.MODEL.RPN.BOUNDARY_THRESH

        self.anchor_generator = build_anchor_generator(
            cfg, [input_shape[f] for f in self.in_features]
        )
        self.box2box_transform = Box2BoxTransform(weights=cfg.MODEL.RPN.BBOX_REG_WEIGHTS)
        self.anchor_matcher = Matcher(
            cfg.MODEL.RPN.IOU_THRESHOLDS, cfg.MODEL.RPN.IOU_LABELS, allow_low_quality_matches=True
        )
        self.rpn_head = build_rpn_head(cfg, [input_shape[f] for f in self.in_features])
        self.cfg = cfg

    def forward(self, images, features, gt_instances=None):
        """
        Args:
            images (ImageList): input images of length `N`
            features (dict[str: Tensor]): input data as a mapping from feature
                map name to tensor. Axis 0 represents the number of images `N` in
                the input data; axes 1-3 are channels, height, and width, which may
                vary between feature maps (e.g., if a feature pyramid is used).
            gt_instances (list[Instances], optional): a length `N` list of `Instances`s.
                Each `Instances` stores ground-truth instances for the corresponding image.

        Returns:
            proposals: list[Instances] or None
            loss: dict[Tensor]
        """
        gt_boxes = [x.gt_boxes for x in gt_instances] if gt_instances is not None else None
        del gt_instances
        features = [features[f] for f in self.in_features]
        if self.cfg.MODEL.RPN.RPN_TYPE != "":
            pred_objectness_logits, pred_anchor_deltas, pred_objectness_logits_all = self.rpn_head(features)
        else:
            pred_objectness_logits, pred_anchor_deltas = self.rpn_head(features)
        anchors = self.anchor_generator(features)
        # TODO: The anchors only depend on the feature map shape; there's probably
        # an opportunity for some optimizations (e.g., caching anchors).
        if self.cfg.MODEL.RPN.RPN_TYPE != "":
            outputs = RPNOutputs(
                self.cfg,
                self.training,
                self.box2box_transform,
                self.anchor_matcher,
                self.batch_size_per_image,
                self.positive_fraction,
                images,
                pred_objectness_logits,
                pred_anchor_deltas,
                anchors,
                self.boundary_threshold,
                gt_boxes,
                self.smooth_l1_beta,
                pred_objectness_logits_all,
            )
        else:
            outputs = RPNOutputs(
                self.cfg,
                self.training,
                self.box2box_transform,
                self.anchor_matcher,
                self.batch_size_per_image,
                self.positive_fraction,
                images,
                pred_objectness_logits,
                pred_anchor_deltas,
                anchors,
                self.boundary_threshold,
                gt_boxes,
                self.smooth_l1_beta,
            )

        if self.training:
            losses = {k: v * self.loss_weight for k, v in outputs.losses().items()}
        else:
            losses = {}

        with torch.no_grad():
            # Find the top proposals by applying NMS and removing boxes that
            # are too small. The proposals are treated as fixed for approximate
            # joint training with roi heads. This approach ignores the derivative
            # w.r.t. the proposal boxes’ coordinates that are also network
            # responses, so is approximate.
            proposals = find_top_rpn_proposals(
                outputs.predict_proposals(),
                outputs.predict_objectness_logits(),
                images,
                self.nms_thresh,
                self.pre_nms_topk[self.training],
                self.post_nms_topk[self.training],
                self.min_box_side_len,
                self.training,
            )
            # For RPN-only models, the proposals are the final output and we return them in
            # high-to-low confidence order.
            # For end-to-end models, the RPN proposals are an intermediate state
            # and this sorting is actually not needed. But the cost is negligible.
            inds = [p.objectness_logits.sort(descending=True)[1] for p in proposals]
            proposals = [p[ind] for p, ind in zip(proposals, inds)]

        return proposals, losses
