import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
import os
from src.models.backbone_croco import AsymmetricCroCo
from src.models.vit_adapter import CroCoViTAdapter
from src.models.croco.misc import transpose_to_landscape
from src.models.heads import head_factory
from src.models.gaussian_adapter import UnifiedGaussianAdapter
from transformers import Mask2FormerConfig
from src.models.mask2former import VideoMask2FormerForVideoSegmentation
from src.models.openclip.lang_module import LangModule
from src.models.mask2former.image_processing_video_mask2former import (
    VideoMask2FormerImageProcessor,
)
from einops import rearrange

from src.utils.weight_modify import checkpoint_filter_fn
from src.utils.gaussians_types import Gaussians

from src.config import ModelCfg

from src.utils import pylogger

log = pylogger.RankedLogger(__name__, rank_zero_only=True)


class SIU3RModel(nn.Module):

    def __init__(
        self,
        cfg: ModelCfg,
    ):
        super(SIU3RModel, self).__init__()
        self.cfg = cfg
        self._set_backbone()
        self._set_adapter()
        self._set_mask2former()
        self._set_center_head()
        self._set_gaussian_param_head()
        self._set_gaussian_adapter()
        if self.cfg.enable_multi_res_gs_feat_lift:
            self._set_multi_res_gaussian_param_head()
            self._set_multi_res_gaussian_adapter()
        if self.cfg.train_refer_segmentation:
            self._set_lang_module()

    def _set_backbone(self):
        self.backbone = AsymmetricCroCo(
            img_size=self.cfg.image_size, **self.cfg.croco.__dict__
        )
        self.backbone.depth_mode = ("exp", -float("inf"), float("inf"))
        self.backbone.conf_mode = None
        if self.cfg.croco.freeze == "encoder":
            self.backbone.enc_blocks.eval()
            for param in self.backbone.enc_blocks.parameters():
                param.requires_grad = False
            self.backbone.enc_norm.eval()
            for param in self.backbone.enc_norm.parameters():
                param.requires_grad = False
            self.backbone.patch_embed.eval()
            for param in self.backbone.patch_embed.parameters():
                param.requires_grad = False
        if self.cfg.train_refer_segmentation:
            for param in self.backbone.parameters():
                param.requires_grad = False

    def _set_adapter(self):
        self.adapter = CroCoViTAdapter(
            num_block=self.cfg.croco.enc_depth,
            embed_dim=self.cfg.croco.enc_embed_dim,
            size=self.cfg.image_size,
            patchsize=self.cfg.croco.patch_size,
        )
        if self.cfg.train_refer_segmentation:
            self.adapter.eval()
        if self.cfg.train_refer_segmentation:
            for param in self.adapter.parameters():
                param.requires_grad = False

    def _set_mask2former(self):
        self.mask2former_config = Mask2FormerConfig(
            id2label=self.cfg.mask2former.id2label,
            num_queries=self.cfg.mask2former.num_queries,
            train_refer_segmentation=self.cfg.train_refer_segmentation,
        )
        self.mask2former = VideoMask2FormerForVideoSegmentation(self.mask2former_config)
        self.mask2former.train()
        self.processor = VideoMask2FormerImageProcessor()
        if self.cfg.train_refer_segmentation:
            for param in self.mask2former.parameters():
                param.requires_grad = False
            for param in self.mask2former.lang_cross_attns.parameters():
                param.requires_grad = True
            for param in self.mask2former.lang_attn_norms.parameters():
                param.requires_grad = True
            for param in self.mask2former.lang_attn_norms_final.parameters():
                param.requires_grad = True
            for param in self.mask2former.lang_fc1s.parameters():
                param.requires_grad = True
            for param in self.mask2former.lang_fc2s.parameters():
                param.requires_grad = True

    def _set_lang_module(self):
        self.lang_module = LangModule(
            out_features=self.mask2former_config.hidden_dim,
            pretrained_weights_path=self.cfg.pretrained_weights_path,
        )
        for param in self.lang_module.text_encoder.parameters():
            param.requires_grad = False

    def _set_center_head(self):
        self.downstream_head1 = head_factory(
            "dpt", "pts3d", self.backbone, has_conf=False
        )
        self.downstream_head2 = head_factory(
            "dpt", "pts3d", self.backbone, has_conf=False
        )
        if self.cfg.train_refer_segmentation:
            for param in self.downstream_head1.parameters():
                param.requires_grad = False
            for param in self.downstream_head2.parameters():
                param.requires_grad = False
        self.head1 = transpose_to_landscape(self.downstream_head1)
        self.head2 = transpose_to_landscape(self.downstream_head2)

    def _set_gaussian_param_head(self):
        # sh + 3 scale + 4 rotation + 1 opacity
        self.raw_gs_dim = (self.cfg.gaussian_head.sh_degree + 1) ** 2 * 3 + 7 + 1
        self.gaussian_param_head1 = head_factory(
            "dpt_gs",
            "gs_params",
            self.backbone,
            has_conf=False,
            out_nchan=self.raw_gs_dim,
        )
        self.gaussian_param_head2 = head_factory(
            "dpt_gs",
            "gs_params",
            self.backbone,
            has_conf=False,
            out_nchan=self.raw_gs_dim,
        )
        if self.cfg.train_refer_segmentation:
            for param in self.gaussian_param_head1.parameters():
                param.requires_grad = False
            for param in self.gaussian_param_head2.parameters():
                param.requires_grad = False

    def _set_multi_res_gaussian_param_head(self):
        self.raw_gs_dim_wo_xyz_sh = 8  # 3 scale + 4 rotation + 1 opacity
        self.multi_res_gaussian_param_head = head_factory(
            "multi_res_dpt_gs",
            "gs_params",
            self.backbone,
            has_conf=False,
            out_nchan=self.raw_gs_dim_wo_xyz_sh,
        )
        self.multi_res_gaussian_param_head2 = head_factory(
            "multi_res_dpt_gs",
            "gs_params",
            self.backbone,
            has_conf=False,
            out_nchan=self.raw_gs_dim_wo_xyz_sh,
        )
        if self.cfg.train_refer_segmentation:
            for param in self.multi_res_gaussian_param_head.parameters():
                param.requires_grad = False
            for param in self.multi_res_gaussian_param_head2.parameters():
                param.requires_grad = False

    def _set_gaussian_adapter(self):
        self.gaussian_adapter = UnifiedGaussianAdapter(
            self.cfg.gaussian_head.gaussian_scale_min,
            self.cfg.gaussian_head.gaussian_scale_max,
            sh_degree=self.cfg.gaussian_head.sh_degree,
        )
        if self.cfg.train_refer_segmentation:
            for param in self.gaussian_adapter.parameters():
                param.requires_grad = False

    def _set_multi_res_gaussian_adapter(self):
        self.multi_res_gaussian_adapter = UnifiedGaussianAdapter(
            self.cfg.gaussian_head.gaussian_scale_min,
            self.cfg.gaussian_head.gaussian_scale_max,
            sh_degree=-1,
        )

        if self.cfg.train_refer_segmentation:
            for param in self.multi_res_gaussian_adapter.parameters():
                param.requires_grad = False

    def load_recon_ckpt(self):
        recon_ckpt_path = os.path.join(
            self.cfg.pretrained_weights_path,
            "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth",
        )
        recon_ckpt = torch.load(recon_ckpt_path, map_location="cpu", weights_only=False)
        recon_ckpt = recon_ckpt["model"]
        recon_ckpt = checkpoint_filter_fn(recon_ckpt, self)
        missing_keys, unexpected_keys = self.load_state_dict(recon_ckpt, strict=False)
        log.info(f"loaded recon ckpt from {recon_ckpt_path}")

    def load_seg_ckpt(self, seg_task="panoptic"):
        if seg_task == "instance":
            seg_ckpt_path = os.path.join(
                self.cfg.pretrained_weights_path,
                "instance_coco_pretrain_vitadapter_maskdecoder_epoch48.ckpt",
            )
        elif seg_task == "panoptic":
            seg_ckpt_path = os.path.join(
                self.cfg.pretrained_weights_path,
                "panoptic_coco_pretrain_vitadapter_maskdecoder_epoch60.ckpt",
            )
        seg_ckpt = torch.load(seg_ckpt_path, map_location="cpu", weights_only=False)
        seg_ckpt = seg_ckpt["state_dict"]
        for key in list(seg_ckpt.keys()):
            if "class_predictor" in key or "criterion" in key or "backbone" in key:
                seg_ckpt.pop(key)
            else:
                seg_ckpt[key[len("model.") :]] = seg_ckpt.pop(key)
        seg_missing_keys, seg_unexpected_keys = self.load_state_dict(
            seg_ckpt, strict=False
        )
        log.info(f"loaded ckpt from {seg_ckpt_path}")
        # log.info(f"missing_keys: {intersaction_missing_keys}")
        # log.info(f"unexpected_keys: {union_unexpected_keys}")

    def load_refer_seg_base_ckpt(self, path):
        ckpt = torch.load(path, map_location="cpu", weights_only=False)
        ckpt = ckpt["state_dict"]
        for key in list(ckpt.keys()):
            ckpt[key[len("model.") :]] = ckpt.pop(key)
        missing_keys, unexpected_keys = self.load_state_dict(ckpt, strict=False)
        log.info(f"loaded refer seg base ckpt from {path}")
        # log.info(f"missing_keys: {missing_keys}")
        # log.info(f"unexpected_keys: {unexpected_keys}")

    def delete_recon_part(self):
        del self.backbone.dec_blocks
        del self.backbone.dec_blocks2
        del self.backbone.decoder_embed
        del self.downstream_head1
        del self.downstream_head2
        del self.head1
        del self.head2
        del self.gaussian_param_head1
        del self.gaussian_param_head2
        del self.gaussian_adapter

    def gaussian_center(self, dec1, dec2, shape1, shape2):
        res1 = self.head1([tok.float() for tok in dec1], shape1)
        res2 = self.head2([tok.float() for tok in dec2], shape2)
        return res1, res2

    def gaussian_param(self, dec1, dec2, res1, res2, view1, view2, shape1, shape2):
        GS_res1 = self.gaussian_param_head1(
            [tok.float() for tok in dec1],
            None,
            view1["img"][:, :3],
            shape1[0].cpu().tolist(),
        )
        GS_res1 = rearrange(GS_res1, "b d h w -> b (h w) d")
        GS_res2 = self.gaussian_param_head2(
            [tok.float() for tok in dec2],
            None,
            view2["img"][:, :3],
            shape2[0].cpu().tolist(),
        )
        GS_res2 = rearrange(GS_res2, "b d h w -> b (h w) d")
        return GS_res1, GS_res2

    def multi_res_gaussian_param(
        self, dec1, dec2, res1, res2, view1, view2, shape1, shape2
    ):
        multi_res_GS_res1 = self.multi_res_gaussian_param_head(
            [tok.float() for tok in dec1],
            None,
            view1["img"][:, :3],
            shape1[0].cpu().tolist(),
        )
        multi_res_GS_res1 = [
            rearrange(res, "b d h w -> b (h w) d") for res in multi_res_GS_res1
        ]

        multi_res_GS_res2 = self.multi_res_gaussian_param_head2(
            [tok.float() for tok in dec2],
            None,
            view2["img"][:, :3],
            shape2[0].cpu().tolist(),
        )
        multi_res_GS_res2 = [
            rearrange(res, "b d h w -> b (h w) d") for res in multi_res_GS_res2
        ]
        return multi_res_GS_res1, multi_res_GS_res2

    def multi_res_gaussian_center(self, pts3d):
        pts = pts3d.detach()
        pts = rearrange(pts, "b h w d -> b d h w")
        pts_ds4 = F.interpolate(
            pts, scale_factor=0.25, mode="nearest"
        )  # [B, D, H/4, W/4]
        pts_ds8 = F.interpolate(
            pts, scale_factor=0.125, mode="nearest"
        )  # [B, D, H/8, W/8]
        pts_ds16 = F.interpolate(
            pts, scale_factor=0.0625, mode="nearest"
        )  # [B, D, H/16, W/16]
        pts_ds32 = F.interpolate(
            pts, scale_factor=0.03125, mode="nearest"
        )  # [B, D, H/32, W/32]
        pts_ds4 = rearrange(pts_ds4, "b d h w -> b h w d")
        pts_ds8 = rearrange(pts_ds8, "b d h w -> b h w d")
        pts_ds16 = rearrange(pts_ds16, "b d h w -> b h w d")
        pts_ds32 = rearrange(pts_ds32, "b d h w -> b h w d")
        return [pts_ds4, pts_ds8, pts_ds16, pts_ds32]

    def multi_res_gaussian(
        self, dec1, dec2, shape1, shape2, view1, view2, pts3d1, pts3d2
    ):
        multi_res_GS_res1, multi_res_GS_res2 = self.multi_res_gaussian_param(
            dec1, dec2, None, None, view1, view2, shape1, shape2
        )
        pts_multi_res1 = self.multi_res_gaussian_center(pts3d1)
        pts_multi_res1 = [rearrange(l, "b h w d -> b (h w) d") for l in pts_multi_res1]
        pts_multi_res2 = self.multi_res_gaussian_center(pts3d2)
        pts_multi_res2 = [rearrange(l, "b h w d -> b (h w) d") for l in pts_multi_res2]
        pts_multi_res_all = [
            torch.stack((pts_multi_res1[i], pts_multi_res2[i]), dim=1)
            for i in range(len(pts_multi_res1))
        ]
        multi_res_gaussians = []
        for i in range(len(pts_multi_res_all)):
            gaussians_multi_res = torch.stack(
                [multi_res_GS_res1[i], multi_res_GS_res2[i]], dim=1
            )  # [B, N, HW, D]
            gaussians_multi_res = self.multi_res_gaussian_adapter.forward(
                pts_multi_res_all[i],  # [B, N, HW, D]
                gaussians_multi_res,  # [B, N, HW, D]
            )
            multi_res_gaussians.append(gaussians_multi_res)
        return multi_res_gaussians

    def gaussian(self, dec1, dec2, shape1, shape2, view1, view2):
        res1, res2 = self.gaussian_center(dec1, dec2, shape1, shape2)
        GS_res1, GS_res2 = self.gaussian_param(
            dec1, dec2, res1, res2, view1, view2, shape1, shape2
        )
        pts3d1hw = res1["pts3d"]
        pts3d2hw = res2["pts3d"]

        pts3d1 = rearrange(pts3d1hw, "b h w d -> b (h w) d")
        pts3d2 = rearrange(pts3d2hw, "b h w d -> b (h w) d")
        pts_all = torch.stack((pts3d1, pts3d2), dim=1)  # [B, N, HW, D]
        gaussians = torch.stack([GS_res1, GS_res2], dim=1)  # [B, N, HW, D]
        gaussians = self.gaussian_adapter.forward(
            pts_all,  # [B, N, HW, D]
            gaussians,
        )

        return gaussians, pts3d1hw, pts3d2hw

    def post_process_gaussians(
        self,
        B,
        H,
        W,
        gaussians: Gaussians,
        context_seg_output,
        multi_res_gaussians=None,
        multi_scale_feat1=None,
        multi_scale_feat2=None,
    ):
        if self.cfg.seg_task == "instance":
            context_seg_results = self.processor.post_process_instance_segmentation(
                outputs=context_seg_output,
                threshold=self.cfg.mask2former.seg_threshold,
                target_sizes=[(H, W)] * B,
            )
        elif self.cfg.seg_task == "panoptic":
            context_seg_results = self.processor.post_process_panoptic_segmentation(
                outputs=context_seg_output,
                threshold=self.cfg.mask2former.seg_threshold,
                target_sizes=[(H, W)] * B,
                label_ids_to_fuse=set(self.cfg.mask2former.label_ids_to_fuse),
            )

        context_seg_masks = []
        context_seg_infos = []
        if self.cfg.enable_query_class_logit_lift:
            context_seg_query_class_logits = []
            context_seg_query_scores = []
        for context_seg_result in context_seg_results:
            seg_mask = context_seg_result["segmentation"]
            context_seg_masks.append(seg_mask)
            seg_info = context_seg_result["segments_info"]
            context_seg_infos.append(seg_info)
            if self.cfg.enable_query_class_logit_lift:
                seg_query_class_logits = context_seg_result[
                    "query_class_logits"
                ]  # [n, q, c+1, h, w]
                seg_query_class_logits = rearrange(
                    seg_query_class_logits, "n q c h w -> (n h w) q c "
                )
                context_seg_query_class_logits.append(seg_query_class_logits)
                context_seg_query_scores.append(context_seg_result["query_scores"])

        gaussians.semantic_labels = torch.zeros(
            B, 2, H, W, dtype=torch.int32, device=gaussians.means.device
        )
        gaussians.instance_labels = torch.zeros(
            B, 2, H, W, dtype=torch.int32, device=gaussians.means.device
        )
        if self.cfg.enable_query_class_logit_lift:
            gaussians.seg_query_class_logits = context_seg_query_class_logits

        for b, (info, masks) in enumerate(
            zip(context_seg_infos, context_seg_masks)
        ):  # iterate over batch
            if len(info) == 0:
                continue
            for seg in info:  # iterate over segments
                semantic_label = (
                    seg["label_id"] + 1
                )  # note here + 1 for 0 is background
                instance_label = seg["id"]
                mask = masks == seg["id"]
                gaussians.semantic_labels[b, mask] = semantic_label
                gaussians.instance_labels[b, mask] = instance_label
        gaussians.semantic_labels = rearrange(
            gaussians.semantic_labels, "b n h w -> b (n h w)"
        )
        gaussians.instance_labels = rearrange(
            gaussians.instance_labels, "b n h w -> b (n h w)"
        )
        gaussians.means = rearrange(gaussians.means, "b n r d -> b (n r) d")
        gaussians.covariances = rearrange(
            gaussians.covariances, "b n r i j -> b (n r) i j"
        )
        gaussians.scales = rearrange(gaussians.scales, "b n r d -> b (n r) d")
        gaussians.rotations = rearrange(gaussians.rotations, "b n r d -> b (n r) d")
        gaussians.opacities = rearrange(gaussians.opacities, "b n r -> b (n r)")
        gaussians.harmonics = rearrange(
            gaussians.harmonics, "b n r sh d -> b (n r) sh d"
        )

        if multi_res_gaussians is not None:
            for i in range(len(multi_res_gaussians)):
                multi_res_gaussians[i].feat = torch.stack(
                    [multi_scale_feat1[i], multi_scale_feat2[i]], dim=1
                )
                multi_res_gaussians[i].means = rearrange(
                    multi_res_gaussians[i].means, "b n r d -> b (n r) d"
                )
                multi_res_gaussians[i].covariances = rearrange(
                    multi_res_gaussians[i].covariances, "b n r i j -> b (n r) i j"
                )
                multi_res_gaussians[i].scales = rearrange(
                    multi_res_gaussians[i].scales, "b n r d -> b (n r) d"
                )
                multi_res_gaussians[i].rotations = rearrange(
                    multi_res_gaussians[i].rotations, "b n r d -> b (n r) d"
                )
                multi_res_gaussians[i].opacities = rearrange(
                    multi_res_gaussians[i].opacities, "b n r -> b (n r)"
                )
                multi_res_gaussians[i].feat = rearrange(
                    multi_res_gaussians[i].feat, "b n c h w -> b (n h w) c"
                )
        return (
            gaussians,
            multi_res_gaussians if self.cfg.enable_multi_res_gs_feat_lift else None,
            context_seg_output,
            context_seg_masks,
            context_seg_infos,
            (
                context_seg_query_scores
                if self.cfg.enable_query_class_logit_lift
                else None
            ),
        )

    def forward(
        self,
        context_views_images,
        context_views_intrinsics,
        mask_labels=None,
        class_labels=None,
    ):
        B, _, _, H, W = context_views_images.shape
        croco_input = {
            "image": context_views_images,
            "intrinsics": context_views_intrinsics,
            "near": 0.1,
            "far": 100,
        }
        (
            feat1,
            feat2,
            all_feat1,
            all_feat2,
            dec1,
            dec2,
            shape1,
            shape2,
            view1,
            view2,
        ) = self.backbone(croco_input, return_views=True)

        gaussians, pts3d1, pts3d2 = self.gaussian(
            dec1, dec2, shape1, shape2, view1, view2
        )

        if self.cfg.enable_multi_res_gs_feat_lift:
            multi_res_gaussians = self.multi_res_gaussian(
                dec1, dec2, shape1, shape2, view1, view2, pts3d1, pts3d2
            )

        img1 = context_views_images[:, 0]
        multi_scale_feat1 = self.adapter(img1, all_feat1)
        img2 = context_views_images[:, 1]
        multi_scale_feat2 = self.adapter(img2, all_feat2)

        multi_scale_feat = []
        for feat1, feat2 in zip(multi_scale_feat1, multi_scale_feat2):
            multi_scale_feat.append(torch.stack([feat1, feat2], dim=1))

        context_seg_output = self.mask2former(
            multi_scale_feat=multi_scale_feat,
            mask_labels=mask_labels,
            class_labels=class_labels,
        )
        (
            gaussians,
            multi_res_gaussians,
            context_seg_output,
            context_seg_masks,
            context_seg_infos,
            context_seg_query_scores,
        ) = self.post_process_gaussians(
            B,
            H,
            W,
            gaussians,
            context_seg_output,
            multi_res_gaussians=(
                multi_res_gaussians if self.cfg.enable_multi_res_gs_feat_lift else None
            ),
            multi_scale_feat1=multi_scale_feat1,
            multi_scale_feat2=multi_scale_feat2,
        )
        if self.cfg.enable_multi_res_gs_feat_lift:
            return (
                gaussians,
                multi_res_gaussians,
            )
        elif self.cfg.enable_query_class_logit_lift:
            return (
                gaussians,
                context_seg_output,
                context_seg_masks,
                context_seg_infos,
                context_seg_query_scores,
            )
        else:
            return (
                gaussians,
                context_seg_output,
                context_seg_masks,
                context_seg_infos,
            )

    def lift_feature_forward(
        self,
        gaussians: Gaussians,
        render_multi_res_feat: list[Tensor],
        context_mask_labels: Tensor,
        context_class_labels: Tensor,
        target_mask_labels: Tensor,
        target_class_labels: Tensor,
        target_sizes: tuple[int, int],
    ):
        B = gaussians.means.shape[0]
        H, W = target_sizes
        context_feat = [
            torch.stack([f[:, 0, ...], f[:, -1, ...]], dim=1).contiguous()
            for f in render_multi_res_feat
        ]
        context_seg_output = self.mask2former(
            multi_scale_feat=context_feat,
            mask_labels=context_mask_labels,
            class_labels=context_class_labels,
        )

        if self.cfg.seg_task == "instance":
            context_seg_results = self.processor.post_process_instance_segmentation(
                outputs=context_seg_output,
                threshold=self.cfg.mask2former.seg_threshold,
                target_sizes=[(H, W)] * B,
            )
        elif self.cfg.seg_task == "panoptic":
            context_seg_results = self.processor.post_process_panoptic_segmentation(
                outputs=context_seg_output,
                threshold=self.cfg.mask2former.seg_threshold,
                target_sizes=[(H, W)] * B,
                label_ids_to_fuse=set(self.cfg.mask2former.label_ids_to_fuse),
            )
        context_seg_masks = []
        context_seg_infos = []
        for context_seg_result in context_seg_results:
            seg_mask = context_seg_result["segmentation"]
            context_seg_masks.append(seg_mask)
            seg_info = context_seg_result["segments_info"]
            context_seg_infos.append(seg_info)

        gaussians.semantic_labels = torch.zeros(B, 2, H, W, dtype=torch.int32)
        gaussians.instance_labels = torch.zeros(B, 2, H, W, dtype=torch.int32)
        for b, (info, masks) in enumerate(
            zip(context_seg_infos, context_seg_masks)
        ):  # iterate over batch
            if len(info) == 0:
                continue
            for seg in info:  # iterate over segments
                semantic_label = (
                    seg["label_id"] + 1
                )  # note here + 1 for 0 is background
                instance_label = seg["id"]
                mask = masks == seg["id"]
                gaussians.semantic_labels[b, mask] = semantic_label
                gaussians.instance_labels[b, mask] = instance_label
        gaussians.semantic_labels = rearrange(
            gaussians.semantic_labels, "b n h w -> b (n h w)"
        )
        gaussians.instance_labels = rearrange(
            gaussians.instance_labels, "b n h w -> b (n h w)"
        )

        target_seg_output = self.mask2former(
            multi_scale_feat=render_multi_res_feat,
            mask_labels=target_mask_labels,
            class_labels=target_class_labels,
        )
        if self.cfg.seg_task == "instance":
            target_seg_results = self.processor.post_process_instance_segmentation(
                outputs=target_seg_output,
                threshold=self.cfg.mask2former.seg_threshold,
                target_sizes=[(H, W)] * B,
            )
        elif self.cfg.seg_task == "panoptic":
            target_seg_results = self.processor.post_process_panoptic_segmentation(
                outputs=target_seg_output,
                threshold=self.cfg.mask2former.seg_threshold,
                target_sizes=[(H, W)] * B,
                label_ids_to_fuse=set(self.cfg.mask2former.label_ids_to_fuse),
            )
        target_seg_masks = []
        target_seg_infos = []
        for target_seg_result in target_seg_results:
            seg_mask = target_seg_result["segmentation"]
            target_seg_masks.append(seg_mask)
            seg_info = target_seg_result["segments_info"]
            target_seg_infos.append(seg_info)
        return (
            gaussians,
            context_seg_output,
            context_seg_masks,
            context_seg_infos,
            target_seg_output,
            target_seg_masks,
            target_seg_infos,
        )

    def train_refer_seg_forward(
        self,
        context_views_images,
        context_views_intrinsics,
        context_text_token,
        context_mask_labels,
        context_class_labels,
    ):
        word_embeddings = self.lang_module(context_text_token)  # B*[n_sample, out_dim]
        B, _, _, H, W = context_views_images.shape
        croco_input = {
            "image": context_views_images,
            "intrinsics": context_views_intrinsics,
            "near": 0.1,
            "far": 100,
        }
        (
            feat1,
            feat2,
            all_feat1,
            all_feat2,
            dec1,
            dec2,
            shape1,
            shape2,
            view1,
            view2,
        ) = self.backbone(croco_input, return_views=True)
        img1 = context_views_images[:, 0]
        multi_scale_feat1 = self.adapter(img1, all_feat1)
        img2 = context_views_images[:, 1]
        multi_scale_feat2 = self.adapter(img2, all_feat2)
        multi_scale_feat = []
        for feat1, feat2 in zip(multi_scale_feat1, multi_scale_feat2):
            multi_scale_feat.append(torch.stack([feat1, feat2], dim=1))

        context_seg_output = self.mask2former(
            multi_scale_feat=multi_scale_feat,
            word_embeddings=word_embeddings,
            mask_labels=context_mask_labels,
            class_labels=context_class_labels,
        )
        if self.cfg.seg_task == "instance":
            context_seg_results = self.processor.post_process_instance_segmentation(
                outputs=context_seg_output,
                threshold=self.cfg.mask2former.seg_threshold,
                target_sizes=[(H, W)] * B,
                word_embeddings=context_seg_output.word_embeddings,
            )
        elif self.cfg.seg_task == "panoptic":
            context_seg_results = self.processor.post_process_panoptic_segmentation(
                outputs=context_seg_output,
                threshold=self.cfg.mask2former.seg_threshold,
                target_sizes=[(H, W)] * B,
                label_ids_to_fuse=set(self.cfg.mask2former.label_ids_to_fuse),
                word_embeddings=context_seg_output.word_embeddings,
            )
        context_seg_masks = []
        context_seg_infos = []
        for context_seg_result in context_seg_results:
            seg_mask = context_seg_result["segmentation"]
            context_seg_masks.append(seg_mask)
            seg_info = context_seg_result["segments_info"]
            context_seg_infos.append(seg_info)
        return context_seg_output, context_seg_masks, context_seg_infos
