import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import numpy as np

# 导入原始组件库
try:
    from .Physics_Attention import Physics_Attention_Structured_Mesh_2D
    from .Transolver_Structured_Mesh_2D import MLP as TransolverMLP
except ImportError:
    from model.Physics_Attention import Physics_Attention_Structured_Mesh_2D
    from model.Transolver_Structured_Mesh_2D import MLP as TransolverMLP

# ==============================================================================
# 1. 基础辅助模块 (自包含)
# ==============================================================================

def pack_tokens(x, active_idx):
    """根据索引高效地收集活跃的Token。"""
    D = x.shape[-1]
    idx_exp = active_idx.unsqueeze(-1).expand(-1, -1, D)
    return torch.gather(x, 1, idx_exp)

# ==============================================================================
# 2. 核心计算模块：自路由的自适应块
# ==============================================================================

class SelfRoutingAdaptiveBlock(nn.Module):
    """
    一个实现了“自路由”机制的核心块。
    它利用 slice_weights 来评估点的重要性，并只在活跃点上进行完整的
    Attention + MLP 计算。
    """
    def __init__(self, num_heads, hidden_dim, dropout, mlp_ratio, slice_num, H, W):
        super().__init__()
        self.H = H
        self.W = W
        self.heads = num_heads
        
        # a. “评估器” (Slicer): 用于生成 slice_weights
        #    我们只需要它的前半部分组件
        _slicer = Physics_Attention_Structured_Mesh_2D(
            dim=hidden_dim, heads=num_heads, dim_head=hidden_dim//num_heads, H=H, W=W, 
            slice_num=slice_num, dropout=dropout
        )
        self.in_project_x = _slicer.in_project_x
        self.in_project_slice = _slicer.in_project_slice
        self.temperature = _slicer.temperature
        self.softmax = _slicer.softmax

        # b. “计算核心” (Processor): 作用于活跃点子集
        self.ln_1 = nn.LayerNorm(hidden_dim)
        # 注意：这里的注意力是标准自注意力，因为它作用于无拓扑的点集上
        self.attn = nn.MultiheadAttention(
            embed_dim=hidden_dim, num_heads=num_heads, 
            dropout=dropout, batch_first=True
        )
        self.ln_2 = nn.LayerNorm(hidden_dim)
        self.mlp = TransolverMLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim)

    def get_routing_scores(self, x):
        """
        根据输入 x 计算 slice_weights，并将其转换为路由分数。
        """
        B, N, C = x.shape
        x_reshaped = x.reshape(B, self.H, self.W, C).contiguous().permute(0, 3, 1, 2)
        
        x_mid_proj = self.in_project_x(x_reshaped)
        x_mid = rearrange(x_mid_proj.permute(0, 2, 3, 1), 'b h w (heads dim) -> b heads (h w) dim', heads=self.heads)
        
        slice_logits = self.in_project_slice(x_mid)
        slice_weights = self.softmax(slice_logits / torch.clamp(self.temperature, min=0.1, max=5))
        
        # 策略：使用最大隶属度作为重要性分数
        point_scores, _ = torch.max(slice_weights, dim=3) # [B, H, N]
        point_scores = torch.mean(point_scores, dim=1) # [B, N]
        
        return point_scores

    def forward(self, x, k_point_ratio):
        B, N, C = x.shape
        shortcut = x
        
        # --- “自路由”决策 ---
        with torch.no_grad(): # 路由决策过程不参与梯度计算
            # 1. 计算路由分数
            point_scores = self.get_routing_scores(x)
            
            # 2. 排序并选择活跃的点
            _, top_point_indices = torch.sort(point_scores, dim=1, descending=True)
            N_active = max(int(N * k_point_ratio), 1)
            active_point_indices = top_point_indices[:, :N_active]

        # --- 【物理Token减少】---
        # 3. 只打包活跃的点进行后续所有计算
        active_points = pack_tokens(x, active_point_indices)
        
        # --- 在活跃点子集上执行完整的 Transformer Block 计算 ---
        # 4. LN -> Attention -> Add
        norm_active_points = self.ln_1(active_points)
        attn_output_active, _ = self.attn(norm_active_points, norm_active_points, norm_active_points)
        fused_active_points = active_points + attn_output_active
        
        # 5. LN -> MLP -> Add
        mlp_output_active = self.mlp(self.ln_2(fused_active_points))
        updated_active_points = fused_active_points + mlp_output_active
        
        # --- 信息写回 ---
        # 6. 将更新后的活跃点结果 scatter 回原始画布
        update = torch.zeros_like(x)
        update = update.scatter_(
            1, active_point_indices.unsqueeze(-1).expand(-1, -1, C),
            updated_active_points
        )
        
        # 7. 最终融合 (非活跃点保持不变，只更新活跃点)
        output = shortcut + update - x.scatter(1, active_point_indices.unsqueeze(-1).expand(-1, -1, C), active_points)
        # 一个更简单的实现是:
        # output = x.scatter(1, active_indices..., updated_active...) 
        # 但上面的写法在数学上更严谨地表达了“只更新活跃部分”

        # 一个最简洁的实现：
        output = shortcut.scatter(1, active_point_indices.unsqueeze(-1).expand(-1, -1, C), updated_active_points)

        return output

# ==============================================================================
# 3. 最终的主模型
# ==============================================================================

class SelfRoutingAdaptiveTransolver(nn.Module):
    def __init__(self, space_dim, fun_dim, n_layers, n_hidden, dropout, n_head, mlp_ratio, 
                 out_dim, slice_num, H, W, capacity_ratios, **kwargs):
        super().__init__()
        self.__name__ = "SelfRoutingAdaptiveTransolver"
        
        self.recursion_depth = n_layers
        
        # a. 【单一容量规划】
        if not isinstance(capacity_ratios, list) or len(capacity_ratios) != n_layers:
            raise ValueError(f"capacity_ratios 的长度必须等于 n_layers.")
        self.point_ratios = capacity_ratios
        print(f"[{self.__name__}] 使用了基于Slice权重的自路由策略。")
        print(f"[{self.__name__}] Point容量规划 (比例/层): {self.point_ratios}")
        
        # b. I/O组件
        input_dim = fun_dim + space_dim
        self.preprocess = TransolverMLP(input_dim, n_hidden * 2, n_hidden)
        self.out_mlp = TransolverMLP(n_hidden, n_hidden, out_dim)

        # c. 递归组件
        self.blocks = nn.ModuleList([
            SelfRoutingAdaptiveBlock(
                num_heads=n_head, hidden_dim=n_hidden, dropout=dropout,
                mlp_ratio=mlp_ratio, slice_num=slice_num, H=H, W=W
            ) for _ in range(n_layers)
        ])

    def forward(self, x, fx, **kwargs):
        if fx is not None:
            # 如果 fx 存在 (时序任务)，则进行拼接
            model_input = torch.cat((x, fx), -1)
        else:
            # 如果 fx 不存在 (稳态任务)，则输入只有 x
            model_input = x
            
        # 2. 将正确准备好的输入送入 preprocess
        hidden_states = self.preprocess(model_input)
        
        for depth in range(self.recursion_depth):
            point_ratio_current = self.point_ratios[depth]
            hidden_states = self.blocks[depth](hidden_states, point_ratio_current)
            
        output = self.out_mlp(hidden_states)
        return output, None