#!/usr/bin/env python
#-*- coding:utf-8 _*-
import math
import numpy as np
import torch
import torch.nn as nn
import dgl
from einops import repeat, rearrange
from torch.nn import functional as F
from torch.nn import GELU, ReLU, Tanh, Sigmoid
from torch.nn.utils.rnn import pad_sequence

from utils import MultipleTensors
from models.mlp import MLP
from utils import MultipleTensors
from models.mlp import MLP
# 我们需要从cgpt.py中导入一些我们已经完善的模块

def pack_tokens(x, active_idx):
    D = x.shape[-1]
    idx_exp = active_idx.unsqueeze(-1).expand(-1, -1, D)
    return torch.gather(x, 1, idx_exp)
class DepthRouter(nn.Module):
    """一个专门用于为Token分配计算深度的路由器。"""
    def __init__(self, d_model, max_depth):
        super().__init__()
        # 简单的线性层，输出每个深度的 logit
        self.proj = nn.Linear(d_model, max_depth)

    def forward(self, x):
        # 返回每个深度的 logits, 形状: (B, N, max_depth)
        return self.proj(x)

class LightweightRouter(nn.Module):
    def __init__(self, d_model, d_hidden_ratio=0.25):
        super().__init__()
        d_hidden = int(d_model * d_hidden_ratio)
        self.proj1 = nn.Linear(d_model, d_hidden)
        self.proj2 = nn.Linear(d_hidden, 1)
    def forward(self, x):
        return self.proj2(F.gelu(self.proj1(x)))
class MoEGPTConfig():
    """ base GPT config, params common to all GPT versions """
    def __init__(self,attn_type='linear', embd_pdrop=0.0, resid_pdrop=0.0,attn_pdrop=0.0, n_embd=128, n_head=1, n_layer=3, block_size=128, n_inner=4,act='gelu',n_experts=2,space_dim=1,branch_sizes=None,n_inputs=1):
        self.attn_type = attn_type
        self.embd_pdrop = embd_pdrop
        self.resid_pdrop = resid_pdrop
        self.attn_pdrop = attn_pdrop
        self.n_embd = n_embd
        self.n_head = n_head
        self.n_layer = n_layer
        self.block_size = block_size
        self.n_inner = n_inner * self.n_embd
        self.act = act
        self.n_experts = n_experts
        self.space_dim = space_dim
        self.branch_sizes = branch_sizes
        self.n_inputs = n_inputs


