import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
import numpy as np
from .attention_module import PreNorm, PostNorm, LinearAttention, CrossLinearAttention,\
    FeedForward, GeGELU, ProjDotProduct
from .cnn_module import UpBlock, FourierConv2d, PeriodicConv2d
from torch.nn.init import xavier_uniform_, orthogonal_


try:
    import xformers.ops as xops
    XFORMERS_AVAILABLE = True
except ImportError:
    XFORMERS_AVAILABLE = False
    print("警告: xformers 库未找到，将回退到标准PyTorch实现。")

# --- 1a. 自适应所需的基础模块 ---


class XFAttention(nn.Module):
    def __init__(self, n_embd, n_head, attn_pdrop=0.0):
        super().__init__()
        assert n_embd % n_head == 0
        self.key = nn.Linear(n_embd, n_embd)
        self.query = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)
        self.proj = nn.Linear(n_embd, n_embd)
        self.n_head = n_head
        self.attn_pdrop = attn_pdrop
    def forward(self, x, y=None, attn_bias=None):
        y = x if y is None else y
        B, N, C = x.shape
        q = self.query(x).view(B, N, self.n_head, -1)
        k = self.key(y).view(B, N, self.n_head, -1)
        v = self.value(y).view(B, N, self.n_head, -1)
        if XFORMERS_AVAILABLE:
            out = xops.memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=self.attn_pdrop)
        else:
            q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d'), (q, k, v))
            scores = torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5)
            if attn_bias is not None: scores += attn_bias
            attn = F.softmax(scores, dim=-1)
            out = torch.matmul(F.dropout(attn, p=self.attn_pdrop), v)
            out = rearrange(out, 'b h n d -> b n h d')
        return self.proj(rearrange(out, 'b n h d -> b n (h d)'))

class LightweightRouter(nn.Module):
    def __init__(self, d_model, d_hidden_ratio=0.25):
        super().__init__()
        self.proj1 = nn.Linear(d_model, int(d_model * d_hidden_ratio))
        self.proj2 = nn.Linear(int(d_model * d_hidden_ratio), 1)
    def forward(self, x):
        return self.proj2(F.gelu(self.proj1(x)))

def pack_tokens(x, active_idx):
    idx_exp = active_idx.unsqueeze(-1).expand(-1, -1, x.shape[-1])
    return torch.gather(x, 1, idx_exp)
class AttentionPropagator2D(nn.Module):
    def __init__(self,
                 dim,
                 depth,
                 heads,
                 dim_head,
                 attn_type,         # ['none', 'galerkin', 'fourier']
                 mlp_dim,
                 scale,
                 use_ln=True,
                 dropout=0.):
        super().__init__()
        assert attn_type in ['none', 'galerkin', 'fourier']
        self.layers = nn.ModuleList([])

        self.attn_type = attn_type
        self.use_ln = use_ln
        for d in range(depth):
            attn_module = LinearAttention(dim, attn_type,
                                          heads=heads, dim_head=dim_head, dropout=dropout,
                                          relative_emb=True, scale=scale,
                                          relative_emb_dim=2,
                                          min_freq=1/64,
                                          init_method='orthogonal'
                                          )
            if use_ln:
                self.layers.append(
                    nn.ModuleList([
                        nn.LayerNorm(dim),
                        attn_module,
                        nn.LayerNorm(dim),
                        nn.Linear(dim+2, dim),
                        FeedForward(dim, mlp_dim, dropout=dropout)
                    ]),
                )
            else:
                self.layers.append(
                    nn.ModuleList([
                        attn_module,
                        nn.Linear(dim + 2, dim),
                        FeedForward(dim, mlp_dim, dropout=dropout)
                    ]),
                )

    def forward(self, x, pos):
        for layer_no, attn_layer in enumerate(self.layers):
            if self.use_ln:
                [ln1, attn, ln2, proj, ffn] = attn_layer
                x = attn(ln1(x), pos) + x
                x = ffn(
                    proj(torch.cat((ln2(x), pos), dim=-1))
                        ) + x
            else:
                [attn, proj, ffn] = attn_layer
                x = attn(x, pos) + x
                x = ffn(
                    proj(torch.cat((x, pos), dim=-1))
                        ) + x
        return x


class AttentionPropagator1D(nn.Module):
    def __init__(self,
                 dim,
                 depth,
                 heads,
                 dim_head,
                 attn_type,         # ['none', 'galerkin', 'fourier']
                 mlp_dim,
                 scale,
                 res,
                 dropout=0.):
        super().__init__()
        assert attn_type in ['none', 'galerkin', 'fourier']
        self.layers = nn.ModuleList([])

        self.attn_type = attn_type

        for d in range(depth):
            attn_module = LinearAttention(dim, attn_type,
                                          heads=heads, dim_head=dim_head, dropout=dropout,
                                          relative_emb=True,
                                          scale=scale,
                                          relative_emb_dim=1,
                                          min_freq=1 / res,
                                          )
            self.layers.append(
                nn.ModuleList([
                    attn_module,
                    FeedForward(dim, mlp_dim, dropout=dropout)
                ]),
            )

    def forward(self, x, pos):
        for layer_no, attn_layer in enumerate(self.layers):
            [attn, ffn] = attn_layer

            x = attn(x, pos) + x
            x = ffn(x) + x
        return x


class FourierPropagator(nn.Module):
    def __init__(self,
                 dim,
                 depth,
                 mode):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.latent_channels = dim

        for d in range(depth):
            self.layers.append(nn.Sequential(FourierConv2d(self.latent_channels, self.latent_channels,
                                                           mode, mode), nn.GELU()))

    def forward(self, z):
        for layer, f_conv in enumerate(self.layers):
            z = f_conv(z) + z
        return z


class MLPPropagator(nn.Module):
    def __init__(self,
                 dim,
                 depth,
                 ):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.latent_channels = dim

        for d in range(depth):
            layer = nn.Sequential(
                nn.Conv2d(dim, dim, 1, 1, 0, bias=False),
                nn.GELU(),
                nn.Conv2d(dim, dim, 1, 1, 0, bias=False),
                nn.GELU(),
                nn.Conv2d(dim, dim, 1, 1, 0, bias=False),
                nn.InstanceNorm2d(dim),
            )
            self.layers.append(layer)

    def forward(self, z):
        for layer, ffn in enumerate(self.layers):
            z = ffn(z) + z
        return z


class PointWiseMLPPropagator(nn.Module):
    def __init__(self,
                 dim,
                 depth,
                 ):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.latent_channels = dim

        for d in range(depth):
            if d == 0:
                layer = nn.Sequential(
                    nn.InstanceNorm1d(dim + 2),
                    nn.Linear(dim + 2, dim, bias=False),  # for position
                    nn.GELU(),
                    nn.Linear(dim, dim, bias=False),
                    nn.GELU(),
                    nn.Linear(dim, dim, bias=False),
                )
            else:
                layer = nn.Sequential(
                    nn.InstanceNorm1d(dim),
                    nn.Linear(dim, dim, bias=False),
                    nn.GELU(),
                    nn.Linear(dim, dim, bias=False),
                    nn.GELU(),
                    nn.Linear(dim, dim, bias=False),
                )
            self.layers.append(layer)

    def forward(self, z, pos):
        for layer, ffn in enumerate(self.layers):
            if layer == 0:
                z = ffn(torch.cat((z, pos), dim=-1)) + z
            else:
                z = ffn(z) + z
        return z


# code copied from: https://github.com/ndahlquist/pytorch-fourier-feature-networks
# author: Nic Dahlquist
class GaussianFourierFeatureTransform(torch.nn.Module):
    """
    An implementation of Gaussian Fourier feature mapping.
    "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains":
       https://arxiv.org/abs/2006.10739
       https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html
    Given an input of size [batches, n, num_input_channels],
     returns a tensor of size [batches, n, mapping_size*2].
    """

    def __init__(self, num_input_channels, mapping_size=256, scale=10):
        super().__init__()

        self._num_input_channels = num_input_channels
        self._mapping_size = mapping_size
        self._B = nn.Parameter(torch.randn((num_input_channels, mapping_size)) * scale, requires_grad=False)

    def forward(self, x):

        batches, num_of_points, channels = x.shape

        # Make shape compatible for matmul with _B.
        # From [B, N, C] to [(B*N), C].
        x = rearrange(x, 'b n c -> (b n) c')

        x = x @ self._B.to(x.device)

        # From [(B*W*H), C] to [B, W, H, C]
        x = rearrange(x, '(b n) c -> b n c', b=batches)

        x = 2 * np.pi * x
        return torch.cat([torch.sin(x), torch.cos(x)], dim=-1)


