# Copyright (c) OpenMMLab. All rights reserved.
import copy
import math
from typing import Callable, List, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Scale
from mmcv.ops.modulated_deform_conv import ModulatedDeformConv2d
from mmengine.config import ConfigDict
from mmengine.model import BaseModel
from mmengine.structures import InstanceData
from torch import Tensor

try:
    from transformers import BertConfig
except ImportError:
    BertConfig = None

from mmdet.registry import MODELS
from mmdet.structures.bbox import cat_boxes
from mmdet.utils import InstanceList, OptInstanceList, reduce_mean
from ..utils import (BertEncoderLayer, VLFuse, filter_scores_and_topk,
                     permute_and_flatten, select_single_mlvl,
                     unpack_gt_instances)
from ..utils.vlfuse_helper import MAX_CLAMP_VALUE
from .atss_head import ATSSHead


def convert_grounding_to_cls_scores(logits: Tensor,
                                    positive_maps: List[dict]) -> Tensor:
    """Convert logits to class scores."""
    assert len(positive_maps) == logits.shape[0]  # batch size

    scores = torch.zeros(logits.shape[0], logits.shape[1],
                         len(positive_maps[0])).to(logits.device)
    if positive_maps is not None:
        if all(x == positive_maps[0] for x in positive_maps):
            # only need to compute once
            positive_map = positive_maps[0]
            for label_j in positive_map:
                scores[:, :, label_j -
                       1] = logits[:, :,
                                   torch.LongTensor(positive_map[label_j]
                                                    )].mean(-1)
        else:
            for i, positive_map in enumerate(positive_maps):
                for label_j in positive_map:
                    scores[i, :, label_j - 1] = logits[
                        i, :, torch.LongTensor(positive_map[label_j])].mean(-1)
    return scores


class Conv3x3Norm(nn.Module):
    """Conv3x3 and norm."""

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 stride: int,
                 groups: int = 1,
                 use_dcn: bool = False,
                 norm_type: Optional[Union[Sequence, str]] = None):
        super().__init__()

        if use_dcn:
            self.conv = ModulatedDeformConv2d(
                in_channels,
                out_channels,
                kernel_size=3,
                stride=stride,
                padding=1,
                groups=groups)
        else:
            self.conv = nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=3,
                stride=stride,
                padding=1,
                groups=groups)

        if isinstance(norm_type, Sequence):
            assert len(norm_type) == 2
            assert norm_type[0] == 'gn'
            gn_group = norm_type[1]
            norm_type = norm_type[0]

        if norm_type == 'bn':
            bn_op = nn.BatchNorm2d(out_channels)
        elif norm_type == 'gn':
            bn_op = nn.GroupNorm(
                num_groups=gn_group, num_channels=out_channels)
        if norm_type is not None:
            self.bn = bn_op
        else:
            self.bn = None

    def forward(self, x, **kwargs):
        x = self.conv(x, **kwargs)
        if self.bn:
            x = self.bn(x)
        return x


class DyReLU(nn.Module):
    """Dynamic ReLU."""

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 expand_ratio: int = 4):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.expand_ratio = expand_ratio
        self.out_channels = out_channels

        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // expand_ratio),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // expand_ratio,
                      out_channels * self.expand_ratio),
            nn.Hardsigmoid(inplace=True))

    def forward(self, x) -> Tensor:
        x_out = x
        b, c, h, w = x.size()
        x = self.avg_pool(x).view(b, c)
        x = self.fc(x).view(b, -1, 1, 1)

        a1, b1, a2, b2 = torch.split(x, self.out_channels, dim=1)
        a1 = (a1 - 0.5) * 2 + 1.0
        a2 = (a2 - 0.5) * 2
        b1 = b1 - 0.5
        b2 = b2 - 0.5
        out = torch.max(x_out * a1 + b1, x_out * a2 + b2)
        return out


