# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
from torch import Tensor

from mmdet.structures.bbox import bbox_overlaps
from ..task_modules.coders import BaseBBoxCoder
from ..task_modules.samplers import SamplingResult


def isr_p(cls_score: Tensor,
          bbox_pred: Tensor,
          bbox_targets: Tuple[Tensor],
          rois: Tensor,
          sampling_results: List[SamplingResult],
          loss_cls: nn.Module,
          bbox_coder: BaseBBoxCoder,
          k: float = 2,
          bias: float = 0,
          num_class: int = 80) -> tuple:
    """Importance-based Sample Reweighting (ISR_P), positive part.

    Args:
        cls_score (Tensor): Predicted classification scores.
        bbox_pred (Tensor): Predicted bbox deltas.
        bbox_targets (tuple[Tensor]): A tuple of bbox targets, the are
            labels, label_weights, bbox_targets, bbox_weights, respectively.
        rois (Tensor): Anchors (single_stage) in shape (n, 4) or RoIs
            (two_stage) in shape (n, 5).
        sampling_results (:obj:`SamplingResult`): Sampling results.
        loss_cls (:obj:`nn.Module`): Classification loss func of the head.
        bbox_coder (:obj:`BaseBBoxCoder`): BBox coder of the head.
        k (float): Power of the non-linear mapping. Defaults to 2.
        bias (float): Shift of the non-linear mapping. Defaults to 0.
        num_class (int): Number of classes, defaults to 80.

    Return:
        tuple([Tensor]): labels, imp_based_label_weights, bbox_targets,
            bbox_target_weights
    """

    labels, label_weights, bbox_targets, bbox_weights = bbox_targets
    pos_label_inds = ((labels >= 0) &
                      (labels < num_class)).nonzero().reshape(-1)
    pos_labels = labels[pos_label_inds]

    # if no positive samples, return the original targets
    num_pos = float(pos_label_inds.size(0))
    if num_pos == 0:
        return labels, label_weights, bbox_targets, bbox_weights

    # merge pos_assigned_gt_inds of per image to a single tensor
    gts = list()
    last_max_gt = 0
    for i in range(len(sampling_results)):
        gt_i = sampling_results[i].pos_assigned_gt_inds
        gts.append(gt_i + last_max_gt)
        if len(gt_i) != 0:
            last_max_gt = gt_i.max() + 1
    gts = torch.cat(gts)
    assert len(gts) == num_pos

    cls_score = cls_score.detach()
    bbox_pred = bbox_pred.detach()

    # For single stage detectors, rois here indicate anchors, in shape (N, 4)
    # For two stage detectors, rois are in shape (N, 5)
    if rois.size(-1) == 5:
        pos_rois = rois[pos_label_inds][:, 1:]
    else:
        pos_rois = rois[pos_label_inds]

    if bbox_pred.size(-1) > 4:
        bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, 4)
        pos_delta_pred = bbox_pred[pos_label_inds, pos_labels].view(-1, 4)
    else:
        pos_delta_pred = bbox_pred[pos_label_inds].view(-1, 4)

    # compute iou of the predicted bbox and the corresponding GT
    pos_delta_target = bbox_targets[pos_label_inds].view(-1, 4)
    pos_bbox_pred = bbox_coder.decode(pos_rois, pos_delta_pred)
    target_bbox_pred = bbox_coder.decode(pos_rois, pos_delta_target)
    ious = bbox_overlaps(pos_bbox_pred, target_bbox_pred, is_aligned=True)

    pos_imp_weights = label_weights[pos_label_inds]
    # Two steps to compute IoU-HLR. Samples are first sorted by IoU locally,
    # then sorted again within the same-rank group
    max_l_num = pos_labels.bincount().max()
    for label in pos_labels.unique():
        l_inds = (pos_labels == label).nonzero().view(-1)
        l_gts = gts[l_inds]
        for t in l_gts.unique():
            t_inds = l_inds[l_gts == t]
            t_ious = ious[t_inds]
            _, t_iou_rank_idx = t_ious.sort(descending=True)
            _, t_iou_rank = t_iou_rank_idx.sort()
            ious[t_inds] += max_l_num - t_iou_rank.float()
        l_ious = ious[l_inds]
        _, l_iou_rank_idx = l_ious.sort(descending=True)
        _, l_iou_rank = l_iou_rank_idx.sort()  # IoU-HLR
        # linearly map HLR to label weights
        pos_imp_weights[l_inds] *= (max_l_num - l_iou_rank.float()) / max_l_num

    pos_imp_weights = (bias + pos_imp_weights * (1 - bias)).pow(k)

    # normalize to make the new weighted loss value equal to the original loss
    pos_loss_cls = loss_cls(
        cls_score[pos_label_inds], pos_labels, reduction_override='none')
    if pos_loss_cls.dim() > 1:
        ori_pos_loss_cls = pos_loss_cls * label_weights[pos_label_inds][:,
                                                                        None]
        new_pos_loss_cls = pos_loss_cls * pos_imp_weights[:, None]
    else:
        ori_pos_loss_cls = pos_loss_cls * label_weights[pos_label_inds]
        new_pos_loss_cls = pos_loss_cls * pos_imp_weights
    pos_loss_cls_ratio = ori_pos_loss_cls.sum() / new_pos_loss_cls.sum()
    pos_imp_weights = pos_imp_weights * pos_loss_cls_ratio
    label_weights[pos_label_inds] = pos_imp_weights

    bbox_targets = labels, label_weights, bbox_targets, bbox_weights
    return bbox_targets


