# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import torch
import torch.nn as nn
from functools import partial
from torch import Tensor
from typing import Optional

from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_, lecun_normal_

from timm.models.layers import DropPath, to_2tuple
from timm.models.vision_transformer import _load_weights

import math

from collections import namedtuple

from mamba_ssm.modules.mamba_simple import Mamba
# from lib.models.Mamba_block.mamba_simple import Mamba 
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf

from lib.models.mamba_fetrack.dynamic_mamba_simple import EventMamba

from lib.models.mamba_fetrack.rope import *
import random
from lib.models.layers.head import build_box_head
import importlib
import lib.train.admin.settings as ws_settings
from .utils import combine_tokens, recover_tokens

try:
    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn

except ImportError:
    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
    

    
    


__all__ = [
    'vim_tiny_patch16_224', 'vim_small_patch16_224', 'vim_base_patch16_224',
    'vim_tiny_patch16_384', 'vim_small_patch16_384', 'vim_base_patch16_384',
]


class PatchEmbed(nn.Module):
    """ 2D Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, stride=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = ((img_size[0] - patch_size[0]) // stride + 1, (img_size[1] - patch_size[1]) // stride + 1)
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.flatten = flatten

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        # assert H == self.img_size[0] and W == self.img_size[1], \
        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC
        x = self.norm(x)
        return x
    

class Block(nn.Module):
    def __init__(
        self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False,drop_path=0.,
    ):
        """
        Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"

        This Block has a slightly different structure compared to a regular
        prenorm Transformer block.
        The standard block is: LN -> MHA/MLP -> Add.
        [Ref: https://arxiv.org/abs/2002.04745]
        Here we have: Add -> LN -> Mixer, returning both
        the hidden_states (output of the mixer) and the residual.
        This is purely for performance reasons, as we can fuse add and LayerNorm.
        The residual needs to be provided (except for the very first block).
        """
        super().__init__()
        self.residual_in_fp32 = residual_in_fp32
        self.fused_add_norm = fused_add_norm
        self.mixer = mixer_cls(dim)
        self.norm = norm_cls(dim)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        if self.fused_add_norm:
            assert RMSNorm is not None, "RMSNorm import fails"
            assert isinstance(
                self.norm, (nn.LayerNorm, RMSNorm)
            ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"

    def forward(
        self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, current_A: Optional[Tensor] = None
    ):
        r"""Pass the input through the encoder layer.

        Args:
            hidden_states: the sequence to the encoder layer (required).
            residual: hidden_states = Mixer(LN(residual))
            inference_params: parameters for inference.
            current_A: Optional dynamic state matrix for the mixer.
        """
        if not self.fused_add_norm:
            if residual is None:
                residual = hidden_states
            else:
                residual = residual + self.drop_path(hidden_states)
            
            hidden_states_norm = self.norm(residual.to(dtype=self.norm.weight.dtype))
            if self.residual_in_fp32:
                residual = residual.to(torch.float32)
        else:
            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
            # Correct logic: apply norm to the sum of residual and dropped hidden_states (if residual exists)
            # The result of norm is the input to the mixer.
            input_to_norm = self.drop_path(hidden_states) if residual is None else residual + self.drop_path(hidden_states)
            
            hidden_states_norm, residual = fused_add_norm_fn(
                input_to_norm, # Apply norm to the potentially modified input
                self.norm.weight,
                self.norm.bias,
                residual=residual if residual is None else None, # Pass residual only if it was None initially for prenorm=True
                prenorm=True, # The norm is before the mixer
                residual_in_fp32=self.residual_in_fp32,
                eps=self.norm.eps,
            )
            # Note: The fused_add_norm_fn with prenorm=True returns:
            # 1. normalized(input + residual) if residual is not None
            # 2. normalized(input) if residual is None
            # And the updated residual (input + residual or just input)

        # Pass the normalized hidden_states and current_A to the mixer
        # Check if the mixer is DynamicMamba to pass current_A
        if isinstance(self.mixer, DynamicMamba) and current_A is not None:
            hidden_states_out = self.mixer(hidden_states_norm, current_A=current_A, inference_params=inference_params)
        else:
             # Fallback for original Mamba or if current_A is not provided/needed
             hidden_states_out = self.mixer(hidden_states_norm, inference_params=inference_params)

        return hidden_states_out, residual

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

# 添加动态状态转移矩阵的Mamba扩展
class DynamicMamba(EventMamba):
    """Mamba类的扩展，增加对动态状态转移矩阵的支持"""
    
    def __init__(self, d_model, A_log=None, **kwargs):
        # Initialize EventMamba, which now potentially holds its own A_log
        super().__init__(d_model, A_log=A_log, **kwargs)
        # Store the A_log specific to this layer for dynamic calculations
        self.A_log = A_log # Shape (d_inner, d_state)
        self.current_A = None # Will store the dynamically computed A passed in forward
        # Get d_state from the parent class or kwargs
        self._d_state = getattr(self, 'd_state', kwargs.get('d_state', 16))
        self._d_inner = getattr(self, 'd_inner', kwargs.get('d_inner', int(kwargs.get('expand', 2) * d_model)))
    
    def forward(
        self, hidden_states: Tensor, current_A: Optional[Tensor] = None, inference_params=None
    ):
        """重写前向传播方法，增加动态状态转移矩阵的计算
        Args:
            hidden_states: Input tensor.
            current_A: The dynamically calculated state transition matrix (B, d_inner, d_state) or (d_inner, d_state).
                       This is PRE-CALCULATED by the caller (e.g., forward_features_event).
            inference_params: Inference parameters.
        """
        # Remove internal A calculation using self.dynamic_params
        # Rely entirely on the passed current_A
        # The check for self.dynamic_params is no longer the primary driver here.

        # We still need self.current_A attribute if get_current_A() is used externally,
        # but its value should be the one passed in, not recalculated.
        self.current_A = current_A # Store the passed A

        # The passed current_A should already be calculated correctly by the caller.
        # No need to recalculate using alpha, sigmoid_w_a_rho, A_prev here.
        
        # --- DEBUG PRINTS (Optional) ---
        # if self.layer_idx in [0, 1, 2, 3]:
        #     # ... (debug prints using self.current_A) ...
        #     pass
        # -------- END DEBUG PRINTS --------

        # Call the EventMamba (super class) forward method,
        # passing the received current_A directly.
        # EventMamba.forward handles reducing A to 2D if needed for the kernel.
        hidden_states_after = super().forward(hidden_states, self.current_A, inference_params)
        
        return hidden_states_after

    # Remove set_dynamic_params as params are no longer used internally for A calculation in forward
    # def set_dynamic_params(self, params, layer_idx=None):
    #     """设置动态参数"""
    #     self.dynamic_params = params
    #     if layer_idx is not None:
    #         self.layer_idx = layer_idx

    # Keep get_current_A to return the stored A
    def get_current_A(self):
        """获取当前的状态转移矩阵"""
        return self.current_A

def create_block(
    d_model,
    ssm_cfg=None,
    norm_epsilon=1e-5,
    drop_path=0.,
    rms_norm=False,
    residual_in_fp32=False,
    fused_add_norm=False,
    layer_idx=None,
    device=None,
    dtype=None,
    if_bimamba=False,
    bimamba_type="none",
    if_devide_out=False,
    init_layer_scale=None,
    use_dynamic_mixer=False,  # 新增参数
    A_log=None,
):
    if if_bimamba:
        bimamba_type = "v1"
    if ssm_cfg is None:
        ssm_cfg = {}
    factory_kwargs = {"device": device, "dtype": dtype}

    # 根据 use_dynamic_mixer 选择 Mamba 或 DynamicMamba
    if use_dynamic_mixer:
        mixer_cls = partial(DynamicMamba, layer_idx=layer_idx, A_log=A_log, bimamba_type=bimamba_type, if_devide_out=if_devide_out, init_layer_scale=init_layer_scale, **ssm_cfg, **factory_kwargs)
    else:
        # 使用原始的 Mamba
        mixer_cls = partial(Mamba, layer_idx=layer_idx, bimamba_type=bimamba_type, if_devide_out=if_devide_out, init_layer_scale=init_layer_scale, **ssm_cfg, **factory_kwargs)

    norm_cls = partial(
        nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
    )
    block = Block(
        d_model,
        mixer_cls,
        norm_cls=norm_cls,
        drop_path=drop_path,
        fused_add_norm=fused_add_norm,
        residual_in_fp32=residual_in_fp32,
    )
    block.layer_idx = layer_idx
    return block


# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(
    module,
    n_layer,
    initializer_range=0.02,  # Now only used for embedding layer.
    rescale_prenorm_residual=True,
    n_residuals_per_layer=1,  # Change to 2 if we have MLP
):
    if isinstance(module, nn.Linear):
        if module.bias is not None:
            if not getattr(module.bias, "_no_reinit", False):
                nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=initializer_range)

    if rescale_prenorm_residual:
        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
        #
        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
        for name, p in module.named_parameters():
            if name in ["out_proj.weight", "fc2.weight"]:
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
                # We need to reinit p since this code could be called multiple times
                # Having just p *= scale would repeatedly scale it down
                nn.init.kaiming_uniform_(p, a=math.sqrt(5))
                with torch.no_grad():
                    p /= math.sqrt(n_residuals_per_layer * n_layer)


def segm_init_weights(m):
    if isinstance(m, nn.Linear):
        trunc_normal_(m.weight, std=0.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Conv2d):
        # NOTE conv was left to pytorch default in my original init
        lecun_normal_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
        nn.init.zeros_(m.bias)
        nn.init.ones_(m.weight)


class VisionMamba(nn.Module):
    def __init__(self, 
                 img_size=224, 
                 patch_size=16, 
                 stride=16,
                 depth=24, 
                 embed_dim=192, 
                 channels=3, 
                 num_classes=1000,
                 ssm_cfg=None, 
                 drop_rate=0.,
                 drop_path_rate=0.1,
                 norm_epsilon: float = 1e-5, 
                 rms_norm: bool = False, 
                 initializer_cfg=None,
                 fused_add_norm=False,
                 residual_in_fp32=False,
                 device=None,
                 dtype=None,
                 ft_seq_len=None,
                 pt_hw_seq_len=14,
                 if_bidirectional=False,
                 final_pool_type='none',
                 if_abs_pos_embed=False,
                 if_rope=False,
                 if_rope_residual=False,
                 flip_img_sequences_ratio=-1.,
                 if_bimamba=False,
                 bimamba_type="none",
                 if_cls_token=False,
                 if_devide_out=False,
                 init_layer_scale=None,
                 use_double_cls_token=False,
                 use_middle_cls_token=False,
                 use_dynamic_a=True, # 保留此标志以指示是否需要动态功能
                 **kwargs):
        factory_kwargs = {"device": device, "dtype": dtype}
        # add factory_kwargs into kwargs
        kwargs.update(factory_kwargs) 
        super().__init__()

        # Ensure ssm_cfg is a dictionary, even if None is passed
        if ssm_cfg is None:
            ssm_cfg = {}

        self.residual_in_fp32 = residual_in_fp32
        self.fused_add_norm = fused_add_norm
        self.if_bidirectional = if_bidirectional
        self.final_pool_type = final_pool_type
        self.if_abs_pos_embed = if_abs_pos_embed
        self.if_rope = if_rope
        self.if_rope_residual = if_rope_residual
        self.flip_img_sequences_ratio = flip_img_sequences_ratio
        # self.if_cls_token = if_cls_token
        self.if_cls_token = False
        self.use_double_cls_token = use_double_cls_token
        self.use_middle_cls_token = use_middle_cls_token
        self.num_tokens = 1 if if_cls_token else 0
        self.use_dynamic_a = use_dynamic_a  # 存储是否应启用动态路径

        # 仅在 use_dynamic_a 为 True 时定义动态参数和动态层列表
        if self.use_dynamic_a:
            self.alpha = nn.Parameter(torch.tensor(0.5))
            d_state = ssm_cfg.get('d_state', 16) if ssm_cfg else 16
            d_inner = embed_dim * ssm_cfg.get('expand', 2)
            # W_rho determines how density rho modulates the original A
            self.W_rho = nn.Parameter(torch.randn(d_inner, d_state) * 0.02)
            self.d_state = d_state # Still needed for A_base calculation in forward_features_event
            self.d_inner = d_inner # Store d_inner
            self.register_buffer('A_prev', None)
            
            # Create A_log parameters needed for the dynamic A calculation
            # These will be passed to the DynamicMamba layers during creation
            self.A_logs_dynamic = nn.ParameterList()
            for _ in range(depth):
                A_log_param = nn.Parameter(
                    torch.log(repeat(
                        torch.arange(1, d_state + 1, dtype=torch.float32),
                        "n -> d n",
                        d=d_inner,
                    ).contiguous())
                )
                A_log_param._no_weight_decay = True
                self.A_logs_dynamic.append(A_log_param)
            
            # 创建动态 Mamba 层
            self.layers_dynamic = nn.ModuleList(
                [
                    create_block(
                        embed_dim,
                        ssm_cfg=ssm_cfg,
                        norm_epsilon=norm_epsilon,
                        rms_norm=rms_norm,
                        residual_in_fp32=residual_in_fp32,
                        fused_add_norm=fused_add_norm,
                        layer_idx=i,
                        A_log=self.A_logs_dynamic[i], # Pass the specific A_log
                        if_bimamba=if_bimamba,
                        bimamba_type=bimamba_type,
                        drop_path=drop_path_rate, # 使用单一的drop_path_rate或dpr[i]
                        if_devide_out=if_devide_out,
                        init_layer_scale=init_layer_scale,
                        use_dynamic_mixer=True, # <--- 使用动态 Mixer
                        **factory_kwargs,
                    )
                    for i in range(depth)
                ]
            )
        else:
            # 如果不使用动态 A，则将动态层列表设为 None
            self.layers_dynamic = None


        # pretrain parameters
        self.num_classes = num_classes
        self.d_model = self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models

        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, stride=stride, in_chans=channels, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        if if_cls_token:
            if use_double_cls_token:
                self.cls_token_head = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
                self.cls_token_tail = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
                self.num_tokens = 2
            else:
                self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
                # self.num_tokens = 1
            
        if if_abs_pos_embed:
            # self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, self.embed_dim))
            self.pos_embed_x = nn.Parameter(torch.zeros(1, 256, self.embed_dim))
            self.pos_embed_z = nn.Parameter(torch.zeros(1, 64, self.embed_dim))
            self.pos_drop = nn.Dropout(p=drop_rate)

        if if_rope:
            half_head_dim = embed_dim // 2
            hw_seq_len = img_size // patch_size
            self.rope = VisionRotaryEmbeddingFast(
                dim=half_head_dim,
                pt_seq_len=pt_hw_seq_len,
                ft_seq_len=hw_seq_len
            )

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        inter_dpr = [0.0] + dpr # Note: dpr might need adjustment depending on usage below
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()

        # 创建标准 Mamba 层
        self.layers = nn.ModuleList(
            [
                create_block(
                    embed_dim,
                    ssm_cfg=ssm_cfg,
                    norm_epsilon=norm_epsilon,
                    rms_norm=rms_norm,
                    residual_in_fp32=residual_in_fp32,
                    fused_add_norm=fused_add_norm,
                    layer_idx=i,
                    if_bimamba=if_bimamba,
                    bimamba_type=bimamba_type,
                    drop_path=dpr[i], # 使用 dpr 列表中的值
                    if_devide_out=if_devide_out,
                    init_layer_scale=init_layer_scale,
                    use_dynamic_mixer=False, # <--- 不使用动态 Mixer
                    **factory_kwargs,
                )
                for i in range(depth)
            ]
        )

        # output head
        self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
            embed_dim, eps=norm_epsilon, **factory_kwargs
        )

        self.patch_embed.apply(segm_init_weights)
        if if_abs_pos_embed:
            trunc_normal_(self.pos_embed_x, std=.02)
            trunc_normal_(self.pos_embed_z, std=.02)
            
        if if_cls_token:
            if use_double_cls_token:
                trunc_normal_(self.cls_token_head, std=.02)
                trunc_normal_(self.cls_token_tail, std=.02)
            else:
                trunc_normal_(self.cls_token, std=.02)

        # mamba init - 应用到所有层
        self.apply(
            partial(
                _init_weights,
                n_layer=depth,
                **(initializer_cfg if initializer_cfg is not None else {}),
            )
        )

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        # 为两个层列表都分配缓存
        cache = {
            f"layer_{i}": layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
            for i, layer in enumerate(self.layers)
        }
        if self.layers_dynamic is not None:
             cache.update({
                 f"layer_dynamic_{i}": layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
                 for i, layer in enumerate(self.layers_dynamic)
             })
        return cache

    @torch.jit.ignore
    def no_weight_decay(self):
        return {"pos_embed", "cls_token", "dist_token", "cls_token_head", "cls_token_tail"}

    @torch.jit.ignore()
    def load_pretrained(self, checkpoint_path, prefix=""):
        _load_weights(self, checkpoint_path, prefix)

    def forward_features(self, z, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
        # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
        # with slight modifications to add the dist_token
        x = self.patch_embed(x)                  #x.shape = torch.Size([B, 3, 256, 256])  -> torch.Size([2, 256, 384])
        z = self.patch_embed(z)                  #z.shape = torch.Size([B, 3, 128, 128])  -> torch.Size([2, 64, 384])
        B, M, _ = x.shape
       
        if self.if_cls_token:                 # False
            if self.use_double_cls_token:
                cls_token_head = self.cls_token_head.expand(B, -1, -1)
                cls_token_tail = self.cls_token_tail.expand(B, -1, -1)
                token_position = [0, M + 1]
                x = torch.cat((cls_token_head, x, cls_token_tail), dim=1)
                M = x.shape[1]
            else:
                if self.use_middle_cls_token:
                    cls_token = self.cls_token.expand(B, -1, -1)
                    token_position = M // 2
                    # add cls token in the middle
                    x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)       
                elif if_random_cls_token_position:
                    cls_token = self.cls_token.expand(B, -1, -1)
                    token_position = random.randint(0, M)
                    x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
                    print("token_position: ", token_position)
                else:
                    cls_token = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
                    token_position = 0
                    x = torch.cat((cls_token, x), dim=1)
                M = x.shape[1]                 
       
        if self.if_abs_pos_embed:                  # True 
            # 根据实际 patch 数量调整 pos_embed 大小 (如果需要)
            _, N_search, _ = x.shape
            _, N_template, _ = z.shape
            if self.pos_embed_x.shape[1] != N_search:
                 pos_embed_x_adj = nn.functional.interpolate(self.pos_embed_x.permute(0, 2, 1), size=N_search, mode='linear', align_corners=False).permute(0, 2, 1)
            else:
                 pos_embed_x_adj = self.pos_embed_x
            if self.pos_embed_z.shape[1] != N_template:
                 pos_embed_z_adj = nn.functional.interpolate(self.pos_embed_z.permute(0, 2, 1), size=N_template, mode='linear', align_corners=False).permute(0, 2, 1)
            else:
                 pos_embed_z_adj = self.pos_embed_z
            
            x = x + pos_embed_x_adj               # x = x + positon_embemding
            z = z + pos_embed_z_adj               # z = z + positon_embemding
            hidden_states = torch.cat((z, x), dim=1)           # torch.Size([B, N_template + N_search, D])
            hidden_states = self.pos_drop(hidden_states)
        else:
            hidden_states = torch.cat((z, x), dim=1) # torch.Size([B, N_template + N_search, D])
            
        if if_random_token_rank:                   #False
            # ... (原始随机排序逻辑) ...
            pass

        if_flip_img_sequences = False
        if self.flip_img_sequences_ratio > 0 and (self.flip_img_sequences_ratio - random.random()) > 1e-5:        # False
            hidden_states = hidden_states.flip([1])
            if_flip_img_sequences = True

        # mamba impl - 使用 self.layers (标准 Mamba)
        residual = None
        # hidden_states = x # hidden_states 已在上面定义并拼接了 z 和 x
        if not self.if_bidirectional:                                 # True
            # 使用 self.layers
            for layer in self.layers:
                # ... (处理 flip 和 rope 的逻辑保持不变) ...
                if if_flip_img_sequences and self.if_rope:
                    hidden_states = hidden_states.flip([1])
                    if residual is not None:
                        residual = residual.flip([1])

                if self.if_rope:
                    hidden_states = self.rope(hidden_states)
                    if residual is not None and self.if_rope_residual:
                        residual = self.rope(residual)

                if if_flip_img_sequences and self.if_rope:
                    hidden_states = hidden_states.flip([1])
                    if residual is not None:
                        residual = residual.flip([1])

                hidden_states, residual = layer(
                    hidden_states, residual, inference_params=inference_params
                )
        
        else:             # False (双向逻辑)
            # 确保也使用 self.layers
            for i in range(len(self.layers) // 2):
                # ... (处理 rope 的逻辑保持不变) ...
                if self.if_rope:
                    hidden_states = self.rope(hidden_states)
                    if residual is not None and self.if_rope_residual:
                        residual = self.rope(residual)
                
                # 使用 self.layers 中的层
                hidden_states_f, residual_f = self.layers[i * 2](
                    hidden_states, residual, inference_params=inference_params
                )
                hidden_states_b, residual_b = self.layers[i * 2 + 1](
                    hidden_states.flip([1]), None if residual is None else residual.flip([1]), inference_params=inference_params
                )
                hidden_states = hidden_states_f + hidden_states_b.flip([1])
                residual = residual_f + residual_b.flip([1])
      
        # ... (后续的 norm 和 pooling 逻辑保持不变) ...
        if not self.fused_add_norm:         #False
            if residual is None:
                residual = hidden_states
            else:
                residual = residual + self.drop_path(hidden_states)
            hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
        else:       #True
            # Set prenorm=False here since we don't need the residual
            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
            hidden_states = fused_add_norm_fn(                                         # hidden_states.shape = torch.Size([B, 320, 384])
                self.drop_path(hidden_states),
                self.norm_f.weight,
                self.norm_f.bias,
                eps=self.norm_f.eps, residual=residual, prenorm=False,
                residual_in_fp32=self.residual_in_fp32,
            )

        # return only cls token if it exists
        if self.if_cls_token:          #False
           # ... (cls token 返回逻辑)
           pass

        if self.final_pool_type == 'none':
            # 假设返回最后一个 token，但需要考虑 z 和 x 拼接后的情况
            # 如果总是返回最后一个，它将始终是 x 的最后一个 token
            return hidden_states[:, -1, :]
        elif self.final_pool_type == 'mean':         #True
            return hidden_states.mean(dim=1)         #hidden_states.shape = torch.Size([2, 384])
        elif self.final_pool_type == 'max':
            # 如果需要返回 max pooling 的结果
            # return torch.max(hidden_states, dim=1)[0]
            return hidden_states
        elif self.final_pool_type == 'all':
            return hidden_states
        else:
            raise NotImplementedError

    def forward(self, x, return_features=False, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
        x = self.forward_features(x, inference_params, if_random_cls_token_position=if_random_cls_token_position, if_random_token_rank=if_random_token_rank)
        if return_features:
            return x
        x = self.head(x)
        if self.final_pool_type == 'max':
            x = x.max(dim=1)[0]
        return x

    def forward_features_event(self, z, x, z_density, x_density, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
        """
        处理事件数据的前向传播函数，使用动态状态转移矩阵。
        事件密度 rho_t 根据输入的 z_density 和 x_density 动态确定。
        
        此函数现在明确使用 self.layers_dynamic (包含 DynamicMamba 块)。
        如果模型初始化时 use_dynamic_a=False，则此函数会抛出错误。
        """
        if not self.use_dynamic_a or self.layers_dynamic is None:
            raise RuntimeError("forward_features_event requires the model to be initialized with use_dynamic_a=True.")

        # Determine model device early
        model_device = self.patch_embed.proj.weight.device

        # ========= Ensure inputs z and x are Tensors on correct device before PatchEmbed =========
        if not isinstance(x, torch.Tensor):
            try:
                x = torch.tensor(x, dtype=torch.float32)
            except Exception as e:
                 raise TypeError(f"Input event search image 'x' failed conversion to Tensor. Error: {e}")
        if not isinstance(z, torch.Tensor):
            try:
                 z = torch.tensor(z, dtype=torch.float32)
            except Exception as e:
                 raise TypeError(f"Input event template image 'z' failed conversion to Tensor. Error: {e}")

        x = x.to(model_device)
        z = z.to(model_device)
        # ================================================================================

        # ========= Patch Embedding =========
        x_embed = self.patch_embed(x)                  # x.shape = torch.Size([B, 3, 256, 256])  -> torch.Size([B, N_search, D])
        z_embed = self.patch_embed(z)                  # z.shape = torch.Size([B, 3, 128, 128])  -> torch.Size([B, N_template, D])
        B, N_search, D = x_embed.shape
        _, N_template, _ = z_embed.shape

        # ========= Calculate weighted event density rho_t (Vector per image) =========
        # --- Validate densities ---
        if z_density is None or x_density is None:
            raise ValueError("Both z_density and x_density must be provided for forward_features_event.")
        if not isinstance(z_density, torch.Tensor):
            try:
                z_density = torch.tensor(z_density, dtype=torch.float32)
            except Exception as e:
                raise TypeError(f"Input event density 'z_density' must be a Tensor or convertible to Tensor. Error: {e}")
        if not isinstance(x_density, torch.Tensor):
            try:
                x_density = torch.tensor(x_density, dtype=torch.float32)
            except Exception as e:
                raise TypeError(f"Input event density 'x_density' must be a Tensor or convertible to Tensor. Error: {e}")

        z_density = z_density.to(model_device)
        x_density = x_density.to(model_device)

        # --- Check batch size consistency ---
        B_density = x_density.shape[0]
        if z_density.shape[0] != B_density:
            raise ValueError(f"Batch size mismatch between z_density ({z_density.shape[0]}) and x_density ({B_density})")
        if B != B_density:
            raise ValueError(f"Batch size mismatch between input images ({B}) and density vectors ({B_density})")
        if z_density.ndim != 1 or x_density.ndim != 1:
             raise ValueError(f"Expected z_density and x_density to be 1D tensors (shape [B]), but got shapes {z_density.shape} and {x_density.shape}")

        # --- Calculate weights based on sequence lengths ---
        N_total_seq = N_template + N_search
        if N_total_seq > 0:
            weight_z = N_template / N_total_seq
            weight_x = N_search / N_total_seq
        else:
            weight_z = 0.5
            weight_x = 0.5

        # --- Calculate weighted average density ---
        rho_t_vector = z_density * weight_z + x_density * weight_x
        rho_t_vector = rho_t_vector.float()
        # =====================================================================================

        # ========= Prepare inputs for Mamba layers (CLS token, Pos Embedding, etc.) =====
        # --- CLS Token (if enabled, but currently hardcoded to False) ---
        if self.if_cls_token:                 # False
            # ... (original cls token logic, needs review if enabled) ...
            pass

        # --- Absolute Positional Embedding --- 
        if self.if_abs_pos_embed:                  # True
            # Adjust pos embed size dynamically
            if self.pos_embed_x.shape[1] != N_search:
                 pos_embed_x_adj = nn.functional.interpolate(self.pos_embed_x.permute(0, 2, 1), size=N_search, mode='linear', align_corners=False).permute(0, 2, 1)
            else:
                 pos_embed_x_adj = self.pos_embed_x
            if self.pos_embed_z.shape[1] != N_template:
                 pos_embed_z_adj = nn.functional.interpolate(self.pos_embed_z.permute(0, 2, 1), size=N_template, mode='linear', align_corners=False).permute(0, 2, 1)
            else:
                 pos_embed_z_adj = self.pos_embed_z

            x_final = x_embed + pos_embed_x_adj
            z_final = z_embed + pos_embed_z_adj

            hidden_states = torch.cat((z_final, x_final), dim=1)
            hidden_states = self.pos_drop(hidden_states)
        else:
            hidden_states = torch.cat((z_embed, x_embed), dim=1)

        # N_total = hidden_states.shape[1] # Total sequence length after concat

        # --- Random Token Rank --- (if enabled)
        if if_random_token_rank:                   #False
            # ... (original random rank logic) ...
            pass

        # --- Flip Sequence --- (if enabled)
        if_flip_img_sequences = False
        if self.flip_img_sequences_ratio > 0 and (self.flip_img_sequences_ratio - random.random()) > 1e-5:        # False
            hidden_states = hidden_states.flip([1])
            if_flip_img_sequences = True
        # =====================================================================================

        # --- Dynamic State Transition Calculation --- 
        residual = None
        # hidden_states is prepared above
        
        # Prepare parameters needed for dynamic A calculation inside the loop
        alpha = torch.sigmoid(self.alpha) # Shape []
        W_rho_device = self.W_rho.to(model_device) # Shape [d_inner, d_state]
        # Precompute rho modulation factor (can be reused across layers)
        # rho_t_vector shape [B]
        # W_rho_device shape [d_inner, d_state]
        # modulation shape should be [B, d_inner, d_state]
        modulation_factor = torch.sigmoid(W_rho_device.unsqueeze(0) * rho_t_vector.view(-1, 1, 1)) # Check broadcasting
        A_prev = self.A_prev # Shape [B, d_inner, d_state] or None, will be updated in loop

        # ======================================================================================
        
        # --- Mamba Layers (using self.layers_dynamic) --- 
        if not self.if_bidirectional: # True
            # Iterate through the DYNAMIC layers
            for layer_idx, layer in enumerate(self.layers_dynamic):
                # --- Get original A for this layer --- 
                if not hasattr(layer.mixer, 'A_log') or layer.mixer.A_log is None:
                     raise AttributeError(f"DynamicMamba layer {layer_idx} does not have A_log attribute.")
                A_log_layer = layer.mixer.A_log.to(model_device)
                A_orig_layer = -torch.exp(A_log_layer.float()) # Shape [d_inner, d_state]

                # --- Calculate current_A for this layer --- 
                # modulation_factor shape [B, d_inner, d_state]
                # A_orig_layer shape [d_inner, d_state]
                modulated_A_orig = modulation_factor * A_orig_layer.unsqueeze(0) # Shape [B, d_inner, d_state]

                if A_prev is None:
                    current_A_layer = alpha * modulated_A_orig
                else:
                    current_A_layer = alpha * modulated_A_orig + (1 - alpha) * A_prev.to(modulated_A_orig.dtype)
                # current_A_layer shape [B, d_inner, d_state]

                # --- Forward pass through the layer --- 
                if if_flip_img_sequences and self.if_rope:
                    hidden_states = hidden_states.flip([1])
                    if residual is not None:
                         residual = residual.flip([1])
                
                if self.if_rope: # Apply rope if needed
                     hidden_states = self.rope(hidden_states)
                     if residual is not None and self.if_rope_residual:
                         residual = self.rope(residual)
                 
                if if_flip_img_sequences and self.if_rope:
                    hidden_states = hidden_states.flip([1])
                    if residual is not None:
                         residual = residual.flip([1])

                hidden_states, residual = layer(
                    hidden_states, 
                    residual, 
                    inference_params=inference_params,
                    current_A=current_A_layer # Pass the calculated 3D A for this layer
                )
 
                # --- Update A_prev for the next layer --- 
                # We use the calculated current_A_layer to update A_prev for the next iteration.
                if current_A_layer is not None: 
                    # Detach and update A_prev on the model itself (for next iteration/batch)
                    # and in dynamic_params (for the next layer in this forward pass)
                    detached_A = current_A_layer.detach()
                    A_prev = detached_A # Update A_prev for the next loop iteration
            # Store the final A_prev from the last layer back to the buffer
            self.A_prev = A_prev.detach() if A_prev is not None else None

        else: # False (Bidirectional Logic)
             # Ensure bidirectional logic also uses self.layers_dynamic and handles dynamic_params
            for i in range(len(self.layers_dynamic) // 2):
                 # ... (rope logic) ...
                 if self.if_rope:
                     hidden_states = self.rope(hidden_states)
                     if residual is not None and self.if_rope_residual:
                         residual = self.rope(residual)
                 
                 # --- Get original A for forward and backward layers --- 
                 idx_f, idx_b = i * 2, i * 2 + 1
                 if not hasattr(self.layers_dynamic[idx_f].mixer, 'A_log') or self.layers_dynamic[idx_f].mixer.A_log is None or \
                    not hasattr(self.layers_dynamic[idx_b].mixer, 'A_log') or self.layers_dynamic[idx_b].mixer.A_log is None:
                     raise AttributeError(f"DynamicMamba bi-layers {idx_f},{idx_b} do not have A_log attribute.")
                 A_log_f = self.layers_dynamic[idx_f].mixer.A_log.to(model_device)
                 A_log_b = self.layers_dynamic[idx_b].mixer.A_log.to(model_device)
                 A_orig_f = -torch.exp(A_log_f.float()) # Shape [d_inner, d_state]
                 A_orig_b = -torch.exp(A_log_b.float()) # Shape [d_inner, d_state]

                 # --- Calculate current_A for this bi-layer --- 
                 # modulation_factor shape [B, d_inner, d_state]
                 modulated_A_orig_f = modulation_factor * A_orig_f.unsqueeze(0) # Shape [B, d_inner, d_state]
                 # Assume same modulation applies to backward pass A? Or should backward have its own W_rho?
                 modulated_A_orig_b = modulation_factor * A_orig_b.unsqueeze(0)

                 if A_prev is None:
                     current_A_f = alpha * modulated_A_orig_f
                     current_A_b = alpha * modulated_A_orig_b
                 else:
                     current_A_f = alpha * modulated_A_orig_f + (1 - alpha) * A_prev.to(modulated_A_orig_f.dtype)
                     current_A_b = alpha * modulated_A_orig_b + (1 - alpha) * A_prev.to(modulated_A_orig_b.dtype)
                 # current_A_f, current_A_b shape [B, d_inner, d_state]

                 # --- Forward Pass --- 
                 hidden_states_f, residual_f = self.layers_dynamic[i * 2](
                     hidden_states, residual, inference_params=inference_params, current_A=current_A_f
                 )
                 
                 # --- Backward Pass --- 
                 hidden_states_b, residual_b = self.layers_dynamic[i * 2 + 1](
                     hidden_states.flip([1]), 
                     None if residual is None else residual.flip([1]), 
                     inference_params=inference_params,
                     current_A=current_A_b # Pass the calculated A for backward layer
                 )
                 
                 # Combine outputs
                 hidden_states = hidden_states_f + hidden_states_b.flip([1])
                 residual = residual_f + residual_b.flip([1])
                 
                 # --- Update A_prev for the next layer --- 
                 # How to combine A_f and A_b for the next A_prev?
                 # Option 1: Use A_f (forward pass dominates?)
                 # Option 2: Average A_f and A_b?
                 # Let's use A_f for now, similar to original unidirectional logic.
                 if current_A_f is not None:
                     detached_A = current_A_f.detach()
                     A_prev = detached_A # Update A_prev for the next loop iteration
            # Store the final A_prev from the last layer back to the buffer
            self.A_prev = A_prev.detach() if A_prev is not None else None

        # =======================================================================================

        # --- Final Normalization (using original logic) --- 
        if not self.fused_add_norm: #False
            if residual is None: residual = hidden_states
            else: residual = residual + self.drop_path(hidden_states)
            hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
        else: #True
            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
            hidden_states = fused_add_norm_fn(
                self.drop_path(hidden_states),
                self.norm_f.weight, self.norm_f.bias,
                eps=self.norm_f.eps, residual=residual, prenorm=False,
                residual_in_fp32=self.residual_in_fp32,
            )

        # --- CLS Token Return (using original logic) --- 
        if self.if_cls_token: #False
            # ... (original cls token return logic) ...
            pass

        # --- Pooling and Final Return (using original logic) --- 
        if self.final_pool_type == 'none':
             # Consider if the last token is always the desired output after concat
             return hidden_states[:, -1, :] if hidden_states.shape[1] > 0 else hidden_states
        elif self.final_pool_type == 'mean': # True
             return hidden_states.mean(dim=1)
        elif self.final_pool_type == 'max':
             # return torch.max(hidden_states, dim=1)[0]
             return hidden_states
        elif self.final_pool_type == 'all':
             return hidden_states
        else:
             raise NotImplementedError(f"final_pool_type '{self.final_pool_type}' not supported")


@register_model
def vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs):
    model = VisionMamba(
        patch_size=16, embed_dim=192, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="to.do",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model

@register_model
def vim_tiny_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs):
    model = VisionMamba(
        patch_size=16, stride=8, embed_dim=192, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="to.do",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model

@register_model
def vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs):
    model = VisionMamba(
        patch_size=16, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='all', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
    model.default_cfg = _cfg()
    
    if pretrained:
        # checkpoint = torch.hub.load_state_dict_from_url(
        #     url="to.do",
        #     map_location="cpu", check_hash=True
        # )
        # model.load_state_dict(checkpoint["model"])
        checkpoint = torch.load(pretrained, map_location="cpu")
        missing_keys, unexpected_keys = model.load_state_dict(checkpoint["model"], strict=False)
        print('Load pretrained model from: ' + pretrained)
    return model

@register_model
def vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs):
    model = VisionMamba(
        patch_size=16, stride=8, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="to.do",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


