# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for Video Mask2Former."""

import itertools
import math
import random
import warnings
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union

import numpy as np

from transformers.image_processing_utils import (
    BaseImageProcessor,
    BatchFeature,
    get_size_dict,
)
from transformers.image_transforms import (
    PaddingMode,
    get_resize_output_image_size,
    normalize,
    pad,
    rescale,
    resize,
    to_channel_dimension_format,
)
from transformers.image_utils import (
    ChannelDimension,
    ImageInput,
    PILImageResampling,
    get_image_size,
    infer_channel_dimension_format,
    is_batched,
    valid_images,
    to_numpy_array,
)
from transformers.utils import (
    IMAGENET_DEFAULT_MEAN,
    IMAGENET_DEFAULT_STD,
    TensorType,
    is_torch_available,
    is_torch_tensor,
    logging,
)


logger = logging.get_logger(__name__)


if is_torch_available():
    import torch
    from torch import nn


# Copied from transformers.models.detr.image_processing_detr.max_across_indices
def max_across_indices(values: Iterable[Any]) -> List[Any]:
    """
    Return the maximum value across all indices of an iterable of values.
    """
    return [max(values_i) for values_i in zip(*values)]


# Copied from transformers.models.detr.image_processing_detr.get_max_height_width
def get_max_height_width(images: List[np.ndarray]) -> List[int]:
    """
    Get the maximum height and width across all images in a batch.
    """
    input_channel_dimension = infer_channel_dimension_format(images[0])

    if input_channel_dimension == ChannelDimension.FIRST:
        _, max_height, max_width = max_across_indices([img.shape for img in images])
    elif input_channel_dimension == ChannelDimension.LAST:
        max_height, max_width, _ = max_across_indices([img.shape for img in images])
    else:
        raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}")
    return (max_height, max_width)


# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
    """
    Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.

    Args:
        image (`np.ndarray`):
            Image to make the pixel mask for.
        output_size (`Tuple[int, int]`):
            Output size of the mask.
    """
    input_height, input_width = get_image_size(image)
    mask = np.zeros(output_size, dtype=np.int64)
    mask[:input_height, :input_width] = 1
    return mask


# Copied from transformers.models.detr.image_processing_detr.binary_mask_to_rle
def binary_mask_to_rle(mask):
    """
    Converts given binary mask of shape `(height, width)` to the run-length encoding (RLE) format.

    Args:
        mask (`torch.Tensor` or `numpy.array`):
            A binary mask tensor of shape `(height, width)` where 0 denotes background and 1 denotes the target
            segment_id or class_id.
    Returns:
        `List`: Run-length encoded list of the binary mask. Refer to COCO API for more information about the RLE
        format.
    """
    if is_torch_tensor(mask):
        mask = mask.numpy()

    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return list(runs)


# Copied from transformers.models.detr.image_processing_detr.convert_segmentation_to_rle
def convert_segmentation_to_rle(segmentation):
    """
    Converts given segmentation map of shape `(height, width)` to the run-length encoding (RLE) format.

    Args:
        segmentation (`torch.Tensor` or `numpy.array`):
            A segmentation map of shape `(height, width)` where each value denotes a segment or class id.
    Returns:
        `List[List]`: A list of lists, where each list is the run-length encoding of a segment / class id.
    """
    segment_ids = torch.unique(segmentation)

    run_length_encodings = []
    for idx in segment_ids:
        mask = torch.where(segmentation == idx, 1, 0)
        rle = binary_mask_to_rle(mask)
        run_length_encodings.append(rle)

    return run_length_encodings


# Copied from transformers.models.detr.image_processing_detr.remove_low_and_no_objects
def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels):
    """
    Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and
    `labels`.

    Args:
        masks (`torch.Tensor`):
            A tensor of shape `(num_queries, height, width)`.
        scores (`torch.Tensor`):
            A tensor of shape `(num_queries)`.
        labels (`torch.Tensor`):
            A tensor of shape `(num_queries)`.
        object_mask_threshold (`float`):
            A number between 0 and 1 used to binarize the masks.
    Raises:
        `ValueError`: Raised when the first dimension doesn't match in all input tensors.
    Returns:
        `Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region
        < `object_mask_threshold`.
    """
    if not (masks.shape[0] == scores.shape[0] == labels.shape[0]):
        raise ValueError("mask, scores and labels must have the same shape!")

    to_keep = labels.ne(num_labels) & (scores > object_mask_threshold)

    return masks[to_keep], scores[to_keep], labels[to_keep]


# Copied from transformers.models.detr.image_processing_detr.check_segment_validity
def check_segment_validity(
    mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8
):
    # Get the mask associated with the k class
    mask_k = mask_labels == k
    mask_k_area = mask_k.sum()

    # Compute the area of all the stuff in query k
    original_area = (mask_probs[:, k] >= mask_threshold).sum()
    mask_exists = mask_k_area > 0 and original_area > 0

    # Eliminate disconnected tiny segments
    if mask_exists:
        area_ratio = mask_k_area / original_area
        if not area_ratio.item() > overlap_mask_area_threshold:
            mask_exists = False

    return mask_exists, mask_k


# Copied from transformers.models.detr.image_processing_detr.compute_segments
def compute_segments(
    mask_probs,
    pred_scores,
    pred_labels,
    mask_threshold: float = 0.5,
    overlap_mask_area_threshold: float = 0.8,
    label_ids_to_fuse: Optional[Set[int]] = None,
    target_size: Tuple[int, int] = None,
):
    height = mask_probs.shape[1] if target_size is None else target_size[0]
    width = mask_probs.shape[2] if target_size is None else target_size[1]

    segmentation = torch.zeros(
        (height, width), dtype=torch.int32, device=mask_probs.device
    )
    segments: List[Dict] = []

    if target_size is not None:
        mask_probs = nn.functional.interpolate(
            mask_probs.unsqueeze(0),
            size=target_size,
            mode="bilinear",
            align_corners=False,
        )[0]

    current_segment_id = 0

    # Weigh each mask by its prediction score
    mask_probs *= pred_scores.view(-1, 1, 1)
    mask_labels = mask_probs.argmax(0)  # [height, width]

    # Keep track of instances of each class
    stuff_memory_list: Dict[str, int] = {}
    for k in range(pred_labels.shape[0]):
        pred_class = pred_labels[k].item()
        should_fuse = pred_class in label_ids_to_fuse

        # Check if mask exists and large enough to be a segment
        mask_exists, mask_k = check_segment_validity(
            mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold
        )

        if mask_exists:
            if pred_class in stuff_memory_list:
                current_segment_id = stuff_memory_list[pred_class]
            else:
                current_segment_id += 1

            # Add current object segment to final segmentation map
            segmentation[mask_k] = current_segment_id
            segment_score = round(pred_scores[k].item(), 6)
            segments.append(
                {
                    "id": current_segment_id,
                    "label_id": pred_class,
                    "was_fused": should_fuse,
                    "score": segment_score,
                }
            )
            if should_fuse:
                stuff_memory_list[pred_class] = current_segment_id

    return segmentation, segments


