import torch

try:
    from detectron2.structures import (
        Boxes,
        pairwise_iou,
        pairwise_iou_rotated,
        RotatedBoxes,
    )
except ImportError:
    pass


def batched_soft_nms(
    boxes,
    scores,
    idxs,
    threshold=0.7,
    method="linear",
    gaussian_sigma=0.5,
    prune_threshold=0.001,
    quad_scale=1.0,
):
    """
    Performs soft non-maximum suppression in a batched fashion.
    Each index value correspond to a category, and NMS
    will not be applied between elements of different categories.
    Args:
        boxes (Tensor[N, 4]):
           boxes where NMS will be performed. They
           are expected to be in (x1, y1, x2, y2) format
        scores (Tensor[N]):
           scores for each one of the boxes
        idxs (Tensor[N]):
           indices of the categories for each one of the boxes.
        method (str):
           one of ['gaussian', 'linear', 'hard']
           see paper for details. users encouraged not to use "hard", as this is the
           same nms available elsewhere in detectron2
        gaussian_sigma (float):
           parameter for Gaussian penalty function
        linear_threshold (float):
           iou threshold for applying linear decay. Nt from the paper
           re-used as threshold for standard "hard" nms
        prune_threshold (float):
           boxes with scores below this threshold are pruned at each iteration.
           Dramatically reduces computation time. Authors use values in [10e-4, 10e-2]
    Returns:
        tuple(Tensor, Tensor):
            [0]: int64 tensor with the indices of the elements that have been kept
            by Soft NMS, sorted in decreasing order of scores
            [1]: float tensor with the re-scored scores of the elements that were kept
    """
    if boxes.numel() == 0:
        return (
            torch.empty((0,), dtype=torch.int64, device=boxes.device),
            torch.empty((0,), dtype=torch.float32, device=scores.device),
        )
    # strategy: in order to perform NMS independently per class.
    # we add an offset to all the boxes. The offset is dependent
    # only on the class idx, and is large enough so that boxes
    # from different classes do not overlap
    max_coordinate = boxes.max()
    offsets = idxs.to(boxes) * (max_coordinate + 1)
    boxes_for_nms = boxes + offsets[:, None]
    return soft_nms(
        boxes_for_nms,
        scores,
        method,
        gaussian_sigma,
        threshold,
        prune_threshold,
        quad_scale,
    )


def soft_nms(
    boxes, scores, method, gaussian_sigma, linear_threshold, prune_threshold, quad_scale
):
    """
    Performs soft non-maximum suppression algorithm on axis aligned boxes
    Args:
        boxes (Tensor[N, 5]):
           boxes where NMS will be performed. They
           are expected to be in (x_ctr, y_ctr, width, height, angle_degrees) format
        scores (Tensor[N]):
           scores for each one of the boxes
        method (str):
           one of ['gaussian', 'linear', 'hard']
           see paper for details. users encouraged not to use "hard", as this is the
           same nms available elsewhere in detectron2
        gaussian_sigma (float):
           parameter for Gaussian penalty function
        linear_threshold (float):
           iou threshold for applying linear decay. Nt from the paper
           re-used as threshold for standard "hard" nms
        prune_threshold (float):
           boxes with scores below this threshold are pruned at each iteration.
           Dramatically reduces computation time. Authors use values in [10e-4, 10e-2]
    Returns:
        tuple(Tensor, Tensor):
            [0]: int64 tensor with the indices of the elements that have been kept
            by Soft NMS, sorted in decreasing order of scores
            [1]: float tensor with the re-scored scores of the elements that were kept
    """
    return _soft_nms(
        Boxes,
        pairwise_iou,
        boxes,
        scores,
        method,
        gaussian_sigma,
        linear_threshold,
        prune_threshold,
        quad_scale,
    )


def _soft_nms(
    box_class,
    pairwise_iou_func,
    boxes,
    scores,
    method,
    gaussian_sigma,
    linear_threshold,
    prune_threshold,
    quad_scale,
):
    """
    Soft non-max suppression algorithm.
    Implementation of [Soft-NMS -- Improving Object Detection With One Line of Codec]
    (https://arxiv.org/abs/1704.04503)
    Args:
        box_class (cls): one of Box, RotatedBoxes
        pairwise_iou_func (func): one of pairwise_iou, pairwise_iou_rotated
        boxes (Tensor[N, ?]):
           boxes where NMS will be performed
           if Boxes, in (x1, y1, x2, y2) format
           if RotatedBoxes, in (x_ctr, y_ctr, width, height, angle_degrees) format
        scores (Tensor[N]):
           scores for each one of the boxes
        method (str):
           one of ['gaussian', 'linear', 'hard']
           see paper for details. users encouraged not to use "hard", as this is the
           same nms available elsewhere in detectron2
        gaussian_sigma (float):
           parameter for Gaussian penalty function
        linear_threshold (float):
           iou threshold for applying linear decay. Nt from the paper
           re-used as threshold for standard "hard" nms
        prune_threshold (float):
           boxes with scores below this threshold are pruned at each iteration.
           Dramatically reduces computation time. Authors use values in [10e-4, 10e-2]
    Returns:
        tuple(Tensor, Tensor):
            [0]: int64 tensor with the indices of the elements that have been kept
            by Soft NMS, sorted in decreasing order of scores
            [1]: float tensor with the re-scored scores of the elements that were kept
    """
    boxes = boxes.clone()
    scores = scores.clone()
    idxs = torch.arange(scores.size()[0])

    idxs_out = []
    scores_out = []

    while scores.numel() > 0:
        top_idx = torch.argmax(scores)
        idxs_out.append(idxs[top_idx].item())
        scores_out.append(scores[top_idx].item())

        top_box = boxes[top_idx]
        ious = pairwise_iou_func(box_class(top_box.unsqueeze(0)), box_class(boxes))[0]

        if method == "quad":
            decay = torch.ones_like(ious)
            decay_mask = ious > linear_threshold
            decay[decay_mask] = (1 - ious[decay_mask]) ** quad_scale
        elif method == "linear":
            decay = torch.ones_like(ious)
            decay_mask = ious > linear_threshold
            decay[decay_mask] = 1 - ious[decay_mask]
        elif method == "gaussian":
            decay = torch.exp(-torch.pow(ious, 2) / gaussian_sigma)
        elif method == "hard":  # standard NMS
            decay = (ious < linear_threshold).float()
        else:
            raise NotImplementedError(
                "{} soft nms method not implemented.".format(method)
            )

        scores *= decay
        keep = scores > prune_threshold
        keep[top_idx] = False

        boxes = boxes[keep]
        scores = scores[keep]
        idxs = idxs[keep.to(idxs.device)]

    return torch.tensor(idxs_out).to(boxes.device), torch.tensor(scores_out).to(
        scores.device
    )
