# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Meta Platforms, Inc. All Rights Reserved

import numpy as np
import torch

from detectron2.data import MetadataCatalog
from detectron2.engine.defaults import DefaultPredictor
from detectron2.utils.visualizer import ColorMode, Visualizer
from detectron2.structures import BitMasks, Instances


class OVSegPredictor(DefaultPredictor):
    def __init__(self, cfg):
        super().__init__(cfg)

    def __call__(self, original_image, class_names, gt, apply_gt_only=False):
        """
        Args:
            original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).

        Returns:
            predictions (dict):
                the output of the model for one image only.
                See :doc:`/tutorials/models` for details about the format.
        """
        with torch.no_grad():  # https://github.com/sphinx-doc/sphinx/issues/4258
            # Apply pre-processing to image.
            if self.input_format == "RGB":
                # whether the model expects BGR inputs or RGB
                original_image = original_image[:, :, ::-1]
            height, width = original_image.shape[:2]
            if gt is not None:
                aug = self.aug.get_transform(original_image)
                image = aug.apply_image(original_image)
                gt = aug.apply_segmentation(gt)
                gt = torch.as_tensor(gt.astype("long"))
                
                # Prepare per-category binary masks
                if gt is not None:
                    sem_seg_gt = gt.numpy()
                    image_shape = (height, width)
                    instances = Instances(image_shape)
                    classes = np.unique(sem_seg_gt)
                    # remove ignored region
                    classes = classes[classes != 255]
                    instances.gt_classes = torch.tensor(classes, dtype=torch.int64)

                    masks = []
                    for class_id in classes:
                        masks.append(sem_seg_gt == class_id)

                    if len(masks) == 0:
                        # Some image does not have annotation (all ignored)
                        instances.gt_masks = torch.zeros(
                            (0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1])
                        )
                    else:
                        masks = BitMasks(
                            torch.stack(
                                [
                                    torch.from_numpy(np.ascontiguousarray(x.copy()))
                                    for x in masks
                                ]
                            )
                        )
                        instances.gt_masks = masks.tensor
            else:
                image = self.aug.get_transform(original_image).apply_image(original_image)
            image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))

            inputs = {"image": image, "instances": instances, "height": height, "width": width, "class_names": class_names}
            predictions = self.model([inputs])[0]
            return predictions

class OVSegVisualizer(Visualizer):
    def __init__(self, img_rgb, metadata=None, scale=1, instance_mode=ColorMode.IMAGE, class_names=None, use_classnames=True, gt=False, colors=None):
        super().__init__(img_rgb, metadata, scale, instance_mode)
        self.class_names = class_names
        self._default_font_size = max(
            np.sqrt(self.output.height * self.output.width) // 90, 30 // scale
        )
        self.use_classnames = use_classnames
        self.gt = gt
        self.colors = colors

    def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8):
        """
        Draw semantic segmentation predictions/labels.

        Args:
            sem_seg (Tensor or ndarray): the segmentation of shape (H, W).
                Each value is the integer label of the pixel.
            area_threshold (int): segments with less than `area_threshold` are not drawn.
            alpha (float): the larger it is, the more opaque the segmentations are.

        Returns:
            output (VisImage): image object with visualizations.
        """
        if isinstance(sem_seg, torch.Tensor):
            sem_seg = sem_seg.numpy()
        labels, areas = np.unique(sem_seg, return_counts=True)
        sorted_idxs = np.argsort(-areas).tolist()
        labels = labels[sorted_idxs]
        class_names = self.class_names if self.class_names is not None else self.metadata.stuff_classes

        for label in filter(lambda l: l < len(class_names), labels):
            try:
                if self.colors is not None:
                    mask_color = [x / 255 for x in self.colors[label]]
                else:
                    mask_color = [x / 255 for x in self.metadata.stuff_colors[label]] #class_names[label]
            except (AttributeError, IndexError):
                mask_color = None

            binary_mask = (sem_seg == label).astype(np.uint8)
            text = class_names[label]
            if "," in text:
                # if self.gt:
                #     text = class_names[label-1].split(",")[0]
                # else:
                text = class_names[label].split(",")[0]
            if self.use_classnames:
                self.draw_binary_mask(
                    binary_mask,
                    color=mask_color,
                    edge_color=(1.0, 1.0, 240.0 / 255),
                    text=text,
                    alpha=alpha,
                    area_threshold=area_threshold
                )
            else:
                self.draw_binary_mask(
                    binary_mask,
                    color=mask_color,
                    edge_color=(1.0, 1.0, 240.0 / 255),
                    alpha=alpha,
                    area_threshold=area_threshold
                )
        return self.output



class VisualizationDemo(object):
    def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False):
        """
        Args:
            cfg (CfgNode):
            instance_mode (ColorMode):
            parallel (bool): whether to run the model in different processes from visualization.
                Useful since the visualization logic can be slow.
        """
        self.metadata = MetadataCatalog.get(
            cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
        )

        self.cpu_device = torch.device("cpu")
        self.instance_mode = instance_mode

        self.parallel = parallel
        if parallel:
            raise NotImplementedError
        else:
            self.predictor = OVSegPredictor(cfg)

    def run_on_image(self, image, class_names, gt=None, use_gt=False, class_names_gt=None, colors=None):
        """
        Args:
            image (np.ndarray): an image of shape (H, W, C) (in BGR order).
                This is the format used by OpenCV.
        Returns:
            predictions (dict): the output of the model.
            vis_output (VisImage): the visualized image output.
        """
        predictions = self.predictor(image, class_names, gt)
        # Convert image from OpenCV BGR format to Matplotlib RGB format.
        image = image[:, :, ::-1]
        if not use_gt:
            visualizer = OVSegVisualizer(image, self.metadata, instance_mode=self.instance_mode, class_names=class_names, use_classnames=True, colors=colors)
        else:
            visualizer = OVSegVisualizer(image, self.metadata, instance_mode=self.instance_mode, class_names=class_names_gt, use_classnames=True, gt=False, colors=colors)
        if "sem_seg" in predictions:
            if not use_gt:
                r = predictions["sem_seg"]
                # blank_area = (r[0] == 0)
                pred_mask = r.argmax(dim=0).to('cpu')
                pred_mask[pred_mask==26] = 27
                # pred_mask[blank_area] = 255
                pred_mask = np.array(pred_mask, dtype=np.int32)

                # print(f"pred masks: {pred_mask}")

                vis_output = visualizer.draw_sem_seg(
                    pred_mask
                )
            else:
                # print(f"gt: {gt}") 
                # gt = gt-1
                # gt[gt ==-1]= 255
                vis_output = visualizer.draw_sem_seg(
                    gt.astype(np.int32)
                )
                # print(f"gt: {gt}") 
        else:
            raise NotImplementedError

        return predictions, vis_output