# import sys
# import os

# sys.path.append(os.path.dirname(os.path.dirname(__file__)))
# from collections import OrderedDict
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.utils.checkpoint import checkpoint
# from copy import deepcopy
# from functools import partial
# from typing import Optional, Tuple, List, Any
# from dataclasses import dataclass
# from transformers import PretrainedConfig
# from transformers import PreTrainedModel
# from transformers.modeling_outputs import BaseModelOutput
# from transformers.file_utils import ModelOutput
# import time
# from dust3r.utils.misc import (
#     fill_default_args,
#     freeze_all_params,
#     is_symmetrized,
#     interleave,
#     transpose_to_landscape,
# )
# from dust3r.heads import head_factory
# from dust3r.utils.camera import PoseEncoder
# from dust3r.patch_embed import get_patch_embed
# import dust3r.utils.path_to_croco  # noqa: F401
# from models.croco import CroCoNet, CrocoConfig  # noqa
# from dust3r.blocks import (
#     Block,
#     DecoderBlock,
#     Mlp,
#     Attention,
#     CrossAttention,
#     DropPath,
#     CustomDecoderBlock,
# )  # noqa

# inf = float("inf")
# from accelerate.logging import get_logger

# printer = get_logger(__name__, log_level="DEBUG")


# @dataclass
# class ARCroco3DStereoOutput(ModelOutput):
#     """
#     Custom output class for ARCroco3DStereo.
#     """

#     ress: Optional[List[Any]] = None
#     views: Optional[List[Any]] = None


# def strip_module(state_dict):
#     """
#     Removes the 'module.' prefix from the keys of a state_dict.
#     Args:
#         state_dict (dict): The original state_dict with possible 'module.' prefixes.
#     Returns:
#         OrderedDict: A new state_dict with 'module.' prefixes removed.
#     """
#     new_state_dict = OrderedDict()
#     for k, v in state_dict.items():
#         name = k[7:] if k.startswith("module.") else k
#         new_state_dict[name] = v
#     return new_state_dict


# def load_model(model_path, device, verbose=True):
#     if verbose:
#         print("... loading model from", model_path)
#     ckpt = torch.load(model_path, map_location="cpu")
#     args = ckpt["args"].model.replace(
#         "ManyAR_PatchEmbed", "PatchEmbedDust3R"
#     )  # ManyAR only for aspect ratio not consistent
#     if "landscape_only" not in args:
#         args = args[:-2] + ", landscape_only=False))"
#     else:
#         args = args.replace(" ", "").replace(
#             "landscape_only=True", "landscape_only=False"
#         )
#     assert "landscape_only=False" in args
#     if verbose:
#         print(f"instantiating : {args}")
#     net = eval(args)
#     s = net.load_state_dict(ckpt["model"], strict=False)
#     if verbose:
#         print(s)
#     return net.to(device)


# class ARCroco3DStereoConfig(PretrainedConfig):
#     model_type = "arcroco_3d_stereo"

#     def __init__(
#         self,
#         output_mode="pts3d",
#         head_type="linear",  # or dpt
#         depth_mode=("exp", -float("inf"), float("inf")),
#         conf_mode=("exp", 1, float("inf")),
#         pose_mode=("exp", -float("inf"), float("inf")),
#         freeze="none",
#         landscape_only=True,
#         patch_embed_cls="PatchEmbedDust3R",
#         ray_enc_depth=2,
#         state_size=324,
#         local_mem_size=256,
#         state_pe="2d",
#         state_dec_num_heads=16,
#         depth_head=False,
#         rgb_head=False,
#         pose_conf_head=False,
#         pose_head=False,
#         fuse_layers=4,
#         **croco_kwargs,
#     ):
#         super().__init__()
#         self.output_mode = output_mode
#         self.head_type = head_type
#         self.depth_mode = depth_mode
#         self.conf_mode = conf_mode
#         self.pose_mode = pose_mode
#         self.freeze = freeze
#         self.landscape_only = landscape_only
#         self.patch_embed_cls = patch_embed_cls
#         self.ray_enc_depth = ray_enc_depth
#         self.state_size = state_size
#         self.state_pe = state_pe
#         self.state_dec_num_heads = state_dec_num_heads
#         self.local_mem_size = local_mem_size
#         self.depth_head = depth_head
#         self.rgb_head = rgb_head
#         self.pose_conf_head = pose_conf_head
#         self.pose_head = pose_head
#         self.croco_kwargs = croco_kwargs
#         self.fuse_layers = fuse_layers
#         # self.depth_guidance = depth_guidance


# class LocalMemory(nn.Module):
#     def __init__(
#         self,
#         size,
#         k_dim,
#         v_dim,
#         num_heads,
#         depth=2,
#         mlp_ratio=4.0,
#         qkv_bias=False,
#         drop=0.0,
#         attn_drop=0.0,
#         drop_path=0.0,
#         act_layer=nn.GELU,
#         norm_layer=nn.LayerNorm,
#         norm_mem=True,
#         rope=None,
#     ) -> None:
#         super().__init__()
#         self.v_dim = v_dim
#         self.proj_q = nn.Linear(k_dim, v_dim)
#         self.masked_token = nn.Parameter(
#             torch.randn(1, 1, v_dim) * 0.2, requires_grad=True
#         )
#         self.mem = nn.Parameter(
#             torch.randn(1, size, 2 * v_dim) * 0.2, requires_grad=True
#         )
#         self.write_blocks = nn.ModuleList(
#             [
#                 DecoderBlock(
#                     2 * v_dim,
#                     num_heads,
#                     mlp_ratio=mlp_ratio,
#                     qkv_bias=qkv_bias,
#                     norm_layer=norm_layer,
#                     attn_drop=attn_drop,
#                     drop=drop,
#                     drop_path=drop_path,
#                     act_layer=act_layer,
#                     norm_mem=norm_mem,
#                     rope=rope,
#                 )
#                 for _ in range(depth)
#             ]
#         )
#         self.read_blocks = nn.ModuleList(
#             [
#                 DecoderBlock(
#                     2 * v_dim,
#                     num_heads,
#                     mlp_ratio=mlp_ratio,
#                     qkv_bias=qkv_bias,
#                     norm_layer=norm_layer,
#                     attn_drop=attn_drop,
#                     drop=drop,
#                     drop_path=drop_path,
#                     act_layer=act_layer,
#                     norm_mem=norm_mem,
#                     rope=rope,
#                 )
#                 for _ in range(depth)
#             ]
#         )

#     def update_mem(self, mem, feat_k, feat_v):
#         """
#         mem_k: [B, size, C]
#         mem_v: [B, size, C]
#         feat_k: [B, 1, C]
#         feat_v: [B, 1, C]
#         """
#         feat_k = self.proj_q(feat_k)  # [B, 1, C]
#         feat = torch.cat([feat_k, feat_v], dim=-1)
#         for blk in self.write_blocks:
#             mem, _ = blk(mem, feat, None, None)
#         return mem

#     def inquire(self, query, mem):
#         x = self.proj_q(query)  # [B, 1, C]
#         x = torch.cat([x, self.masked_token.expand(x.shape[0], -1, -1)], dim=-1)
#         for blk in self.read_blocks:
#             x, _ = blk(x, mem, None, None)
#         return x[..., -self.v_dim :]


# class ARCroco3DStereo(CroCoNet):
#     config_class = ARCroco3DStereoConfig
#     base_model_prefix = "arcroco3dstereo"
#     supports_gradient_checkpointing = True

#     def __init__(self, config: ARCroco3DStereoConfig):
#         self.gradient_checkpointing = False
#         self.fixed_input_length = True
#         config.croco_kwargs = fill_default_args(
#             config.croco_kwargs, CrocoConfig.__init__
#         )
#         self.config = config
#         self.patch_embed_cls = config.patch_embed_cls
#         self.croco_args = config.croco_kwargs
#         croco_cfg = CrocoConfig(**self.croco_args)
#         super().__init__(croco_cfg)
#         self.enc_blocks_ray_map = nn.ModuleList(
#             [
#                 Block(
#                     self.enc_embed_dim,
#                     16,
#                     4,
#                     qkv_bias=True,
#                     norm_layer=partial(nn.LayerNorm, eps=1e-6),
#                     rope=self.rope,
#                 )
#                 for _ in range(config.ray_enc_depth)
#             ]
#         )
#         self.enc_norm_ray_map = nn.LayerNorm(self.enc_embed_dim, eps=1e-6)
#         self.dec_num_heads = self.croco_args["dec_num_heads"]
#         self.pose_head_flag = config.pose_head
#         if self.pose_head_flag:
#             self.pose_token = nn.Parameter(
#                 torch.randn(1, 1, self.dec_embed_dim) * 0.02, requires_grad=True
#             )
#             self.pose_retriever = LocalMemory(
#                 size=config.local_mem_size,
#                 k_dim=self.enc_embed_dim,
#                 v_dim=self.dec_embed_dim,
#                 num_heads=self.dec_num_heads,
#                 mlp_ratio=4,
#                 qkv_bias=True,
#                 attn_drop=0.0,
#                 norm_layer=partial(nn.LayerNorm, eps=1e-6),
#                 rope=None,
#             )
#         self.register_tokens = nn.Embedding(config.state_size, self.enc_embed_dim)
#         self.state_size = config.state_size
#         self.state_pe = config.state_pe
#         self.masked_img_token = nn.Parameter(
#             torch.randn(1, self.enc_embed_dim) * 0.02, requires_grad=True
#         )
#         self.masked_ray_map_token = nn.Parameter(
#             torch.randn(1, self.enc_embed_dim) * 0.02, requires_grad=True
#         )
#         self._set_state_decoder(
#             self.enc_embed_dim,
#             self.dec_embed_dim,
#             config.state_dec_num_heads,
#             self.dec_depth,
#             self.croco_args.get("mlp_ratio", None),
#             self.croco_args.get("norm_layer", None),
#             self.croco_args.get("norm_im2_in_dec", None),
#         )
#         self.set_downstream_head(
#             config.output_mode,
#             config.head_type,
#             config.landscape_only,
#             config.depth_mode,
#             config.conf_mode,
#             config.pose_mode,
#             config.depth_head,
#             config.rgb_head,
#             config.pose_conf_head,
#             config.pose_head,
#             **self.croco_args,
#         )
#         self.set_freeze(config.freeze)

#     @classmethod
#     def from_pretrained(cls, pretrained_model_name_or_path, **kw):
#         if os.path.isfile(pretrained_model_name_or_path):
#             return load_model(pretrained_model_name_or_path, device="cpu")
#         else:
#             try:
#                 model = super(ARCroco3DStereo, cls).from_pretrained(
#                     pretrained_model_name_or_path, **kw
#                 )
#             except TypeError as e:
#                 raise Exception(
#                     f"tried to load {pretrained_model_name_or_path} from huggingface, but failed"
#                 )
#             return model

#     def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
#         self.patch_embed = get_patch_embed(
#             self.patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans=3
#         )
#         self.patch_embed_ray_map = get_patch_embed(
#             self.patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans=6
#         )

#     def _set_decoder(
#         self,
#         enc_embed_dim,
#         dec_embed_dim,
#         dec_num_heads,
#         dec_depth,
#         mlp_ratio,
#         norm_layer,
#         norm_im2_in_dec,
#     ):
#         self.dec_depth = dec_depth
#         self.dec_embed_dim = dec_embed_dim
#         self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
#         self.dec_blocks = nn.ModuleList(
#             [
#                 DecoderBlock(
#                     dec_embed_dim,
#                     dec_num_heads,
#                     mlp_ratio=mlp_ratio,
#                     qkv_bias=True,
#                     norm_layer=norm_layer,
#                     norm_mem=norm_im2_in_dec,
#                     rope=self.rope,
#                 )
#                 for i in range(dec_depth)
#             ]
#         )
#         self.dec_norm = norm_layer(dec_embed_dim)

#     def _set_state_decoder(
#         self,
#         enc_embed_dim,
#         dec_embed_dim,
#         dec_num_heads,
#         dec_depth,
#         mlp_ratio,
#         norm_layer,
#         norm_im2_in_dec,
#     ):
#         self.dec_depth_state = dec_depth
#         self.dec_embed_dim_state = dec_embed_dim
#         self.decoder_embed_state = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
#         self.dec_blocks_state = nn.ModuleList(
#             [
#                 DecoderBlock(
#                     dec_embed_dim,
#                     dec_num_heads,
#                     mlp_ratio=mlp_ratio,
#                     qkv_bias=True,
#                     norm_layer=norm_layer,
#                     norm_mem=norm_im2_in_dec,
#                     rope=self.rope,
#                 )
#                 for i in range(dec_depth)
#             ]
#         )
#         self.dec_norm_state = norm_layer(dec_embed_dim)

#     def load_state_dict(self, ckpt, **kw):
#         if all(k.startswith("module") for k in ckpt):
#             ckpt = strip_module(ckpt)
#         new_ckpt = dict(ckpt)
#         if not any(k.startswith("dec_blocks_state") for k in ckpt):
#             for key, value in ckpt.items():
#                 if key.startswith("dec_blocks"):
#                     new_ckpt[key.replace("dec_blocks", "dec_blocks_state")] = value
#         try:
#             return super().load_state_dict(new_ckpt, **kw)
#         except:
#             try:
#                 new_new_ckpt = {
#                     k: v
#                     for k, v in new_ckpt.items()
#                     if not k.startswith("dec_blocks")
#                     and not k.startswith("dec_norm")
#                     and not k.startswith("decoder_embed")
#                 }
#                 return super().load_state_dict(new_new_ckpt, **kw)
#             except:
#                 new_new_ckpt = {}
#                 for key in new_ckpt:
#                     if key in self.state_dict():
#                         if new_ckpt[key].size() == self.state_dict()[key].size():
#                             new_new_ckpt[key] = new_ckpt[key]
#                         else:
#                             printer.info(
#                                 f"Skipping '{key}': size mismatch (ckpt: {new_ckpt[key].size()}, model: {self.state_dict()[key].size()})"
#                             )
#                     else:
#                         printer.info(f"Skipping '{key}': not found in model")
#                 return super().load_state_dict(new_new_ckpt, **kw)

#     def set_freeze(self, freeze):  # this is for use by downstream models
#         self.freeze = freeze
#         to_be_frozen = {
#             "none": [],
#             "mask": [self.mask_token] if hasattr(self, "mask_token") else [],
#             "encoder": [
#                 self.patch_embed,
#                 self.patch_embed_ray_map,
#                 self.masked_img_token,
#                 self.masked_ray_map_token,
#                 self.enc_blocks,
#                 self.enc_blocks_ray_map,
#                 self.enc_norm,
#                 self.enc_norm_ray_map,
#             ],
#             "encoder_and_head": [
#                 self.patch_embed,
#                 self.patch_embed_ray_map,
#                 self.masked_img_token,
#                 self.masked_ray_map_token,
#                 self.enc_blocks,
#                 self.enc_blocks_ray_map,
#                 self.enc_norm,
#                 self.enc_norm_ray_map,
#                 self.downstream_head,
#             ],
#             "encoder_and_decoder": [
#                 self.patch_embed,
#                 self.patch_embed_ray_map,
#                 self.masked_img_token,
#                 self.masked_ray_map_token,
#                 self.enc_blocks,
#                 self.enc_blocks_ray_map,
#                 self.enc_norm,
#                 self.enc_norm_ray_map,
#                 self.dec_blocks,
#                 self.dec_blocks_state,
#                 self.pose_retriever,
#                 self.pose_token,
#                 self.register_tokens,
#                 self.decoder_embed_state,
#                 self.decoder_embed,
#                 self.dec_norm,
#                 self.dec_norm_state,
#             ],
#             "decoder": [
#                 self.dec_blocks,
#                 self.dec_blocks_state,
#                 self.pose_retriever,
#                 self.pose_token,
#             ],
#         }
#         freeze_all_params(to_be_frozen[freeze])