class CrossFormer(nn.Module):
    def __init__(self,
                 dim,
                 attn_type,
                 heads,
                 dim_head,
                 mlp_dim,
                 residual=True,
                 use_ffn=True,
                 use_ln=False,
                 relative_emb=False,
                 scale=1.,
                 relative_emb_dim=2,
                 min_freq=1/64,
                 dropout=0.,
                 cat_pos=False,
                 ):
        super().__init__()

        self.cross_attn_module = CrossLinearAttention(dim, attn_type,
                                                       heads=heads, dim_head=dim_head, dropout=dropout,
                                                       relative_emb=relative_emb,
                                                       scale=scale,

                                                       relative_emb_dim=relative_emb_dim,
                                                       min_freq=min_freq,
                                                       init_method='orthogonal',
                                                       cat_pos=cat_pos,
                                                       pos_dim=relative_emb_dim,
                                                  )
        self.use_ln = use_ln
        self.residual = residual
        self.use_ffn = use_ffn

        if self.use_ln:
            self.ln1 = nn.LayerNorm(dim)
            self.ln2 = nn.LayerNorm(dim)

        if self.use_ffn:
            self.ffn = FeedForward(dim, mlp_dim, dropout)

    def forward(self, x, z, x_pos=None, z_pos=None):
        # x in [b n1 c]
        # b, n1, c = x.shape   # coordinate encoding
        # b, n2, c = z.shape   # system encoding
        if self.use_ln:
            z = self.ln1(z)
            if self.residual:
                x = self.ln2(self.cross_attn_module(x, z, x_pos, z_pos)) + x
            else:
                x = self.ln2(self.cross_attn_module(x, z, x_pos, z_pos))
        else:
            if self.residual:
                x = self.cross_attn_module(x, z, x_pos, z_pos) + x
            else:
                x = self.cross_attn_module(x, z, x_pos, z_pos)

        if self.use_ffn:
            x = self.ffn(x) + x

        return x


class BranchTrunkNet(nn.Module):
    def __init__(self,
                 dim,
                 branch_size,
                 branchnet_dim,
                 ):
        super().__init__()
        self.proj = nn.Sequential(
            Rearrange('b n c -> b c n'),
            nn.Linear(branch_size, branchnet_dim),
            nn.ReLU(),
            nn.Linear(branchnet_dim//2, branchnet_dim//2),
            nn.ReLU(),
            nn.Linear(branchnet_dim//2, 1),

        )
        self.net = ProjDotProduct(dim, dim, dim)

    def forward(self, x, z):
        # x in [b n1 c]
        # b, n1, c = x.shape   # coordinate encoding
        # b, n2, c = z.shape   # system encoding
        z = self.proj(z).squeeze(-1)
        return self.net(x, z)


class Decoder(nn.Module):
    def __init__(self,
                 grid_size,                # 64 x 64
                 latent_channels,              # 256??
                 out_channels,                 # 1 or 2?
                 out_steps,                    # 10
                 decoding_depth,                        # 4?
                 propagator_depth,
                 pos_encoding_aug=False,
                 **kwargs,
                 ):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.out_channels = out_channels
        self.out_steps = out_steps
        self.grid_size = grid_size
        self.latent_channels = latent_channels
        self.pos_encoding_aug = pos_encoding_aug

        self.propagator = MLPPropagator(self.latent_channels, propagator_depth)

        self.decoder = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(self.latent_channels + 2 if (l == 0 and pos_encoding_aug) else self.latent_channels,
                          self.latent_channels, 1, 1, 0, bias=False),
                nn.GELU(),
                nn.Conv2d(self.latent_channels, self.latent_channels, 1, 1, 0, bias=False),
                nn.GELU(),
                nn.Conv2d(self.latent_channels, self.latent_channels, 1, 1, 0, bias=False),
            )
            for l in range(decoding_depth)])

        self.to_out = nn.Conv2d(self.latent_channels, self.out_channels*self.out_steps, 1, 1, 0, bias=True)

        x0, y0 = np.meshgrid(np.linspace(0, 1, grid_size),
                             np.linspace(0, 1, grid_size))
        xs = np.concatenate((x0[None, ...], y0[None, ...]), axis=0)
        self.grid = nn.Parameter(torch.from_numpy(xs.reshape((1, 2, grid_size, grid_size))).float(), requires_grad=False)

    def decode(self, z):
        if self.pos_encoding_aug:
            z = torch.cat((z, repeat(self.grid, 'b c h w -> (repeat b) c h w', repeat=z.shape[0])), dim=1)
        for layer in self.decoder:
            z = layer(z)
        z = self.to_out(z)
        return z

    def forward(self, z, z_cls, forward_steps):
        assert len(z.shape) == 4  # [b, c, h, w]
        history = []
        z_cls = rearrange(z_cls, 'b c -> b c 1 1').repeat(1, 1, z.shape[2], z.shape[3])
        # forward the dynamics in the latent space
        for step in range(forward_steps):
            z = self.propagator(z + z_cls)
            u = self.decode(z)
            history.append(rearrange(u, 'b (c t) h w -> b c t h w', c=self.out_channels, t=self.out_steps))
        history = torch.cat(history, dim=2)
        return history    # [b, c, length_of_history, h, w]


class GraphDecoder(nn.Module):
    def __init__(self,
                 latent_channels,              # 256??
                 out_channels,                 # 1 or 2?
                 out_steps,                    # 10
                 decoding_depth,               # 4?
                 propagator_depth,
                 **kwargs,
                 ):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.out_channels = out_channels
        self.out_steps = out_steps
        self.latent_channels = latent_channels

        # self.pivotal_to_query = SmoothConvDecoder(self.latent_channels, self.latent_channels, 3)

        self.propagator = PointWiseMLPPropagator(self.latent_channels, propagator_depth)

        self.decoder = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.latent_channels, self.latent_channels, bias=False),
                nn.GELU(),
                nn.Linear(self.latent_channels, self.latent_channels, bias=False),
                nn.GELU(),
                nn.Linear(self.latent_channels, self.latent_channels, bias=True),
            )
            for _ in range(decoding_depth)])

        self.to_out = nn.Linear(self.latent_channels, self.out_channels*self.out_steps, bias=True)

    def decode(self, z):
        for layer in self.decoder:
            z = layer(z)
        z = self.to_out(z)
        return z

    def forward(self,
                propagate_pos,      #  [sum n_p, 2]
                pivotal_pos,        #  [sum n_pivot, 2]
                pivotal2prop_graph,   #  [sum g_pivot, 2]
                pivotal2prop_cutoff,  # float
                z_pivotal,          # [b, c, num_of_pivot]
                z_cls,              # [b, c]
                forward_steps):
        assert len(z_pivotal.shape) == 3  # [b, n, c]
        batch_size = z_pivotal.shape[0]
        history = []
        num_of_prop = int(propagate_pos.shape[0] // batch_size)  # assuming each batch have same number of nodes
        z_cls = rearrange(z_cls, 'b c -> b 1 c').repeat(1, num_of_prop, 1)

        # get embedding for nodes we want to propagate dynamics
        # z in shape [b, n, c]
        # z = self.pivotal_to_query.forward(z_pivotal, pivotal_pos, propagate_pos, pivotal2prop_graph,
        #                                   pivotal2prop_cutoff)
        z = rearrange(z_pivotal, 'b c n -> b n c')
        pos = rearrange(propagate_pos, '(b n) c -> b n c', b=batch_size)
        # forward the dynamics in the latent space
        for step in range(forward_steps):
            z = self.propagator(z + z_cls, pos)
            u = self.decode(z)
            history.append(rearrange(u, 'b n (t c) -> b c t n', c=self.out_channels, t=self.out_steps))
        history = torch.cat(history, dim=2)  # concatenate in temporal dimension
        return history    # [b, c, length_of_history, n]


class DecoderNew(nn.Module):
    def __init__(self,
                 latent_channels,              # 256??
                 out_channels,                 # 1 or 2?
                 out_steps,                    # 10
                 decoding_depth,               # 4?
                 propagator_depth,
                 **kwargs,
                 ):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.out_channels = out_channels
        self.out_steps = out_steps
        self.latent_channels = latent_channels

        self.propagator = PointWiseMLPPropagator(self.latent_channels, propagator_depth)

        self.decoder = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.latent_channels, self.latent_channels, bias=False),
                nn.GELU(),
                nn.Linear(self.latent_channels, self.latent_channels, bias=False),
                nn.GELU(),
                nn.Linear(self.latent_channels, self.latent_channels, bias=True),
            )
            for _ in range(decoding_depth)])

        self.to_out = nn.Linear(self.latent_channels, self.out_channels*self.out_steps, bias=True)

    def decode(self, z):
        for layer in self.decoder:
            z = layer(z)
        z = self.to_out(z)
        return z

    def forward(self,
                z,                  # [b, c, h, w]
                z_cls,              # [b, c]
                propagate_pos,      # [b, n, 2]
                forward_steps):
        history = []
        pos = propagate_pos
        z = rearrange(z, 'b c h w -> b (h w) c')
        z_cls = rearrange(z_cls, 'b c -> b 1 c').repeat(1, z.shape[1], 1)

        # forward the dynamics in the latent space
        for step in range(forward_steps):
            z = self.propagator(z + z_cls, pos)
            u = self.decode(z)
            history.append(rearrange(u, 'b n (t c) -> b c t n', c=self.out_channels, t=self.out_steps))
        history = torch.cat(history, dim=2)  # concatenate in temporal dimension
        return history    # [b, c, length_of_history, n]


