# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.distributed
import torch.nn.functional as F

from torch.nn.init import trunc_normal_

from sam2.modeling.sam.mask_decoder import MaskDecoder
from sam2.modeling.sam.prompt_encoder import PromptEncoder
from sam2.modeling.sam.transformer import TwoWayTransformer
from sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames
from sam2.utils.wb import activation_hook
import time
import math
from typing import Callable, Tuple

# a large negative value as a placeholder score for missing objects
NO_OBJ_SCORE = -1024.0


def generate_probability_mask(A, seed=None):
    
    if seed is not None:
        # 设置随机种子
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
    
    
    random_values = torch.rand_like(A, device=A.device)
    mask = random_values < A
    
    return mask



def merge_boolean_masks(Am_bool, Bm_bool):
    """
    专门处理布尔mask的合并
    
    Args:
        Am_bool: shape [N] 布尔mask
        Bm_bool: shape [N_1] 布尔mask  
        original_size: 原始tensor的第一维大小N
    
    Returns:
        Cm_bool: shape [N] 合并后的布尔mask
    """
    # 创建最终的布尔mask
    Cm_bool = torch.zeros(Am_bool.shape[0], dtype=torch.bool, device=Am_bool.device)
    
    # 获取Am中True的位置
    Am_indices = torch.where(Am_bool)[0]
    
    # 获取Bm中True的位置
    Bm_indices = torch.where(Bm_bool)[0]
    
    # 将Bm的True位置映射回原始索引
    final_indices = Am_indices[Bm_indices]
    
    # 设置最终mask
    Cm_bool[final_indices] = True
    
    return Cm_bool


def create_Mem_SP_mask(A, theta):
    N = A.shape[0]
    if N <= 1:
        return torch.ones(N, device=A.device)
    
    A_norm = F.normalize(A, p=2, dim=1)
    num_pairs = N // 2
    
    mask = torch.ones(N, device=A.device, dtype=torch.bool)
    if num_pairs > 0:
        similarities = (A_norm[0::2][:num_pairs] * A_norm[1::2][:num_pairs]).sum(1)
        # print(similarities)
        high_sim_indices = torch.where(similarities > theta)[0]
        if len(high_sim_indices) > 0:  # 可选的安全检查
            mask[high_sim_indices * 2 + 1] = False
    
    
    return mask



def do_nothing(x, mode=None):
    return x


def get_bi_matching_ids(
    metric: torch.Tensor,
    r_ratio: float,
):
    N, C = metric.shape
    r = min(int(N*r_ratio), N//2)
    if r <= 0:
      # 不需要merge的情况
      return {
          'unm_idx': torch.arange(N, device=metric.device),
          'src_idx': torch.empty(0, dtype=torch.long, device=metric.device),
          'dst_idx': torch.empty(0, dtype=torch.long, device=metric.device),
          'tensor_size': N,
          'merge_count': 0
      }
    with torch.no_grad():
        metric_ = metric / metric.norm(dim=-1, keepdim=True)
        a, b = metric_[::2, :], metric_[1::2, :]
        scores = a @ b.transpose(-1, -2)
        scores[scores.isnan()] = float('-inf')
        
        node_max, node_idx = scores.max(dim=-1)
        edge_idx = node_max.argsort(dim=-1, descending=True)
        
        # 局部索引
        local_unm_idx = edge_idx[r:] * 2  # 转换为原始索引（偶数位置）
        local_src_idx = edge_idx[:r] * 2  # 源索引（偶数位置）
        local_dst_idx = node_idx[edge_idx[:r]] * 2 + 1  # 目标索引（奇数位置）
        
        return {
            'unm_idx': local_unm_idx,
            'src_idx': local_src_idx,
            'dst_idx': local_dst_idx,
            'tensor_size': N,
            'merge_count': r
        }

def bipartite_soft_matching(
    metric: torch.Tensor,
    r: int,
) -> Tuple[Callable, Callable]:
    """
    Applies ToMe with a balanced matching set (50%, 50%).

    Input size is [batch, tokens, channels].
    r indicates the number of tokens to remove (max 50% of tokens).

    Extra args:
     - class_token: Whether or not there's a class token.
     - distill_token: Whether or not there's also a distillation token.

    When enabled, the class token and distillation tokens won't get merged.
    """
    # print('matching.........')
    # protected = 0
    
    N,C = metric.shape
    # We can only reduce by a maximum of 50% tokens
    t = metric.shape[0]
    r = min(r, t // 2)

    if r <= 0:
        return do_nothing, do_nothing

    with torch.no_grad():
        metric_ = metric / metric.norm(dim=-1, keepdim=True)
        a, b = metric_[::2, :], metric_[1::2, :]
        scores = a @ b.transpose(-1, -2)
        scores[scores.isnan()] = float('-inf')

        node_max, node_idx = scores.max(dim=-1)
        edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]

        unm_idx = edge_idx[..., r:, :]  # Unmerged Tokens
        src_idx = edge_idx[..., :r, :]  # Merged Tokens
        dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)


    def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
        
        src, dst = x[..., ::2, :], x[..., 1::2, :]
        t1, c = src.shape
        unm = src.gather(dim=-2, index=unm_idx.expand(t1 - r, c))
        src = src.gather(dim=-2, index=src_idx.expand(r, c))
        dst = dst.scatter_reduce(-2, dst_idx.expand(r, c), src, reduce=mode)

        return torch.cat([unm, dst], dim=0)

    def unmerge(x: torch.Tensor) -> torch.Tensor:
        unm_len = unm_idx.shape[1]
        unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
        n, _, c = unm.shape

        src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c))

        out = torch.zeros(n, metric.shape[1], c, device=x.device, dtype=x.dtype)

        out[..., 1::2, :] = dst
        out.scatter_(dim=-2, index=(2 * unm_idx).expand(n, unm_len, c), src=unm)
        out.scatter_(dim=-2, index=(2 * src_idx).expand(n, r, c), src=src)

        return out
    
    def prune(x: torch.Tensor, mode="mean") -> torch.Tensor:
        
        src, dst = x[..., ::2, :], x[..., 1::2, :]
        n, t1, c = src.shape
        unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
        # src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
        # dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)

        return torch.cat([unm, dst], dim=1)
        
    def restore(x: torch.Tensor) -> torch.Tensor:
        unm_len = unm_idx.shape[1]
        unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
        n, _, c = unm.shape

        # src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c))

        out = torch.zeros(n, metric.shape[1], c, device=x.device, dtype=x.dtype)

        out[..., 1::2, :] = dst
        out.scatter_(dim=-2, index=(2 * unm_idx).expand(n, unm_len, c), src=unm)
        # out.scatter_(dim=-2, index=(2 * src_idx).expand(n, r, c), src=src)

        return out

    return merge

def create_mask_rand(shape, false_ratio, device='cuda'):
    """使用随机数与阈值比较生成mask"""
    mask = torch.rand(shape, device=device) >= false_ratio
    return mask