class LinearAttention(nn.Module):
    """
    带有KV Cache优化的多头线性自注意力层。
    """
    def __init__(self, config):
        super(LinearAttention, self).__init__()
        assert config.n_embd % config.n_head == 0
        self.key = nn.Linear(config.n_embd, config.n_embd)
        self.query = nn.Linear(config.n_embd, config.n_embd)
        self.value = nn.Linear(config.n_embd, config.n_embd)
        self.attn_drop = nn.Dropout(config.attn_pdrop)
        self.proj = nn.Linear(config.n_embd, config.n_embd)
        self.n_head = config.n_head
        self.attn_type = 'l1'

    def forward(self, x, y=None, kv_cache=None):
        y = x if y is None else y
        B, T1, C = x.size()
        _, T2, _ = y.size()

        q = self.query(x).view(B, T1, self.n_head, C // self.n_head).transpose(1, 2)
        k = self.key(y).view(B, T2, self.n_head, C // self.n_head).transpose(1, 2)
        v = self.value(y).view(B, T2, self.n_head, C // self.n_head).transpose(1, 2)

        if self.attn_type == 'l1':
            q = q.softmax(dim=-1)
            k_softmax = k.softmax(dim=-1)
        elif self.attn_type == "galerkin":
            q = q.softmax(dim=-1)
            k_softmax = k.softmax(dim=-1)
        elif self.attn_type == "l2":
            q = q / q.norm(dim=-1, keepdim=True, p=1).clamp(min=1e-6)
            k_softmax = k / k.norm(dim=-1, keepdim=True, p=1).clamp(min=1e-6)
        else:
            raise NotImplementedError
        
        current_context_sum = k_softmax.transpose(-2, -1) @ v
        current_k_softmax_sum = k_softmax.sum(dim=-2, keepdim=True)

        if kv_cache is not None:
            prev_context_sum, prev_k_softmax_sum = kv_cache
            total_context_sum = prev_context_sum + current_context_sum
            total_k_softmax_sum = prev_k_softmax_sum + current_k_softmax_sum
        else:
            # 修正: 在else分支中为两个变量都赋值
            total_context_sum = current_context_sum
            total_k_softmax_sum = current_k_softmax_sum

        new_kv_cache = (total_context_sum, total_k_softmax_sum)
        
        if self.attn_type == 'galerkin':
            D_inv = 1. / T2
        else:
            D_inv = 1. / (q * total_k_softmax_sum).sum(dim=-1, keepdim=True).clamp(min=1e-6)

        context = q @ total_context_sum
        y = self.attn_drop(context * D_inv + q)

        y = rearrange(y, 'b h n d -> b n (h d)')
        y = self.proj(y)
        
        return y, new_kv_cache


class LinearCrossAttention(nn.Module):
    """
    多头线性交叉注意力层。
    注意: 交叉注意力的Key和Value来自静态输入，因此通常不使用KV Cache。
    """
    def __init__(self, config):
        super(LinearCrossAttention, self).__init__()
        assert config.n_embd % config.n_head == 0
        self.query = nn.Linear(config.n_embd, config.n_embd)
        self.keys = nn.ModuleList([nn.Linear(config.n_embd, config.n_embd) for _ in range(config.n_inputs)])
        self.values = nn.ModuleList([nn.Linear(config.n_embd, config.n_embd) for _ in range(config.n_inputs)])
        self.attn_drop = nn.Dropout(config.attn_pdrop)
        self.proj = nn.Linear(config.n_embd, config.n_embd)
        self.n_head = config.n_head
        self.n_inputs = config.n_inputs
        self.attn_type = 'l1'

    def forward(self, x, y=None, layer_past=None):
        y = x if y is None else y
        B, T1, C = x.size()
        q = self.query(x).view(B, T1, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.softmax(dim=-1)
        out = q
        for i in range(self.n_inputs):
            _, T2, _ = y[i].size()
            k = self.keys[i](y[i]).view(B, T2, self.n_head, C // self.n_head).transpose(1, 2)
            v = self.values[i](y[i]).view(B, T2, self.n_head, C // self.n_head).transpose(1, 2)
            k = k.softmax(dim=-1)
            k_cumsum = k.sum(dim=-2, keepdim=True)
            D_inv = 1. / (q * k_cumsum).sum(dim=-1, keepdim=True).clamp(min=1e-6)
            out = out +  1 * (q @ (k.transpose(-2, -1) @ v)) * D_inv
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.proj(out)
        return out


def horizontal_fourier_embedding(X, n=3):
    freqs = 2**torch.linspace(-n, n, 2*n+1).to(X.device)
    freqs = freqs[None,None,None,...]
    X_ = X.unsqueeze(-1).repeat([1,1,1,2*n+1])
    X_cos = torch.cos(freqs * X_)
    X_sin = torch.sin(freqs * X_)
    X = torch.cat([X.unsqueeze(-1), X_cos, X_sin],dim=-1).view(X.shape[0],X.shape[1],-1)
    return X


class MIOECrossAttentionBlock(nn.Module):
    """
    带有MoE专家层的注意力模块。
    """
    def __init__(self, config):
        super(MIOECrossAttentionBlock, self).__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2_branch = nn.ModuleList([nn.LayerNorm(config.n_embd) for _ in range(config.n_inputs)])
        self.ln3 = nn.LayerNorm(config.n_embd)
        self.ln4 = nn.LayerNorm(config.n_embd)
        self.ln5 = nn.LayerNorm(config.n_embd)
        if config.attn_type == 'linear':
            print('Using Linear Attention')
            self.selfattn = LinearAttention(config)
            self.crossattn = LinearCrossAttention(config)
        else:
            raise NotImplementedError
        if config.act == 'gelu': self.act = GELU
        elif config.act == "tanh": self.act = Tanh
        elif config.act == 'relu': self.act = ReLU
        elif config.act == 'sigmoid': self.act = Sigmoid
        self.resid_drop1 = nn.Dropout(config.resid_pdrop)
        self.resid_drop2 = nn.Dropout(config.resid_pdrop)
        self.n_experts = config.n_experts
        self.n_inputs = config.n_inputs
        self.moe_mlp1 = nn.ModuleList([nn.Sequential(nn.Linear(config.n_embd, config.n_inner), self.act(), nn.Linear(config.n_inner, config.n_embd)) for _ in range(self.n_experts)])
        self.moe_mlp2 = nn.ModuleList([nn.Sequential(nn.Linear(config.n_embd, config.n_inner), self.act(), nn.Linear(config.n_inner, config.n_embd)) for _ in range(self.n_experts)])
        self.gatenet = nn.Sequential(nn.Linear(config.space_dim, config.n_inner), self.act(), nn.Linear(config.n_inner, config.n_inner), self.act(), nn.Linear(config.n_inner, self.n_experts))

    def ln_branchs(self, y):
        return MultipleTensors([self.ln2_branch[i](y[i]) for i in range(self.n_inputs)])

    # MODIFICATION START: 接收并返回自注意力层的缓存 self_attn_cache
    def forward(self, x, y, pos, self_attn_cache=None):
        gate_score = F.softmax(self.gatenet(pos), dim=-1).unsqueeze(2)
        
        # 交叉注意力
        x = x + self.resid_drop1(self.crossattn(self.ln1(x), self.ln_branchs(y)))
        
        # 第一个MoE MLP
        x_moe1 = torch.stack([self.moe_mlp1[i](x) for i in range(self.n_experts)], dim=-1)
        x_moe1 = (gate_score * x_moe1).sum(dim=-1, keepdim=False)
        x = x + self.ln3(x_moe1)
        
        # 自注意力 (使用并更新缓存)
        attn_out, new_self_attn_cache = self.selfattn(self.ln4(x), kv_cache=self_attn_cache)
        x = x + self.resid_drop2(attn_out)
        
        # 第二个MoE MLP
        x_moe2 = torch.stack([self.moe_mlp2[i](x) for i in range(self.n_experts)], dim=-1)
        x_moe2 = (gate_score * x_moe2).sum(dim=-1, keepdim=False)
        x = x + self.ln5(x_moe2)
        
        return x, new_self_attn_cache
    # MODIFICATION END


# class GNOT(nn.Module):
#     """
#     适配长时序自回归任务的GNOT (MoE版本)。
#     - `forward` 签名与训练/评估脚本统一。
#     - 直接使用批次化的 `coords` 张量，无需 `dgl.unbatch` 和 `pad_sequence`。
#     """
#     def __init__(self,
#                  coord_dim, state_dim, theta_dim,
#                  branch_sizes, output_size,
#                  n_layers, n_hidden, n_head,
#                  n_experts=2, n_inner=4, mlp_layers=2,
#                  attn_type='linear', act='gelu',
#                  ffn_dropout=0.0, attn_dropout=0.0,
#                  space_dim=2, **kwargs):
#         super(GNOT, self).__init__()
        
#         self.n_inputs = len(branch_sizes)
#         self.output_size = output_size
#         self.space_dim = space_dim

#         # 输入编码器 (Trunk MLP)
#         # 输入: 节点状态 + 坐标 + 全局参数 (theta)
#         trunk_size = state_dim + coord_dim + theta_dim
#         self.trunk_mlp = MLP(trunk_size, n_hidden, n_hidden, n_layers=mlp_layers, act=act)

#         # 输入函数编码器 (Branch MLPs)
#         if self.n_inputs > 0:
#             self.branch_mlps = nn.ModuleList([MLP(bsize, n_hidden, n_hidden, n_layers=mlp_layers, act=act) for bsize in branch_sizes])
        
#         # GPT 配置
#         self.gpt_config = MoEGPTConfig(
#             attn_type=attn_type, embd_pdrop=ffn_dropout, resid_pdrop=ffn_dropout,
#             attn_pdrop=attn_dropout, n_embd=n_hidden, n_head=n_head, n_layer=n_layers,
#             n_inner=n_inner, act=act, n_experts=n_experts, space_dim=space_dim,
#             branch_sizes=branch_sizes, n_inputs=self.n_inputs
#         )
        
#         self.blocks = nn.ModuleList([MIOECrossAttentionBlock(self.gpt_config) for _ in range(self.gpt_config.n_layer)])
#         self.out_mlp = MLP(n_hidden, n_hidden, output_size, n_layers=mlp_layers)
#         self.__name__ = 'GNOT_MoE'

#     def forward(self, g, coords, u_p, inputs, past_key_values=None, **kwargs):
#         # 1. 准备输入
#         # `coords` 已经是 (B, N, C_coord)
#         # `u_p` 已经是 (B, C_theta)
#         # `g.ndata['x']` 是当前时间步的状态 (Total_Nodes, C_state)
        
#         B, N, _ = coords.shape
        
#         # 从DGL图中提取并重塑当前状态
#         current_state = g.ndata['x'].view(B, N, -1)
        
#         # 扩展 u_p 以匹配每个节点的维度
#         u_p_expanded = u_p.unsqueeze(1).repeat(1, N, 1)
        
#         # 拼接所有输入特征
#         x = torch.cat([current_state, coords, u_p_expanded], dim=-1)
#         x = self.trunk_mlp(x) # -> (B, N, n_hidden)
        
#         # 提取用于门控网络的位置信息
#         pos = coords[:, :, :self.space_dim]
        
#         # 处理输入函数 (如果存在)
#         if self.n_inputs > 0 and isinstance(inputs, MultipleTensors) and len(inputs.x) > 0:
#             z = MultipleTensors([self.branch_mlps[i](inputs[i]) for i in range(self.n_inputs)])
#         else:
#              # 如果没有 branch inputs, 创建一个空的 MultipleTensors
#             z = MultipleTensors([])
        
#         # 初始化KV Cache
#         if past_key_values is None:
#             past_key_values = [None] * len(self.blocks)
        
#         new_past_key_values = []
#         for i, block in enumerate(self.blocks):
#             x, new_cache = block(x, z, pos, self_attn_cache=past_key_values[i])
#             new_past_key_values.append(new_cache)
        
#         # 输出解码
#         x_out = self.out_mlp(x)
        
#         # 重塑为DGL图所需的格式 (Total_Nodes, C_out)
#         x_out_flat = x_out.view(-1, self.output_size)
        
#         return x_out_flat, new_past_key_values

class GNOT(nn.Module):
    """
    原始 GNOT (MoE 版本)，用于单步预测任务。
    - forward 签名接收原始的、非批次化的输入。
    - 内部使用 dgl.unbatch 和 pad_sequence 进行预处理。
    """
    def __init__(self,
                 trunk_size,
                 branch_sizes,
                 output_size,
                 # 保持与 get_model 兼容的额外参数
                 coord_dim=None, state_dim=None, theta_dim=None,
                 # 模型超参数
                 n_layers=2, n_hidden=64, n_head=1, n_experts=2,
                 n_inner=4, mlp_layers=2, attn_type='linear', act='gelu',
                 ffn_dropout=0.0, attn_dropout=0.0, space_dim=2, **kwargs):
        super(GNOT, self).__init__()
        
        self.n_inputs = len(branch_sizes)
        self.output_size = output_size
        self.space_dim = space_dim
        
        self.gpt_config = MoEGPTConfig(
            attn_type=attn_type, embd_pdrop=ffn_dropout, resid_pdrop=ffn_dropout,
            attn_pdrop=attn_dropout, n_embd=n_hidden, n_head=n_head, n_layer=n_layers,
            block_size=128, act=act, n_experts=n_experts, space_dim=space_dim,
            branch_sizes=branch_sizes, n_inputs=len(branch_sizes), n_inner=n_inner
        )

        self.trunk_mlp = MLP(trunk_size, n_hidden, n_hidden, n_layers=mlp_layers, act=act)
        if self.n_inputs > 0:
            self.branch_mlps = nn.ModuleList([MLP(bsize, n_hidden, n_hidden, n_layers=mlp_layers, act=act) for bsize in branch_sizes])

        # 注意：这里使用 nn.ModuleList 以便正确处理 KV Cache
        self.blocks = nn.ModuleList([MIOECrossAttentionBlock(self.gpt_config) for _ in range(self.gpt_config.n_layer)])
        self.out_mlp = MLP(n_hidden, n_hidden, output_size, n_layers=mlp_layers)
        self.__name__ = 'GNOT'

    def forward(self, g, u_p, inputs, past_key_values=None, snapshot_layers=None,timer_events=None,**kwargs):
        # 原始的预处理流程
        activation_fields = {}
        gs = dgl.unbatch(g)
        # 从 DGL 图的节点数据 'x' 中获取输入，它应包含坐标和状态
        x_from_g = pad_sequence([_g.ndata['x'] for _g in gs]).permute(1, 0, 2)
        
        pos = x_from_g[:, :, :self.space_dim]
        
        # 拼接参数 u_p
        x = torch.cat([x_from_g, u_p.unsqueeze(1).repeat([1, x_from_g.shape[1], 1])], dim=-1)
        
        x = self.trunk_mlp(x)
        if snapshot_layers is not None and 0 in snapshot_layers:
            # 假设 batch_size=1
            # 计算所有节点的激活 L2 范数
            activation_fields[0] = torch.linalg.norm(x[0, :, :], dim=-1).cpu().numpy()
        if self.n_inputs > 0 and isinstance(inputs, MultipleTensors) and len(inputs.x) > 0:
            z = MultipleTensors([self.branch_mlps[i](inputs[i]) for i in range(self.n_inputs)])
        else:
            z = MultipleTensors([])
        if timer_events is not None:
            start_event, end_event = timer_events
            start_event.record()
        if past_key_values is None:
            past_key_values = [None] * len(self.blocks)
        
        new_past_key_values = []
        for i, block in enumerate(self.blocks):
            # 关键适配：MIOECrossAttentionBlock 返回一个元组
            x, new_cache = block(x, z, pos, self_attn_cache=past_key_values[i])
            new_past_key_values.append(new_cache)
            current_layer_idx = i + 1
            if snapshot_layers is not None and current_layer_idx in snapshot_layers:
                activation_fields[current_layer_idx] = torch.linalg.norm(x[0, :, :], dim=-1).cpu().numpy()
        if timer_events is not None:
            end_event.record()
        x = self.out_mlp(x)

        x_out = torch.cat([x[i, :num] for i, num in enumerate(g.batch_num_nodes())], dim=0)
        
        #保持接口统一，返回元组
        # if snapshot_layers is not None:
        #     return x_out, new_past_key_values, activation_fields
        # else:
        #     return x_out, new_past_key_values
        return x_out

class SR_GNOT(nn.Module):
    """
    适配长时序自回归任务的 Structured Recursive GNOT。
    - `forward` 签名与训练/评估脚本统一。
    - 直接使用批次化的 `coords` 和 `u_p` 张量。
    - `capacity_factors` 从硬编码改为可配置的 `capacity_schedule`。
    """
    def __init__(self,
                 coord_dim, state_dim, theta_dim,
                 branch_sizes, output_size, n_hidden, n_head, n_layers,
                 final_keep_ratio=0.25, # 这个参数现在用于生成schedule
                 capacity_schedule=None, # 优先使用这个参数
                 n_experts=4, space_dim=2,
                 n_inner=4, mlp_layers=2, act='gelu',
                 ffn_dropout=0.0, attn_dropout=0.0, **kwargs):
        super().__init__()
        self.__name__ = 'SR_GNOT'
        self.output_size = output_size
        self.recursion_depth = n_layers
        self.n_hidden = n_hidden
        self.n_inputs = len(branch_sizes) if branch_sizes else 0
        
        # --- 核心修改: 容量规划更灵活 ---
        if capacity_schedule and isinstance(capacity_schedule, list):
            self.capacity_factors = capacity_schedule
            # 如果提供的schedule比递归深度短，用最后一个值填充
            if len(self.capacity_factors) < self.recursion_depth:
                self.capacity_factors.extend([self.capacity_factors[-1]] * (self.recursion_depth - len(self.capacity_factors)))
        else:
            # 如果未提供schedule，则动态计算
            self.final_keep_ratio = final_keep_ratio
            self.capacity_factors = None # 标记为动态计算

        print(f"[{self.__name__}] Capacity Schedule: {'Dynamic' if self.capacity_factors is None else self.capacity_factors}")
        
        # 输入/输出层
        trunk_size = state_dim + coord_dim + theta_dim
        self.trunk_mlp = MLP(trunk_size, n_hidden, n_hidden, n_layers=mlp_layers, act=act)
        if self.n_inputs > 0:
            self.branch_mlps = nn.ModuleList([MLP(bsize, n_hidden, n_hidden, n_layers=mlp_layers, act=act) for bsize in branch_sizes])
        self.out_mlp = MLP(n_hidden, n_hidden, output_size, n_layers=mlp_layers)
        
        # 核心递归组件
        self.router = LightweightRouter(n_hidden)
        
        config = MoEGPTConfig(n_embd=n_hidden, n_head=n_head, n_inputs=self.n_inputs, 
                              n_inner=n_inner, attn_pdrop=attn_dropout, 
                              resid_pdrop=ffn_dropout, n_experts=n_experts, 
                              space_dim=space_dim, n_layer=1) # n_layer=1因为是单个块
        
        self.recursion_blocks = nn.ModuleList([MIOECrossAttentionBlock(config) for _ in range(self.recursion_depth)])

    def forward(self, g, coords, u_p, inputs, **kwargs):
        # 1. 准备输入
        B, N, C_coord = coords.shape
        C_state = g.ndata['x'].shape[-1]
        
        current_state = g.ndata['x'].view(B, N, C_state)
        u_p_expanded = u_p.unsqueeze(1).repeat(1, N, 1)
        
        x_trunk_input = torch.cat([current_state, coords, u_p_expanded], dim=-1)
        hidden_states = self.trunk_mlp(x_trunk_input)
        
        if self.n_inputs > 0 and isinstance(inputs, MultipleTensors) and len(inputs.x) > 0:
            z_list = [self.branch_mlps[i](inputs.x[i]) for i in range(self.n_inputs)]
            z = MultipleTensors(z_list)
        else:
            z = MultipleTensors([])
        
        # 2. 一次性全局路由打分与排序
        router_scores = self.router(hidden_states).squeeze(-1)
        with torch.no_grad():
            _, global_ranking_indices = torch.sort(router_scores, dim=1, descending=True)

        # 3. 结构化的自适应递归循环
        # 如果capacity_factors是动态的，则在forward中计算
        if self.capacity_factors is None:
            ratios = np.linspace(1.0, self.final_keep_ratio, self.recursion_depth)
            capacity_factors = [int(N * r) for r in ratios]
            capacity_factors[-1] = max(capacity_factors[-1], 1) # 确保最后一层至少有一个token
        else:
            capacity_factors = self.capacity_factors

        for depth in range(self.recursion_depth):
            k = min(capacity_factors[depth], N) # 确保 k 不超过 N
            active_indices = global_ranking_indices[:, :k]
            active_h = pack_tokens(hidden_states, active_indices)
            active_coords = pack_tokens(coords, active_indices)
            
            # 从 active_coords 中提取 pos
            active_pos = active_coords[:, :, :self.recursion_blocks[depth].gatenet[0].in_features]
            
            block_output, _ = self.recursion_blocks[depth](active_h, z, active_pos)
            
            # Scatter 回原始张量
            idx_exp = active_indices.unsqueeze(-1).expand(-1, -1, self.n_hidden)
            hidden_states = hidden_states.scatter(1, idx_exp, block_output)
        
        # 4. 输出
        output = self.out_mlp(hidden_states)
        x_out = output.view(-1, self.output_size)
        
        return x_out, None
class SR_GNOT_SS(nn.Module):
    """
    将静态自适应递归应用于单步预测任务的 GNOT 模型。
    - 接口与原始 GNOT 完全兼容，可被 train.py 调用。
    - 内部计算采用 SR_GNOT 的高效路由和递归机制。
    """
    def __init__(self,
                 trunk_size, branch_sizes, output_size,
                 # 模型超参数
                 n_layers, n_hidden, n_head,
                 capacity_ratios=None, final_keep_ratio=0.25,
                 n_experts=4, space_dim=2, n_inner=4, mlp_layers=2,
                 act='gelu', ffn_dropout=0.0, attn_dropout=0.0, **kwargs):
        super().__init__()
        self.__name__ = 'SR_GNOT_SS'
        
        self.output_size = output_size
        self.recursion_depth = n_layers
        self.n_hidden = n_hidden
        self.n_inputs = len(branch_sizes) if branch_sizes else 0
        
        # # 容量规划逻辑 (与 SR_GNOT 相同)
        # if capacity_schedule and isinstance(capacity_schedule, list):
        #     self.capacity_factors = capacity_schedule
        #     if len(self.capacity_factors) < self.recursion_depth:
        #         self.capacity_factors.extend([self.capacity_factors[-1]] * (self.recursion_depth - len(self.capacity_factors)))
        # else:
        #     self.final_keep_ratio = final_keep_ratio
        #     self.capacity_factors = None

        # print(f"[{self.__name__}] Capacity Schedule: {'Dynamic' if self.capacity_factors is None else self.capacity_factors}")
        if capacity_ratios and isinstance(capacity_ratios, list):
            # 检查比例是否在 (0, 1] 范围内
            if not all(0.0 < r <= 1.0 for r in capacity_ratios):
                raise ValueError("capacity_ratios 中的所有值必须在 (0, 1] 范围内。")
            self.capacity_ratios = capacity_ratios
            if len(self.capacity_ratios) < self.recursion_depth:
                self.capacity_ratios.extend([self.capacity_ratios[-1]] * (self.recursion_depth - len(self.capacity_ratios)))
        else:
            self.final_keep_ratio = final_keep_ratio
            self.capacity_ratios = None # 标记为动态计算

        print(f"[{self.__name__}] Capacity Ratios: {'Dynamic' if self.capacity_ratios is None else self.capacity_ratios}")        
        # 输入/输出层 (与原始 GNOT 相同)
        self.trunk_mlp = MLP(trunk_size, n_hidden, n_hidden, n_layers=mlp_layers, act=act)
        if self.n_inputs > 0:
            self.branch_mlps = nn.ModuleList([MLP(bsize, n_hidden, n_hidden, n_layers=mlp_layers, act=act) for bsize in branch_sizes])
        self.out_mlp = MLP(n_hidden, n_hidden, output_size, n_layers=mlp_layers)
        
        # 核心递归组件
        self.router = LightweightRouter(n_hidden)
        
        config = MoEGPTConfig(n_embd=n_hidden, n_head=n_head, n_inputs=self.n_inputs, 
                              n_inner=n_inner, attn_pdrop=attn_dropout, 
                              resid_pdrop=ffn_dropout, n_experts=n_experts, 
                              space_dim=space_dim, n_layer=1)
        
        self.recursion_blocks = nn.ModuleList([MIOECrossAttentionBlock(config) for _ in range(self.recursion_depth)])
        # torch.manual_seed(42)
    def forward(self, g, u_p, inputs, current_epoch=-1, pretrain_epochs=0,timer_events=None,**kwargs):
        # 1. 准备输入 (采用原始 GNOT 的方式)
        gs = dgl.unbatch(g)
        # 从 DGL 图的节点数据 'x' 中获取输入，它应包含坐标和状态
        x_from_g = pad_sequence([_g.ndata['x'] for _g in gs]).permute(1, 0, 2)
        
        B, N, _ = x_from_g.shape
        
        # 从 x_from_g 中提取坐标，用于后续使用
        coords = x_from_g[:, :, :2] # 假设前两维是坐标
        
        # 拼接参数 u_p
        x_trunk_input = torch.cat([x_from_g, u_p.unsqueeze(1).repeat([1, N, 1])], dim=-1)
        
        hidden_states = self.trunk_mlp(x_trunk_input)
        
        if self.n_inputs > 0 and isinstance(inputs, MultipleTensors) and len(inputs.x) > 0:
            z_list = [self.branch_mlps[i](inputs.x[i]) for i in range(self.n_inputs)]
            z = MultipleTensors(z_list)
        else:
            z = MultipleTensors([])
        
        # 2. 一次性全局路由打分与排序 (与 SR_GNOT 相同)
        router_scores = self.router(hidden_states).squeeze(-1)
        with torch.no_grad():
            _, global_ranking_indices = torch.sort(router_scores, dim=1, descending=True)
            # 使用随机路由来强制 recursion_blocks 学习通用特征
        # # x = hidden_states
        # _, global_ranking_indices = torch.rand_like(router_scores).sort(dim=1, descending=True)
            # 在第二阶段或评估时，使用学习到的路由器
        # with torch.no_grad():
        #     _, global_ranking_indices = torch.sort(router_scores, dim=1, descending=True)
        # 3. 结构化的自适应递归循环 (与 SR_GNOT 相同)
        # if self.capacity_factors is None:
        #     ratios = np.linspace(1.0, self.final_keep_ratio, self.recursion_depth)
        #     capacity_factors = [int(N * r) for r in ratios]
        #     capacity_factors[-1] = max(capacity_factors[-1], 1)
        # else:
        #     capacity_factors = self.capacity_factors
        # --- 3. 结构化的自适应递归循环 (关键修改) ---
        if self.capacity_ratios is None:
            # 如果没有提供静态比例，则动态生成
            capacity_ratios = np.linspace(1.0, self.final_keep_ratio, self.recursion_depth)
        else:
            capacity_ratios = self.capacity_ratios
        if timer_events is not None:
            start_event, end_event = timer_events
            start_event.record()
        for depth in range(self.recursion_depth):
            current_ratio = capacity_ratios[depth]
            # 计算保留数量 k，并确保至少为 1
            k = max(1, int(N * current_ratio))
            active_indices = global_ranking_indices[:, :k]
            # weights = torch.ones(B, N, device=x.device) # 均匀权重
            # active_indices = torch.multinomial(weights, num_samples=k, replacement=False)
            active_h = pack_tokens(hidden_states, active_indices)
            active_coords = pack_tokens(coords, active_indices)
            
            active_pos = active_coords[:, :, :self.recursion_blocks[depth].gatenet[0].in_features]
            
            block_output, _ = self.recursion_blocks[depth](active_h, z, active_pos)
            
            idx_exp = active_indices.unsqueeze(-1).expand(-1, -1, self.n_hidden)
            hidden_states = hidden_states.scatter(1, idx_exp, block_output)
        # --- 微基准测试计时器 ---
        if timer_events is not None:
            end_event.record()
        # 4. 输出 (采用原始 GNOT 的方式)
        output = self.out_mlp(hidden_states)
        x_out = torch.cat([output[i, :num] for i, num in enumerate(g.batch_num_nodes())], dim=0)
        
        # 原始 train.py 的 train_batch 期望单个返回值
        return x_out
class GNOT_StaticDepth(nn.Module):
    """
    实现了“静态深度”自适应策略的 GNOT 模型，用于与 SR_GNOT 的“静态容量”进行对比。
    - 在训练时，额外返回 target_depths 用于计算 Ponder Cost 辅助损失。
    """
    def __init__(self,
                 trunk_size, branch_sizes, output_size,
                 # 模型超参数
                 n_layers, # n_layers 在这里是最大深度 max_depth
                 n_hidden, n_head,
                 n_experts=4, space_dim=2, n_inner=4, mlp_layers=2,
                 act='gelu', ffn_dropout=0.0, attn_dropout=0.0, **kwargs):
        super().__init__()
        self.__name__ = 'GNOT_StaticDepth'
        
        self.output_size = output_size
        self.max_depth = n_layers
        self.n_hidden = n_hidden
        self.n_inputs = len(branch_sizes) if branch_sizes else 0
        self.space_dim = space_dim
        
        # --- 核心组件 ---
        # 1. 输入/输出层
        self.trunk_mlp = MLP(trunk_size, n_hidden, n_hidden, n_layers=mlp_layers, act=act)
        if self.n_inputs > 0:
            self.branch_mlps = nn.ModuleList([MLP(bsize, n_hidden, n_hidden, n_layers=mlp_layers, act=act) for bsize in branch_sizes])
        self.out_mlp = MLP(n_hidden, n_hidden, output_size, n_layers=mlp_layers)
        
        # 2. 深度规划路由器
        self.depth_router = DepthRouter(n_hidden, self.max_depth)
        
        # 3. 共享的计算块 (只有一个实例)
        config = MoEGPTConfig(n_embd=n_hidden, n_head=n_head, n_inputs=self.n_inputs, 
                              n_inner=n_inner, attn_pdrop=attn_dropout, 
                              resid_pdrop=ffn_dropout, n_experts=n_experts, 
                              space_dim=space_dim, n_layer=1)
        self.shared_block = MIOECrossAttentionBlock(config)
        
        self.log_info()

    def log_info(self):
        print(f"--- Model Config: {self.__name__} ---")
        print(f"  Max Recursion Depth: {self.max_depth}")
        print("------------------------------------")

    def forward(self, g, u_p, inputs, **kwargs):
        # --- 1. 准备输入 ---
        gs = dgl.unbatch(g)
        x_from_g = pad_sequence([_g.ndata['x'] for _g in gs]).permute(1, 0, 2)
        B, N, _ = x_from_g.shape
        coords = x_from_g[:, :, :self.space_dim]
        x_trunk_input = torch.cat([x_from_g, u_p.unsqueeze(1).repeat([1, N, 1])], dim=-1)
        
        hidden_states = self.trunk_mlp(x_trunk_input)
        
        if self.n_inputs > 0 and isinstance(inputs, MultipleTensors) and len(inputs.x) > 0:
            z = MultipleTensors([self.branch_mlps[i](inputs.x[i]) for i in range(self.n_inputs)])
        else:
            z = MultipleTensors([])
        
        # --- 2. 一次性静态深度规划 ---
        depth_logits = self.depth_router(hidden_states)
        
        # 使用 Gumbel-Softmax 使 argmax 在训练时可微
        if self.training:
            target_depths_one_hot = F.gumbel_softmax(depth_logits, tau=1, hard=True, dim=-1)
            target_depths = torch.argmax(target_depths_one_hot, dim=-1)
        else:
            target_depths = torch.argmax(depth_logits, dim=-1)

        # --- 3. 结构化的递归循环 ---
        x = hidden_states
        # 预计算所有层的活跃掩码
        layer_masks = [target_depths >= depth for depth in range(self.max_depth)]
        layer_capacities = []
        for depth in range(self.max_depth):
            active_mask = layer_masks[depth]
            num_active_tokens = torch.sum(active_mask).item()
            layer_capacities.append(num_active_tokens)
            
            if num_active_tokens == 0:
                continue
            if not active_mask.any():
                continue

            # 使用 gather/pack 提取活跃的 Token
            # 注意: 这里的实现为了简洁，没有处理批次中每个样本活跃数不同的情况，
            # 而是假设可以找到一个统一的处理方式。一个更健壮的实现可能需要稀疏操作库。
            # 为了实验对比，这种简化是可接受的。
            active_h = x[active_mask]
            active_coords = coords[active_mask]
            
            # 由于 z 是批处理的，我们需要找到一种方式来对齐
            # 这是一个简化，假设 z 可以被所有活跃 token 共享
            # (一个更复杂的实现会根据 batch_idx 来 gather z)
            z_for_block = z

            # 执行核心计算
            # 输入需要是 (B, k, C) 格式, 我们将所有活跃 token 视为一个大批次
            num_active = active_h.shape[0]
            active_pos = active_coords[:, :self.space_dim]
            block_output, _ = self.shared_block(active_h.view(1, num_active, -1), z_for_block, active_pos.view(1, num_active, -1))
            
            # 使用 scatter 写回结果
            x = x.masked_scatter(active_mask.unsqueeze(-1), block_output.view(num_active, -1))
        
        # --- 4. 输出 ---
        output = self.out_mlp(x)
        x_out = torch.cat([output[i, :num] for i, num in enumerate(g.batch_num_nodes())], dim=0)
        
        # --- 5. 根据模式返回不同结果 ---
        if self.training:
            return x_out, {"target_depths": target_depths}
        else:
            # 推理时，额外返回逐层容量信息
            return x_out, {"layer_capacities": layer_capacities}