# adaptive_transolver_final.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import numpy as np

# 导入原始的 Physics_Attention 作为父类
try:
    from .Physics_Attention import Physics_Attention_Structured_Mesh_2D
    from .Transolver_Structured_Mesh_2D import MLP as TransolverMLP
except ImportError:
    from model.Physics_Attention import Physics_Attention_Structured_Mesh_2D
    from model.Transolver_Structured_Mesh_2D import MLP as TransolverMLP

# ==============================================================================
# 1. 基础与新增的辅助模块
# ==============================================================================

def pack_tokens(x, active_idx):
    """根据索引高效地收集活跃的Token。"""
    D = x.shape[-1]
    idx_exp = active_idx.unsqueeze(-1).expand(-1, -1, D)
    return torch.gather(x, 1, idx_exp)

class CNNRouter(nn.Module):
    """
    一个基于卷积神经网络的路由器，用于捕捉层级化的空间特征来为Token打分。
    """
    def __init__(self, in_channels, H, W, hidden_channels_ratio=0.25):
        super().__init__()
        self.H = H
        self.W = W
        hidden_channels = max(int(in_channels * hidden_channels_ratio), 16)
        self.cnn_scorer = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(hidden_channels, 1, kernel_size=1)
        )

    def forward(self, x):
        B, N, C = x.shape
        if N != self.H * self.W:
            raise ValueError(f"输入Token数 N={N} 与 CNNRouter 的网格尺寸 H={self.H}, W={self.W} 不匹配。")
        x_image = x.permute(0, 2, 1).reshape(B, C, self.H, self.W)
        score_map = self.cnn_scorer(x_image)
        scores = score_map.reshape(B, N, 1)
        return scores

class CrossAttention(nn.Module):
    """一个标准的交叉注意力模块。"""
    def __init__(self, d_model, n_head, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout, batch_first=True)
    
    def forward(self, query, key_value):
        # query 来自“焦点”，key_value 来自“背景”
        return self.attn(query, key_value, key_value)[0]

# ==============================================================================
# 2. 升级后的核心计算模块
# ==============================================================================

class Physics_Attention_Adaptive(Physics_Attention_Structured_Mesh_2D):
    """
    一个适配了可变序列长度的Physics_Attention版本。
    (此模块保持不变，但其'完美平方数'的限制需要被上层逻辑满足)
    """
    def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=64, **kwargs):
        super().__init__(dim, heads, dim_head, dropout, slice_num, H=-1, W=-1)

    def forward(self, x):
        B, N, C = x.shape
        if N == 0: return torch.zeros_like(x)
        H = W = int(np.sqrt(N))
        if H * W != N:
            raise ValueError(f"自适应模式下，输入序列长度 {N} 必须是一个完美的平方数。")
        
        # --- (后续逻辑与您原始代码完全相同) ---
        x_reshaped = x.reshape(B, H, W, C).contiguous().permute(0, 3, 1, 2)
        fx_mid = self.in_project_fx(x_reshaped).permute(0, 2, 3, 1).reshape(B, N, self.heads, self.dim_head).permute(0, 2, 1, 3)
        x_mid = self.in_project_x(x_reshaped).permute(0, 2, 3, 1).reshape(B, N, self.heads, self.dim_head).permute(0, 2, 1, 3)
        slice_weights = self.softmax(self.in_project_slice(x_mid) / torch.clamp(self.temperature, min=0.1, max=5))
        slice_norm = slice_weights.sum(2, keepdim=True).clamp(min=1e-8)
        slice_token = torch.matmul(slice_weights.transpose(-2, -1), fx_mid) / slice_norm.transpose(-2, -1)
        q_slice, k_slice, v_slice = self.to_q(slice_token), self.to_k(slice_token), self.to_v(slice_token)
        dots = torch.matmul(q_slice, k_slice.transpose(-1, -2)) * self.scale
        attn = self.dropout(self.softmax(dots))
        out_slice = torch.matmul(attn, v_slice)
        out_x = torch.einsum("bhgc,bhng->bhnc", out_slice, slice_weights)
        out_x = rearrange(out_x, 'b h n d -> b n (h d)')
        return self.to_out(out_x)

