# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from typing import Sequence
import torch

from detectron2.layers.nms import batched_nms
from detectron2.structures.instances import Instances

from densepose.vis.bounding_box import BoundingBoxVisualizer, ScoredBoundingBoxVisualizer
from densepose.vis.densepose import DensePoseResultsVisualizer

from .base import CompoundVisualizer

Scores = Sequence[float]


def extract_scores_from_instances(instances: Instances, select=None):
    if instances.has("scores"):
        return instances.scores if select is None else instances.scores[select]
    return None


def extract_boxes_xywh_from_instances(instances: Instances, select=None):
    if instances.has("pred_boxes"):
        boxes_xywh = instances.pred_boxes.tensor.clone()
        boxes_xywh[:, 2] -= boxes_xywh[:, 0]
        boxes_xywh[:, 3] -= boxes_xywh[:, 1]
        return boxes_xywh if select is None else boxes_xywh[select]
    return None


def create_extractor(visualizer: object):
    """
    Create an extractor for the provided visualizer
    """
    if isinstance(visualizer, CompoundVisualizer):
        extractors = [create_extractor(v) for v in visualizer.visualizers]
        return CompoundExtractor(extractors)
    elif isinstance(visualizer, DensePoseResultsVisualizer):
        return DensePoseResultExtractor()
    elif isinstance(visualizer, ScoredBoundingBoxVisualizer):
        return CompoundExtractor([extract_boxes_xywh_from_instances, extract_scores_from_instances])
    elif isinstance(visualizer, BoundingBoxVisualizer):
        return extract_boxes_xywh_from_instances
    else:
        logger = logging.getLogger(__name__)
        logger.error(f"Could not create extractor for {visualizer}")
        return None


class BoundingBoxExtractor(object):
    """
    Extracts bounding boxes from instances
    """

    def __call__(self, instances: Instances):
        boxes_xywh = extract_boxes_xywh_from_instances(instances)
        return boxes_xywh


class ScoredBoundingBoxExtractor(object):
    """
    Extracts bounding boxes from instances
    """

    def __call__(self, instances: Instances, select=None):
        scores = extract_scores_from_instances(instances)
        boxes_xywh = extract_boxes_xywh_from_instances(instances)
        if (scores is None) or (boxes_xywh is None):
            return (boxes_xywh, scores)
        if select is not None:
            scores = scores[select]
            boxes_xywh = boxes_xywh[select]
        return (boxes_xywh, scores)


class DensePoseResultExtractor(object):
    """
    Extracts DensePose result from instances
    """

    def __call__(self, instances: Instances, select=None):
        boxes_xywh = extract_boxes_xywh_from_instances(instances)
        if instances.has("pred_densepose") and (boxes_xywh is not None):
            dpout = instances.pred_densepose
            if select is not None:
                dpout = dpout[select]
                boxes_xywh = boxes_xywh[select]
            return dpout.to_result(boxes_xywh)
        else:
            return None


class CompoundExtractor(object):
    """
    Extracts data for CompoundVisualizer
    """

    def __init__(self, extractors):
        self.extractors = extractors

    def __call__(self, instances: Instances, select=None):
        datas = []
        for extractor in self.extractors:
            data = extractor(instances, select)
            datas.append(data)
        return datas


class NmsFilteredExtractor(object):
    """
    Extracts data in the format accepted by NmsFilteredVisualizer
    """

    def __init__(self, extractor, iou_threshold):
        self.extractor = extractor
        self.iou_threshold = iou_threshold

    def __call__(self, instances: Instances, select=None):
        scores = extract_scores_from_instances(instances)
        boxes_xywh = extract_boxes_xywh_from_instances(instances)
        if boxes_xywh is None:
            return None
        select_local_idx = batched_nms(
            boxes_xywh,
            scores,
            torch.zeros(len(scores), dtype=torch.int32),
            iou_threshold=self.iou_threshold,
        ).squeeze()
        select_local = torch.zeros(len(boxes_xywh), dtype=torch.bool, device=boxes_xywh.device)
        select_local[select_local_idx] = True
        select = select_local if select is None else (select & select_local)
        return self.extractor(instances, select=select)


class ScoreThresholdedExtractor(object):
    """
    Extracts data in the format accepted by ScoreThresholdedVisualizer
    """

    def __init__(self, extractor, min_score):
        self.extractor = extractor
        self.min_score = min_score

    def __call__(self, instances: Instances, select=None):
        scores = extract_scores_from_instances(instances)
        if scores is None:
            return None
        select_local = scores > self.min_score
        select = select_local if select is None else (select & select_local)
        data = self.extractor(instances, select=select)
        return data