class DyConv(nn.Module):
    """Dynamic Convolution."""

    def __init__(self,
                 conv_func: Callable,
                 in_channels: int,
                 out_channels: int,
                 use_dyfuse: bool = True,
                 use_dyrelu: bool = False,
                 use_dcn: bool = False):
        super().__init__()

        self.dyconvs = nn.ModuleList()
        self.dyconvs.append(conv_func(in_channels, out_channels, 1))
        self.dyconvs.append(conv_func(in_channels, out_channels, 1))
        self.dyconvs.append(conv_func(in_channels, out_channels, 2))

        if use_dyfuse:
            self.attnconv = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Conv2d(in_channels, 1, kernel_size=1),
                nn.ReLU(inplace=True))
            self.h_sigmoid = nn.Hardsigmoid(inplace=True)
        else:
            self.attnconv = None

        if use_dyrelu:
            self.relu = DyReLU(in_channels, out_channels)
        else:
            self.relu = nn.ReLU()

        if use_dcn:
            self.offset = nn.Conv2d(
                in_channels, 27, kernel_size=3, stride=1, padding=1)
        else:
            self.offset = None

        self.init_weights()

    def init_weights(self):
        for m in self.dyconvs.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, 0, 0.01)
                if m.bias is not None:
                    m.bias.data.zero_()
        if self.attnconv is not None:
            for m in self.attnconv.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.normal_(m.weight.data, 0, 0.01)
                    if m.bias is not None:
                        m.bias.data.zero_()

    def forward(self, inputs: dict) -> dict:
        visual_feats = inputs['visual']

        out_vis_feats = []
        for level, feature in enumerate(visual_feats):

            offset_conv_args = {}
            if self.offset is not None:
                offset_mask = self.offset(feature)
                offset = offset_mask[:, :18, :, :]
                mask = offset_mask[:, 18:, :, :].sigmoid()
                offset_conv_args = dict(offset=offset, mask=mask)

            temp_feats = [self.dyconvs[1](feature, **offset_conv_args)]

            if level > 0:
                temp_feats.append(self.dyconvs[2](visual_feats[level - 1],
                                                  **offset_conv_args))
            if level < len(visual_feats) - 1:
                temp_feats.append(
                    F.upsample_bilinear(
                        self.dyconvs[0](visual_feats[level + 1],
                                        **offset_conv_args),
                        size=[feature.size(2),
                              feature.size(3)]))
            mean_feats = torch.mean(
                torch.stack(temp_feats), dim=0, keepdim=False)

            if self.attnconv is not None:
                attn_feat = []
                res_feat = []
                for feat in temp_feats:
                    res_feat.append(feat)
                    attn_feat.append(self.attnconv(feat))

                res_feat = torch.stack(res_feat)
                spa_pyr_attn = self.h_sigmoid(torch.stack(attn_feat))

                mean_feats = torch.mean(
                    res_feat * spa_pyr_attn, dim=0, keepdim=False)

            out_vis_feats.append(mean_feats)

        out_vis_feats = [self.relu(item) for item in out_vis_feats]

        features_dict = {'visual': out_vis_feats, 'lang': inputs['lang']}

        return features_dict


