import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_
import numpy as np
import torch.nn.functional as F
# --- 导入 Irregular Mesh 版本的核心组件 ---
# 我们假设这些文件都在 model/ 目录下
try:
    from .Transolver_Irregular_Mesh import Transolver_block as IrregularTransolverBlock
    from .Transolver_Irregular_Mesh import MLP
except ImportError:
    from model.Transolver_Irregular_Mesh import Transolver_block as IrregularTransolverBlock
    from model.Transolver_Irregular_Mesh import MLP

# --- 导入您自适应逻辑所需的基础模块 ---
class LightweightRouter(nn.Module):
    def __init__(self, d_model, d_hidden_ratio=0.25):
        super().__init__()
        d_hidden = int(d_model * d_hidden_ratio)
        self.proj1 = nn.Linear(d_model, d_hidden)
        self.proj2 = nn.Linear(d_hidden, 1)

    def forward(self, x):
        return self.proj2(F.gelu(self.proj1(x)))

def pack_tokens(x, active_idx):
    D = x.shape[-1]
    idx_exp = active_idx.unsqueeze(-1).expand(-1, -1, D)
    return torch.gather(x, 1, idx_exp)

# ==============================================================================
# 最终的、通用的自适应 Transolver 模型
# ==============================================================================

class IrregularAdaptiveTransolver(nn.Module):
    """
    一个基于 Irregular Mesh 架构的、通用的自适应 Transolver。
    它通过点削减实现自适应，并能处理任意网格和任务类型。
    """
    def __init__(self, space_dim, fun_dim, n_layers, n_hidden, dropout, n_head, mlp_ratio,
                 out_dim, slice_num, ref, unified_pos, capacity_ratios, **kwargs):
        super().__init__()
        self.__name__ = "IrregularAdaptiveTransolver"
        
        # --- a. 基础架构参数 (来自 Irregular_Mesh.Model) ---
        self.ref = ref
        self.unified_pos = unified_pos
        self.n_hidden = n_hidden
        self.space_dim = space_dim
        self.n_layers = n_layers

        # --- b. 容量规划 (来自 StructuredAdaptiveTransolver) ---
        if not isinstance(capacity_ratios, list) or len(capacity_ratios) != n_layers:
            raise ValueError(f"capacity_ratios 的长度必须等于 n_layers.")
        self.capacity_ratios = capacity_ratios
        # capacity_factors 将在 forward 中根据实际的 N 动态计算
        print(f"[{self.__name__}] 使用了点削减自适应策略。")
        print(f"[{self.__name__}] 容量规划 (比例): {self.capacity_ratios}")

        # --- c. I/O 和特征处理组件 (来自 Irregular_Mesh.Model) ---
        preprocess_in_dim = fun_dim + (self.ref * self.ref if self.unified_pos else self.space_dim)
        self.preprocess = MLP(preprocess_in_dim, n_hidden * 2, n_hidden, n_layers=0)
        
        # 注意：Irregular_Mesh 版本没有 out_mlp，最后一层在Block里
        # 我们遵循这个设计，但也可以选择添加一个统一的 out_mlp
        
        self.placeholder = nn.Parameter((1 / n_hidden) * torch.rand(n_hidden, dtype=torch.float))
        
        # --- d. 自适应和递归组件 ---
        self.router = LightweightRouter(n_hidden)
        
        # 【核心】使用 IrregularTransolverBlock
        self.blocks = nn.ModuleList([
            IrregularTransolverBlock(
                num_heads=n_head, hidden_dim=n_hidden, dropout=dropout,
                mlp_ratio=mlp_ratio, slice_num=slice_num,
                # 确保最后一个 block 能输出正确的维度
                last_layer=(i == n_layers - 1),
                out_dim=out_dim
            ) for i in range(n_layers)
        ])
        
        self.initialize_weights()

    def initialize_weights(self):
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.LayerNorm)):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def get_grid(self, x, batchsize=1):
        # (代码与 Irregular_Mesh.Model 完全相同)
        gridx = torch.tensor(np.linspace(0, 1, self.ref), dtype=torch.float)
        gridx = gridx.reshape(1, self.ref, 1, 1).repeat([batchsize, 1, self.ref, 1])
        gridy = torch.tensor(np.linspace(0, 1, self.ref), dtype=torch.float)
        gridy = gridy.reshape(1, 1, self.ref, 1).repeat([batchsize, self.ref, 1, 1])
        grid_ref = torch.cat((gridx, gridy), dim=-1).to(x.device).reshape(batchsize, self.ref * self.ref, 2)
        pos = torch.sqrt(torch.sum((x[:, :, None, :] - grid_ref[:, None, :, :]) ** 2, dim=-1))
        pos = pos.reshape(batchsize, x.shape[1], self.ref * self.ref).contiguous()
        return pos

    def forward(self, x, fx=None, **kwargs):
        # 1. 预处理 (与 Irregular_Mesh.Model 逻辑完全相同)
        if self.unified_pos:
            pos_feature = self.get_grid(x, x.shape[0])
            input_feature = torch.cat((pos_feature, fx), -1) if fx is not None else pos_feature
        else:
            input_feature = torch.cat((x, fx), -1) if fx is not None else x
            
        hidden_states = self.preprocess(input_feature)
        hidden_states = hidden_states + self.placeholder[None, None, :]
        
        B, N, C = hidden_states.shape

        # 2. 【自适应核心】一次性全局路由
        router_scores = self.router(hidden_states).squeeze(-1)
        with torch.no_grad():
            _, global_ranking_indices = torch.sort(router_scores, dim=1, descending=True)

        # 动态计算 capacity_factors
        capacity_factors = [max(int(N * r), 1) for r in self.capacity_ratios]

        for depth in range(self.n_layers - 1):
            k = capacity_factors[depth]
            k = min(k, N)
            active_indices = global_ranking_indices[:, :k]
            
            active_h = pack_tokens(hidden_states, active_indices)
            updated_active_h = self.blocks[depth](active_h)
            
            # 在这些层，形状是匹配的，scatter 可以安全执行
            hidden_states = hidden_states.scatter(1, active_indices.unsqueeze(-1).expand(-1, -1, C), updated_active_h)

        # 4. 单独处理最后一层
        last_depth = self.n_layers - 1
        k_last = capacity_factors[last_depth]
        k_last = min(k_last, N)
        active_indices_last = global_ranking_indices[:, :k_last]
        
        active_h_last = pack_tokens(hidden_states, active_indices_last)
        
        # 得到形状为 [B, k_last, out_dim] 的最终输出
        final_active_output = self.blocks[last_depth](active_h_last)
        
        # 5. 创建一个全零的最终输出画布，并写回结果
        # 注意：画布的最后一个维度现在是 out_dim
        final_output = torch.zeros(B, N, self.blocks[last_depth].mlp2.out_features, device=x.device)
        final_output = final_output.scatter(1, active_indices_last.unsqueeze(-1).expand(-1, -1, final_output.shape[-1]), final_active_output)
        
        # Transolver_Irregular_Mesh 的输出格式是 (tensor, None)，但我们直接返回 tensor
        # 确保训练脚本中的调用是 out, _ = model(...)
        # 或者直接返回 final_output
        return final_output, None