import logging
import math
from typing import Tuple

import torchvision.transforms as T
from torchvision.ops.boxes import batched_nms

from transformers import AutoProcessor, Blip2Config, Blip2ForConditionalGeneration

import torch
from torch import nn
from torch.nn import functional as F
import pdb
from detectron2.config import configurable
from detectron2.data import MetadataCatalog
from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
from detectron2.modeling.backbone import Backbone
from detectron2.modeling.postprocessing import sem_seg_postprocess
from detectron2.structures import Boxes, ImageList, Instances, BitMasks

import numpy as np
from transformers.generation.utils import GenerationMode
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
import random
from .modeling.criterion import VideoSetCriterion
from .modeling.matcher import VideoHungarianMatcher
from .utils.memory import retry_if_cuda_oom
from detectron2.structures import Boxes
from torchvision.utils import save_image
import time

logger = logging.getLogger(__name__)


@META_ARCH_REGISTRY.register()
class VideoMaskFormer(nn.Module):
    """
    Main class for mask classification semantic segmentation architectures.
    """

    @configurable
    def __init__(
        self,
        *,
        backbone: Backbone,
        sem_seg_head: nn.Module,
        criterion: nn.Module,
        num_queries: int,
        object_mask_threshold: float,
        overlap_threshold: float,
        metadata,
        size_divisibility: int,
        sem_seg_postprocess_before_inference: bool,
        pixel_mean: Tuple[float],
        pixel_std: Tuple[float],
        # video
        num_frames,
        cfg,
        generate_captions,
        queries_2types,
    ):
        """
        Args:
            backbone: a backbone module, must follow detectron2's backbone interface
            sem_seg_head: a module that predicts semantic segmentation from backbone features
            criterion: a module that defines the loss
            num_queries: int, number of queries
            object_mask_threshold: float, threshold to filter query based on classification score
                for panoptic segmentation inference
            overlap_threshold: overlap threshold used in general inference for panoptic segmentation
            metadata: dataset meta, get `thing` and `stuff` category names for panoptic
                segmentation inference
            size_divisibility: Some backbones require the input height and width to be divisible by a
                specific integer. We can use this to override such requirement.
            sem_seg_postprocess_before_inference: whether to resize the prediction back
                to original input size before semantic segmentation inference or after.
                For high-resolution dataset like Mapillary, resizing predictions before
                inference will cause OOM error.
            pixel_mean, pixel_std: list or tuple with #channels element, representing
                the per-channel mean and std to be used to normalize the input image
            semantic_on: bool, whether to output semantic segmentation prediction
            instance_on: bool, whether to output instance segmentation prediction
            panoptic_on: bool, whether to output panoptic segmentation prediction
            test_topk_per_image: int, instance segmentation parameter, keep topk instances per image
        """
        super().__init__()
        self.backbone = backbone
        self.sem_seg_head = sem_seg_head
        self.criterion = criterion
        self.num_queries = num_queries
        self.overlap_threshold = overlap_threshold
        self.object_mask_threshold = object_mask_threshold
        self.metadata = metadata
        if size_divisibility < 0:
            # use backbone size_divisibility if not set
            size_divisibility = self.backbone.size_divisibility
        self.size_divisibility = size_divisibility
        self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
        self.register_buffer(
            "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False
        )
        self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
        self.cfg = cfg
        self.generate_captions = generate_captions
        self.queries_2types = queries_2types
        self.num_frames = num_frames

        if self.generate_captions:
            self.blip_processor = AutoProcessor.from_pretrained(
                "Salesforce/blip2-opt-2.7b"
            )
            self.blip_config = Blip2Config().from_pretrained(
                "Salesforce/blip2-opt-2.7b"
            )
            self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
                "Salesforce/blip2-opt-2.7b"
            )
            self.blip_model.language_model.requires_grad_(False)
            self.blip_model.vision_model.requires_grad_(False)
            self.blip_model.language_model = self.blip_model.language_model.half()
            self.blip_model.vision_model = self.blip_model.vision_model.half()

            self.caption_projection = nn.Linear(256, 2560)

    @classmethod
    def from_config(cls, cfg):
        backbone = build_backbone(cfg)
        sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())

        # Loss parameters:
        deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
        no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT

        # loss weights
        class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT
        dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
        mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT

        # building criterion
        matcher = VideoHungarianMatcher(
            cost_class=class_weight,
            cost_mask=mask_weight,
            cost_dice=dice_weight,
            num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
        )

        weight_dict = {
            "loss_ce": class_weight,
            "loss_mask": mask_weight,
            "loss_dice": dice_weight,
        }

        if deep_supervision:
            dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
            aux_weight_dict = {}
            for i in range(dec_layers - 1):
                aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
            weight_dict.update(aux_weight_dict)

        losses = ["labels", "masks"]

        if cfg.MODEL.MASK_FORMER.CONTRASTIVE:
            losses.append("contrastive")
            weight_dict["loss_contrastive"] = 0.01

        criterion = VideoSetCriterion(
            sem_seg_head.num_classes,
            matcher=matcher,
            weight_dict=weight_dict,
            eos_coef=no_object_weight,
            losses=losses,
            num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
            oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
            importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
            contrastive=cfg.MODEL.MASK_FORMER.CONTRASTIVE,
        )

        return {
            "backbone": backbone,
            "sem_seg_head": sem_seg_head,
            "criterion": criterion,
            "num_queries": cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES,
            "object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD,
            "overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD,
            "metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
            "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
            "sem_seg_postprocess_before_inference": True,
            "pixel_mean": cfg.MODEL.PIXEL_MEAN,
            "pixel_std": cfg.MODEL.PIXEL_STD,
            # video
            "num_frames": cfg.INPUT.SAMPLING_FRAME_NUM,
            "cfg": cfg,
            "generate_captions": cfg.INPUT.CAPTIONS,
            "queries_2types": cfg.QUERIES_2TYPES,
        }

    @property
    def device(self):
        return self.pixel_mean.device

    def forward(self, batched_inputs):
        """
        Args:
            batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
                Each item in the list contains the inputs for one image.
                For now, each item in the list is a dict that contains:
                   * "image": Tensor, image in (C, H, W) format.
                   * "instances": per-region ground truth
                   * Other information that's included in the original dicts, such as:
                     "height", "width" (int): the output resolution of the model (may be different
                     from input resolution), used in inference.
        Returns:
            losses (if training)
            results (if evaluation)
        """

        images = []

        if self.training:

            for video in batched_inputs:
                if isinstance(video["image"], list) or len(video["image"].shape) > 3:
                    for frame in video["image"]:
                        images.append(frame.to(self.device))

                else:
                    images.append(video["image"].to(self.device))

            images = [(x - self.pixel_mean) / self.pixel_std for x in images]
            images = ImageList.from_tensors(images, self.size_divisibility)
            features = self.backbone(images.tensor)

            outputs = self.sem_seg_head(features)
            if "pred_boxes" not in outputs.keys():
                outputs = self.masks_to_boxes(outputs)

            if (
                self.metadata.name.find("vis") > -1
                or self.metadata.name.find("burst") > -1
            ):
                targets = self.prepare_targets_vidstg(batched_inputs, images)
                losses, indices = self.criterion(
                    outputs, targets, openworld=self.queries_2types
                )

            if self.metadata.name.find("vg_") > -1:
                targets = self.prepare_targets_vg(batched_inputs, images)

                losses, indices = self.criterion(outputs, targets)
                if self.cfg.BB_BASELINE == True:
                    losses["loss_embedding"] = self.get_bb_baseline_results_loss_fromgt(
                        batched_inputs, images, [], visualize=True
                    )

                elif self.cfg.MASKED_PROMPT == True:
                    losses["loss_embedding"] = self.vg_loss_maskprompt(
                        batched_inputs, images, [], visualize=True
                    )

            if self.metadata.name.find("vidstg") > -1:
                targets = self.prepare_targets_vidstg(batched_inputs, images)
                losses, indices = self.criterion(outputs, targets)
                if self.generate_captions == True:
                    losses["loss_embedding"] = self.get_embedding_loss(
                        batched_inputs, images, outputs, indices
                    )

            for k in list(losses.keys()):
                if k in self.criterion.weight_dict:
                    losses[k] *= self.criterion.weight_dict[k]
                    if torch.isnan(losses[k]):
                        pdb.set_trace()
                else:
                    # remove this loss if not specified in `weight_dict`
                    losses.pop(k)

            if sum([losses[k] for k in losses]) == 0:
                return

            return losses

        else:

            vid_lens = [len(k["image"]) for k in batched_inputs]
            time_gap = self.num_frames

            mask_cls_results_all = []
            mask_pred_results_all = []

            vid_lens_modified = []
            for vid_len in vid_lens:
                if vid_len % self.num_frames == 0:
                    vid_lens_modified.append(vid_len)
                else:
                    vid_lens_modified.append(vid_len + (time_gap - vid_len % time_gap))

            for video in batched_inputs:
                if isinstance(video["image"], list) or len(video["image"].shape) > 3:
                    vid_images = []
                    for frame_id in range(
                        max(vid_lens_modified)
                    ):  # , frame in enumerate(video['image']):
                        if frame_id < len(video["image"]):
                            frame = video["image"][frame_id]
                        else:
                            frame = video["image"][-1]
                        vid_images.append(frame)
                        del frame
                    images.append(vid_images)
                    del vid_images
                else:
                    images.append(video["image"])

            images = [
                [
                    (x - self.pixel_mean.to(x.device)) / self.pixel_std.to(x.device)
                    for x in y
                ]
                for y in images
            ]
            images = ImageList.from_tensors(
                [torch.stack(x) for x in images], self.size_divisibility
            )  # bs, vidlen, 3, h, w

            results_all = [[] for _ in batched_inputs]

            prev_output = []
            prev_masks = []
            bs, vl, c, h_pad, w_pad = (
                images.tensor.shape
            )  # batch size, video_length (padded), channels, h, w

            if (
                self.metadata.name.find("burst") > -1
                or self.metadata.name.find("vis") > -1
            ):
                padded_size = [h_pad, w_pad]
                unpadded_sizes = [k["image"][0].shape[-2:] for k in batched_inputs]
                output_sizes = [[k["height"], k["width"]] for k in batched_inputs]
                # del batched_inputs

            imgs = images.tensor.reshape(bs, -1, time_gap, c, h_pad, w_pad)
            del images
            imgs = imgs.transpose(
                1, 0
            )  # shape is  video_length/time_gap, bs, time_gap, c, h, w: we are iterating over the first dimension
            popped_batches = []

            for img_i, img in enumerate(imgs):
                print("img:", img_i * self.num_frames, "/", max(vid_lens_modified))

                if self.metadata.name.find("vidstg") > -1:
                    keep_vids = [k > img_i * self.num_frames for k in vid_lens]
                else:
                    keep_vids = [True for k in vid_lens]  # all

                if self.metadata.name.find("vidstg") > -1:
                    # modify batched inputs
                    batched_input_temp = [
                        {} for keep, _ in zip(keep_vids, batched_inputs) if keep == True
                    ]
                    count = 0
                    for b_i, (current, keep) in enumerate(
                        zip(batched_inputs, keep_vids)
                    ):
                        if keep == True:
                            start = img_i * self.num_frames
                            finish = (img_i + 1) * self.num_frames
                            for key in current.keys():
                                if key in ["image", "instances"]:
                                    if frame_id < len(video["image"]):
                                        frame = video["image"][frame_id]
                                    else:
                                        frame = video["image"][-1]
                                    batched_input_temp[count][key] = current[key][
                                        start:finish
                                    ]
                                    while (
                                        len(batched_input_temp[count][key])
                                        < self.num_frames
                                    ):
                                        batched_input_temp[count][key].append(
                                            batched_input_temp[count][key][-1]
                                        )
                                else:
                                    batched_input_temp[count][key] = current[key]
                            count = count + 1

                img = img[keep_vids]

                # modify prev_outputs for different lengths of videos, the videos that are already processed are popped out
                if prev_output != []:
                    prev_output = prev_output[
                        :,
                        [k for i, k in enumerate(keep_vids) if i not in popped_batches],
                    ]
                if prev_masks != []:
                    prev_masks = prev_masks[
                        [k for i, k in enumerate(keep_vids) if i not in popped_batches]
                    ]

                features = self.backbone(img.to(self.device).flatten(0, 1))

                outputs = self.sem_seg_head(
                    features, prev_output=prev_output, mask=prev_masks
                )

                prev_masks = outputs["pred_masks"]
                prev_output = outputs["obj_queries"].transpose(0, 1)
                popped_batches = list(
                    set(
                        popped_batches
                        + [i for i, k in enumerate(keep_vids) if k == False]
                    )
                )

                if self.metadata.name.find("vg_") > -1:
                    if self.cfg.BB_BASELINE == True:
                        results, annotations, loss = (
                            self.get_bb_baseline_results_loss_fromgt(
                                batched_inputs,
                                images,
                                [],
                                visualize=True,
                                training=False,
                            )
                        )
                    else:
                        if self.cfg.MASKED_ATTENTION == True:
                            results, annotations, loss = self.get_vidstg_eval(
                                batched_inputs, images, outputs
                            )
                        else:
                            if self.cfg.MASKED_PROMPT == True:
                                results, annotations, loss = self.vg_loss_maskprompt(
                                    batched_inputs,
                                    images,
                                    [],
                                    visualize=True,
                                    training=False,
                                )

                    return results
                if self.metadata.name.find("vidstg") > -1:
                    results, annotations, loss = self.get_vidstg_eval(
                        batched_input_temp,
                        ImageList.from_tensors(
                            [im for im in img], self.size_divisibility
                        ),
                        outputs,
                        start_frame=img_i * self.num_frames,
                    )
                    count_res = 0
                    for bs_res in range(len(results_all)):
                        vid_len = vid_lens[bs_res]
                        if keep_vids[bs_res] == True:
                            results_all[bs_res] = (
                                results_all[bs_res] + results[count_res]
                            )
                            count_res = count_res + 1
                        if len(results_all[bs_res]) > vid_len:
                            results_all[bs_res] = results_all[bs_res][:vid_len]

                if (
                    self.metadata.name.find("burst") > -1
                    or self.metadata.name.find("vis") > -1
                ):
                    for _ in range(time_gap):
                        mask_cls_results_all.append(
                            outputs["pred_logits"].unsqueeze(-2).cpu()
                        )  # outputs["pred_logits"]: bs, q, 1, numclass+1. Will concatenate later along 3rd dim.
                    mask_pred_results_all.append(
                        outputs["pred_masks"].cpu()
                    )  # outputs["pred_masks"]: bs, q, t, h, w. Will concatenate later along 3rd dim.

            if (
                self.metadata.name.find("burst") > -1
                or self.metadata.name.find("vis") > -1
            ):
                class_all = torch.concatenate(mask_cls_results_all, dim=2)
                mask_all = torch.concatenate(mask_pred_results_all, dim=2)

                results_burst_all = self.inference_video(
                    class_all,
                    mask_all,
                    unpadded_sizes,
                    padded_size,
                    output_sizes,
                    vid_lens,
                )

                return results_burst_all

            return results_all

    def prepare_targets(self, targets, images):
        h_pad, w_pad = images.tensor.shape[-2:]
        gt_instances = []
        for targets_per_video in targets:
            _num_instance = len(targets_per_video["instances"][0])
            mask_shape = [_num_instance, self.num_frames, h_pad, w_pad]
            gt_masks_per_video = torch.zeros(
                mask_shape, dtype=(torch.bool), device=(self.device)
            )
            gt_ids_per_video = []
            for f_i, targets_per_frame in enumerate(targets_per_video["instances"]):
                targets_per_frame = targets_per_frame.to(self.device)
                h, w = targets_per_frame.image_size
                gt_ids_per_video.append(targets_per_frame.gt_ids[:, None])
                gt_masks_per_video[:, f_i, :h, :w] = targets_per_frame.gt_masks.tensor

            gt_ids_per_video = torch.cat(gt_ids_per_video, dim=1)
            valid_idx = (gt_ids_per_video != -1).any(dim=(-1))
            gt_classes_per_video = targets_per_frame.gt_classes[valid_idx]
            gt_ids_per_video = gt_ids_per_video[valid_idx]
            label_texts_per_video = [
                self.metadata.thing_classes[r].replace("_", " ")
                for r in gt_classes_per_video
            ]
            label_tokens_per_video = self.blip_processor.tokenizer(
                label_texts_per_video
            )
            gt_instances.append(
                {
                    "labels": gt_classes_per_video,
                    "ids": gt_ids_per_video,
                    "label_texts": label_texts_per_video,
                    "text_tokens": label_tokens_per_video,
                }
            )
            gt_masks_per_video = gt_masks_per_video[valid_idx].float()
            gt_instances[-1].update({"masks": gt_masks_per_video})

        else:
            return gt_instances

    def prepare_targets_vidstg(self, targets, images):  # also works for burst
        # for now frame by frame
        """
        Args:
            targets: a list of length bs.
                Each item in the list is a dict that contains:
                   * "image": Tensor, image in (C, H, W) format.
                   * "instances": per-region ground truth with the following fields:
                        gt_boxes: tensor of shape num_targets_in_batch, 4
                        gt_phrases: list of length num_targets_in_batch. Each item is a list of strings (descriptions for that object).
                        gt_classes: tensor of shape  num_targets_in_batch. Contains class id.
                        gt_names: list of length num_targets_in_batch. Each item is a list of strings (names for that object).
            images: ImageList with tensor of shape [bs, C, H, W]
        Returns:
            list[dict]:
                each dict contains the following keys:
                    "labels": shape [num_targets_in_batch]. values range from 0 to K. K is background
                    "boxes": shape [num_targets_in_batch, t, 4]. (t=1 for visual genome)
        """

        h_pad, w_pad = images.tensor.shape[-2:]

        image_size_xyxy = torch.as_tensor(
            [w_pad, h_pad, w_pad, h_pad], dtype=torch.float, device=self.device
        )

        new_targets = []

        for targets_per_video in targets:
            vid_gt_boxes = []
            vid_gt_masks = []
            labels_old = []
            gt_ids_per_video = []
            captions_per_video = []

            _num_instance = len(targets_per_video["instances"][0])
            mask_shape = [_num_instance, self.num_frames, h_pad, w_pad]

            if "gt_masks" in targets_per_video["instances"][0]._fields.keys():
                gt_masks_per_video = torch.zeros(
                    mask_shape, dtype=(torch.bool), device=(self.device)
                )

            for frame in range(len(targets_per_video["image"])):
                h, w = targets_per_video["image"][frame].shape[-2:]
                gt_boxes = (
                    targets_per_video["instances"][frame]
                    .gt_boxes.tensor.float()
                    .to(self.device)
                )
                gt_boxes = gt_boxes / image_size_xyxy
                gt_boxes = box_ops.box_xyxy_to_cxcywh(gt_boxes)
                vid_gt_boxes.append(gt_boxes)

                if "gt_masks" in targets_per_video["instances"][frame]._fields.keys():
                    gt_masks_per_video[:, frame, :h, :w] = targets_per_video[
                        "instances"
                    ][frame].gt_masks.tensor
                else:
                    gt_masks = box_ops.boxes_to_mask(
                        gt_boxes.unsqueeze(1),
                        original_size=(h, w),
                        final_size=(h_pad, w_pad),
                        resize_or_pad="pad",
                    )
                    vid_gt_masks.append(gt_masks)

                gt_ids_per_video.append(
                    targets_per_video["instances"][frame].gt_ids[:, None]
                )
                if "gt_phrases" in targets_per_video["instances"][frame]._fields.keys():
                    captions_per_video.append(
                        targets_per_video["instances"][frame].gt_phrases
                    )
                # if frame>0:
                #     assert abs(labels - targets_per_video["instances"][frame].gt_classes).sum().item()==0
                labels = targets_per_video["instances"][frame].gt_classes
                # if self.cfg.OPEN_WORLD:
                #     labels=(targets_per_video["instances"][frame].gt_classes>= self.sem_seg_head.num_classes-1).long()
                # else:
                #     labels=targets_per_video["instances"][frame].gt_classes

            gt_ids_per_video = torch.cat(gt_ids_per_video, dim=1)
            valid_idx = (gt_ids_per_video != -1).any(dim=(-1))

            if not "gt_masks" in targets_per_video["instances"][frame]._fields.keys():
                gt_masks_per_video = (
                    torch.stack(vid_gt_masks).squeeze(2).squeeze(2).transpose(1, 0)
                )

            gt_masks_per_video = gt_masks_per_video[valid_idx]

            # image_size=images_original.image_sizes[i]
            # box = Boxes(boxes_per_batch[top_indices_per_batch][itr, i2:i2+1]/torch.tensor([wt, ht, wt, ht]).to(self.device))
            # box.scale(scale_x=image_size[1], scale_y=image_size[0])
            # box=box.tensor[0]

            new_targets.append(
                {
                    "ids": gt_ids_per_video[valid_idx],
                    "labels": labels[valid_idx].to(self.device),
                    "boxes": torch.stack(vid_gt_boxes).transpose(1, 0)[valid_idx],
                    "masks": gt_masks_per_video.float(),
                    "gt_phrases": captions_per_video,
                }
            )

        return new_targets

    def masks_to_boxes(self, outputs):
        """
        Args:
            outputs: a dict with keys:
                'obj_queries': bs, nq, c
                'pred_logits': bs, nq, K+1 (K being the number of classes, +1 for the no object)
                'pred_masks': bs, nq, t, h, w
                'aux_outputs': list with length of number of decoder layers, intermediate output from each layer.
                    Each element is a dict with keys ['obj_queries', 'pred_logits', 'pred_masks'].
        Returns:
            outputs: a dict with keys:
                'obj_queries': bs, nq, c
                'pred_logits': bs, nq, K+1 (K being the number of classes, +1 for the no object)
                'pred_masks': bs, nq, t, h, w
                'pred_boxes': bs, nq, t, 4 (centered and normalized)
                'aux_outputs': list with length of number of decoder layers, intermediate output from each layer.
                    Each element is a dict with keys ['obj_queries', 'pred_logits', 'pred_masks', 'pred_boxes'].
        """
        bs, nq, t, h, w = outputs["pred_masks"].shape

        outputs["pred_boxes"] = self.masks_to_boxes_single(outputs["pred_masks"])
        for op in outputs["aux_outputs"]:
            op["pred_boxes"] = self.masks_to_boxes_single(op["pred_masks"])
        return outputs

    def masks_to_boxes_single(self, outputs_masks):
        """
        Args:
            outputs_masks: bs, t, nq, h, w
        Returns:
            centered_normalized_boxes: bs, t, nq, 4
                Boxes that are normalized and centered
        """

        bs, t, nq, h, w = outputs_masks.shape
        size_xyxy = torch.tensor([w, h, w, h]).to(outputs_masks.device)
        flattened_masks = outputs_masks.flatten(0, 2) > 0
        flattened_boxes = box_ops.masks_to_boxes(flattened_masks)
        boxes = flattened_boxes.reshape(bs, t, nq, 4)
        centered_normalized_boxes = box_ops.box_xyxy_to_cxcywh(boxes / size_xyxy)
        return centered_normalized_boxes

    def get_blip_inputs(self, batched_inputs, outputs, boxes):
        transform_blip = T.Resize([224, 224])

        images_temp = []

        for vid_i, video in enumerate(batched_inputs):
            vid_images = []
            if isinstance(video["image"], list) or len(video["image"].shape) > 3:
                for fr_i, frame in enumerate(video["image"]):
                    modified_frame = self.blip_processor(
                        T.functional.to_pil_image((frame.to(self.device)), mode=None),
                        return_tensors="pt",
                    ).to(self.device)["pixel_values"]
                    vid_images.append(modified_frame)
                images_temp.append(torch.concatenate(vid_images))
            else:
                modified_frame = self.blip_processor(
                    T.functional.to_pil_image(
                        (video["image"].to(self.device)), mode=None
                    ),
                    return_tensors="pt",
                ).to(self.device)["pixel_values"]
                images_temp.append(modified_frame)

        images = torch.stack(images_temp)

        inputs_embeds_obj = self.caption_projection(outputs["obj_queries"])
        image_embeds = self.blip_model.vision_model(
            images.flatten(0, 1), return_dict=True
        ).last_hidden_state
        query_tokens = self.blip_model.query_tokens.expand(
            image_embeds.shape[0], -1, -1
        )
        masks = box_ops.boxes_to_mask(
            boxes,
            original_size=(outputs["pred_masks"].shape[-2:]),
            final_size=(torch.Size([16, 16])),
        ).unsqueeze(
            -3
        )  # why 16?
        temp_masks = (masks > 0).float()
        flattened_mask_xy = temp_masks.flatten(-2, -1).to(image_embeds.device)
        image_attention_masks = torch.cat(
            [
                flattened_mask_xy,
                torch.zeros_like(flattened_mask_xy[..., :1])
                .type(torch.long)
                .to(image_embeds.device),
            ],
            dim=(-1),
        )
        return image_embeds, inputs_embeds_obj, query_tokens, image_attention_masks

    def blip_qformer_process(
        self,
        img_emb,
        query_tokens_i,
        image_attn_mask,
        inputs_embeds_obj=None,
        target=[],
        training=True,
    ):
        """
        Args:
            img_emb: [1, hw, f_img]
            query_tokens_i: [n, q_text, f_text]
            image_attn_mask: [n, 1, hw]
            inputs_embeds_obj: [n, 1, f_lang]
            target: list of tensors. each tensor of shape [n, len_text_tokens]
            training: True or False

        Returns:
            inputs_embeds_final: [n, q_text+1, f_lang] if inputs_embeds_obj is not None, else: [n, q_text+1+1, f_lang]
            attention_mask (only if training is True): [n, len_text_tokens + q_text + 1] if inputs_embeds_obj is not None, else: [n, len_text_tokens + q_text]
            input_id (only if training is True): [n, len_text_tokens, f_lang]

        """
        if training == True:
            input_id = torch.concatenate(
                [
                    target[k][random.sample(range(len(target[k])), 1)]
                    for k in range(len(target))
                ]
            ).to(
                img_emb.device
            )  # [n, 50]
            input_embed = self.blip_model.get_input_embeddings()(
                input_id
            )  # [n, 50, 2560]
            input_embed[input_id == 1] = 0
            attention_mask = torch.ones_like(input_id)  # [n, 50]
            attention_mask[input_id == 1] = 0

        else:
            input_id = (
                torch.LongTensor([[self.blip_model.config.text_config.bos_token_id]])
                .repeat(len(query_tokens_i), 1)
                .to(img_emb.device)
            )  # [n, 1]
            input_embed = self.blip_model.get_input_embeddings()(
                input_id
            )  # [n, 1, 2560]

        query_outputs = self.blip_model.qformer(
            query_embeds=query_tokens_i,
            encoder_hidden_states=img_emb,
            encoder_attention_mask=image_attn_mask,
            return_dict=True,
        )

        query_output = query_outputs.last_hidden_state  # [n, 32, 768]
        language_model_inputs = self.blip_model.language_projection(
            query_output
        )  # [n, 32, 2560]
        language_model_attention_mask = torch.ones(
            (language_model_inputs.size()[:-1]),
            dtype=(torch.long),
            device=(language_model_inputs.device),
        )  # [n, 32]

        if inputs_embeds_obj == None:
            inputs_embeds_final = torch.cat(
                [language_model_inputs, input_embed.to(language_model_inputs.device)],
                dim=1,
            )  # [n, 32+input_id.shape[1], 2560]

        else:

            obj_attn_mask = torch.ones(
                (inputs_embeds_obj.size()[:-1]),
                dtype=(torch.long),
                device=(inputs_embeds_obj.device),
            )  # [n, 1]
            inputs_embeds_final = torch.cat(
                [
                    language_model_inputs,
                    inputs_embeds_obj,
                    input_embed.to(language_model_inputs.device),
                ],
                dim=1,
            )  # [n, 32+input_id.shape[1], 2560]

        if training == True:

            if inputs_embeds_obj == None:
                attention_mask = torch.cat(
                    [language_model_attention_mask, attention_mask.to(input_id.device)],
                    dim=1,
                )  # [n, 82]
            else:
                attention_mask = torch.cat(
                    [
                        language_model_attention_mask,
                        obj_attn_mask,
                        attention_mask.to(input_id.device),
                    ],
                    dim=1,
                )  # [n, 83]

            return inputs_embeds_final, attention_mask, input_id
        else:
            return inputs_embeds_final

    def get_vidstg_eval(
        self,
        batched_inputs,
        images_original,
        outputs,
        threshold=0.1,
        start_frame=0,
        frame_for_cap=1,
    ):
        bs, nq, nf, ht, wt = outputs["pred_masks"].shape
        boxes = box_ops.box_cxcywh_to_xyxy(outputs["pred_boxes"]) * torch.tensor(
            [wt, ht, wt, ht]
        ).to(self.device)
        output_prob = outputs["pred_logits"].softmax(-1)[..., :-1]
        scores, categories = output_prob.max(-1)

        if start_frame % frame_for_cap == 0 and self.generate_captions:
            image_embeds, inputs_embeds_obj, query_tokens, image_attention_masks = (
                self.get_blip_inputs(batched_inputs, outputs, boxes)
            )

        output_batchwise = []
        for i, inputs_per_batch in enumerate(batched_inputs):

            if boxes[i].dim() == 4:
                boxes_per_batch = boxes[i][:, 0]
            else:
                boxes_per_batch = boxes[i]

            top_indices_per_batch = torch.where(scores[i] > 0.5)[0]
            top_scores_per_batch = scores[i][top_indices_per_batch]

            output_per_frame = []
            num_frames = len(batched_inputs[i]["image"])
            for i2, frame_i2 in enumerate(range(i * num_frames, (i + 1) * num_frames)):
                predicted_boxes = []
                predicted_texts = []
                predicted_scores = []
                ids = []

                target_boxes = []
                target_texts = []
                target_ids = []

                if top_indices_per_batch.sum() > 0:

                    if (
                        start_frame + i2
                    ) % frame_for_cap == 0 and self.generate_captions:
                        inputs_embeds_final = self.blip_qformer_process(
                            img_emb=image_embeds[frame_i2 : frame_i2 + 1],
                            query_tokens_i=query_tokens[frame_i2 : frame_i2 + 1].repeat(
                                len(top_indices_per_batch), 1, 1
                            ),
                            image_attn_mask=image_attention_masks[i][
                                top_indices_per_batch, i2
                            ],
                            inputs_embeds_obj=inputs_embeds_obj[
                                i, top_indices_per_batch
                            ].unsqueeze(1),
                            training=False,
                        )

                        everything = self.blip_model.language_model.generate(
                            inputs_embeds=inputs_embeds_final,
                            max_new_tokens=30,
                            return_dict_in_generate=True,
                        )
                        generated_text = self.blip_processor.batch_decode(
                            (everything["sequences"]), skip_special_tokens=True
                        )

                    for itr in range(len(top_indices_per_batch)):

                        image_size = images_original.image_sizes[i]
                        box = Boxes(
                            boxes_per_batch[top_indices_per_batch][itr, i2 : i2 + 1]
                            / torch.tensor([wt, ht, wt, ht]).to(self.device)
                        )
                        box.scale(scale_x=image_size[1], scale_y=image_size[0])
                        box = box.tensor[0]
                        if (
                            start_frame + i2
                        ) % frame_for_cap == 0 and self.generate_captions:
                            predicted_texts.append(
                                "<bos> " + generated_text[itr].strip() + " <eos>"
                            )
                        else:
                            predicted_texts.append(-1)
                        # print(box, predicted_texts[-1])
                        # pdb.set_trace()
                        predicted_boxes.append(box)
                        predicted_scores.append(top_scores_per_batch[itr])
                        ids.append(top_indices_per_batch[itr].item())

                gt_box = inputs_per_batch["instances"][i2].gt_boxes.tensor
                gt_ids = inputs_per_batch["instances"][i2].gt_ids.cpu().tolist()

                if "gt_phrases" in inputs_per_batch["instances"][i2]._fields.keys():
                    gt_text = inputs_per_batch["instances"][i2].gt_phrases
                else:
                    gt_text = [[] for _ in gt_box]

                if isinstance(gt_text, str):
                    target_boxes = target_boxes + [gt_box]
                    target_texts = target_texts + ["<bos> " + gt_text + " <eos>"]
                    target_ids = target_ids + [gt_ids]
                else:
                    target_boxes = []
                    target_texts = []
                    target_ids = []
                    for gtt_itr, gtt in enumerate(gt_text):
                        if isinstance(gtt, str):
                            target_texts.append("<bos> " + gtt + " <eos>")
                            target_boxes.append(gt_box[gtt_itr])
                            target_ids.append(gt_ids[gtt_itr])
                        elif gtt == []:
                            if gt_ids[gtt_itr] != -1:
                                target_texts.append("<bos> " + "" + " <eos>")
                                target_boxes.append(gt_box[gtt_itr])
                                target_ids.append(gt_ids[gtt_itr])
                        else:
                            target_texts = target_texts + [
                                "<bos> " + gt_t + " <eos>" for gt_t in gtt
                            ]
                            target_boxes = target_boxes + [
                                gt_box[gtt_itr] for gt_t in gtt
                            ]
                            target_ids = target_ids + [gt_ids[gtt_itr] for gt_t in gtt]

                if "image_id" in batched_inputs[i]:
                    batch_id = batched_inputs[i]["image_id"]
                elif "video_id" in batched_inputs[i]:
                    batch_id = (
                        str(batched_inputs[i]["video_id"]) + "_" + str(i2 + start_frame)
                    )
                if predicted_boxes == []:
                    predicted_boxes_temp = torch.tensor(predicted_boxes)
                else:
                    predicted_boxes_temp = torch.stack(predicted_boxes)
                if target_boxes == []:
                    target_boxes_temp = torch.tensor(target_boxes)
                else:
                    target_boxes_temp = torch.stack(target_boxes)

                output_per_frame.append(
                    {
                        "scores": torch.tensor(predicted_scores),
                        "boxes": predicted_boxes_temp,
                        "caps": predicted_texts,
                        "obj_ids": ids,
                        "gt_ids": target_ids,
                        "target_boxes": target_boxes_temp,
                        "target_texts": target_texts,
                        "frame": start_frame + i2,
                        "image_id": batch_id,
                    }
                )
            output_batchwise.append(output_per_frame)

        return output_batchwise, [], 0.0

    def get_embedding_loss(
        self,
        batched_inputs,
        images_original,
        outputs,
        indices,
        visualize=True,
        training=True,
    ):

        embedding_loss = torch.tensor(0.0)
        # self.blip_model.train()
        loss_fct = nn.CrossEntropyLoss(reduction="none", ignore_index=1)

        bs, nq, nf, ht, wt = outputs["pred_masks"].shape
        boxes = box_ops.box_cxcywh_to_xyxy(outputs["pred_boxes"]) * torch.tensor(
            [wt, ht, wt, ht]
        ).to(self.device)
        image_embeds, inputs_embeds_obj, query_tokens, image_attention_masks = (
            self.get_blip_inputs(batched_inputs, outputs, boxes)
        )

        for i, inputs_per_batch in enumerate(batched_inputs):

            index = [indices[i][1], indices[i][0]]

            if len(index[0]) > 0:
                for i2, frame_i2 in enumerate(
                    range(i * self.num_frames, (i + 1) * self.num_frames)
                ):
                    overlapped_gt_phrases = [
                        inputs_per_batch["instances"][i2].gt_phrases[k]
                        for k in index[0]
                    ]
                    tgt = [
                        (
                            torch.tensor(
                                self.blip_processor.tokenizer(
                                    obj_gt_phrases, padding=True, pad_to_multiple_of=50
                                )["input_ids"]
                            )
                            if obj_gt_phrases != []
                            else -1
                        )
                        for obj_gt_phrases in overlapped_gt_phrases
                    ]
                    keep_tgt = [not (isinstance(k, int) and k == -1) for k in tgt]
                    tgt = [k for k in tgt if not (isinstance(k, int) and k == -1)]
                    if tgt != []:
                        inputs_embeds_final, attention_mask, input_id = (
                            self.blip_qformer_process(
                                img_emb=image_embeds[frame_i2 : frame_i2 + 1],
                                query_tokens_i=query_tokens[
                                    frame_i2 : frame_i2 + 1
                                ].repeat(len(index[1]), 1, 1)[keep_tgt],
                                image_attn_mask=image_attention_masks[i][index[1], i2][
                                    keep_tgt
                                ],
                                inputs_embeds_obj=inputs_embeds_obj[i, index[1]][
                                    keep_tgt
                                ].unsqueeze(1),
                                target=tgt,
                                training=True,
                            )
                        )

                        everything = self.blip_model.language_model(
                            inputs_embeds=inputs_embeds_final,
                            attention_mask=attention_mask,
                        )

                        logits = everything.logits
                        labels = input_id
                        logits = logits[:, -labels.size(1) :, :]
                        shift_logits = logits[..., :-1, :].contiguous()
                        shift_labels = labels[..., 1:].contiguous().to(logits.device)
                        loss = loss_fct(
                            shift_logits.view(
                                -1, self.blip_model.config.text_config.vocab_size
                            ),
                            shift_labels.view(-1),
                        )
                        embedding_loss = embedding_loss + loss[loss > 0].mean()

        return embedding_loss

    def inference_video(
        self,
        pred_cls_all,
        pred_masks_all,
        unpadded_sizes_all,
        padded_size,
        output_sizes_all,
        lengths_all,
    ):
        """
        For VIS or BURST

        Args:
            pred_cls: tensor of shape [bs, videolengh, q, num_classes+1]
            pred_masks:  tensor of shape [bs, videolengh, q, time_gap, h, w]
            unpadded_sizes: list of len bs. Each element: [h,w]
            padded_sizes: list of len bs. Each element [h_pad, w_pad]
        Returns:
            list if video_outputs. Each element:
                video_output = {
                    "image_size": [h,w]
                    "pred_scores": list of length [num_objs]: each object receives 1 score for all frames
                    "pred_labels": list of length [num_objs]: each object receives 1 class label for all frames
                    "pred_masks":  list of length [num_objs]: each element is a list of length num_frames
                }
        """
        video_outputs_all = []
        for pred_cls, pred_masks, unpadded_size, output_size, length in zip(
            pred_cls_all,
            pred_masks_all,
            unpadded_sizes_all,
            output_sizes_all,
            lengths_all,
        ):
            if len(pred_cls) > 0:
                scores_tmp = F.softmax(pred_cls.float(), dim=-1)[:, :, :-1]
                nonzero_masks = (pred_masks > 0).sum(-1).sum(-1) > 0
                scores_tmp[nonzero_masks == 0] = 0.0
                normalizing_factor = (nonzero_masks).sum(
                    1
                )  # to scale the scores; if all T*H*W masks are non-zero, this is total no of num_frames
                normalizing_factor[normalizing_factor == 0] = pred_masks.shape[1]
                scores = scores_tmp.sum(1) / normalizing_factor.unsqueeze(1)
                del scores_tmp, nonzero_masks, normalizing_factor

                # original: keep top-10 predictions
                scores_per_image, topk_indices = scores.flatten(0, 1).topk(
                    20, sorted=False
                )
                del scores

                labels = (
                    torch.arange(self.sem_seg_head.num_classes, device=self.device)
                    .unsqueeze(0)
                    .repeat(pred_cls.shape[0], 1)
                    .flatten(0, 1)
                )

                labels_per_image = labels[topk_indices]
                del labels

                q, t = pred_masks.shape[:2]
                pred_masks = F.interpolate(
                    pred_masks.float().flatten(0, 1).unsqueeze(1),
                    size=padded_size,
                    mode="bilinear",
                    align_corners=False,
                ).squeeze(1)
                pred_masks = pred_masks.reshape(
                    q, t, pred_masks.shape[-2], pred_masks.shape[-1]
                )

                topk_indices = torch.div(
                    topk_indices, self.sem_seg_head.num_classes, rounding_mode="trunc"
                )
                pred_masks = pred_masks[topk_indices]

                pred_masks = pred_masks[
                    :, :length, : unpadded_size[0], : unpadded_size[1]
                ]
                pred_masks = F.interpolate(
                    pred_masks, size=output_size, mode="bilinear", align_corners=False
                )
                masks = pred_masks > 0.0

                del pred_masks

                out_scores = scores_per_image.tolist()
                del scores_per_image

                out_labels = labels_per_image.tolist()
                del labels_per_image

                out_masks = [m for m in masks.cpu()]
                del masks

            else:
                out_scores = []
                out_labels = []
                out_masks = []

            video_output = {
                "image_size": output_size,
                "pred_scores": out_scores,
                "pred_labels": out_labels,
                "pred_masks": out_masks,
            }

            del out_scores, out_labels, out_masks

            video_outputs_all.append(video_output)
            del video_output

        return video_outputs_all