class VLFusionModule(BaseModel):
    """Visual-lang Fusion Module."""

    def __init__(self,
                 in_channels: int,
                 feat_channels: int,
                 num_base_priors: int,
                 early_fuse: bool = False,
                 num_dyhead_blocks: int = 6,
                 lang_model_name: str = 'bert-base-uncased',
                 use_dyrelu: bool = True,
                 use_dyfuse: bool = True,
                 use_dcn: bool = True,
                 use_checkpoint: bool = False,
                 **kwargs) -> None:
        super().__init__(**kwargs)
        if BertConfig is None:
            raise RuntimeError(
                'transformers is not installed, please install it by: '
                'pip install transformers.')
        self.in_channels = in_channels
        self.feat_channels = feat_channels
        self.num_base_priors = num_base_priors
        self.early_fuse = early_fuse
        self.num_dyhead_blocks = num_dyhead_blocks
        self.use_dyrelu = use_dyrelu
        self.use_dyfuse = use_dyfuse
        self.use_dcn = use_dcn
        self.use_checkpoint = use_checkpoint

        self.lang_cfg = BertConfig.from_pretrained(lang_model_name)
        self.lang_dim = self.lang_cfg.hidden_size
        self._init_layers()

    def _init_layers(self) -> None:
        """Initialize layers of the model."""
        bias_value = -math.log((1 - 0.01) / 0.01)

        dyhead_tower = []
        for i in range(self.num_dyhead_blocks):
            if self.early_fuse:
                # cross-modality fusion
                dyhead_tower.append(VLFuse(use_checkpoint=self.use_checkpoint))
                # lang branch
                dyhead_tower.append(
                    BertEncoderLayer(
                        self.lang_cfg,
                        clamp_min_for_underflow=True,
                        clamp_max_for_overflow=True))

            # vision branch
            dyhead_tower.append(
                DyConv(
                    lambda i, o, s: Conv3x3Norm(
                        i, o, s, use_dcn=self.use_dcn, norm_type=['gn', 16]),
                    self.in_channels if i == 0 else self.feat_channels,
                    self.feat_channels,
                    use_dyrelu=(self.use_dyrelu
                                and self.in_channels == self.feat_channels)
                    if i == 0 else self.use_dyrelu,
                    use_dyfuse=(self.use_dyfuse
                                and self.in_channels == self.feat_channels)
                    if i == 0 else self.use_dyfuse,
                    use_dcn=(self.use_dcn
                             and self.in_channels == self.feat_channels)
                    if i == 0 else self.use_dcn,
                ))

        self.add_module('dyhead_tower', nn.Sequential(*dyhead_tower))

        self.bbox_pred = nn.Conv2d(
            self.feat_channels, self.num_base_priors * 4, kernel_size=1)
        self.centerness = nn.Conv2d(
            self.feat_channels, self.num_base_priors * 1, kernel_size=1)
        self.dot_product_projection_text = nn.Linear(
            self.lang_dim,
            self.num_base_priors * self.feat_channels,
            bias=True)
        self.log_scale = nn.Parameter(torch.Tensor([0.0]), requires_grad=True)
        self.bias_lang = nn.Parameter(
            torch.zeros(self.lang_dim), requires_grad=True)
        self.bias0 = nn.Parameter(
            torch.Tensor([bias_value]), requires_grad=True)
        self.scales = nn.ModuleList([Scale(1.0) for _ in range(5)])

    def forward(self, visual_feats: Tuple[Tensor],
                language_feats: dict) -> Tuple:
        feat_inputs = {'visual': visual_feats, 'lang': language_feats}
        dyhead_tower = self.dyhead_tower(feat_inputs)

        if self.early_fuse:
            embedding = dyhead_tower['lang']['hidden']
        else:
            embedding = language_feats['embedded']

        embedding = F.normalize(embedding, p=2, dim=-1)
        dot_product_proj_tokens = self.dot_product_projection_text(embedding /
                                                                   2.0)
        dot_product_proj_tokens_bias = torch.matmul(
            embedding, self.bias_lang) + self.bias0

        bbox_preds = []
        centerness = []
        cls_logits = []

        for i, feature in enumerate(visual_feats):
            visual = dyhead_tower['visual'][i]
            B, C, H, W = visual.shape

            bbox_pred = self.scales[i](self.bbox_pred(visual))
            bbox_preds.append(bbox_pred)
            centerness.append(self.centerness(visual))

            dot_product_proj_queries = permute_and_flatten(
                visual, B, self.num_base_priors, C, H, W)

            bias = dot_product_proj_tokens_bias.unsqueeze(1).repeat(
                1, self.num_base_priors, 1)
            dot_product_logit = (
                torch.matmul(dot_product_proj_queries,
                             dot_product_proj_tokens.transpose(-1, -2)) /
                self.log_scale.exp()) + bias
            dot_product_logit = torch.clamp(
                dot_product_logit, max=MAX_CLAMP_VALUE)
            dot_product_logit = torch.clamp(
                dot_product_logit, min=-MAX_CLAMP_VALUE)
            cls_logits.append(dot_product_logit)

        return bbox_preds, centerness, cls_logits