#     def _set_prediction_head(self, *args, **kwargs):
#         """No prediction head"""
#         return

#     def set_downstream_head(
#         self,
#         output_mode,
#         head_type,
#         landscape_only,
#         depth_mode,
#         conf_mode,
#         pose_mode,
#         depth_head,
#         rgb_head,
#         pose_conf_head,
#         pose_head,
#         patch_size,
#         img_size,
#         **kw,
#     ):
#         assert (
#             img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0
#         ), f"{img_size=} must be multiple of {patch_size=}"
#         self.output_mode = output_mode
#         self.head_type = head_type
#         self.depth_mode = depth_mode
#         self.conf_mode = conf_mode
#         self.pose_mode = pose_mode
#         self.downstream_head = head_factory(
#             head_type,
#             output_mode,
#             self,
#             has_conf=bool(conf_mode),
#             has_depth=bool(depth_head),
#             has_rgb=bool(rgb_head),
#             has_pose_conf=bool(pose_conf_head),
#             has_pose=bool(pose_head),
#         )
#         self.head = transpose_to_landscape(
#             self.downstream_head, activate=landscape_only
#         )

#     def _encode_image(self, image, true_shape):
#         x, pos = self.patch_embed(image, true_shape=true_shape)
#         assert self.enc_pos_embed is None
#         for blk in self.enc_blocks:
#             if self.gradient_checkpointing and self.training:
#                 x = checkpoint(blk, x, pos, use_reentrant=False)
#             else:
#                 x = blk(x, pos)
#         x = self.enc_norm(x)
#         return [x], pos, None

#     def _encode_ray_map(self, ray_map, true_shape):
#         x, pos = self.patch_embed_ray_map(ray_map, true_shape=true_shape)
#         assert self.enc_pos_embed is None
#         for blk in self.enc_blocks_ray_map:
#             if self.gradient_checkpointing and self.training:
#                 x = checkpoint(blk, x, pos, use_reentrant=False)
#             else:
#                 x = blk(x, pos)
#         x = self.enc_norm_ray_map(x)
#         return [x], pos, None

#     def _encode_state(self, image_tokens, image_pos):
#         batch_size = image_tokens.shape[0]
#         state_feat = self.register_tokens(
#             torch.arange(self.state_size, device=image_pos.device)
#         )
#         if self.state_pe == "1d":
#             state_pos = (
#                 torch.tensor(
#                     [[i, i] for i in range(self.state_size)],
#                     dtype=image_pos.dtype,
#                     device=image_pos.device,
#                 )[None]
#                 .expand(batch_size, -1, -1)
#                 .contiguous()
#             )  # .long()
#         elif self.state_pe == "2d":
#             width = int(self.state_size**0.5)
#             width = width + 1 if width % 2 == 1 else width
#             state_pos = (
#                 torch.tensor(
#                     [[i // width, i % width] for i in range(self.state_size)],
#                     dtype=image_pos.dtype,
#                     device=image_pos.device,
#                 )[None]
#                 .expand(batch_size, -1, -1)
#                 .contiguous()
#             )
#         elif self.state_pe == "none":
#             state_pos = None
#         state_feat = state_feat[None].expand(batch_size, -1, -1)
#         return state_feat, state_pos, None

#     def _encode_views(self, views, img_mask=None, ray_mask=None):
#         device = views[0]["img"].device
#         batch_size = views[0]["img"].shape[0]
#         given = True
#         if img_mask is None and ray_mask is None:
#             given = False
#         if not given:
#             img_mask = torch.stack(
#                 [view["img_mask"] for view in views], dim=0
#             )  # Shape: (num_views, batch_size)
#             ray_mask = torch.stack(
#                 [view["ray_mask"] for view in views], dim=0
#             )  # Shape: (num_views, batch_size)
#         imgs = torch.stack(
#             [view["img"] for view in views], dim=0
#         )  # Shape: (num_views, batch_size, C, H, W)
#         ray_maps = torch.stack(
#             [view["ray_map"] for view in views], dim=0
#         )  # Shape: (num_views, batch_size, H, W, C)
#         shapes = []
#         for view in views:
#             if "true_shape" in view:
#                 shapes.append(view["true_shape"])
#             else:
#                 shape = torch.tensor(view["img"].shape[-2:], device=device)
#                 shapes.append(shape.unsqueeze(0).repeat(batch_size, 1))
#         shapes = torch.stack(shapes, dim=0).to(
#             imgs.device
#         )  # Shape: (num_views, batch_size, 2)
#         imgs = imgs.view(
#             -1, *imgs.shape[2:]
#         )  # Shape: (num_views * batch_size, C, H, W)
#         ray_maps = ray_maps.view(
#             -1, *ray_maps.shape[2:]
#         )  # Shape: (num_views * batch_size, H, W, C)
#         shapes = shapes.view(-1, 2)  # Shape: (num_views * batch_size, 2)
#         img_masks_flat = img_mask.view(-1)  # Shape: (num_views * batch_size)
#         ray_masks_flat = ray_mask.view(-1)
#         selected_imgs = imgs[img_masks_flat]
#         selected_shapes = shapes[img_masks_flat]
#         if selected_imgs.size(0) > 0:
#             img_out, img_pos, _ = self._encode_image(selected_imgs, selected_shapes)
#         else:
#             raise NotImplementedError
#         full_out = [
#             torch.zeros(
#                 len(views) * batch_size, *img_out[0].shape[1:], device=img_out[0].device
#             )
#             for _ in range(len(img_out))
#         ]
#         full_pos = torch.zeros(
#             len(views) * batch_size,
#             *img_pos.shape[1:],
#             device=img_pos.device,
#             dtype=img_pos.dtype,
#         )
#         for i in range(len(img_out)):
#             full_out[i][img_masks_flat] += img_out[i]
#             full_out[i][~img_masks_flat] += self.masked_img_token
#         full_pos[img_masks_flat] += img_pos
#         ray_maps = ray_maps.permute(0, 3, 1, 2)  # Change shape to (N, C, H, W)
#         selected_ray_maps = ray_maps[ray_masks_flat]
#         selected_shapes_ray = shapes[ray_masks_flat]
#         if selected_ray_maps.size(0) > 0:
#             ray_out, ray_pos, _ = self._encode_ray_map(
#                 selected_ray_maps, selected_shapes_ray
#             )
#             assert len(ray_out) == len(full_out), f"{len(ray_out)}, {len(full_out)}"
#             for i in range(len(ray_out)):
#                 full_out[i][ray_masks_flat] += ray_out[i]
#                 full_out[i][~ray_masks_flat] += self.masked_ray_map_token
#             full_pos[ray_masks_flat] += (
#                 ray_pos * (~img_masks_flat[ray_masks_flat][:, None, None]).long()
#             )
#         else:
#             raymaps = torch.zeros(
#                 1, 6, imgs[0].shape[-2], imgs[0].shape[-1], device=img_out[0].device
#             )
#             ray_mask_flat = torch.zeros_like(img_masks_flat)
#             ray_mask_flat[:1] = True
#             ray_out, ray_pos, _ = self._encode_ray_map(raymaps, shapes[ray_mask_flat])
#             for i in range(len(ray_out)):
#                 full_out[i][ray_mask_flat] += ray_out[i] * 0.0
#                 full_out[i][~ray_mask_flat] += self.masked_ray_map_token * 0.0
#         return (
#             shapes.chunk(len(views), dim=0),
#             [out.chunk(len(views), dim=0) for out in full_out],
#             full_pos.chunk(len(views), dim=0),
#         )

#     def _decoder(self, f_state, pos_state, f_img, pos_img, f_pose, pos_pose):
#         final_output = [(f_state, f_img)]  # before projection
#         assert f_state.shape[-1] == self.dec_embed_dim
#         f_img = self.decoder_embed(f_img)
#         if self.pose_head_flag:
#             assert f_pose is not None and pos_pose is not None
#             f_img = torch.cat([f_pose, f_img], dim=1)
#             pos_img = torch.cat([pos_pose, pos_img], dim=1)
#         final_output.append((f_state, f_img))
#         for blk_state, blk_img in zip(self.dec_blocks_state, self.dec_blocks):
#             if (
#                 self.gradient_checkpointing
#                 and self.training
#                 and torch.is_grad_enabled()
#             ):
#                 f_state, _ = checkpoint(
#                     blk_state,
#                     *final_output[-1][::+1],
#                     pos_state,
#                     pos_img,
#                     use_reentrant=not self.fixed_input_length,
#                 )
#                 f_img, _ = checkpoint(
#                     blk_img,
#                     *final_output[-1][::-1],
#                     pos_img,
#                     pos_state,
#                     use_reentrant=not self.fixed_input_length,
#                 )
#             else:
#                 f_state, _ = blk_state(*final_output[-1][::+1], pos_state, pos_img)
#                 f_img, _ = blk_img(*final_output[-1][::-1], pos_img, pos_state)
#             final_output.append((f_state, f_img))
#         del final_output[1]  # duplicate with final_output[0]
#         final_output[-1] = (
#             self.dec_norm_state(final_output[-1][0]),
#             self.dec_norm(final_output[-1][1]),
#         )
#         return zip(*final_output)

#     def _downstream_head(self, decout, img_shape, **kwargs):
#         B, S, D = decout[-1].shape
#         head = getattr(self, f"head")
#         return head(decout, img_shape, **kwargs)

#     def _init_state(self, image_tokens, image_pos):
#         """
#         Current Version: input the first frame img feature and pose to initialize the state feature and pose
#         """
#         state_feat, state_pos, _ = self._encode_state(image_tokens, image_pos)
#         state_feat = self.decoder_embed_state(state_feat)
#         return state_feat, state_pos

#     def _recurrent_rollout(
#         self,
#         state_feat,
#         state_pos,
#         current_feat,
#         current_pos,
#         pose_feat,
#         pose_pos,
#         init_state_feat,
#         img_mask=None,
#         reset_mask=None,
#         update=None,
#     ):
#         new_state_feat, dec = self._decoder(
#             state_feat, state_pos, current_feat, current_pos, pose_feat, pose_pos
#         )
#         new_state_feat = new_state_feat[-1]
#         return new_state_feat, dec

#     def _get_img_level_feat(self, feat):
#         return torch.mean(feat, dim=1, keepdim=True)

#     def _forward_encoder(self, views):
#         shape, feat_ls, pos = self._encode_views(views)
#         feat = feat_ls[-1]
#         state_feat, state_pos = self._init_state(feat[0], pos[0])
#         mem = self.pose_retriever.mem.expand(feat[0].shape[0], -1, -1)
#         init_state_feat = state_feat.clone()
#         init_mem = mem.clone()
#         return (feat, pos, shape), (
#             init_state_feat,
#             init_mem,
#             state_feat,
#             state_pos,
#             mem,
#         )

#     def _forward_decoder_step(
#         self,
#         views,
#         i,
#         feat_i,
#         pos_i,
#         shape_i,
#         init_state_feat,
#         init_mem,
#         state_feat,
#         state_pos,
#         mem,
#     ):
#         if self.pose_head_flag:
#             global_img_feat_i = self._get_img_level_feat(feat_i)
#             if i == 0:
#                 pose_feat_i = self.pose_token.expand(feat_i.shape[0], -1, -1)
#             else:
#                 pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem)
#             pose_pos_i = -torch.ones(
#                 feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype
#             )
#         else:
#             pose_feat_i = None
#             pose_pos_i = None
#         new_state_feat, dec = self._recurrent_rollout(
#             state_feat,
#             state_pos,
#             feat_i,
#             pos_i,
#             pose_feat_i,
#             pose_pos_i,
#             init_state_feat,
#             img_mask=views[i]["img_mask"],
#             reset_mask=views[i]["reset"],
#             update=views[i].get("update", None),
#         )
#         out_pose_feat_i = dec[-1][:, 0:1]
#         new_mem = self.pose_retriever.update_mem(
#             mem, global_img_feat_i, out_pose_feat_i
#         )
#         head_input = [
#             dec[0].float(),
#             dec[self.dec_depth * 2 // 4][:, 1:].float(),
#             dec[self.dec_depth * 3 // 4][:, 1:].float(),
#             dec[self.dec_depth].float(),
#         ]
#         res = self._downstream_head(head_input, shape_i, pos=pos_i)
#         img_mask = views[i]["img_mask"]
#         update = views[i].get("update", None)
#         if update is not None:
#             update_mask = img_mask & update  # if don't update, then whatever img_mask
#         else:
#             update_mask = img_mask
#         update_mask = update_mask[:, None, None].float()
#         state_feat = new_state_feat * update_mask + state_feat * (
#             1 - update_mask
#         )  # update global state
#         mem = new_mem * update_mask + mem * (1 - update_mask)  # then update local state
#         reset_mask = views[i]["reset"]
#         if reset_mask is not None:
#             reset_mask = reset_mask[:, None, None].float()
#             state_feat = init_state_feat * reset_mask + state_feat * (1 - reset_mask)
#             mem = init_mem * reset_mask + mem * (1 - reset_mask)
#         return res, (state_feat, mem)

#     def _forward_impl(self, views, ret_state=False):
#         shape, feat_ls, pos = self._encode_views(views)
#         feat = feat_ls[-1]
#         state_feat, state_pos = self._init_state(feat[0], pos[0])
#         mem = self.pose_retriever.mem.expand(feat[0].shape[0], -1, -1)
#         init_state_feat = state_feat.clone()
#         init_mem = mem.clone()
#         all_state_args = [(state_feat, state_pos, init_state_feat, mem, init_mem)]
#         ress = []
#         for i in range(len(views)):
#             feat_i = feat[i]
#             pos_i = pos[i]
#             if self.pose_head_flag:
#                 global_img_feat_i = self._get_img_level_feat(feat_i)
#                 if i == 0:
#                     pose_feat_i = self.pose_token.expand(feat_i.shape[0], -1, -1)
#                 else:
#                     pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem)
#                 pose_pos_i = -torch.ones(
#                     feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype
#                 )
#             else:
#                 pose_feat_i = None
#                 pose_pos_i = None
#             new_state_feat, dec = self._recurrent_rollout(
#                 state_feat,
#                 state_pos,
#                 feat_i,
#                 pos_i,
#                 pose_feat_i,
#                 pose_pos_i,
#                 init_state_feat,
#                 img_mask=views[i]["img_mask"],
#                 reset_mask=views[i]["reset"],
#                 update=views[i].get("update", None),
#             )
#             out_pose_feat_i = dec[-1][:, 0:1]
#             new_mem = self.pose_retriever.update_mem(
#                 mem, global_img_feat_i, out_pose_feat_i
#             )
#             assert len(dec) == self.dec_depth + 1
#             head_input = [
#                 dec[0].float(),
#                 dec[self.dec_depth * 2 // 4][:, 1:].float(),
#                 dec[self.dec_depth * 3 // 4][:, 1:].float(),
#                 dec[self.dec_depth].float(),
#             ]
#             res = self._downstream_head(head_input, shape[i], pos=pos_i)
#             ress.append(res)
#             img_mask = views[i]["img_mask"]
#             update = views[i].get("update", None)
#             if update is not None:
#                 update_mask = (
#                     img_mask & update
#                 )  # if don't update, then whatever img_mask
#             else:
#                 update_mask = img_mask
#             update_mask = update_mask[:, None, None].float()
#             state_feat = new_state_feat * update_mask + state_feat * (
#                 1 - update_mask
#             )  # update global state
#             mem = new_mem * update_mask + mem * (
#                 1 - update_mask
#             )  # then update local state
#             reset_mask = views[i]["reset"]
#             if reset_mask is not None:
#                 reset_mask = reset_mask[:, None, None].float()
#                 state_feat = init_state_feat * reset_mask + state_feat * (
#                     1 - reset_mask
#                 )
#                 mem = init_mem * reset_mask + mem * (1 - reset_mask)
#             all_state_args.append(
#                 (state_feat, state_pos, init_state_feat, mem, init_mem)
#             )
#         if ret_state:
#             return ress, views, all_state_args
#         return ress, views