# TODO: (Amy) Move to image_transforms
def convert_segmentation_map_to_binary_masks(
    segmentation_map: "np.ndarray",
    instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
    ignore_index: Optional[int] = None,
    reduce_labels: bool = False,
):
    if reduce_labels and ignore_index is None:
        raise ValueError("If `reduce_labels` is True, `ignore_index` must be provided.")

    if reduce_labels:
        segmentation_map = np.where(
            segmentation_map == 0, ignore_index, segmentation_map - 1
        )

    # Get unique ids (class or instance ids based on input) present in the video frame
    all_frame_labels = np.unique(segmentation_map)

    # Drop background label if applicable
    if ignore_index is not None:
        all_frame_labels = all_frame_labels[all_frame_labels != ignore_index]

    # Generate a binary mask for each object instance
    binary_masks = [(segmentation_map == i) for i in all_frame_labels]
    binary_masks = np.stack(binary_masks, axis=0)  # (num_labels, height, width)

    # Convert instance ids to class ids
    if instance_id_to_semantic_id is not None:
        labels = np.zeros(all_frame_labels.shape[0], dtype=np.int64)
        for label in all_frame_labels:
            class_id = instance_id_to_semantic_id[label + 1 if reduce_labels else label]
            labels[all_frame_labels == label] = (
                class_id - 1 if reduce_labels else class_id
            )
    else:
        labels = all_frame_labels

    return (
        binary_masks.astype(np.float32),
        labels.astype(np.int64),
        all_frame_labels.astype(np.int64),
    )


def get_mask2former_resize_output_image_size(
    image: np.ndarray,
    size: Union[int, Tuple[int, int], List[int], Tuple[int]],
    max_size: Optional[int] = None,
    size_divisor: int = 0,
    default_to_square: bool = True,
) -> tuple:
    """
    Computes the output size given the desired size.

    Args:
        input_image (`np.ndarray`):
            The input image.
        size (`int`, `Tuple[int, int]`, `List[int]`, `Tuple[int]`):
            The size of the output image.
        default_to_square (`bool`, *optional*, defaults to `True`):
            Whether to default to square if no size is provided.
        max_size (`int`, *optional*):
            The maximum size of the output image.
        size_divisible (`int`, *optional*, defaults to `0`):
            If size_divisible is given, the output image size will be divisible by the number.

    Returns:
        `Tuple[int, int]`: The output size.
    """
    output_size = get_resize_output_image_size(
        input_image=image,
        size=size,
        default_to_square=default_to_square,
        max_size=max_size,
    )

    if size_divisor > 0:
        height, width = output_size
        height = int(math.ceil(height / size_divisor) * size_divisor)
        width = int(math.ceil(width / size_divisor) * size_divisor)
        output_size = (height, width)

    return output_size


