# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging

import numpy as np
import torch
import torch.distributed
from sam2.modeling.sam2_base import SAM2Base
from sam2.modeling.sam2_utils import (
    get_1d_sine_pe,
    get_next_point,
    sample_box_points,
    select_closest_cond_frames,
)

from sam2.utils.misc import concat_points

from training.utils.data_utils import BatchedVideoDatapoint

from sam2.modeling.boundary_module import BoundaryHead
from sam2.modeling.spatial_module import *
from training.utils.spatial_utils import RelativePositionLoss


class SAM2Train(SAM2Base):
    def __init__(
        self,
        image_encoder,
        memory_attention=None,
        memory_encoder=None,
        prob_to_use_pt_input_for_train=0.0,
        prob_to_use_pt_input_for_eval=0.0,
        prob_to_use_box_input_for_train=0.0,
        prob_to_use_box_input_for_eval=0.0,
        # if it is greater than 1, we interactive point sampling in the 1st frame and other randomly selected frames
        num_frames_to_correct_for_train=1,  # default: only iteratively sample on first frame
        num_frames_to_correct_for_eval=1,  # default: only iteratively sample on first frame
        rand_frames_to_correct_for_train=False,
        rand_frames_to_correct_for_eval=False,
        # how many frames to use as initial conditioning frames (for both point input and mask input; the first frame is always used as an initial conditioning frame)
        # - if `rand_init_cond_frames` below is True, we randomly sample 1~num_init_cond_frames initial conditioning frames
        # - otherwise we sample a fixed number of num_init_cond_frames initial conditioning frames
        # note: for point input, we sample correction points on all such initial conditioning frames, and we require that `num_frames_to_correct` >= `num_init_cond_frames`;
        # these are initial conditioning frames because as we track the video, more conditioning frames might be added
        # when a frame receives correction clicks under point input if `add_all_frames_to_correct_as_cond=True`
        num_init_cond_frames_for_train=1,  # default: only use the first frame as initial conditioning frame
        num_init_cond_frames_for_eval=1,  # default: only use the first frame as initial conditioning frame
        # default: random 1~num_init_cond_frames_for_train cond frames (to be constent w/ previous TA data loader)
        rand_init_cond_frames_for_train=True,
        rand_init_cond_frames_for_eval=False,
        # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
        # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
        add_all_frames_to_correct_as_cond=False,
        # how many additional correction points to sample (on each frame selected to be corrected)
        # note that the first frame receives an initial input click (in addition to any correction clicks)
        num_correction_pt_per_frame=7,
        # method for point sampling during evaluation
        # "uniform" (sample uniformly from error region) or "center" (use the point with the largest distance to error region boundary)
        # default to "center" to be consistent with evaluation in the SAM paper
        pt_sampling_for_eval="center",
        # During training, we optionally allow sampling the correction points from GT regions
        # instead of the prediction error regions with a small probability. This might allow the
        # model to overfit less to the error regions in training datasets
        prob_to_sample_from_gt_for_train=0.0,
        use_act_ckpt_iterative_pt_sampling=False,
        # whether to forward image features per frame (as it's being tracked) during evaluation, instead of forwarding image features
        # of all frames at once. This avoids backbone OOM errors on very long videos in evaluation, but could be slightly slower.
        forward_backbone_per_frame_for_eval=False,
        freeze_image_encoder=True,
        boundary_head=None,
        **kwargs,
    ):
        super().__init__(image_encoder, memory_attention, memory_encoder, **kwargs)
        self.use_act_ckpt_iterative_pt_sampling = use_act_ckpt_iterative_pt_sampling
        self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval

        # Point sampler and conditioning frames
        self.prob_to_use_pt_input_for_train = prob_to_use_pt_input_for_train
        self.prob_to_use_box_input_for_train = prob_to_use_box_input_for_train
        self.prob_to_use_pt_input_for_eval = prob_to_use_pt_input_for_eval
        self.prob_to_use_box_input_for_eval = prob_to_use_box_input_for_eval
        if prob_to_use_pt_input_for_train > 0 or prob_to_use_pt_input_for_eval > 0:
            logging.info(
                f"Training with points (sampled from masks) as inputs with p={prob_to_use_pt_input_for_train}"
            )
            assert num_frames_to_correct_for_train >= num_init_cond_frames_for_train
            assert num_frames_to_correct_for_eval >= num_init_cond_frames_for_eval

        self.num_frames_to_correct_for_train = num_frames_to_correct_for_train
        self.num_frames_to_correct_for_eval = num_frames_to_correct_for_eval
        self.rand_frames_to_correct_for_train = rand_frames_to_correct_for_train
        self.rand_frames_to_correct_for_eval = rand_frames_to_correct_for_eval
        # Initial multi-conditioning frames
        self.num_init_cond_frames_for_train = num_init_cond_frames_for_train
        self.num_init_cond_frames_for_eval = num_init_cond_frames_for_eval
        self.rand_init_cond_frames_for_train = rand_init_cond_frames_for_train
        self.rand_init_cond_frames_for_eval = rand_init_cond_frames_for_eval
        self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
        self.num_correction_pt_per_frame = num_correction_pt_per_frame
        self.pt_sampling_for_eval = pt_sampling_for_eval
        self.prob_to_sample_from_gt_for_train = prob_to_sample_from_gt_for_train
        # A random number generator with a fixed initial seed across GPUs
        self.rng = np.random.default_rng(seed=42)

        # ====== MODIFICATIONS ======
        # ------ boundary feature -----
        self.use_boundary = True
        self.boundary_head = boundary_head

        # ------ volume feature -------
        self.use_volume = True
        self.spatial_attention_layer = RPEAttention()
        self.position_prediction_head = RelativePositionHead()
        self.position_loss_fn = nn.CrossEntropyLoss()

        logging.info(f"freeze_image_encoder status:{freeze_image_encoder}")
        if freeze_image_encoder:
            for p in self.image_encoder.parameters():
                p.requires_grad = False

    def forward_image(self, img_batch: torch.Tensor):
        """New forward image section derives from base class function

        """

        """Get the image feature on the input batch."""
        backbone_out = self.image_encoder(img_batch)
        if self.use_high_res_features_in_sam:
            # precompute projected level 0 and level 1 features in SAM decoder
            # to avoid running it again on every SAM click
            backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
                backbone_out["backbone_fpn"][0]
            )
            backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
                backbone_out["backbone_fpn"][1]
            )

        # logging.info(
        #     f'backbone_out["backbone_fpn"] type:{type(backbone_out["backbone_fpn"])}, len:{len(backbone_out["backbone_fpn"])}, 0:shape:{backbone_out["backbone_fpn"][0].shape}, 1:shape:{backbone_out["backbone_fpn"][1].shape}, 2:shape:{backbone_out["backbone_fpn"][2].shape}')
        # logging.info(f'backbone_out["vision_features"] type:{type(backbone_out["vision_features"])}, shape:{backbone_out["vision_features"].shape}')
        # logging.info(
        #     f'backbone_out["vision_pos_enc"] type:{type(backbone_out["vision_pos_enc"])}, len:{len(backbone_out["vision_pos_enc"])}, 0:shape:{backbone_out["vision_pos_enc"][0].shape}, 1:shape:{backbone_out["vision_pos_enc"][1].shape}, 2:shape:{backbone_out["vision_pos_enc"][2].shape}')

        # --- Integration of the Boundary Head ---
        if self.use_boundary:
            # logging.info("Using boundary head to enhance high-res features.")

            # Select the highest-resolution feature map as input
            high_res_fpn = backbone_out['backbone_fpn'][0]

            # Get the enhanced feature and the boundary logits
            enhanced_fpn_feature, boundary_logits = self.boundary_head(high_res_fpn)

            # 1. Replace the original FPN feature with the enhanced one
            backbone_out['backbone_fpn'][0] = enhanced_fpn_feature

            # 2. Store the new boundary logits for loss calculation
            backbone_out["boundary_logits"] = boundary_logits

        if not self.training:
            spacial_out = self.process_spatial_feature_v2_infer(backbone_out)  # already directly inject into backbone_out

        return backbone_out

    def forward(self, input: BatchedVideoDatapoint):
        if self.training or not self.forward_backbone_per_frame_for_eval:
            # precompute image features on all frames before tracking
            backbone_out = self.forward_image(input.flat_img_batch)
        else:
            # defer image feature computation on a frame until it's being tracked
            backbone_out = {"backbone_fpn": None, "vision_pos_enc": None}
        backbone_out = self.prepare_prompt_inputs(backbone_out, input)

        # if self.training:
        # Add operations to get center_img_features and context_img_features and their target layers
        spacial_out = self.process_spatial_feature_v2(backbone_out)  # already directly inject into backbone_out

        # enriched_observer_feature, extra_rpe_feature = self.RPE_module(backbone_out, spacial_out)

        previous_stages_out = self.forward_tracking(backbone_out, input)

        return previous_stages_out

    def RPE_module(self, backbone_out, spatial_out, viz_pred_slice=True):
        # 从 spatial_out 中解包所需张量
        observer_features_map = spatial_out['observer_features']  # Shape: (B*C, N, D, Hf, Wf)
        target_features_map = spatial_out['target_features']     # Shape: (B*C, N, D, Hf, Wf)
        relative_positions = spatial_out['relative_positions']   # Shape: (B*C, N)
        backbone_out['relative_positions'] = relative_positions

        # 1. 特征准备: 将特征图池化为特征向量
        #    为什么要进行全局平均池化 (Global Average Pooling)?
        #    - 注意力机制处理的是序列数据 (Batch, Seq_Len, Dim)，而不是4D特征图。
        #    - GAP能将每个切片的 (D, Hf, Wf) 特征图总结成一个 (D,) 的特征向量，代表该切片的全局信息。
        #    - 这是一个高效且常用的方法，用于连接卷积网络和Transformer/Attention层。

        # observer_features_map[:, 0, ...] 取出每个观察者自己的特征图
        # .mean(dim=(-2, -1)) 在 Hf, Wf 维度上进行平均池化
        observer_pooled = observer_features_map[:, 0, ...].mean(dim=(-2, -1)).unsqueeze(1)  # Shape: (B*C, 1, D)

        B_eff, N, D, Hf, Wf = target_features_map.shape
        # 将 target_features_map 从 (B*C, N, D, Hf, Wf) -> (B*C * N, D, Hf, Wf) 再池化
        target_pooled = target_features_map.reshape(B_eff * N, D, Hf, Wf).mean(dim=(-2, -1))
        target_pooled = target_pooled.view(B_eff, N, D)  # Shape: (B*C, N, D)

        # 2. RPE注意力模块: 增强观察者特征
        #    用整个空间上下文(targets)来丰富观察者(observer)的特征
        enriched_observer_feature = self.rpe_attention(
            query=observer_pooled,
            key=target_pooled,
            value=target_pooled,
            relative_positions=relative_positions
        )  # Output Shape: (B*C, 1, D)

        extra_rpe_feature = {}
        extra_rpe_feature['observer_pooled'] = observer_pooled
        extra_rpe_feature['target_pooled'] = target_pooled

        # logging.info(f"predicted_pos_logits:{predicted_pos_logits.shape}")
        # if self.training and viz_pred_slice:
        #     self.log_position_predictions(
        #         predicted_logits=predicted_pos_logits.detach(),  # detach() 避免影响梯度计算
        #         true_relative_positions=relative_positions,
        #         max_rel_dist=self.rpe_attention.max_rel_dist  # 从模块中获取超参数
        #     )

        # spatial_context_loss = self.relative_pos_loss(
        #     predicted_pos_logits,
        #     relative_positions
        # )

        # logging.info(f"spatial_context_loss:{spatial_context_loss}")

        return enriched_observer_feature, extra_rpe_feature

    def log_position_predictions(self,
                                 predicted_logits: torch.Tensor,
                                 true_relative_positions: torch.Tensor,
                                 max_rel_dist: int,
                                 num_examples: int = 2):
        """
        Logs a comparison of predicted vs. true relative positions for debugging.

        Args:
            predicted_logits (torch.Tensor): The model's output logits. Shape: (B, N, NumClasses)
            true_relative_positions (torch.Tensor): The ground truth positions. Shape: (B, N)
            max_rel_dist (int): The hyperparameter `k` used for position calculation.
            num_examples (int): How many examples from the batch to log.
        """
        if not self.training:  # 只在训练时打印，避免干扰评估
            return

        # 1. 从 Logits 中获取预测的类别索引
        #    torch.argmax 在最后一个维度 (NumClasses) 上找到最大值的索引
        predicted_indices = torch.argmax(predicted_logits, dim=-1)  # Shape: (B, N)

        # 2. 将预测的类别索引 [0, 2k] 转换回相对位置 [-k, k]
        #    这是 `RelativePositionLoss` 中 `target_indices = pos + max_rel_dist` 的逆运算
        predicted_positions = predicted_indices - max_rel_dist

        # 确保只记录部分样本，防止刷屏
        num_to_log = min(num_examples, predicted_logits.shape[0])

        logging.info("--- [Spatial Context] Position Prediction Check ---")
        for i in range(num_to_log):
            true_pos_sample = true_relative_positions[i].cpu().tolist()
            pred_pos_sample = predicted_positions[i].cpu().tolist()

            # 计算该样本的预测准确率
            correct_predictions = torch.eq(true_relative_positions[i], predicted_positions[i]).sum().item()
            accuracy = (correct_predictions / len(true_pos_sample)) * 100

            log_message = (
                f"\n[Observer Example {i+1}]:"
                f"\n  => Ground Truth Rel. Pos: {true_pos_sample}"
                f"\n  => Predicted Rel. Pos:    {pred_pos_sample}"
                f"\n  => Accuracy for this observer: {accuracy:.2f}%"
            )
            logging.info(log_message)
        logging.info("--------------------------------------------------")

    def process_spatial_feature(self, backbone_out):
        spacial_out = {}

        num_frames = backbone_out['num_frames']
        B = backbone_out['vision_features'].shape[0]
        init_cond_frames = backbone_out['init_cond_frames']  # get core image slice
        frames_not_in_init_cond = backbone_out['frames_not_in_init_cond']  # context image slice

        batch_size = B // num_frames
        fpn_feature_volume = []
        for feat in backbone_out['backbone_fpn']:
            C, H, W = feat.shape[1], feat.shape[2], feat.shape[3]
            fpn_feature_volume.append(feat.reshape(batch_size, num_frames, C, H, W))

        core_img_features = {}
        context_img_features = []
        for slice_idx in init_cond_frames:
            features_for_this_slice = []

            for fpn_level_tensor in fpn_feature_volume:
                sliced_feature = fpn_level_tensor[:, slice_idx: slice_idx + 1, :, :, :]  # get 5D tensor
                features_for_this_slice.append(sliced_feature)

            core_img_features[slice_idx] = features_for_this_slice

        logging.info(f"core_img_features:{core_img_features[1][0].shape}")
        exit(0)

        logging.info(f"keys:{backbone_out.keys()}")
        logging.info(f"num_frames:{backbone_out['num_frames']}")
        logging.info(f"gt_masks_per_frame:{len(backbone_out['gt_masks_per_frame'])}")
        logging.info(f"init_cond_frames:{backbone_out['init_cond_frames']}")
        logging.info(f"frames_not_in_init_cond:{backbone_out['frames_not_in_init_cond']}")
        logging.info(f"vision_features:{backbone_out['vision_features'].shape}")
        logging.info(f"vision_pos_enc(is a list):{len(backbone_out['vision_pos_enc'])}, shape of element{backbone_out['vision_pos_enc'][1].shape}")
        logging.info(f"backbone_fpn[0]:{backbone_out['backbone_fpn'][0].shape}")
        logging.info(f"backbone_fpn[1]:{backbone_out['backbone_fpn'][1].shape}")
        logging.info(f"backbone_fpn[2]:{backbone_out['backbone_fpn'][2].shape}")

        exit(0)

        return spacial_out

    def process_spatial_feature_v2(
        self, backbone_out: dict, local_context_window_size: int = 2
    ):
        """
        Processes backbone features to generate TWO distinct sets of structured outputs.
        This function ONLY structures the data; it does not flatten it for a decoder.

        1.  "All-Perspectives" (Global Context): For each core "observer" slice, it pairs it
            with ALL other slices. This is for advanced cross-attention modules. The output
            batch dimension is (B * num_cores).

        2.  "Per-Slice Local Context": For EVERY slice, it bundles its features with those
            of its immediate neighbors (in a window of size 2*k+1). This output is
            structured for a per-slice decoder, with a batch dimension of (B * num_frames).

        Args:
            backbone_out (dict): The output from the `forward_image` function.
            local_context_window_size (int): Radius of the local context window (k).
                                            k=2 results in a total window of 5 slices.

        Returns:
            dict: The input `backbone_out` dictionary, updated with two new keys:
                - 'spatial_out': The "all-perspectives" processed output.
                - 'per_slice_local_context': The new processed output with local context.
        """
        # =================================================================================
        # 通用准备阶段 (Common Preparation Stage)
        # =================================================================================
        device = backbone_out["vision_features"].device
        num_frames = backbone_out["num_frames"]
        k = local_context_window_size

        def _reshape_to_volume(flat_tensor: torch.Tensor) -> torch.Tensor:
            B_flat, *rest_dims = flat_tensor.shape
            batch_size = B_flat // num_frames
            return flat_tensor.view(batch_size, num_frames, *rest_dims)

        fpn_feature_volume = [_reshape_to_volume(f) for f in backbone_out["backbone_fpn"]]
        vision_features_volume = _reshape_to_volume(backbone_out["vision_features"])
        B, N, D_vit, Hf_vit, Wf_vit = vision_features_volume.shape

        # =================================================================================
        # 第一部分：计算 "All-Perspectives" 全局上下文 (保留原始逻辑)
        # =================================================================================
        # logging.info("--- Part 1: Processing All-Perspectives Global Context ---")
        core_indices = torch.tensor(
            backbone_out["init_cond_frames"], dtype=torch.long, device=device
        )
        C = len(core_indices)
        all_indices = torch.arange(N, device=device)

        # --- ViT Features ---
        observer_features = vision_features_volume[:, core_indices, ...]
        observer_features_expanded = observer_features.unsqueeze(2).expand(-1, -1, N, -1, -1, -1)
        target_features_expanded = vision_features_volume.unsqueeze(1).expand(-1, C, -1, -1, -1, -1)

        # --- FPN Features ---
        observer_fpn_expanded_list, target_fpn_expanded_list = [], []
        for fpn_level in fpn_feature_volume:
            observer_fpn = fpn_level[:, core_indices, ...]
            observer_fpn_expanded_list.append(
                observer_fpn.unsqueeze(2).expand(-1, -1, N, -1, -1, -1)
            )
            target_fpn_expanded_list.append(
                fpn_level.unsqueeze(1).expand(-1, C, -1, -1, -1, -1)
            )

        # --- Relative Positions ---
        relative_positions = all_indices.view(1, 1, N) - core_indices.view(1, C, 1)

        # --- 拉平为 (B*C) 的批次维度 ---
        new_batch_size_global = B * C
        observer_features_flat = observer_features_expanded.reshape(new_batch_size_global, N, D_vit, Hf_vit, Wf_vit)
        target_features_flat = target_features_expanded.reshape(new_batch_size_global, N, D_vit, Hf_vit, Wf_vit)
        observer_fpn_flat = [fpn.reshape(new_batch_size_global, N, *fpn.shape[3:]) for fpn in observer_fpn_expanded_list]
        target_fpn_flat = [fpn.reshape(new_batch_size_global, N, *fpn.shape[3:]) for fpn in target_fpn_expanded_list]
        relative_positions_flat = relative_positions.expand(B, -1, -1).reshape(new_batch_size_global, N)

        spatial_out = {
            "observer_features": observer_features_flat,
            "target_features": target_features_flat,
            "observer_fpn": observer_fpn_flat,
            "target_fpn": target_fpn_flat,
            "relative_positions": relative_positions_flat,
            "metadata": {"original_batch_size": B, "num_cores": C, "num_frames": N},
        }
        print(f"get spatial_out_here")
        backbone_out["spatial_out"] = spatial_out
        # logging.info(f"Generated 'spatial_out' (for global context) with new batch size {new_batch_size_global}.")

        # =================================================================================
        # 第二部分：为每个切片构建局部上下文 (新增逻辑)
        # =================================================================================
        # logging.info(f"--- Part 2: Processing Per-Slice Local Context (Window k={k}) ---")

        window_size = 2 * k + 1
        local_offsets = torch.arange(-k, k + 1, device=device)
        slice_indices = torch.arange(N, device=device)
        neighbor_indices = slice_indices.view(N, 1) + local_offsets.view(1, window_size)
        clamped_neighbor_indices = torch.clamp(neighbor_indices, 0, N - 1)

        local_context_fpn_list = []
        for fpn_level in fpn_feature_volume:
            # fpn_level shape: (B, N, FPN_D, FPN_H, FPN_W)
            # Using advanced indexing to gather neighbors for each slice.
            # Result shape: (B, N, window_size, FPN_D, FPN_H, FPN_W)
            gathered_features = fpn_level[:, clamped_neighbor_indices, ...]
            local_context_fpn_list.append(gathered_features)

        # --- 拉平为 (B*N) 的批次维度 ---
        new_batch_size_local = B * N
        final_local_context_fpn = []
        for fpn_context in local_context_fpn_list:
            # from (B, N, window_size, ...) to (B*N, window_size, ...)
            _, _, _, FPN_D, FPN_H, FPN_W = fpn_context.shape
            reshaped_fpn = fpn_context.reshape(new_batch_size_local, window_size, FPN_D, FPN_H, FPN_W)
            final_local_context_fpn.append(reshaped_fpn)

        relative_positions_local = local_offsets.view(1, -1).expand(new_batch_size_local, -1)

        per_slice_local_context = {
            "local_context_fpn": final_local_context_fpn,
            "local_relative_positions": relative_positions_local,
            "metadata": {
                "original_batch_size": B,
                "num_frames": N,
                "window_size": window_size,
            },
        }
        backbone_out["slice_spatial_out"] = per_slice_local_context
        # logging.info(f"Generated 'per_slice_local_context' with new batch size {new_batch_size_local}.")
        # logging.info(f"  - Shape of local_context_fpn[0]: {per_slice_local_context['local_context_fpn'][0].shape}")

        return backbone_out

    def process_spatial_feature_v2_infer(
        self, backbone_out: dict, local_context_window_size: int = 2
    ):
        """
        Processes backbone features to generate TWO distinct sets of structured outputs.
        This function is now robust for both training and inference.

        1.  "All-Perspectives" (Global Context): For each core "observer" slice, it pairs it
            with ALL other slices. This part is ONLY executed during training when
            'init_cond_frames' is available. During inference, 'spatial_out' will be None.

        2.  "Per-Slice Local Context": For EVERY slice, it bundles its features with those
            of its immediate neighbors. This is executed in BOTH training and inference,
            as it's required by the new tracking mechanism.

        Args:
            backbone_out (dict): The output from the `forward_image` function.
            local_context_window_size (int): Radius of the local context window (k).
                                            k=2 results in a total window of 5 slices.

        Returns:
            dict: The input `backbone_out` dictionary, updated with two keys:
                - 'spatial_out': The "all-perspectives" processed output, or None during inference.
                - 'slice_spatial_out': The local context output, always generated.
        """
        # =================================================================================
        # 通用准备阶段 (Common Preparation Stage)
        # =================================================================================
        device = backbone_out["vision_features"].device
        num_frames = 1
        k = local_context_window_size

        def _reshape_to_volume(flat_tensor: torch.Tensor) -> torch.Tensor:
            B_flat, *rest_dims = flat_tensor.shape
            # Handle the case where B_flat might not be divisible by num_frames, though it should be.
            if B_flat % num_frames != 0:
                raise ValueError(f"Total batch size {B_flat} is not divisible by num_frames {num_frames}.")
            batch_size = B_flat // num_frames
            return flat_tensor.view(batch_size, num_frames, *rest_dims)

        fpn_feature_volume = [_reshape_to_volume(f) for f in backbone_out["backbone_fpn"]]
        vision_features_volume = _reshape_to_volume(backbone_out["vision_features"])
        B, N, D_vit, Hf_vit, Wf_vit = vision_features_volume.shape

        # =================================================================================
        # 第一部分：计算 "All-Perspectives" 全局上下文 (条件执行)
        # =================================================================================
        # This block is now conditional. It only runs if `init_cond_frames` is present and not empty,
        # which is typical for the training phase.
        backbone_out["init_cond_frames"] = [0]
        if "init_cond_frames" in backbone_out and backbone_out["init_cond_frames"]:
            # logging.info("--- Part 1: Processing All-Perspectives Global Context (Training Mode) ---")
            core_indices = torch.tensor(
                backbone_out["init_cond_frames"], dtype=torch.long, device=device
            )
            C = len(core_indices)
            all_indices = torch.arange(N, device=device)

            # --- ViT Features ---
            observer_features = vision_features_volume[:, core_indices, ...]
            observer_features_expanded = observer_features.unsqueeze(2).expand(-1, -1, N, -1, -1, -1)
            target_features_expanded = vision_features_volume.unsqueeze(1).expand(-1, C, -1, -1, -1, -1)

            # --- FPN Features ---
            observer_fpn_expanded_list, target_fpn_expanded_list = [], []
            for fpn_level in fpn_feature_volume:
                observer_fpn = fpn_level[:, core_indices, ...]
                observer_fpn_expanded_list.append(
                    observer_fpn.unsqueeze(2).expand(-1, -1, N, -1, -1, -1)
                )
                target_fpn_expanded_list.append(
                    fpn_level.unsqueeze(1).expand(-1, C, -1, -1, -1, -1)
                )

            # --- Relative Positions ---
            relative_positions = all_indices.view(1, 1, N) - core_indices.view(1, C, 1)

            # --- 拉平为 (B*C) 的批次维度 ---
            new_batch_size_global = B * C
            observer_features_flat = observer_features_expanded.reshape(new_batch_size_global, N, D_vit, Hf_vit, Wf_vit)
            target_features_flat = target_features_expanded.reshape(new_batch_size_global, N, D_vit, Hf_vit, Wf_vit)
            observer_fpn_flat = [fpn.reshape(new_batch_size_global, N, *fpn.shape[3:]) for fpn in observer_fpn_expanded_list]
            target_fpn_flat = [fpn.reshape(new_batch_size_global, N, *fpn.shape[3:]) for fpn in target_fpn_expanded_list]
            relative_positions_flat = relative_positions.expand(B, -1, -1).reshape(new_batch_size_global, N)

            spatial_out = {
                "observer_features": observer_features_flat,
                "target_features": target_features_flat,
                "observer_fpn": observer_fpn_flat,
                "target_fpn": target_fpn_flat,
                "relative_positions": relative_positions_flat,
                "metadata": {"original_batch_size": B, "num_cores": C, "num_frames": N},
            }
            backbone_out["spatial_out"] = spatial_out
            # logging.info(f"Generated 'spatial_out' (for global context) with new batch size {new_batch_size_global}.")
        else:
            # This branch is taken during inference or when no initial frames are provided.
            # We set `spatial_out` to None to signal that it wasn't computed.
            # logging.info("--- Part 1: Skipping All-Perspectives Global Context (Inference Mode) ---")
            backbone_out["spatial_out"] = None
            # print("'init_cond_frames' not found or empty. Skipping global context ('spatial_out') generation. This is expected during inference.")

        # =================================================================================
        # 第二部分：为每个切片构建局部上下文 (始终执行)
        # =================================================================================
        # This part is essential for the new tracking mechanism and runs in both training and inference.
        # logging.info(f"--- Part 2: Processing Per-Slice Local Context (Window k={k}) ---")

        window_size = 2 * k + 1
        local_offsets = torch.arange(-k, k + 1, device=device)
        slice_indices = torch.arange(N, device=device)
        neighbor_indices = slice_indices.view(N, 1) + local_offsets.view(1, window_size)
        clamped_neighbor_indices = torch.clamp(neighbor_indices, 0, N - 1)

        local_context_fpn_list = []
        for fpn_level in fpn_feature_volume:
            # fpn_level shape: (B, N, FPN_D, FPN_H, FPN_W)
            # Using advanced indexing to gather neighbors for each slice.
            # Result shape: (B, N, window_size, FPN_D, FPN_H, FPN_W)
            gathered_features = fpn_level[:, clamped_neighbor_indices, ...]
            local_context_fpn_list.append(gathered_features)

        # --- 拉平为 (B*N) 的批次维度 ---
        new_batch_size_local = B * N
        final_local_context_fpn = []
        for fpn_context in local_context_fpn_list:
            # from (B, N, window_size, ...) to (B*N, window_size, ...)
            _, _, _, FPN_D, FPN_H, FPN_W = fpn_context.shape
            reshaped_fpn = fpn_context.reshape(new_batch_size_local, window_size, FPN_D, FPN_H, FPN_W)
            final_local_context_fpn.append(reshaped_fpn)

        relative_positions_local = local_offsets.view(1, -1).expand(new_batch_size_local, -1)

        slice_spatial_out = {
            "local_context_fpn": final_local_context_fpn,
            "local_relative_positions": relative_positions_local,
            "metadata": {
                "original_batch_size": B,
                "num_frames": N,
                "window_size": window_size,
            },
        }
        backbone_out["slice_spatial_out"] = slice_spatial_out  # Changed key to match your training code
        # logging.info(f"Generated 'slice_spatial_out' with new batch size {new_batch_size_local}.")
        # logging.info(f"  - Shape of local_context_fpn[0]: {slice_spatial_out['local_context_fpn'][0].shape}")

        return backbone_out

    def process_spatial_feature_all_perspectives(self, backbone_out: dict) -> dict:
        """
        Processes backbone output for a full "all-perspectives" setup.

        For each core slice ("observer"), this function pairs it with ALL slices in the
        window ("targets", including itself). This creates a rich set of relational
        pairs, maximizing data utilization and enabling both self- and cross-attention.
        The output tensors are flattened into a new, larger batch dimension (B * num_cores)
        for efficient, parallel processing in the subsequent attention module.

        Args:
            backbone_out (dict): The output from the `forward_image` function, containing
                                flattened features and metadata. Expected keys include:
                                'num_frames', 'init_cond_frames', 'vision_features',
                                and 'backbone_fpn'.

        Returns:
            dict: A structured dictionary ready for the attention module:
                - 'observer_features': Main ViT features of the observer slices.
                                        Shape: (B*C, N, D, Hf, Wf)
                - 'target_features': Main ViT features of the target slices.
                                    Shape: (B*C, N, D, Hf, Wf)
                - 'observer_fpn': List of FPN features for observer slices.
                                    Each element shape: (B*C, N, FPN_D, FPN_H, FPN_W)
                - 'target_fpn': List of FPN features for target slices.
                                Each element shape: (B*C, N, FPN_D, FPN_H, FPN_W)
                - 'relative_positions': The spatial distance from each target to its
                                        observer. Shape: (B*C, N).
                - 'metadata': A sub-dictionary with original dimensions for potential
                                reshaping later: {'original_batch_size', 'num_cores',
                                'num_frames'}.
        """
        # 1. 元数据和特征恢复 (Metadata and Feature Reshaping)
        # --------------------------------------------------------------------------
        device = backbone_out['vision_features'].device
        num_frames = backbone_out['num_frames']

        all_indices = torch.arange(num_frames, device=device)
        core_indices = torch.tensor(backbone_out['init_cond_frames'], dtype=torch.long, device=device)

        # Helper function to reshape flattened tensors back to their spatial volume
        def _reshape_to_volume(flat_tensor: torch.Tensor) -> torch.Tensor:
            B_flat, *rest_dims = flat_tensor.shape
            batch_size = B_flat // num_frames
            return flat_tensor.view(batch_size, num_frames, *rest_dims)

        # Reshape all relevant features from (B*N, ...) to (B, N, ...)
        fpn_feature_volume = [_reshape_to_volume(f) for f in backbone_out['backbone_fpn']]
        vision_features_volume = _reshape_to_volume(backbone_out['vision_features'])

        B, N, D, Hf, Wf = vision_features_volume.shape
        C = len(core_indices)

        # 2. 核心逻辑: 构建“观察者-目标”对 (Core Logic: Build Observer-Target Pairs)
        # --------------------------------------------------------------------------

        # --- Main ViT Features ---
        # Get observer features (B, C, D, Hf, Wf)
        observer_features = vision_features_volume[:, core_indices, ...]
        # Expand to match targets: (B, C, 1, ...) -> (B, C, N, ...)
        observer_features_expanded = observer_features.unsqueeze(2).expand(-1, -1, N, -1, -1, -1)

        # Target features are the whole volume, expanded for each observer
        # (B, N, D, Hf, Wf) -> (B, 1, N, ...) -> (B, C, N, ...)
        target_features_expanded = vision_features_volume.unsqueeze(1).expand(-1, C, -1, -1, -1, -1)

        # --- FPN Features (handled level by level) ---
        observer_fpn_expanded_list = []
        target_fpn_expanded_list = []
        for fpn_level in fpn_feature_volume:
            # Observer FPN features (B, C, FPN_D, FPN_H, FPN_W)
            observer_fpn = fpn_level[:, core_indices, ...]
            # Expanded: (B, C, N, FPN_D, FPN_H, FPN_W)
            observer_fpn_expanded_list.append(observer_fpn.unsqueeze(2).expand(-1, -1, N, -1, -1, -1))

            # Target FPN features, expanded for each observer
            # Expanded: (B, C, N, FPN_D, FPN_H, FPN_W)
            target_fpn_expanded_list.append(fpn_level.unsqueeze(1).expand(-1, C, -1, -1, -1, -1))

        # --- Relative Positions ---
        # Calculate the relative position matrix using broadcasting
        # core_indices (C,) -> (1, C, 1) | all_indices (N,) -> (1, 1, N)
        # Resulting shape: (1, C, N)
        relative_positions = all_indices.view(1, 1, N) - core_indices.view(1, C, 1)

        # 3. 拉平为新的批次维度 (Flatten to New Batch Dimension)
        # --------------------------------------------------------------------------
        new_batch_size = B * C

        # Flatten main features
        observer_features_flat = observer_features_expanded.reshape(new_batch_size, N, D, Hf, Wf)
        target_features_flat = target_features_expanded.reshape(new_batch_size, N, D, Hf, Wf)

        # Flatten FPN features
        observer_fpn_flat = [fpn.reshape(new_batch_size, N, *fpn.shape[3:]) for fpn in observer_fpn_expanded_list]
        target_fpn_flat = [fpn.reshape(new_batch_size, N, *fpn.shape[3:]) for fpn in target_fpn_expanded_list]

        # Expand and flatten relative positions
        relative_positions_flat = relative_positions.expand(B, -1, -1).reshape(new_batch_size, N)

        # 4. 组装最终输出 (Assemble Final Output)
        # --------------------------------------------------------------------------
        spatial_out = {
            'observer_features': observer_features_flat,
            'target_features': target_features_flat,
            'observer_fpn': observer_fpn_flat,
            'target_fpn': target_fpn_flat,
            'relative_positions': relative_positions_flat,

            'metadata': {
                'original_batch_size': B,
                'num_cores': C,
                'num_frames': N,
            }
        }

        # # Logging for verification
        # logging.info("--- Spatial Feature Processing (All Perspectives - V2 Complete) ---")
        # logging.info(f"Original B={B}, Num Cores={C}, Num Frames={N}")
        # logging.info(f"New effective batch size (B*C): {new_batch_size}")
        # logging.info(f"Shape of 'observer_features': {spatial_out['observer_features'].shape}")
        # logging.info(f"Shape of 'target_features': {spatial_out['target_features'].shape}")
        # logging.info(f"Shape of 'observer_fpn[0]': {spatial_out['observer_fpn'][0].shape}")
        # logging.info(f"Shape of 'relative_positions': {spatial_out['relative_positions'].shape}")
        # # Log a sample to show the '0' for self-relation
        # if new_batch_size > 0:
        #     logging.info(f"Sample relative positions for first observer: {spatial_out['relative_positions'][0].tolist()}")

        # save back into backbone out section
        backbone_out["spatial_out"] = spatial_out
        return spatial_out

    def _prepare_backbone_features_per_frame(self, img_batch, img_ids):
        """Compute the image backbone features on the fly for the given img_ids."""
        # Only forward backbone on unique image ids to avoid repetitive computation
        # (if `img_ids` has only one element, it's already unique so we skip this step).
        if img_ids.numel() > 1:
            unique_img_ids, inv_ids = torch.unique(img_ids, return_inverse=True)
        else:
            unique_img_ids, inv_ids = img_ids, None

        # Compute the image features on those unique image ids
        image = img_batch[unique_img_ids]
        backbone_out = self.forward_image(image)
        (
            _,
            vision_feats,
            vision_pos_embeds,
            feat_sizes,
        ) = self._prepare_backbone_features(backbone_out)
        # Inverse-map image features for `unique_img_ids` to the final image features
        # for the original input `img_ids`.
        if inv_ids is not None:
            image = image[inv_ids]
            vision_feats = [x[:, inv_ids] for x in vision_feats]
            vision_pos_embeds = [x[:, inv_ids] for x in vision_pos_embeds]

        return image, vision_feats, vision_pos_embeds, feat_sizes

    def prepare_prompt_inputs(self, backbone_out, input, start_frame_idx=0):
        """
        Prepare input mask, point or box prompts. Optionally, we allow tracking from
        a custom `start_frame_idx` to the end of the video (for evaluation purposes).
        """
        # Load the ground-truth masks on all frames (so that we can later
        # sample correction points from them)
        # gt_masks_per_frame = {
        #     stage_id: targets.segments.unsqueeze(1)  # [B, 1, H_im, W_im]
        #     for stage_id, targets in enumerate(input.find_targets)
        # }
        gt_masks_per_frame = {
            stage_id: masks.unsqueeze(1)  # [B, 1, H_im, W_im]
            for stage_id, masks in enumerate(input.masks)
        }
        # gt_masks_per_frame = input.masks.unsqueeze(2) # [T,B,1,H_im,W_im] keep everything in tensor form
        backbone_out["gt_masks_per_frame"] = gt_masks_per_frame
        num_frames = input.num_frames
        backbone_out["num_frames"] = num_frames

        # Randomly decide whether to use point inputs or mask inputs
        if self.training:
            prob_to_use_pt_input = self.prob_to_use_pt_input_for_train
            prob_to_use_box_input = self.prob_to_use_box_input_for_train
            num_frames_to_correct = self.num_frames_to_correct_for_train
            rand_frames_to_correct = self.rand_frames_to_correct_for_train
            num_init_cond_frames = self.num_init_cond_frames_for_train
            rand_init_cond_frames = self.rand_init_cond_frames_for_train
        else:
            prob_to_use_pt_input = self.prob_to_use_pt_input_for_eval
            prob_to_use_box_input = self.prob_to_use_box_input_for_eval
            num_frames_to_correct = self.num_frames_to_correct_for_eval
            rand_frames_to_correct = self.rand_frames_to_correct_for_eval
            num_init_cond_frames = self.num_init_cond_frames_for_eval
            rand_init_cond_frames = self.rand_init_cond_frames_for_eval
        if num_frames == 1:
            # here we handle a special case for mixing video + SAM on image training,
            # where we force using point input for the SAM task on static images
            prob_to_use_pt_input = 1.0
            num_frames_to_correct = 1
            num_init_cond_frames = 1
        assert num_init_cond_frames >= 1
        # (here `self.rng.random()` returns value in range 0.0 <= X < 1.0)

        # Choose num_init_cond_frames and num_frames_to_correct
        use_pt_input = self.rng.random() < prob_to_use_pt_input
        if rand_init_cond_frames and num_init_cond_frames > 1:
            # randomly select 1 to `num_init_cond_frames` frames as initial conditioning frames
            num_init_cond_frames = self.rng.integers(
                1, num_init_cond_frames, endpoint=True
            )
        if (
            use_pt_input
            and rand_frames_to_correct
            and num_frames_to_correct > num_init_cond_frames
        ):
            # randomly select `num_init_cond_frames` to `num_frames_to_correct` frames to sample
            # correction clicks (only for the case of point input)
            num_frames_to_correct = self.rng.integers(
                num_init_cond_frames, num_frames_to_correct, endpoint=True
            )
        backbone_out["use_pt_input"] = use_pt_input

        # Sample initial conditioning frames
        sample_start_frame_idx = start_frame_idx + 1
        if num_init_cond_frames == 1:
            init_cond_frames = [sample_start_frame_idx]  # starting frame
        else:
            # starting frame + randomly selected remaining frames (without replacement)
            init_cond_frames = [sample_start_frame_idx] + self.rng.choice(
                range(start_frame_idx + 2, num_frames),
                num_init_cond_frames - 1,
                replace=False,
            ).tolist()

        # logging.info(f"init_cond_frames:{init_cond_frames}")
        backbone_out["init_cond_frames"] = init_cond_frames
        backbone_out["frames_not_in_init_cond"] = [
            t for t in range(start_frame_idx, num_frames) if t not in init_cond_frames
        ]

        # Prepare mask or point inputs on initial conditioning frames
        backbone_out["mask_inputs_per_frame"] = {}  # {frame_idx: <input_masks>}
        backbone_out["point_inputs_per_frame"] = {}  # {frame_idx: <input_points>}
        for t in init_cond_frames:
            if not use_pt_input:
                backbone_out["mask_inputs_per_frame"][t] = gt_masks_per_frame[t]
            else:
                # During training # P(box) = prob_to_use_pt_input * prob_to_use_box_input
                use_box_input = self.rng.random() < prob_to_use_box_input
                if use_box_input:
                    points, labels = sample_box_points(
                        gt_masks_per_frame[t],
                    )
                else:
                    # (here we only sample **one initial point** on initial conditioning frames from the
                    # ground-truth mask; we may sample more correction points on the fly)
                    points, labels = get_next_point(
                        gt_masks=gt_masks_per_frame[t],
                        pred_masks=None,
                        method=(
                            "uniform" if self.training else self.pt_sampling_for_eval
                        ),
                    )

                point_inputs = {"point_coords": points, "point_labels": labels}
                backbone_out["point_inputs_per_frame"][t] = point_inputs

        # Sample frames where we will add correction clicks on the fly
        # based on the error between prediction and ground-truth masks
        if not use_pt_input:
            # no correction points will be sampled when using mask inputs
            frames_to_add_correction_pt = []
        elif num_frames_to_correct == num_init_cond_frames:
            frames_to_add_correction_pt = init_cond_frames
        else:
            assert num_frames_to_correct > num_init_cond_frames
            # initial cond frame + randomly selected remaining frames (without replacement)
            extra_num = num_frames_to_correct - num_init_cond_frames
            frames_to_add_correction_pt = (
                init_cond_frames
                + self.rng.choice(
                    backbone_out["frames_not_in_init_cond"], extra_num, replace=False
                ).tolist()
            )
        backbone_out["frames_to_add_correction_pt"] = frames_to_add_correction_pt

        return backbone_out

    def forward_tracking(
        self, backbone_out, input: BatchedVideoDatapoint, return_dict=False
    ):
        """Forward video tracking on each frame (and sample correction clicks)."""
        img_feats_already_computed = backbone_out["backbone_fpn"] is not None
        if img_feats_already_computed:
            # Prepare the backbone features
            # - vision_feats and vision_pos_embeds are in (HW)BC format
            (
                _,
                vision_feats,
                vision_pos_embeds,
                feat_sizes,
            ) = self._prepare_backbone_features(backbone_out)

            prepared_spatial_features = self._prepare_spatial_features(backbone_out["spatial_out"])  # global
            prepared_slice_spatial_features = self._prepare_local_spatial_features(backbone_out["slice_spatial_out"])  # local context sliding window

            prepared_local_context_fpn = prepared_slice_spatial_features['prepared_local_context_fpn']
            feat_sizes = prepared_slice_spatial_features['feat_sizes']
            local_relative_positions = prepared_slice_spatial_features['local_relative_positions']
            metadata = prepared_slice_spatial_features['metadata']

        # Starting the stage loop
        num_frames = backbone_out["num_frames"]
        init_cond_frames = backbone_out["init_cond_frames"]
        frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"]
        # first process all the initial conditioning frames to encode them as memory,
        # and then conditioning on them to track the remaining frames
        processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"]
        output_dict = {
            "cond_frame_outputs": {},  # dict containing {frame_idx: <out>}
            "non_cond_frame_outputs": {},  # dict containing {frame_idx: <out>}
        }

        # logging.info(f"processing_order:{processing_order}")
        current_local_context_vision_feats = []
        current_local_context_vision_relative_pos = []

        # process input data per frame
        for stage_id in processing_order:
            # Get the image features for the current frames
            # img_ids = input.find_inputs[stage_id].img_ids
            img_ids = input.flat_obj_to_img_idx[stage_id]
            if img_feats_already_computed:
                # Retrieve image features according to img_ids (if they are already computed).
                current_vision_feats = [x[:, img_ids] for x in vision_feats]
                current_vision_pos_embeds = [x[:, img_ids] for x in vision_pos_embeds]
                # logging.info(
                #     f"current_vision_feats:{len(current_vision_feats)}, shape:{current_vision_feats[0].shape}, {current_vision_feats[1].shape}, {current_vision_feats[2].shape}")
                # logging.info(f"current_vision_pos_embeds:{len(current_vision_pos_embeds)}, shape:{current_vision_pos_embeds[0].shape}")

                current_local_context_vision_feats = [
                    fpn_level[:, img_ids] for fpn_level in prepared_local_context_fpn
                ]
                current_local_context_vision_relative_pos = local_relative_positions[img_ids] 

                # logging.info(f"--- Processing Stage {stage_id} ---")
                # logging.info(f"Number of objects/slices in this stage: {len(img_ids)}")

                # if current_local_context_vision_feats:
                #     fpn_shape = current_local_context_vision_feats[0].shape
                #     logging.info(f"Shape of current_local_context_vision_feats[0]: {fpn_shape}")
                #     # 预期的形状: (SeqLen, K, FeatureDim)
                #     # 例如: (Win*H*W, num_objects_in_stage, FPN_D)
                #     win = metadata['window_size']
                #     h, w = feat_sizes[0]
                #     logging.info(f"  - Verified SeqLen ({win}*{h}*{w}): {fpn_shape[0]}")
                #     logging.info(f"  - Verified Batch Dim (num objects): {fpn_shape[1]}")
                #     logging.info(f"  - Verified Feature Dim: {fpn_shape[2]}")

                # logging.info(f"Shape of current_local_context_vision_relative_pos: {current_local_context_vision_relative_pos.shape}")
                # logging.info(f"Content of current_local_context_vision_relative_pos:\n{current_local_context_vision_relative_pos}")

            else:
                # Otherwise, compute the image features on the fly for the given img_ids
                # (this might be used for evaluation on long videos to avoid backbone OOM).
                (
                    _,
                    current_vision_feats,
                    current_vision_pos_embeds,
                    feat_sizes,
                ) = self._prepare_backbone_features_per_frame(
                    input.flat_img_batch, img_ids
                )

            # Get output masks based on this frame's prompts and previous memory
            current_out = self.track_step(
                frame_idx=stage_id,
                is_init_cond_frame=stage_id in init_cond_frames,
                current_vision_feats=current_vision_feats,
                current_vision_pos_embeds=current_vision_pos_embeds,
                current_local_context_vision_feats=current_local_context_vision_feats,
                current_local_context_vision_relative_pos=current_local_context_vision_relative_pos,
                feat_sizes=feat_sizes,
                point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None),
                mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None),
                gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None),
                frames_to_add_correction_pt=frames_to_add_correction_pt,
                output_dict=output_dict,
                num_frames=num_frames,
            )
            # Append the output, depending on whether it's a conditioning frame
            add_output_as_cond_frame = stage_id in init_cond_frames or (
                self.add_all_frames_to_correct_as_cond
                and stage_id in frames_to_add_correction_pt
            )
            if add_output_as_cond_frame:
                output_dict["cond_frame_outputs"][stage_id] = current_out
            else:
                output_dict["non_cond_frame_outputs"][stage_id] = current_out

        # # full batch frame prediction
        # predicted_pos_logits = self.pos_prediction_head(
        #     observer_feature=extra_rpe_feature['observer_pooled'],
        #     target_features=extra_rpe_feature['target_pooled']
        # )  # Output Shape: (B*C, N, num_classes)

        if return_dict:
            return output_dict
        # turn `output_dict` into a list for loss function
        all_frame_outputs = {}
        all_frame_outputs.update(output_dict["cond_frame_outputs"])
        all_frame_outputs.update(output_dict["non_cond_frame_outputs"])

        # return boundary logits info
        if self.use_boundary and self.training:
            # print(f"all_frame_outputs:{all_frame_outputs.keys()}")
            for key in all_frame_outputs:
                all_frame_outputs[key]['boundary_logits'] = backbone_out['boundary_logits']

        # # return predicted_pos_logits info
        # if self.use_volume and self.training:
        #     for key in all_frame_outputs:
        #         all_frame_outputs[key]['predicted_pos_logits'] = predicted_pos_logits
        #         all_frame_outputs[key]['relative_positions'] = backbone_out['relative_positions']

        all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)]
        # Make DDP happy with activation checkpointing by removing unused keys
        all_frame_outputs = [
            {k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs
        ]

        # # deliver boundary logits to final part
        # print(f"type:{type(all_frame_outputs[0])}, keys:{all_frame_outputs[0].keys()}")

        return all_frame_outputs

    def track_step(
        self,
        frame_idx,
        is_init_cond_frame,
        current_vision_feats,
        current_vision_pos_embeds,
        current_local_context_vision_feats,
        current_local_context_vision_relative_pos,
        feat_sizes,
        point_inputs,
        mask_inputs,
        output_dict,
        num_frames,
        track_in_reverse=False,  # tracking in reverse time order (for demo usage)
        run_mem_encoder=True,  # Whether to run the memory encoder on the predicted masks.
        prev_sam_mask_logits=None,  # The previously predicted SAM mask logits.
        frames_to_add_correction_pt=None,
        gt_masks=None,
    ):
        if frames_to_add_correction_pt is None:
            frames_to_add_correction_pt = []
        current_out, sam_outputs, high_res_features, pix_feat = self._track_step(
            frame_idx,
            is_init_cond_frame,
            current_vision_feats,
            current_vision_pos_embeds,
            current_local_context_vision_feats,
            current_local_context_vision_relative_pos,
            feat_sizes,
            point_inputs,
            mask_inputs,
            output_dict,
            num_frames,
            track_in_reverse,
            prev_sam_mask_logits,
        )

        (
            low_res_multimasks,
            high_res_multimasks,
            ious,
            low_res_masks,
            high_res_masks,
            obj_ptr,
            object_score_logits,
        ) = sam_outputs

        current_out["multistep_pred_masks"] = low_res_masks
        current_out["multistep_pred_masks_high_res"] = high_res_masks
        current_out["multistep_pred_multimasks"] = [low_res_multimasks]
        current_out["multistep_pred_multimasks_high_res"] = [high_res_multimasks]
        current_out["multistep_pred_ious"] = [ious]
        current_out["multistep_point_inputs"] = [point_inputs]
        current_out["multistep_object_score_logits"] = [object_score_logits]

        # Optionally, sample correction points iteratively to correct the mask
        if frame_idx in frames_to_add_correction_pt:
            point_inputs, final_sam_outputs = self._iter_correct_pt_sampling(
                is_init_cond_frame,
                point_inputs,
                gt_masks,
                high_res_features,
                pix_feat,
                low_res_multimasks,
                high_res_multimasks,
                ious,
                low_res_masks,
                high_res_masks,
                object_score_logits,
                current_out,
            )
            (
                _,
                _,
                _,
                low_res_masks,
                high_res_masks,
                obj_ptr,
                object_score_logits,
            ) = final_sam_outputs

        # Use the final prediction (after all correction steps for output and eval)
        current_out["pred_masks"] = low_res_masks
        current_out["pred_masks_high_res"] = high_res_masks
        current_out["obj_ptr"] = obj_ptr

        # Finally run the memory encoder on the predicted mask to encode
        # it into a new memory feature (that can be used in future frames)
        self._encode_memory_in_output(
            current_vision_feats,
            feat_sizes,
            point_inputs,
            run_mem_encoder,
            high_res_masks,
            object_score_logits,
            current_out,
        )
        return current_out

    def _iter_correct_pt_sampling(
        self,
        is_init_cond_frame,
        point_inputs,
        gt_masks,
        high_res_features,
        pix_feat_with_mem,
        low_res_multimasks,
        high_res_multimasks,
        ious,
        low_res_masks,
        high_res_masks,
        object_score_logits,
        current_out,
    ):

        assert gt_masks is not None
        all_pred_masks = [low_res_masks]
        all_pred_high_res_masks = [high_res_masks]
        all_pred_multimasks = [low_res_multimasks]
        all_pred_high_res_multimasks = [high_res_multimasks]
        all_pred_ious = [ious]
        all_point_inputs = [point_inputs]
        all_object_score_logits = [object_score_logits]
        for _ in range(self.num_correction_pt_per_frame):
            # sample a new point from the error between prediction and ground-truth
            # (with a small probability, directly sample from GT masks instead of errors)
            if self.training and self.prob_to_sample_from_gt_for_train > 0:
                sample_from_gt = (
                    self.rng.random() < self.prob_to_sample_from_gt_for_train
                )
            else:
                sample_from_gt = False
            # if `pred_for_new_pt` is None, only GT masks will be used for point sampling
            pred_for_new_pt = None if sample_from_gt else (high_res_masks > 0)
            new_points, new_labels = get_next_point(
                gt_masks=gt_masks,
                pred_masks=pred_for_new_pt,
                method="uniform" if self.training else self.pt_sampling_for_eval,
            )
            point_inputs = concat_points(point_inputs, new_points, new_labels)
            # Feed the mask logits of the previous SAM outputs in the next SAM decoder step.
            # For tracking, this means that when the user adds a correction click, we also feed
            # the tracking output mask logits along with the click as input to the SAM decoder.
            mask_inputs = low_res_masks
            multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
            if self.use_act_ckpt_iterative_pt_sampling and not multimask_output:
                sam_outputs = torch.utils.checkpoint.checkpoint(
                    self._forward_sam_heads,
                    backbone_features=pix_feat_with_mem,
                    point_inputs=point_inputs,
                    mask_inputs=mask_inputs,
                    high_res_features=high_res_features,
                    multimask_output=multimask_output,
                    use_reentrant=False,
                )
            else:
                sam_outputs = self._forward_sam_heads(
                    backbone_features=pix_feat_with_mem,
                    point_inputs=point_inputs,
                    mask_inputs=mask_inputs,
                    high_res_features=high_res_features,
                    multimask_output=multimask_output,
                )
            (
                low_res_multimasks,
                high_res_multimasks,
                ious,
                low_res_masks,
                high_res_masks,
                _,
                object_score_logits,
            ) = sam_outputs
            all_pred_masks.append(low_res_masks)
            all_pred_high_res_masks.append(high_res_masks)
            all_pred_multimasks.append(low_res_multimasks)
            all_pred_high_res_multimasks.append(high_res_multimasks)
            all_pred_ious.append(ious)
            all_point_inputs.append(point_inputs)
            all_object_score_logits.append(object_score_logits)

        # Concatenate the masks along channel (to compute losses on all of them,
        # using `MultiStepIteractiveMasks`)
        current_out["multistep_pred_masks"] = torch.cat(all_pred_masks, dim=1)
        current_out["multistep_pred_masks_high_res"] = torch.cat(
            all_pred_high_res_masks, dim=1
        )
        current_out["multistep_pred_multimasks"] = all_pred_multimasks
        current_out["multistep_pred_multimasks_high_res"] = all_pred_high_res_multimasks
        current_out["multistep_pred_ious"] = all_pred_ious
        current_out["multistep_point_inputs"] = all_point_inputs
        current_out["multistep_object_score_logits"] = all_object_score_logits

        return point_inputs, sam_outputs

    def _fuse_spatial_context(
        self,
        local_context_features: list[torch.Tensor],
        relative_positions: torch.Tensor,
        feat_sizes: list[tuple[int, int]],
    ):
        """
        Fuses spatial context features using our RPEAttention layer.
        """
        context_sequence = local_context_features[-1]
        H, W = feat_sizes[-1]

        # 1. 准备 Query, Key, Value
        Win = context_sequence.size(0) // (H * W)
        center_slice_start_idx = (Win // 2) * (H * W)
        center_slice_end_idx = center_slice_start_idx + (H * W)
        query_features = context_sequence[center_slice_start_idx:center_slice_end_idx, :, :]
        key_value_features = context_sequence
        fused_features_seq = self.spatial_attention_layer(
            query=query_features,
            key=key_value_features,
            value=key_value_features,
            relative_positions=relative_positions,
            H=H,
            W=W
        )
        # fused_features_seq shape: (H*W, B_stage, D)

        # 3. 重塑为图像特征格式
        B_stage, D = fused_features_seq.size(1), fused_features_seq.size(2)
        fused_features_map = fused_features_seq.permute(1, 2, 0).view(B_stage, D, H, W)

        return fused_features_map, context_sequence

    def _track_step(
        self,
        frame_idx,
        is_init_cond_frame,
        current_vision_feats,
        current_vision_pos_embeds,
        current_local_context_vision_feats,
        current_local_context_vision_relative_pos,
        feat_sizes,
        point_inputs,
        mask_inputs,
        output_dict,
        num_frames,
        track_in_reverse,
        prev_sam_mask_logits,
    ):
        current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs, "pos_logits": None, "pos_target": None}
        # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
        if len(current_vision_feats) > 1:
            high_res_features = [
                x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
                for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
            ]
        else:
            high_res_features = None

        # fused the visual feature with previous memory features in the memory bank
        single_pix_feat = self._prepare_memory_conditioned_features(
            frame_idx=frame_idx,
            is_init_cond_frame=is_init_cond_frame,
            current_vision_feats=current_vision_feats[-1:],
            current_vision_pos_embeds=current_vision_pos_embeds[-1:],
            feat_sizes=feat_sizes[-1:],
            output_dict=output_dict,
            num_frames=num_frames,
            track_in_reverse=track_in_reverse,
        )

        # fusion feature on spatial feature and their relative pos
        fused_pix_feat = single_pix_feat
        if current_local_context_vision_feats != [] and current_local_context_vision_relative_pos != []:
            spatial_context_feat, context_sequence = self._fuse_spatial_context(
                local_context_features=current_local_context_vision_feats,
                relative_positions=current_local_context_vision_relative_pos,
                feat_sizes=feat_sizes,
            )
            # --- 融合模块 ---
            fused_pix_feat = single_pix_feat + spatial_context_feat
            # logging.info(f"single_pix_feat:{single_pix_feat.shape}")
            # logging.info(f"spatial_context_feat:{spatial_context_feat.shape}")

        if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
            # When use_mask_input_as_output_without_sam=True, we directly output the mask input
            # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
            pix_feat = current_vision_feats[-1].permute(1, 2, 0)
            pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
            sam_outputs = self._use_mask_as_output(
                fused_pix_feat,    # using new fused feature here
                high_res_features,
                mask_inputs
            )
        else:
            # apply SAM-style segmentation head
            # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
            # e.g. in demo where such logits come from earlier interaction instead of correction sampling
            # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
            if prev_sam_mask_logits is not None:
                assert point_inputs is not None and mask_inputs is None
                mask_inputs = prev_sam_mask_logits
            multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
            sam_outputs = self._forward_sam_heads(
                backbone_features=fused_pix_feat,   # using new fused feature here
                point_inputs=point_inputs,
                mask_inputs=mask_inputs,
                high_res_features=high_res_features,
                multimask_output=multimask_output,
            )

        # --- 辅助任务: 计算位置预测损失 ---
        # 只有在训练时才计算损失
        if self.training:
            H, W = feat_sizes[-1]

            # 2. 调用预测头，输入是原始的上下文序列
            pos_logits = self.position_prediction_head(
                context_feature_sequence=context_sequence,
                H=H,
                W=W
            )

            self.window_radius_k = 2
            targets = current_local_context_vision_relative_pos + self.window_radius_k
            B_stage, Win, C = pos_logits.shape

            current_out['pos_logits'] = pos_logits
            current_out['pos_target'] = targets

            # loss_pos = self.position_loss_fn(
            #     pos_logits.view(B_stage * Win, C),  # (B*Win, C)
            #     targets.view(B_stage * Win).long()  # (B*Win,)
            # )

            # self._visualize_predictions(pos_logits, targets, frame_idx)
            # logging.info(f"Frame {frame_idx}, Auxiliary Relative Position Loss: {loss_pos.item():.4f}")
            # exit(0)

        return current_out, sam_outputs, high_res_features, fused_pix_feat

    def _visualize_predictions(
        self,
        pos_logits: torch.Tensor,
        targets: torch.Tensor,
        frame_idx: int,
        max_items_to_show: int = 4
    ):
        """
        Prints a formatted table to visualize the predictions of the RelativePositionHead
        against the ground truth targets.

        Args:
            pos_logits (torch.Tensor): Raw logits from the prediction head.
                                    Shape: (B_stage, Win, num_classes).
            targets (torch.Tensor): Ground truth class indices. Shape: (B_stage, Win).
            frame_idx (int): The index of the current central frame being processed.
            max_items_to_show (int): Maximum number of batch items to visualize to avoid
                                    spamming the console.
        """
        # Detach tensors from the computation graph and move to CPU for processing
        pos_logits = pos_logits.detach().cpu()
        targets = targets.detach().cpu()

        # Calculate predicted indices and confidence
        probs = F.softmax(pos_logits, dim=-1)
        confidences, predicted_indices = torch.max(probs, dim=-1)

        B_stage, Win = targets.shape
        num_to_show = min(B_stage, max_items_to_show)

        logging.info("=" * 80)
        logging.info(f"VISUALIZING RELATIVE POSITION PREDICTIONS (Frame: {frame_idx})")
        logging.info(f"Showing first {num_to_show} of {B_stage} items in the batch.")
        logging.info("=" * 80)

        for i in range(num_to_show):
            logging.info(f"--- Batch Item {i+1} / {B_stage} ---")
            header = (
                f"{'Slice in Window':<18} | {'Ground Truth':<15} | {'Prediction':<15} | "
                f"{'Confidence':<12} | {'Correct?':<10}"
            )
            logging.info(header)
            logging.info("-" * len(header))

            for j in range(Win):
                gt_index = targets[i, j].item()
                pred_index = predicted_indices[i, j].item()
                confidence = confidences[i, j].item()

                # Convert class indices back to relative positions [-k, +k]
                gt_pos = gt_index - self.window_radius_k
                pred_pos = pred_index - self.window_radius_k

                is_correct = (gt_index == pred_index)
                correct_symbol = "✅" if is_correct else "❌"

                log_line = (
                    f"Slice {j+1:<12d} | Pos: {gt_pos:<+10d} | Pos: {pred_pos:<+10d} | "
                    f"{confidence:^12.2%} | {correct_symbol:^10}"
                )
                logging.info(log_line)
            logging.info("-" * 80)