class TwoStreamTransolverBlock(nn.Module):
    """
    一个经过改造的、支持双流输入的自适应 Transolver Block。
    """
    def __init__(self, num_heads, hidden_dim, dropout, mlp_ratio, slice_num):
        super().__init__()
        self.ln_1 = nn.LayerNorm(hidden_dim)
        self.Attn = Physics_Attention_Adaptive(
            dim=hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads,
            dropout=dropout, slice_num=slice_num
        )
        self.ln_2 = nn.LayerNorm(hidden_dim)
        self.mlp = TransolverMLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim)
        
        # 新增一个交叉注意力模块用于融合背景信息
        self.ln_cross = nn.LayerNorm(hidden_dim)
        self.cross_attention = CrossAttention(hidden_dim, num_heads)

    def forward(self, fx, context):
        # fx: 焦点流 active_h, shape: [B, k, C] (k 是完美平方数)
        # context: 背景流 background_tokens, shape: [B, num_bg, C]
        
        # 1. 焦点流内部的自注意力 (Physics-Attention)
        fx_self_attended = self.Attn(self.ln_1(fx)) + fx
        
        # 2. 通过交叉注意力，将背景(context)信息注入到焦点流中
        fx_context_fused = self.cross_attention(self.ln_cross(fx_self_attended), context) + fx_self_attended
        
        # 3. MLP
        fx_out = self.mlp(self.ln_2(fx_context_fused)) + fx_context_fused
        
        return fx_out

# ==============================================================================
# 3. 最终的、升级后的自适应Transolver主模型
# ==============================================================================

class TwoStreamCNNRouterTransolver(nn.Module):
    def __init__(self, space_dim, fun_dim, n_layers, n_hidden, dropout, n_head, mlp_ratio, 
                 out_dim, slice_num, H, W, capacity_ratios, 
                 bg_pool_size=8, **kwargs):
        super().__init__()
        self.__name__ = "TwoStreamCNNRouterTransolver"
        
        self.recursion_depth = n_layers
        self.H = H
        self.W = W

        # a. 容量规划 (预计算完美平方数)
        num_coarse_nodes = H * W
        if not isinstance(capacity_ratios, list) or len(capacity_ratios) != n_layers:
            raise ValueError(f"capacity_ratios 的长度必须等于 n_layers。")
        
        precomputed_factors = []
        
        # 2. 遍历用户输入的每一个比例
        for ratio in capacity_ratios:
            # a. 根据比例计算一个目标 k 值
            target_k = num_coarse_nodes * ratio
            
            # b. 找到小于或等于 target_k 的最大完美平方数
            target_h = int(np.sqrt(target_k))
            perfect_square_k = target_h * target_h
            
            # c. 将这个完美的平方数添加到列表中
            precomputed_factors.append(max(perfect_square_k, 1)) # 确保容量至少为1
            
        # 3. 将预计算好的、完美的列表保存为成员变量
        self.capacity_factors = precomputed_factors
        print(f"[{self.__name__}] 预计算容量 (活跃Token数/层): {self.capacity_factors}")
        
        # b. I/O 组件
        input_dim = fun_dim + space_dim
        self.preprocess = TransolverMLP(input_dim, n_hidden * 2, n_hidden)
        self.out_mlp = TransolverMLP(n_hidden, n_hidden, out_dim)

        # c. 升级后的自适应组件
        self.router = CNNRouter(in_channels=n_hidden, H=H, W=W)
        self.background_pooler = nn.AvgPool2d(kernel_size=bg_pool_size, stride=bg_pool_size)
        
        # d. 递归核心
        self.transolver_blocks = nn.ModuleList([
            TwoStreamTransolverBlock(
                num_heads=n_head, hidden_dim=n_hidden, dropout=dropout,
                mlp_ratio=mlp_ratio, slice_num=slice_num
            ) for _ in range(self.recursion_depth)
        ])

    def forward(self, x, fx, **kwargs):
        # 1. 预处理
        hidden_states = self.preprocess(torch.cat((x, fx), -1))
        B, N, C = hidden_states.shape
        
        # 2. 区域评估与路由 (CNN Router, 一次性)
        router_scores = self.router(hidden_states)
        with torch.no_grad():
            _, global_ranking_indices = torch.sort(router_scores.squeeze(-1), dim=1, descending=True)

        # 3. 创建全局背景流 (一次性)
        hidden_states_image = hidden_states.permute(0, 2, 1).reshape(B, C, self.H, self.W)
        background_feature_map = self.background_pooler(hidden_states_image)
        background_tokens = background_feature_map.reshape(B, C, -1).permute(0, 2, 1)
        
        # 4. 双流自适应递归循环
        for depth in range(self.recursion_depth):
            k = self.capacity_factors[depth]
            k = min(k, N)
            active_indices = global_ranking_indices[:, :k]
            
            # a. 打包“焦点流”
            active_h = pack_tokens(hidden_states, active_indices)
            
            # b. 将“焦点流”和“背景流”送入Block进行融合计算
            updated_active_h = self.transolver_blocks[depth](active_h, context=background_tokens)
            
            # c. 将更新后的“焦点”写回
            hidden_states = hidden_states.scatter(1, active_indices.unsqueeze(-1).expand(-1, -1, C), updated_active_h)

        # 5. 最终输出
        output = self.out_mlp(hidden_states)
        return output, None