import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import numpy as np

# 导入原始的 Physics_Attention 作为父类
# 确保这个路径是正确的
try:
    from .Physics_Attention import Physics_Attention_Structured_Mesh_2D
    from .Transolver_Structured_Mesh_2D import MLP as TransolverMLP
except ImportError:
    from model.Physics_Attention import Physics_Attention_Structured_Mesh_2D
    from model.Transolver_Structured_Mesh_2D import MLP as TransolverMLP

# ==============================================================================
# 1. 基础辅助模块 (与之前版本完全相同，无需修改)
# ==============================================================================

class LightweightRouter(nn.Module):
    def __init__(self, d_model, d_hidden_ratio=0.25):
        super().__init__()
        d_hidden = int(d_model * d_hidden_ratio)
        self.proj1 = nn.Linear(d_model, d_hidden)
        self.proj2 = nn.Linear(d_hidden, 1)

    def forward(self, x):
        return self.proj2(F.gelu(self.proj1(x)))

def pack_tokens(x, active_idx):
    D = x.shape[-1]
    idx_exp = active_idx.unsqueeze(-1).expand(-1, -1, D)
    return torch.gather(x, 1, idx_exp)

# ==============================================================================
# 2. 核心计算引擎 (与之前版本完全相同，无需修改)
# ==============================================================================

class Physics_Attention_Adaptive(Physics_Attention_Structured_Mesh_2D):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=64, **kwargs):
        super().__init__(dim, heads, dim_head, dropout, slice_num, H=-1, W=-1)

    def forward(self, x):
        B, N, C = x.shape
        if N == 0: 
            return torch.zeros_like(x)
        
        H = W = int(np.sqrt(N))

        if H * W != N:
            raise ValueError(f"自适应模式下，输入序列长度 {N} 必须是一个完美的平方数。")
        
        x_reshaped = x.reshape(B, H, W, C).contiguous().permute(0, 3, 1, 2)
        
        fx_mid = self.in_project_fx(x_reshaped).permute(0, 2, 3, 1).reshape(B, N, self.heads, self.dim_head).permute(0, 2, 1, 3)
        x_mid = self.in_project_x(x_reshaped).permute(0, 2, 3, 1).reshape(B, N, self.heads, self.dim_head).permute(0, 2, 1, 3)
        slice_weights = self.softmax(self.in_project_slice(x_mid) / torch.clamp(self.temperature, min=0.1, max=5))
        
        slice_norm = slice_weights.sum(2, keepdim=True).clamp(min=1e-8)
        slice_token = torch.matmul(slice_weights.transpose(-2, -1), fx_mid) / slice_norm.transpose(-2, -1)
        
        q_slice = self.to_q(slice_token)
        k_slice = self.to_k(slice_token)
        v_slice = self.to_v(slice_token)
        dots = torch.matmul(q_slice, k_slice.transpose(-1, -2)) * self.scale
        attn = self.softmax(dots)
        attn = self.dropout(attn)
        out_slice = torch.matmul(attn, v_slice)
        
        out_x = torch.einsum("bhgc,bhng->bhnc", out_slice, slice_weights)
        out_x = rearrange(out_x, 'b h n d -> b n (h d)')
        return self.to_out(out_x)

class Transolver_block_Adaptive(nn.Module):
    def __init__(self, num_heads, hidden_dim, dropout, mlp_ratio, slice_num, act='gelu'):
        super().__init__()
        self.ln_1 = nn.LayerNorm(hidden_dim)
        self.Attn = Physics_Attention_Adaptive(
            dim=hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads,
            dropout=dropout, slice_num=slice_num
        )
        self.ln_2 = nn.LayerNorm(hidden_dim)
        self.mlp = TransolverMLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act)

    def forward(self, fx):
        fx = self.Attn(self.ln_1(fx)) + fx
        fx = self.mlp(self.ln_2(fx)) + fx
        return fx

# ==============================================================================
# 3. 专为稳态任务设计的自适应Transolver (核心修改部分)
# ==============================================================================