@MODELS.register_module()
class ATSSVLFusionHead(ATSSHead):
    """ATSS head with visual-language fusion module.

    Args:
        early_fuse (bool): Whether to fuse visual and language features
            Defaults to False.
        use_checkpoint (bool): Whether to use checkpoint. Defaults to False.
        num_dyhead_blocks (int): Number of dynamic head blocks. Defaults to 6.
        lang_model_name (str): Name of the language model.
            Defaults to 'bert-base-uncased'.
    """

    def __init__(self,
                 *args,
                 early_fuse: bool = False,
                 use_checkpoint: bool = False,
                 num_dyhead_blocks: int = 6,
                 lang_model_name: str = 'bert-base-uncased',
                 init_cfg=None,
                 **kwargs):
        super().__init__(*args, **kwargs, init_cfg=init_cfg)
        self.head = VLFusionModule(
            in_channels=self.in_channels,
            feat_channels=self.feat_channels,
            num_base_priors=self.num_base_priors,
            early_fuse=early_fuse,
            use_checkpoint=use_checkpoint,
            num_dyhead_blocks=num_dyhead_blocks,
            lang_model_name=lang_model_name)
        self.text_masks = None

    def _init_layers(self) -> None:
        """No need to initialize the ATSS head layer."""
        pass

    def forward(self, visual_feats: Tuple[Tensor],
                language_feats: dict) -> Tuple[Tensor]:
        """Forward function."""
        bbox_preds, centerness, cls_logits = self.head(visual_feats,
                                                       language_feats)
        return cls_logits, bbox_preds, centerness

    def loss(self, visual_feats: Tuple[Tensor], language_feats: dict,
             batch_data_samples):
        outputs = unpack_gt_instances(batch_data_samples)
        (batch_gt_instances, batch_gt_instances_ignore,
         batch_img_metas) = outputs

        outs = self(visual_feats, language_feats)
        self.text_masks = language_feats['masks']
        loss_inputs = outs + (batch_gt_instances, batch_img_metas,
                              batch_gt_instances_ignore)
        losses = self.loss_by_feat(*loss_inputs)
        return losses

    def loss_by_feat(
            self,
            cls_scores: List[Tensor],
            bbox_preds: List[Tensor],
            centernesses: List[Tensor],
            batch_gt_instances: InstanceList,
            batch_img_metas: List[dict],
            batch_gt_instances_ignore: OptInstanceList = None) -> dict:
        """Calculate the loss based on the features extracted by the detection
        head.

        Args:
            cls_scores (list[Tensor]): Box scores for each scale level
                Has shape (N, num_anchors * num_classes, H, W)
            bbox_preds (list[Tensor]): Box energies / deltas for each scale
                level with shape (N, num_anchors * 4, H, W)
            centernesses (list[Tensor]): Centerness for each scale
                level with shape (N, num_anchors * 1, H, W)
            batch_gt_instances (list[:obj:`InstanceData`]): Batch of
                gt_instance.  It usually includes ``bboxes`` and ``labels``
                attributes.
            batch_img_metas (list[dict]): Meta information of each image, e.g.,
                image size, scaling factor, etc.
            batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
                Batch of gt_instances_ignore. It includes ``bboxes`` attribute
                data that is ignored during training and testing.
                Defaults to None.

        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """
        featmap_sizes = [featmap.size()[-2:] for featmap in bbox_preds]
        assert len(featmap_sizes) == self.prior_generator.num_levels

        device = cls_scores[0].device
        anchor_list, valid_flag_list = self.get_anchors(
            featmap_sizes, batch_img_metas, device=device)

        cls_reg_targets = self.get_targets(
            anchor_list,
            valid_flag_list,
            batch_gt_instances,
            batch_img_metas,
            batch_gt_instances_ignore=batch_gt_instances_ignore)

        (anchor_list, labels_list, label_weights_list, bbox_targets_list,
         bbox_weights_list, avg_factor) = cls_reg_targets
        avg_factor = reduce_mean(
            torch.tensor(avg_factor, dtype=torch.float, device=device)).item()

        anchors = torch.cat(anchor_list, dim=1)
        labels = torch.cat(labels_list, dim=1)
        label_weights = torch.cat(label_weights_list, dim=1)
        bbox_targets = torch.cat(bbox_targets_list, dim=1)
        cls_scores = torch.cat(cls_scores, dim=1)

        centernesses_ = []
        bbox_preds_ = []
        for bbox_pred, centerness in zip(bbox_preds, centernesses):
            centernesses_.append(
                centerness.permute(0, 2, 3,
                                   1).reshape(cls_scores.size(0), -1, 1))
            bbox_preds_.append(
                bbox_pred.permute(0, 2, 3,
                                  1).reshape(cls_scores.size(0), -1, 4))
        bbox_preds = torch.cat(bbox_preds_, dim=1)
        centernesses = torch.cat(centernesses_, dim=1)

        losses_cls, losses_bbox, loss_centerness, bbox_avg_factor = \
            self._loss_by_feat(
                anchors,
                cls_scores,
                bbox_preds,
                centernesses,
                labels,
                label_weights,
                bbox_targets,
                avg_factor=avg_factor)

        bbox_avg_factor = reduce_mean(bbox_avg_factor).clamp_(min=1).item()
        losses_bbox = losses_bbox / bbox_avg_factor
        return dict(
            loss_cls=losses_cls,
            loss_bbox=losses_bbox,
            loss_centerness=loss_centerness)

    def _loss_by_feat(self, anchors: Tensor, cls_score: Tensor,
                      bbox_pred: Tensor, centerness: Tensor, labels: Tensor,
                      label_weights: Tensor, bbox_targets: Tensor,
                      avg_factor: float) -> dict:
        """Calculate the loss of all scale level based on the features
        extracted by the detection head.

        Returns:
            dict[str, Tensor]: A dictionary of loss components.
        """

        anchors = anchors.reshape(-1, 4)

        # ===== this change =====
        pos_inds = (labels.sum(-1) > 0).reshape(-1)

        # Loss is not computed for the padded regions of the text.
        assert (self.text_masks.dim() == 2)
        text_mask = (self.text_masks > 0).unsqueeze(1)
        text_mask = text_mask.repeat(1, cls_score.size(1), 1)
        cls_score = torch.masked_select(cls_score, text_mask).contiguous()
        labels = torch.masked_select(labels, text_mask)
        label_weights = label_weights[...,
                                      None].repeat(1, 1, text_mask.size(-1))
        label_weights = torch.masked_select(label_weights, text_mask)

        bbox_pred = bbox_pred.reshape(-1, 4)
        centerness = centerness.reshape(-1)
        bbox_targets = bbox_targets.reshape(-1, 4)
        labels = labels.reshape(-1)
        label_weights = label_weights.reshape(-1)

        # classification loss
        loss_cls = self.loss_cls(
            cls_score, labels, label_weights, avg_factor=avg_factor)

        if pos_inds.sum() > 0:
            pos_bbox_targets = bbox_targets[pos_inds]
            pos_bbox_pred = bbox_pred[pos_inds]
            pos_anchors = anchors[pos_inds]
            pos_centerness = centerness[pos_inds]

            centerness_targets = self.centerness_target(
                pos_anchors, pos_bbox_targets)

            if torch.isnan(centerness_targets).any():
                print('=====Centerness includes NaN=====')
                mask = ~torch.isnan(centerness_targets)
                centerness_targets = centerness_targets[mask]
                pos_centerness = pos_centerness[mask]
                pos_anchors = pos_anchors[mask]
                pos_bbox_targets = pos_bbox_targets[mask]
                pos_bbox_pred = pos_bbox_pred[mask]

                if pos_bbox_targets.shape[0] == 0:
                    loss_bbox = bbox_pred.sum() * 0
                    loss_centerness = centerness.sum() * 0
                    centerness_targets = bbox_targets.new_tensor(0.)
                    return loss_cls, loss_bbox, loss_centerness, \
                        centerness_targets.sum()

            # The decoding process takes the offset into consideration.
            pos_anchors[:, 2:] += 1
            pos_decode_bbox_pred = self.bbox_coder.decode(
                pos_anchors, pos_bbox_pred)

            # regression loss
            loss_bbox = self.loss_bbox(
                pos_decode_bbox_pred,
                pos_bbox_targets,
                weight=centerness_targets,
                avg_factor=1.0)

            # centerness loss
            loss_centerness = self.loss_centerness(
                pos_centerness, centerness_targets, avg_factor=avg_factor)
        else:
            loss_bbox = bbox_pred.sum() * 0
            loss_centerness = centerness.sum() * 0
            centerness_targets = bbox_targets.new_tensor(0.)

        return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum()

    def _get_targets_single(self,
                            flat_anchors: Tensor,
                            valid_flags: Tensor,
                            num_level_anchors: List[int],
                            gt_instances: InstanceData,
                            img_meta: dict,
                            gt_instances_ignore: Optional[InstanceData] = None,
                            unmap_outputs: bool = True) -> tuple:
        """Compute regression, classification targets for anchors in a single
        image.

        Args:
            flat_anchors (Tensor): Multi-level anchors of the image, which are
                concatenated into a single tensor of shape (num_anchors ,4)
            valid_flags (Tensor): Multi level valid flags of the image,
                which are concatenated into a single tensor of
                    shape (num_anchors,).
            num_level_anchors (List[int]): Number of anchors of each scale
                level.
            gt_instances (:obj:`InstanceData`): Ground truth of instance
                annotations. It usually includes ``bboxes`` and ``labels``
                attributes.
            img_meta (dict): Meta information for current image.
            gt_instances_ignore (:obj:`InstanceData`, optional): Instances
                to be ignored during training. It includes ``bboxes`` attribute
                data that is ignored during training and testing.
                Defaults to None.
            unmap_outputs (bool): Whether to map outputs back to the original
                set of anchors.

        Returns:
            tuple: N is the number of total anchors in the image.
                labels (Tensor): Labels of all anchors in the image with shape
                    (N,).
                label_weights (Tensor): Label weights of all anchor in the
                    image with shape (N,).
                bbox_targets (Tensor): BBox targets of all anchors in the
                    image with shape (N, 4).
                bbox_weights (Tensor): BBox weights of all anchors in the
                    image with shape (N, 4)
                pos_inds (Tensor): Indices of positive anchor with shape
                    (num_pos,).
                neg_inds (Tensor): Indices of negative anchor with shape
                    (num_neg,).
                sampling_result (:obj:`SamplingResult`): Sampling results.
        """
        anchors = flat_anchors
        # Align the official implementation
        anchors[:, 2:] -= 1

        num_level_anchors_inside = num_level_anchors
        pred_instances = InstanceData(priors=anchors)
        assign_result = self.assigner.assign(pred_instances,
                                             num_level_anchors_inside,
                                             gt_instances, gt_instances_ignore)

        sampling_result = self.sampler.sample(assign_result, pred_instances,
                                              gt_instances)

        num_valid_anchors = anchors.shape[0]
        bbox_targets = torch.zeros_like(anchors)
        bbox_weights = torch.zeros_like(anchors)

        # ===== this change =====
        labels = anchors.new_full((num_valid_anchors, self.feat_channels),
                                  0,
                                  dtype=torch.float32)
        label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
        pos_inds = sampling_result.pos_inds
        neg_inds = sampling_result.neg_inds
        if len(pos_inds) > 0:
            if self.reg_decoded_bbox:
                pos_bbox_targets = sampling_result.pos_gt_bboxes
            else:
                pos_bbox_targets = self.bbox_coder.encode(
                    sampling_result.pos_priors, sampling_result.pos_gt_bboxes)

            bbox_targets[pos_inds, :] = pos_bbox_targets
            bbox_weights[pos_inds, :] = 1.0

            # ===== this change =====
            labels[pos_inds] = gt_instances.positive_maps[
                sampling_result.pos_assigned_gt_inds]
            if self.train_cfg['pos_weight'] <= 0:
                label_weights[pos_inds] = 1.0
            else:
                label_weights[pos_inds] = self.train_cfg['pos_weight']
        if len(neg_inds) > 0:
            label_weights[neg_inds] = 1.0

        return (anchors, labels, label_weights, bbox_targets, bbox_weights,
                pos_inds, neg_inds, sampling_result)

    def centerness_target(self, anchors: Tensor, gts: Tensor) -> Tensor:
        """Calculate the centerness between anchors and gts.

        Only calculate pos centerness targets, otherwise there may be nan.

        Args:
            anchors (Tensor): Anchors with shape (N, 4), "xyxy" format.
            gts (Tensor): Ground truth bboxes with shape (N, 4), "xyxy" format.

        Returns:
            Tensor: Centerness between anchors and gts.
        """
        anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2
        anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2
        l_ = anchors_cx - gts[:, 0]
        t_ = anchors_cy - gts[:, 1]
        r_ = gts[:, 2] - anchors_cx
        b_ = gts[:, 3] - anchors_cy

        left_right = torch.stack([l_, r_], dim=1)
        top_bottom = torch.stack([t_, b_], dim=1)
        centerness = torch.sqrt(
            (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) *
            (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]))
        # assert not torch.isnan(centerness).any()
        return centerness

    def predict(self,
                visual_feats: Tuple[Tensor],
                language_feats: dict,
                batch_data_samples,
                rescale: bool = True):
        """Perform forward propagation of the detection head and predict
        detection results on the features of the upstream network.

        Args:
            visual_feats (tuple[Tensor]): Multi-level visual features from the
                upstream network, each is a 4D-tensor.
            language_feats (dict): Language features from the upstream network.
            batch_data_samples (List[:obj:`DetDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
            rescale (bool, optional): Whether to rescale the results.
                Defaults to False.

        Returns:
            list[obj:`InstanceData`]: Detection results of each image
            after the post process.
        """
        batch_img_metas = [
            data_samples.metainfo for data_samples in batch_data_samples
        ]
        batch_token_positive_maps = [
            data_samples.token_positive_map
            for data_samples in batch_data_samples
        ]
        outs = self(visual_feats, language_feats)

        predictions = self.predict_by_feat(
            *outs,
            batch_img_metas=batch_img_metas,
            batch_token_positive_maps=batch_token_positive_maps,
            rescale=rescale)
        return predictions

    def predict_by_feat(self,
                        cls_logits: List[Tensor],
                        bbox_preds: List[Tensor],
                        score_factors: List[Tensor],
                        batch_img_metas: Optional[List[dict]] = None,
                        batch_token_positive_maps: Optional[List[dict]] = None,
                        cfg: Optional[ConfigDict] = None,
                        rescale: bool = False,
                        with_nms: bool = True) -> InstanceList:
        """Transform a batch of output features extracted from the head into
        bbox results.

        Note: When score_factors is not None, the cls_scores are
        usually multiplied by it then obtain the real score used in NMS,
        such as CenterNess in FCOS, IoU branch in ATSS.

        Args:
            cls_logits (list[Tensor]): Classification scores for all
                scale levels, each is a 4D-tensor, has shape
                (batch_size, num_priors * num_classes, H, W).
            bbox_preds (list[Tensor]): Box energies / deltas for all
                scale levels, each is a 4D-tensor, has shape
                (batch_size, num_priors * 4, H, W).
            score_factors (list[Tensor], optional): Score factor for
                all scale level, each is a 4D-tensor, has shape
                (batch_size, num_priors * 1, H, W). Defaults to None.
            batch_img_metas (list[dict], Optional): Batch image meta info.
                Defaults to None.
            batch_token_positive_maps (list[dict], Optional): Batch token
                positive map. Defaults to None.
            cfg (ConfigDict, optional): Test / postprocessing
                configuration, if None, test_cfg would be used.
                Defaults to None.
            rescale (bool): If True, return boxes in original image space.
                Defaults to False.
            with_nms (bool): If True, do nms before return boxes.
                Defaults to True.

        Returns:
            list[:obj:`InstanceData`]: Object detection results of each image
            after the post process. Each item usually contains following keys.

                - scores (Tensor): Classification scores, has a shape
                  (num_instance, )
                - labels (Tensor): Labels of bboxes, has a shape
                  (num_instances, ).
                - bboxes (Tensor): Has a shape (num_instances, 4),
                  the last dimension 4 arrange as (x1, y1, x2, y2).
        """
        assert len(bbox_preds) == len(score_factors)
        num_levels = len(bbox_preds)

        featmap_sizes = [bbox_preds[i].shape[-2:] for i in range(num_levels)]
        mlvl_priors = self.prior_generator.grid_priors(
            featmap_sizes,
            dtype=bbox_preds[0].dtype,
            device=bbox_preds[0].device)

        result_list = []

        for img_id in range(len(batch_img_metas)):
            img_meta = batch_img_metas[img_id]
            token_positive_maps = batch_token_positive_maps[img_id]
            bbox_pred_list = select_single_mlvl(
                bbox_preds, img_id, detach=True)
            score_factor_list = select_single_mlvl(
                score_factors, img_id, detach=True)
            cls_logit_list = select_single_mlvl(
                cls_logits, img_id, detach=True)

            results = self._predict_by_feat_single(
                bbox_pred_list=bbox_pred_list,
                score_factor_list=score_factor_list,
                cls_logit_list=cls_logit_list,
                mlvl_priors=mlvl_priors,
                token_positive_maps=token_positive_maps,
                img_meta=img_meta,
                cfg=cfg,
                rescale=rescale,
                with_nms=with_nms)
            result_list.append(results)
        return result_list

    def _predict_by_feat_single(self,
                                bbox_pred_list: List[Tensor],
                                score_factor_list: List[Tensor],
                                cls_logit_list: List[Tensor],
                                mlvl_priors: List[Tensor],
                                token_positive_maps: dict,
                                img_meta: dict,
                                cfg: ConfigDict,
                                rescale: bool = True,
                                with_nms: bool = True) -> InstanceData:
        """Transform a single image's features extracted from the head into
        bbox results.

        Args:
            bbox_pred_list (list[Tensor]): Box energies / deltas from
                all scale levels of a single image, each item has shape
                (num_priors * 4, H, W).
            score_factor_list (list[Tensor]): Score factor from all scale
                levels of a single image, each item has shape
                (num_priors * 1, H, W).
            cls_logit_list (list[Tensor]): Box scores from all scale
                levels of a single image, each item has shape
                (num_priors * num_classes, H, W).
            mlvl_priors (list[Tensor]): Each element in the list is
                the priors of a single level in feature pyramid. In all
                anchor-based methods, it has shape (num_priors, 4). In
                all anchor-free methods, it has shape (num_priors, 2)
                when `with_stride=True`, otherwise it still has shape
                (num_priors, 4).
            token_positive_maps (dict): Token positive map.
            img_meta (dict): Image meta info.
            cfg (mmengine.Config): Test / postprocessing configuration,
                if None, test_cfg would be used.
            rescale (bool): If True, return boxes in original image space.
                Defaults to False.
            with_nms (bool): If True, do nms before return boxes.
                Defaults to True.

        Returns:
            :obj:`InstanceData`: Detection results of each image
            after the post process.
            Each item usually contains following keys.

                - scores (Tensor): Classification scores, has a shape
                  (num_instance, )
                - labels (Tensor): Labels of bboxes, has a shape
                  (num_instances, ).
                - bboxes (Tensor): Has a shape (num_instances, 4),
                  the last dimension 4 arrange as (x1, y1, x2, y2).
        """
        cfg = self.test_cfg if cfg is None else cfg
        cfg = copy.deepcopy(cfg)
        img_shape = img_meta['img_shape']
        nms_pre = cfg.get('nms_pre', -1)
        score_thr = cfg.get('score_thr', 0)

        mlvl_bbox_preds = []
        mlvl_valid_priors = []
        mlvl_scores = []
        mlvl_labels = []

        for level_idx, (bbox_pred, score_factor, cls_logit, priors) in \
                enumerate(zip(bbox_pred_list,
                              score_factor_list, cls_logit_list, mlvl_priors)):
            bbox_pred = bbox_pred.permute(1, 2, 0).reshape(
                -1, self.bbox_coder.encode_size)
            score_factor = score_factor.permute(1, 2, 0).reshape(-1).sigmoid()

            scores = convert_grounding_to_cls_scores(
                logits=cls_logit.sigmoid()[None],
                positive_maps=[token_positive_maps])[0]

            results = filter_scores_and_topk(
                scores, score_thr, nms_pre,
                dict(bbox_pred=bbox_pred, priors=priors))

            scores, labels, keep_idxs, filtered_results = results

            bbox_pred = filtered_results['bbox_pred']
            priors = filtered_results['priors']
            score_factor = score_factor[keep_idxs]
            scores = torch.sqrt(scores * score_factor)

            mlvl_bbox_preds.append(bbox_pred)
            mlvl_valid_priors.append(priors)
            mlvl_scores.append(scores)
            mlvl_labels.append(labels)

        bbox_pred = torch.cat(mlvl_bbox_preds)
        priors = cat_boxes(mlvl_valid_priors)
        bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape)

        results = InstanceData()
        results.bboxes = bboxes
        results.scores = torch.cat(mlvl_scores)
        results.labels = torch.cat(mlvl_labels)

        predictions = self._bbox_post_process(
            results=results,
            cfg=cfg,
            rescale=rescale,
            with_nms=with_nms,
            img_meta=img_meta)

        if len(predictions) > 0:
            # Note: GLIP adopts a very strange bbox decoder logic,
            # and if 1 is not added here, it will not align with
            # the official mAP.
            predictions.bboxes[:, 2:] = predictions.bboxes[:, 2:] + 1
        return predictions