# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple

import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.ops import DeformConv2d
from mmengine.model import normal_init
from torch import Tensor

from mmdet.registry import MODELS
from mmdet.utils import (ConfigType, InstanceList, OptInstanceList,
                         OptMultiConfig)
from ..utils import multi_apply
from .corner_head import CornerHead


@MODELS.register_module()
class CentripetalHead(CornerHead):
    """Head of CentripetalNet: Pursuing High-quality Keypoint Pairs for Object
    Detection.

    CentripetalHead inherits from :class:`CornerHead`. It removes the
    embedding branch and adds guiding shift and centripetal shift branches.
    More details can be found in the `paper
    <https://arxiv.org/abs/2003.09119>`_ .

    Args:
        num_classes (int): Number of categories excluding the background
            category.
        in_channels (int): Number of channels in the input feature map.
        num_feat_levels (int): Levels of feature from the previous module.
            2 for HourglassNet-104 and 1 for HourglassNet-52. HourglassNet-104
            outputs the final feature and intermediate supervision feature and
            HourglassNet-52 only outputs the final feature. Defaults to 2.
        corner_emb_channels (int): Channel of embedding vector. Defaults to 1.
        train_cfg (:obj:`ConfigDict` or dict, optional): Training config.
            Useless in CornerHead, but we keep this variable for
            SingleStageDetector.
        test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
            CornerHead.
        loss_heatmap (:obj:`ConfigDict` or dict): Config of corner heatmap
            loss. Defaults to GaussianFocalLoss.
        loss_embedding (:obj:`ConfigDict` or dict): Config of corner embedding
            loss. Defaults to AssociativeEmbeddingLoss.
        loss_offset (:obj:`ConfigDict` or dict): Config of corner offset loss.
            Defaults to SmoothL1Loss.
        loss_guiding_shift (:obj:`ConfigDict` or dict): Config of
            guiding shift loss. Defaults to SmoothL1Loss.
        loss_centripetal_shift (:obj:`ConfigDict` or dict): Config of
            centripetal shift loss. Defaults to SmoothL1Loss.
       init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
           the initialization.
    """

    def __init__(self,
                 *args,
                 centripetal_shift_channels: int = 2,
                 guiding_shift_channels: int = 2,
                 feat_adaption_conv_kernel: int = 3,
                 loss_guiding_shift: ConfigType = dict(
                     type='SmoothL1Loss', beta=1.0, loss_weight=0.05),
                 loss_centripetal_shift: ConfigType = dict(
                     type='SmoothL1Loss', beta=1.0, loss_weight=1),
                 init_cfg: OptMultiConfig = None,
                 **kwargs) -> None:
        assert init_cfg is None, 'To prevent abnormal initialization ' \
                                 'behavior, init_cfg is not allowed to be set'
        assert centripetal_shift_channels == 2, (
            'CentripetalHead only support centripetal_shift_channels == 2')
        self.centripetal_shift_channels = centripetal_shift_channels
        assert guiding_shift_channels == 2, (
            'CentripetalHead only support guiding_shift_channels == 2')
        self.guiding_shift_channels = guiding_shift_channels
        self.feat_adaption_conv_kernel = feat_adaption_conv_kernel
        super().__init__(*args, init_cfg=init_cfg, **kwargs)
        self.loss_guiding_shift = MODELS.build(loss_guiding_shift)
        self.loss_centripetal_shift = MODELS.build(loss_centripetal_shift)

    def _init_centripetal_layers(self) -> None:
        """Initialize centripetal layers.

        Including feature adaption deform convs (feat_adaption), deform offset
        prediction convs (dcn_off), guiding shift (guiding_shift) and
        centripetal shift ( centripetal_shift). Each branch has two parts:
        prefix `tl_` for top-left and `br_` for bottom-right.
        """
        self.tl_feat_adaption = nn.ModuleList()
        self.br_feat_adaption = nn.ModuleList()
        self.tl_dcn_offset = nn.ModuleList()
        self.br_dcn_offset = nn.ModuleList()
        self.tl_guiding_shift = nn.ModuleList()
        self.br_guiding_shift = nn.ModuleList()
        self.tl_centripetal_shift = nn.ModuleList()
        self.br_centripetal_shift = nn.ModuleList()

        for _ in range(self.num_feat_levels):
            self.tl_feat_adaption.append(
                DeformConv2d(self.in_channels, self.in_channels,
                             self.feat_adaption_conv_kernel, 1, 1))
            self.br_feat_adaption.append(
                DeformConv2d(self.in_channels, self.in_channels,
                             self.feat_adaption_conv_kernel, 1, 1))

            self.tl_guiding_shift.append(
                self._make_layers(
                    out_channels=self.guiding_shift_channels,
                    in_channels=self.in_channels))
            self.br_guiding_shift.append(
                self._make_layers(
                    out_channels=self.guiding_shift_channels,
                    in_channels=self.in_channels))

            self.tl_dcn_offset.append(
                ConvModule(
                    self.guiding_shift_channels,
                    self.feat_adaption_conv_kernel**2 *
                    self.guiding_shift_channels,
                    1,
                    bias=False,
                    act_cfg=None))
            self.br_dcn_offset.append(
                ConvModule(
                    self.guiding_shift_channels,
                    self.feat_adaption_conv_kernel**2 *
                    self.guiding_shift_channels,
                    1,
                    bias=False,
                    act_cfg=None))

            self.tl_centripetal_shift.append(
                self._make_layers(
                    out_channels=self.centripetal_shift_channels,
                    in_channels=self.in_channels))
            self.br_centripetal_shift.append(
                self._make_layers(
                    out_channels=self.centripetal_shift_channels,
                    in_channels=self.in_channels))

    def _init_layers(self) -> None:
        """Initialize layers for CentripetalHead.

        Including two parts: CornerHead layers and CentripetalHead layers
        """
        super()._init_layers()  # using _init_layers in CornerHead
        self._init_centripetal_layers()

    def init_weights(self) -> None:
        super().init_weights()
        for i in range(self.num_feat_levels):
            normal_init(self.tl_feat_adaption[i], std=0.01)
            normal_init(self.br_feat_adaption[i], std=0.01)
            normal_init(self.tl_dcn_offset[i].conv, std=0.1)
            normal_init(self.br_dcn_offset[i].conv, std=0.1)
            _ = [x.conv.reset_parameters() for x in self.tl_guiding_shift[i]]
            _ = [x.conv.reset_parameters() for x in self.br_guiding_shift[i]]
            _ = [
                x.conv.reset_parameters() for x in self.tl_centripetal_shift[i]
            ]
            _ = [
                x.conv.reset_parameters() for x in self.br_centripetal_shift[i]
            ]

    def forward_single(self, x: Tensor, lvl_ind: int) -> List[Tensor]:
        """Forward feature of a single level.

        Args:
            x (Tensor): Feature of a single level.
            lvl_ind (int): Level index of current feature.

        Returns:
            tuple[Tensor]: A tuple of CentripetalHead's output for current
            feature level. Containing the following Tensors:

                - tl_heat (Tensor): Predicted top-left corner heatmap.
                - br_heat (Tensor): Predicted bottom-right corner heatmap.
                - tl_off (Tensor): Predicted top-left offset heatmap.
                - br_off (Tensor): Predicted bottom-right offset heatmap.
                - tl_guiding_shift (Tensor): Predicted top-left guiding shift
                  heatmap.
                - br_guiding_shift (Tensor): Predicted bottom-right guiding
                  shift heatmap.
                - tl_centripetal_shift (Tensor): Predicted top-left centripetal
                  shift heatmap.
                - br_centripetal_shift (Tensor): Predicted bottom-right
                  centripetal shift heatmap.
        """
        tl_heat, br_heat, _, _, tl_off, br_off, tl_pool, br_pool = super(
        ).forward_single(
            x, lvl_ind, return_pool=True)

        tl_guiding_shift = self.tl_guiding_shift[lvl_ind](tl_pool)
        br_guiding_shift = self.br_guiding_shift[lvl_ind](br_pool)

        tl_dcn_offset = self.tl_dcn_offset[lvl_ind](tl_guiding_shift.detach())
        br_dcn_offset = self.br_dcn_offset[lvl_ind](br_guiding_shift.detach())

        tl_feat_adaption = self.tl_feat_adaption[lvl_ind](tl_pool,
                                                          tl_dcn_offset)
        br_feat_adaption = self.br_feat_adaption[lvl_ind](br_pool,
                                                          br_dcn_offset)

        tl_centripetal_shift = self.tl_centripetal_shift[lvl_ind](
            tl_feat_adaption)
        br_centripetal_shift = self.br_centripetal_shift[lvl_ind](
            br_feat_adaption)

        result_list = [
            tl_heat, br_heat, tl_off, br_off, tl_guiding_shift,
            br_guiding_shift, tl_centripetal_shift, br_centripetal_shift
        ]
        return result_list

    def loss_by_feat(
            self,
            tl_heats: List[Tensor],
            br_heats: List[Tensor],
            tl_offs: List[Tensor],
            br_offs: List[Tensor],
            tl_guiding_shifts: List[Tensor],
            br_guiding_shifts: List[Tensor],
            tl_centripetal_shifts: List[Tensor],
            br_centripetal_shifts: List[Tensor],
            batch_gt_instances: InstanceList,
            batch_img_metas: List[dict],
            batch_gt_instances_ignore: OptInstanceList = None) -> dict:
        """Calculate the loss based on the features extracted by the detection
        head.

        Args:
            tl_heats (list[Tensor]): Top-left corner heatmaps for each level
                with shape (N, num_classes, H, W).
            br_heats (list[Tensor]): Bottom-right corner heatmaps for each
                level with shape (N, num_classes, H, W).
            tl_offs (list[Tensor]): Top-left corner offsets for each level
                with shape (N, corner_offset_channels, H, W).
            br_offs (list[Tensor]): Bottom-right corner offsets for each level
                with shape (N, corner_offset_channels, H, W).
            tl_guiding_shifts (list[Tensor]): Top-left guiding shifts for each
                level with shape (N, guiding_shift_channels, H, W).
            br_guiding_shifts (list[Tensor]): Bottom-right guiding shifts for
                each level with shape (N, guiding_shift_channels, H, W).
            tl_centripetal_shifts (list[Tensor]): Top-left centripetal shifts
                for each level with shape (N, centripetal_shift_channels, H,
                W).
            br_centripetal_shifts (list[Tensor]): Bottom-right centripetal
                shifts for each level with shape (N,
                centripetal_shift_channels, H, W).
            batch_gt_instances (list[:obj:`InstanceData`]): Batch of
                gt_instance. It usually includes ``bboxes`` and ``labels``
                attributes.
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
                image size, scaling factor, etc.
            batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
                Specify which bounding boxes can be ignored when computing
                the loss.

        Returns:
            dict[str, Tensor]: A dictionary of loss components. Containing the
            following losses:

                - det_loss (list[Tensor]): Corner keypoint losses of all
                  feature levels.
                - off_loss (list[Tensor]): Corner offset losses of all feature
                  levels.
                - guiding_loss (list[Tensor]): Guiding shift losses of all
                  feature levels.
                - centripetal_loss (list[Tensor]): Centripetal shift losses of
                  all feature levels.
        """
        gt_bboxes = [
            gt_instances.bboxes for gt_instances in batch_gt_instances
        ]
        gt_labels = [
            gt_instances.labels for gt_instances in batch_gt_instances
        ]

        targets = self.get_targets(
            gt_bboxes,
            gt_labels,
            tl_heats[-1].shape,
            batch_img_metas[0]['batch_input_shape'],
            with_corner_emb=self.with_corner_emb,
            with_guiding_shift=True,
            with_centripetal_shift=True)
        mlvl_targets = [targets for _ in range(self.num_feat_levels)]
        [det_losses, off_losses, guiding_losses, centripetal_losses
         ] = multi_apply(self.loss_by_feat_single, tl_heats, br_heats, tl_offs,
                         br_offs, tl_guiding_shifts, br_guiding_shifts,
                         tl_centripetal_shifts, br_centripetal_shifts,
                         mlvl_targets)
        loss_dict = dict(
            det_loss=det_losses,
            off_loss=off_losses,
            guiding_loss=guiding_losses,
            centripetal_loss=centripetal_losses)
        return loss_dict

    def loss_by_feat_single(self, tl_hmp: Tensor, br_hmp: Tensor,
                            tl_off: Tensor, br_off: Tensor,
                            tl_guiding_shift: Tensor, br_guiding_shift: Tensor,
                            tl_centripetal_shift: Tensor,
                            br_centripetal_shift: Tensor,
                            targets: dict) -> Tuple[Tensor, ...]:
        """Calculate the loss of a single scale level based on the features
        extracted by the detection head.

        Args:
            tl_hmp (Tensor): Top-left corner heatmap for current level with
                shape (N, num_classes, H, W).
            br_hmp (Tensor): Bottom-right corner heatmap for current level with
                shape (N, num_classes, H, W).
            tl_off (Tensor): Top-left corner offset for current level with
                shape (N, corner_offset_channels, H, W).
            br_off (Tensor): Bottom-right corner offset for current level with
                shape (N, corner_offset_channels, H, W).
            tl_guiding_shift (Tensor): Top-left guiding shift for current level
                with shape (N, guiding_shift_channels, H, W).
            br_guiding_shift (Tensor): Bottom-right guiding shift for current
                level with shape (N, guiding_shift_channels, H, W).
            tl_centripetal_shift (Tensor): Top-left centripetal shift for
                current level with shape (N, centripetal_shift_channels, H, W).
            br_centripetal_shift (Tensor): Bottom-right centripetal shift for
                current level with shape (N, centripetal_shift_channels, H, W).
            targets (dict): Corner target generated by `get_targets`.

        Returns:
            tuple[torch.Tensor]: Losses of the head's different branches
            containing the following losses:

                - det_loss (Tensor): Corner keypoint loss.
                - off_loss (Tensor): Corner offset loss.
                - guiding_loss (Tensor): Guiding shift loss.
                - centripetal_loss (Tensor): Centripetal shift loss.
        """
        targets['corner_embedding'] = None

        det_loss, _, _, off_loss = super().loss_by_feat_single(
            tl_hmp, br_hmp, None, None, tl_off, br_off, targets)

        gt_tl_guiding_shift = targets['topleft_guiding_shift']
        gt_br_guiding_shift = targets['bottomright_guiding_shift']
        gt_tl_centripetal_shift = targets['topleft_centripetal_shift']
        gt_br_centripetal_shift = targets['bottomright_centripetal_shift']

        gt_tl_heatmap = targets['topleft_heatmap']
        gt_br_heatmap = targets['bottomright_heatmap']
        # We only compute the offset loss at the real corner position.
        # The value of real corner would be 1 in heatmap ground truth.
        # The mask is computed in class agnostic mode and its shape is
        # batch * 1 * width * height.
        tl_mask = gt_tl_heatmap.eq(1).sum(1).gt(0).unsqueeze(1).type_as(
            gt_tl_heatmap)
        br_mask = gt_br_heatmap.eq(1).sum(1).gt(0).unsqueeze(1).type_as(
            gt_br_heatmap)

        # Guiding shift loss
        tl_guiding_loss = self.loss_guiding_shift(
            tl_guiding_shift,
            gt_tl_guiding_shift,
            tl_mask,
            avg_factor=tl_mask.sum())
        br_guiding_loss = self.loss_guiding_shift(
            br_guiding_shift,
            gt_br_guiding_shift,
            br_mask,
            avg_factor=br_mask.sum())
        guiding_loss = (tl_guiding_loss + br_guiding_loss) / 2.0
        # Centripetal shift loss
        tl_centripetal_loss = self.loss_centripetal_shift(
            tl_centripetal_shift,
            gt_tl_centripetal_shift,
            tl_mask,
            avg_factor=tl_mask.sum())
        br_centripetal_loss = self.loss_centripetal_shift(
            br_centripetal_shift,
            gt_br_centripetal_shift,
            br_mask,
            avg_factor=br_mask.sum())
        centripetal_loss = (tl_centripetal_loss + br_centripetal_loss) / 2.0

        return det_loss, off_loss, guiding_loss, centripetal_loss

    def predict_by_feat(self,
                        tl_heats: List[Tensor],
                        br_heats: List[Tensor],
                        tl_offs: List[Tensor],
                        br_offs: List[Tensor],
                        tl_guiding_shifts: List[Tensor],
                        br_guiding_shifts: List[Tensor],
                        tl_centripetal_shifts: List[Tensor],
                        br_centripetal_shifts: List[Tensor],
                        batch_img_metas: Optional[List[dict]] = None,
                        rescale: bool = False,
                        with_nms: bool = True) -> InstanceList:
        """Transform a batch of output features extracted from the head into
        bbox results.

        Args:
            tl_heats (list[Tensor]): Top-left corner heatmaps for each level
                with shape (N, num_classes, H, W).
            br_heats (list[Tensor]): Bottom-right corner heatmaps for each
                level with shape (N, num_classes, H, W).
            tl_offs (list[Tensor]): Top-left corner offsets for each level
                with shape (N, corner_offset_channels, H, W).
            br_offs (list[Tensor]): Bottom-right corner offsets for each level
                with shape (N, corner_offset_channels, H, W).
            tl_guiding_shifts (list[Tensor]): Top-left guiding shifts for each
                level with shape (N, guiding_shift_channels, H, W). Useless in
                this function, we keep this arg because it's the raw output
                from CentripetalHead.
            br_guiding_shifts (list[Tensor]): Bottom-right guiding shifts for
                each level with shape (N, guiding_shift_channels, H, W).
                Useless in this function, we keep this arg because it's the
                raw output from CentripetalHead.
            tl_centripetal_shifts (list[Tensor]): Top-left centripetal shifts
                for each level with shape (N, centripetal_shift_channels, H,
                W).
            br_centripetal_shifts (list[Tensor]): Bottom-right centripetal
                shifts for each level with shape (N,
                centripetal_shift_channels, H, W).
            batch_img_metas (list[dict], optional): Batch image meta info.
                Defaults to None.
            rescale (bool): If True, return boxes in original image space.
                Defaults to False.
            with_nms (bool): If True, do nms before return boxes.
                Defaults to True.

        Returns:
            list[:obj:`InstanceData`]: Object detection results of each image
            after the post process. Each item usually contains following keys.

                - scores (Tensor): Classification scores, has a shape
                  (num_instance, )
                - labels (Tensor): Labels of bboxes, has a shape
                  (num_instances, ).
                - bboxes (Tensor): Has a shape (num_instances, 4),
                  the last dimension 4 arrange as (x1, y1, x2, y2).
        """
        assert tl_heats[-1].shape[0] == br_heats[-1].shape[0] == len(
            batch_img_metas)
        result_list = []
        for img_id in range(len(batch_img_metas)):
            result_list.append(
                self._predict_by_feat_single(
                    tl_heats[-1][img_id:img_id + 1, :],
                    br_heats[-1][img_id:img_id + 1, :],
                    tl_offs[-1][img_id:img_id + 1, :],
                    br_offs[-1][img_id:img_id + 1, :],
                    batch_img_metas[img_id],
                    tl_emb=None,
                    br_emb=None,
                    tl_centripetal_shift=tl_centripetal_shifts[-1][
                        img_id:img_id + 1, :],
                    br_centripetal_shift=br_centripetal_shifts[-1][
                        img_id:img_id + 1, :],
                    rescale=rescale,
                    with_nms=with_nms))

        return result_list