class PointWiseDecoder(nn.Module):
    def __init__(self,
                 latent_channels,  # 256??
                 out_channels,  # 1 or 2?
                 out_steps,  # 10
                 decoding_depth,  # 4?
                 propagator_depth,
                 scale=8,
                 use_rope=False,
                 **kwargs,
                 ):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.out_channels = out_channels
        self.out_steps = out_steps
        self.latent_channels = latent_channels

        self.coordinate_projection = nn.Sequential(
            GaussianFourierFeatureTransform(2, self.latent_channels//4, scale=scale),
            nn.Linear(self.latent_channels // 2, self.latent_channels // 2, bias=False),
            nn.GELU(),
            nn.Linear(self.latent_channels // 2, self.latent_channels // 2, bias=False),
            nn.GELU(),
            nn.Linear(self.latent_channels // 2, self.latent_channels // 2, bias=False),
        )
        self.z_project = nn.Linear(self.latent_channels, self.latent_channels//2, bias=False)

        self.use_rope = use_rope
        if not use_rope:
            self.decoding_transformer = CrossFormer(self.latent_channels//2, 4, 64, self.latent_channels//2)
        else:
            self.decoding_transformer = CrossFormer(self.latent_channels//2, 4, 64, self.latent_channels//2,
                                                    relative_emb=True, scale=16.)

        self.project = nn.Sequential(
            nn.Linear(self.latent_channels//2, self.latent_channels, bias=False),
            nn.InstanceNorm1d(self.latent_channels))

        self.propagator = PointWiseMLPPropagator(self.latent_channels, propagator_depth)

        self.decoder = nn.ModuleList([
            nn.Sequential(
                nn.InstanceNorm1d(self.latent_channels),
                nn.Linear(self.latent_channels, self.latent_channels, bias=False),
                nn.GELU(),
                nn.Linear(self.latent_channels, self.latent_channels, bias=False),
                nn.GELU(),
                nn.Linear(self.latent_channels, self.latent_channels, bias=False),
            )
            for _ in range(decoding_depth)])

        self.to_out = nn.Sequential(
            nn.Linear(self.latent_channels, self.out_channels * self.out_steps, bias=True))

    def decode(self, z):
        for layer in self.decoder:
            z = layer(z)
        z = self.to_out(z)
        return z

    def forward(self,
                z,  # [b, n c]
                z_cls,  # [b, c]
                propagate_pos,  # [b, n, 2]
                forward_steps,
                input_pos=None):
        history = []
        x = self.coordinate_projection.forward(propagate_pos)
        z_cls = z_cls.repeat(1, propagate_pos.shape[1], 1)
        z = self.z_project(z)  # c to c/2
        if not self.use_rope:
            z = self.decoding_transformer.forward(x, z)
        else:
            z = self.decoding_transformer.forward(x, z, propagate_pos, input_pos)
        z = self.project.forward(z)

        # forward the dynamics in the latent space
        for step in range(forward_steps):
            z = self.propagator(z + z_cls, propagate_pos)
            u = self.decode(z)
            history.append(rearrange(u, 'b n (t c) -> b c t n', c=self.out_channels, t=self.out_steps))
        history = torch.cat(history, dim=2)  # concatenate in temporal dimension
        return history  # [b, c, length_of_history, n]


class SimplerPointWiseDecoder(nn.Module):
    def __init__(self,
                 latent_channels,  # 256??
                 out_channels,  # 1 or 2?
                 out_steps,  # 10
                 decoding_depth,  # 4?
                 propagator_depth,
                 scale=8,
                 use_rope=False,
                 **kwargs,
                 ):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.out_channels = out_channels
        self.out_steps = out_steps
        self.latent_channels = latent_channels

        self.coordinate_projection = nn.Sequential(
            GaussianFourierFeatureTransform(2, self.latent_channels//4, scale=scale),
            nn.Linear(self.latent_channels // 2, self.latent_channels // 2, bias=False),
            nn.GELU(),
            nn.Linear(self.latent_channels // 2, self.latent_channels // 2, bias=False),
            nn.GELU(),
            nn.Linear(self.latent_channels // 2, self.latent_channels // 2, bias=False),
        )
        self.z_project = nn.Linear(self.latent_channels, self.latent_channels//2, bias=False)

        self.use_rope = use_rope
        if not use_rope:
            self.decoding_transformer = CrossFormer(self.latent_channels//2, 4, 64, self.latent_channels//2)
        else:
            self.decoding_transformer = CrossFormer(self.latent_channels//2, 4, 64, self.latent_channels//2,
                                                    relative_emb=True, scale=16.)

        self.project = nn.Sequential(
            nn.Linear(self.latent_channels//2, self.latent_channels, bias=False),
            nn.InstanceNorm1d(self.latent_channels))

        self.propagator = PointWiseMLPPropagator(self.latent_channels, propagator_depth)

        self.decoder = nn.Sequential(
                nn.Linear(self.latent_channels, self.latent_channels, bias=False),
                nn.GELU(),
                nn.Linear(self.latent_channels, self.latent_channels, bias=False),
                nn.GELU(),
                nn.Linear(self.latent_channels, self.latent_channels, bias=False),
                nn.GELU()
            )

        self.to_out = nn.Linear(self.latent_channels, self.out_channels * self.out_steps, bias=True)

    def decode(self, z):
        z = self.decoder(z)
        z = self.to_out(z)
        return z

    def forward(self,
                z,  # [b, n c]
                z_cls,  # [b, c]
                propagate_pos,  # [b, n, 2]
                forward_steps,
                input_pos=None):
        history = []
        x = self.coordinate_projection.forward(propagate_pos)
        z_cls = z_cls.repeat(1, propagate_pos.shape[1], 1)
        z = self.z_project(z)  # c to c/2
        if not self.use_rope:
            z = self.decoding_transformer.forward(x, z)
        else:
            z = self.decoding_transformer.forward(x, z, propagate_pos, input_pos)
        z = self.project.forward(z)

        # forward the dynamics in the latent space
        for step in range(forward_steps):
            z = self.propagator(z + z_cls, propagate_pos)
            u = self.decode(z)
            history.append(rearrange(u, 'b n (t c) -> b c t n', c=self.out_channels, t=self.out_steps))
        history = torch.cat(history, dim=2)  # concatenate in temporal dimension
        return history  # [b, c, length_of_history, n]


class PointWiseDecoder2D(nn.Module):
    def __init__(self,
                 latent_channels,  # 256??
                 out_channels,  # 1 or 2?
                 out_steps,  # 10
                 propagator_depth,
                 scale=8,
                 dropout=0.,
                 **kwargs,
                 ):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.out_channels = out_channels
        self.out_steps = out_steps
        self.latent_channels = latent_channels

        self.coordinate_projection = nn.Sequential(
            GaussianFourierFeatureTransform(2, self.latent_channels//2, scale=scale),
            nn.Linear(self.latent_channels, self.latent_channels, bias=False),
            nn.GELU(),
            nn.Linear(self.latent_channels, self.latent_channels//2, bias=False),
        )

        self.decoding_transformer = CrossFormer(self.latent_channels//2, 'galerkin', 4,
                                                self.latent_channels//2, self.latent_channels//2,
                                                relative_emb=True,
                                                scale=16.,
                                                relative_emb_dim=2,
                                                min_freq=1/64)

        self.expand_feat = nn.Linear(self.latent_channels//2, self.latent_channels)

        self.propagator = nn.ModuleList([
               nn.ModuleList([nn.LayerNorm(self.latent_channels),
               nn.Sequential(
                    nn.Linear(self.latent_channels + 2, self.latent_channels, bias=False),
                    nn.GELU(),
                    nn.Linear(self.latent_channels, self.latent_channels, bias=False),
                    nn.GELU(),
                    nn.Linear(self.latent_channels, self.latent_channels, bias=False))])
            for _ in range(propagator_depth)])

        self.to_out = nn.Sequential(
            nn.LayerNorm(self.latent_channels),
            nn.Linear(self.latent_channels, self.latent_channels//2, bias=False),
            nn.GELU(),
            nn.Linear(self.latent_channels // 2, self.latent_channels // 2, bias=False),
            nn.GELU(),
            nn.Linear(self.latent_channels//2, self.out_channels * self.out_steps, bias=True))

    def propagate(self, z, pos):
        for layer in self.propagator:
            norm_fn, ffn = layer
            z = ffn(torch.cat((norm_fn(z), pos), dim=-1)) + z
        return z

    def decode(self, z):
        z = self.to_out(z)
        return z

    def get_embedding(self,
                      z,  # [b, n c]
                      propagate_pos,  # [b, n, 2]
                      input_pos
                      ):
        x = self.coordinate_projection.forward(propagate_pos)
        z = self.decoding_transformer.forward(x, z, propagate_pos, input_pos)
        z = self.expand_feat(z)
        return z

    def forward(self,
                z,              # [b, n, c]
                propagate_pos   # [b, n, 2]
                ):
        z = self.propagate(z, propagate_pos)
        u = self.decode(z)
        u = rearrange(u, 'b n (t c) -> b (t c) n', c=self.out_channels, t=self.out_steps)
        return u, z                # [b c_out t n], [b c_latent t n]

    def rollout(self,
                z,  # [b, n c]
                propagate_pos,  # [b, n, 2]
                forward_steps,
                input_pos):
        history = []
        x = self.coordinate_projection.forward(propagate_pos)
        z = self.decoding_transformer.forward(x, z, propagate_pos, input_pos)
        z = self.expand_feat(z)

        # forward the dynamics in the latent space
        for step in range(forward_steps//self.out_steps):
            z = self.propagate(z, propagate_pos)
            u = self.decode(z)
            history.append(rearrange(u, 'b n (t c) -> b (t c) n', c=self.out_channels, t=self.out_steps))
        history = torch.cat(history, dim=-2)  # concatenate in temporal dimension
        return history  # [b, length_of_history*c, n]


class PointWiseDecoder1D(nn.Module):
    # for Burgers equation
    def __init__(self,
                 latent_channels,  # 256??
                 out_channels,  # 1 or 2?
                 decoding_depth,  # 4?
                 scale=8,
                 res=2048,
                 **kwargs,
                 ):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.out_channels = out_channels
        self.latent_channels = latent_channels

        self.coordinate_projection = nn.Sequential(
            GaussianFourierFeatureTransform(1, self.latent_channels, scale=scale),
            nn.GELU(),
            nn.Linear(self.latent_channels*2, self.latent_channels, bias=False),
        )

        self.decoding_transformer = CrossFormer(self.latent_channels, 'fourier', 8,
                                                self.latent_channels, self.latent_channels,
                                                relative_emb=True,
                                                scale=1,
                                                relative_emb_dim=1,
                                                min_freq=1/res)

        self.propagator = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.latent_channels, self.latent_channels, bias=False),
                nn.GELU(),
                nn.Linear(self.latent_channels, self.latent_channels, bias=False),
                nn.GELU(),
                nn.Linear(self.latent_channels, self.latent_channels, bias=False),)
            for _ in range(decoding_depth)])

        self.init_propagator_params()
        self.to_out = nn.Sequential(
            nn.Linear(self.latent_channels, self.latent_channels//2, bias=False),
            nn.GELU(),
            nn.Linear(self.latent_channels//2, self.out_channels, bias=True))

    def propagate(self, z):
        for num_l, layer in enumerate(self.propagator):
            z = z + layer(z)
        return z

    def decode(self, z):
        z = self.to_out(z)
        return z

    def init_propagator_params(self):
        for block in self.propagator:
            for layers in block:
                    for param in layers.parameters():
                        if param.ndim > 1:
                            in_c = param.size(-1)
                            orthogonal_(param[:in_c], gain=1/in_c)
                            param.data[:in_c] += 1/in_c * torch.diag(torch.ones(param.size(-1), dtype=torch.float32))
                            if param.size(-2) != param.size(-1):
                                orthogonal_(param[in_c:], gain=1/in_c)
                                param.data[in_c:] += 1/in_c * torch.diag(torch.ones(param.size(-1), dtype=torch.float32))

    def forward(self,
                z,  # [b, n c]
                propagate_pos,  # [b, n, 1]
                input_pos=None,
                ):

        x = self.coordinate_projection.forward(propagate_pos)
        z = self.decoding_transformer.forward(x, z, propagate_pos, input_pos)

        z = self.propagate(z)
        z = self.decode(z)
        return z  # [b, n, c]

class RefinementBlock(nn.Module):
    """
    深度精化网络中的一个基本计算单元。
    它是一个包含层归一化和残差连接的逐点MLP。
    """
    def __init__(self, latent_channels):
        super().__init__()
        self.norm = nn.LayerNorm(latent_channels)
        self.ffn = nn.Sequential(
            # 将归一化后的隐状态和2D坐标拼接后输入
            nn.Linear(latent_channels + 2, latent_channels * 4, bias=False),
            nn.GELU(),
            nn.Linear(latent_channels * 4, latent_channels, bias=False)
        )

    def forward(self, hidden_states, pos_embedding):
        """
        Args:
            hidden_states (torch.Tensor): 当前的隐状态, shape (B, N, C)
            pos_embedding (torch.Tensor): 坐标信息, shape (B, N, 2)
        """
        # 残差连接结构: H' = H + FFN(Norm(H), Pos)
        norm_h = self.norm(hidden_states)
        ffn_input = torch.cat((norm_h, pos_embedding), dim=-1)
        refined_states = hidden_states + self.ffn(ffn_input)
        return refined_states
# class PointWiseDecoder2DSimple(nn.Module):
#     # for Darcy equation
#     def __init__(self,
#                  latent_channels,  # 256??
#                  out_channels,  # 1 or 2?
#                  res=211,
#                  scale=0.5,
#                  **kwargs,
#                  ):
#         super().__init__()
#         self.layers = nn.ModuleList([])
#         self.out_channels = out_channels
#         self.latent_channels = latent_channels

#         self.coordinate_projection = nn.Sequential(
#             GaussianFourierFeatureTransform(2, self.latent_channels//2, scale=scale),
#             # nn.Linear(2, self.latent_channels, bias=False),
#             # nn.GELU(),
#             nn.Linear(self.latent_channels, self.latent_channels, bias=False),
#             nn.GELU(),
#             nn.Linear(self.latent_channels, self.latent_channels, bias=False),
#             # nn.Dropout(0.05),
#         )

        # self.decoding_transformer = CrossFormer(self.latent_channels, 'galerkin', 4,
        #                                         self.latent_channels, self.latent_channels,
        #                                         use_ln=False,
        #                                         residual=True,
        #                                         relative_emb=True,
        #                                         scale=16,
        #                                         relative_emb_dim=2,
        #                                         min_freq=1/res)

#         # self.init_propagator_params()
#         self.to_out = nn.Sequential(
#             nn.Linear(self.latent_channels+2, self.latent_channels, bias=False),
#             nn.GELU(),
#             nn.Linear(self.latent_channels, self.latent_channels//2, bias=False),
#             nn.GELU(),
#             nn.Linear(self.latent_channels//2, self.out_channels, bias=True))

#     def decode(self, z):
#         z = self.to_out(z)
#         return z

#     def forward(self,
#                 z,  # [b, n c]
#                 propagate_pos,  # [b, n, 1]
#                 input_pos=None,
#                 ):

#         x = self.coordinate_projection.forward(propagate_pos)
#         z = self.decoding_transformer.forward(x, z, propagate_pos, input_pos)

#         z = self.decode(torch.cat((z, propagate_pos), dim=-1))
#         return z  # [b, n, c]
class PointWiseDecoder2DSimple(nn.Module):
    """
    PointWiseDecoder2DSimple的直接深度修改版。
    
    核心修改：在交叉注意力之后，加入了一个由多个残差MLP块组成的深度精化网络，
    其深度由 `refinement_depth` 参数控制。
    """
    def __init__(self,
                 latent_channels,       # 解码器的工作维度
                 out_channels,          # 输出物理量的通道数
                 refinement_depth,      # <<< 新增参数：精化网络的深度
                 res=211,
                 scale=0.5,
                 **kwargs,              # 用于接收其他参数，如heads
                 ):
        super().__init__()
        self.__name__ = 'PointWiseDecoder2DSimple_Deep'
        self.latent_channels = latent_channels
        self.out_channels = out_channels
        self.refinement_depth = refinement_depth

        # --- 组件 1: 初始解码 ---
        # 坐标编码器，将2D坐标映射到高维空间
        fourier_map_dim = latent_channels * 2 # 输出维度为 mapping_size * 2
        self.coordinate_projection = nn.Sequential(
            GaussianFourierFeatureTransform(2, latent_channels, scale=scale),
            nn.Linear(fourier_map_dim, latent_channels),
        )

        # 交叉注意力模块，用于融合坐标信息和Encoder的全局上下文
        self.decoding_transformer = CrossFormer(self.latent_channels, 'galerkin', 4,
                                                self.latent_channels, self.latent_channels,
                                                use_ln=False,
                                                residual=True,
                                                relative_emb=True,
                                                scale=16,
                                                relative_emb_dim=2,
                                                min_freq=1/res)

        # --- 组件 2: 深度精化网络 (核心修改) ---
        # 使用ModuleList来堆叠多个独立的精化块
        self.refinement_blocks = nn.ModuleList(
            [RefinementBlock(latent_channels) for _ in range(refinement_depth)]
        )

        # --- 组件 3: 输出头 ---
        self.to_out = nn.Sequential(
            nn.Linear(latent_channels + 2, latent_channels, bias=False),
            nn.GELU(),
            nn.Linear(latent_channels, latent_channels // 2, bias=False),
            nn.GELU(),
            nn.Linear(latent_channels // 2, out_channels, bias=True)
        )

    def forward(self,
                z,              # Encoder的输出, shape: [B, N_in, C]
                propagate_pos,  # 查询点坐标, shape: [B, N_out, 2]
                input_pos=None, # Encoder输入点坐标, shape: [B, N_in, 2]
                ):
        """
        执行一次完整的前向传播。
        """
        # --- 步骤 1: 初始解码 (信息融合) ---
        # 1a. 对查询坐标进行编码，得到查询向量
        query_features = self.coordinate_projection(propagate_pos)
        
        # 1b. 通过交叉注意力，将坐标特征与全局上下文 z 融合，得到初始隐状态
        # z 是 Key/Value, query_features 是 Query
        hidden_states = self.decoding_transformer(query_features, z, propagate_pos, input_pos)
        
        # --- 步骤 2: 深度精化 (迭代优化) ---
        # 将初始隐状态送入多层精化网络
        for block in self.refinement_blocks:
            hidden_states = block(hidden_states, propagate_pos)
        
        # --- 步骤 3: 最终输出 ---
        # 将精化后的隐状态和坐标信息一起送入输出头，得到最终的物理量
        final_input = torch.cat((hidden_states, propagate_pos), dim=-1)
        output = self.to_out(final_input)
        
        # 为了与未来的自适应版本保持接口统一，返回一个空的diagnostics字典
        return output, {}
class AdaptivePointWiseDecoder2D_SteadyState(nn.Module):
    """
    集成了自适应计算机制的稳态问题解码器。
    它将 `PointWiseDecoder2DSimple_Deep` 中的密集精化网络替换为
    高效的 `StructuredRecursivePropagator`。
    """
    def __init__(self,
                 latent_channels,
                 out_channels,
                 propagator_depth,      # <<< 现在控制自适应递归的深度
                 capacity_ratios,       # <<< 新增：每层的计算容量比例列表
                 final_keep_ratio=0.25, # <<< 新增：若未提供ratios, 则线性衰减到此比例
                 res=211,
                 scale=0.5,
                 **kwargs,
                 ):
        super().__init__()
        self.__name__ = 'AdaptivePointWiseDecoder2D_SteadyState'
        self.latent_channels = latent_channels
        self.out_channels = out_channels

        # --- 组件 1: 初始解码 (保持不变) ---
        fourier_map_dim = latent_channels * 2
        self.coordinate_projection = nn.Sequential(
            GaussianFourierFeatureTransform(2, latent_channels, scale=scale),
            nn.Linear(fourier_map_dim, latent_channels),
        )
        self.decoding_transformer = CrossFormer(self.latent_channels, 'galerkin', 4,
                                                self.latent_channels, self.latent_channels,
                                                use_ln=False,
                                                residual=True,
                                                relative_emb=True,
                                                scale=16,
                                                relative_emb_dim=2,
                                                min_freq=1/res)

        # --- 组件 2: 自适应深度精化 (核心替换) ---
        self.propagator = StructuredRecursivePropagator_simple(
            latent_channels=latent_channels,
            recursion_depth=propagator_depth,
            context_dim=latent_channels, # 路由器基于初始隐状态进行打分
            capacity_ratios=capacity_ratios,
            final_keep_ratio=final_keep_ratio,
        )

        # --- 组件 3: 输出头 (保持不变) ---
        self.to_out = nn.Sequential(
            nn.Linear(latent_channels + 2, latent_channels, bias=False),
            nn.GELU(),
            nn.Linear(latent_channels, latent_channels // 2, bias=False),
            nn.GELU(),
            nn.Linear(latent_channels // 2, out_channels, bias=True)
        )

    def forward(self,
                z,
                propagate_pos,
                input_pos=None,
                ):
        # 1. 初始解码，生成 h_initial
        query_features = self.coordinate_projection(propagate_pos)
        h_initial = self.decoding_transformer(query_features, z, propagate_pos, input_pos)
        
        # 2. 自适应深度精化
        #    使用 h_initial.detach() 作为路由器的静态上下文，以稳定训练
        z_context_for_router = h_initial.detach()
        h_final, diagnostics = self.propagator(
            hidden_states=h_initial, 
            pos_embedding=propagate_pos,
            z_context=z_context_for_router
        )
        
        # 3. 最终输出
        final_input = torch.cat((h_final, propagate_pos), dim=-1)
        output = self.to_out(final_input)
        
        return output, diagnostics
class STPointWiseDecoder2D(nn.Module):
    def __init__(self,
                 latent_channels,  # 256??
                 out_channels,  # 1 or 2?
                 out_steps,
                 scale=8,
                 **kwargs,
                 ):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.out_channels = out_channels
        self.out_steps = out_steps
        self.latent_channels = latent_channels

        self.coordinate_projection = nn.Sequential(
            GaussianFourierFeatureTransform(3, self.latent_channels//2, scale=scale),
            nn.GELU(),
            nn.Linear(self.latent_channels, self.latent_channels, bias=False),
            nn.GELU(),
            nn.Linear(self.latent_channels, self.latent_channels, bias=False),
        )

        self.decoding_transformer = CrossFormer(self.latent_channels, 'galerkin', 1,
                                                self.latent_channels, self.latent_channels,
                                                residual=False,
                                                use_ffn=False,
                                                relative_emb=True,
                                                scale=1.,
                                                relative_emb_dim=2,
                                                min_freq=1/64)

        self.to_out = nn.Sequential(
            nn.LayerNorm(self.latent_channels),
            nn.Linear(self.latent_channels, self.latent_channels//2, bias=False),
            nn.GELU(),
            nn.Linear(self.latent_channels // 2, self.latent_channels // 2, bias=False),
            nn.GELU(),
            nn.Linear(self.latent_channels//2, self.out_channels, bias=True))

    def decode(self, z):
        z = self.to_out(z)
        return z

    def forward(self,
                z,              # [b, n, c]
                propagate_pos,  # [b, tn, 3]
                input_pos,      # [b, n, 2]
                ):
        x = self.coordinate_projection.forward(propagate_pos)
        z = self.decoding_transformer.forward(x, z, propagate_pos[:, :, :-1], input_pos)
        z = self.decode(z)
        z = rearrange(z, 'b (t n) c -> b (t c) n', c=self.out_channels, t=self.out_steps)
        return z


class BCDecoder1D(nn.Module):
    # for Burgers equation, using DeepONet formulation
    def __init__(self,
                 latent_channels,  # 256??
                 out_channels,  # 1 or 2?
                 decoding_depth,  # 4?
                 scale=8,
                 res=2048,
                 **kwargs,
                 ):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.out_channels = out_channels
        self.latent_channels = latent_channels

        self.coordinate_projection = nn.Sequential(
            GaussianFourierFeatureTransform(1, self.latent_channels, scale=scale),
            nn.GELU(),
            nn.Linear(self.latent_channels*2, self.latent_channels, bias=False),
            nn.GELU(),
            nn.Linear(self.latent_channels, self.latent_channels, bias=False),
        )

        self.decoding_transformer = BranchTrunkNet(latent_channels,
                                                   res)

    def forward(self,
                z,  # [b, n, c]
                propagate_pos,  # [b, n, 1]
                ):
        propagate_pos = propagate_pos[0]
        x = self.coordinate_projection.forward(propagate_pos)
        z = self.decoding_transformer.forward(x, z)

        return z  # [b, n, c]


class PieceWiseDecoder2DSimple(nn.Module):
    # for Darcy flow inverse problem
    def __init__(self,
                 latent_channels,  # 256??
                 out_channels,  # 1 or 2?
                 res=141,
                 scale=0.5,
                 **kwargs,
                 ):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.out_channels = out_channels
        self.latent_channels = latent_channels

        self.coordinate_projection = nn.Sequential(
            GaussianFourierFeatureTransform(2, self.latent_channels//2, scale=scale),
            # nn.Linear(2, self.latent_channels, bias=False),
            # nn.GELU(),
            # nn.Linear(self.latent_channels*2, self.latent_channels, bias=False),
            # nn.GELU(),
            nn.GELU(),
            nn.Linear(self.latent_channels, self.latent_channels, bias=False),
            nn.Dropout(0.05),
        )

        self.decoding_transformer = CrossFormer(self.latent_channels, 'galerkin', 4,
                                                self.latent_channels, self.latent_channels,
                                                use_ln=False,
                                                residual=True,
                                                use_ffn=False,
                                                relative_emb=True,
                                                scale=16,
                                                relative_emb_dim=2,
                                                min_freq=1/res)

        # self.init_propagator_params()
        self.to_out = nn.Sequential(
            nn.Linear(self.latent_channels+2, self.latent_channels, bias=False),
            nn.ReLU(),
            nn.Linear(self.latent_channels, self.latent_channels, bias=False),
            nn.ReLU(),
            nn.Linear(self.latent_channels, self.latent_channels//2, bias=False),
            nn.ReLU(),
            nn.Linear(self.latent_channels//2, self.out_channels, bias=True))

    def decode(self, z):
        z = self.to_out(z)
        return z

    def forward(self,
                z,  # [b, n c]
                propagate_pos,  # [b, n, 1]
                input_pos=None,
                ):

        x = self.coordinate_projection.forward(propagate_pos)
        z = self.decoding_transformer.forward(x, z, propagate_pos, input_pos)

        z = self.decode(torch.cat((z, propagate_pos), dim=-1))
        return z  # [b, n, c]


class NoRelPointWiseDecoder2D(nn.Module):
    def __init__(self,
                 latent_channels,  # 256??
                 out_channels,  # 1 or 2?
                 out_steps,  # 10
                 propagator_depth,
                 scale=8,
                 dropout=0.,
                 **kwargs,
                 ):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.out_channels = out_channels
        self.out_steps = out_steps
        self.latent_channels = latent_channels

        self.coordinate_projection = nn.Sequential(
            GaussianFourierFeatureTransform(2, self.latent_channels//2, scale=scale),
            nn.Linear(self.latent_channels, self.latent_channels),
            nn.GELU(),
            nn.Linear(self.latent_channels, self.latent_channels//2, bias=False),
        )

        self.decoding_transformer = CrossFormer(self.latent_channels//2, 'galerkin', 4,
                                                self.latent_channels//2, self.latent_channels//2,
                                                relative_emb=False,
                                                cat_pos=True,
                                                relative_emb_dim=2,
                                                min_freq=1/64)

        self.expand_feat = nn.Linear(self.latent_channels//2, self.latent_channels)

        self.propagator = nn.ModuleList([
               nn.ModuleList([nn.LayerNorm(self.latent_channels),
               nn.Sequential(
                    nn.Linear(self.latent_channels + 2, self.latent_channels, bias=False),
                    nn.GELU(),
                    nn.Linear(self.latent_channels, self.latent_channels, bias=False),
                    nn.GELU(),
                    nn.Linear(self.latent_channels, self.latent_channels, bias=False))])
            for _ in range(propagator_depth)])

        self.to_out = nn.Sequential(
            nn.LayerNorm(self.latent_channels),
            nn.Linear(self.latent_channels, self.latent_channels//2, bias=False),
            nn.GELU(),
            nn.Linear(self.latent_channels // 2, self.latent_channels // 2, bias=False),
            nn.GELU(),
            nn.Linear(self.latent_channels//2, self.out_channels * self.out_steps, bias=True))

    def propagate(self, z, pos):
        for layer in self.propagator:
            norm_fn, ffn = layer
            z = ffn(torch.cat((norm_fn(z), pos), dim=-1)) + z
        return z

    def decode(self, z):
        z = self.to_out(z)
        return z

    def get_embedding(self,
                      z,  # [b, n c]
                      propagate_pos,  # [b, n, 2]
                      input_pos
                      ):
        x = self.coordinate_projection.forward(propagate_pos)
        z = self.decoding_transformer.forward(x, z, propagate_pos, input_pos)
        z = self.expand_feat(z)
        return z

    def forward(self,
                z,              # [b, n, c]
                propagate_pos   # [b, n, 2]
                ):
        z = self.propagate(z, propagate_pos)
        u = self.decode(z)
        u = rearrange(u, 'b n (t c) -> b (t c) n', c=self.out_channels, t=self.out_steps)
        return u, z                # [b c_out t n], [b c_latent t n]

    def rollout(self,
                z,  # [b, n c]
                propagate_pos,  # [b, n, 2]
                forward_steps,
                input_pos):
        history = []
        x = self.coordinate_projection.forward(propagate_pos)
        z = self.decoding_transformer.forward(x, z, propagate_pos, input_pos)
        z = self.expand_feat(z)

        # forward the dynamics in the latent space
        for step in range(forward_steps//self.out_steps):
            z = self.propagate(z, propagate_pos)
            u = self.decode(z)
            history.append(rearrange(u, 'b n (t c) -> b (t c) n', c=self.out_channels, t=self.out_steps))
        history = torch.cat(history, dim=-2)  # concatenate in temporal dimension
        return history  # [b, length_of_history*c, n]

class StructuredRecursivePropagator(nn.Module):
    def __init__(self, latent_channels, n_head, recursion_depth, n_inner,
                 context_dim, # 用于路由器的、上下文的真实维度
                 capacity_ratios=None, final_keep_ratio=0.25):
        super().__init__()
        self.__name__ = 'StructuredRecursivePropagator'
        self.recursion_depth = recursion_depth
        self.latent_channels = latent_channels

        # --- a. 处理容量比例 ---
        if capacity_ratios is not None:
            if len(capacity_ratios) != self.recursion_depth:
                raise ValueError(
                    f"Length of capacity_ratios ({len(capacity_ratios)}) must match recursion_depth ({self.recursion_depth})."
                )
            ratios = capacity_ratios
        else:
            print(f"Warning: capacity_ratios not provided. Defaulting to linear decay from 1.0 to {final_keep_ratio}.")
            ratios = np.linspace(1.0, final_keep_ratio, self.recursion_depth)
        
        self.register_buffer('capacity_ratios', torch.tensor(ratios, dtype=torch.float32))
        print(f"[{self.__name__}] Capacity Ratios per Layer: {self.capacity_ratios.tolist()}")
        
        # --- b. 核心组件创建 ---
        # 路由器使用正确的上下文维度进行初始化
        self.router = LightweightRouter(context_dim)
        
        # 递归块使用Decoder的工作维度n_hidden/latent_channels
        self.recursion_blocks = nn.ModuleList([self._create_block() for _ in range(self.recursion_depth)])
        # torch.manual_seed(42)
    def _create_block(self):
        """ 创建一个与OFormer原始propagator计算单元兼容的、轻量级的Block """
        return nn.ModuleList([
            nn.LayerNorm(self.latent_channels),
            nn.Sequential(
                nn.Linear(self.latent_channels + 2, self.latent_channels, bias=False),
                nn.GELU(),
                nn.Linear(self.latent_channels, self.latent_channels, bias=False),
                nn.GELU(),
                nn.Linear(self.latent_channels, self.latent_channels, bias=False)
            )
        ])

    def forward(self, hidden_states, pos_embedding, z_context):
        """
        Args:
            hidden_states (torch.Tensor): 当前需要演化的隐状态, shape (B, N, C)
            pos_embedding (torch.Tensor): 空间坐标, shape (B, N, 2)
            z_context (torch.Tensor): 来自Encoder的上下文，用于路由, shape (B, N, C_context)
        """
        B, N, C = hidden_states.shape
        
        diagnostics = {
            'active_tokens_per_layer': [],
            'time_per_layer_ms': []
        }
        
        # 1. 一次性全局路由打分
        # 推理时不需要梯度，训练时为了稳定也可以关闭路由器的梯度，只依赖主任务
        with torch.no_grad():
            router_scores = self.router(z_context).squeeze(-1)
            _, global_ranking_indices = torch.sort(router_scores, dim=1, descending=True)
            # _, global_ranking_indices = torch.rand_like(router_scores).sort(dim=1, descending=True)
        # 初始化计时器
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        
        # 2. 结构化的自适应递归循环
        for depth in range(self.recursion_depth):
            # a. 根据预定义的容量，确定当前层的活跃索引
            k = int(N * self.capacity_ratios[depth])
            k = max(k, 1)
            if k > N: k = N
            active_indices = global_ranking_indices[:, :k]
            
            num_active = active_indices.shape[1]
            diagnostics['active_tokens_per_layer'].append(num_active)
            
            if num_active == 0:
                diagnostics['time_per_layer_ms'].append(0.0)
                continue

            # b. Pack: 收集活跃的Token和它们的位置
            active_h = pack_tokens(hidden_states, active_indices)
            active_pos = pack_tokens(pos_embedding, active_indices)
            
            # c. Compute: 在紧凑的活跃张量上执行计算
            torch.cuda.synchronize(hidden_states.device)
            start_event.record()
            
            norm_fn, ffn = self.recursion_blocks[depth]
            ffn_input = torch.cat((norm_fn(active_h), active_pos), dim=-1)
            # 应用残差连接
            block_output_active = ffn(ffn_input) + active_h 
            
            end_event.record()
            torch.cuda.synchronize(hidden_states.device)
            diagnostics['time_per_layer_ms'].append(start_event.elapsed_time(end_event))
            
            # d. Scatter: 将计算结果写回
            idx_exp = active_indices.unsqueeze(-1).expand(-1, -1, C)
            hidden_states = hidden_states.scatter(1, idx_exp, block_output_active)
            
        return hidden_states, diagnostics
    
class StructuredRecursivePropagator_simple(nn.Module):
    """
    自适应计算的核心引擎。
    它接收一个初始隐状态，并根据一个静态的重要性排名，在多层中高效地进行演化。
    """
    def __init__(self, latent_channels, recursion_depth, context_dim,
                 capacity_ratios=None, final_keep_ratio=0.25):
        super().__init__()
        self.recursion_depth = recursion_depth
        self.latent_channels = latent_channels

        # --- a. 处理容量比例 ---
        if capacity_ratios is not None:
            if len(capacity_ratios) != self.recursion_depth:
                raise ValueError(f"capacity_ratios的长度({len(capacity_ratios)})必须与recursion_depth({self.recursion_depth})匹配。")
            ratios = capacity_ratios
        else:
            ratios = np.linspace(1.0, final_keep_ratio, self.recursion_depth)
        
        self.register_buffer('capacity_ratios', torch.tensor(ratios, dtype=torch.float32))

        # --- b. 核心组件 ---
        self.router = LightweightRouter(context_dim)
        
        # 递归块：这里的实现与PointWiseDecoder2DSimple_Deep中的RefinementBlock一致，以确保公平对比
        self.recursion_blocks = nn.ModuleList([self._create_block() for _ in range(self.recursion_depth)])
        # torch.manual_seed(42)
    def _create_block(self):
        """创建一个与之前版本RefinementBlock兼容的计算单元。"""
        return nn.ModuleList([
            nn.LayerNorm(self.latent_channels),
            nn.Sequential(
                nn.Linear(self.latent_channels + 2, self.latent_channels * 4, bias=False),
                nn.GELU(),
                nn.Linear(self.latent_channels * 4, self.latent_channels, bias=False)
            )
        ])

    def forward(self, hidden_states, pos_embedding, z_context):
        B, N, C = hidden_states.shape
        diagnostics = {'active_tokens_per_layer': []}

        # 1. 一次性全局路由打分和排序
        with torch.no_grad(): # 在推理和训练中都关闭梯度，以增加稳定性
            router_scores = self.router(z_context).squeeze(-1) # (B, N)
            _, global_ranking_indices = torch.sort(router_scores, dim=1, descending=True)
            # _, global_ranking_indices = torch.rand_like(router_scores).sort(dim=1, descending=True)

        # 2. 结构化的自适应递归循环
        for depth in range(self.recursion_depth):
            # a. 确定当前层的计算容量和活跃索引
            k = max(1, int(N * self.capacity_ratios[depth]))
            active_indices = global_ranking_indices[:, :k]
            diagnostics['active_tokens_per_layer'].append(k)

            # b. 打包 (Pack): 高效收集活跃的Token
            active_h = pack_tokens(hidden_states, active_indices)
            active_pos = pack_tokens(pos_embedding, active_indices)

            # c. 计算 (Compute): 在紧凑的活跃子集上执行计算
            norm_fn, ffn = self.recursion_blocks[depth]
            ffn_input = torch.cat((norm_fn(active_h), active_pos), dim=-1)
            # 应用残差连接
            block_output_active = active_h + ffn(ffn_input)

            # d. 分散 (Scatter): 将计算结果写回原始尺寸张量的对应位置
            #    未被选中的token特征通过这种方式被直接传递（隐式的残差连接）
            hidden_states = hidden_states.scatter(1, active_indices.unsqueeze(-1).expand(-1, -1, C), block_output_active)
            
        return hidden_states, diagnostics
# ==============================================================================
# 3. 新的、集成了轻量级自适应传播器的OFormer Decoder (最终版)
# ==============================================================================
class StructuredRecursivePropagator1D(nn.Module):
    """
    一个专门为1D任务设计的自适应深度特征处理模块。
    它在逻辑上与2D版本完全相同，但硬编码处理1维坐标。
    """
    def __init__(self,
                 latent_channels,       # 例如 96
                 recursion_depth,       # 自适应递归的深度
                 context_dim,           # 用于路由的上下文维度
                 capacity_ratios=None,
                 final_keep_ratio=0.25,
                 **kwargs):
        super().__init__()
        self.__name__ = 'StructuredRecursivePropagator1D'
        self.recursion_depth = recursion_depth
        self.latent_channels = latent_channels

        # --- a. 处理容量比例 (逻辑与2D版本完全相同) ---
        if capacity_ratios is not None:
            if len(capacity_ratios) != self.recursion_depth:
                raise ValueError(
                    f"capacity_ratios 的长度 ({len(capacity_ratios)}) 必须与 recursion_depth ({self.recursion_depth}) 相匹配。"
                )
            ratios = capacity_ratios
        else:
            print(f"警告: 未提供 capacity_ratios。将默认使用从 1.0 到 {final_keep_ratio} 的线性衰减。")
            ratios = np.linspace(1.0, final_keep_ratio, self.recursion_depth)
        
        self.register_buffer('capacity_ratios', torch.tensor(ratios, dtype=torch.float32))
        print(f"[{self.__name__}] 每层容量比例: {self.capacity_ratios.tolist()}")
        
        # --- b. 核心组件 (路由器与2D版本完全相同) ---
        self.router = LightweightRouter(context_dim)
        
        # --- c. 递归计算块 (唯一的区别在这里) ---
        self.recursion_blocks = nn.ModuleList([self._create_block() for _ in range(self.recursion_depth)])

    def _create_block(self):
        """
        创建一个处理1D坐标的计算块。
        """
        return nn.ModuleList([
            nn.LayerNorm(self.latent_channels),
            nn.Sequential(
                # ！！！核心区别！！！
                # 输入维度是 latent_channels + 1 (因为坐标是1维的)
                nn.Linear(self.latent_channels + 1, self.latent_channels, bias=False),
                nn.GELU(),
                nn.Linear(self.latent_channels, self.latent_channels, bias=False),
                nn.GELU(),
                nn.Linear(self.latent_channels, self.latent_channels, bias=False)
            )
        ])

    def forward(self, hidden_states, pos_embedding, z_context):
        """
        Args:
            hidden_states (torch.Tensor): 隐状态, shape (B, N, C)
            pos_embedding (torch.Tensor): 空间坐标, shape (B, N, 1) <--- 期望是1维
            z_context (torch.Tensor): 路由上下文, shape (B, N, C_context)
        """
        if pos_embedding.shape[-1] != 1:
            raise ValueError(f"StructuredRecursivePropagator1D 期望一个1维的位置编码，但收到了 {pos_embedding.shape[-1]} 维。")

        B, N, C = hidden_states.shape
        
        diagnostics = {
            'active_tokens_per_layer': [],
            'time_per_layer_ms': []
        }
        
        # --- (后续的路由、排序、Pack、Compute、Scatter逻辑与2D版本完全相同) ---
        with torch.no_grad():
            router_scores = self.router(z_context).squeeze(-1)
            _, global_ranking_indices = torch.sort(router_scores, dim=1, descending=True)
            
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        
        for depth in range(self.recursion_depth):
            k = int(N * self.capacity_ratios[depth])
            k = max(k, 1)
            active_indices = global_ranking_indices[:, :k]
            
            num_active = active_indices.shape[1]
            diagnostics['active_tokens_per_layer'].append(num_active)
            
            if num_active == 0:
                diagnostics['time_per_layer_ms'].append(0.0)
                continue

            active_h = pack_tokens(hidden_states, active_indices)
            active_pos = pack_tokens(pos_embedding, active_indices)
            
            torch.cuda.synchronize(hidden_states.device)
            start_event.record()
            
            norm_fn, ffn = self.recursion_blocks[depth]
            ffn_input = torch.cat((norm_fn(active_h), active_pos), dim=-1)
            block_output_active = ffn(ffn_input) + active_h 
            
            end_event.record()
            torch.cuda.synchronize(hidden_states.device)
            diagnostics['time_per_layer_ms'].append(start_event.elapsed_time(end_event))
            
            idx_exp = active_indices.unsqueeze(-1).expand(-1, -1, C)
            hidden_states = hidden_states.scatter(1, idx_exp, block_output_active)
            
        return hidden_states, diagnostics
class PointWiseDecoder2D_Adaptive(nn.Module):
    def __init__(self,
                 latent_channels, # 来自 --decoder_emb_dim
                 out_channels,    # 来自 --out_channels
                 out_steps,       # 来自 --out_step
                 propagator_depth,# 来自 --propagator_depth, 作为递归深度
                 scale=8,
                 dropout=0.,
                 **kwargs):
        super().__init__()
        self.__name__ = 'PointWiseDecoder2D_Adaptive'
        
        self.out_channels = out_channels
        self.out_steps = out_steps
        self.latent_channels = latent_channels

        # --- a. OFormer的输入处理部分 ---
        # 傅里叶特征编码器，将2D坐标映射到高维空间
        fourier_map_size = self.latent_channels // 2
        self.coordinate_projection = GaussianFourierFeatureTransform(2, fourier_map_size, scale=scale)
        
        # --- b. 交叉注意力模块 ---
        # 用于将Encoder的输出投影到Decoder工作维度的适配器层
        encoder_out_dim = kwargs.get('out_seq_emb_dim', self.latent_channels)
        self.context_projection = nn.Linear(encoder_out_dim, self.latent_channels)

        # 交叉注意力接收来自坐标编码器的查询(Query)和来自Encoder的上下文(Key/Value)
        self.decoding_transformer = XFAttention(
            n_embd=self.latent_channels,
            n_head=kwargs.get('encoder_heads', 4), 
            attn_pdrop=dropout
        )
        self.cross_attn_norm = nn.LayerNorm(self.latent_channels)
        
        # --- c. 核心自适应传播器 ---
        # 用我们的结构化递归引擎替换原始的MLPPropagator
        self.propagator = StructuredRecursivePropagator(
            latent_channels=self.latent_channels,
            n_head=kwargs.get('encoder_heads', 4),
            recursion_depth=propagator_depth,
            n_inner=self.latent_channels * 4,
            capacity_ratios=kwargs.get('capacity_ratios'),
            final_keep_ratio=kwargs.get('final_keep_ratio', 0.25),
            context_dim=self.latent_channels
        )

        # --- d. 输出解码部分 ---
        self.to_out = nn.Sequential(
            nn.LayerNorm(self.latent_channels),
            nn.Linear(self.latent_channels, self.latent_channels * 2),
            nn.GELU(),
            nn.Linear(self.latent_channels * 2, self.out_channels * self.out_steps)
        )

    def get_embedding(self, z, propagate_pos):
        """
        执行Decoder的初始步骤：坐标编码和交叉注意力。
        z: 来自Encoder的上下文, shape (B, N, C_encoder)
        propagate_pos: 查询点的坐标, shape (B, N, 2)
        """
        # 1. 将查询点坐标编码为高维特征
        x_queries = self.coordinate_projection(propagate_pos) # -> (B, N, C_decoder)
        
        # 2. 将Encoder的上下文投影到Decoder的工作空间
        z_context = self.context_projection(z) # -> (B, N, C_decoder)
        
        # 3. 执行交叉注意力，得到初始的隐状态
        h = self.decoding_transformer(self.cross_attn_norm(x_queries), y=z_context)
        
        return h, z_context

    def decode(self, z):
        """
        将最终的隐状态解码为物理场输出。
        """
        return self.to_out(z)

    def rollout(self, z, prop_pos, forward_steps, input_pos):
        """
        执行完整的自回归预测，并返回诊断信息。
        z: 来自Encoder的静态上下文, shape (B, N, C_encoder)
        prop_pos: 查询点的坐标, shape (B, N, 2) (在我们的场景中与input_pos相同)
        """
        history = []
        diagnostics_report = None # 初始化报告为空
        
        # 1. 获得自回归的初始状态h_current和用于路由的上下文z_context
        h_current, z_context = self.get_embedding(z, prop_pos)
        
        # 2. 自回归循环
        for step in range(forward_steps // self.out_steps):
            # 3. (核心) 调用我们新的自适应传播器进行演化
            #    它现在返回 (更新后的状态, 诊断信息字典)
            h_next, diagnostics = self.propagator(h_current, pos_embedding=prop_pos, z_context=z_context)
            
            # 4. (核心) 只在第一个自回归步骤保存诊断报告
            if step == 0:
                diagnostics_report = diagnostics
            
            # 5. 解码当前步的输出
            u = self.decode(h_next)
            history.append(rearrange(u, 'b n (t c) -> b c t n', c=self.out_channels, t=self.out_steps))
            
            # 6. 更新状态，用于下一步自回归
            h_current = h_next
        
        history = torch.cat(history, dim=2)
        
        # 7. 将最终的预测结果和诊断报告一起返回
        final_prediction = rearrange(history, 'b c t n -> b (t c) n')
        
        return final_prediction, diagnostics_report
class AdaptivePointWiseDecoder1D(nn.Module):
    """
    一个 PointWiseDecoder1D 的自适应版本，专门用于 Burgers 任务。
    它使用 StructuredRecursivePropagator1D 作为其核心演化引擎。
    """
    def __init__(self,
                 latent_channels,
                 out_channels,
                 propagator_depth, # 自适应递归深度
                 capacity_ratios,
                 res=2048,
                 scale=2,
                 **kwargs):
        super().__init__()
        self.__name__ = 'AdaptivePointWiseDecoder1D'
        self.out_channels = out_channels
        self.latent_channels = latent_channels

        # --- a. 信息融合部分 (与 PointWiseDecoder1D 保持一致) ---
        self.coordinate_projection = nn.Sequential(
            GaussianFourierFeatureTransform(1, latent_channels, scale=scale),
            nn.GELU(),
            nn.Linear(latent_channels * 2, latent_channels, bias=False),
        )
        self.decoding_transformer = CrossFormer(self.latent_channels, 'fourier', 8,
                                                self.latent_channels, self.latent_channels,
                                                relative_emb=True,
                                                scale=1,
                                                relative_emb_dim=1,
                                                min_freq=1/res)

        # --- b. 核心修改：使用 1D 自适应传播器 ---
        self.propagator = StructuredRecursivePropagator1D(
            latent_channels=latent_channels,
            recursion_depth=propagator_depth,
            context_dim=latent_channels, # 路由基于初始隐状态
            capacity_ratios=capacity_ratios
        )

        # --- c. 最终解码头 (与 PointWiseDecoder1D 保持一致) ---
        self.to_out = nn.Sequential(
            nn.Linear(latent_channels, latent_channels // 2, bias=False),
            nn.GELU(),
            nn.Linear(latent_channels // 2, out_channels, bias=True)
        )

    def forward(self, z, prop_pos, input_pos=None):
        """
        执行一次完整的从 t=0 到 t=1 的映射。
        """
        # 1. 融合 Encoder 输出 z 和坐标编码，得到初始隐状态 h_0
        x = self.coordinate_projection(prop_pos)
        h_0 = self.decoding_transformer(x, z, prop_pos, input_pos)

        # 2. 将 h_0 作为演化的初始状态，和用于路由的上下文
        h_current = h_0
        z_context = h_0.detach()
        
        # 3. 调用1D自适应传播器，完成内部的隐式演化
        h_final, diagnostics = self.propagator(
            hidden_states=h_current, 
            pos_embedding=prop_pos,
            z_context=z_context
        )
        
        # 4. 解码最终状态
        output = self.to_out(h_final)
        
        return output, diagnostics