#     def forward(self, views, ret_state=False):
#         if ret_state:
#             ress, views, state_args = self._forward_impl(views, ret_state=ret_state)
#             return ARCroco3DStereoOutput(ress=ress, views=views), state_args
#         else:
#             ress, views = self._forward_impl(views, ret_state=ret_state)
#             return ARCroco3DStereoOutput(ress=ress, views=views)

#     def inference_step(
#         self, view, state_feat, state_pos, init_state_feat, mem, init_mem
#     ):
#         batch_size = view["img"].shape[0]
#         raymaps = []
#         shapes = []
#         for j in range(batch_size):
#             assert view["ray_mask"][j]
#             raymap = view["ray_map"][[j]].permute(0, 3, 1, 2)
#             raymaps.append(raymap)
#             shapes.append(
#                 view.get(
#                     "true_shape",
#                     torch.tensor(view["ray_map"].shape[-2:])[None].repeat(
#                         view["ray_map"].shape[0], 1
#                     ),
#                 )[[j]]
#             )

#         raymaps = torch.cat(raymaps, dim=0)
#         shape = torch.cat(shapes, dim=0).to(raymaps.device)
#         feat_ls, pos, _ = self._encode_ray_map(raymaps, shapes)

#         feat_i = feat_ls[-1]
#         pos_i = pos
#         if self.pose_head_flag:
#             global_img_feat_i = self._get_img_level_feat(feat_i)
#             pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem)
#             pose_pos_i = -torch.ones(
#                 feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype
#             )
#         else:
#             pose_feat_i = None
#             pose_pos_i = None
#         new_state_feat, dec = self._recurrent_rollout(
#             state_feat,
#             state_pos,
#             feat_i,
#             pos_i,
#             pose_feat_i,
#             pose_pos_i,
#             init_state_feat,
#             img_mask=view["img_mask"],
#             reset_mask=view["reset"],
#             update=view.get("update", None),
#         )

#         out_pose_feat_i = dec[-1][:, 0:1]
#         new_mem = self.pose_retriever.update_mem(
#             mem, global_img_feat_i, out_pose_feat_i
#         )
#         assert len(dec) == self.dec_depth + 1
#         head_input = [
#             dec[0].float(),
#             dec[self.dec_depth * 2 // 4][:, 1:].float(),
#             dec[self.dec_depth * 3 // 4][:, 1:].float(),
#             dec[self.dec_depth].float(),
#         ]
#         res = self._downstream_head(head_input, shape, pos=pos_i)
#         return res, view

#     def forward_recurrent(self, views, device, ret_state=False):
#         ress = []
#         all_state_args = []
#         for i, view in enumerate(views):
#             device = view["img"].device
#             batch_size = view["img"].shape[0]
#             img_mask = view["img_mask"].reshape(
#                 -1, batch_size
#             )  # Shape: (1, batch_size)
#             ray_mask = view["ray_mask"].reshape(
#                 -1, batch_size
#             )  # Shape: (1, batch_size)
#             imgs = view["img"].unsqueeze(0)  # Shape: (1, batch_size, C, H, W)
#             ray_maps = view["ray_map"].unsqueeze(
#                 0
#             )  # Shape: (num_views, batch_size, H, W, C)
#             shapes = (
#                 view["true_shape"].unsqueeze(0)
#                 if "true_shape" in view
#                 else torch.tensor(view["img"].shape[-2:], device=device)
#                 .unsqueeze(0)
#                 .repeat(batch_size, 1)
#                 .unsqueeze(0)
#             )  # Shape: (num_views, batch_size, 2)
#             imgs = imgs.view(
#                 -1, *imgs.shape[2:]
#             )  # Shape: (num_views * batch_size, C, H, W)
#             ray_maps = ray_maps.view(
#                 -1, *ray_maps.shape[2:]
#             )  # Shape: (num_views * batch_size, H, W, C)
#             shapes = shapes.view(-1, 2).to(
#                 imgs.device
#             )  # Shape: (num_views * batch_size, 2)
#             img_masks_flat = img_mask.view(-1)  # Shape: (num_views * batch_size)
#             ray_masks_flat = ray_mask.view(-1)
#             selected_imgs = imgs[img_masks_flat]
#             selected_shapes = shapes[img_masks_flat]
#             if selected_imgs.size(0) > 0:
#                 img_out, img_pos, _ = self._encode_image(selected_imgs, selected_shapes)
#             else:
#                 img_out, img_pos = None, None
#             ray_maps = ray_maps.permute(0, 3, 1, 2)  # Change shape to (N, C, H, W)
#             selected_ray_maps = ray_maps[ray_masks_flat]
#             selected_shapes_ray = shapes[ray_masks_flat]
#             if selected_ray_maps.size(0) > 0:
#                 ray_out, ray_pos, _ = self._encode_ray_map(
#                     selected_ray_maps, selected_shapes_ray
#                 )
#             else:
#                 ray_out, ray_pos = None, None

#             shape = shapes
#             if img_out is not None and ray_out is None:
#                 feat_i = img_out[-1]
#                 pos_i = img_pos
#             elif img_out is None and ray_out is not None:
#                 feat_i = ray_out[-1]
#                 pos_i = ray_pos
#             elif img_out is not None and ray_out is not None:
#                 feat_i = img_out[-1] + ray_out[-1]
#                 pos_i = img_pos
#             else:
#                 raise NotImplementedError

#             if i == 0:
#                 state_feat, state_pos = self._init_state(feat_i, pos_i)
#                 mem = self.pose_retriever.mem.expand(feat_i.shape[0], -1, -1)
#                 init_state_feat = state_feat.clone()
#                 init_mem = mem.clone()
#                 all_state_args.append(
#                     (state_feat, state_pos, init_state_feat, mem, init_mem)
#                 )

#             if self.pose_head_flag:
#                 global_img_feat_i = self._get_img_level_feat(feat_i)
#                 if i == 0:
#                     pose_feat_i = self.pose_token.expand(feat_i.shape[0], -1, -1)
#                 else:
#                     pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem)
#                 pose_pos_i = -torch.ones(
#                     feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype
#                 )
#             else:
#                 pose_feat_i = None
#                 pose_pos_i = None
#             new_state_feat, dec = self._recurrent_rollout(
#                 state_feat,
#                 state_pos,
#                 feat_i,
#                 pos_i,
#                 pose_feat_i,
#                 pose_pos_i,
#                 init_state_feat,
#                 img_mask=view["img_mask"],
#                 reset_mask=view["reset"],
#                 update=view.get("update", None),
#             )
#             out_pose_feat_i = dec[-1][:, 0:1]
#             new_mem = self.pose_retriever.update_mem(
#                 mem, global_img_feat_i, out_pose_feat_i
#             )
#             assert len(dec) == self.dec_depth + 1
#             head_input = [
#                 dec[0].float(),
#                 dec[self.dec_depth * 2 // 4][:, 1:].float(),
#                 dec[self.dec_depth * 3 // 4][:, 1:].float(),
#                 dec[self.dec_depth].float(),
#             ]
#             res = self._downstream_head(head_input, shape, pos=pos_i)
#             ress.append(res)
#             img_mask = view["img_mask"]
#             update = view.get("update", None)
#             if update is not None:
#                 update_mask = (
#                     img_mask & update
#                 )  # if don't update, then whatever img_mask
#             else:
#                 update_mask = img_mask
#             update_mask = update_mask[:, None, None].float()
#             state_feat = new_state_feat * update_mask + state_feat * (
#                 1 - update_mask
#             )  # update global state
#             mem = new_mem * update_mask + mem * (
#                 1 - update_mask
#             )  # then update local state
#             reset_mask = view["reset"]
#             if reset_mask is not None:
#                 reset_mask = reset_mask[:, None, None].float()
#                 state_feat = init_state_feat * reset_mask + state_feat * (
#                     1 - reset_mask
#                 )
#                 mem = init_mem * reset_mask + mem * (1 - reset_mask)
#             all_state_args.append(
#                 (state_feat, state_pos, init_state_feat, mem, init_mem)
#             )
#         if ret_state:
#             return ress, views, all_state_args
#         return ress, views


# # if __name__ == "__main__":
# #     print(ARCroco3DStereo.mro())
# #     cfg = ARCroco3DStereoConfig(
# #         state_size=256,
# #         pos_embed="RoPE100",
# #         rgb_head=True,
# #         pose_head=True,
# #         img_size=(224, 224),
# #         head_type="linear",
# #         output_mode="pts3d+pose",
# #         depth_mode=("exp", -inf, inf),
# #         conf_mode=("exp", 1, inf),
# #         pose_mode=("exp", -inf, inf),
# #         enc_embed_dim=1024,
# #         enc_depth=24,
# #         enc_num_heads=16,
# #         dec_embed_dim=768,
# #         dec_depth=12,
# #         dec_num_heads=12,
# #     )
# #     ARCroco3DStereo(cfg)


# def zero_module(module):
#     """
#     Zero out the parameters of a module and return it.
#     """
#     for p in module.parameters():
#         p.detach().zero_()
#     return module

# def conv_nd(dims, *args, **kwargs):
#     """
#     Create a 1D, 2D, or 3D convolution module.
#     """
#     if dims == 1:
#         # in_channels, out_channels, kernel_size
#         return nn.Conv1d(*args, **kwargs)
#     elif dims == 2:
#         return nn.Conv2d(*args, **kwargs)
#     elif dims == 3:
#         return nn.Conv3d(*args, **kwargs)
#     raise ValueError(f"unsupported dimensions: {dims}")

# def make_zero_conv(channels):
#     return nn.Sequential(zero_module(conv_nd(1, channels, channels, 1, padding=0)))

# class ARCroco3DStereoGuided(CroCoNet):
#     config_class = ARCroco3DStereoConfig
#     base_model_prefix = "arcroco3dstereo"
#     supports_gradient_checkpointing = True

#     def __init__(self, config: ARCroco3DStereoConfig):
#         self.gradient_checkpointing = False
#         self.fixed_input_length = True
#         config.croco_kwargs = fill_default_args(
#             config.croco_kwargs, CrocoConfig.__init__
#         )
#         self.config = config
#         self.patch_embed_cls = config.patch_embed_cls
#         self.croco_args = config.croco_kwargs
    
#         self.dec_embed_dim = self.croco_args["dec_embed_dim"]

#         croco_cfg = CrocoConfig(**self.croco_args)
#         super().__init__(croco_cfg)
#         self.enc_blocks_ray_map = nn.ModuleList(
#             [
#                 Block(
#                     self.enc_embed_dim,
#                     16,
#                     4,
#                     qkv_bias=True,
#                     norm_layer=partial(nn.LayerNorm, eps=1e-6),
#                     rope=self.rope,
#                 )
#                 for _ in range(config.ray_enc_depth)
#             ]
#         )
#         self.enc_norm_ray_map = nn.LayerNorm(self.enc_embed_dim, eps=1e-6)
#         self.dec_num_heads = self.croco_args["dec_num_heads"]
#         self.pose_head_flag = config.pose_head
#         # self.depth_guidance = config.depth_guidance

#         if self.pose_head_flag:
#             self.pose_token = nn.Parameter(
#                 torch.randn(1, 1, self.dec_embed_dim) * 0.02, requires_grad=True
#             )
#             self.pose_retriever = LocalMemory(
#                 size=config.local_mem_size,
#                 k_dim=self.enc_embed_dim,
#                 v_dim=self.dec_embed_dim,
#                 num_heads=self.dec_num_heads,
#                 mlp_ratio=4,
#                 qkv_bias=True,
#                 attn_drop=0.0,
#                 norm_layer=partial(nn.LayerNorm, eps=1e-6),
#                 rope=None,
#             )
#         self.register_tokens = nn.Embedding(config.state_size, self.enc_embed_dim)
#         self.state_size = config.state_size
#         self.state_pe = config.state_pe
#         self.masked_img_token = nn.Parameter(
#             torch.randn(1, self.enc_embed_dim) * 0.02, requires_grad=True
#         )
#         self.masked_ray_map_token = nn.Parameter(
#             torch.randn(1, self.enc_embed_dim) * 0.02, requires_grad=True
#         )
#         self._set_state_decoder(
#             self.enc_embed_dim,
#             self.dec_embed_dim,
#             config.state_dec_num_heads,
#             self.dec_depth,
#             self.croco_args.get("mlp_ratio", None),
#             self.croco_args.get("norm_layer", None),
#             self.croco_args.get("norm_im2_in_dec", None),
#         )

#         # decoder block for depth prompt
#         self.dec_blocks_pc = nn.ModuleList([
#             Block(self.dec_embed_dim,
#                   self.dec_num_heads,
#                   mlp_ratio=self.croco_args.get("mlp_ratio", None),
#                   qkv_bias=True,
#                   norm_layer=self.croco_args.get("norm_layer", None),
#                   rope=self.rope)
#             for i in range(self.croco_args.get("dec_depth", None)//2-2)
#         ])

#         self.zero_convs = []
#         for i in range(len(self.dec_blocks_pc) + 1):
#             self.zero_convs.append(make_zero_conv(self.dec_embed_dim))
#         self.zero_convs = nn.ModuleList(self.zero_convs)

#         self.set_downstream_head(
#             config.output_mode,
#             config.head_type,
#             config.landscape_only,
#             config.depth_mode,
#             config.conf_mode,
#             config.pose_mode,
#             config.depth_head,
#             config.rgb_head,
#             config.pose_conf_head,
#             config.pose_head,
#             **self.croco_args,
#         )
#         self.set_freeze(config.freeze)

#         print(f"config.landscape_only: {config.landscape_only}")    

#     @classmethod
#     def from_pretrained(cls, pretrained_model_name_or_path, **kw):
#         if os.path.isfile(pretrained_model_name_or_path):
#             return load_model(pretrained_model_name_or_path, device="cpu")
#         else:
#             try:
#                 model = super(ARCroco3DStereo, cls).from_pretrained(
#                     pretrained_model_name_or_path, **kw
#                 )
#             except TypeError as e:
#                 raise Exception(
#                     f"tried to load {pretrained_model_name_or_path} from huggingface, but failed"
#                 )
#             return model

#     def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
#         self.patch_embed = get_patch_embed(
#             self.patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans=3
#         )
#         self.patch_embed_ray_map = get_patch_embed(
#             self.patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans=6
#         )

#         self.patch_embed_pc = get_patch_embed(
#             self.patch_embed_cls, img_size, patch_size, self.dec_embed_dim, in_chans=2
#         )

#     def _set_decoder(
#         self,
#         enc_embed_dim,
#         dec_embed_dim,
#         dec_num_heads,
#         dec_depth,
#         mlp_ratio,
#         norm_layer,
#         norm_im2_in_dec,
#     ):
#         self.dec_depth = dec_depth
#         self.dec_embed_dim = dec_embed_dim
#         self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
#         self.dec_blocks = nn.ModuleList(
#             [
#                 DecoderBlock(
#                     dec_embed_dim,
#                     dec_num_heads,
#                     mlp_ratio=mlp_ratio,
#                     qkv_bias=True,
#                     norm_layer=norm_layer,
#                     norm_mem=norm_im2_in_dec,
#                     rope=self.rope,
#                 )
#                 for i in range(dec_depth)
#             ]
#         )
#         self.dec_norm = norm_layer(dec_embed_dim)