class VideoMask2FormerImageProcessor(BaseImageProcessor):
    r"""
    Constructs a Video Mask2Former image processor. The image processor can be used to prepare image(s) and optional
    targets for the model.

    This image processor inherits from [`BaseImageProcessor`] which contains most of the main methods. Users should
    refer to this superclass for more information regarding those methods.

    Args:
        do_resize (`bool`, *optional*, defaults to `True`):
            Whether to resize the input to a certain `size`.
        size (`int`, *optional*, defaults to 800):
            Resize the input to the given size. Only has an effect if `do_resize` is set to `True`. If size is a
            sequence like `(width, height)`, output size will be matched to this. If size is an int, smaller edge of
            the image will be matched to this number. i.e, if `height > width`, then image will be rescaled to `(size *
            height / width, size)`.
        max_size (`int`, *optional*, defaults to 1333):
            The largest size an image dimension can have (otherwise it's capped). Only has an effect if `do_resize` is
            set to `True`.
        resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
            An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
            `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
            `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
            to `True`.
        size_divisor (`int`, *optional*, defaults to 32):
            Some backbones need images divisible by a certain number. If not passed, it defaults to the value used in
            Swin Transformer.
        do_rescale (`bool`, *optional*, defaults to `True`):
            Whether to rescale the input to a certain `scale`.
        rescale_factor (`float`, *optional*, defaults to 1/ 255):
            Rescale the input by the given factor. Only has an effect if `do_rescale` is set to `True`.
        do_normalize (`bool`, *optional*, defaults to `True`):
            Whether or not to normalize the input with mean and standard deviation.
        image_mean (`int`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
            The sequence of means for each channel, to be used when normalizing images. Defaults to the ImageNet mean.
        image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
            The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the
            ImageNet std.
        ignore_index (`int`, *optional*):
            Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels
            denoted with 0 (background) will be replaced with `ignore_index`.
        reduce_labels (`bool`, *optional*, defaults to `False`):
            Whether or not to decrement all label values of segmentation maps by 1. Usually used for datasets where 0
            is used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k).
            The background label will be replaced by `ignore_index`.

    """

    model_input_names = ["pixel_values", "pixel_mask"]

    def __init__(
        self,
        do_resize: bool = True,
        size: Dict[str, int] = None,
        size_divisor: int = 32,
        resample: PILImageResampling = PILImageResampling.BILINEAR,  # type: ignore
        do_rescale: bool = True,
        rescale_factor: float = 1 / 255,
        do_normalize: bool = True,
        image_mean: Union[float, List[float]] = None,
        image_std: Union[float, List[float]] = None,
        ignore_index: Optional[int] = None,
        reduce_labels: bool = False,
        **kwargs,
    ):
        if "size_divisibility" in kwargs:
            warnings.warn(
                "The `size_divisibility` argument is deprecated and will be removed in v4.27. Please use "
                "`size_divisor` instead.",
                FutureWarning,
            )
            size_divisor = kwargs.pop("size_divisibility")
        if "max_size" in kwargs:
            warnings.warn(
                "The `max_size` argument is deprecated and will be removed in v4.27. Please use size['longest_edge']"
                " instead.",
                FutureWarning,
            )
            # We make max_size a private attribute so we can pass it as a default value in the preprocess method whilst
            # `size` can still be pass in as an int
            self._max_size = kwargs.pop("max_size")
        else:
            self._max_size = 1333

        size = (
            size
            if size is not None
            else {"shortest_edge": 800, "longest_edge": self._max_size}
        )
        size = get_size_dict(size, max_size=self._max_size, default_to_square=False)

        super().__init__(**kwargs)
        self.do_resize = do_resize
        self.size = size
        self.resample = resample
        self.size_divisor = size_divisor
        self.do_rescale = do_rescale
        self.rescale_factor = rescale_factor
        self.do_normalize = do_normalize
        self.image_mean = (
            image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
        )
        self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
        self.ignore_index = ignore_index
        self.reduce_labels = reduce_labels

    @classmethod
    def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
        """
        Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
        created using from_dict and kwargs e.g. `Mask2FormerImageProcessor.from_pretrained(checkpoint, max_size=800)`
        """
        image_processor_dict = image_processor_dict.copy()
        if "max_size" in kwargs:
            image_processor_dict["max_size"] = kwargs.pop("max_size")
        if "size_divisibility" in kwargs:
            image_processor_dict["size_divisibility"] = kwargs.pop("size_divisibility")
        return super().from_dict(image_processor_dict, **kwargs)

    @property
    def size_divisibility(self):
        warnings.warn(
            "The `size_divisibility` property is deprecated and will be removed in v4.27. Please use "
            "`size_divisor` instead.",
            FutureWarning,
        )
        return self.size_divisor

    @property
    def max_size(self):
        warnings.warn(
            "The `max_size` property is deprecated and will be removed in v4.27. Please use size['longest_edge']"
            " instead.",
            FutureWarning,
        )
        return self.size["longest_edge"]

    def resize(
        self,
        image: np.ndarray,
        size: Dict[str, int],
        size_divisor: int = 0,
        resample: PILImageResampling = PILImageResampling.BILINEAR,  # type: ignore
        data_format=None,
        **kwargs,
    ) -> np.ndarray:
        """
        Resize the image to the given size. Size can be min_size (scalar) or `(height, width)` tuple. If size is an
        int, smaller edge of the image will be matched to this number.
        """
        if "max_size" in kwargs:
            warnings.warn(
                "The `max_size` parameter is deprecated and will be removed in v4.27. "
                "Please specify in `size['longest_edge'] instead`.",
                FutureWarning,
            )
            max_size = kwargs.pop("max_size")
        else:
            max_size = None
        size = get_size_dict(size, max_size=max_size, default_to_square=False)
        if "shortest_edge" in size and "longest_edge" in size:
            size, max_size = size["shortest_edge"], size["longest_edge"]
        elif "height" in size and "width" in size:
            size = (size["height"], size["width"])
            max_size = None
        else:
            raise ValueError(
                "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
                f" {size.keys()}."
            )
        size = get_mask2former_resize_output_image_size(
            image=image,
            size=size,
            max_size=max_size,
            size_divisor=size_divisor,
            default_to_square=False,
        )

        image = resize(image, size=size, resample=resample, data_format=data_format)
        return image

    def rescale(
        self,
        image: np.ndarray,
        rescale_factor: float,
        data_format: Optional[ChannelDimension] = None,
    ) -> np.ndarray:
        """
        Rescale the image by the given factor.
        """
        return rescale(image, rescale_factor, data_format=data_format)

    def normalize(
        self,
        image: np.ndarray,
        mean: Union[float, Iterable[float]],
        std: Union[float, Iterable[float]],
        data_format: Optional[ChannelDimension] = None,
    ) -> np.ndarray:
        """
        Normalize the image with the given mean and standard deviation.
        """
        return normalize(image, mean=mean, std=std, data_format=data_format)

    def convert_segmentation_map_to_binary_masks(
        self,
        segmentation_map: "np.ndarray",
        instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
        ignore_index: Optional[int] = None,
        reduce_labels: bool = False,
    ):
        reduce_labels = (
            reduce_labels if reduce_labels is not None else self.reduce_labels
        )
        ignore_index = ignore_index if ignore_index is not None else self.ignore_index
        return convert_segmentation_map_to_binary_masks(
            segmentation_map=segmentation_map,
            instance_id_to_semantic_id=instance_id_to_semantic_id,
            ignore_index=ignore_index,
            reduce_labels=reduce_labels,
        )

    def __call__(
        self, images, segmentation_maps=None, do_sampling=False, **kwargs
    ) -> BatchFeature:
        return self.preprocess(
            images,
            segmentation_maps=segmentation_maps,
            do_sampling=do_sampling,
            **kwargs,
        )

    def _preprocess(
        self,
        image: ImageInput,
        do_resize: bool = None,
        size: Dict[str, int] = None,
        size_divisor: int = None,
        resample: PILImageResampling = None,  # type: ignore
        do_rescale: bool = None,
        rescale_factor: float = None,
        do_normalize: bool = None,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
    ):
        if do_resize:
            image = self.resize(
                image, size=size, size_divisor=size_divisor, resample=resample
            )
        if do_rescale:
            image = self.rescale(image, rescale_factor=rescale_factor)
        if do_normalize:
            image = self.normalize(image, mean=image_mean, std=image_std)
        return image

    def _preprocess_image(
        self,
        image: ImageInput,
        do_resize: bool = None,
        size: Dict[str, int] = None,
        size_divisor: int = None,
        resample: PILImageResampling = None,  # type: ignore
        do_rescale: bool = None,
        rescale_factor: float = None,
        do_normalize: bool = None,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        data_format: Optional[Union[str, ChannelDimension]] = None,
    ) -> np.ndarray:
        """Preprocesses a single image."""
        # All transformations expect numpy arrays.
        image = to_numpy_array(image)
        image = self._preprocess(
            image=image,
            do_resize=do_resize,
            size=size,
            size_divisor=size_divisor,
            resample=resample,
            do_rescale=do_rescale,
            rescale_factor=rescale_factor,
            do_normalize=do_normalize,
            image_mean=image_mean,
            image_std=image_std,
        )
        if data_format is not None:
            image = to_channel_dimension_format(image, data_format)
        return image

    def _preprocess_mask(
        self,
        segmentation_map: ImageInput,
        do_resize: bool = None,
        size: Dict[str, int] = None,
        size_divisor: int = 0,
    ) -> np.ndarray:
        """Preprocesses a single mask."""
        segmentation_map = to_numpy_array(segmentation_map)
        # Add channel dimension if missing - needed for certain transformations
        added_channel_dim = False
        if segmentation_map.ndim == 2:
            added_channel_dim = True
            segmentation_map = segmentation_map[None, ...]

        # TODO: (Amy)
        # Remork segmentation map processing to include reducing labels and resizing which doesn't
        # drop segment IDs > 255.
        segmentation_map = self._preprocess(
            image=segmentation_map,
            do_resize=do_resize,
            resample=PILImageResampling.NEAREST,
            size=size,
            size_divisor=size_divisor,
            do_rescale=False,
            do_normalize=False,
        )

        # Remove extra channel dimension if added for processing
        if added_channel_dim:
            segmentation_map = segmentation_map.squeeze(0)

        return segmentation_map

    def preprocess(
        self,
        video_frames: ImageInput,
        segmentation_maps: Optional[ImageInput] = None,
        do_sampling: bool = False,
        instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
        do_resize: Optional[bool] = None,
        size: Optional[Dict[str, int]] = None,
        size_divisor: Optional[int] = None,
        resample: PILImageResampling = None,  # type: ignore
        sampling_frame_range: Optional[int] = 20,
        sampling_frame_num: Optional[int] = 2,
        sampling_frame_shuffle: Optional[bool] = False,
        do_rescale: Optional[bool] = None,
        rescale_factor: Optional[float] = None,
        do_normalize: Optional[bool] = None,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        ignore_index: Optional[int] = None,
        reduce_labels: Optional[bool] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
        **kwargs,
    ) -> BatchFeature:
        if "pad_and_return_pixel_mask" in kwargs:
            warnings.warn(
                "The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version",
                FutureWarning,
            )

        do_resize = do_resize if do_resize is not None else self.do_resize
        size = size if size is not None else self.size
        size = get_size_dict(size, default_to_square=False, max_size=self._max_size)
        size_divisor = size_divisor if size_divisor is not None else self.size_divisor
        resample = resample if resample is not None else self.resample
        do_rescale = do_rescale if do_rescale is not None else self.do_rescale
        rescale_factor = (
            rescale_factor if rescale_factor is not None else self.rescale_factor
        )
        do_normalize = do_normalize if do_normalize is not None else self.do_normalize
        image_mean = image_mean if image_mean is not None else self.image_mean
        image_std = image_std if image_std is not None else self.image_std
        ignore_index = ignore_index if ignore_index is not None else self.ignore_index
        reduce_labels = (
            reduce_labels if reduce_labels is not None else self.reduce_labels
        )

        if do_resize is not None and size is None or size_divisor is None:
            raise ValueError(
                "If `do_resize` is True, `size` and `size_divisor` must be provided."
            )

        if do_rescale is not None and rescale_factor is None:
            raise ValueError(
                "If `do_rescale` is True, `rescale_factor` must be provided."
            )

        if do_normalize is not None and (image_mean is None or image_std is None):
            raise ValueError(
                "If `do_normalize` is True, `image_mean` and `image_std` must be provided."
            )

        if not valid_images(video_frames):
            raise ValueError(
                "Invalid frame type. Must be of type PIL.Image.Image, numpy.ndarray, "
                "torch.Tensor, tf.Tensor or jax.ndarray."
            )

        if segmentation_maps is not None and not valid_images(segmentation_maps):
            raise ValueError(
                "Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
                "torch.Tensor, tf.Tensor or jax.ndarray."
            )

        if not is_batched(video_frames):
            video_frames = [video_frames]
            segmentation_maps = (
                [segmentation_maps] if segmentation_maps is not None else None
            )

        if segmentation_maps is not None and len(video_frames) != len(
            segmentation_maps
        ):
            raise ValueError(
                "Video_frames and segmentation maps must have the same length."
            )

        # do frame sampling (Video Mask2Former model uses only 2 frames for training)
        if do_sampling:
            video_length = len(video_frames)

            # choose a reference frame
            reference_frame = random.randrange(video_length)

            start_idx_range = reference_frame - sampling_frame_range
            start_idx = max(0, start_idx_range)

            end_idx_range = reference_frame + sampling_frame_range + 1
            end_idx = min(video_length, end_idx_range)

            # create a list of frames which doesn't include the reference frame
            selected_frame_list = list(range(start_idx, reference_frame)) + list(
                range(reference_frame + 1, end_idx)
            )

            # choose another frame at random from the `selected_frame_list`
            selected_frame = np.random.choice(
                np.array(selected_frame_list), sampling_frame_num - 1
            )

            # create a list containing the indexes of the 2 chosen frames
            selected_frame_indexes = sorted(selected_frame.tolist() + [reference_frame])

            # shuffle the frames if shuffling is enabled
            if sampling_frame_shuffle:
                random.shuffle(selected_frame_indexes)

            video_frames = [
                video_frames[frame_idx] for frame_idx in selected_frame_indexes
            ]
            segmentation_maps = [
                segmentation_maps[frame_idx] for frame_idx in selected_frame_indexes
            ]

        processed_frames = [
            self._preprocess_image(
                frame,
                do_resize=do_resize,
                size=size,
                size_divisor=size_divisor,
                resample=resample,
                do_rescale=do_rescale,
                rescale_factor=rescale_factor,
                do_normalize=do_normalize,
                image_mean=image_mean,
                image_std=image_std,
                data_format=data_format,
            )
            for frame in video_frames
        ]

        if segmentation_maps is not None:
            segmentation_maps = [
                self._preprocess_mask(segmentation_map, do_resize, size, size_divisor)
                for segmentation_map in segmentation_maps
            ]
        encoded_inputs = self.encode_inputs(
            processed_frames,
            segmentation_maps,
            instance_id_to_semantic_id,
            ignore_index,
            reduce_labels,
            return_tensors,
        )
        return encoded_inputs

    # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image
    def _pad_image(
        self,
        image: np.ndarray,
        output_size: Tuple[int, int],
        constant_values: Union[float, Iterable[float]] = 0,
        data_format: Optional[ChannelDimension] = None,
    ) -> np.ndarray:
        """
        Pad an image with zeros to the given size.
        """
        input_height, input_width = get_image_size(image)
        output_height, output_width = output_size

        pad_bottom = output_height - input_height
        pad_right = output_width - input_width
        padding = ((0, pad_bottom), (0, pad_right))
        padded_image = pad(
            image,
            padding,
            mode=PaddingMode.CONSTANT,
            constant_values=constant_values,
            data_format=data_format,
        )
        return padded_image

    # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.pad
    def pad(
        self,
        images: List[np.ndarray],
        constant_values: Union[float, Iterable[float]] = 0,
        return_pixel_mask: bool = True,
        return_tensors: Optional[Union[str, TensorType]] = None,
        data_format: Optional[ChannelDimension] = None,
    ) -> np.ndarray:
        """
        Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
        in the batch and optionally returns their corresponding pixel mask.

        Args:
            image (`np.ndarray`):
                Image to pad.
            constant_values (`float` or `Iterable[float]`, *optional*):
                The value to use for the padding if `mode` is `"constant"`.
            return_pixel_mask (`bool`, *optional*, defaults to `True`):
                Whether to return a pixel mask.
            input_channel_dimension (`ChannelDimension`, *optional*):
                The channel dimension format of the image. If not provided, it will be inferred from the input image.
            data_format (`str` or `ChannelDimension`, *optional*):
                The channel dimension format of the image. If not provided, it will be the same as the input image.
        """
        pad_size = get_max_height_width(images)

        padded_images = [
            self._pad_image(
                image,
                pad_size,
                constant_values=constant_values,
                data_format=data_format,
            )
            for image in images
        ]
        data = {"pixel_values": padded_images}

        if return_pixel_mask:
            masks = [
                make_pixel_mask(image=image, output_size=pad_size) for image in images
            ]
            data["pixel_mask"] = masks

        return BatchFeature(data=data, tensor_type=return_tensors)

    def encode_inputs(
        self,
        pixel_values_list: List[ImageInput],
        segmentation_maps: ImageInput = None,
        instance_id_to_semantic_id: Optional[
            Union[List[Dict[int, int]], Dict[int, int]]
        ] = None,
        ignore_index: Optional[int] = None,
        reduce_labels: bool = False,
        return_tensors: Optional[Union[str, TensorType]] = None,
        **kwargs,
    ):
        """
        Pad images up to the largest image in a batch and create a corresponding `pixel_mask`.

        Video Mask2Former addresses semantic segmentation with a mask classification paradigm, thus input segmentation
        maps will be converted to lists of binary masks and their respective labels. Let's see an example, assuming
        `segmentation_maps = [[2,6,7,9]]`, the output will contain `mask_labels =
        [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]` (four binary masks) and `class_labels = [2,6,7,9]`, the labels for
        each mask.

        Args:
            pixel_values_list (`List[ImageInput]`):
                List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height,
                width)`.

            segmentation_maps (`ImageInput`, *optional*):
                The corresponding semantic segmentation maps with the pixel-wise annotations.

             (`bool`, *optional*, defaults to `True`):
                Whether or not to pad images up to the largest image in a batch and create a pixel mask.

                If left to the default, will return a pixel mask that is:

                - 1 for pixels that are real (i.e. **not masked**),
                - 0 for pixels that are padding (i.e. **masked**).

            instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*):
                A mapping between object instance ids and class ids. If passed, `segmentation_maps` is treated as an
                instance segmentation map where each pixel represents an instance id. Can be provided as a single
                dictionary with a global/dataset-level mapping or as a list of dictionaries (one per image), to map
                instance ids in each image separately.

            return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
                If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
                objects.

        Returns:
            [`BatchFeature`]: A [`BatchFeature`] with the following fields:

            - **pixel_values** -- Pixel values to be fed to a model.
            - **pixel_mask** -- Pixel mask to be fed to a model (when `=True` or if `pixel_mask` is in
              `self.model_input_names`).
            - **mask_labels** -- Optional list of mask labels of shape `(labels, num_frames, height, width)` to be fed
              to a model (when `annotations` are provided).
            - **class_labels** -- Optional list of class labels of shape `(labels)` to be fed to a model (when
              `annotations` are provided). They identify the labels of `mask_labels`, e.g. the label of
              `mask_labels[i][j]` if `class_labels[i][j]`.
        """
        ignore_index = self.ignore_index if ignore_index is None else ignore_index
        reduce_labels = self.reduce_labels if reduce_labels is None else reduce_labels

        if "pad_and_return_pixel_mask" in kwargs:
            warnings.warn(
                "The `pad_and_return_pixel_mask` argument has no effect and will be removed in v4.27",
                FutureWarning,
            )

        pixel_values_list = [
            to_numpy_array(pixel_values) for pixel_values in pixel_values_list
        ]
        encoded_inputs = self.pad(pixel_values_list, return_tensors=return_tensors)

        if segmentation_maps is not None:
            frame_labels = []
            frame_masks = []
            frame_instances = []

            pad_size = get_max_height_width(pixel_values_list)

            if isinstance(instance_id_to_semantic_id, list):
                # merge instance2class_id mapping with global mapping
                instance_ids_to_semantic_ids = {}
                for idx, instance_id_to_semantic_id in enumerate(
                    instance_id_to_semantic_id
                ):
                    instance_ids_to_semantic_ids.update(instance_id_to_semantic_id)
            else:
                instance_id = instance_id_to_semantic_id
            for idx, segmentation_map in enumerate(segmentation_maps):
                segmentation_map = to_numpy_array(segmentation_map)
                # Use instance2class_id mapping per image
                masks, classes, instance_ids = (
                    self.convert_segmentation_map_to_binary_masks(
                        segmentation_map,
                        instance_ids_to_semantic_ids,
                        ignore_index=ignore_index,
                        reduce_labels=reduce_labels,
                    )
                )
                # We add an axis to make them compatible with the transformations library
                # this will be removed in the future
                masks = [mask[None, ...] for mask in masks]
                masks = [
                    self._pad_image(
                        image=mask, output_size=pad_size, constant_values=ignore_index
                    )
                    for mask in masks
                ]

                # add list of binary masks corresponding to current segmentation map
                frame_masks.append(masks)
                # add list of classes corresponding to current segmentation map
                frame_labels.append(classes)
                # add list of instance ids corresponding to current segmentation map
                frame_instances.append(instance_ids)

            # create list of labels for the whole video using `frame_labels`
            video_instances = np.unique(np.concatenate(frame_instances))
            video_instances2idx = {
                instance: idx for idx, instance in enumerate(video_instances)
            }
            num_video_instances = len(video_instances)
            _, H, W = frame_masks[0][0].shape
            video_labels = [
                (
                    (instance_ids_to_semantic_ids[ins + 1 if reduce_labels else ins])
                    - 1
                    if reduce_labels
                    else (
                        instance_ids_to_semantic_ids[ins + 1 if reduce_labels else ins]
                    )
                )
                for ins in video_instances
            ]

            video_masks = torch.zeros(
                num_video_instances, len(frame_masks), H, W, dtype=torch.int64
            )
            for frame_id, current_frame_mask_list in enumerate(frame_masks):
                for mask_id, mask in enumerate(current_frame_mask_list):
                    instance_id = frame_instances[frame_id][mask_id]
                    video_masks[video_instances2idx[instance_id], frame_id, :, :] = (
                        torch.from_numpy(mask)
                    )

            encoded_inputs["mask_labels"] = video_masks
            encoded_inputs["class_labels"] = torch.from_numpy(
                np.array(video_labels, dtype=np.int64)
            )

        return encoded_inputs

    def post_process_instance_segmentation(
        self,
        outputs,
        threshold: float = 0.5,
        mask_threshold: float = 0.5,
        overlap_mask_area_threshold: float = 0.8,
        target_sizes: Optional[List[Tuple[int, int]]] = None,
        return_coco_annotation: Optional[bool] = False,
        return_binary_maps: Optional[bool] = False,
    ) -> List[Dict]:
        """
        Converts the output of [`VideoMask2FormerForVideoSegmentationOutput`] into instance segmentation predictions.
        Only supports PyTorch.

        Args:
            outputs ([`VideoMask2FormerForVideoSegmentationOutput`]):
                Raw outputs of the model.
            threshold (`float`, *optional*, defaults to 0.5):
                The probability score threshold to keep predicted instance masks.
            mask_threshold (`float`, *optional*, defaults to 0.5):
                Threshold to use when turning the predicted masks into binary values.
            overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
                The overlap mask area threshold to merge or discard small disconnected parts within each binary
                instance mask.
            target_sizes (`List[Tuple]`, *optional*):
                List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
                final size (height, width) of each prediction. If left to None, predictions will not be resized.
            return_coco_annotation (`bool`, *optional*, defaults to `False`):
                If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) format.
            return_binary_maps (`bool`, *optional*, defaults to `False`):
                If set to `True`, segmentation maps are returned as a concatenated tensor of binary segmentation maps
                (one per detected instance).
        Returns:
            `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
            - **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or
              `List[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to
              `True`. Set to `None` if no mask if found above `threshold`.
            - **segments_info** -- A dictionary that contains additional information on each segment.
                - **id** -- An integer representing the `segment_id`.
                - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
                - **score** -- Prediction score of segment with `segment_id`.
        """
        if return_coco_annotation and return_binary_maps:
            raise ValueError(
                "return_coco_annotation and return_binary_maps can not be both set to True."
            )

        # [batch_size, num_queries, num_classes+1]
        class_queries_logits = outputs.class_queries_logits

        # [batch_size, num_queries, num_frames, height, width]
        masks_queries_logits = outputs.masks_queries_logits

        device = masks_queries_logits.device
        num_classes = class_queries_logits.shape[-1] - 1
        num_queries = class_queries_logits.shape[-2]
        batch_size, num_queries, num_frames, height, width = masks_queries_logits.shape

        mask_size = (256, 256)
        num_topk_queries = 10
        masks_queries_logits = masks_queries_logits.permute(0, 2, 1, 3, 4).reshape(
            batch_size * num_frames, num_queries, height, width
        )  # [batch_size*num_frames, num_queries, height, width]

        # Scale back to preprocessed image size for all models
        masks_queries_logits = torch.nn.functional.interpolate(
            masks_queries_logits, size=mask_size, mode="bilinear", align_corners=False
        )

        masks_queries_logits = masks_queries_logits.view(
            batch_size, num_frames, num_queries, mask_size[0], mask_size[1]
        )

        # Loop over items in batch size
        results: List[Dict[str, TensorType]] = []
        for i in range(class_queries_logits.shape[0]):
            # [num_frames, num_queries, height, width]
            mask_pred = masks_queries_logits[i]
            # [num_queries, num_classes+1]
            mask_cls = class_queries_logits[i]

            scores = torch.nn.functional.softmax(mask_cls, dim=-1)[
                :, :-1
            ]  # [num_queries, num_classes]
            labels = (
                torch.arange(num_classes, device=device)
                .unsqueeze(0)
                .repeat(num_queries, 1)
                .flatten(0, 1)
            )  # [num_queries*num_classes]

            # keep top-k predictions
            scores_per_image, topk_indices = scores.flatten(0, 1).topk(
                num_topk_queries, sorted=False
            )
            labels_per_image = labels[topk_indices]

            topk_indices = torch.div(topk_indices, num_classes, rounding_mode="floor")
            mask_pred = mask_pred[:, topk_indices, ...]
            pred_masks = (
                mask_pred > 0
            ).float()  # [num_frames, num_topk_queries, height, width]

            # calculate final prediction scores, masks, classes
            mask_scores_per_image = (
                mask_pred.sigmoid().permute(1, 0, 2, 3).flatten(1)
                * pred_masks.permute(1, 0, 2, 3).flatten(1)
            ).sum(1) / (
                pred_masks.permute(1, 0, 2, 3).flatten(1).sum(1) + 1e-6
            )  # [num_topk_queries]
            pred_scores = scores_per_image * mask_scores_per_image  # [num_topk_queries]
            pred_classes = labels_per_image
            segmentation = torch.zeros((num_frames, mask_size[0], mask_size[1])) - 1
            # resize prediction masks
            if target_sizes is not None:
                segmentation = (
                    torch.zeros((num_frames, target_sizes[i][0], target_sizes[i][1]))
                    - 1
                )
                pred_masks = pred_masks.view(
                    num_frames * num_topk_queries, mask_size[0], mask_size[1]
                )
                pred_masks = torch.nn.functional.interpolate(
                    pred_masks.unsqueeze(0), size=target_sizes[i], mode="bilinear"
                )[0]
                pred_masks = pred_masks.view(
                    num_frames, num_topk_queries, target_sizes[i][0], target_sizes[i][1]
                )

            instance_maps, segments = [], []
            current_segment_id = 0
            for j in range(num_topk_queries):
                score = pred_scores[j].item()
                if not torch.all(pred_masks[:, j] == 0) and score >= threshold:
                    segmentation[pred_masks[:, j] == 1] = current_segment_id
                    segments.append(
                        {
                            "id": current_segment_id,
                            "label_id": pred_classes[j].item(),
                            "was_fused": False,
                            "score": round(score, 6),
                        }
                    )
                    current_segment_id += 1
                    instance_maps.append(pred_masks[:, j])
                    # Return segmentation map in run-length encoding (RLE) format
                    if return_coco_annotation:
                        segmentation = convert_segmentation_to_rle(segmentation)
            # Return a concatenated tensor of binary instance maps
            if return_binary_maps and len(instance_maps) != 0:
                segmentation = torch.stack(instance_maps, dim=0)
            # Compute confidence map
            if target_sizes is not None:
                masks_queries_logits_interp = nn.functional.interpolate(
                    mask_pred.view(
                        num_frames * num_queries, mask_size[0], mask_size[1]
                    ),
                    size=target_sizes[i],
                    mode="bilinear",
                    align_corners=False,
                ).view(num_frames, num_queries, target_sizes[i][0], target_sizes[i][1])
            else:
                masks_queries_logits_interp = mask_pred
            mask_probs = (
                masks_queries_logits_interp.sigmoid()
            )  # [num_frames, num_queries, height, width]
            P_class = nn.functional.softmax(
                mask_cls, dim=-1
            )  # [num_queries, num_classes+1]
            confidence_map = (
                P_class[None, :, :, None, None] * mask_probs[:, :, None, :, :]
            )  # [num_frames, num_query, num_classes+1, height, width]
            results.append(
                {
                    "segmentation": segmentation,
                    "segments_info": segments,
                    "query_class_logits": confidence_map,
                }
            )
        return results

    def post_process_panoptic_segmentation(
        self,
        outputs,
        threshold: float = 0.5,
        mask_threshold: float = 0.5,
        overlap_mask_area_threshold: float = 0.8,
        label_ids_to_fuse: Optional[Set[int]] = None,
        target_sizes: Optional[List[Tuple[int, int]]] = None,
        word_embeddings: Optional[List[torch.Tensor]] = None,
    ) -> List[Dict]:
        """
        Converts the output of [`Mask2FormerForUniversalSegmentationOutput`] into image panoptic segmentation
        predictions. Only supports PyTorch.

        Args:
            outputs ([`Mask2FormerForUniversalSegmentationOutput`]):
                The outputs from [`Mask2FormerForUniversalSegmentation`].
            threshold (`float`, *optional*, defaults to 0.5):
                The probability score threshold to keep predicted instance masks.
            mask_threshold (`float`, *optional*, defaults to 0.5):
                Threshold to use when turning the predicted masks into binary values.
            overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
                The overlap mask area threshold to merge or discard small disconnected parts within each binary
                instance mask.
            label_ids_to_fuse (`Set[int]`, *optional*):
                The labels in this state will have all their instances be fused together. For instance we could say
                there can only be one sky in an image, but several persons, so the label ID for sky would be in that
                set, but not the one for person.
            target_sizes (`List[Tuple]`, *optional*):
                List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
                final size (height, width) of each prediction in batch. If left to None, predictions will not be
                resized.

        Returns:
            `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
            - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`, set
              to `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized
              to the corresponding `target_sizes` entry.
            - **segments_info** -- A dictionary that contains additional information on each segment.
                - **id** -- an integer representing the `segment_id`.
                - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
                - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
                  Multiple instances of the same class / label were fused and assigned a single `segment_id`.
                - **score** -- Prediction score of segment with `segment_id`.
        """

        if label_ids_to_fuse is None:
            logger.warning("`label_ids_to_fuse` unset. No instance will be fused.")
            label_ids_to_fuse = set()

        # [batch_size, num_queries, num_classes+1]
        class_queries_logits = outputs.class_queries_logits

        # [batch_size, num_queries, num_frames, height, width]
        masks_queries_logits = outputs.masks_queries_logits

        device = masks_queries_logits.device
        batch_size = class_queries_logits.shape[0]
        num_labels = class_queries_logits.shape[-1] - 1
        batch_size, num_queries, num_frames, height, width = masks_queries_logits.shape
        mask_size = (256, 256)
        masks_queries_logits = masks_queries_logits.permute(0, 2, 1, 3, 4).reshape(
            batch_size * num_frames, num_queries, height, width
        )  # [batch_size*num_frames, num_queries, height, width]
        # Scale back to preprocessed image size for all models
        masks_queries_logits = torch.nn.functional.interpolate(
            masks_queries_logits, size=mask_size, mode="bilinear", align_corners=False
        )
        masks_queries_logits = masks_queries_logits.view(
            batch_size, num_frames, num_queries, mask_size[0], mask_size[1]
        )

        mask_probs = (
            masks_queries_logits.sigmoid()
        )  # [batch_size, num_frames, num_queries, height, width]
        class_probs = nn.functional.softmax(
            class_queries_logits, dim=-1
        )  # [batch_size, num_queries, num_classes+1]

        # Predicted label and score of each query (batch_size, num_queries)
        pred_scores, pred_labels = class_probs.max(-1)

        # Loop over items in batch size
        results: List[Dict[str, TensorType]] = []

        for i in range(batch_size):
            mask_probs_item, pred_scores_item, pred_labels_item, class_probs_item = (
                mask_probs[i],
                pred_scores[i],
                pred_labels[i],
                class_probs[i],
            )
            if word_embeddings is not None:
                word_logits = word_embeddings[i]
                preserve_queries = word_logits.argmax(dim=-1)
                preserve_queries = torch.unique(preserve_queries)
            queries_idx = torch.arange(num_queries, device=device)
            to_keep = pred_labels_item.ne(num_labels) & (pred_scores_item > threshold)
            keep_queries_idx = queries_idx[to_keep]
            mask_probs_item = mask_probs_item[
                :, to_keep, ...
            ]  # [num_frames, num_queries, height, width]
            pred_scores_item = pred_scores_item[to_keep]  # [num_queries]
            pred_labels_item = pred_labels_item[to_keep]  # [num_queries]
            class_probs_item = class_probs_item[
                to_keep, ...
            ]  # [num_queries, num_classes+1]
            # No mask found
            intersaction_keep_queries_idx = keep_queries_idx
            if word_embeddings is not None:
                intersaction_keep_queries_idx = keep_queries_idx[
                    torch.isin(keep_queries_idx, preserve_queries)
                ]  # [num_queries]
            if (
                keep_queries_idx.shape[0] <= 0
                or intersaction_keep_queries_idx.shape[0] <= 0
            ):
                height, width = (
                    target_sizes[i]
                    if target_sizes is not None
                    else mask_probs_item.shape[2:]
                )
                segmentation = (
                    torch.zeros((num_frames, height, width), device=device) - 1
                )
                query_class_logits = torch.zeros(
                    (num_frames, 1, num_labels + 1, height, width), device=device
                )
                query_class_logits[:, 0, -1, :, :] = 1
                results.append(
                    {
                        "segmentation": segmentation,
                        "segments_info": [],
                        "query_class_logits": query_class_logits,
                        "query_scores": [0.0],
                    }
                )
                continue

            # Get segmentation map and segment information of batch item
            target_size = target_sizes[i] if target_sizes is not None else None
            segmentation = torch.zeros(
                (num_frames, target_size[0], target_size[1]),
                dtype=torch.int32,
                device=mask_probs_item.device,
            )
            segments: List[Dict] = []
            if target_size is not None:
                mask_probs_item = nn.functional.interpolate(
                    mask_probs_item,
                    size=target_size,
                    mode="bilinear",
                    align_corners=False,
                )
            current_segment_id = 0
            # Weigh each mask by its prediction score
            mask_probs_item_weight = pred_scores_item[None, :, None, None].expand(
                num_frames, -1, -1, -1
            )  # [num_frames, num_queries, 1, 1]
            mask_probs_item_weighted = (
                mask_probs_item * mask_probs_item_weight
            )  # [num_frames, num_queries, height, width]

            mask_labels_item = mask_probs_item_weighted.argmax(
                1
            )  # [num_frames, height, width]
            # Keep track of instances of each class
            stuff_memory_list: Dict[str, int] = {}
            keep_querys = []
            keep_query_scores = []
            for k in range(pred_labels_item.shape[0]):
                pred_class_item = pred_labels_item[k].item()
                should_fuse = pred_class_item in label_ids_to_fuse
                query_idx = keep_queries_idx[k]
                if word_embeddings is not None:
                    if query_idx not in preserve_queries:
                        continue
                # Check if mask exists and large enough to be a segment
                # Get the mask associated with the k class
                mask_k = mask_labels_item == k
                mask_k_area = mask_k.sum()

                # Compute the area of all the stuff in query k
                original_area = (mask_probs_item_weighted[:, k] >= mask_threshold).sum()
                mask_exists = mask_k_area > 0 and original_area > 0

                # Eliminate disconnected tiny segments
                if mask_exists:
                    area_ratio = mask_k_area / original_area
                    if not area_ratio.item() > overlap_mask_area_threshold:
                        mask_exists = False

                if mask_exists:
                    if pred_class_item in stuff_memory_list:
                        current_segment_id_should_fuse = stuff_memory_list[
                            pred_class_item
                        ]
                    else:
                        current_segment_id += 1
                        current_segment_id_should_fuse = current_segment_id
                    # Add current object segment to final segmentation map
                    segmentation[mask_k] = (
                        current_segment_id
                        if not should_fuse
                        else current_segment_id_should_fuse
                    )
                    segment_score = round(pred_scores_item[k].item(), 6)
                    segments.append(
                        {
                            "id": (
                                current_segment_id
                                if not should_fuse
                                else current_segment_id_should_fuse
                            ),
                            "label_id": pred_class_item,
                            "was_fused": should_fuse,
                            "score": segment_score,
                        }
                    )
                    keep_querys.append(k)
                    keep_query_scores.append(segment_score)
                    if should_fuse and pred_class_item not in stuff_memory_list:
                        stuff_memory_list[pred_class_item] = current_segment_id

            # calculate confidence map for each class
            query_class_logits = (
                class_probs_item[None, :, :, None, None]
                * mask_probs_item[:, :, None, :, :]
            )  # [num_frames, num_query, num_classes+1, height, width]
            query_class_logits = query_class_logits[:, keep_querys, ...]
            results.append(
                {
                    "segmentation": segmentation,
                    "segments_info": segments,
                    "query_class_logits": query_class_logits,
                    "query_scores": keep_query_scores,
                }
            )
        return results