def carl_loss(cls_score: Tensor,
              labels: Tensor,
              bbox_pred: Tensor,
              bbox_targets: Tensor,
              loss_bbox: nn.Module,
              k: float = 1,
              bias: float = 0.2,
              avg_factor: Optional[int] = None,
              sigmoid: bool = False,
              num_class: int = 80) -> dict:
    """Classification-Aware Regression Loss (CARL).

    Args:
        cls_score (Tensor): Predicted classification scores.
        labels (Tensor): Targets of classification.
        bbox_pred (Tensor): Predicted bbox deltas.
        bbox_targets (Tensor): Target of bbox regression.
        loss_bbox (func): Regression loss func of the head.
        bbox_coder (obj): BBox coder of the head.
        k (float): Power of the non-linear mapping. Defaults to 1.
        bias (float): Shift of the non-linear mapping. Defaults to 0.2.
        avg_factor (int, optional): Average factor used in regression loss.
        sigmoid (bool): Activation of the classification score.
        num_class (int): Number of classes, defaults to 80.

    Return:
        dict: CARL loss dict.
    """
    pos_label_inds = ((labels >= 0) &
                      (labels < num_class)).nonzero().reshape(-1)
    if pos_label_inds.numel() == 0:
        return dict(loss_carl=cls_score.sum()[None] * 0.)
    pos_labels = labels[pos_label_inds]

    # multiply pos_cls_score with the corresponding bbox weight
    # and remain gradient
    if sigmoid:
        pos_cls_score = cls_score.sigmoid()[pos_label_inds, pos_labels]
    else:
        pos_cls_score = cls_score.softmax(-1)[pos_label_inds, pos_labels]
    carl_loss_weights = (bias + (1 - bias) * pos_cls_score).pow(k)

    # normalize carl_loss_weight to make its sum equal to num positive
    num_pos = float(pos_cls_score.size(0))
    weight_ratio = num_pos / carl_loss_weights.sum()
    carl_loss_weights *= weight_ratio

    if avg_factor is None:
        avg_factor = bbox_targets.size(0)
    # if is class agnostic, bbox pred is in shape (N, 4)
    # otherwise, bbox pred is in shape (N, #classes, 4)
    if bbox_pred.size(-1) > 4:
        bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, 4)
        pos_bbox_preds = bbox_pred[pos_label_inds, pos_labels]
    else:
        pos_bbox_preds = bbox_pred[pos_label_inds]
    ori_loss_reg = loss_bbox(
        pos_bbox_preds,
        bbox_targets[pos_label_inds],
        reduction_override='none') / avg_factor
    loss_carl = (ori_loss_reg * carl_loss_weights[:, None]).sum()
    return dict(loss_carl=loss_carl[None])
