# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union

import torch
from mmcv.ops import nms_match
from mmengine.structures import InstanceData
from numpy import ndarray
from torch import Tensor

from mmdet.registry import TASK_UTILS
from mmdet.structures.bbox import bbox2roi
from ..assigners import AssignResult
from .base_sampler import BaseSampler
from .sampling_result import SamplingResult


@TASK_UTILS.register_module()
class ScoreHLRSampler(BaseSampler):
    r"""Importance-based Sample Reweighting (ISR_N), described in `Prime Sample
    Attention in Object Detection <https://arxiv.org/abs/1904.04821>`_.

    Score hierarchical local rank (HLR) differentiates with RandomSampler in
    negative part. It firstly computes Score-HLR in a two-step way,
    then linearly maps score hlr to the loss weights.

    Args:
        num (int): Total number of sampled RoIs.
        pos_fraction (float): Fraction of positive samples.
        context (:obj:`BaseRoIHead`): RoI head that the sampler belongs to.
        neg_pos_ub (int): Upper bound of the ratio of num negative to num
            positive, -1 means no upper bound. Defaults to -1.
        add_gt_as_proposals (bool): Whether to add ground truth as proposals.
            Defaults to True.
        k (float): Power of the non-linear mapping. Defaults to 0.5
        bias (float): Shift of the non-linear mapping. Defaults to 0.
        score_thr (float): Minimum score that a negative sample is to be
            considered as valid bbox. Defaults to 0.05.
        iou_thr (float): IoU threshold for NMS match. Defaults to 0.5.
    """

    def __init__(self,
                 num: int,
                 pos_fraction: float,
                 context,
                 neg_pos_ub: int = -1,
                 add_gt_as_proposals: bool = True,
                 k: float = 0.5,
                 bias: float = 0,
                 score_thr: float = 0.05,
                 iou_thr: float = 0.5,
                 **kwargs) -> None:
        super().__init__(
            num=num,
            pos_fraction=pos_fraction,
            neg_pos_ub=neg_pos_ub,
            add_gt_as_proposals=add_gt_as_proposals)
        self.k = k
        self.bias = bias
        self.score_thr = score_thr
        self.iou_thr = iou_thr
        self.context = context
        # context of cascade detectors is a list, so distinguish them here.
        if not hasattr(context, 'num_stages'):
            self.bbox_roi_extractor = context.bbox_roi_extractor
            self.bbox_head = context.bbox_head
            self.with_shared_head = context.with_shared_head
            if self.with_shared_head:
                self.shared_head = context.shared_head
        else:
            self.bbox_roi_extractor = context.bbox_roi_extractor[
                context.current_stage]
            self.bbox_head = context.bbox_head[context.current_stage]

    @staticmethod
    def random_choice(gallery: Union[Tensor, ndarray, list],
                      num: int) -> Union[Tensor, ndarray]:
        """Randomly select some elements from the gallery.

        If `gallery` is a Tensor, the returned indices will be a Tensor;
        If `gallery` is a ndarray or list, the returned indices will be a
        ndarray.

        Args:
            gallery (Tensor or ndarray or list): indices pool.
            num (int): expected sample num.

        Returns:
            Tensor or ndarray: sampled indices.
        """
        assert len(gallery) >= num

        is_tensor = isinstance(gallery, torch.Tensor)
        if not is_tensor:
            if torch.cuda.is_available():
                device = torch.cuda.current_device()
            else:
                device = 'cpu'
            gallery = torch.tensor(gallery, dtype=torch.long, device=device)
        perm = torch.randperm(gallery.numel(), device=gallery.device)[:num]
        rand_inds = gallery[perm]
        if not is_tensor:
            rand_inds = rand_inds.cpu().numpy()
        return rand_inds

    def _sample_pos(self, assign_result: AssignResult, num_expected: int,
                    **kwargs) -> Union[Tensor, ndarray]:
        """Randomly sample some positive samples.

        Args:
            assign_result (:obj:`AssignResult`): Bbox assigning results.
            num_expected (int): The number of expected positive samples

        Returns:
            Tensor or ndarray: sampled indices.
        """
        pos_inds = torch.nonzero(assign_result.gt_inds > 0).flatten()
        if pos_inds.numel() <= num_expected:
            return pos_inds
        else:
            return self.random_choice(pos_inds, num_expected)

    def _sample_neg(self, assign_result: AssignResult, num_expected: int,
                    bboxes: Tensor, feats: Tensor,
                    **kwargs) -> Union[Tensor, ndarray]:
        """Sample negative samples.

        Score-HLR sampler is done in the following steps:
        1. Take the maximum positive score prediction of each negative samples
            as s_i.
        2. Filter out negative samples whose s_i <= score_thr, the left samples
            are called valid samples.
        3. Use NMS-Match to divide valid samples into different groups,
            samples in the same group will greatly overlap with each other
        4. Rank the matched samples in two-steps to get Score-HLR.
            (1) In the same group, rank samples with their scores.
            (2) In the same score rank across different groups,
                rank samples with their scores again.
        5. Linearly map Score-HLR to the final label weights.

        Args:
            assign_result (:obj:`AssignResult`): result of assigner.
            num_expected (int): Expected number of samples.
            bboxes (Tensor): bbox to be sampled.
            feats (Tensor): Features come from FPN.

        Returns:
            Tensor or ndarray: sampled indices.
        """
        neg_inds = torch.nonzero(assign_result.gt_inds == 0).flatten()
        num_neg = neg_inds.size(0)
        if num_neg == 0:
            return neg_inds, None
        with torch.no_grad():
            neg_bboxes = bboxes[neg_inds]
            neg_rois = bbox2roi([neg_bboxes])
            bbox_result = self.context._bbox_forward(feats, neg_rois)
            cls_score, bbox_pred = bbox_result['cls_score'], bbox_result[
                'bbox_pred']

            ori_loss = self.bbox_head.loss(
                cls_score=cls_score,
                bbox_pred=None,
                rois=None,
                labels=neg_inds.new_full((num_neg, ),
                                         self.bbox_head.num_classes),
                label_weights=cls_score.new_ones(num_neg),
                bbox_targets=None,
                bbox_weights=None,
                reduction_override='none')['loss_cls']

            # filter out samples with the max score lower than score_thr
            max_score, argmax_score = cls_score.softmax(-1)[:, :-1].max(-1)
            valid_inds = (max_score > self.score_thr).nonzero().view(-1)
            invalid_inds = (max_score <= self.score_thr).nonzero().view(-1)
            num_valid = valid_inds.size(0)
            num_invalid = invalid_inds.size(0)

            num_expected = min(num_neg, num_expected)
            num_hlr = min(num_valid, num_expected)
            num_rand = num_expected - num_hlr
            if num_valid > 0:
                valid_rois = neg_rois[valid_inds]
                valid_max_score = max_score[valid_inds]
                valid_argmax_score = argmax_score[valid_inds]
                valid_bbox_pred = bbox_pred[valid_inds]

                # valid_bbox_pred shape: [num_valid, #num_classes, 4]
                valid_bbox_pred = valid_bbox_pred.view(
                    valid_bbox_pred.size(0), -1, 4)
                selected_bbox_pred = valid_bbox_pred[range(num_valid),
                                                     valid_argmax_score]
                pred_bboxes = self.bbox_head.bbox_coder.decode(
                    valid_rois[:, 1:], selected_bbox_pred)
                pred_bboxes_with_score = torch.cat(
                    [pred_bboxes, valid_max_score[:, None]], -1)
                group = nms_match(pred_bboxes_with_score, self.iou_thr)

                # imp: importance
                imp = cls_score.new_zeros(num_valid)
                for g in group:
                    g_score = valid_max_score[g]
                    # g_score has already sorted
                    rank = g_score.new_tensor(range(g_score.size(0)))
                    imp[g] = num_valid - rank + g_score
                _, imp_rank_inds = imp.sort(descending=True)
                _, imp_rank = imp_rank_inds.sort()
                hlr_inds = imp_rank_inds[:num_expected]

                if num_rand > 0:
                    rand_inds = torch.randperm(num_invalid)[:num_rand]
                    select_inds = torch.cat(
                        [valid_inds[hlr_inds], invalid_inds[rand_inds]])
                else:
                    select_inds = valid_inds[hlr_inds]

                neg_label_weights = cls_score.new_ones(num_expected)

                up_bound = max(num_expected, num_valid)
                imp_weights = (up_bound -
                               imp_rank[hlr_inds].float()) / up_bound
                neg_label_weights[:num_hlr] = imp_weights
                neg_label_weights[num_hlr:] = imp_weights.min()
                neg_label_weights = (self.bias +
                                     (1 - self.bias) * neg_label_weights).pow(
                                         self.k)
                ori_selected_loss = ori_loss[select_inds]
                new_loss = ori_selected_loss * neg_label_weights
                norm_ratio = ori_selected_loss.sum() / new_loss.sum()
                neg_label_weights *= norm_ratio
            else:
                neg_label_weights = cls_score.new_ones(num_expected)
                select_inds = torch.randperm(num_neg)[:num_expected]

            return neg_inds[select_inds], neg_label_weights

    def sample(self, assign_result: AssignResult, pred_instances: InstanceData,
               gt_instances: InstanceData, **kwargs) -> SamplingResult:
        """Sample positive and negative bboxes.

        This is a simple implementation of bbox sampling given candidates,
        assigning results and ground truth bboxes.

        Args:
            assign_result (:obj:`AssignResult`): Assigning results.
            pred_instances (:obj:`InstanceData`): Instances of model
                predictions. It includes ``priors``, and the priors can
                be anchors or points, or the bboxes predicted by the
                previous stage, has shape (n, 4). The bboxes predicted by
                the current model or stage will be named ``bboxes``,
                ``labels``, and ``scores``, the same as the ``InstanceData``
                in other places.
            gt_instances (:obj:`InstanceData`): Ground truth of instance
                annotations. It usually includes ``bboxes``, with shape (k, 4),
                and ``labels``, with shape (k, ).

        Returns:
            :obj:`SamplingResult`: Sampling result.
        """
        gt_bboxes = gt_instances.bboxes
        priors = pred_instances.priors
        gt_labels = gt_instances.labels

        gt_flags = priors.new_zeros((priors.shape[0], ), dtype=torch.uint8)
        if self.add_gt_as_proposals and len(gt_bboxes) > 0:
            priors = torch.cat([gt_bboxes, priors], dim=0)
            assign_result.add_gt_(gt_labels)
            gt_ones = priors.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
            gt_flags = torch.cat([gt_ones, gt_flags])

        num_expected_pos = int(self.num * self.pos_fraction)
        pos_inds = self.pos_sampler._sample_pos(
            assign_result, num_expected_pos, bboxes=priors, **kwargs)
        num_sampled_pos = pos_inds.numel()
        num_expected_neg = self.num - num_sampled_pos
        if self.neg_pos_ub >= 0:
            _pos = max(1, num_sampled_pos)
            neg_upper_bound = int(self.neg_pos_ub * _pos)
            if num_expected_neg > neg_upper_bound:
                num_expected_neg = neg_upper_bound
        neg_inds, neg_label_weights = self.neg_sampler._sample_neg(
            assign_result, num_expected_neg, bboxes=priors, **kwargs)

        sampling_result = SamplingResult(
            pos_inds=pos_inds,
            neg_inds=neg_inds,
            priors=priors,
            gt_bboxes=gt_bboxes,
            assign_result=assign_result,
            gt_flags=gt_flags)
        return sampling_result, neg_label_weights