#     def _set_state_decoder(
#         self,
#         enc_embed_dim,
#         dec_embed_dim,
#         dec_num_heads,
#         dec_depth,
#         mlp_ratio,
#         norm_layer,
#         norm_im2_in_dec,
#     ):
#         self.dec_depth_state = dec_depth
#         self.dec_embed_dim_state = dec_embed_dim
#         self.decoder_embed_state = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
#         self.dec_blocks_state = nn.ModuleList(
#             [
#                 DecoderBlock(
#                     dec_embed_dim,
#                     dec_num_heads,
#                     mlp_ratio=mlp_ratio,
#                     qkv_bias=True,
#                     norm_layer=norm_layer,
#                     norm_mem=norm_im2_in_dec,
#                     rope=self.rope,
#                 )
#                 for i in range(dec_depth)
#             ]
#         )
#         self.dec_norm_state = norm_layer(dec_embed_dim)

#     def load_state_dict(self, ckpt, **kw):
#         if all(k.startswith("module") for k in ckpt):
#             ckpt = strip_module(ckpt)
#         new_ckpt = dict(ckpt)
#         if not any(k.startswith("dec_blocks_state") for k in ckpt):
#             for key, value in ckpt.items():
#                 if key.startswith("dec_blocks"):
#                     new_ckpt[key.replace("dec_blocks", "dec_blocks_state")] = value
#         try:
#             return super().load_state_dict(new_ckpt, **kw)
#         except:
#             try:
#                 new_new_ckpt = {
#                     k: v
#                     for k, v in new_ckpt.items()
#                     if not k.startswith("dec_blocks")
#                     and not k.startswith("dec_norm")
#                     and not k.startswith("decoder_embed")
#                 }
#                 return super().load_state_dict(new_new_ckpt, **kw)
#             except:
#                 new_new_ckpt = {}
#                 for key in new_ckpt:
#                     if key in self.state_dict():
#                         if new_ckpt[key].size() == self.state_dict()[key].size():
#                             new_new_ckpt[key] = new_ckpt[key]
#                         else:
#                             printer.info(
#                                 f"Skipping '{key}': size mismatch (ckpt: {new_ckpt[key].size()}, model: {self.state_dict()[key].size()})"
#                             )
#                     else:
#                         printer.info(f"Skipping '{key}': not found in model")
#                 return super().load_state_dict(new_new_ckpt, **kw)

#     def set_freeze(self, freeze):  # this is for use by downstream models
#         self.freeze = freeze
#         to_be_frozen = {
#             "none": [],
#             "mask": [self.mask_token] if hasattr(self, "mask_token") else [],
#             "encoder": [
#                 self.patch_embed,
#                 self.patch_embed_ray_map,
#                 self.masked_img_token,
#                 self.masked_ray_map_token,
#                 self.enc_blocks,
#                 self.enc_blocks_ray_map,
#                 self.enc_norm,
#                 self.enc_norm_ray_map,
#             ],
#             "encoder_and_head": [
#                 self.patch_embed,
#                 self.patch_embed_ray_map,
#                 self.masked_img_token,
#                 self.masked_ray_map_token,
#                 self.enc_blocks,
#                 self.enc_blocks_ray_map,
#                 self.enc_norm,
#                 self.enc_norm_ray_map,
#                 self.downstream_head,
#             ],
#             "encoder_and_decoder": [
#                 self.patch_embed,
#                 self.patch_embed_ray_map,
#                 self.masked_img_token,
#                 self.masked_ray_map_token,
#                 self.enc_blocks,
#                 self.enc_blocks_ray_map,
#                 self.enc_norm,
#                 self.enc_norm_ray_map,
#                 self.dec_blocks,
#                 self.dec_blocks_state,
#                 self.pose_retriever,
#                 self.pose_token,
#                 self.register_tokens,
#                 self.decoder_embed_state,
#                 self.decoder_embed,
#                 self.dec_norm,
#                 self.dec_norm_state,
#             ],
#             "decoder": [
#                 self.dec_blocks,
#                 self.dec_blocks_state,
#                 self.pose_retriever,
#                 self.pose_token,
#             ],
#         }
#         freeze_all_params(to_be_frozen[freeze])

#     def _set_prediction_head(self, *args, **kwargs):
#         """No prediction head"""
#         return

#     def set_downstream_head(
#         self,
#         output_mode,
#         head_type,
#         landscape_only,
#         depth_mode,
#         conf_mode,
#         pose_mode,
#         depth_head,
#         rgb_head,
#         pose_conf_head,
#         pose_head,
#         patch_size,
#         img_size,
#         **kw,
#     ):
#         assert (
#             img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0
#         ), f"{img_size=} must be multiple of {patch_size=}"
#         self.output_mode = output_mode
#         self.head_type = head_type
#         self.depth_mode = depth_mode
#         self.conf_mode = conf_mode
#         self.pose_mode = pose_mode
#         self.downstream_head = head_factory(
#             head_type,
#             output_mode,
#             self,
#             has_conf=bool(conf_mode),
#             has_depth=bool(depth_head),
#             has_rgb=bool(rgb_head),
#             has_pose_conf=bool(pose_conf_head),
#             has_pose=bool(pose_head),
#             # depth_guidance=self.depth_guidance
#         )
#         self.head = transpose_to_landscape(
#             self.downstream_head, activate=landscape_only
#         )

#     def _encode_image(self, image, true_shape):
#         """
#         Forward prop of VIT encoder:
#         Arguments:
#             image: images of shape [B*num_views, C, H, W]
#         """
#         # print(f"[_encode_image] image.shape: {image.shape}")
#         x, pos = self.patch_embed(image, true_shape=true_shape)
#         # print(f"[patch_embed] x.shape: {x.shape}")
#         assert self.enc_pos_embed is None
#         for blk in self.enc_blocks:
#             if self.gradient_checkpointing and self.training:
#                 x = checkpoint(blk, x, pos, use_reentrant=False)
#             else:
#                 x = blk(x, pos)
#         x = self.enc_norm(x)
#         return [x], pos, None

#     def _encode_ray_map(self, ray_map, true_shape):
#         # print(f"[_encode_ray_map] ray_map.shape: {ray_map.shape}")
#         x, pos = self.patch_embed_ray_map(ray_map, true_shape=true_shape)
#         assert self.enc_pos_embed is None
#         for blk in self.enc_blocks_ray_map:
#             if self.gradient_checkpointing and self.training:
#                 x = checkpoint(blk, x, pos, use_reentrant=False)
#             else:
#                 x = blk(x, pos)
#         x = self.enc_norm_ray_map(x)
#         return [x], pos, None

#     def _encode_state(self, image_tokens, image_pos):
#         batch_size = image_tokens.shape[0]
#         state_feat = self.register_tokens(
#             torch.arange(self.state_size, device=image_pos.device)
#         )
#         if self.state_pe == "1d":
#             state_pos = (
#                 torch.tensor(
#                     [[i, i] for i in range(self.state_size)],
#                     dtype=image_pos.dtype,
#                     device=image_pos.device,
#                 )[None]
#                 .expand(batch_size, -1, -1)
#                 .contiguous()
#             )  # .long()
#         elif self.state_pe == "2d":
#             width = int(self.state_size**0.5)
#             width = width + 1 if width % 2 == 1 else width
#             state_pos = (
#                 torch.tensor(
#                     [[i // width, i % width] for i in range(self.state_size)],
#                     dtype=image_pos.dtype,
#                     device=image_pos.device,
#                 )[None]
#                 .expand(batch_size, -1, -1)
#                 .contiguous()
#             )
#         elif self.state_pe == "none":
#             state_pos = None
#         state_feat = state_feat[None].expand(batch_size, -1, -1)
#         return state_feat, state_pos, None

#     def _encode_views(self, views, img_mask=None, ray_mask=None):
#         """
#         Encoder forward propagation to get a set of tokens
        
#         Arguments:
#             views: list of views, where length of list corresponds to sequence length
#                     and every item in list has shape [B, ...]
#                    views[0].keys(): dict_keys(['img', 'depthmap', 'camera_pose', 'camera_intrinsics', 'dataset', 'label', 'is_metric', 'instance', 'is_video', 'quantile', 'img_mask', 'ray_mask', 'camera_only', 'depth_only', 'single_view', 'reset', 'idx', 'true_shape', 'sky_mask', 'ray_map', 'pts3d', 'valid_mask', 'rng'])
#         """
#         device = views[0]["img"].device
#         batch_size = views[0]["img"].shape[0]
#         given = True

#         if img_mask is None and ray_mask is None:
#             given = False
#         if not given:
#             img_mask = torch.stack(
#                 [view["img_mask"] for view in views], dim=0
#             )  # Shape: (num_views, batch_size)
#             ray_mask = torch.stack(
#                 [view["ray_mask"] for view in views], dim=0
#             )  # Shape: (num_views, batch_size)

#         imgs = torch.stack(
#             [view["img"] for view in views], dim=0
#         )  # Shape: (num_views, batch_size, C, H, W)
#         ray_maps = torch.stack(
#             [view["ray_map"] for view in views], dim=0
#         )  # Shape: (num_views, batch_size, H, W, C)


#         shapes = []
#         for view in views:
#             if "true_shape" in view:
#                 shapes.append(view["true_shape"])
#             else:
#                 shape = torch.tensor(view["img"].shape[-2:], device=device)
#                 shapes.append(shape.unsqueeze(0).repeat(batch_size, 1))
#         shapes = torch.stack(shapes, dim=0).to(
#             imgs.device
#         )  # Shape: (num_views, batch_size, 2)
#         imgs = imgs.view(
#             -1, *imgs.shape[2:]
#         )  # Shape: (num_views * batch_size, C, H, W)
#         ray_maps = ray_maps.view(
#             -1, *ray_maps.shape[2:]
#         )  # Shape: (num_views * batch_size, H, W, C)
#         shapes = shapes.view(-1, 2)  # Shape: (num_views * batch_size, 2)
#         img_masks_flat = img_mask.view(-1)  # Shape: (num_views * batch_size)
#         ray_masks_flat = ray_mask.view(-1)
#         selected_imgs = imgs[img_masks_flat]
#         selected_shapes = shapes[img_masks_flat]
#         # print(f"len(selected_imgs): {len(selected_imgs)}")
#         if selected_imgs.size(0) > 0:
#             img_out, img_pos, _ = self._encode_image(selected_imgs, selected_shapes)
#         else:
#             raise NotImplementedError
        
#         # print(f"len(img_out): {len(img_out)}")
#         # print(f"img_out[0].shape: {img_out[0].shape}")

#         # print(f"imgs.shape: {imgs.shape}")

#         # N * [n_views*batch_size, ]
#         full_out = [
#             torch.zeros(
#                 len(views) * batch_size, *img_out[0].shape[1:], device=img_out[0].device
#             )
#             for _ in range(len(img_out))
#         ]
#         # print(f"imgs.shape: {imgs.shape}")
#         full_pos = torch.zeros(
#             len(views) * batch_size,
#             *img_pos.shape[1:],
#             device=img_pos.device,
#             dtype=img_pos.dtype,
#         )
#         for i in range(len(img_out)):
#             full_out[i][img_masks_flat] += img_out[i]
#             full_out[i][~img_masks_flat] += self.masked_img_token
#         full_pos[img_masks_flat] += img_pos
#         ray_maps = ray_maps.permute(0, 3, 1, 2)  # Change shape to (N, C, H, W)
#         selected_ray_maps = ray_maps[ray_masks_flat]
#         selected_shapes_ray = shapes[ray_masks_flat]
#         if selected_ray_maps.size(0) > 0:
#             ray_out, ray_pos, _ = self._encode_ray_map(
#                 selected_ray_maps, selected_shapes_ray
#             )
#             assert len(ray_out) == len(full_out), f"{len(ray_out)}, {len(full_out)}"
#             for i in range(len(ray_out)):
#                 full_out[i][ray_masks_flat] += ray_out[i]
#                 full_out[i][~ray_masks_flat] += self.masked_ray_map_token
#             full_pos[ray_masks_flat] += (
#                 ray_pos * (~img_masks_flat[ray_masks_flat][:, None, None]).long()
#             )
#         else:
#             raymaps = torch.zeros(
#                 1, 6, imgs[0].shape[-2], imgs[0].shape[-1], device=img_out[0].device
#             )
#             ray_mask_flat = torch.zeros_like(img_masks_flat)
#             ray_mask_flat[:1] = True
#             ray_out, ray_pos, _ = self._encode_ray_map(raymaps, shapes[ray_mask_flat])
#             for i in range(len(ray_out)):
#                 full_out[i][ray_mask_flat] += ray_out[i] * 0.0
#                 full_out[i][~ray_mask_flat] += self.masked_ray_map_token * 0.0
#         return (
#             shapes.chunk(len(views), dim=0),
#             [out.chunk(len(views), dim=0) for out in full_out],
#             full_pos.chunk(len(views), dim=0),
#         )

#     def normalize_depth_prompt(self, depth_prompt: torch.Tensor):
#         """
#         normalize input depth in [0,1] range
#         """
#         # print(f"depth_prompt.shape: {depth_prompt.shape}")

#         B, C, H, W = depth_prompt.shape
#         min_val = torch.quantile(
#             depth_prompt.reshape(B, -1), 0., dim=1, keepdim=True)[:, :, None, None]
#         max_val = torch.quantile(
#             depth_prompt.reshape(B, -1), 1., dim=1, keepdim=True)[:, :, None, None]

#         # min_val=torch.tensor(0.0)
#         # max_val=torch.tensor(100.0)

#         denom = (max_val - min_val).clamp_min(1e-6)
#         depth_prompt = (depth_prompt - min_val) / denom

#         return depth_prompt, min_val, max_val

#     def _decoder(self, 
#                  f_state, 
#                  pos_state, 
#                  f_img, 
#                  pos_img, 
#                  f_pose, 
#                  pos_pose, 
#                  depth_prompt,
#                  depth_true_shape):
#         """
#         Forward prop. of 2 identical decoder blocks(self-attn + cross-attn)
#         for [f_state, f_img], if pose [f_state, [f_pose, f_img]]

#         Args:
#             f_state (torch.Tensor): Embedding of state feature [B, D, D]
#             pos_state (int): Positions of state
            
#             f_img (torch.Tensor): Embedding of image features [B, N, D]
#             pos_img (): Positions(x,y) of image features
            
#             f_pose (torch.Tensor): Embedding of pose feature [B, 1, D]
#             pos_pose (): Positions of pose feature

#             depth_prompt (torch.Tensor): Sparse depth promp [B, H, W]
#         Returns:
#             int: The product of a and b.

#         """

#         # print(f"pos_img.shape: {pos_img.shape}")
#         # print(f"pos_img: {pos_img}")

#         depth_prompt_normed, min1, max1 = self.normalize_depth_prompt(depth_prompt.unsqueeze(1))
#         depth_prompt_mask = depth_prompt_normed != 0
#         # print(f"depth_prompt_normed.shape: {depth_prompt_normed.shape}")
#         # print(f"depth_prompt_mask.shape: {depth_prompt_mask.shape}")
#         # exit()
#         feat_pc, pos_pc = self.patch_embed_pc(torch.cat((depth_prompt_normed, depth_prompt_mask), dim=1), true_shape=depth_true_shape)
        
#         n_image_tokens = f_img.shape[1]
#         final_output = [(f_state, f_img)]  # before projection
#         assert f_state.shape[-1] == self.dec_embed_dim
#         f_img = self.decoder_embed(f_img)
#         # print(f"before f_img.shape: {f_img.shape}")
#         if self.pose_head_flag:
#             assert f_pose is not None and pos_pose is not None
#             f_img = torch.cat([f_pose, f_img], dim=1) # f_img = [f_pose, f_img]
#             pos_img = torch.cat([pos_pose, pos_img], dim=1)
#         # print(f"after f_img.shape: {f_img.shape}")

#         # conv_res = self.zero_convs[0](feat_pc.transpose(-1,-2)).transpose(-1,-2)
#         # print(f"feat_pc.shape: {feat_pc.shape}")
#         # !depth features are added only to image features
#         f_img[:,:n_image_tokens,:] += self.zero_convs[0](feat_pc.transpose(-1,-2)).transpose(-1,-2)

#         final_output.append((f_state, f_img))