def create_mask_uniform(N, sample_ratio, device='cuda'):
    """
    生成一个shape为[N]的mask tensor，实现指定采样率的均匀采样
    
    Args:
        N (int): tensor的长度
        sample_ratio (float): 采样比例，范围[0, 1]
        device (str): 设备类型，默认'cuda'
    
    Returns:
        torch.Tensor: shape为[N]的bool类型mask tensor，True的位置均匀分布
    """
    # 确保sample_ratio在合理范围内
    sample_ratio = max(0.0, min(1.0, sample_ratio))
    
    # 计算需要采样的点数
    num_samples = int(N * sample_ratio)
    
    if num_samples == 0:
        # 如果采样数为0，返回全False的mask
        return torch.zeros(N, dtype=torch.bool, device=device)
    elif num_samples >= N:
        # 如果采样数大于等于N，返回全True的mask
        return torch.ones(N, dtype=torch.bool, device=device)
    
    # 方法1: 使用linspace生成均匀分布的索引
    indices = torch.linspace(0, N-1, num_samples, device=device)
    indices = torch.round(indices).long()
    
    # 去除可能的重复索引
    indices = torch.unique(indices)
    
    # 创建mask tensor
    mask = torch.zeros(N, dtype=torch.bool, device=device)
    mask[indices] = True
    
    return mask

def get_topk_mask(A, k):
    """
    根据重要性tensor A获取top-k mask
    
    Args:
        A (torch.Tensor): shape为[N]的重要性tensor
        k (int): 需要选择的top-k数量
        device (str): 设备类型，默认'cuda'
    
    Returns:
        torch.Tensor: shape为[N]的bool类型mask tensor，重要性最高的k个位置为True
    """
    # 确保tensor在正确的设备上
    assert len(A.shape)==1
    N = A.shape[0]
    
    # 边界情况处理
    # if k <= 0:
    #     return torch.zeros(N, dtype=torch.bool, device=device)
    # elif k >= N:
    #     return torch.ones(N, dtype=torch.bool, device=device)
    assert 0<=k<=N
    
    # 使用topk获取最大的k个值的索引
    _, top_indices = torch.topk(A, k, largest=True)
    # print(top_indices.shape)
    # 创建mask tensor
    mask = torch.zeros(N, dtype=torch.bool, device=A.device)
    mask[top_indices] = True
    # print(mask.sum(), k, N)
    return mask

