# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging

import numpy as np
import torch
import torch.distributed
from sam2.modeling.sam2_base import SAM2Base
from sam2.modeling.sam2_utils import (
    get_1d_sine_pe,
    get_next_point,
    sample_box_points,
    select_closest_cond_frames,
)

from sam2.utils.misc import concat_points

from training.utils.data_utils import BatchedVideoDatapoint

from sam2.modeling.boundary_module import BoundaryHead
from sam2.modeling.spatial_module import *
from training.utils.spatial_utils import RelativePositionLoss

# a large negative value as a placeholder score for missing objects
NO_OBJ_SCORE = -1024.0


class SAM2Train(SAM2Base):
    def __init__(
        self,
        image_encoder,
        memory_attention=None,
        memory_encoder=None,
        prob_to_use_pt_input_for_train=0.0,
        prob_to_use_pt_input_for_eval=0.0,
        prob_to_use_box_input_for_train=0.0,
        prob_to_use_box_input_for_eval=0.0,
        # if it is greater than 1, we interactive point sampling in the 1st frame and other randomly selected frames
        num_frames_to_correct_for_train=1,  # default: only iteratively sample on first frame
        num_frames_to_correct_for_eval=1,  # default: only iteratively sample on first frame
        rand_frames_to_correct_for_train=False,
        rand_frames_to_correct_for_eval=False,
        # how many frames to use as initial conditioning frames (for both point input and mask input; the first frame is always used as an initial conditioning frame)
        # - if `rand_init_cond_frames` below is True, we randomly sample 1~num_init_cond_frames initial conditioning frames
        # - otherwise we sample a fixed number of num_init_cond_frames initial conditioning frames
        # note: for point input, we sample correction points on all such initial conditioning frames, and we require that `num_frames_to_correct` >= `num_init_cond_frames`;
        # these are initial conditioning frames because as we track the video, more conditioning frames might be added
        # when a frame receives correction clicks under point input if `add_all_frames_to_correct_as_cond=True`
        num_init_cond_frames_for_train=1,  # default: only use the first frame as initial conditioning frame
        num_init_cond_frames_for_eval=1,  # default: only use the first frame as initial conditioning frame
        # default: random 1~num_init_cond_frames_for_train cond frames (to be constent w/ previous TA data loader)
        rand_init_cond_frames_for_train=True,
        rand_init_cond_frames_for_eval=False,
        # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
        # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
        add_all_frames_to_correct_as_cond=False,
        # how many additional correction points to sample (on each frame selected to be corrected)
        # note that the first frame receives an initial input click (in addition to any correction clicks)
        num_correction_pt_per_frame=7,
        # method for point sampling during evaluation
        # "uniform" (sample uniformly from error region) or "center" (use the point with the largest distance to error region boundary)
        # default to "center" to be consistent with evaluation in the SAM paper
        pt_sampling_for_eval="center",
        # During training, we optionally allow sampling the correction points from GT regions
        # instead of the prediction error regions with a small probability. This might allow the
        # model to overfit less to the error regions in training datasets
        prob_to_sample_from_gt_for_train=0.0,
        use_act_ckpt_iterative_pt_sampling=False,
        # whether to forward image features per frame (as it's being tracked) during evaluation, instead of forwarding image features
        # of all frames at once. This avoids backbone OOM errors on very long videos in evaluation, but could be slightly slower.
        forward_backbone_per_frame_for_eval=False,
        freeze_image_encoder=True,
        boundary_head=None,
        **kwargs,
    ):
        super().__init__(image_encoder, memory_attention, memory_encoder, **kwargs)
        self.use_act_ckpt_iterative_pt_sampling = use_act_ckpt_iterative_pt_sampling
        self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval

        # Point sampler and conditioning frames
        self.prob_to_use_pt_input_for_train = prob_to_use_pt_input_for_train
        self.prob_to_use_box_input_for_train = prob_to_use_box_input_for_train
        self.prob_to_use_pt_input_for_eval = prob_to_use_pt_input_for_eval
        self.prob_to_use_box_input_for_eval = prob_to_use_box_input_for_eval
        if prob_to_use_pt_input_for_train > 0 or prob_to_use_pt_input_for_eval > 0:
            logging.info(
                f"Training with points (sampled from masks) as inputs with p={prob_to_use_pt_input_for_train}"
            )
            assert num_frames_to_correct_for_train >= num_init_cond_frames_for_train
            assert num_frames_to_correct_for_eval >= num_init_cond_frames_for_eval

        self.num_frames_to_correct_for_train = num_frames_to_correct_for_train
        self.num_frames_to_correct_for_eval = num_frames_to_correct_for_eval
        self.rand_frames_to_correct_for_train = rand_frames_to_correct_for_train
        self.rand_frames_to_correct_for_eval = rand_frames_to_correct_for_eval
        # Initial multi-conditioning frames
        self.num_init_cond_frames_for_train = num_init_cond_frames_for_train
        self.num_init_cond_frames_for_eval = num_init_cond_frames_for_eval
        self.rand_init_cond_frames_for_train = rand_init_cond_frames_for_train
        self.rand_init_cond_frames_for_eval = rand_init_cond_frames_for_eval
        self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
        self.num_correction_pt_per_frame = num_correction_pt_per_frame
        self.pt_sampling_for_eval = pt_sampling_for_eval
        self.prob_to_sample_from_gt_for_train = prob_to_sample_from_gt_for_train
        # A random number generator with a fixed initial seed across GPUs
        self.rng = np.random.default_rng(seed=42)

        # ====== MODIFICATIONS ======
        # ------ boundary feature -----
        self.use_boundary = True
        self.boundary_head = boundary_head

        # ------ volume feature -------
        self.use_volume = True
        self.spatial_attention_layer = RPEAttention()
        self.position_prediction_head = RelativePositionHead()
        self.position_loss_fn = nn.CrossEntropyLoss()

        logging.info(f"freeze_image_encoder status:{freeze_image_encoder}")
        if freeze_image_encoder:
            for p in self.image_encoder.parameters():
                p.requires_grad = False

    def forward_image(self, img_batch: torch.Tensor):
        """New forward image section derives from base class function

        """

        """Get the image feature on the input batch."""
        backbone_out = self.image_encoder(img_batch)
        if self.use_high_res_features_in_sam:
            # precompute projected level 0 and level 1 features in SAM decoder
            # to avoid running it again on every SAM click
            backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
                backbone_out["backbone_fpn"][0]
            )
            backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
                backbone_out["backbone_fpn"][1]
            )

        # logging.info(
        #     f'backbone_out["backbone_fpn"] type:{type(backbone_out["backbone_fpn"])}, len:{len(backbone_out["backbone_fpn"])}, 0:shape:{backbone_out["backbone_fpn"][0].shape}, 1:shape:{backbone_out["backbone_fpn"][1].shape}, 2:shape:{backbone_out["backbone_fpn"][2].shape}')
        # logging.info(f'backbone_out["vision_features"] type:{type(backbone_out["vision_features"])}, shape:{backbone_out["vision_features"].shape}')
        # logging.info(
        #     f'backbone_out["vision_pos_enc"] type:{type(backbone_out["vision_pos_enc"])}, len:{len(backbone_out["vision_pos_enc"])}, 0:shape:{backbone_out["vision_pos_enc"][0].shape}, 1:shape:{backbone_out["vision_pos_enc"][1].shape}, 2:shape:{backbone_out["vision_pos_enc"][2].shape}')

        # --- Integration of the Boundary Head ---
        if self.use_boundary:
            # logging.info("Using boundary head to enhance high-res features.")

            # Select the highest-resolution feature map as input
            high_res_fpn = backbone_out['backbone_fpn'][0]

            # Get the enhanced feature and the boundary logits
            enhanced_fpn_feature, boundary_logits = self.boundary_head(high_res_fpn)

            # 1. Replace the original FPN feature with the enhanced one
            backbone_out['backbone_fpn'][0] = enhanced_fpn_feature

            # 2. Store the new boundary logits for loss calculation
            backbone_out["boundary_logits"] = boundary_logits

        return backbone_out

    def forward(self, input: BatchedVideoDatapoint):
        if self.training or not self.forward_backbone_per_frame_for_eval:
            # precompute image features on all frames before tracking
            backbone_out = self.forward_image(input.flat_img_batch)
        else:
            # defer image feature computation on a frame until it's being tracked
            backbone_out = {"backbone_fpn": None, "vision_pos_enc": None}
        backbone_out = self.prepare_prompt_inputs(backbone_out, input)

        if self.training:
            # Add operations to get center_img_features and context_img_features and their target layers
            spacial_out = self.process_spatial_feature_v2(backbone_out)  # directly inject into backbone_out already

        # enriched_observer_feature, extra_rpe_feature = self.RPE_module(backbone_out, spacial_out)

        previous_stages_out = self.forward_tracking(backbone_out, input)

        return previous_stages_out

    def RPE_module(self, backbone_out, spatial_out, viz_pred_slice=True):
        # 从 spatial_out 中解包所需张量
        observer_features_map = spatial_out['observer_features']  # Shape: (B*C, N, D, Hf, Wf)
        target_features_map = spatial_out['target_features']     # Shape: (B*C, N, D, Hf, Wf)
        relative_positions = spatial_out['relative_positions']   # Shape: (B*C, N)
        backbone_out['relative_positions'] = relative_positions

        # 1. 特征准备: 将特征图池化为特征向量
        #    为什么要进行全局平均池化 (Global Average Pooling)?
        #    - 注意力机制处理的是序列数据 (Batch, Seq_Len, Dim)，而不是4D特征图。
        #    - GAP能将每个切片的 (D, Hf, Wf) 特征图总结成一个 (D,) 的特征向量，代表该切片的全局信息。
        #    - 这是一个高效且常用的方法，用于连接卷积网络和Transformer/Attention层。

        # observer_features_map[:, 0, ...] 取出每个观察者自己的特征图
        # .mean(dim=(-2, -1)) 在 Hf, Wf 维度上进行平均池化
        observer_pooled = observer_features_map[:, 0, ...].mean(dim=(-2, -1)).unsqueeze(1)  # Shape: (B*C, 1, D)

        B_eff, N, D, Hf, Wf = target_features_map.shape
        # 将 target_features_map 从 (B*C, N, D, Hf, Wf) -> (B*C * N, D, Hf, Wf) 再池化
        target_pooled = target_features_map.reshape(B_eff * N, D, Hf, Wf).mean(dim=(-2, -1))
        target_pooled = target_pooled.view(B_eff, N, D)  # Shape: (B*C, N, D)

        # 2. RPE注意力模块: 增强观察者特征
        #    用整个空间上下文(targets)来丰富观察者(observer)的特征
        enriched_observer_feature = self.rpe_attention(
            query=observer_pooled,
            key=target_pooled,
            value=target_pooled,
            relative_positions=relative_positions
        )  # Output Shape: (B*C, 1, D)

        extra_rpe_feature = {}
        extra_rpe_feature['observer_pooled'] = observer_pooled
        extra_rpe_feature['target_pooled'] = target_pooled

        # logging.info(f"predicted_pos_logits:{predicted_pos_logits.shape}")
        # if self.training and viz_pred_slice:
        #     self.log_position_predictions(
        #         predicted_logits=predicted_pos_logits.detach(),  # detach() 避免影响梯度计算
        #         true_relative_positions=relative_positions,
        #         max_rel_dist=self.rpe_attention.max_rel_dist  # 从模块中获取超参数
        #     )

        # spatial_context_loss = self.relative_pos_loss(
        #     predicted_pos_logits,
        #     relative_positions
        # )

        # logging.info(f"spatial_context_loss:{spatial_context_loss}")

        return enriched_observer_feature, extra_rpe_feature

    def log_position_predictions(self,
                                 predicted_logits: torch.Tensor,
                                 true_relative_positions: torch.Tensor,
                                 max_rel_dist: int,
                                 num_examples: int = 2):
        """
        Logs a comparison of predicted vs. true relative positions for debugging.

        Args:
            predicted_logits (torch.Tensor): The model's output logits. Shape: (B, N, NumClasses)
            true_relative_positions (torch.Tensor): The ground truth positions. Shape: (B, N)
            max_rel_dist (int): The hyperparameter `k` used for position calculation.
            num_examples (int): How many examples from the batch to log.
        """
        if not self.training:  # 只在训练时打印，避免干扰评估
            return

        # 1. 从 Logits 中获取预测的类别索引
        #    torch.argmax 在最后一个维度 (NumClasses) 上找到最大值的索引
        predicted_indices = torch.argmax(predicted_logits, dim=-1)  # Shape: (B, N)

        # 2. 将预测的类别索引 [0, 2k] 转换回相对位置 [-k, k]
        #    这是 `RelativePositionLoss` 中 `target_indices = pos + max_rel_dist` 的逆运算
        predicted_positions = predicted_indices - max_rel_dist

        # 确保只记录部分样本，防止刷屏
        num_to_log = min(num_examples, predicted_logits.shape[0])

        logging.info("--- [Spatial Context] Position Prediction Check ---")
        for i in range(num_to_log):
            true_pos_sample = true_relative_positions[i].cpu().tolist()
            pred_pos_sample = predicted_positions[i].cpu().tolist()

            # 计算该样本的预测准确率
            correct_predictions = torch.eq(true_relative_positions[i], predicted_positions[i]).sum().item()
            accuracy = (correct_predictions / len(true_pos_sample)) * 100

            log_message = (
                f"\n[Observer Example {i+1}]:"
                f"\n  => Ground Truth Rel. Pos: {true_pos_sample}"
                f"\n  => Predicted Rel. Pos:    {pred_pos_sample}"
                f"\n  => Accuracy for this observer: {accuracy:.2f}%"
            )
            logging.info(log_message)
        logging.info("--------------------------------------------------")

    def process_spatial_feature(self, backbone_out):
        spacial_out = {}

        num_frames = backbone_out['num_frames']
        B = backbone_out['vision_features'].shape[0]
        init_cond_frames = backbone_out['init_cond_frames']  # get core image slice
        frames_not_in_init_cond = backbone_out['frames_not_in_init_cond']  # context image slice

        batch_size = B // num_frames
        fpn_feature_volume = []
        for feat in backbone_out['backbone_fpn']:
            C, H, W = feat.shape[1], feat.shape[2], feat.shape[3]
            fpn_feature_volume.append(feat.reshape(batch_size, num_frames, C, H, W))

        core_img_features = {}
        context_img_features = []
        for slice_idx in init_cond_frames:
            features_for_this_slice = []

            for fpn_level_tensor in fpn_feature_volume:
                sliced_feature = fpn_level_tensor[:, slice_idx: slice_idx + 1, :, :, :]  # get 5D tensor
                features_for_this_slice.append(sliced_feature)

            core_img_features[slice_idx] = features_for_this_slice

        logging.info(f"core_img_features:{core_img_features[1][0].shape}")
        exit(0)

        logging.info(f"keys:{backbone_out.keys()}")
        logging.info(f"num_frames:{backbone_out['num_frames']}")
        logging.info(f"gt_masks_per_frame:{len(backbone_out['gt_masks_per_frame'])}")
        logging.info(f"init_cond_frames:{backbone_out['init_cond_frames']}")
        logging.info(f"frames_not_in_init_cond:{backbone_out['frames_not_in_init_cond']}")
        logging.info(f"vision_features:{backbone_out['vision_features'].shape}")
        logging.info(f"vision_pos_enc(is a list):{len(backbone_out['vision_pos_enc'])}, shape of element{backbone_out['vision_pos_enc'][1].shape}")
        logging.info(f"backbone_fpn[0]:{backbone_out['backbone_fpn'][0].shape}")
        logging.info(f"backbone_fpn[1]:{backbone_out['backbone_fpn'][1].shape}")
        logging.info(f"backbone_fpn[2]:{backbone_out['backbone_fpn'][2].shape}")

        exit(0)

        return spacial_out

    def process_spatial_feature_v2(
        self, backbone_out: dict, local_context_window_size: int = 2
    ):
        """
        Processes backbone features to generate TWO distinct sets of structured outputs.
        This function ONLY structures the data; it does not flatten it for a decoder.

        1.  "All-Perspectives" (Global Context): For each core "observer" slice, it pairs it
            with ALL other slices. This is for advanced cross-attention modules. The output
            batch dimension is (B * num_cores).

        2.  "Per-Slice Local Context": For EVERY slice, it bundles its features with those
            of its immediate neighbors (in a window of size 2*k+1). This output is
            structured for a per-slice decoder, with a batch dimension of (B * num_frames).

        Args:
            backbone_out (dict): The output from the `forward_image` function.
            local_context_window_size (int): Radius of the local context window (k).
                                            k=2 results in a total window of 5 slices.

        Returns:
            dict: The input `backbone_out` dictionary, updated with two new keys:
                - 'spatial_out': The "all-perspectives" processed output.
                - 'per_slice_local_context': The new processed output with local context.
        """
        # =================================================================================
        # 通用准备阶段 (Common Preparation Stage)
        # =================================================================================
        device = backbone_out["vision_features"].device
        num_frames = backbone_out["num_frames"]
        k = local_context_window_size

        def _reshape_to_volume(flat_tensor: torch.Tensor) -> torch.Tensor:
            B_flat, *rest_dims = flat_tensor.shape
            batch_size = B_flat // num_frames
            return flat_tensor.view(batch_size, num_frames, *rest_dims)

        fpn_feature_volume = [_reshape_to_volume(f) for f in backbone_out["backbone_fpn"]]
        vision_features_volume = _reshape_to_volume(backbone_out["vision_features"])
        B, N, D_vit, Hf_vit, Wf_vit = vision_features_volume.shape

        # =================================================================================
        # 第一部分：计算 "All-Perspectives" 全局上下文 (保留原始逻辑)
        # =================================================================================
        # logging.info("--- Part 1: Processing All-Perspectives Global Context ---")
        core_indices = torch.tensor(
            backbone_out["init_cond_frames"], dtype=torch.long, device=device
        )
        C = len(core_indices)
        all_indices = torch.arange(N, device=device)

        # --- ViT Features ---
        observer_features = vision_features_volume[:, core_indices, ...]
        observer_features_expanded = observer_features.unsqueeze(2).expand(-1, -1, N, -1, -1, -1)
        target_features_expanded = vision_features_volume.unsqueeze(1).expand(-1, C, -1, -1, -1, -1)

        # --- FPN Features ---
        observer_fpn_expanded_list, target_fpn_expanded_list = [], []
        for fpn_level in fpn_feature_volume:
            observer_fpn = fpn_level[:, core_indices, ...]
            observer_fpn_expanded_list.append(
                observer_fpn.unsqueeze(2).expand(-1, -1, N, -1, -1, -1)
            )
            target_fpn_expanded_list.append(
                fpn_level.unsqueeze(1).expand(-1, C, -1, -1, -1, -1)
            )

        # --- Relative Positions ---
        relative_positions = all_indices.view(1, 1, N) - core_indices.view(1, C, 1)

        # --- 拉平为 (B*C) 的批次维度 ---
        new_batch_size_global = B * C
        observer_features_flat = observer_features_expanded.reshape(new_batch_size_global, N, D_vit, Hf_vit, Wf_vit)
        target_features_flat = target_features_expanded.reshape(new_batch_size_global, N, D_vit, Hf_vit, Wf_vit)
        observer_fpn_flat = [fpn.reshape(new_batch_size_global, N, *fpn.shape[3:]) for fpn in observer_fpn_expanded_list]
        target_fpn_flat = [fpn.reshape(new_batch_size_global, N, *fpn.shape[3:]) for fpn in target_fpn_expanded_list]
        relative_positions_flat = relative_positions.expand(B, -1, -1).reshape(new_batch_size_global, N)

        spatial_out = {
            "observer_features": observer_features_flat,
            "target_features": target_features_flat,
            "observer_fpn": observer_fpn_flat,
            "target_fpn": target_fpn_flat,
            "relative_positions": relative_positions_flat,
            "metadata": {"original_batch_size": B, "num_cores": C, "num_frames": N},
        }
        backbone_out["spatial_out"] = spatial_out
        # logging.info(f"Generated 'spatial_out' (for global context) with new batch size {new_batch_size_global}.")

        # =================================================================================
        # 第二部分：为每个切片构建局部上下文 (新增逻辑)
        # =================================================================================
        # logging.info(f"--- Part 2: Processing Per-Slice Local Context (Window k={k}) ---")

        window_size = 2 * k + 1
        local_offsets = torch.arange(-k, k + 1, device=device)
        slice_indices = torch.arange(N, device=device)
        neighbor_indices = slice_indices.view(N, 1) + local_offsets.view(1, window_size)
        clamped_neighbor_indices = torch.clamp(neighbor_indices, 0, N - 1)

        local_context_fpn_list = []
        for fpn_level in fpn_feature_volume:
            # fpn_level shape: (B, N, FPN_D, FPN_H, FPN_W)
            # Using advanced indexing to gather neighbors for each slice.
            # Result shape: (B, N, window_size, FPN_D, FPN_H, FPN_W)
            gathered_features = fpn_level[:, clamped_neighbor_indices, ...]
            local_context_fpn_list.append(gathered_features)

        # --- 拉平为 (B*N) 的批次维度 ---
        new_batch_size_local = B * N
        final_local_context_fpn = []
        for fpn_context in local_context_fpn_list:
            # from (B, N, window_size, ...) to (B*N, window_size, ...)
            _, _, _, FPN_D, FPN_H, FPN_W = fpn_context.shape
            reshaped_fpn = fpn_context.reshape(new_batch_size_local, window_size, FPN_D, FPN_H, FPN_W)
            final_local_context_fpn.append(reshaped_fpn)

        relative_positions_local = local_offsets.view(1, -1).expand(new_batch_size_local, -1)

        per_slice_local_context = {
            "local_context_fpn": final_local_context_fpn,
            "local_relative_positions": relative_positions_local,
            "metadata": {
                "original_batch_size": B,
                "num_frames": N,
                "window_size": window_size,
            },
        }
        backbone_out["slice_spatial_out"] = per_slice_local_context
        # logging.info(f"Generated 'per_slice_local_context' with new batch size {new_batch_size_local}.")
        # logging.info(f"  - Shape of local_context_fpn[0]: {per_slice_local_context['local_context_fpn'][0].shape}")

        return backbone_out

    def process_spatial_feature_all_perspectives(self, backbone_out: dict) -> dict:
        """
        Processes backbone output for a full "all-perspectives" setup.

        For each core slice ("observer"), this function pairs it with ALL slices in the
        window ("targets", including itself). This creates a rich set of relational
        pairs, maximizing data utilization and enabling both self- and cross-attention.
        The output tensors are flattened into a new, larger batch dimension (B * num_cores)
        for efficient, parallel processing in the subsequent attention module.

        Args:
            backbone_out (dict): The output from the `forward_image` function, containing
                                flattened features and metadata. Expected keys include:
                                'num_frames', 'init_cond_frames', 'vision_features',
                                and 'backbone_fpn'.

        Returns:
            dict: A structured dictionary ready for the attention module:
                - 'observer_features': Main ViT features of the observer slices.
                                        Shape: (B*C, N, D, Hf, Wf)
                - 'target_features': Main ViT features of the target slices.
                                    Shape: (B*C, N, D, Hf, Wf)
                - 'observer_fpn': List of FPN features for observer slices.
                                    Each element shape: (B*C, N, FPN_D, FPN_H, FPN_W)
                - 'target_fpn': List of FPN features for target slices.
                                Each element shape: (B*C, N, FPN_D, FPN_H, FPN_W)
                - 'relative_positions': The spatial distance from each target to its
                                        observer. Shape: (B*C, N).
                - 'metadata': A sub-dictionary with original dimensions for potential
                                reshaping later: {'original_batch_size', 'num_cores',
                                'num_frames'}.
        """
        # 1. 元数据和特征恢复 (Metadata and Feature Reshaping)
        # --------------------------------------------------------------------------
        device = backbone_out['vision_features'].device
        num_frames = backbone_out['num_frames']

        all_indices = torch.arange(num_frames, device=device)
        core_indices = torch.tensor(backbone_out['init_cond_frames'], dtype=torch.long, device=device)

        # Helper function to reshape flattened tensors back to their spatial volume
        def _reshape_to_volume(flat_tensor: torch.Tensor) -> torch.Tensor:
            B_flat, *rest_dims = flat_tensor.shape
            batch_size = B_flat // num_frames
            return flat_tensor.view(batch_size, num_frames, *rest_dims)

        # Reshape all relevant features from (B*N, ...) to (B, N, ...)
        fpn_feature_volume = [_reshape_to_volume(f) for f in backbone_out['backbone_fpn']]
        vision_features_volume = _reshape_to_volume(backbone_out['vision_features'])

        B, N, D, Hf, Wf = vision_features_volume.shape
        C = len(core_indices)

        # 2. 核心逻辑: 构建“观察者-目标”对 (Core Logic: Build Observer-Target Pairs)
        # --------------------------------------------------------------------------

        # --- Main ViT Features ---
        # Get observer features (B, C, D, Hf, Wf)
        observer_features = vision_features_volume[:, core_indices, ...]
        # Expand to match targets: (B, C, 1, ...) -> (B, C, N, ...)
        observer_features_expanded = observer_features.unsqueeze(2).expand(-1, -1, N, -1, -1, -1)

        # Target features are the whole volume, expanded for each observer
        # (B, N, D, Hf, Wf) -> (B, 1, N, ...) -> (B, C, N, ...)
        target_features_expanded = vision_features_volume.unsqueeze(1).expand(-1, C, -1, -1, -1, -1)

        # --- FPN Features (handled level by level) ---
        observer_fpn_expanded_list = []
        target_fpn_expanded_list = []
        for fpn_level in fpn_feature_volume:
            # Observer FPN features (B, C, FPN_D, FPN_H, FPN_W)
            observer_fpn = fpn_level[:, core_indices, ...]
            # Expanded: (B, C, N, FPN_D, FPN_H, FPN_W)
            observer_fpn_expanded_list.append(observer_fpn.unsqueeze(2).expand(-1, -1, N, -1, -1, -1))

            # Target FPN features, expanded for each observer
            # Expanded: (B, C, N, FPN_D, FPN_H, FPN_W)
            target_fpn_expanded_list.append(fpn_level.unsqueeze(1).expand(-1, C, -1, -1, -1, -1))

        # --- Relative Positions ---
        # Calculate the relative position matrix using broadcasting
        # core_indices (C,) -> (1, C, 1) | all_indices (N,) -> (1, 1, N)
        # Resulting shape: (1, C, N)
        relative_positions = all_indices.view(1, 1, N) - core_indices.view(1, C, 1)

        # 3. 拉平为新的批次维度 (Flatten to New Batch Dimension)
        # --------------------------------------------------------------------------
        new_batch_size = B * C

        # Flatten main features
        observer_features_flat = observer_features_expanded.reshape(new_batch_size, N, D, Hf, Wf)
        target_features_flat = target_features_expanded.reshape(new_batch_size, N, D, Hf, Wf)

        # Flatten FPN features
        observer_fpn_flat = [fpn.reshape(new_batch_size, N, *fpn.shape[3:]) for fpn in observer_fpn_expanded_list]
        target_fpn_flat = [fpn.reshape(new_batch_size, N, *fpn.shape[3:]) for fpn in target_fpn_expanded_list]

        # Expand and flatten relative positions
        relative_positions_flat = relative_positions.expand(B, -1, -1).reshape(new_batch_size, N)

        # 4. 组装最终输出 (Assemble Final Output)
        # --------------------------------------------------------------------------
        spatial_out = {
            'observer_features': observer_features_flat,
            'target_features': target_features_flat,
            'observer_fpn': observer_fpn_flat,
            'target_fpn': target_fpn_flat,
            'relative_positions': relative_positions_flat,

            'metadata': {
                'original_batch_size': B,
                'num_cores': C,
                'num_frames': N,
            }
        }

        # # Logging for verification
        # logging.info("--- Spatial Feature Processing (All Perspectives - V2 Complete) ---")
        # logging.info(f"Original B={B}, Num Cores={C}, Num Frames={N}")
        # logging.info(f"New effective batch size (B*C): {new_batch_size}")
        # logging.info(f"Shape of 'observer_features': {spatial_out['observer_features'].shape}")
        # logging.info(f"Shape of 'target_features': {spatial_out['target_features'].shape}")
        # logging.info(f"Shape of 'observer_fpn[0]': {spatial_out['observer_fpn'][0].shape}")
        # logging.info(f"Shape of 'relative_positions': {spatial_out['relative_positions'].shape}")
        # # Log a sample to show the '0' for self-relation
        # if new_batch_size > 0:
        #     logging.info(f"Sample relative positions for first observer: {spatial_out['relative_positions'][0].tolist()}")

        # save back into backbone out section
        backbone_out["spatial_out"] = spatial_out
        return spatial_out

    def _prepare_backbone_features_per_frame(self, img_batch, img_ids):
        """Compute the image backbone features on the fly for the given img_ids."""
        # Only forward backbone on unique image ids to avoid repetitive computation
        # (if `img_ids` has only one element, it's already unique so we skip this step).
        if img_ids.numel() > 1:
            unique_img_ids, inv_ids = torch.unique(img_ids, return_inverse=True)
        else:
            unique_img_ids, inv_ids = img_ids, None

        # Compute the image features on those unique image ids
        image = img_batch[unique_img_ids]
        backbone_out = self.forward_image(image)
        (
            _,
            vision_feats,
            vision_pos_embeds,
            feat_sizes,
        ) = self._prepare_backbone_features(backbone_out)
        # Inverse-map image features for `unique_img_ids` to the final image features
        # for the original input `img_ids`.
        if inv_ids is not None:
            image = image[inv_ids]
            vision_feats = [x[:, inv_ids] for x in vision_feats]
            vision_pos_embeds = [x[:, inv_ids] for x in vision_pos_embeds]

        return image, vision_feats, vision_pos_embeds, feat_sizes

    def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0):
        """
        Prepare input mask, point or box prompts. Optionally, we allow tracking from
        a custom `start_frame_idx` to the end of the video (for evaluation purposes).
        """
        # Load the ground-truth masks on all frames (so that we can later
        # sample correction points from them)
        # gt_masks_per_frame = {
        #     stage_id: targets.segments.unsqueeze(1)  # [B, 1, H_im, W_im]
        #     for stage_id, targets in enumerate(input.find_targets)
        # }
        gt_masks_per_frame = {
            stage_id: masks.unsqueeze(1)  # [B, 1, H_im, W_im]
            for stage_id, masks in enumerate(input.masks)
        }
        # gt_masks_per_frame = input.masks.unsqueeze(2) # [T,B,1,H_im,W_im] keep everything in tensor form
        backbone_out["gt_masks_per_frame"] = gt_masks_per_frame
        num_frames = input.num_frames
        backbone_out["num_frames"] = num_frames

        # Randomly decide whether to use point inputs or mask inputs
        if self.training:
            prob_to_use_pt_input = self.prob_to_use_pt_input_for_train
            prob_to_use_box_input = self.prob_to_use_box_input_for_train
            num_frames_to_correct = self.num_frames_to_correct_for_train
            rand_frames_to_correct = self.rand_frames_to_correct_for_train
            num_init_cond_frames = self.num_init_cond_frames_for_train
            rand_init_cond_frames = self.rand_init_cond_frames_for_train
        else:
            prob_to_use_pt_input = self.prob_to_use_pt_input_for_eval
            prob_to_use_box_input = self.prob_to_use_box_input_for_eval
            num_frames_to_correct = self.num_frames_to_correct_for_eval
            rand_frames_to_correct = self.rand_frames_to_correct_for_eval
            num_init_cond_frames = self.num_init_cond_frames_for_eval
            rand_init_cond_frames = self.rand_init_cond_frames_for_eval
        if num_frames == 1:
            # here we handle a special case for mixing video + SAM on image training,
            # where we force using point input for the SAM task on static images
            prob_to_use_pt_input = 1.0
            num_frames_to_correct = 1
            num_init_cond_frames = 1
        assert num_init_cond_frames >= 1
        # (here `self.rng.random()` returns value in range 0.0 <= X < 1.0)

        # Choose num_init_cond_frames and num_frames_to_correct
        use_pt_input = self.rng.random() < prob_to_use_pt_input
        if rand_init_cond_frames and num_init_cond_frames > 1:
            # randomly select 1 to `num_init_cond_frames` frames as initial conditioning frames
            num_init_cond_frames = self.rng.integers(
                1, num_init_cond_frames, endpoint=True
            )
        if (
            use_pt_input
            and rand_frames_to_correct
            and num_frames_to_correct > num_init_cond_frames
        ):
            # randomly select `num_init_cond_frames` to `num_frames_to_correct` frames to sample
            # correction clicks (only for the case of point input)
            num_frames_to_correct = self.rng.integers(
                num_init_cond_frames, num_frames_to_correct, endpoint=True
            )
        backbone_out["use_pt_input"] = use_pt_input

        # Sample initial conditioning frames
        sample_start_frame_idx = start_frame_idx + 1
        if num_init_cond_frames == 1:
            init_cond_frames = [sample_start_frame_idx]  # starting frame
        else:
            # starting frame + randomly selected remaining frames (without replacement)
            init_cond_frames = [sample_start_frame_idx] + self.rng.choice(
                range(start_frame_idx + 2, num_frames),
                num_init_cond_frames - 1,
                replace=False,
            ).tolist()

        # logging.info(f"init_cond_frames:{init_cond_frames}")
        backbone_out["init_cond_frames"] = init_cond_frames
        backbone_out["frames_not_in_init_cond"] = [
            t for t in range(start_frame_idx, num_frames) if t not in init_cond_frames
        ]

        # Prepare mask or point inputs on initial conditioning frames
        backbone_out["mask_inputs_per_frame"] = {}  # {frame_idx: <input_masks>}
        backbone_out["point_inputs_per_frame"] = {}  # {frame_idx: <input_points>}
        for t in init_cond_frames:
            if not use_pt_input:
                backbone_out["mask_inputs_per_frame"][t] = gt_masks_per_frame[t]
            else:
                # During training # P(box) = prob_to_use_pt_input * prob_to_use_box_input
                use_box_input = self.rng.random() < prob_to_use_box_input
                if use_box_input:
                    points, labels = sample_box_points(
                        gt_masks_per_frame[t],
                    )
                else:
                    # (here we only sample **one initial point** on initial conditioning frames from the
                    # ground-truth mask; we may sample more correction points on the fly)
                    points, labels = get_next_point(
                        gt_masks=gt_masks_per_frame[t],
                        pred_masks=None,
                        method=(
                            "uniform" if self.training else self.pt_sampling_for_eval
                        ),
                    )

                point_inputs = {"point_coords": points, "point_labels": labels}
                backbone_out["point_inputs_per_frame"][t] = point_inputs

        # Sample frames where we will add correction clicks on the fly
        # based on the error between prediction and ground-truth masks
        if not use_pt_input:
            # no correction points will be sampled when using mask inputs
            frames_to_add_correction_pt = []
        elif num_frames_to_correct == num_init_cond_frames:
            frames_to_add_correction_pt = init_cond_frames
        else:
            assert num_frames_to_correct > num_init_cond_frames
            # initial cond frame + randomly selected remaining frames (without replacement)
            extra_num = num_frames_to_correct - num_init_cond_frames
            frames_to_add_correction_pt = (
                init_cond_frames
                + self.rng.choice(
                    backbone_out["frames_not_in_init_cond"], extra_num, replace=False
                ).tolist()
            )
        backbone_out["frames_to_add_correction_pt"] = frames_to_add_correction_pt

        return backbone_out

    def forward_tracking(
        self, backbone_out, input: BatchedVideoDatapoint, return_dict=False
    ):
        """Forward video tracking on each frame (and sample correction clicks)."""
        img_feats_already_computed = backbone_out["backbone_fpn"] is not None
        if img_feats_already_computed:
            # Prepare the backbone features
            # - vision_feats and vision_pos_embeds are in (HW)BC format
            (
                _,
                vision_feats,
                vision_pos_embeds,
                feat_sizes,
            ) = self._prepare_backbone_features(backbone_out)

            prepared_spatial_features = self._prepare_spatial_features(backbone_out["spatial_out"])  # global
            prepared_slice_spatial_features = self._prepare_local_spatial_features(backbone_out["slice_spatial_out"])  # local context sliding window

            prepared_local_context_fpn = prepared_slice_spatial_features['prepared_local_context_fpn']
            feat_sizes = prepared_slice_spatial_features['feat_sizes']
            local_relative_positions = prepared_slice_spatial_features['local_relative_positions']
            metadata = prepared_slice_spatial_features['metadata']

        # Starting the stage loop
        num_frames = backbone_out["num_frames"]
        init_cond_frames = backbone_out["init_cond_frames"]
        frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"]
        # first process all the initial conditioning frames to encode them as memory,
        # and then conditioning on them to track the remaining frames
        processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"]
        output_dict = {
            "cond_frame_outputs": {},  # dict containing {frame_idx: <out>}
            "non_cond_frame_outputs": {},  # dict containing {frame_idx: <out>}
        }

        # logging.info(f"processing_order:{processing_order}")
        current_local_context_vision_feats = []
        current_local_context_vision_relative_pos = []

        # process input data per frame
        for stage_id in processing_order:
            # Get the image features for the current frames
            # img_ids = input.find_inputs[stage_id].img_ids
            img_ids = input.flat_obj_to_img_idx[stage_id]
            if img_feats_already_computed:
                # Retrieve image features according to img_ids (if they are already computed).
                current_vision_feats = [x[:, img_ids] for x in vision_feats]
                current_vision_pos_embeds = [x[:, img_ids] for x in vision_pos_embeds]
                # logging.info(
                #     f"current_vision_feats:{len(current_vision_feats)}, shape:{current_vision_feats[0].shape}, {current_vision_feats[1].shape}, {current_vision_feats[2].shape}")
                # logging.info(f"current_vision_pos_embeds:{len(current_vision_pos_embeds)}, shape:{current_vision_pos_embeds[0].shape}")

                current_local_context_vision_feats = [
                    fpn_level[:, img_ids] for fpn_level in prepared_local_context_fpn
                ]
                current_local_context_vision_relative_pos = local_relative_positions[img_ids]

                # --- 新的日志记录，用于验证 ---
                # logging.info(f"--- Processing Stage {stage_id} ---")
                # logging.info(f"Number of objects/slices in this stage: {len(img_ids)}")

                # if current_local_context_vision_feats:
                #     # 打印第一个 FPN level 的形状作为代表
                #     fpn_shape = current_local_context_vision_feats[0].shape
                #     logging.info(f"Shape of current_local_context_vision_feats[0]: {fpn_shape}")
                #     # 预期的形状: (SeqLen, K, FeatureDim)
                #     # 例如: (Win*H*W, num_objects_in_stage, FPN_D)

                #     # 为了更详细的日志
                #     win = metadata['window_size']
                #     h, w = feat_sizes[0]
                #     logging.info(f"  - Verified SeqLen ({win}*{h}*{w}): {fpn_shape[0]}")
                #     logging.info(f"  - Verified Batch Dim (num objects): {fpn_shape[1]}")
                #     logging.info(f"  - Verified Feature Dim: {fpn_shape[2]}")

                # logging.info(f"Shape of current_local_context_vision_relative_pos: {current_local_context_vision_relative_pos.shape}")
                # logging.info(f"Content of current_local_context_vision_relative_pos:\n{current_local_context_vision_relative_pos}")

            else:
                # Otherwise, compute the image features on the fly for the given img_ids
                # (this might be used for evaluation on long videos to avoid backbone OOM).
                (
                    _,
                    current_vision_feats,
                    current_vision_pos_embeds,
                    feat_sizes,
                ) = self._prepare_backbone_features_per_frame(
                    input.flat_img_batch, img_ids
                )

            # Get output masks based on this frame's prompts and previous memory
            current_out = self.track_step(
                frame_idx=stage_id,
                is_init_cond_frame=stage_id in init_cond_frames,
                current_vision_feats=current_vision_feats,
                current_vision_pos_embeds=current_vision_pos_embeds,
                current_local_context_vision_feats=current_local_context_vision_feats,
                current_local_context_vision_relative_pos=current_local_context_vision_relative_pos,
                feat_sizes=feat_sizes,
                point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None),
                mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None),
                gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None),
                frames_to_add_correction_pt=frames_to_add_correction_pt,
                output_dict=output_dict,
                num_frames=num_frames,
            )
            # Append the output, depending on whether it's a conditioning frame
            add_output_as_cond_frame = stage_id in init_cond_frames or (
                self.add_all_frames_to_correct_as_cond
                and stage_id in frames_to_add_correction_pt
            )
            if add_output_as_cond_frame:
                output_dict["cond_frame_outputs"][stage_id] = current_out
            else:
                output_dict["non_cond_frame_outputs"][stage_id] = current_out

        # # full batch frame prediction
        # predicted_pos_logits = self.pos_prediction_head(
        #     observer_feature=extra_rpe_feature['observer_pooled'],
        #     target_features=extra_rpe_feature['target_pooled']
        # )  # Output Shape: (B*C, N, num_classes)

        if return_dict:
            return output_dict
        # turn `output_dict` into a list for loss function
        all_frame_outputs = {}
        all_frame_outputs.update(output_dict["cond_frame_outputs"])
        all_frame_outputs.update(output_dict["non_cond_frame_outputs"])

        # return boundary logits info
        if self.use_boundary and self.training:
            # print(f"all_frame_outputs:{all_frame_outputs.keys()}")
            for key in all_frame_outputs:
                all_frame_outputs[key]['boundary_logits'] = backbone_out['boundary_logits']

        # # return predicted_pos_logits info
        # if self.use_volume and self.training:
        #     for key in all_frame_outputs:
        #         all_frame_outputs[key]['predicted_pos_logits'] = predicted_pos_logits
        #         all_frame_outputs[key]['relative_positions'] = backbone_out['relative_positions']

        all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)]
        # Make DDP happy with activation checkpointing by removing unused keys
        all_frame_outputs = [
            {k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs
        ]

        # # deliver boundary logits to final part
        # print(f"type:{type(all_frame_outputs[0])}, keys:{all_frame_outputs[0].keys()}")
        # print(f"shape of object_token:{all_frame_outputs[0]['object_token'].shape}")

        return all_frame_outputs

    def track_step(
        self,
        frame_idx,
        is_init_cond_frame,
        current_vision_feats,
        current_vision_pos_embeds,
        current_local_context_vision_feats,
        current_local_context_vision_relative_pos,
        feat_sizes,
        point_inputs,
        mask_inputs,
        output_dict,
        num_frames,
        track_in_reverse=False,  # tracking in reverse time order (for demo usage)
        run_mem_encoder=True,  # Whether to run the memory encoder on the predicted masks.
        prev_sam_mask_logits=None,  # The previously predicted SAM mask logits.
        frames_to_add_correction_pt=None,
        gt_masks=None,
    ):
        if frames_to_add_correction_pt is None:
            frames_to_add_correction_pt = []
        current_out, sam_outputs, high_res_features, pix_feat = self._track_step(
            frame_idx,
            is_init_cond_frame,
            current_vision_feats,
            current_vision_pos_embeds,
            current_local_context_vision_feats,
            current_local_context_vision_relative_pos,
            feat_sizes,
            point_inputs,
            mask_inputs,
            output_dict,
            num_frames,
            track_in_reverse,
            prev_sam_mask_logits,
        )

        (
            low_res_multimasks,
            high_res_multimasks,
            ious,
            low_res_masks,
            high_res_masks,
            obj_ptr,
            object_score_logits,
            object_token    # new for modifications on loss
        ) = sam_outputs

        current_out["object_token"] = object_token
        current_out["multistep_pred_masks"] = low_res_masks
        current_out["multistep_pred_masks_high_res"] = high_res_masks
        current_out["multistep_pred_multimasks"] = [low_res_multimasks]
        current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks]
        current_out["multistep_pred_ious"] = [ious]
        current_out["multistep_point_inputs"] = [point_inputs]
        current_out["multistep_object_score_logits"] = [object_score_logits]

        # Optionally, sample correction points iteratively to correct the mask
        if frame_idx in frames_to_add_correction_pt:
            point_inputs, final_sam_outputs = self._iter_correct_pt_sampling(
                is_init_cond_frame,
                point_inputs,
                gt_masks,
                high_res_features,
                pix_feat,
                low_res_multimasks,
                high_res_multimasks,
                ious,
                low_res_masks,
                high_res_masks,
                object_score_logits,
                current_out,
            )
            (
                _,
                _,
                _,
                low_res_masks,
                high_res_masks,
                obj_ptr,
                object_score_logits,
                object_token,
            ) = final_sam_outputs

        # Use the final prediction (after all correction steps for output and eval)
        current_out["pred_masks"] = low_res_masks
        current_out["pred_masks_high_res"] = high_res_masks
        current_out["obj_ptr"] = obj_ptr

        # Finally run the memory encoder on the predicted mask to encode
        # it into a new memory feature (that can be used in future frames)
        self._encode_memory_in_output(
            current_vision_feats,
            feat_sizes,
            point_inputs,
            run_mem_encoder,
            high_res_masks,
            object_score_logits,
            current_out,
        )
        return current_out

    def _iter_correct_pt_sampling(
        self,
        is_init_cond_frame,
        point_inputs,
        gt_masks,
        high_res_features,
        pix_feat_with_mem,
        low_res_multimasks,
        high_res_multimasks,
        ious,
        low_res_masks,
        high_res_masks,
        object_score_logits,
        current_out,
    ):

        assert gt_masks is not None
        all_pred_masks = [low_res_masks]
        all_pred_high_res_masks = [high_res_masks]
        all_pred_multimasks = [low_res_multimasks]
        all_pred_high_res_multimasks = [high_res_multimasks]
        all_pred_ious = [ious]
        all_point_inputs = [point_inputs]
        all_object_score_logits = [object_score_logits]
        for _ in range(self.num_correction_pt_per_frame):
            # sample a new point from the error between prediction and ground-truth
            # (with a small probability, directly sample from GT masks instead of errors)
            if self.training and self.prob_to_sample_from_gt_for_train > 0:
                sample_from_gt = (
                    self.rng.random() < self.prob_to_sample_from_gt_for_train
                )
            else:
                sample_from_gt = False
            # if `pred_for_new_pt` is None, only GT masks will be used for point sampling
            pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0)
            new_points, new_labels = get_next_point(
                gt_masks=gt_masks,
                pred_masks=pred_for_new_pt,
                method="uniform" if self.training else self.pt_sampling_for_eval,
            )
            point_inputs = concat_points(point_inputs, new_points, new_labels)
            # Feed the mask logits of the previous SAM outputs in the next SAM decoder step.
            # For tracking, this means that when the user adds a correction click, we also feed
            # the tracking output mask logits along with the click as input to the SAM decoder.
            mask_inputs = low_res_masks
            multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
            if self.use_act_ckpt_iterative_pt_sampling and not multimask_output:
                sam_outputs = torch.utils.checkpoint.checkpoint(
                    self._forward_sam_heads,
                    backbone_features=pix_feat_with_mem,
                    point_inputs=point_inputs,
                    mask_inputs=mask_inputs,
                    high_res_features=high_res_features,
                    multimask_output=multimask_output,
                    use_reentrant=False,
                )
            else:
                sam_outputs = self._forward_sam_heads(
                    backbone_features=pix_feat_with_mem,
                    point_inputs=point_inputs,
                    mask_inputs=mask_inputs,
                    high_res_features=high_res_features,
                    multimask_output=multimask_output,
                )
            (
                low_res_multimasks,
                high_res_multimasks,
                ious,
                low_res_masks,
                high_res_masks,
                _,
                object_score_logits,
                object_token,
            ) = sam_outputs
            all_pred_masks.append(low_res_masks)
            all_pred_high_res_masks.append(high_res_masks)
            all_pred_multimasks.append(low_res_multimasks)
            all_pred_high_res_multimasks.append(high_res_multimasks)
            all_pred_ious.append(ious)
            all_point_inputs.append(point_inputs)
            all_object_score_logits.append(object_score_logits)

        # Concatenate the masks along channel (to compute losses on all of them,
        # using `MultiStepIteractiveMasks`)
        current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1)
        current_out["multistep_pred_masks_high_res"] = torch.cat(
            all_pred_high_res_masks, dim=1
        )
        current_out["multistep_pred_multimasks"] = all_pred_multimasks
        current_out["multistep_pred_multimasks_high_res"] = all_pred_high_res_multimasks
        current_out["multistep_pred_ious"] = all_pred_ious
        current_out["multistep_point_inputs"] = all_point_inputs
        current_out["multistep_object_score_logits"] = all_object_score_logits

        return point_inputs, sam_outputs

    def _fuse_spatial_context(
        self,
        local_context_features: list[torch.Tensor],
        relative_positions: torch.Tensor,
        feat_sizes: list[tuple[int, int]],
    ):
        """
        Fuses spatial context features using our RPEAttention layer.
        """
        context_sequence = local_context_features[-1]
        H, W = feat_sizes[-1]

        # 1. 准备 Query, Key, Value
        Win = context_sequence.size(0) // (H * W)
        center_slice_start_idx = (Win // 2) * (H * W)
        center_slice_end_idx = center_slice_start_idx + (H * W)
        query_features = context_sequence[center_slice_start_idx:center_slice_end_idx, :, :]
        key_value_features = context_sequence

        fused_features_seq = self.spatial_attention_layer(
            query=query_features,
            key=key_value_features,
            value=key_value_features,
            relative_positions=relative_positions,
            H=H,
            W=W
        )
        # fused_features_seq shape: (H*W, B_stage, D)

        # 3. 重塑为图像特征格式
        B_stage, D = fused_features_seq.size(1), fused_features_seq.size(2)
        fused_features_map = fused_features_seq.permute(1, 2, 0).view(B_stage, D, H, W)

        return fused_features_map, context_sequence

    def _track_step(
        self,
        frame_idx,
        is_init_cond_frame,
        current_vision_feats,
        current_vision_pos_embeds,
        current_local_context_vision_feats,
        current_local_context_vision_relative_pos,
        feat_sizes,
        point_inputs,
        mask_inputs,
        output_dict,
        num_frames,
        track_in_reverse,
        prev_sam_mask_logits,
    ):
        current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs, "pos_logits": None, "pos_target": None}
        # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
        if len(current_vision_feats) > 1:
            high_res_features = [
                x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
                for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
            ]
        else:
            high_res_features = None

        # fused the visual feature with previous memory features in the memory bank
        single_pix_feat = self._prepare_memory_conditioned_features(
            frame_idx=frame_idx,
            is_init_cond_frame=is_init_cond_frame,
            current_vision_feats=current_vision_feats[-1:],
            current_vision_pos_embeds=current_vision_pos_embeds[-1:],
            feat_sizes=feat_sizes[-1:],
            output_dict=output_dict,
            num_frames=num_frames,
            track_in_reverse=track_in_reverse,
        )

        # fusion feature on spatial feature and their relative pos
        fused_pix_feat = single_pix_feat
        if current_local_context_vision_feats != [] and current_local_context_vision_relative_pos != []:
            spatial_context_feat, context_sequence = self._fuse_spatial_context(
                local_context_features=current_local_context_vision_feats,
                relative_positions=current_local_context_vision_relative_pos,
                feat_sizes=feat_sizes,
            )
            # --- 融合模块 ---
            fused_pix_feat = single_pix_feat + spatial_context_feat
            # logging.info(f"single_pix_feat:{single_pix_feat.shape}")
            # logging.info(f"spatial_context_feat:{spatial_context_feat.shape}")

        if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
            # When use_mask_input_as_output_without_sam=True, we directly output the mask input
            # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
            pix_feat = current_vision_feats[-1].permute(1, 2, 0)
            pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
            sam_outputs = self._use_mask_as_output(
                fused_pix_feat,    # using new fused feature here
                high_res_features,
                mask_inputs
            )
        else:
            # apply SAM-style segmentation head
            # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
            # e.g. in demo where such logits come from earlier interaction instead of correction sampling
            # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
            if prev_sam_mask_logits is not None:
                assert point_inputs is not None and mask_inputs is None
                mask_inputs = prev_sam_mask_logits
            multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
            sam_outputs = self._forward_sam_heads(
                backbone_features=fused_pix_feat,   # using new fused feature here
                point_inputs=point_inputs,
                mask_inputs=mask_inputs,
                high_res_features=high_res_features,
                multimask_output=multimask_output,
            )

        # --- 辅助任务: 计算位置预测损失 ---
        # 只有在训练时才计算损失
        if self.training:
            H, W = feat_sizes[-1]

            # 2. 调用预测头，输入是原始的上下文序列
            pos_logits = self.position_prediction_head(
                context_feature_sequence=context_sequence,
                H=H,
                W=W
            )

            self.window_radius_k = 2
            targets = current_local_context_vision_relative_pos + self.window_radius_k
            B_stage, Win, C = pos_logits.shape

            current_out['pos_logits'] = pos_logits
            current_out['pos_target'] = targets

            # loss_pos = self.position_loss_fn(
            #     pos_logits.view(B_stage * Win, C),  # (B*Win, C)
            #     targets.view(B_stage * Win).long()  # (B*Win,)
            # )

            # self._visualize_predictions(pos_logits, targets, frame_idx)
            # logging.info(f"Frame {frame_idx}, Auxiliary Relative Position Loss: {loss_pos.item():.4f}")
            # exit(0)

        return current_out, sam_outputs, high_res_features, fused_pix_feat

    def _visualize_predictions(
        self,
        pos_logits: torch.Tensor,
        targets: torch.Tensor,
        frame_idx: int,
        max_items_to_show: int = 4
    ):
        """
        Prints a formatted table to visualize the predictions of the RelativePositionHead
        against the ground truth targets.

        Args:
            pos_logits (torch.Tensor): Raw logits from the prediction head.
                                    Shape: (B_stage, Win, num_classes).
            targets (torch.Tensor): Ground truth class indices. Shape: (B_stage, Win).
            frame_idx (int): The index of the current central frame being processed.
            max_items_to_show (int): Maximum number of batch items to visualize to avoid
                                    spamming the console.
        """
        # Detach tensors from the computation graph and move to CPU for processing
        pos_logits = pos_logits.detach().cpu()
        targets = targets.detach().cpu()

        # Calculate predicted indices and confidence
        probs = F.softmax(pos_logits, dim=-1)
        confidences, predicted_indices = torch.max(probs, dim=-1)

        B_stage, Win = targets.shape
        num_to_show = min(B_stage, max_items_to_show)

        logging.info("=" * 80)
        logging.info(f"VISUALIZING RELATIVE POSITION PREDICTIONS (Frame: {frame_idx})")
        logging.info(f"Showing first {num_to_show} of {B_stage} items in the batch.")
        logging.info("=" * 80)

        for i in range(num_to_show):
            logging.info(f"--- Batch Item {i+1} / {B_stage} ---")
            header = (
                f"{'Slice in Window':<18} | {'Ground Truth':<15} | {'Prediction':<15} | "
                f"{'Confidence':<12} | {'Correct?':<10}"
            )
            logging.info(header)
            logging.info("-" * len(header))

            for j in range(Win):
                gt_index = targets[i, j].item()
                pred_index = predicted_indices[i, j].item()
                confidence = confidences[i, j].item()

                # Convert class indices back to relative positions [-k, +k]
                gt_pos = gt_index - self.window_radius_k
                pred_pos = pred_index - self.window_radius_k

                is_correct = (gt_index == pred_index)
                correct_symbol = "✅" if is_correct else "❌"

                log_line = (
                    f"Slice {j+1:<12d} | Pos: {gt_pos:<+10d} | Pos: {pred_pos:<+10d} | "
                    f"{confidence:^12.2%} | {correct_symbol:^10}"
                )
                logging.info(log_line)
            logging.info("-" * 80)

    def _forward_sam_heads(
        self,
        backbone_features,
        point_inputs=None,
        mask_inputs=None,
        high_res_features=None,
        multimask_output=False,
    ):
        """
        Forward SAM prompt encoders and mask heads.

        Inputs:
        - backbone_features: image features of [B, C, H, W] shape
        - point_inputs: a dictionary with "point_coords" and "point_labels", where
          1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the
             absolute pixel-unit coordinate in (x, y) format of the P input points
          2) "point_labels" has shape [B, P] and int32 dtype, where 1 means
             positive clicks, 0 means negative clicks, and -1 means padding
        - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the
          same spatial size as the image.
        - high_res_features: either 1) None or 2) or a list of length 2 containing
          two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively,
          which will be used as high-resolution feature maps for SAM decoder.
        - multimask_output: if it's True, we output 3 candidate masks and their 3
          corresponding IoU estimates, and if it's False, we output only 1 mask and
          its corresponding IoU estimate.

        Outputs:
        - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if
          `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM
          output mask logits (before sigmoid) for the low-resolution masks, with 4x
          the resolution (1/4 stride) of the input backbone_features.
        - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3
          if `multimask_output=True` and M = 1 if `multimask_output=False`),
          upsampled from the low-resolution masks, with shape size as the image
          (stride is 1 pixel).
        - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1
          if `multimask_output=False`), the estimated IoU of each output mask.
        - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`.
          If `multimask_output=True`, it's the mask with the highest IoU estimate.
          If `multimask_output=False`, it's the same as `low_res_multimasks`.
        - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`.
          If `multimask_output=True`, it's the mask with the highest IoU estimate.
          If `multimask_output=False`, it's the same as `high_res_multimasks`.
        - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted
          based on the output token from the SAM mask decoder.
        """
        B = backbone_features.size(0)
        device = backbone_features.device
        assert backbone_features.size(1) == self.sam_prompt_embed_dim
        assert backbone_features.size(2) == self.sam_image_embedding_size
        assert backbone_features.size(3) == self.sam_image_embedding_size

        # a) Handle point prompts
        if point_inputs is not None:
            sam_point_coords = point_inputs["point_coords"]
            sam_point_labels = point_inputs["point_labels"]
            assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
        else:
            # If no points are provide, pad with an empty point (with label -1)
            sam_point_coords = torch.zeros(B, 1, 2, device=device)
            sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)

        # b) Handle mask prompts
        if mask_inputs is not None:
            # If mask_inputs is provided, downsize it into low-res mask input if needed
            # and feed it as a dense mask prompt into the SAM mask encoder
            assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
            if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
                sam_mask_prompt = F.interpolate(
                    mask_inputs.float(),
                    size=self.sam_prompt_encoder.mask_input_size,
                    align_corners=False,
                    mode="bilinear",
                    antialias=True,  # use antialias for downsampling
                )
            else:
                sam_mask_prompt = mask_inputs
        else:
            # Otherwise, simply feed None (and SAM's prompt encoder will add
            # a learned `no_mask_embed` to indicate no mask input in this case).
            sam_mask_prompt = None

        # PROMPT_ENCODER
        sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
            points=(sam_point_coords, sam_point_labels),
            boxes=None,
            masks=sam_mask_prompt,
        )

        # MASK_DECODER
        (
            low_res_multimasks,
            ious,
            sam_output_tokens,
            object_score_logits,
        ) = self.sam_mask_decoder(
            image_embeddings=backbone_features,
            image_pe=self.sam_prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=multimask_output,
            repeat_image=False,  # the image is already batched
            high_res_features=high_res_features,
        )
        if self.pred_obj_scores:
            is_obj_appearing = object_score_logits > 0

            # Mask used for spatial memories is always a *hard* choice between obj and no obj,
            # consistent with the actual mask prediction
            low_res_multimasks = torch.where(
                is_obj_appearing[:, None, None],
                low_res_multimasks,
                NO_OBJ_SCORE,
            )

        # convert masks from possibly bfloat16 (or float16) to float32
        # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
        low_res_multimasks = low_res_multimasks.float()
        high_res_multimasks = F.interpolate(
            low_res_multimasks,
            size=(self.image_size, self.image_size),
            mode="bilinear",
            align_corners=False,
        )

        sam_output_token = sam_output_tokens[:, 0]
        if multimask_output:
            # take the best mask prediction (with the highest IoU estimation)
            best_iou_inds = torch.argmax(ious, dim=-1)
            batch_inds = torch.arange(B, device=device)
            low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
            high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
            if sam_output_tokens.size(1) > 1:
                sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
        else:
            low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks

        # Extract object pointer from the SAM output token (with occlusion handling)
        obj_ptr = self.obj_ptr_proj(sam_output_token)
        if self.pred_obj_scores:
            # Allow *soft* no obj ptr, unlike for masks
            if self.soft_no_obj_ptr:
                lambda_is_obj_appearing = object_score_logits.sigmoid()
            else:
                lambda_is_obj_appearing = is_obj_appearing.float()

            if self.fixed_no_obj_ptr:
                obj_ptr = lambda_is_obj_appearing * obj_ptr
            obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr

        return (
            low_res_multimasks,
            high_res_multimasks,
            ious,
            low_res_masks,
            high_res_masks,
            obj_ptr,
            object_score_logits,
            sam_output_token    # minimize modification on return loss calculation
        )

    def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
        """
        Directly turn binary `mask_inputs` into a output mask logits without using SAM.
        (same input and output shapes as in _forward_sam_heads above).
        """
        # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
        out_scale, out_bias = 20.0, -10.0  # sigmoid(-10.0)=4.5398e-05
        mask_inputs_float = mask_inputs.float()
        high_res_masks = mask_inputs_float * out_scale + out_bias
        low_res_masks = F.interpolate(
            high_res_masks,
            size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
            align_corners=False,
            mode="bilinear",
            antialias=True,  # use antialias for downsampling
        )

        # returned object token
        object_token = None

        # a dummy IoU prediction of all 1's under mask input
        ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
        if not self.use_obj_ptrs_in_encoder:
            # all zeros as a dummy object pointer (of shape [B, C])
            obj_ptr = torch.zeros(
                mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device
            )
            object_token = torch.zeros_like(obj_ptr)  # 形状和设备都与obj_ptr相同
        else:
            # produce an object pointer using the SAM decoder from the mask input
            _, _, _, _, _, obj_ptr, _, object_token = self._forward_sam_heads(
                backbone_features=backbone_features,
                mask_inputs=self.mask_downsample(mask_inputs_float),
                high_res_features=high_res_features,
            )
        # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
        # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
        # on the object_scores from the SAM decoder.
        is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
        is_obj_appearing = is_obj_appearing[..., None]
        lambda_is_obj_appearing = is_obj_appearing.float()
        object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
        if self.pred_obj_scores:
            if self.fixed_no_obj_ptr:
                obj_ptr = lambda_is_obj_appearing * obj_ptr
            obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr

        return (
            low_res_masks,
            high_res_masks,
            ious,
            low_res_masks,
            high_res_masks,
            obj_ptr,
            object_score_logits,
            object_token
        )