#         for i in range(len(self.dec_blocks)):
#             blk_state = self.dec_blocks_state[i]
#             blk_img = self.dec_blocks[i]

#             if (
#                 self.gradient_checkpointing
#                 and self.training
#                 and torch.is_grad_enabled()
#             ):
#                 f_state, _ = checkpoint(
#                     blk_state,
#                     *final_output[-1][::+1],
#                     pos_state,
#                     pos_img,
#                     use_reentrant=not self.fixed_input_length,
#                 )
#                 f_img, _ = checkpoint(
#                     blk_img,
#                     *final_output[-1][::-1],
#                     pos_img,
#                     pos_state,
#                     use_reentrant=not self.fixed_input_length,
#                 )
#             else:
#                 # decoder(state,img)
#                 f_state, _ = blk_state(*final_output[-1][::+1], pos_state, pos_img)
#                 # decoder(img,state)
#                 f_img, _ = blk_img(*final_output[-1][::-1], pos_img, pos_state)
            
#             if i < len(self.dec_blocks_pc):

#                 feat_pc = self.dec_blocks_pc[i](feat_pc, pos_pc)
#                 f_img[:,:n_image_tokens,:] += self.zero_convs[i+1](feat_pc.transpose(-1,-2)).transpose(-1,-2)


#             final_output.append((f_state, f_img))

#         del final_output[1]  # duplicate with final_output[0]
#         final_output[-1] = (
#             self.dec_norm_state(final_output[-1][0]),
#             self.dec_norm(final_output[-1][1]),
#         )
#         return zip(*final_output)

#     # def _downstream_head(self, decout, img_shape, prompt_depth=None, **kwargs):
#     #     B, S, D = decout[-1].shape
#     #     head = getattr(self, f"head")
#     #     if self.depth_guidance:
#     #         return head(decout, img_shape, prompt_depth, **kwargs)
#     #     else:
#     #         return head(decout, img_shape, **kwargs)
        
#     def _downstream_head(self, decout, img_shape, **kwargs):
#         B, S, D = decout[-1].shape
#         head = getattr(self, f"head")
#         return head(decout, img_shape, **kwargs)

#     def _init_state(self, image_tokens, image_pos):
#         """
#         Current Version: input the first frame img feature and pose to initialize the state feature and pose
#         """
#         state_feat, state_pos, _ = self._encode_state(image_tokens, image_pos)
#         state_feat = self.decoder_embed_state(state_feat)
#         return state_feat, state_pos

#     def _recurrent_rollout(
#         self,
#         state_feat,
#         state_pos,
#         current_feat,
#         current_pos,
#         pose_feat,
#         pose_pos,
#         init_state_feat,
#         img_mask=None,
#         reset_mask=None,
#         update=None,
#         depth_prompt=None,
#         depth_true_shape=None

#     ):
#         """
#         Performs decoder forward pass
#         """
#         new_state_feat, dec = self._decoder(state_feat, 
#                                             state_pos, 
#                                             current_feat, 
#                                             current_pos, 
#                                             pose_feat, 
#                                             pose_pos,
#                                             depth_prompt,
#                                             depth_true_shape
#         )
#         new_state_feat = new_state_feat[-1]
#         return new_state_feat, dec

#     def _get_img_level_feat(self, feat):
#         return torch.mean(feat, dim=1, keepdim=True)

#     def _forward_encoder(self, views):
#         shape, feat_ls, pos = self._encode_views(views)
#         feat = feat_ls[-1]
#         state_feat, state_pos = self._init_state(feat[0], pos[0])
#         mem = self.pose_retriever.mem.expand(feat[0].shape[0], -1, -1)
#         init_state_feat = state_feat.clone()
#         init_mem = mem.clone()
#         return (feat, pos, shape), (
#             init_state_feat,
#             init_mem,
#             state_feat,
#             state_pos,
#             mem,
#         )

#     def _forward_decoder_step(
#         self,
#         views,
#         i,
#         feat_i,
#         pos_i,
#         shape_i,
#         init_state_feat, # initial state feature, 
#         init_mem, # initial memory feature
#         state_feat, # state feature
#         state_pos, # state feature token positions
#         mem,
#     ):
#         """
#         Calls:
#         1. Decoders forward prop
#         2. Output head(dpt/linear) forward prop
        
#         Arguments:
#             views: list of views, where length of list corresponds to sequence length
#                     and every item in list has shape [B, ...]
#                    views[0].keys(): dict_keys(['img', 'depthmap', 'camera_pose', 'camera_intrinsics', 'dataset', 'label', 'is_metric', 'instance', 'is_video', 'quantile', 'img_mask', 'ray_mask', 'camera_only', 'depth_only', 'single_view', 'reset', 'idx', 'true_shape', 'sky_mask', 'ray_map', 'pts3d', 'valid_mask', 'rng'])
#         """

#         # print(f"len(views): {len(views):}")
#         # print(f"views[0].keys(): {views[0].keys()}")
#         # print(f"views[0]['depthmap'].shape: {views[0]['depthmap'].shape}")
#         # for i in range(len(views)):        
#         #     print(f"views[i]['depthmap'].shape: {views[i]['depthmap'].shape}")
#         # print(f"feat_i.shape: {feat_i.shape}")

#         # depth_prompt1 = self.normalize_depth_prompt(view1['depthmap'].unsqueeze(1))
#         # depth_prompt2 = self.normalize_depth_prompt(view2['depthmap'].unsqueeze(1))

#         if self.pose_head_flag:
#             global_img_feat_i = self._get_img_level_feat(feat_i)
#             # print(f"global_img_feat_i.shape: {global_img_feat_i.shape}")
#             if i == 0:
#                 pose_feat_i = self.pose_token.expand(feat_i.shape[0], -1, -1)
#             else:
#                 pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem)
#             pose_pos_i = -torch.ones(
#                 feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype
#             )
#         else:
#             pose_feat_i = None
#             pose_pos_i = None

#         # print(f"state_feat.shape: {state_feat.shape}")
#         # print(f"state_pos.shape: {state_pos.shape}")
#         # print(f"feat_i.shape: {feat_i.shape}")
#         # print(f"pos_i.shape: {pos_i.shape}")

#         # returns [f_state, f_img], if pose_mode: [f_state, [f_pose, f_img]]
#         new_state_feat, dec = self._recurrent_rollout(
#             state_feat, # f_state
#             state_pos, # pos_state 
#             feat_i, # f_img
#             pos_i, #pos_img
#             pose_feat_i, #f_pose
#             pose_pos_i, #pos_pose
#             init_state_feat,
#             img_mask=views[i]["img_mask"],
#             reset_mask=views[i]["reset"],
#             update=views[i].get("update", None),
#             depth_prompt=views[i]["depthmap"],
#             depth_true_shape=views[i]["true_shape"]
#         )
#         out_pose_feat_i = dec[-1][:, 0:1]
#         new_mem = self.pose_retriever.update_mem(
#             mem, global_img_feat_i, out_pose_feat_i
#         )
#         head_input = [
#             dec[0].float(), # f_img from encoder  of shape [B, N_tokens, D_0]
#             dec[self.dec_depth * 2 // 4][:, 1:].float(), #idx=6 [f_img] in the middle
#             dec[self.dec_depth * 3 // 4][:, 1:].float(), #idx=9 no_cls [f_img] in 75%
#             dec[self.dec_depth].float(), # idx=12 with_cls [f_img, f_pose]
#         ]

#         res = self._downstream_head(head_input, shape_i, pos=pos_i)
#         img_mask = views[i]["img_mask"]
#         update = views[i].get("update", None)
#         if update is not None:
#             update_mask = img_mask & update  # if don't update, then whatever img_mask
#         else:
#             update_mask = img_mask
#         update_mask = update_mask[:, None, None].float()
#         state_feat = new_state_feat * update_mask + state_feat * (
#             1 - update_mask
#         )  # update global state
#         mem = new_mem * update_mask + mem * (1 - update_mask)  # then update local state
#         reset_mask = views[i]["reset"]
#         if reset_mask is not None:
#             reset_mask = reset_mask[:, None, None].float()
#             state_feat = init_state_feat * reset_mask + state_feat * (1 - reset_mask)
#             mem = init_mem * reset_mask + mem * (1 - reset_mask)
#         return res, (state_feat, mem)

#     def _forward_impl(self, views, ret_state=False):
#         # 1. Encoder forward prop
#         shape, feat_ls, pos = self._encode_views(views)
#         feat = feat_ls[-1]
#         # 2. State initialization
#         state_feat, state_pos = self._init_state(feat[0], pos[0])
#         mem = self.pose_retriever.mem.expand(feat[0].shape[0], -1, -1)
#         init_state_feat = state_feat.clone()
#         init_mem = mem.clone()
#         all_state_args = [(state_feat, state_pos, init_state_feat, mem, init_mem)]
#         ress = []
#         # 3. Reconstruction loop
#         for i in range(len(views)):
#             feat_i = feat[i]
#             pos_i = pos[i]
#             # whether pose should be predicted
#             if self.pose_head_flag:
#                 global_img_feat_i = self._get_img_level_feat(feat_i)
#                 if i == 0:
#                     pose_feat_i = self.pose_token.expand(feat_i.shape[0], -1, -1)
#                 else:
#                     pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem)
#                 pose_pos_i = -torch.ones(
#                     feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype
#                 )
#             else:
#                 pose_feat_i = None
#                 pose_pos_i = None
            
#             # iteratively update feat and return decoder feats
#             new_state_feat, dec = self._recurrent_rollout(
#                 state_feat,
#                 state_pos,
#                 feat_i,
#                 pos_i,
#                 pose_feat_i,
#                 pose_pos_i,
#                 init_state_feat,
#                 img_mask=views[i]["img_mask"],
#                 reset_mask=views[i]["reset"],
#                 update=views[i].get("update", None),
#                 depth_prompt=views[i]["depthmap"],
#                 depth_true_shape=views[i]["true_shape"]
#             )
#             out_pose_feat_i = dec[-1][:, 0:1]
#             new_mem = self.pose_retriever.update_mem(
#                 mem, global_img_feat_i, out_pose_feat_i
#             )
#             assert len(dec) == self.dec_depth + 1

#             # Pass decoder feats through output layer
#             head_input = [
#                 dec[0].float(),
#                 dec[self.dec_depth * 2 // 4][:, 1:].float(),
#                 dec[self.dec_depth * 3 // 4][:, 1:].float(),
#                 dec[self.dec_depth].float(),
#             ]

#             # print(f"views[i].keys(): {views[i].keys()}")
#             # exit()

#             # prompt_depth = None

#             prompt_depth_i = self.normalize_depth_prompt(views[i]['depthmap'].unsqueeze(1))
#             # print(f"prompt_depth_i.shape: {prompt_depth_i.shape}")
#             # exit()

#             res = self._downstream_head(head_input, shape[i], prompt_depth=prompt_depth_i, pos=pos_i)
#             # print(f"res.keys(): {res.keys()}")
#             # print(f"res['pts3d_in_self_view'].shape: {res['pts3d_in_self_view'].shape}")
#             # print(f"res['conf_self'].shape: {res['conf_self'].shape}")
#             # print(f"views[i]['pts3d'].shape: {views[i]['pts3d'].shape}")

#             ress.append(res)
#             img_mask = views[i]["img_mask"]
#             update = views[i].get("update", None)
#             if update is not None:
#                 update_mask = (
#                     img_mask & update
#                 )  # if don't update, then whatever img_mask
#             else:
#                 update_mask = img_mask
#             update_mask = update_mask[:, None, None].float()
#             state_feat = new_state_feat * update_mask + state_feat * (
#                 1 - update_mask
#             )  # update global state
#             mem = new_mem * update_mask + mem * (
#                 1 - update_mask
#             )  # then update local state
#             reset_mask = views[i]["reset"]
#             if reset_mask is not None:
#                 reset_mask = reset_mask[:, None, None].float()
#                 state_feat = init_state_feat * reset_mask + state_feat * (
#                     1 - reset_mask
#                 )
#                 mem = init_mem * reset_mask + mem * (1 - reset_mask)
#             all_state_args.append(
#                 (state_feat, state_pos, init_state_feat, mem, init_mem)
#             )
#         if ret_state:
#             return ress, views, all_state_args
#         return ress, views

#     def forward(self, views, ret_state=False):
#         if ret_state:
#             ress, views, state_args = self._forward_impl(views, ret_state=ret_state)
#             return ARCroco3DStereoOutput(ress=ress, views=views), state_args
#         else:
#             ress, views = self._forward_impl(views, ret_state=ret_state)
#             return ARCroco3DStereoOutput(ress=ress, views=views)

#     def inference_step(
#         self, view, state_feat, state_pos, init_state_feat, mem, init_mem
#     ):
#         batch_size = view["img"].shape[0]
#         raymaps = []
#         shapes = []
#         for j in range(batch_size):
#             assert view["ray_mask"][j]
#             raymap = view["ray_map"][[j]].permute(0, 3, 1, 2)
#             raymaps.append(raymap)
#             shapes.append(
#                 view.get(
#                     "true_shape",
#                     torch.tensor(view["ray_map"].shape[-2:])[None].repeat(
#                         view["ray_map"].shape[0], 1
#                     ),
#                 )[[j]]
#             )

#         raymaps = torch.cat(raymaps, dim=0)
#         shape = torch.cat(shapes, dim=0).to(raymaps.device)
#         feat_ls, pos, _ = self._encode_ray_map(raymaps, shapes)

#         feat_i = feat_ls[-1]
#         pos_i = pos
#         if self.pose_head_flag:
#             global_img_feat_i = self._get_img_level_feat(feat_i)
#             pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem)
#             pose_pos_i = -torch.ones(
#                 feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype
#             )
#         else:
#             pose_feat_i = None
#             pose_pos_i = None
#         new_state_feat, dec = self._recurrent_rollout(
#             state_feat,
#             state_pos,
#             feat_i,
#             pos_i,
#             pose_feat_i,
#             pose_pos_i,
#             init_state_feat,
#             img_mask=view["img_mask"],
#             reset_mask=view["reset"],
#             update=view.get("update", None),
#         )

#         out_pose_feat_i = dec[-1][:, 0:1]
#         new_mem = self.pose_retriever.update_mem(
#             mem, global_img_feat_i, out_pose_feat_i
#         )
#         assert len(dec) == self.dec_depth + 1
#         head_input = [
#             dec[0].float(),
#             dec[self.dec_depth * 2 // 4][:, 1:].float(),
#             dec[self.dec_depth * 3 // 4][:, 1:].float(),
#             dec[self.dec_depth].float(),
#         ]
#         res = self._downstream_head(head_input, shape, pos=pos_i)
#         return res, view

#     def forward_recurrent(self, views, device, ret_state=False):
#         ress = []
#         all_state_args = []
#         for i, view in enumerate(views):
#             device = view["img"].device
#             batch_size = view["img"].shape[0]
#             img_mask = view["img_mask"].reshape(
#                 -1, batch_size
#             )  # Shape: (1, batch_size)
#             ray_mask = view["ray_mask"].reshape(
#                 -1, batch_size
#             )  # Shape: (1, batch_size)
#             imgs = view["img"].unsqueeze(0)  # Shape: (1, batch_size, C, H, W)
#             ray_maps = view["ray_map"].unsqueeze(
#                 0
#             )  # Shape: (num_views, batch_size, H, W, C)
#             shapes = (
#                 view["true_shape"].unsqueeze(0)
#                 if "true_shape" in view
#                 else torch.tensor(view["img"].shape[-2:], device=device)
#                 .unsqueeze(0)
#                 .repeat(batch_size, 1)
#                 .unsqueeze(0)
#             )  # Shape: (num_views, batch_size, 2)
#             imgs = imgs.view(
#                 -1, *imgs.shape[2:]
#             )  # Shape: (num_views * batch_size, C, H, W)
#             ray_maps = ray_maps.view(
#                 -1, *ray_maps.shape[2:]
#             )  # Shape: (num_views * batch_size, H, W, C)
#             shapes = shapes.view(-1, 2).to(
#                 imgs.device
#             )  # Shape: (num_views * batch_size, 2)
#             img_masks_flat = img_mask.view(-1)  # Shape: (num_views * batch_size)
#             ray_masks_flat = ray_mask.view(-1)
#             selected_imgs = imgs[img_masks_flat]
#             selected_shapes = shapes[img_masks_flat]
#             if selected_imgs.size(0) > 0:
#                 img_out, img_pos, _ = self._encode_image(selected_imgs, selected_shapes)
#             else:
#                 img_out, img_pos = None, None
#             ray_maps = ray_maps.permute(0, 3, 1, 2)  # Change shape to (N, C, H, W)
#             selected_ray_maps = ray_maps[ray_masks_flat]
#             selected_shapes_ray = shapes[ray_masks_flat]
#             if selected_ray_maps.size(0) > 0:
#                 ray_out, ray_pos, _ = self._encode_ray_map(
#                     selected_ray_maps, selected_shapes_ray
#                 )
#             else:
#                 ray_out, ray_pos = None, None