def create_cumulative_mask_optimized_PF(data, N=4096, threshold=0.8):
    """
    data.shape=[1,7*4096]
    
    """
    
    data = data.reshape(data.shape[-1]//N,N)
    # 2. 排序并获取索引
    sorted_probs, sorted_indices = torch.sort(data, dim=-1, descending=True)
    
    # 3. 计算累积概率并创建掩码
    cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
    frame_sum = torch.sum(data,dim=-1,keepdim=True)
    selected_mask_sorted = cumsum_probs <= threshold*frame_sum
    # selected_mask_sorted = cumsum_probs <= threshold*torch.sum(data,dim=-1)
    
    # 4. 高效的逆向映射
    # 创建batch索引
    batch_indices = torch.arange(data.size(0), device=data.device).unsqueeze(1)
    
    # 创建结果掩码
    mask = torch.zeros_like(data, dtype=torch.bool)
    mask[batch_indices, sorted_indices] = selected_mask_sorted
    
    return mask.reshape(-1)


class SAM2Base(torch.nn.Module):
    def __init__(
        self,
        image_encoder,
        memory_attention,
        memory_encoder,
        num_maskmem=7,  # default 1 input frame + 6 previous frames
        image_size=512,
        backbone_stride=16,  # stride of the image backbone output
        sigmoid_scale_for_mem_enc=1.0,  # scale factor for mask sigmoid prob
        sigmoid_bias_for_mem_enc=0.0,  # bias factor for mask sigmoid prob
        # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks
        binarize_mask_from_pts_for_mem_enc=False,
        use_mask_input_as_output_without_sam=False,  # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder
        # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit,
        # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model
        # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM.
        max_cond_frames_in_attn=-1,
        # on the first frame, whether to directly add the no-memory embedding to the image feature
        # (instead of using the transformer encoder)
        directly_add_no_mem_embed=False,
        # whether to use high-resolution feature maps in the SAM mask decoder
        use_high_res_features_in_sam=False,
        # whether to output multiple (3) masks for the first click on initial conditioning frames
        multimask_output_in_sam=False,
        # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`;
        # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points)
        multimask_min_pt_num=1,
        multimask_max_pt_num=1,
        # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`)
        multimask_output_for_tracking=False,
        # Whether to use multimask tokens for obj ptr; Only relevant when both
        # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True
        use_multimask_token_for_obj_ptr: bool = False,
        # whether to use sigmoid to restrict ious prediction to [0-1]
        iou_prediction_use_sigmoid=False,
        # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5).
        # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
        # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
        memory_temporal_stride_for_eval=1,
        # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
        non_overlap_masks_for_mem_enc=False,
        # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
        use_obj_ptrs_in_encoder=False,
        # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`)
        max_obj_ptrs_in_encoder=16,
        # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`)
        add_tpos_enc_to_obj_ptrs=True,
        # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference
        # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
        proj_tpos_enc_in_obj_ptrs=False,
        # whether to use signed distance (instead of unsigned absolute distance) in the temporal positional encoding in the object pointers
        # (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
        use_signed_tpos_enc_to_obj_ptrs=False,
        # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation
        # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking)
        only_obj_ptrs_in_the_past_for_eval=False,
        # Whether to predict if there is an object in the frame
        pred_obj_scores: bool = False,
        # Whether to use an MLP to predict object scores
        pred_obj_scores_mlp: bool = False,
        # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True;
        # Whether to have a fixed no obj pointer when there is no object present
        # or to use it as an additive embedding with obj_ptr produced by decoder
        fixed_no_obj_ptr: bool = False,
        # Soft no object, i.e. mix in no_obj_ptr softly,
        # hope to make recovery easier if there is a mistake and mitigate accumulation of errors
        soft_no_obj_ptr: bool = False,
        use_mlp_for_obj_ptr_proj: bool = False,
        # add no obj embedding to spatial frames
        no_obj_embed_spatial: bool = False,
        # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
        sam_mask_decoder_extra_args=None,
        compile_image_encoder: bool = False,
        num_frame_to_prune: int = 2,
    ):
        super().__init__()

        # Part 1: the image backbone
        self.image_encoder = image_encoder
        # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
        self.use_high_res_features_in_sam = use_high_res_features_in_sam
        self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
        self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
        self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
        if use_obj_ptrs_in_encoder:
            # A conv layer to downsample the mask prompt to stride 4 (the same stride as
            # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
            # so that it can be fed into the SAM mask decoder to generate a pointer.
            self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
        self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
        if proj_tpos_enc_in_obj_ptrs:
            assert add_tpos_enc_to_obj_ptrs  # these options need to be used together
        self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
        self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs
        self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval

        # Part 2: memory attention to condition current frame's visual features
        # with memories (and obj ptrs) from past frames
        self.memory_attention = memory_attention
        self.hidden_dim = image_encoder.neck.d_model

        # Part 3: memory encoder for the previous frame's outputs
        self.memory_encoder = memory_encoder
        self.mem_dim = self.hidden_dim
        if hasattr(self.memory_encoder, "out_proj") and hasattr(
            self.memory_encoder.out_proj, "weight"
        ):
            # if there is compression of memories along channel dim
            self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]

        self.num_frame_to_prune = num_frame_to_prune
        self.num_maskmem = num_maskmem  # Number of memories accessible
        # Temporal encoding of the memories
        self.maskmem_tpos_enc = torch.nn.Parameter(
            torch.zeros(num_maskmem, 1, 1, self.mem_dim)
        )
        trunc_normal_(self.maskmem_tpos_enc, std=0.02)
        # a single token to indicate no memory embedding from previous frames
        self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
        self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
        trunc_normal_(self.no_mem_embed, std=0.02)
        trunc_normal_(self.no_mem_pos_enc, std=0.02)
        self.directly_add_no_mem_embed = directly_add_no_mem_embed
        # Apply sigmoid to the output raw mask logits (to turn them from
        # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
        self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
        self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
        self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
        self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
        self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
        # On frames with mask input, whether to directly output the input mask without
        # using a SAM prompt encoder + mask decoder
        self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
        self.multimask_output_in_sam = multimask_output_in_sam
        self.multimask_min_pt_num = multimask_min_pt_num
        self.multimask_max_pt_num = multimask_max_pt_num
        self.multimask_output_for_tracking = multimask_output_for_tracking
        self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
        self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid

        # Part 4: SAM-style prompt encoder (for both mask and point inputs)
        # and SAM-style mask decoder for the final mask output
        self.image_size = image_size
        self.backbone_stride = backbone_stride
        self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
        self.pred_obj_scores = pred_obj_scores
        self.pred_obj_scores_mlp = pred_obj_scores_mlp
        self.fixed_no_obj_ptr = fixed_no_obj_ptr
        self.soft_no_obj_ptr = soft_no_obj_ptr
        if self.fixed_no_obj_ptr:
            assert self.pred_obj_scores
            assert self.use_obj_ptrs_in_encoder
        if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
            self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
            trunc_normal_(self.no_obj_ptr, std=0.02)
        self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
        self.no_obj_embed_spatial = None
        if no_obj_embed_spatial:
            self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
            trunc_normal_(self.no_obj_embed_spatial, std=0.02)

        self._build_sam_heads()
        self.max_cond_frames_in_attn = max_cond_frames_in_attn

        # Model compilation
        if compile_image_encoder:
            # Compile the forward function (not the full module) to allow loading checkpoints.
            print(
                "Image encoder compilation is enabled. First forward pass will be slow."
            )
            self.image_encoder.forward = torch.compile(
                self.image_encoder.forward,
                mode="max-autotune",
                fullgraph=True,
                dynamic=False,
            )

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, *args, **kwargs):
        raise NotImplementedError(
            "Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuning"
            "See notebooks/video_predictor_example.ipynb for an inference example."
        )

    def _build_sam_heads(self):
        """Build SAM-style prompt encoder and mask decoder."""
        self.sam_prompt_embed_dim = self.hidden_dim
        self.sam_image_embedding_size = self.image_size // self.backbone_stride

        # build PromptEncoder and MaskDecoder from SAM
        # (their hyperparameters like `mask_in_chans=16` are from SAM code)
        self.sam_prompt_encoder = PromptEncoder(
            embed_dim=self.sam_prompt_embed_dim,
            image_embedding_size=(
                self.sam_image_embedding_size,
                self.sam_image_embedding_size,
            ),
            input_image_size=(self.image_size, self.image_size),
            mask_in_chans=16,
        )
        self.sam_mask_decoder = MaskDecoder(
            num_multimask_outputs=3,
            transformer=TwoWayTransformer(
                depth=2,
                embedding_dim=self.sam_prompt_embed_dim,
                mlp_dim=2048,
                num_heads=8,
            ),
            transformer_dim=self.sam_prompt_embed_dim,
            iou_head_depth=3,
            iou_head_hidden_dim=256,
            use_high_res_features=self.use_high_res_features_in_sam,
            iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
            pred_obj_scores=self.pred_obj_scores,
            pred_obj_scores_mlp=self.pred_obj_scores_mlp,
            use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
            **(self.sam_mask_decoder_extra_args or {}),
        )
        if self.use_obj_ptrs_in_encoder:
            # a linear projection on SAM output tokens to turn them into object pointers
            self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
            if self.use_mlp_for_obj_ptr_proj:
                self.obj_ptr_proj = MLP(
                    self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
                )
        else:
            self.obj_ptr_proj = torch.nn.Identity()
        if self.proj_tpos_enc_in_obj_ptrs:
            # a linear projection on temporal positional encoding in object pointers to
            # avoid potential interference with spatial positional encoding
            self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
        else:
            self.obj_ptr_tpos_proj = torch.nn.Identity()

    def _forward_sam_heads(
        self,
        backbone_features,
        point_inputs=None,
        mask_inputs=None,
        high_res_features=None,
        multimask_output=False,
    ):
        """
        Forward SAM prompt encoders and mask heads.

        Inputs:
        - backbone_features: image features of [B, C, H, W] shape
        - point_inputs: a dictionary with "point_coords" and "point_labels", where
          1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the
             absolute pixel-unit coordinate in (x, y) format of the P input points
          2) "point_labels" has shape [B, P] and int32 dtype, where 1 means
             positive clicks, 0 means negative clicks, and -1 means padding
        - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the
          same spatial size as the image.
        - high_res_features: either 1) None or 2) or a list of length 2 containing
          two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively,
          which will be used as high-resolution feature maps for SAM decoder.
        - multimask_output: if it's True, we output 3 candidate masks and their 3
          corresponding IoU estimates, and if it's False, we output only 1 mask and
          its corresponding IoU estimate.

        Outputs:
        - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if
          `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM
          output mask logits (before sigmoid) for the low-resolution masks, with 4x
          the resolution (1/4 stride) of the input backbone_features.
        - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3
          if `multimask_output=True` and M = 1 if `multimask_output=False`),
          upsampled from the low-resolution masks, with shape size as the image
          (stride is 1 pixel).
        - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1
          if `multimask_output=False`), the estimated IoU of each output mask.
        - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`.
          If `multimask_output=True`, it's the mask with the highest IoU estimate.
          If `multimask_output=False`, it's the same as `low_res_multimasks`.
        - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`.
          If `multimask_output=True`, it's the mask with the highest IoU estimate.
          If `multimask_output=False`, it's the same as `high_res_multimasks`.
        - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted
          based on the output token from the SAM mask decoder.
        """
        B = backbone_features.size(0)
        device = backbone_features.device
        assert backbone_features.size(1) == self.sam_prompt_embed_dim
        assert backbone_features.size(2) == self.sam_image_embedding_size
        assert backbone_features.size(3) == self.sam_image_embedding_size

        # a) Handle point prompts
        if point_inputs is not None:
            sam_point_coords = point_inputs["point_coords"]
            sam_point_labels = point_inputs["point_labels"]
            assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
        else:
            # If no points are provide, pad with an empty point (with label -1)
            sam_point_coords = torch.zeros(B, 1, 2, device=device)
            sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)

        # b) Handle mask prompts
        if mask_inputs is not None:
            # If mask_inputs is provided, downsize it into low-res mask input if needed
            # and feed it as a dense mask prompt into the SAM mask encoder
            assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
            if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
                sam_mask_prompt = F.interpolate(
                    mask_inputs.float(),
                    size=self.sam_prompt_encoder.mask_input_size,
                    align_corners=False,
                    mode="bilinear",
                    antialias=True,  # use antialias for downsampling
                )
            else:
                sam_mask_prompt = mask_inputs
        else:
            # Otherwise, simply feed None (and SAM's prompt encoder will add
            # a learned `no_mask_embed` to indicate no mask input in this case).
            sam_mask_prompt = None

        sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
            points=(sam_point_coords, sam_point_labels),
            boxes=None,
            masks=sam_mask_prompt,
        )

        
        (
            low_res_multimasks,
            ious,
            sam_output_tokens,
            object_score_logits,
        ) = self.sam_mask_decoder(
            image_embeddings=backbone_features,
            image_pe=self.sam_prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=multimask_output,
            repeat_image=False,  # the image is already batched
            high_res_features=high_res_features,
        )
        # print(low_res_multimasks.shape)
        # exit()
        
        if self.pred_obj_scores:
            is_obj_appearing = object_score_logits > 0

            # Mask used for spatial memories is always a *hard* choice between obj and no obj,
            # consistent with the actual mask prediction
            low_res_multimasks = torch.where(
                is_obj_appearing[:, None, None],
                low_res_multimasks,
                NO_OBJ_SCORE,
            )

        # convert masks from possibly bfloat16 (or float16) to float32
        # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
        low_res_multimasks = low_res_multimasks.float()
        high_res_multimasks = F.interpolate(
            low_res_multimasks,
            size=(self.image_size, self.image_size),
            mode="bilinear",
            align_corners=False,
        )

        sam_output_token = sam_output_tokens[:, 0]
        if multimask_output:
            # take the best mask prediction (with the highest IoU estimation)
            best_iou_inds = torch.argmax(ious, dim=-1)
            batch_inds = torch.arange(B, device=device)
            low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
            high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
            if sam_output_tokens.size(1) > 1:
                sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
        else:
            low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks

        # Extract object pointer from the SAM output token (with occlusion handling)
        obj_ptr = self.obj_ptr_proj(sam_output_token)
        if self.pred_obj_scores:
            # Allow *soft* no obj ptr, unlike for masks
            if self.soft_no_obj_ptr:
                lambda_is_obj_appearing = object_score_logits.sigmoid()
            else:
                lambda_is_obj_appearing = is_obj_appearing.float()

            if self.fixed_no_obj_ptr:
                obj_ptr = lambda_is_obj_appearing * obj_ptr
            obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr

        # print('ios:', ious)   # torch.Size([1, 3])
        # print('object_score_logits:', object_score_logits)    # torch.Size([1, 1])
        return (
            low_res_multimasks,
            high_res_multimasks,
            ious,
            low_res_masks,
            high_res_masks,
            obj_ptr,
            object_score_logits,
        )

    def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
        """
        Directly turn binary `mask_inputs` into a output mask logits without using SAM.
        (same input and output shapes as in _forward_sam_heads above).
        """
        # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
        out_scale, out_bias = 20.0, -10.0  # sigmoid(-10.0)=4.5398e-05
        mask_inputs_float = mask_inputs.float()
        high_res_masks = mask_inputs_float * out_scale + out_bias
        low_res_masks = F.interpolate(
            high_res_masks,
            size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
            align_corners=False,
            mode="bilinear",
            antialias=True,  # use antialias for downsampling
        )
        # a dummy IoU prediction of all 1's under mask input
        ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
        if not self.use_obj_ptrs_in_encoder:
            # all zeros as a dummy object pointer (of shape [B, C])
            obj_ptr = torch.zeros(
                mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device
            )
        else:
            # produce an object pointer using the SAM decoder from the mask input
            _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
                backbone_features=backbone_features,
                mask_inputs=self.mask_downsample(mask_inputs_float),
                high_res_features=high_res_features,
            )
        # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
        # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
        # on the object_scores from the SAM decoder.
        is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
        is_obj_appearing = is_obj_appearing[..., None]
        lambda_is_obj_appearing = is_obj_appearing.float()
        object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
        if self.pred_obj_scores:
            if self.fixed_no_obj_ptr:
                obj_ptr = lambda_is_obj_appearing * obj_ptr
            obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr

        return (
            low_res_masks,
            high_res_masks,
            ious,
            low_res_masks,
            high_res_masks,
            obj_ptr,
            object_score_logits,
        )

    def forward_image(self, img_batch: torch.Tensor):
        """Get the image feature on the input batch."""
        backbone_out = self.image_encoder(img_batch)
        if self.use_high_res_features_in_sam:
            # precompute projected level 0 and level 1 features in SAM decoder
            # to avoid running it again on every SAM click
            backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
                backbone_out["backbone_fpn"][0]
            )
            backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
                backbone_out["backbone_fpn"][1]
            )
        return backbone_out

    def _prepare_backbone_features(self, backbone_out):
        """Prepare and flatten visual features."""
        backbone_out = backbone_out.copy()
        assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
        assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels

        feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
        vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]

        feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
        # flatten NxCxHxW to HWxNxC
        vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
        vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]

        return backbone_out, vision_feats, vision_pos_embeds, feat_sizes

    def _prepare_memory_conditioned_features(
        self,
        frame_idx,
        is_init_cond_frame,
        current_vision_feats,
        current_vision_pos_embeds,
        feat_sizes,
        output_dict,
        num_frames,
        track_in_reverse=False,  # tracking in reverse time order (for demo usage)
    ):
        """Fuse the current frame's visual feature map with previous memory."""
        B = current_vision_feats[-1].size(1)  # batch size on this frame
        C = self.hidden_dim
        H, W = feat_sizes[-1]  # top-level (lowest-resolution) feature size
        device = current_vision_feats[-1].device
        # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
        # In this case, we skip the fusion with any memory.
        if self.num_maskmem == 0:  # Disable memory and skip fusion
            pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
            return pix_feat

        num_obj_ptr_tokens = 0
        tpos_sign_mul = -1 if track_in_reverse else 1
        # Step 1: condition the visual features of the current frame on previous memories
        if not is_init_cond_frame:
            # Retrieve the memories encoded with the maskmem backbone
            to_cat_memory, to_cat_memory_pos_embed = [], []
            # Add conditioning frames's output first (all cond frames have t_pos=0 for
            # when getting temporal positional embedding below)
            assert len(output_dict["cond_frame_outputs"]) > 0
            # Select a maximum number of temporally closest cond frames for cross attention
            cond_outputs = output_dict["cond_frame_outputs"]
            selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
                frame_idx, cond_outputs, self.max_cond_frames_in_attn
            )
            t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
            # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
            # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
            # We also allow taking the memory frame non-consecutively (with stride>1), in which case
            # we take (self.num_maskmem - 2) frames among every stride-th frames plus the last frame.
            stride = 1 if self.training else self.memory_temporal_stride_for_eval
            for t_pos in range(1, self.num_maskmem):
                t_rel = self.num_maskmem - t_pos  # how many frames before current frame
                if t_rel == 1:
                    # for t_rel == 1, we take the last frame (regardless of r)
                    if not track_in_reverse:
                        # the frame immediately before this frame (i.e. frame_idx - 1)
                        prev_frame_idx = frame_idx - t_rel
                    else:
                        # the frame immediately after this frame (i.e. frame_idx + 1)
                        prev_frame_idx = frame_idx + t_rel
                else:
                    # for t_rel >= 2, we take the memory frame from every r-th frames
                    if not track_in_reverse:
                        # first find the nearest frame among every r-th frames before this frame
                        # for r=1, this would be (frame_idx - 2)
                        prev_frame_idx = ((frame_idx - 2) // stride) * stride
                        # then seek further among every r-th frames
                        prev_frame_idx = prev_frame_idx - (t_rel - 2) * stride
                    else:
                        # first find the nearest frame among every r-th frames after this frame
                        # for r=1, this would be (frame_idx + 2)
                        prev_frame_idx = -(-(frame_idx + 2) // stride) * stride
                        # then seek further among every r-th frames
                        prev_frame_idx = prev_frame_idx + (t_rel - 2) * stride
                # print(prev_frame_idx)
                out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
                if out is None:
                    # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
                    # frames, we still attend to it as if it's a non-conditioning frame.
                    out = unselected_cond_outputs.get(prev_frame_idx, None)
                t_pos_and_prevs.append((t_pos, out))
            # print([t[0] for t in t_pos_and_prevs])
            for t_pos, prev in t_pos_and_prevs:
                if prev is None:
                    continue  # skip padding frames
                # "maskmem_features" might have been offloaded to CPU in demo use cases,
                # so we load it back to GPU (it's a no-op if it's already on GPU).
                # print(t_pos, self.num_maskmem - t_pos - 1)
                # print(type(self.maskmem_tpos_enc))
                # print(self.maskmem_tpos_enc.data.shape)
                feats = prev["maskmem_features"].to(device, non_blocking=True)
                to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
                # Spatial positional encoding (it might have been offloaded to CPU in eval)
                maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
                maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
                # Temporal positional encoding
                # if t_pos == 0:
                #     maskmem_enc = (
                #         maskmem_enc + self.maskmem_tpos_enc[6]
                #     )
                # else:
                maskmem_enc = (
                    maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
                )
                to_cat_memory_pos_embed.append(maskmem_enc)

            # remove self.num_frame_to_prune=2 memory frames most similar with the last frame
            if self.Mem_Frame_Prune and (len(to_cat_memory) + self.num_frame_to_prune) > self.num_maskmem:
                # To constrain the memory frame number to self.num_maskmem - self.num_frame_to_prune
                this_num_frame_to_prune = len(to_cat_memory) + self.num_frame_to_prune - self.num_maskmem

                # get the last frame's feature
                last_vision_feature = to_cat_memory[-1]
                repeat_dims = [len(to_cat_memory) - 2] + [1] * (len(last_vision_feature.shape))
                last_frame_expanded = last_vision_feature.unsqueeze(0).repeat(*repeat_dims)  # [num_frames-2, 1024, bs, 64]
                # get the candidate frames
                candidate_frames = torch.stack(to_cat_memory[1:-1], dim=0)  # [num_frames-2, 1024, bs, 64]

                # compute the cosine similarity between the candidate frames and the last frame
                similarities = torch.cosine_similarity(candidate_frames, last_frame_expanded, dim=-1)  # [num_frames-2, 1024, bs]
                similarities = torch.sum(similarities, dim=1)  # [num_frames-2, bs]
                # [num_frames-2, bs] -> [num_frames-2], to support batch in training, we use mean instead of squeeze
                similarities = similarities.mean(dim=1)

                # the ranking is from large to small
                _, sorted_indices = torch.sort(similarities, descending=True)
                # we delete the frames with the largest cosine similarity, which are the most similar frames
                delete_indices = sorted_indices[:this_num_frame_to_prune].cpu().numpy() + 1  # +1 because we exclude the first frame
                # delete 'self.num_frame_to_prune' frames from the end to the beginning, or it might cause index error
                delete_indices = sorted(delete_indices, reverse=True)
                # print(delete_indices)
                if self.enable_MeP_info:
                    self.mem_info['frame_delete_indices'] = delete_indices
                    if self.mem_info['enable_mem_prune']:
                        frame_mask = torch.ones(5, dtype=torch.bool, device='cuda')
                        for j in self.mem_info['frame_delete_indices']:
                            frame_mask[j-1]=False
                        self.mem_info['frame_mask'] = frame_mask
                # delete the frames
                for i in delete_indices:
                    to_cat_memory.pop(i)
                    to_cat_memory_pos_embed.pop(i)

            # Construct the list of past object pointers
            if self.use_obj_ptrs_in_encoder:
                max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
                # First add those object pointers from selected conditioning frames
                # (optionally, only include object pointers in the past during evaluation)
                if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
                    ptr_cond_outputs = {
                        t: out
                        for t, out in selected_cond_outputs.items()
                        if (t >= frame_idx if track_in_reverse else t <= frame_idx)
                    }
                else:
                    ptr_cond_outputs = selected_cond_outputs
                pos_and_ptrs = [
                    # Temporal pos encoding contains how far away each pointer is from current frame
                    (
                        (
                            (frame_idx - t) * tpos_sign_mul
                            if self.use_signed_tpos_enc_to_obj_ptrs
                            else abs(frame_idx - t)
                        ),
                        out["obj_ptr"],
                    )
                    for t, out in ptr_cond_outputs.items()
                ]
                # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
                for t_diff in range(1, max_obj_ptrs_in_encoder):
                    t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
                    if t < 0 or (num_frames is not None and t >= num_frames):
                        break
                    out = output_dict["non_cond_frame_outputs"].get(
                        t, unselected_cond_outputs.get(t, None)
                    )
                    if out is not None:
                        pos_and_ptrs.append((t_diff, out["obj_ptr"]))
                # If we have at least one object pointer, add them to the across attention
                if len(pos_and_ptrs) > 0:
                    pos_list, ptrs_list = zip(*pos_and_ptrs)
                    # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
                    obj_ptrs = torch.stack(ptrs_list, dim=0)
                    # a temporal positional embedding based on how far each object pointer is from
                    # the current frame (sine embedding normalized by the max pointer num).
                    if self.add_tpos_enc_to_obj_ptrs:
                        t_diff_max = max_obj_ptrs_in_encoder - 1
                        tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
                        obj_pos = torch.tensor(pos_list).to(
                            device=device, non_blocking=True
                        )
                        obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
                        obj_pos = self.obj_ptr_tpos_proj(obj_pos)
                        obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
                    else:
                        obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
                    if self.mem_dim < C:
                        # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
                        obj_ptrs = obj_ptrs.reshape(
                            -1, B, C // self.mem_dim, self.mem_dim
                        )
                        obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
                        obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
                    to_cat_memory.append(obj_ptrs)
                    to_cat_memory_pos_embed.append(obj_pos)
                    num_obj_ptr_tokens = obj_ptrs.shape[0]
                else:
                    num_obj_ptr_tokens = 0
        else:
            # for initial conditioning frames, encode them without using any previous memory
            if self.directly_add_no_mem_embed:
                # directly add no-mem embedding (instead of using the transformer encoder)
                pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
                pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
                return pix_feat_with_mem

            # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder)
            to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
            to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
            
        

        # Step 2: Concatenate the memories and forward through the transformer encoder
        memory = torch.cat(to_cat_memory, dim=0)
        memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
        
        # MeP
        '''
        if self.enable_MeP_info:
            frame_delete_indices = self.mem_info['frame_delete_indices']
            # print(frame_delete_indices)
            mem_masks_stack_obj = self.mem_info['mem_masks_stack'].get(self.mem_info['obj_idx'],None)
            # mem_SP_masks_stack_obj = self.mem_info['mem_SP_masks_stack'].get(obj_idx,None)
            if mem_masks_stack_obj != None:
                self.mem_info['mem_masks'][self.mem_info['obj_idx']] = []
                # self.mem_info['mem_SP_masks'][obj_idx]=[]
                if len(mem_masks_stack_obj) == 5:
                    for i in range(4):
                        mem_masks_obj_Li = [mask[i] for mask in mem_masks_stack_obj]
                        # print(len(mem_masks_obj_Li))
                        for j in frame_delete_indices:
                            mem_masks_obj_Li.pop(j-1)
                        # print(len(mem_masks_obj_Li))
                        self.mem_info['mem_masks'][self.mem_info['obj_idx']].append(torch.cat(mem_masks_obj_Li,dim=0))

        '''
        pix_feat_with_mem = self.memory_attention(
            curr=current_vision_feats,
            curr_pos=current_vision_pos_embeds,
            memory=memory.detach(),
            memory_pos=memory_pos_embed.detach(),
            num_obj_ptr_tokens=num_obj_ptr_tokens,
        )
        # reshape the output (HW)BC => BCHW
        pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
        return pix_feat_with_mem

    def _encode_new_memory(
        self,
        current_vision_feats,
        feat_sizes,
        pred_masks_high_res,
        object_score_logits,
        is_mask_from_pts,
    ):
        """Encode the current image and its prediction into a memory feature."""
        B = current_vision_feats[-1].size(1)  # batch size on this frame
        C = self.hidden_dim
        H, W = feat_sizes[-1]  # top-level (lowest-resolution) feature size
        # top-level feature, (HW)BC => BCHW
        pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
        if self.non_overlap_masks_for_mem_enc and not self.training:
            # optionally, apply non-overlapping constraints to the masks (it's applied
            # in the batch dimension and should only be used during eval, where all
            # the objects come from the same video under batch size 1).
            pred_masks_high_res = self._apply_non_overlapping_constraints(
                pred_masks_high_res
            )
        # scale the raw mask logits with a temperature before applying sigmoid
        binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
        if binarize and not self.training:
            mask_for_mem = (pred_masks_high_res > 0).float()
        else:
            # apply sigmoid on the raw mask logits to turn them into range (0, 1)
            mask_for_mem = torch.sigmoid(pred_masks_high_res)
        # apply scale and bias terms to the sigmoid probabilities
        if self.sigmoid_scale_for_mem_enc != 1.0:
            mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
        if self.sigmoid_bias_for_mem_enc != 0.0:
            mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
        maskmem_out = self.memory_encoder(
            pix_feat, mask_for_mem, skip_mask_sigmoid=True  # sigmoid already applied
        )
        maskmem_features = maskmem_out["vision_features"]
        maskmem_pos_enc = maskmem_out["vision_pos_enc"]
        # add a no-object embedding to the spatial memory to indicate that the frame
        # is predicted to be occluded (i.e. no object is appearing in the frame)
        if self.no_obj_embed_spatial is not None:
            is_obj_appearing = (object_score_logits > 0).float()
            maskmem_features += (
                1 - is_obj_appearing[..., None, None]
            ) * self.no_obj_embed_spatial[..., None, None].expand(
                *maskmem_features.shape
            )

        return maskmem_features, maskmem_pos_enc

    def _track_step(
        self,
        frame_idx,
        is_init_cond_frame,
        current_vision_feats,
        current_vision_pos_embeds,
        feat_sizes,
        point_inputs,
        mask_inputs,
        output_dict,
        num_frames,
        track_in_reverse,
        prev_sam_mask_logits,
    ):
        current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
        # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
        if len(current_vision_feats) > 1:
            high_res_features = [
                x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
                for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
            ]
        else:
            high_res_features = None
        if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
            # When use_mask_input_as_output_without_sam=True, we directly output the mask input
            # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
            pix_feat = current_vision_feats[-1].permute(1, 2, 0)
            pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
            sam_outputs = self._use_mask_as_output(
                pix_feat, high_res_features, mask_inputs
            )
            # print('use_mask_input_as_output_without_sam')
        else:
            # fused the visual feature with previous memory features in the memory bank
            torch.cuda.synchronize()
            st = time.time()
            pix_feat = self._prepare_memory_conditioned_features(
                frame_idx=frame_idx,
                is_init_cond_frame=is_init_cond_frame,
                current_vision_feats=current_vision_feats[-1:],
                current_vision_pos_embeds=current_vision_pos_embeds[-1:],
                feat_sizes=feat_sizes[-1:],
                output_dict=output_dict,
                num_frames=num_frames,
                track_in_reverse=track_in_reverse,
            )
            torch.cuda.synchronize()
            ed = time.time()
            if frame_idx > 0:
                self.time_log[frame_idx]['Mem_attn'].append(ed-st)
            # apply SAM-style segmentation head
            # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
            # e.g. in demo where such logits come from earlier interaction instead of correction sampling
            # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
            if prev_sam_mask_logits is not None:
                assert point_inputs is not None and mask_inputs is None
                mask_inputs = prev_sam_mask_logits
            multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
            
            torch.cuda.synchronize()
            st = time.time()
            sam_outputs = self._forward_sam_heads(
                backbone_features=pix_feat,
                point_inputs=point_inputs,
                mask_inputs=mask_inputs,
                high_res_features=high_res_features,
                multimask_output=multimask_output,
            )
            torch.cuda.synchronize()
            ed = time.time()
            if frame_idx > 0:
                self.time_log[frame_idx]['MD'].append(ed-st)

        return current_out, sam_outputs, high_res_features, pix_feat

    def _encode_memory_in_output(
        self,
        current_vision_feats,
        feat_sizes,
        point_inputs,
        run_mem_encoder,
        high_res_masks,
        object_score_logits,
        current_out,
    ):
        if run_mem_encoder and self.num_maskmem > 0:
            high_res_masks_for_mem_enc = high_res_masks
            maskmem_features, maskmem_pos_enc = self._encode_new_memory(
                current_vision_feats=current_vision_feats,
                feat_sizes=feat_sizes,
                pred_masks_high_res=high_res_masks_for_mem_enc,
                object_score_logits=object_score_logits,
                is_mask_from_pts=(point_inputs is not None),
            )
            current_out["maskmem_features"] = maskmem_features
            current_out["maskmem_pos_enc"] = maskmem_pos_enc
        else:
            current_out["maskmem_features"] = None
            current_out["maskmem_pos_enc"] = None

    def track_step(
        self,
        frame_idx,
        is_init_cond_frame,
        current_vision_feats,
        current_vision_pos_embeds,
        feat_sizes,
        point_inputs,
        mask_inputs,
        output_dict,
        num_frames,
        track_in_reverse=False,  # tracking in reverse time order (for demo usage)
        # Whether to run the memory encoder on the predicted masks. Sometimes we might want
        # to skip the memory encoder with `run_mem_encoder=False`. For example,
        # in demo we might call `track_step` multiple times for each user click,
        # and only encode the memory when the user finalizes their clicks. And in ablation
        # settings like SAM training on static images, we don't need the memory encoder.
        run_mem_encoder=True,
        # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
        prev_sam_mask_logits=None,
    ):
        if self.enable_MeP_info:
            if not self.disable_WB:
            # if True:
                Pt_ATM_module = [
                        self.sam_mask_decoder.transformer.layers[0].cross_attn_token_to_image, 
                        self.sam_mask_decoder.transformer.layers[1].cross_attn_token_to_image, 
                        self.sam_mask_decoder.transformer.final_attn_token_to_image
                        ]
                Pt_ATM_hooks = [activation_hook(act, True) for act in Pt_ATM_module]
            
            if self.mem_info['enable_mem_prune'] and (frame_idx-1) % self.memory_temporal_stride_for_eval == 0:
                Mem_ATM_module = [
                        self.memory_attention.layers[0].cross_attn_image, 
                        self.memory_attention.layers[1].cross_attn_image, 
                        self.memory_attention.layers[2].cross_attn_image, 
                        self.memory_attention.layers[3].cross_attn_image, 
                        ]
                Mem_ATM_hooks = [activation_hook(act, True) for act in Mem_ATM_module]
            
        current_out, sam_outputs, _, _ = self._track_step(
            frame_idx,
            is_init_cond_frame,
            current_vision_feats,
            current_vision_pos_embeds,
            feat_sizes,
            point_inputs,
            mask_inputs,
            output_dict,
            num_frames,
            track_in_reverse,
            prev_sam_mask_logits,
        )

        (
            low_res_multimasks,
            high_res_multimasks,
            ious,
            low_res_masks,
            high_res_masks,
            obj_ptr,
            object_score_logits,
        ) = sam_outputs
        
        current_out["pred_masks"] = low_res_masks
        current_out["pred_masks_high_res"] = high_res_masks
        current_out["obj_ptr"] = obj_ptr
        if not self.training:
            # Only add this in inference (to avoid unused param in activation checkpointing;
            # it's mainly used in the demo to encode spatial memories w/ consolidated masks)
            current_out["object_score_logits"] = object_score_logits
        
        best_iou_inds = torch.argmax(ious, dim=-1)
        current_out["best_iou"] = ious[0][best_iou_inds]
        
        # if frame_idx >5:
        #     print(output_dict["non_cond_frame_outputs"][1].keys())
        
        if self.enable_MeP_info:
            if not self.disable_WB:
            # if True:
                Pt_ATMs = []
                for Pt_ATM_hook in Pt_ATM_hooks:
                    # print(Pt_ATM_hook.feature[1][2].shape)
                    Pt_ATMs.append(Pt_ATM_hook.feature[1][2].squeeze(0)[:,0:7,:].mean(dim=(0,1)))
                    # Pt_ATMs.append(Pt_ATM_hook.feature[1][2].squeeze(0).mean(dim=(0,1)))
                    # Pt_ATMs.append(Pt_ATM_hook.feature[1][2].squeeze(0).mean(dim=(0,1)))
                    Pt_ATM_hook.remove()
                Pt_ATMs = torch.stack(Pt_ATMs,dim=0).mean(0).reshape(1,64,64)
                self.mem_info['pt_sen_region'].append(Pt_ATMs)
                self.mem_info['mask_region'].append(low_res_multimasks)
                
            self.mem_info['ious'].append(ious)
            # self.mem_info['mask_region'].append(low_res_multimasks)
            # self.mem_info['pt_sen_region'].append(Pt_ATMs)
            # print(object_score_logits)
            self.mem_info['obj_scores'].append(object_score_logits)
            self.mem_info['best_iou'].append(ious[0][best_iou_inds])
            
            
            '''
            # 重复计算attn
            if self.mem_info['enable_mem_prune'] and (frame_idx-1) % self.memory_temporal_stride_for_eval == 0:
                Mem_ATMs = []
                for Mem_ATM_hook in Mem_ATM_hooks:
                    q, k, v = Mem_ATM_hook.feature[1]
                    # print(k.shape)
                    # mem_mask = torch.ones(k.shape[-2], dtype=torch.bool, device=k.device)
                    # print(k.shape)
                    Mem_ATM_hook.remove()
                    
                    scale_factor = 1 / math.sqrt(q.size(-1))
                    ATM = q @ k.transpose(-2,-1) * scale_factor
                    ATM_norm = ATM.softmax(dim=-1).squeeze(0).squeeze(0)
                    if self.random_mask:
                        if self.set_drop_ratio == -1:
                            mem_mask = torch.randint(0, 2, (4096,), dtype=torch.bool, device=k.device)
                        else:
                            mem_mask = create_mask_rand(4096, self.set_drop_ratio, device=k.device)

                    else:
                        if self.VMW_mask:
                            v_L1 = torch.norm(v, p=1, dim=-1)
                            v_mag = v_L1/v_L1.mean()
                            attn_VMW = ATM_norm.mean(0)/v_mag
                            
                            mem_mask = create_cumulative_mask_optimized_PF(attn_VMW.unsqueeze(0), threshold=self.MTP_theta)   # torch.Size([28672])
                            
                        else:
                            mem_mask = create_cumulative_mask_optimized_PF(ATM_norm.mean(0, keepdim=True), threshold=self.MTP_theta)   # torch.Size([28672])
                    # print(mem_mask.shape)
                    # selected_indices = torch.nonzero(mem_mask, as_tuple=False).squeeze(-1)
                    # full_indices = torch.cat([
                    #                 torch.arange(4096, device=mem_mask.device),
                    #                 selected_indices + 4096,
                    #                 torch.arange(4096*6, 4096*7, device=mem_mask.device)
                    #             ])
                    # Mem_ATMs.append(full_indices)
                    Mem_ATMs.append(mem_mask)
                mem_masks_per_obj = self.mem_info['mem_masks_stack'].setdefault(self.mem_info['obj_idx'], [])
                if len(mem_masks_per_obj) == 5:
                    mem_masks_per_obj.pop(0)
                mem_masks_per_obj.append(Mem_ATMs) 
                # print(len(self.mem_info['mem_masks'][0]))
                    
                # print(Mem_ATMs)
            
            '''
            
            # '''
            if self.mem_info['enable_mem_prune'] and (frame_idx-1) % self.memory_temporal_stride_for_eval == 0:
                t_last_best_iou = output_dict["non_cond_frame_outputs"][frame_idx-1]['best_iou']
                t_last_obj_score = output_dict["non_cond_frame_outputs"][frame_idx-1]['object_score_logits']
                # print(t_last_best_iou, t_last_obj_score)
                if (t_last_obj_score < 0 or t_last_best_iou < 0.3) and self.MR_OA_U:
                    # Mem_ATMs = [torch.zeros(4096, dtype=torch.bool, device=t_last_best_iou.device) for i in range(4)]
                    Mem_ATMs = [self.uniform_mode for i in range(4)]
                    for Mem_ATM_hook in Mem_ATM_hooks:
                        Mem_ATM_hook.remove()
                    # set_drop_ratio = self.set_drop_ratio_up
                else:
                    if (t_last_obj_score < 0 or t_last_best_iou < 0.3) and self.MR_OA:
                        set_drop_ratio = self.set_drop_ratio_up
                    else:
                        set_drop_ratio = self.set_drop_ratio
                    
                    
                    Mem_ATMs = []
                    Mem_SP_masks = []
                    for Mem_ATM_hook in Mem_ATM_hooks:
                        ATM, v = Mem_ATM_hook.feature[1]
                        # ATM_norm = torch.softmax(ATM, dim=-1)
                        # print(attn.shape)
                        # exit()
                        # print(k.shape)
                        # mem_mask = torch.ones(k.shape[-2], dtype=torch.bool, device=k.device)
                        # print(k.shape)
                        Mem_ATM_hook.remove()
                        
                        # scale_factor = 1 / math.sqrt(q.size(-1))
                        # ATM = q @ k.transpose(-2,-1) * scale_factor
                        # ATM_norm = ATM.softmax(dim=-1).squeeze(0).squeeze(0)
                        if self.random_mask:
                            if self.set_drop_ratio == -1:
                                mem_mask = torch.randint(0, 2, (4096,), dtype=torch.bool, device=v.device)
                            else:
                                mem_mask = create_mask_rand(4096, self.set_drop_ratio, device=v.device)
                        elif self.uniform_mask:
                            mem_mask = create_mask_uniform(4096, 1-self.set_drop_ratio, device=v.device)
                        else:
                            if self.VMW_mask:
                                v_L1 = torch.norm(v, p=1, dim=-1).squeeze(0).squeeze(0)
                                v_mag = v_L1/v_L1.mean()
                                # print(ATM_norm.shape, ATM_norm.mean(0).shape, v_mag.shape, v_L1.shape)
                                ATM_norm = torch.softmax(ATM, dim=-1)
                                
                                attn_VMW = ATM_norm.mean(0)/v_mag
                                # print(attn_VMW.sum())
                                if self.topk_mask:
                                    # print(attn_VMW.shape)
                                    mem_mask = get_topk_mask(attn_VMW, int((1-self.set_drop_ratio)*4096))
                                else:
                                    mem_mask = create_cumulative_mask_optimized_PF(attn_VMW.unsqueeze(0), threshold=self.MTP_theta)   # torch.Size([28672])
                                
                            else:
                                if self.topk_mask:
                                    # print(ATM)
                                    ATM_norm = torch.softmax(ATM, dim=-1)
                                    mem_mask = get_topk_mask(ATM_norm.mean(0), int((1-set_drop_ratio)*4096))
                                else:
                                    ATM_norm = torch.softmax(ATM, dim=-1)
                                    mem_mask = create_cumulative_mask_optimized_PF(ATM_norm.mean(0, keepdim=True), threshold=self.MTP_theta)   # torch.Size([28672])
                        
                        # print(mem_mask.shape)
                        # print(mem_mask.dtype)
                        # Mem_ATMs.append(mem_mask)
                        
                        # memory merge matching
                        if self.Mem_SP:
                            v_s = v[0,0,mem_mask,:]
                            Mem_SP_mask = create_Mem_SP_mask(v_s, theta=self.Mem_SP_theta)
                            # Mem_SP_masks.append(Mem_SP_mask)
                            # print(mem_mask.sum(), Mem_SP_mask.sum())
                            mem_mask = merge_boolean_masks(mem_mask, Mem_SP_mask)
                        # print(mem_mask.sum())
                        
                        if self.MR_OA_U:
                            mem_mask = mem_mask | self.uniform_mode

                        Mem_ATMs.append(mem_mask)

                # print(sum(mem_mask), sum(Mem_SP_mask))
                # kv_merge = bipartite_soft_matching(v_s, int(v_s.shape[0]*0.2))
                # v_m = kv_merge(v_s)
                # print(mem_mask.sum(), v.shape, v_s.shape, v_m.shape)
                # ---merge_ids = get_bi_matching_ids(v_s, 0.2)
                # ---Mem_merge_ids.append(merge_ids)
                # print(merge_ids)

                    
                    
                mem_masks_per_obj = self.mem_info['mem_masks_stack'].setdefault(self.mem_info['obj_idx'], [])
                # mem_SP_masks_per_obj = self.mem_info['mem_SP_masks_stack'].setdefault(self.mem_info['obj_idx'], [])
                if len(mem_masks_per_obj) == 5:
                    mem_masks_per_obj.pop(0)
                    # mem_SP_masks_per_obj.pop(0)
                mem_masks_per_obj.append(Mem_ATMs)
                # mem_SP_masks_per_obj.append(Mem_SP_masks)
                # print(len(mem_masks_per_obj))
                
                # print(len(self.mem_info['mem_masks'][0]))
                    
                # print(Mem_ATMs)
            # '''
            
            # Pt_ATMs = torch.stack(Pt_ATMs,dim=0)
            # Pt_ATMs = Pt_ATMs.mean(0).reshape(64,64)
            # Pt_ATMs = F.pad(Pt_ATMs, (0, 6, 0, 6), "constant", 0) # 在右边和下边 padding 0
            # # 将 tensor B 重塑为 14x14 的窗口，然后计算每个窗口的和
            # # Pt_ATMs = F.unfold(Pt_ATMs, kernel_size=14, stride=14).sum(dim=1).view(-1)
            # Pt_ATMs = Pt_ATMs.reshape(5,14,5,14).permute(0,2,1,3).reshape(25,-1).sum(-1)
            # # T = (Pt_ATMs.mean() + Pt_ATMs.std())
            # T=0.04
            
            # indices = torch.where(Pt_ATMs > T)[0].cpu().tolist()
            # indices = torch.where(A > (A.mean() + A.std()))[0]
            # Pt_ATMs = Pt_ATMs.
            
            # print('prompt attention focus windows (Threshold:{}):\n'.format(T),indices)
        
            
            
            
        # print(low_res_masks.shape)
        # exit()
        
        
        # 训练bypass暂时加上
        # current_out["object_score_logits"] = object_score_logits

        # Finally run the memory encoder on the predicted mask to encode
        # it into a new memory feature (that can be used in future frames)
        torch.cuda.synchronize()
        st = time.time()
        self._encode_memory_in_output(
            current_vision_feats,
            feat_sizes,
            point_inputs,
            run_mem_encoder,
            high_res_masks,
            object_score_logits,
            current_out,
        )
        ed = time.time()
        if frame_idx > 0:
            self.time_log[frame_idx]['Mem_E'].append(ed-st)

        return current_out


    def _use_multimask(self, is_init_cond_frame, point_inputs):
        """Whether to use multimask output in the SAM head."""
        num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
        multimask_output = (
            self.multimask_output_in_sam
            and (is_init_cond_frame or self.multimask_output_for_tracking)
            and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
        )
        return multimask_output

    def _apply_non_overlapping_constraints(self, pred_masks):
        """
        Apply non-overlapping constraints to the object scores in pred_masks. Here we
        keep only the highest scoring object at each spatial location in pred_masks.
        """
        batch_size = pred_masks.size(0)
        if batch_size == 1:
            return pred_masks

        device = pred_masks.device
        # "max_obj_inds": object index of the object with the highest score at each location
        max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
        # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
        batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
        keep = max_obj_inds == batch_obj_inds
        # suppress overlapping regions' scores below -10.0 so that the foreground regions
        # don't overlap (here sigmoid(-10.0)=4.5398e-05)
        pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
        return pred_masks
