import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn.utils.rnn import pad_sequence
import dgl
import numpy as np

# 导入原始的 Physics_Attention 作为父类
from .Physics_Attention import Physics_Attention_Structured_Mesh_2D


# 假设这些模块与此文件在同一'model'目录下
from .Transolver_Structured_Mesh_2D import MLP as TransolverMLP

# ==============================================================================
# 2. 基础辅助模块 (自包含)
# ==============================================================================

class GatedMLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, act_layer=nn.GELU):
        super().__init__()
        self.proj_up_gate = nn.Linear(in_features, 2 * hidden_features)
        self.act = act_layer()
        self.proj_down = nn.Linear(hidden_features, out_features)
    def forward(self, x):
        up, gate = self.proj_up_gate(x).chunk(2, dim=-1)
        return self.proj_down(self.act(up) * gate)

class XFAttention(nn.Module):
    def __init__(self, n_embd, n_head, attn_pdrop=0.0):
        super().__init__()
        assert n_embd % n_head == 0
        self.key = nn.Linear(n_embd, n_embd)
        self.query = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)
        self.proj = nn.Linear(n_embd, n_embd)
        self.n_head = n_head
        self.attn_pdrop = attn_pdrop
    def forward(self, x, y=None, attn_bias=None):
        y = x if y is None else y
        B, N, C = x.shape
        q = self.query(x).view(B, N, self.n_head, C // self.n_head)
        k = self.key(y).view(B, N, self.n_head, C // self.n_head)
        v = self.value(y).view(B, N, self.n_head, C // self.n_head)
        if XFORMERS_AVAILABLE:
            out = xops.memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=self.attn_pdrop)
        else:
            q_pt, k_pt, v_pt = map(lambda t: rearrange(t, 'b n h d -> b h n d'), (q, k, v))
            scores = torch.matmul(q_pt, k_pt.transpose(-2, -1)) / (k_pt.size(-1) ** 0.5)
            if attn_bias is not None: scores += attn_bias
            attn = F.softmax(scores, dim=-1)
            attn = F.dropout(attn, p=self.attn_pdrop)
            out_pt = torch.matmul(attn, v_pt)
            out = rearrange(out_pt, 'b h n d -> b n h d')
        out = rearrange(out, 'b n h d -> b n (h d)')
        return self.proj(out)

class LightweightRouter(nn.Module):
    def __init__(self, d_model, d_hidden_ratio=0.25):
        super().__init__()
        d_hidden = int(d_model * d_hidden_ratio)
        self.proj1 = nn.Linear(d_model, d_hidden)
        self.proj2 = nn.Linear(d_hidden, 1)
    def forward(self, x):
        return self.proj2(F.gelu(self.proj1(x)))
class CNNRouter(nn.Module):
    """
    一个基于卷积神经网络的路由器，用于捕捉层级化的空间特征来为Token打分。
    专门为规则的二维网格数据设计。
    """
    def __init__(self, in_channels, H, W, hidden_channels_ratio=0.25):
        super().__init__()
        self.H = H
        self.W = W
        
        # 使用一个小型、高效的卷积网络
        # 例如：两层卷积 + 降采样
        hidden_channels = max(int(in_channels * hidden_channels_ratio), 16) # 确保至少有16个通道

        self.cnn_scorer = nn.Sequential(
            # 输入: [B, C, H, W]
            nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
            nn.GELU(),
            # 可以在这里加入降采样层，例如 MaxPool2d，以扩大感受野
            # nn.MaxPool2d(kernel_size=2, stride=2), 
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
            nn.GELU(),
            # 最终输出一个单通道的“分数图”
            nn.Conv2d(hidden_channels, 1, kernel_size=1)
            # 输出: [B, 1, H, W]
        )

    def forward(self, x):
        # x shape: [B, N, C]
        B, N, C = x.shape
        
        if N != self.H * self.W:
            raise ValueError(
                f"输入Token数 N={N} 与 CNNRouter 初始化时的网格尺寸 H={self.H}, W={self.W} 不匹配。"
            )

        # 1. 将点序列 reshape 回图像形式
        x_image = x.permute(0, 2, 1).reshape(B, C, self.H, self.W)
        
        # 2. 通过CNN计算分数图
        score_map = self.cnn_scorer(x_image) # shape: [B, 1, H, W]
        
        # 3. 将分数图 flatten 回分数列表
        scores = score_map.reshape(B, N, 1) # shape: [B, N, 1]
        
        return scores
def pack_tokens(x, active_idx):
    D = x.shape[-1]
    idx_exp = active_idx.unsqueeze(-1).expand(-1, -1, D)
    return torch.gather(x, 1, idx_exp)

# ==============================================================================
# 3. 核心计算引擎 (运行在粗网格上)
# ==============================================================================
class Physics_Attention_Adaptive(Physics_Attention_Structured_Mesh_2D):
    """
    一个适配了可变序列长度的Physics_Attention版本。
    """
    def __init__(self, dim, heads=8, dim_head=64, dropout=0., slice_num=64, **kwargs):
        # 调用父类构造函数，但忽略H和W
        super().__init__(dim, heads, dim_head, dropout, slice_num, H=-1, W=-1)

    def forward(self, x):
        B, N, C = x.shape
        if N == 0: return torch.zeros_like(x) # 处理空输入的边界情况
        
        H = W = int(np.sqrt(N))
        if H * W != N:
            raise ValueError(f"自适应模式下，输入序列长度 {N} 必须是一个完美的平方数。")
        
        # --- (后续的Slicing, Attention, Deslicing逻辑与原始代码完全相同) ---
        x_reshaped = x.reshape(B, H, W, C).contiguous().permute(0, 3, 1, 2)
        
        fx_mid = self.in_project_fx(x_reshaped).permute(0, 2, 3, 1).reshape(B, N, self.heads, self.dim_head).permute(0, 2, 1, 3)
        x_mid = self.in_project_x(x_reshaped).permute(0, 2, 3, 1).reshape(B, N, self.heads, self.dim_head).permute(0, 2, 1, 3)
        slice_weights = self.softmax(self.in_project_slice(x_mid) / torch.clamp(self.temperature, min=0.1, max=5))
        slice_norm = slice_weights.sum(2, keepdim=True).clamp(min=1e-8)
        slice_token = torch.matmul(slice_weights.transpose(-2, -1), fx_mid) / slice_norm.transpose(-2, -1)
        
        q_slice = self.to_q(slice_token)
        k_slice = self.to_k(slice_token)
        v_slice = self.to_v(slice_token)
        dots = torch.matmul(q_slice, k_slice.transpose(-1, -2)) * self.scale
        attn = self.softmax(dots)
        attn = self.dropout(attn)
        out_slice = torch.matmul(attn, v_slice)
        
        out_x = torch.einsum("bhgc,bhng->bhnc", out_slice, slice_weights)
        out_x = rearrange(out_x, 'b h n d -> b n (h d)')
        return self.to_out(out_x)

class Transolver_block_Adaptive(nn.Module):
    """
    一个使用了可变长度Attention的Transolver Block版本。
    """
    def __init__(self, num_heads, hidden_dim, dropout, mlp_ratio, slice_num, act='gelu'):
        super().__init__()
        self.ln_1 = nn.LayerNorm(hidden_dim)
        self.Attn = Physics_Attention_Adaptive(
            dim=hidden_dim, heads=num_heads, dim_head=hidden_dim // num_heads,
            dropout=dropout, slice_num=slice_num
        )
        self.ln_2 = nn.LayerNorm(hidden_dim)
        self.mlp = TransolverMLP(hidden_dim, hidden_dim * mlp_ratio, hidden_dim, n_layers=0, res=False, act=act)

    def forward(self, fx):
        # 原始的Transolver Block残差连接逻辑
        fx = self.Attn(self.ln_1(fx)) + fx
        fx = self.mlp(self.ln_2(fx)) + fx
        return fx
class AdaptiveCore(nn.Module):
    def __init__(self, n_hidden, n_head, n_layers, final_keep_ratio, n_inner, resid_pdrop, attn_pdrop, act, num_coarse_nodes):
        super().__init__()
        self.recursion_depth = n_layers
        
        ratios = np.linspace(1.0, final_keep_ratio, self.recursion_depth)
        self.capacity_factors = [int(num_coarse_nodes * r) for r in ratios]
        self.capacity_factors[-1] = max(self.capacity_factors[-1], 1)
        print(f"[AdaptiveCore] 容量规划 (活跃Token数/层): {self.capacity_factors}")

        self.router = LightweightRouter(n_hidden)
        
        config = {'n_embd': n_hidden, 'n_head': n_head, 'n_inner': n_inner, 'resid_pdrop': resid_pdrop, 'attn_pdrop': attn_pdrop, 'act': act}
        self.recursion_blocks = nn.ModuleList([self._create_block(config) for _ in range(self.recursion_depth)])

    def _create_block(self, config):
        class CoreBlock(nn.Module):
            def __init__(self, cfg):
                super().__init__()
                self.ln1 = nn.LayerNorm(cfg['n_embd'])
                self.ln2 = nn.LayerNorm(cfg['n_embd'])
                self.attn = XFAttention(cfg['n_embd'], cfg['n_head'], cfg['attn_pdrop'])
                self.mlp = GatedMLP(cfg['n_embd'], cfg['n_inner'], cfg['n_embd'], act_layer=getattr(nn, cfg['act']))
                self.resid_drop = nn.Dropout(cfg['resid_pdrop'])
            def forward(self, x):
                x = x + self.resid_drop(self.attn(self.ln1(x)))
                x = x + self.resid_drop(self.mlp(self.ln2(x)))
                return x
        return CoreBlock(config)

    def forward(self, hidden_states):
        B, N, C = hidden_states.shape
        
        router_scores = self.router(hidden_states).squeeze(-1)
        with torch.no_grad():
            _, global_ranking_indices = torch.sort(router_scores, dim=1, descending=True)

        for depth in range(self.recursion_depth):
            k = self.capacity_factors[depth]
            if k == 0: continue
            active_indices = global_ranking_indices[:, :k]
            
            active_h = pack_tokens(hidden_states, active_indices)
            block_output_active = self.recursion_blocks[depth](active_h)
            
            idx_exp = active_indices.unsqueeze(-1).expand(-1, -1, C)
            hidden_states = hidden_states.scatter(1, idx_exp, block_output_active)
        
        return hidden_states, None # 推理时只返回最终状态

# ==============================================================================
# 4. 最终的、解耦的自适应Transolver
# ==============================================================================

class StructuredAdaptiveTransolver(nn.Module):
    def __init__(self, space_dim, n_layers, n_hidden, dropout, n_head, mlp_ratio, 
                 fun_dim, out_dim, slice_num, H, W, 
                 # --- 核心修改：capacity_ratios 现在是必需的 ---
                 capacity_ratios, 
                 unified_pos=False,
                 act='gelu',
                 ref=8,
                 **kwargs):
        super().__init__()
        self.__name__ = "StructuredAdaptiveTransolver"
        self.H = H
        self.W = W
        self.ref = ref
        self.recursion_depth = n_layers
        self.unified_pos = unified_pos
        if self.unified_pos:
            self.pos = self.get_grid()
            self.preprocess = TransolverMLP(fun_dim + self.ref * self.ref, n_hidden * 2, n_hidden, n_layers=0, res=False, act=act)
        else:
            self.preprocess = TransolverMLP(fun_dim + space_dim, n_hidden * 2, n_hidden, n_layers=0, res=False, act=act)
        # --- a. (核心修改) 使用用户定义的任意容量规划 ---
        num_coarse_nodes = H * W
        if not isinstance(capacity_ratios, list) or len(capacity_ratios) != n_layers:
            raise ValueError(f"capacity_ratios 的长度必须等于 n_layers。")
        print(f"[{self.__name__}] 正在根据用户输入的比例，预计算完美的平方数容量...")
        
        # 1. 创建一个空的列表来存储最终的容量因子
        precomputed_factors = []
        
        # 2. 遍历用户输入的每一个比例
        for ratio in capacity_ratios:
            # a. 根据比例计算一个目标 k 值
            target_k = num_coarse_nodes * ratio
            
            # b. 找到小于或等于 target_k 的最大完美平方数
            target_h = int(np.sqrt(target_k))
            perfect_square_k = target_h * target_h
            
            # c. 将这个完美的平方数添加到列表中
            precomputed_factors.append(max(perfect_square_k, 1)) # 确保容量至少为1
            
        # 3. 将预计算好的、完美的列表保存为成员变量
        self.capacity_factors = precomputed_factors
        # --- 修改结束 ---

        # b. Transolver的I/O组件 (不变)
        self.out_mlp = TransolverMLP(n_hidden, n_hidden, out_dim)

        # c. 自适应和递归组件 (不变)
        self.router = LightweightRouter(n_hidden)
        # self.router = CNNRouter(in_channels=n_hidden, H=H, W=W)
        print(f"[{self.__name__}] 使用了 CNNRouter。")
        # 使用改造后的、支持变长的Block
        self.transolver_blocks = nn.ModuleList([
            Transolver_block_Adaptive(
                num_heads=n_head, hidden_dim=n_hidden, dropout=dropout,
                mlp_ratio=mlp_ratio, slice_num=slice_num
            ) for _ in range(self.recursion_depth)
        ])
    def get_grid(self, batchsize=1):
        size_x, size_y = self.H, self.W
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        grid = torch.cat((gridx, gridy), dim=-1).cuda()  # B H W 2

        gridx = torch.tensor(np.linspace(0, 1, self.ref), dtype=torch.float)
        gridx = gridx.reshape(1, self.ref, 1, 1).repeat([batchsize, 1, self.ref, 1])
        gridy = torch.tensor(np.linspace(0, 1, self.ref), dtype=torch.float)
        gridy = gridy.reshape(1, 1, self.ref, 1).repeat([batchsize, self.ref, 1, 1])
        grid_ref = torch.cat((gridx, gridy), dim=-1).cuda()  # B H W 8 8 2

        pos = torch.sqrt(torch.sum((grid[:, :, :, None, None, :] - grid_ref[:, None, None, :, :, :]) ** 2, dim=-1)). \
            reshape(batchsize, size_x, size_y, self.ref * self.ref).contiguous()
        return pos
    def forward(self, x, fx, **kwargs):
        
        if self.unified_pos:
            x = self.pos.repeat(x.shape[0], 1, 1, 1).reshape(x.shape[0], self.H * self.W, self.ref * self.ref)
        hidden_states = self.preprocess(torch.cat((x, fx), -1))
        B, N, C = hidden_states.shape
        
        # 2. 一次性全局路由
        router_scores = self.router(hidden_states).squeeze(-1)
        with torch.no_grad():
            _, global_ranking_indices = torch.sort(router_scores, dim=1, descending=True)

        # 3. 自适应递归循环
        for depth in range(self.recursion_depth):
            # --- 核心修改：直接使用预先计算好的容量因子 ---
            k = self.capacity_factors[depth]
            # --- 修改结束 ---

            if k == 0: continue
            # 确保 k 不超过总Token数 (安全检查)
            if k > N: k = N
                
            active_indices = global_ranking_indices[:, :k]
            
            active_h = pack_tokens(hidden_states, active_indices)
            
            block_output_active = self.transolver_blocks[depth](active_h)
            
            idx_exp = active_indices.unsqueeze(-1).expand(-1, -1, C)
            hidden_states = hidden_states.scatter(1, idx_exp, block_output_active)

        # 4. 输出
        output = self.out_mlp(hidden_states)
        
        return output, None