#             shape = shapes
#             if img_out is not None and ray_out is None:
#                 feat_i = img_out[-1]
#                 pos_i = img_pos
#             elif img_out is None and ray_out is not None:
#                 feat_i = ray_out[-1]
#                 pos_i = ray_pos
#             elif img_out is not None and ray_out is not None:
#                 feat_i = img_out[-1] + ray_out[-1]
#                 pos_i = img_pos
#             else:
#                 raise NotImplementedError

#             if i == 0:
#                 state_feat, state_pos = self._init_state(feat_i, pos_i)
#                 mem = self.pose_retriever.mem.expand(feat_i.shape[0], -1, -1)
#                 init_state_feat = state_feat.clone()
#                 init_mem = mem.clone()
#                 all_state_args.append(
#                     (state_feat, state_pos, init_state_feat, mem, init_mem)
#                 )

#             if self.pose_head_flag:
#                 global_img_feat_i = self._get_img_level_feat(feat_i)
#                 if i == 0:
#                     pose_feat_i = self.pose_token.expand(feat_i.shape[0], -1, -1)
#                 else:
#                     pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem)
#                 pose_pos_i = -torch.ones(
#                     feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype
#                 )
#             else:
#                 pose_feat_i = None
#                 pose_pos_i = None
#             new_state_feat, dec = self._recurrent_rollout(
#                 state_feat,
#                 state_pos,
#                 feat_i,
#                 pos_i,
#                 pose_feat_i,
#                 pose_pos_i,
#                 init_state_feat,
#                 img_mask=view["img_mask"],
#                 reset_mask=view["reset"],
#                 update=view.get("update", None),
#             )
#             out_pose_feat_i = dec[-1][:, 0:1]
#             new_mem = self.pose_retriever.update_mem(
#                 mem, global_img_feat_i, out_pose_feat_i
#             )
#             assert len(dec) == self.dec_depth + 1
#             head_input = [
#                 dec[0].float(),
#                 dec[self.dec_depth * 2 // 4][:, 1:].float(),
#                 dec[self.dec_depth * 3 // 4][:, 1:].float(),
#                 dec[self.dec_depth].float(),
#             ]
#             res = self._downstream_head(head_input, shape, pos=pos_i)
#             ress.append(res)
#             img_mask = view["img_mask"]
#             update = view.get("update", None)
#             if update is not None:
#                 update_mask = (
#                     img_mask & update
#                 )  # if don't update, then whatever img_mask
#             else:
#                 update_mask = img_mask
#             update_mask = update_mask[:, None, None].float()
#             state_feat = new_state_feat * update_mask + state_feat * (
#                 1 - update_mask
#             )  # update global state
#             mem = new_mem * update_mask + mem * (
#                 1 - update_mask
#             )  # then update local state
#             reset_mask = view["reset"]
#             if reset_mask is not None:
#                 reset_mask = reset_mask[:, None, None].float()
#                 state_feat = init_state_feat * reset_mask + state_feat * (
#                     1 - reset_mask
#                 )
#                 mem = init_mem * reset_mask + mem * (1 - reset_mask)
#             all_state_args.append(
#                 (state_feat, state_pos, init_state_feat, mem, init_mem)
#             )
#         if ret_state:
#             return ress, views, all_state_args
#         return ress, views


# class ARCroco3DStereoGuidedConcat(CroCoNet):
#     config_class = ARCroco3DStereoConfig
#     base_model_prefix = "arcroco3dstereo"
#     supports_gradient_checkpointing = True

#     def __init__(self, config: ARCroco3DStereoConfig):
#         self.gradient_checkpointing = False
#         self.fixed_input_length = True
#         config.croco_kwargs = fill_default_args(
#             config.croco_kwargs, CrocoConfig.__init__
#         )
#         self.config = config
#         self.patch_embed_cls = config.patch_embed_cls
#         self.croco_args = config.croco_kwargs
#         croco_cfg = CrocoConfig(**self.croco_args)
#         super().__init__(croco_cfg)
#         self.enc_blocks_ray_map = nn.ModuleList(
#             [
#                 Block(
#                     self.enc_embed_dim,
#                     16,
#                     4,
#                     qkv_bias=True,
#                     norm_layer=partial(nn.LayerNorm, eps=1e-6),
#                     rope=self.rope,
#                 )
#                 for _ in range(config.ray_enc_depth)
#             ]
#         )
#         self.enc_norm_ray_map = nn.LayerNorm(self.enc_embed_dim, eps=1e-6)
#         self.dec_num_heads = self.croco_args["dec_num_heads"]
#         self.pose_head_flag = config.pose_head
#         self.depth_guidance = config.depth_guidance

#         if self.pose_head_flag:
#             self.pose_token = nn.Parameter(
#                 torch.randn(1, 1, self.dec_embed_dim) * 0.02, requires_grad=True
#             )
#             self.pose_retriever = LocalMemory(
#                 size=config.local_mem_size,
#                 k_dim=self.enc_embed_dim,
#                 v_dim=self.dec_embed_dim,
#                 num_heads=self.dec_num_heads,
#                 mlp_ratio=4,
#                 qkv_bias=True,
#                 attn_drop=0.0,
#                 norm_layer=partial(nn.LayerNorm, eps=1e-6),
#                 rope=None,
#             )
#         self.register_tokens = nn.Embedding(config.state_size, self.enc_embed_dim)
#         self.state_size = config.state_size
#         self.state_pe = config.state_pe
#         self.masked_img_token = nn.Parameter(
#             torch.randn(1, self.enc_embed_dim) * 0.02, requires_grad=True
#         )
#         self.masked_ray_map_token = nn.Parameter(
#             torch.randn(1, self.enc_embed_dim) * 0.02, requires_grad=True
#         )
#         self._set_state_decoder(
#             self.enc_embed_dim,
#             self.dec_embed_dim,
#             config.state_dec_num_heads,
#             self.dec_depth,
#             self.croco_args.get("mlp_ratio", None),
#             self.croco_args.get("norm_layer", None),
#             self.croco_args.get("norm_im2_in_dec", None),
#         )

#         # decoder block for depth prompt
#         # self.dec_blocks_pc = nn.ModuleList([
#         #     Block(self.dec_embed_dim,
#         #           self.dec_num_heads,
#         #           mlp_ratio=self.croco_args.get("mlp_ratio", None),
#         #           qkv_bias=True,
#         #           norm_layer=self.croco_args.get("norm_layer", None),
#         #           rope=self.rope)
#         #     for i in range(self.croco_args.get("dec_depth", None)//2-2)
#         # ])

#         # self.zero_convs = []
#         # for i in range(len(self.dec_blocks_pc) + 1):
#         #     self.zero_convs.append(make_zero_conv(self.dec_embed_dim))
#         # self.zero_convs = nn.ModuleList(self.zero_convs)

#         self.set_downstream_head(
#             config.output_mode,
#             config.head_type,
#             config.landscape_only,
#             config.depth_mode,
#             config.conf_mode,
#             config.pose_mode,
#             config.depth_head,
#             config.rgb_head,
#             config.pose_conf_head,
#             config.pose_head,
#             **self.croco_args,
#         )
#         self.set_freeze(config.freeze)

#         print(f"config.landscape_only: {config.landscape_only}")    

#     @classmethod
#     def from_pretrained(cls, pretrained_model_name_or_path, **kw):
#         if os.path.isfile(pretrained_model_name_or_path):
#             return load_model(pretrained_model_name_or_path, device="cpu")
#         else:
#             try:
#                 model = super(ARCroco3DStereo, cls).from_pretrained(
#                     pretrained_model_name_or_path, **kw
#                 )
#             except TypeError as e:
#                 raise Exception(
#                     f"tried to load {pretrained_model_name_or_path} from huggingface, but failed"
#                 )
#             return model

#     def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
#         self.patch_embed = get_patch_embed(
#             self.patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans=3+2
#         )
#         self.patch_embed_ray_map = get_patch_embed(
#             self.patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans=6+2
#         )

#         # self.patch_embed_pc = get_patch_embed(
#         #     self.patch_embed_cls, img_size, patch_size, self.dec_embed_dim, in_chans=2
#         # )

#     def _set_decoder(
#         self,
#         enc_embed_dim,
#         dec_embed_dim,
#         dec_num_heads,
#         dec_depth,
#         mlp_ratio,
#         norm_layer,
#         norm_im2_in_dec,
#     ):
#         self.dec_depth = dec_depth
#         self.dec_embed_dim = dec_embed_dim
#         self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
#         self.dec_blocks = nn.ModuleList(
#             [
#                 DecoderBlock(
#                     dec_embed_dim,
#                     dec_num_heads,
#                     mlp_ratio=mlp_ratio,
#                     qkv_bias=True,
#                     norm_layer=norm_layer,
#                     norm_mem=norm_im2_in_dec,
#                     rope=self.rope,
#                 )
#                 for i in range(dec_depth)
#             ]
#         )
#         self.dec_norm = norm_layer(dec_embed_dim)

#     def _set_state_decoder(
#         self,
#         enc_embed_dim,
#         dec_embed_dim,
#         dec_num_heads,
#         dec_depth,
#         mlp_ratio,
#         norm_layer,
#         norm_im2_in_dec,
#     ):
#         self.dec_depth_state = dec_depth
#         self.dec_embed_dim_state = dec_embed_dim
#         self.decoder_embed_state = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
#         self.dec_blocks_state = nn.ModuleList(
#             [
#                 DecoderBlock(
#                     dec_embed_dim,
#                     dec_num_heads,
#                     mlp_ratio=mlp_ratio,
#                     qkv_bias=True,
#                     norm_layer=norm_layer,
#                     norm_mem=norm_im2_in_dec,
#                     rope=self.rope,
#                 )
#                 for i in range(dec_depth)
#             ]
#         )
#         self.dec_norm_state = norm_layer(dec_embed_dim)

#     def load_state_dict(self, ckpt, **kw):
#         if all(k.startswith("module") for k in ckpt):
#             ckpt = strip_module(ckpt)
#         new_ckpt = dict(ckpt)
#         if not any(k.startswith("dec_blocks_state") for k in ckpt):
#             for key, value in ckpt.items():
#                 if key.startswith("dec_blocks"):
#                     new_ckpt[key.replace("dec_blocks", "dec_blocks_state")] = value
#         try:
#             return super().load_state_dict(new_ckpt, **kw)
#         except:
#             try:
#                 new_new_ckpt = {
#                     k: v
#                     for k, v in new_ckpt.items()
#                     if not k.startswith("dec_blocks")
#                     and not k.startswith("dec_norm")
#                     and not k.startswith("decoder_embed")
#                 }
#                 return super().load_state_dict(new_new_ckpt, **kw)
#             except:
#                 new_new_ckpt = {}
#                 for key in new_ckpt:
#                     if key in self.state_dict():
#                         if new_ckpt[key].size() == self.state_dict()[key].size():
#                             new_new_ckpt[key] = new_ckpt[key]
#                         else:
#                             printer.info(
#                                 f"Skipping '{key}': size mismatch (ckpt: {new_ckpt[key].size()}, model: {self.state_dict()[key].size()})"
#                             )
#                     else:
#                         printer.info(f"Skipping '{key}': not found in model")
#                 return super().load_state_dict(new_new_ckpt, **kw)

#     def set_freeze(self, freeze):  # this is for use by downstream models
#         self.freeze = freeze
#         to_be_frozen = {
#             "none": [],
#             "mask": [self.mask_token] if hasattr(self, "mask_token") else [],
#             "encoder": [
#                 self.patch_embed,
#                 self.patch_embed_ray_map,
#                 self.masked_img_token,
#                 self.masked_ray_map_token,
#                 self.enc_blocks,
#                 self.enc_blocks_ray_map,
#                 self.enc_norm,
#                 self.enc_norm_ray_map,
#             ],
#             "encoder_and_head": [
#                 self.patch_embed,
#                 self.patch_embed_ray_map,
#                 self.masked_img_token,
#                 self.masked_ray_map_token,
#                 self.enc_blocks,
#                 self.enc_blocks_ray_map,
#                 self.enc_norm,
#                 self.enc_norm_ray_map,
#                 self.downstream_head,
#             ],
#             "encoder_and_decoder": [
#                 self.patch_embed,
#                 self.patch_embed_ray_map,
#                 self.masked_img_token,
#                 self.masked_ray_map_token,
#                 self.enc_blocks,
#                 self.enc_blocks_ray_map,
#                 self.enc_norm,
#                 self.enc_norm_ray_map,
#                 self.dec_blocks,
#                 self.dec_blocks_state,
#                 self.pose_retriever,
#                 self.pose_token,
#                 self.register_tokens,
#                 self.decoder_embed_state,
#                 self.decoder_embed,
#                 self.dec_norm,
#                 self.dec_norm_state,
#             ],
#             "decoder": [
#                 self.dec_blocks,
#                 self.dec_blocks_state,
#                 self.pose_retriever,
#                 self.pose_token,
#             ],
#         }
#         freeze_all_params(to_be_frozen[freeze])

#     def _set_prediction_head(self, *args, **kwargs):
#         """No prediction head"""
#         return

#     def set_downstream_head(
#         self,
#         output_mode,
#         head_type,
#         landscape_only,
#         depth_mode,
#         conf_mode,
#         pose_mode,
#         depth_head,
#         rgb_head,
#         pose_conf_head,
#         pose_head,
#         patch_size,
#         img_size,
#         **kw,
#     ):
#         assert (
#             img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0
#         ), f"{img_size=} must be multiple of {patch_size=}"
#         self.output_mode = output_mode
#         self.head_type = head_type
#         self.depth_mode = depth_mode
#         self.conf_mode = conf_mode
#         self.pose_mode = pose_mode
#         self.downstream_head = head_factory(
#             head_type,
#             output_mode,
#             self,
#             has_conf=bool(conf_mode),
#             has_depth=bool(depth_head),
#             has_rgb=bool(rgb_head),
#             has_pose_conf=bool(pose_conf_head),
#             has_pose=bool(pose_head),
#             depth_guidance=self.depth_guidance
#         )
#         self.head = transpose_to_landscape(
#             self.downstream_head, activate=landscape_only
#         )

#     def _encode_image(self, image, depthmap, depthmap_mask, true_shape):
#         """
#         Forward prop of VIT encoder:
#         Arguments:
#             image: images of shape [B*num_views, C, H, W]
#         """
#         # print(f"[_encode_image] image.shape: {image.shape}")
#         # print(f"[_encode_image] depthmap.shape: {depthmap.shape}")
#         # print(f"[_encode_image] depthmap_mask.shape: {depthmap_mask.shape}")

        