class PipeAdaptiveTransolver(nn.Module):
    """
    一个专为稳态PDE任务（如pipe数据集）设计的自适应Transolver。
    它的输入只有空间坐标 x，没有时序特征 fx。
    """
    def __init__(self, space_dim, n_layers, n_hidden, dropout, n_head, mlp_ratio, 
                 out_dim, slice_num, H, W, ref,capacity_ratios,unified_pos=False, **kwargs):
        super().__init__()
        self.__name__ = "PipeAdaptiveTransolver" # 赋予一个清晰、特定的新名称
        
        # 移除 fun_dim 参数，因为它在这里永远是0
        # fun_dim 在稳态任务中没有意义
        self.H = H
        self.W = W
        self.recursion_depth = n_layers
        self.unified_pos = unified_pos
        self.ref = ref
        # --- a. 自适应容量规划 (逻辑不变) ---
        num_coarse_nodes = H * W
        if not isinstance(capacity_ratios, list) or len(capacity_ratios) != n_layers:
            raise ValueError(f"capacity_ratios 的长度必须等于 n_layers。")
        print(f"[{self.__name__}] 正在根据用户输入的比例，预计算完美的平方数容量...")
        
        # 1. 创建一个空的列表来存储最终的容量因子
        precomputed_factors = []
        
        # 2. 遍历用户输入的每一个比例
        for ratio in capacity_ratios:
            # a. 根据比例计算一个目标 k 值
            target_k = num_coarse_nodes * ratio
            
            # b. 找到小于或等于 target_k 的最大完美平方数
            target_h = int(np.sqrt(target_k))
            perfect_square_k = target_h * target_h
            
            # c. 将这个完美的平方数添加到列表中
            precomputed_factors.append(max(perfect_square_k, 1)) # 确保容量至少为1
            
        # 3. 将预计算好的、完美的列表保存为成员变量
        self.capacity_factors = precomputed_factors

        # --- b. I/O组件 (核心修改 1: 简化输入维度) ---
        # 输入维度现在固定为 space_dim，因为没有 fx
        input_dim = space_dim
        print(f"[{self.__name__}] 模型入口 (preprocess) 输入维度固定为: {input_dim}")
            
        # self.preprocess = TransolverMLP(input_dim, n_hidden * 2, n_hidden)
        if self.unified_pos:
            self.pos = self.get_grid()
            self.preprocess = TransolverMLP( self.ref * self.ref + 1 , n_hidden * 2, n_hidden, n_layers=0, res=False)
        else:
            self.preprocess = TransolverMLP( space_dim, n_hidden * 2, n_hidden, n_layers=0, res=False)
        self.out_mlp = TransolverMLP(n_hidden, n_hidden, out_dim)

        # --- c. 自适应和递归组件 (逻辑不变) ---
        self.router = LightweightRouter(n_hidden)
        self.transolver_blocks = nn.ModuleList([
            Transolver_block_Adaptive(
                num_heads=n_head, hidden_dim=n_hidden, dropout=dropout,
                mlp_ratio=mlp_ratio, slice_num=slice_num
            ) for _ in range(self.recursion_depth)
        ])
    def get_grid(self, batchsize=1):
        size_x, size_y = self.H, self.W
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        grid = torch.cat((gridx, gridy), dim=-1).cuda()  # B H W 2

        gridx = torch.tensor(np.linspace(0, 1, self.ref), dtype=torch.float)
        gridx = gridx.reshape(1, self.ref, 1, 1).repeat([batchsize, 1, self.ref, 1])
        gridy = torch.tensor(np.linspace(0, 1, self.ref), dtype=torch.float)
        gridy = gridy.reshape(1, 1, self.ref, 1).repeat([batchsize, self.ref, 1, 1])
        grid_ref = torch.cat((gridx, gridy), dim=-1).cuda()  # B H W 8 8 2

        pos = torch.sqrt(torch.sum((grid[:, :, :, None, None, :] - grid_ref[:, None, None, :, :, :]) ** 2, dim=-1)). \
            reshape(batchsize, size_x, size_y, self.ref * self.ref).contiguous()
        return pos
    def forward(self, x, fx, **kwargs):
        if self.unified_pos:
            x = self.pos.repeat(x.shape[0], 1, 1, 1).reshape(x.shape[0], self.H * self.W, self.ref * self.ref)
        # --- 1. 预处理 (核心修改 2: 简化输入) ---
        # 直接使用 x 作为输入，移除了所有与 fx 相关的逻辑
        if fx is not None:
            fx = torch.cat((x, fx), -1)
            hidden_states = self.preprocess(fx)
        else:
            hidden_states = self.preprocess(x)
        #     fx = fx + self.placeholder[None, None, :]
        # hidden_states = self.preprocess(x)
        B, N, C = hidden_states.shape
        
        # --- 2. 一次性全局路由 (逻辑不变) ---
        router_scores = self.router(hidden_states).squeeze(-1)
        with torch.no_grad():
            _, global_ranking_indices = torch.sort(router_scores, dim=1, descending=True)
        # _, global_ranking_indices = torch.rand_like(router_scores).sort(dim=1, descending=True)
        # print("random")
        # --- 3. 自适应递归循环 (逻辑不变) ---
        for depth in range(self.recursion_depth):
            k = self.capacity_factors[depth]
            
            if k == 0: continue
            k = min(k, N) 
                
            active_indices = global_ranking_indices[:, :k]

            # a. 保存 shortcut 连接。这是梯度流的“高速公路”。
            shortcut = hidden_states
            
            # b. 打包活跃的 tokens 用于计算。注意，我们是从 shortcut (即当前状态) 中打包。
            active_h_before_block = pack_tokens(shortcut, active_indices)
            
            # c. 在活跃的 tokens 上执行核心计算。
            block_output_active = self.transolver_blocks[depth](active_h_before_block)
            
            # d. 【关键】计算“更新增量 (delta)”。
            #    我们只关心 block 计算带来了多少变化，而不是它的绝对输出值。
            update_delta = block_output_active - active_h_before_block
            
            # e. 创建一个全零的“更新画布”。
            #    这个操作虽然会临时占用内存，但它是构建正确梯度流所必需的。
            update_canvas = torch.zeros_like(hidden_states)
            
            # f. 使用 scatter_ 将“更新增量”精确地放置到画布的对应位置。
            idx_exp = active_indices.unsqueeze(-1).expand(-1, -1, C)
            update_canvas.scatter_(1, idx_exp, update_delta)
            
            # g. 【关键】将更新应用到原始的 shortcut 上，完成标准的残差连接。
            #     梯度现在可以无障碍地从 hidden_states 流回 shortcut。
            hidden_states = shortcut + update_canvas
        # for depth in range(self.recursion_depth):
        #     k = self.capacity_factors[depth]
            
        #     if k == 0: continue
        #     k = min(k, N) 
                
        #     active_indices = global_ranking_indices[:, :k]
        #     active_h = pack_tokens(hidden_states, active_indices)
            
        #     block_output_active = self.transolver_blocks[depth](active_h)
            
        #     idx_exp = active_indices.unsqueeze(-1).expand(-1, -1, C)
        #     hidden_states = hidden_states.scatter(1, idx_exp, block_output_active)
        # --- 4. 输出 (逻辑不变) ---
        output = self.out_mlp(hidden_states)
        
        # 在稳态任务中，通常只需要返回一个张量
        # 我们保持与原始Transolver相同的(tensor, None)格式以兼容旧的评估逻辑
        return output, None