# Copyright (c) Facebook, Inc. and its affiliates.

# pyre-unsafe
import numpy as np
from typing import Optional, Tuple
import cv2

from densepose.structures import DensePoseDataRelative

from ..structures import DensePoseChartPredictorOutput
from .base import Boxes, Image, MatrixVisualizer


class DensePoseOutputsVisualizer:
    def __init__(
        self, inplace=True, cmap=cv2.COLORMAP_PARULA, alpha=0.7, to_visualize=None, **kwargs
    ):
        assert to_visualize in "IUV", "can only visualize IUV"
        self.to_visualize = to_visualize

        if self.to_visualize == "I":
            val_scale = 255.0 / DensePoseDataRelative.N_PART_LABELS
        else:
            val_scale = 1.0
        self.mask_visualizer = MatrixVisualizer(
            inplace=inplace, cmap=cmap, val_scale=val_scale, alpha=alpha
        )

    def visualize(
        self,
        image_bgr: Image,
        dp_output_with_bboxes: Tuple[Optional[DensePoseChartPredictorOutput], Optional[Boxes]],
    ) -> Image:
        densepose_output, bboxes_xywh = dp_output_with_bboxes
        if densepose_output is None or bboxes_xywh is None:
            return image_bgr

        assert isinstance(
            densepose_output, DensePoseChartPredictorOutput
        ), "DensePoseChartPredictorOutput expected, {} encountered".format(type(densepose_output))

        S = densepose_output.coarse_segm
        I = densepose_output.fine_segm  # noqa
        U = densepose_output.u
        V = densepose_output.v
        N = S.size(0)
        assert N == I.size(
            0
        ), "densepose outputs S {} and I {}" " should have equal first dim size".format(
            S.size(), I.size()
        )
        assert N == U.size(
            0
        ), "densepose outputs S {} and U {}" " should have equal first dim size".format(
            S.size(), U.size()
        )
        assert N == V.size(
            0
        ), "densepose outputs S {} and V {}" " should have equal first dim size".format(
            S.size(), V.size()
        )
        assert N == len(
            bboxes_xywh
        ), "number of bounding boxes {}" " should be equal to first dim size of outputs {}".format(
            len(bboxes_xywh), N
        )
        for n in range(N):
            Sn = S[n].argmax(dim=0)
            In = I[n].argmax(dim=0) * (Sn > 0).long()
            segmentation = In.cpu().numpy().astype(np.uint8)
            mask = np.zeros(segmentation.shape, dtype=np.uint8)
            mask[segmentation > 0] = 1
            bbox_xywh = bboxes_xywh[n]

            if self.to_visualize == "I":
                vis = segmentation
            elif self.to_visualize in "UV":
                U_or_Vn = {"U": U, "V": V}[self.to_visualize][n].cpu().numpy().astype(np.float32)
                vis = np.zeros(segmentation.shape, dtype=np.float32)
                for partId in range(U_or_Vn.shape[0]):
                    vis[segmentation == partId] = (
                        U_or_Vn[partId][segmentation == partId].clip(0, 1) * 255
                    )

            # pyre-fixme[61]: `vis` may not be initialized here.
            image_bgr = self.mask_visualizer.visualize(image_bgr, mask, vis, bbox_xywh)

        return image_bgr


class DensePoseOutputsUVisualizer(DensePoseOutputsVisualizer):
    def __init__(self, inplace=True, cmap=cv2.COLORMAP_PARULA, alpha=0.7, **kwargs):
        super().__init__(inplace=inplace, cmap=cmap, alpha=alpha, to_visualize="U", **kwargs)


class DensePoseOutputsVVisualizer(DensePoseOutputsVisualizer):
    def __init__(self, inplace=True, cmap=cv2.COLORMAP_PARULA, alpha=0.7, **kwargs):
        super().__init__(inplace=inplace, cmap=cmap, alpha=alpha, to_visualize="V", **kwargs)


class DensePoseOutputsFineSegmentationVisualizer(DensePoseOutputsVisualizer):
    def __init__(self, inplace=True, cmap=cv2.COLORMAP_PARULA, alpha=0.7, **kwargs):
        super().__init__(inplace=inplace, cmap=cmap, alpha=alpha, to_visualize="I", **kwargs)