#         x, pos = self.patch_embed(torch.cat((image, depthmap, depthmap_mask), dim=1), true_shape=true_shape)
#         # print(f"[patch_embed] x.shape: {x.shape}")
#         assert self.enc_pos_embed is None
#         for blk in self.enc_blocks:
#             if self.gradient_checkpointing and self.training:
#                 x = checkpoint(blk, x, pos, use_reentrant=False)
#             else:
#                 x = blk(x, pos)
#         x = self.enc_norm(x)
#         return [x], pos, None

#     def _encode_ray_map(self, ray_map, depthmap, depthmap_mask, true_shape):
#         # print(f"ray_map.shape: {ray_map.shape}")
#         x, pos = self.patch_embed_ray_map(torch.cat((ray_map, depthmap, depthmap_mask), dim=1), true_shape=true_shape)
#         assert self.enc_pos_embed is None
#         for blk in self.enc_blocks_ray_map:
#             if self.gradient_checkpointing and self.training:
#                 x = checkpoint(blk, x, pos, use_reentrant=False)
#             else:
#                 x = blk(x, pos)
#         x = self.enc_norm_ray_map(x)
#         return [x], pos, None

#     def _encode_state(self, image_tokens, image_pos):
#         batch_size = image_tokens.shape[0]
#         state_feat = self.register_tokens(
#             torch.arange(self.state_size, device=image_pos.device)
#         )
#         if self.state_pe == "1d":
#             state_pos = (
#                 torch.tensor(
#                     [[i, i] for i in range(self.state_size)],
#                     dtype=image_pos.dtype,
#                     device=image_pos.device,
#                 )[None]
#                 .expand(batch_size, -1, -1)
#                 .contiguous()
#             )  # .long()
#         elif self.state_pe == "2d":
#             width = int(self.state_size**0.5)
#             width = width + 1 if width % 2 == 1 else width
#             state_pos = (
#                 torch.tensor(
#                     [[i // width, i % width] for i in range(self.state_size)],
#                     dtype=image_pos.dtype,
#                     device=image_pos.device,
#                 )[None]
#                 .expand(batch_size, -1, -1)
#                 .contiguous()
#             )
#         elif self.state_pe == "none":
#             state_pos = None
#         state_feat = state_feat[None].expand(batch_size, -1, -1)
#         return state_feat, state_pos, None

#     def _encode_views(self, views, img_mask=None, ray_mask=None):
#         """
#         Encoder forward propagation to get a set of tokens
        
#         Arguments:
#             views: list of views, where length of list corresponds to sequence length
#                     and every item in list has shape [B, ...]
#                    views[0].keys(): dict_keys(['img', 'depthmap', 'camera_pose', 'camera_intrinsics', 'dataset', 'label', 'is_metric', 'instance', 'is_video', 'quantile', 'img_mask', 'ray_mask', 'camera_only', 'depth_only', 'single_view', 'reset', 'idx', 'true_shape', 'sky_mask', 'ray_map', 'pts3d', 'valid_mask', 'rng'])
#         """
#         device = views[0]["img"].device
#         batch_size = views[0]["img"].shape[0]
#         given = True

#         # img/ray mask basically define whether image or corresponding raymap used as input
#         if img_mask is None and ray_mask is None:
#             given = False
#         if not given:
#             img_mask = torch.stack(
#                 [view["img_mask"] for view in views], dim=0
#             )  # Shape: (num_views, batch_size)
#             ray_mask = torch.stack(
#                 [view["ray_mask"] for view in views], dim=0
#             )  # Shape: (num_views, batch_size)

#         imgs = torch.stack(
#             [view["img"] for view in views], dim=0
#         )  # Shape: (num_views, batch_size, C, H, W)
#         ray_maps = torch.stack(
#             [view["ray_map"] for view in views], dim=0
#         )  # Shape: (num_views, batch_size, H, W, C)

#         depth_maps = torch.stack(
#             [view["depthmap"] for view in views], dim=0
#         )

#         depth_map_masks = depth_maps != 0

#         # print(f"depth_maps.shape: {depth_maps.shape}")

#         shapes = []
#         for view in views:
#             if "true_shape" in view:
#                 shapes.append(view["true_shape"])
#             else:
#                 shape = torch.tensor(view["img"].shape[-2:], device=device)
#                 shapes.append(shape.unsqueeze(0).repeat(batch_size, 1))
#         shapes = torch.stack(shapes, dim=0).to(
#             imgs.device
#         )  # Shape: (num_views, batch_size, 2)
#         imgs = imgs.view(
#             -1, *imgs.shape[2:]
#         )  # Shape: (num_views * batch_size, C, H, W)
#         ray_maps = ray_maps.view(
#             -1, *ray_maps.shape[2:]
#         )  # Shape: (num_views * batch_size, H, W, C)

#         depth_maps = depth_maps.view(-1, *depth_maps.shape[2:]).unsqueeze(1)
#         depth_map_masks = depth_map_masks.view(-1, *depth_map_masks.shape[2:]).unsqueeze(1)


#         shapes = shapes.view(-1, 2)  # Shape: (num_views * batch_size, 2)
#         img_masks_flat = img_mask.view(-1)  # Shape: (num_views * batch_size)
#         ray_masks_flat = ray_mask.view(-1)
#         selected_imgs = imgs[img_masks_flat]
#         selected_shapes = shapes[img_masks_flat]

#         selected_img_depth_maps = depth_maps[img_masks_flat]
#         selected_img_depth_map_masks = depth_map_masks[img_masks_flat]

#         # ENCODER FORWARD PASS ON IMAGES
#         if selected_imgs.size(0) > 0:
#             img_out, img_pos, _ = self._encode_image(selected_imgs,
#                                                      selected_img_depth_maps,
#                                                      selected_img_depth_map_masks, 
#                                                      selected_shapes)
#         else:
#             raise NotImplementedError
        


#         # 1 * [n_views*batch_size, ]
#         full_out = [
#             torch.zeros(
#                 len(views) * batch_size, *img_out[0].shape[1:], device=img_out[0].device
#             )
#             for _ in range(len(img_out))
#         ]
#         # print(f"imgs.shape: {imgs.shape}")
#         full_pos = torch.zeros(
#             len(views) * batch_size,
#             *img_pos.shape[1:],
#             device=img_pos.device,
#             dtype=img_pos.dtype,
#         )
#         for i in range(len(img_out)):
#             full_out[i][img_masks_flat] += img_out[i]
#             full_out[i][~img_masks_flat] += self.masked_img_token
#         full_pos[img_masks_flat] += img_pos
#         ray_maps = ray_maps.permute(0, 3, 1, 2)  # Change shape to (N, C, H, W)
#         selected_ray_maps = ray_maps[ray_masks_flat]
#         selected_shapes_ray = shapes[ray_masks_flat]

#         selected_ray_depth_maps = depth_maps[ray_masks_flat]
#         selected_ray_depth_map_masks = depth_map_masks[ray_masks_flat]

#         # ENCODER FORWARD PASS ON RAYS
#         if selected_ray_maps.size(0) > 0:
#             ray_out, ray_pos, _ = self._encode_ray_map(
#                 selected_ray_maps, 
#                 selected_ray_depth_maps,
#                 selected_ray_depth_map_masks,
#                 selected_shapes_ray
#             )
#             assert len(ray_out) == len(full_out), f"{len(ray_out)}, {len(full_out)}"
#             for i in range(len(ray_out)):
#                 full_out[i][ray_masks_flat] += ray_out[i]
#                 full_out[i][~ray_masks_flat] += self.masked_ray_map_token
#             full_pos[ray_masks_flat] += (
#                 ray_pos * (~img_masks_flat[ray_masks_flat][:, None, None]).long()
#             )
#         else:
#             raymaps = torch.zeros(
#                 1, 6, imgs[0].shape[-2], imgs[0].shape[-1], device=img_out[0].device
#             )
#             ray_mask_flat = torch.zeros_like(img_masks_flat)
#             ray_mask_flat[:1] = True
#             ray_out, ray_pos, _ = self._encode_ray_map(raymaps, shapes[ray_mask_flat])
#             for i in range(len(ray_out)):
#                 full_out[i][ray_mask_flat] += ray_out[i] * 0.0
#                 full_out[i][~ray_mask_flat] += self.masked_ray_map_token * 0.0
#         return (
#             shapes.chunk(len(views), dim=0),
#             [out.chunk(len(views), dim=0) for out in full_out],
#             full_pos.chunk(len(views), dim=0),
#         )

#     def normalize_depth_prompt(self, depth_prompt: torch.Tensor):
#         """
#         normalize input depth in [0,1] range
#         """
#         # print(f"depth_prompt.shape: {depth_prompt.shape}")

#         B, C, H, W = depth_prompt.shape
#         min_val = torch.quantile(
#             depth_prompt.reshape(B, -1), 0., dim=1, keepdim=True)[:, :, None, None]
#         max_val = torch.quantile(
#             depth_prompt.reshape(B, -1), 1., dim=1, keepdim=True)[:, :, None, None]

#         # min_val=torch.tensor(0.0)
#         # max_val=torch.tensor(100.0)

#         denom = (max_val - min_val).clamp_min(1e-6)
#         depth_prompt = (depth_prompt - min_val) / denom

#         return depth_prompt, min_val, max_val

#     def _decoder(self, 
#                  f_state, 
#                  pos_state, 
#                  f_img, 
#                  pos_img, 
#                  f_pose, 
#                  pos_pose, 
#                  depth_prompt,
#                  depth_true_shape):
#         """
#         Forward prop. of 2 identical decoder blocks(self-attn + cross-attn)
#         for [f_state, f_img], if pose [f_state, [f_pose, f_img]]

#         Args:
#             f_state (torch.Tensor): Embedding of state feature [B, D, D]
#             pos_state (int): Positions of state
            
#             f_img (torch.Tensor): Embedding of image features [B, N, D]
#             pos_img (): Positions(x,y) of image features
            
#             f_pose (torch.Tensor): Embedding of pose feature [B, 1, D]
#             pos_pose (): Positions of pose feature

#             depth_prompt (torch.Tensor): Sparse depth promp [B, H, W]
#         Returns:
#             int: The product of a and b.

#         """

#         # print(f"pos_img.shape: {pos_img.shape}")
#         # print(f"pos_img: {pos_img}")

#         # depth_prompt_normed, min1, max1 = self.normalize_depth_prompt(depth_prompt.unsqueeze(1))
#         # depth_prompt_mask = depth_prompt_normed != 0
#         # print(f"depth_prompt_normed.shape: {depth_prompt_normed.shape}")
#         # print(f"depth_prompt_mask.shape: {depth_prompt_mask.shape}")
#         # exit()
#         # feat_pc, pos_pc = self.patch_embed_pc(torch.cat((depth_prompt_normed, depth_prompt_mask), dim=1), true_shape=depth_true_shape)
        
#         n_image_tokens = f_img.shape[1]
#         final_output = [(f_state, f_img)]  # before projection
#         assert f_state.shape[-1] == self.dec_embed_dim
#         f_img = self.decoder_embed(f_img)
#         # print(f"before f_img.shape: {f_img.shape}")
#         if self.pose_head_flag:
#             assert f_pose is not None and pos_pose is not None
#             f_img = torch.cat([f_pose, f_img], dim=1) # f_img = [f_pose, f_img]
#             pos_img = torch.cat([pos_pose, pos_img], dim=1)
#         # print(f"after f_img.shape: {f_img.shape}")

#         # conv_res = self.zero_convs[0](feat_pc.transpose(-1,-2)).transpose(-1,-2)
#         # print(f"feat_pc.shape: {feat_pc.shape}")
#         # !depth features are added only to image features
#         # f_img[:,:n_image_tokens,:] += self.zero_convs[0](feat_pc.transpose(-1,-2)).transpose(-1,-2)

#         final_output.append((f_state, f_img))

#         for i in range(len(self.dec_blocks)):
#             blk_state = self.dec_blocks_state[i]
#             blk_img = self.dec_blocks[i]

#             if (
#                 self.gradient_checkpointing
#                 and self.training
#                 and torch.is_grad_enabled()
#             ):
#                 f_state, _ = checkpoint(
#                     blk_state,
#                     *final_output[-1][::+1],
#                     pos_state,
#                     pos_img,
#                     use_reentrant=not self.fixed_input_length,
#                 )
#                 f_img, _ = checkpoint(
#                     blk_img,
#                     *final_output[-1][::-1],
#                     pos_img,
#                     pos_state,
#                     use_reentrant=not self.fixed_input_length,
#                 )
#             else:
#                 # decoder(state,img)
#                 f_state, _ = blk_state(*final_output[-1][::+1], pos_state, pos_img)
#                 # decoder(img,state)
#                 f_img, _ = blk_img(*final_output[-1][::-1], pos_img, pos_state)
            
#             # if i < len(self.dec_blocks_pc):

#             #     feat_pc = self.dec_blocks_pc[i](feat_pc, pos_pc)
#             #     f_img[:,:n_image_tokens,:] += self.zero_convs[i+1](feat_pc.transpose(-1,-2)).transpose(-1,-2)


#             final_output.append((f_state, f_img))

#         del final_output[1]  # duplicate with final_output[0]
#         final_output[-1] = (
#             self.dec_norm_state(final_output[-1][0]),
#             self.dec_norm(final_output[-1][1]),
#         )
#         return zip(*final_output)

#     def _downstream_head(self, decout, img_shape, prompt_depth=None, **kwargs):
#         B, S, D = decout[-1].shape
#         head = getattr(self, f"head")
#         if self.depth_guidance:
#             return head(decout, img_shape, prompt_depth, **kwargs)
#         else:
#             return head(decout, img_shape, **kwargs)

#     def _init_state(self, image_tokens, image_pos):
#         """
#         Current Version: input the first frame img feature and pose to initialize the state feature and pose
#         """
#         state_feat, state_pos, _ = self._encode_state(image_tokens, image_pos)
#         state_feat = self.decoder_embed_state(state_feat)
#         return state_feat, state_pos

#     def _recurrent_rollout(
#         self,
#         state_feat,
#         state_pos,
#         current_feat,
#         current_pos,
#         pose_feat,
#         pose_pos,
#         init_state_feat,
#         img_mask=None,
#         reset_mask=None,
#         update=None,
#         depth_prompt=None,
#         depth_true_shape=None

#     ):
#         """
#         Performs decoder forward pass
#         """
#         new_state_feat, dec = self._decoder(state_feat, 
#                                             state_pos, 
#                                             current_feat, 
#                                             current_pos, 
#                                             pose_feat, 
#                                             pose_pos,
#                                             depth_prompt,
#                                             depth_true_shape
#         )
#         new_state_feat = new_state_feat[-1]
#         return new_state_feat, dec

#     def _get_img_level_feat(self, feat):
#         return torch.mean(feat, dim=1, keepdim=True)

#     def _forward_encoder(self, views):
#         shape, feat_ls, pos = self._encode_views(views)
#         feat = feat_ls[-1]
#         state_feat, state_pos = self._init_state(feat[0], pos[0])
#         mem = self.pose_retriever.mem.expand(feat[0].shape[0], -1, -1)
#         init_state_feat = state_feat.clone()
#         init_mem = mem.clone()
#         return (feat, pos, shape), (
#             init_state_feat,
#             init_mem,
#             state_feat,
#             state_pos,
#             mem,
#         )

#     def _forward_decoder_step(
#         self,
#         views,
#         i,
#         feat_i,
#         pos_i,
#         shape_i,
#         init_state_feat,
#         init_mem,
#         state_feat,
#         state_pos,
#         mem,
#     ):
#         """
#         Calls:
#         1. Decoders forward prop
#         2. Output head(dpt/linear) forward prop
        
#         Arguments:
#             views: list of views, where length of list corresponds to sequence length
#                     and every item in list has shape [B, ...]
#                    views[0].keys(): dict_keys(['img', 'depthmap', 'camera_pose', 'camera_intrinsics', 'dataset', 'label', 'is_metric', 'instance', 'is_video', 'quantile', 'img_mask', 'ray_mask', 'camera_only', 'depth_only', 'single_view', 'reset', 'idx', 'true_shape', 'sky_mask', 'ray_map', 'pts3d', 'valid_mask', 'rng'])
#         """

#         # print(f"len(views): {len(views):}")
#         # print(f"views[0].keys(): {views[0].keys()}")
#         # print(f"views[0]['depthmap'].shape: {views[0]['depthmap'].shape}")
#         # for i in range(len(views)):        
#         #     print(f"views[i]['depthmap'].shape: {views[i]['depthmap'].shape}")
#         # print(f"feat_i.shape: {feat_i.shape}")

#         # depth_prompt1 = self.normalize_depth_prompt(view1['depthmap'].unsqueeze(1))
#         # depth_prompt2 = self.normalize_depth_prompt(view2['depthmap'].unsqueeze(1))

#         if self.pose_head_flag:
#             global_img_feat_i = self._get_img_level_feat(feat_i)
#             # print(f"global_img_feat_i.shape: {global_img_feat_i.shape}")
#             if i == 0:
#                 pose_feat_i = self.pose_token.expand(feat_i.shape[0], -1, -1)
#             else:
#                 pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem)
#             pose_pos_i = -torch.ones(
#                 feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype
#             )
#         else:
#             pose_feat_i = None
#             pose_pos_i = None

#         # print(f"state_feat.shape: {state_feat.shape}")
#         # print(f"state_pos.shape: {state_pos.shape}")
#         # print(f"feat_i.shape: {feat_i.shape}")
#         # print(f"pos_i.shape: {pos_i.shape}")

#         # returns [f_state, f_img], if pose_mode: [f_state, [f_pose, f_img]]
#         new_state_feat, dec = self._recurrent_rollout(
#             state_feat, # f_state
#             state_pos, # pos_state 
#             feat_i, # f_img
#             pos_i, #pos_img
#             pose_feat_i, #f_pose
#             pose_pos_i, #pos_pose
#             init_state_feat,
#             img_mask=views[i]["img_mask"],
#             reset_mask=views[i]["reset"],
#             update=views[i].get("update", None),
#             depth_prompt=views[i]["depthmap"],
#             depth_true_shape=views[i]["true_shape"]
#         )
#         out_pose_feat_i = dec[-1][:, 0:1]
#         new_mem = self.pose_retriever.update_mem(
#             mem, global_img_feat_i, out_pose_feat_i
#         )
#         head_input = [
#             dec[0].float(), # f_img from encoder  of shape [B, N_tokens, D_0]
#             dec[self.dec_depth * 2 // 4][:, 1:].float(), #idx=6 [f_img] in the middle
#             dec[self.dec_depth * 3 // 4][:, 1:].float(), #idx=9 no_cls [f_img] in 75%
#             dec[self.dec_depth].float(), # idx=12 with_cls [f_img, f_pose]
#         ]

#         prompt_depth_i = self.normalize_depth_prompt(views[i]['depthmap'].unsqueeze(1))

#         res = self._downstream_head(head_input, shape_i, pos=pos_i, prompt_depth=prompt_depth_i)
#         img_mask = views[i]["img_mask"]
#         update = views[i].get("update", None)
#         if update is not None:
#             update_mask = img_mask & update  # if don't update, then whatever img_mask
#         else:
#             update_mask = img_mask
#         update_mask = update_mask[:, None, None].float()
#         state_feat = new_state_feat * update_mask + state_feat * (
#             1 - update_mask
#         )  # update global state
#         mem = new_mem * update_mask + mem * (1 - update_mask)  # then update local state
#         reset_mask = views[i]["reset"]
#         if reset_mask is not None:
#             reset_mask = reset_mask[:, None, None].float()
#             state_feat = init_state_feat * reset_mask + state_feat * (1 - reset_mask)
#             mem = init_mem * reset_mask + mem * (1 - reset_mask)
#         return res, (state_feat, mem)

#     def _forward_impl(self, views, ret_state=False):
#         # 1. Encoder forward prop
#         shape, feat_ls, pos = self._encode_views(views)
#         feat = feat_ls[-1]
#         # 2. State initialization
#         state_feat, state_pos = self._init_state(feat[0], pos[0])
#         mem = self.pose_retriever.mem.expand(feat[0].shape[0], -1, -1)
#         init_state_feat = state_feat.clone()
#         init_mem = mem.clone()
#         all_state_args = [(state_feat, state_pos, init_state_feat, mem, init_mem)]
#         ress = []
#         # 3. Reconstruction loop
#         for i in range(len(views)):
#             feat_i = feat[i]
#             pos_i = pos[i]
#             # whether pose should be predicted
#             if self.pose_head_flag:
#                 global_img_feat_i = self._get_img_level_feat(feat_i)
#                 if i == 0:
#                     pose_feat_i = self.pose_token.expand(feat_i.shape[0], -1, -1)
#                 else:
#                     pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem)
#                 pose_pos_i = -torch.ones(
#                     feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype
#                 )
#             else:
#                 pose_feat_i = None
#                 pose_pos_i = None
            
#             # iteratively update feat and return decoder feats
#             new_state_feat, dec = self._recurrent_rollout(
#                 state_feat,
#                 state_pos,
#                 feat_i,
#                 pos_i,
#                 pose_feat_i,
#                 pose_pos_i,
#                 init_state_feat,
#                 img_mask=views[i]["img_mask"],
#                 reset_mask=views[i]["reset"],
#                 update=views[i].get("update", None),
#                 depth_prompt=views[i]["depthmap"],
#                 depth_true_shape=views[i]["true_shape"]
#             )
#             out_pose_feat_i = dec[-1][:, 0:1]
#             new_mem = self.pose_retriever.update_mem(
#                 mem, global_img_feat_i, out_pose_feat_i
#             )
#             assert len(dec) == self.dec_depth + 1

#             # Pass decoder feats through output layer
#             head_input = [
#                 dec[0].float(),
#                 dec[self.dec_depth * 2 // 4][:, 1:].float(),
#                 dec[self.dec_depth * 3 // 4][:, 1:].float(),
#                 dec[self.dec_depth].float(),
#             ]

#             # print(f"views[i].keys(): {views[i].keys()}")
#             # exit()

#             # prompt_depth = None

#             prompt_depth_i = self.normalize_depth_prompt(views[i]['depthmap'].unsqueeze(1))
#             # print(f"prompt_depth_i.shape: {prompt_depth_i.shape}")
#             # exit()

#             res = self._downstream_head(head_input, shape[i], prompt_depth=prompt_depth_i, pos=pos_i)
#             # print(f"res.keys(): {res.keys()}")
#             # print(f"res['pts3d_in_self_view'].shape: {res['pts3d_in_self_view'].shape}")
#             # print(f"res['conf_self'].shape: {res['conf_self'].shape}")
#             # print(f"views[i]['pts3d'].shape: {views[i]['pts3d'].shape}")

#             ress.append(res)
#             img_mask = views[i]["img_mask"]
#             update = views[i].get("update", None)
#             if update is not None:
#                 update_mask = (
#                     img_mask & update
#                 )  # if don't update, then whatever img_mask
#             else:
#                 update_mask = img_mask
#             update_mask = update_mask[:, None, None].float()
#             state_feat = new_state_feat * update_mask + state_feat * (
#                 1 - update_mask
#             )  # update global state
#             mem = new_mem * update_mask + mem * (
#                 1 - update_mask
#             )  # then update local state
#             reset_mask = views[i]["reset"]
#             if reset_mask is not None:
#                 reset_mask = reset_mask[:, None, None].float()
#                 state_feat = init_state_feat * reset_mask + state_feat * (
#                     1 - reset_mask
#                 )
#                 mem = init_mem * reset_mask + mem * (1 - reset_mask)
#             all_state_args.append(
#                 (state_feat, state_pos, init_state_feat, mem, init_mem)
#             )
#         if ret_state:
#             return ress, views, all_state_args
#         return ress, views

#     def forward(self, views, ret_state=False):
#         if ret_state:
#             ress, views, state_args = self._forward_impl(views, ret_state=ret_state)
#             return ARCroco3DStereoOutput(ress=ress, views=views), state_args
#         else:
#             ress, views = self._forward_impl(views, ret_state=ret_state)
#             return ARCroco3DStereoOutput(ress=ress, views=views)

#     def inference_step(
#         self, view, state_feat, state_pos, init_state_feat, mem, init_mem
#     ):
#         batch_size = view["img"].shape[0]
#         raymaps = []
#         shapes = []
#         for j in range(batch_size):
#             assert view["ray_mask"][j]
#             raymap = view["ray_map"][[j]].permute(0, 3, 1, 2)
#             raymaps.append(raymap)
#             shapes.append(
#                 view.get(
#                     "true_shape",
#                     torch.tensor(view["ray_map"].shape[-2:])[None].repeat(
#                         view["ray_map"].shape[0], 1
#                     ),
#                 )[[j]]
#             )

#         raymaps = torch.cat(raymaps, dim=0)
#         shape = torch.cat(shapes, dim=0).to(raymaps.device)
#         feat_ls, pos, _ = self._encode_ray_map(raymaps, shapes)

#         feat_i = feat_ls[-1]
#         pos_i = pos
#         if self.pose_head_flag:
#             global_img_feat_i = self._get_img_level_feat(feat_i)
#             pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem)
#             pose_pos_i = -torch.ones(
#                 feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype
#             )
#         else:
#             pose_feat_i = None
#             pose_pos_i = None
#         new_state_feat, dec = self._recurrent_rollout(
#             state_feat,
#             state_pos,
#             feat_i,
#             pos_i,
#             pose_feat_i,
#             pose_pos_i,
#             init_state_feat,
#             img_mask=view["img_mask"],
#             reset_mask=view["reset"],
#             update=view.get("update", None),
#         )

#         out_pose_feat_i = dec[-1][:, 0:1]
#         new_mem = self.pose_retriever.update_mem(
#             mem, global_img_feat_i, out_pose_feat_i
#         )
#         assert len(dec) == self.dec_depth + 1
#         head_input = [
#             dec[0].float(),
#             dec[self.dec_depth * 2 // 4][:, 1:].float(),
#             dec[self.dec_depth * 3 // 4][:, 1:].float(),
#             dec[self.dec_depth].float(),
#         ]
#         res = self._downstream_head(head_input, shape, pos=pos_i)
#         return res, view

#     def forward_recurrent(self, views, device, ret_state=False):
#         ress = []
#         all_state_args = []
#         for i, view in enumerate(views):
#             device = view["img"].device
#             batch_size = view["img"].shape[0]
#             img_mask = view["img_mask"].reshape(
#                 -1, batch_size
#             )  # Shape: (1, batch_size)
#             ray_mask = view["ray_mask"].reshape(
#                 -1, batch_size
#             )  # Shape: (1, batch_size)
#             imgs = view["img"].unsqueeze(0)  # Shape: (1, batch_size, C, H, W)
#             ray_maps = view["ray_map"].unsqueeze(
#                 0
#             )  # Shape: (num_views, batch_size, H, W, C)
#             shapes = (
#                 view["true_shape"].unsqueeze(0)
#                 if "true_shape" in view
#                 else torch.tensor(view["img"].shape[-2:], device=device)
#                 .unsqueeze(0)
#                 .repeat(batch_size, 1)
#                 .unsqueeze(0)
#             )  # Shape: (num_views, batch_size, 2)
#             imgs = imgs.view(
#                 -1, *imgs.shape[2:]
#             )  # Shape: (num_views * batch_size, C, H, W)
#             ray_maps = ray_maps.view(
#                 -1, *ray_maps.shape[2:]
#             )  # Shape: (num_views * batch_size, H, W, C)
#             shapes = shapes.view(-1, 2).to(
#                 imgs.device
#             )  # Shape: (num_views * batch_size, 2)
#             img_masks_flat = img_mask.view(-1)  # Shape: (num_views * batch_size)
#             ray_masks_flat = ray_mask.view(-1)
#             selected_imgs = imgs[img_masks_flat]
#             selected_shapes = shapes[img_masks_flat]
#             if selected_imgs.size(0) > 0:
#                 img_out, img_pos, _ = self._encode_image(selected_imgs, selected_shapes)
#             else:
#                 img_out, img_pos = None, None
#             ray_maps = ray_maps.permute(0, 3, 1, 2)  # Change shape to (N, C, H, W)
#             selected_ray_maps = ray_maps[ray_masks_flat]
#             selected_shapes_ray = shapes[ray_masks_flat]
#             if selected_ray_maps.size(0) > 0:
#                 ray_out, ray_pos, _ = self._encode_ray_map(
#                     selected_ray_maps, selected_shapes_ray
#                 )
#             else:
#                 ray_out, ray_pos = None, None

#             shape = shapes
#             if img_out is not None and ray_out is None:
#                 feat_i = img_out[-1]
#                 pos_i = img_pos
#             elif img_out is None and ray_out is not None:
#                 feat_i = ray_out[-1]
#                 pos_i = ray_pos
#             elif img_out is not None and ray_out is not None:
#                 feat_i = img_out[-1] + ray_out[-1]
#                 pos_i = img_pos
#             else:
#                 raise NotImplementedError

#             if i == 0:
#                 state_feat, state_pos = self._init_state(feat_i, pos_i)
#                 mem = self.pose_retriever.mem.expand(feat_i.shape[0], -1, -1)
#                 init_state_feat = state_feat.clone()
#                 init_mem = mem.clone()
#                 all_state_args.append(
#                     (state_feat, state_pos, init_state_feat, mem, init_mem)
#                 )

#             if self.pose_head_flag:
#                 global_img_feat_i = self._get_img_level_feat(feat_i)
#                 if i == 0:
#                     pose_feat_i = self.pose_token.expand(feat_i.shape[0], -1, -1)
#                 else:
#                     pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem)
#                 pose_pos_i = -torch.ones(
#                     feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype
#                 )
#             else:
#                 pose_feat_i = None
#                 pose_pos_i = None
#             new_state_feat, dec = self._recurrent_rollout(
#                 state_feat,
#                 state_pos,
#                 feat_i,
#                 pos_i,
#                 pose_feat_i,
#                 pose_pos_i,
#                 init_state_feat,
#                 img_mask=view["img_mask"],
#                 reset_mask=view["reset"],
#                 update=view.get("update", None),
#             )
#             out_pose_feat_i = dec[-1][:, 0:1]
#             new_mem = self.pose_retriever.update_mem(
#                 mem, global_img_feat_i, out_pose_feat_i
#             )
#             assert len(dec) == self.dec_depth + 1
#             head_input = [
#                 dec[0].float(),
#                 dec[self.dec_depth * 2 // 4][:, 1:].float(),
#                 dec[self.dec_depth * 3 // 4][:, 1:].float(),
#                 dec[self.dec_depth].float(),
#             ]
#             res = self._downstream_head(head_input, shape, pos=pos_i)
#             ress.append(res)
#             img_mask = view["img_mask"]
#             update = view.get("update", None)
#             if update is not None:
#                 update_mask = (
#                     img_mask & update
#                 )  # if don't update, then whatever img_mask
#             else:
#                 update_mask = img_mask
#             update_mask = update_mask[:, None, None].float()
#             state_feat = new_state_feat * update_mask + state_feat * (
#                 1 - update_mask
#             )  # update global state
#             mem = new_mem * update_mask + mem * (
#                 1 - update_mask
#             )  # then update local state
#             reset_mask = view["reset"]
#             if reset_mask is not None:
#                 reset_mask = reset_mask[:, None, None].float()
#                 state_feat = init_state_feat * reset_mask + state_feat * (
#                     1 - reset_mask
#                 )
#                 mem = init_mem * reset_mask + mem * (1 - reset_mask)
#             all_state_args.append(
#                 (state_feat, state_pos, init_state_feat, mem, init_mem)
#             )
#         if ret_state:
#             return ress, views, all_state_args
#         return ress, views
