import torch
import torch.nn as nn
from torchdiffeq import odeint
import math
import torch.nn.functional as F

# Other normalization
class RMSNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-8, bias=True):
        super().__init__()
        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape)) if bias else None

    def forward(self, x):
        norm = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
        out = x / norm * self.weight
        return out + self.bias if self.bias is not None else out

# RoPE position encoding
def apply_rope(x, sin, cos):
    """
    x: [B, L, H, rotary_dim]
    sin, cos: [L, rotary_dim]
    """
    # Expand for broadcasting
    sin = sin[None, :, None, :]   # [1, L, 1, rotary_dim]
    cos = cos[None, :, None, :]   # [1, L, 1, rotary_dim]

    x1 = x[..., ::2]
    x2 = x[..., 1::2]

    sin_ = sin[..., ::2]
    cos_ = cos[..., ::2]

    x_rot = torch.stack([
        x1 * cos_ - x2 * sin_,
        x1 * sin_ + x2 * cos_
    ], dim=-1)

    return x_rot.flatten(-2)  # [B, L, H, rotary_dim]


def build_rope_sin_cos(seq_len, rotary_dim, device, base=10000):
    """
    Builds sin and cos buffers for RoPE.

    Returns:
        sin: [seq_len, rotary_dim] on same device
        cos: [seq_len, rotary_dim] on same device
    """
    theta = 1.0 / (base ** (torch.arange(0, rotary_dim, 2, device=device).float() / rotary_dim))  # 🔥 device added
    t = torch.arange(seq_len, device=device).float()  # 🔥 device added

    freqs = torch.einsum("i,j->ij", t, theta)  # [seq_len, rotary_dim // 2]
    sin = torch.sin(freqs).repeat_interleave(2, dim=-1)
    cos = torch.cos(freqs).repeat_interleave(2, dim=-1)

    return sin, cos


# Encoder families
class GRUUpdateEnhanced(nn.Module):
    """
    A more expressive masked GRU cell:
    - Separate linear layers for gates
    - Optional hidden layer (nonlinear MLP)
    - Optional LayerNorm or Dropout
    """

    def __init__(self, input_dim: int, rnn_hidden_dim: int, hidden_units: int = 128, use_layernorm=False):
        super().__init__()
        self.input_dim = input_dim
        self.rnn_hidden_dim = rnn_hidden_dim
        self.use_layernorm = use_layernorm

        # Shared MLP before gates (optional)
        self.mlp = nn.Sequential(
            nn.Linear(input_dim + rnn_hidden_dim, hidden_units),
            nn.ReLU()
        )

        # Separate gate layers
        self.linear_z = nn.Linear(hidden_units, rnn_hidden_dim)
        self.linear_r = nn.Linear(hidden_units, rnn_hidden_dim)
        self.linear_n = nn.Linear(hidden_units, rnn_hidden_dim)
        # Optional normalization
        if use_layernorm:
            self.ln_z = nn.LayerNorm(rnn_hidden_dim)
            self.ln_r = nn.LayerNorm(rnn_hidden_dim)
            self.ln_n = nn.LayerNorm(rnn_hidden_dim)
            # self.ln_z = RMSNorm(rnn_hidden_dim)
            # self.ln_r = RMSNorm(rnn_hidden_dim)
            # self.ln_n = RMSNorm(rnn_hidden_dim)
    def forward(self, h: torch.Tensor, x: torch.Tensor, mask_1d: torch.Tensor) -> torch.Tensor:
        # h: [B, H], x: [B, D], mask_1d: [B]
        concat = torch.cat([h, x], dim=-1)                     # [B, D+H]
        hidden = self.mlp(concat)                              # [B, hidden_units]

        # Gate computations
        z = torch.sigmoid(self.ln_z(self.linear_z(hidden)) if self.use_layernorm else self.linear_z(hidden))
        r = torch.sigmoid(self.ln_r(self.linear_r(hidden)) if self.use_layernorm else self.linear_r(hidden))

        # Candidate hidden state
        r_h = r * h
        concat_rh = torch.cat([r_h, x], dim=-1)
        hidden_n = self.mlp(concat_rh)  # reuse MLP for n-update or make a new one
        n = torch.tanh(self.ln_n(self.linear_n(hidden_n)) if self.use_layernorm else self.linear_n(hidden_n))

        # GRU update
        h_new = (1 - z) * n + z * h

        # Masked update
        mask = mask_1d.unsqueeze(-1)  # [B, 1]
        h_out = mask * h_new + (1 - mask) * h

        return h_out
class GRUUpdate(nn.Module):
    """
    GRU cell with masking, following standard GRU formulation.
    """

    def __init__(self, input_dim: int, rnn_hidden_dim: int):
        super().__init__()
        self.input_dim = input_dim
        self.rnn_hidden_dim = rnn_hidden_dim
        self.linear = nn.Linear(input_dim + rnn_hidden_dim, 3 * rnn_hidden_dim)

    def forward(self, h: torch.Tensor, x: torch.Tensor, mask_1d: torch.Tensor) -> torch.Tensor:
        # h: [B, H], x: [B, D], mask_1d: [B]
        concat = torch.cat([h, x], dim=-1)                # [B, D+H]
        gates = self.linear(concat)                       # [B, 3H]
        z, r, n = gates.chunk(3, dim=-1)

        z = torch.sigmoid(z)                              # update gate
        r = torch.sigmoid(r)                              # reset gate

        r_h = r * h                                        # apply reset
        concat_rh = torch.cat([r_h, x], dim=-1)            # [B, D+H]
        n = torch.tanh(self.linear(concat_rh)[:, 2*self.rnn_hidden_dim:3*self.rnn_hidden_dim])

        h_new = (1 - z) * n + z * h                        # GRU update
        mask = mask_1d.unsqueeze(-1)                      # [B, 1]
        h_out = mask * h_new + (1 - mask) * h             # apply mask

        return h_out
class GRUEncoder(nn.Module):
    """
    A masked GRU encoder that encodes input sequences into latent states z0
    for multiple prefix times. This version uses manual masked GRU updates.
    """
    def __init__(self, gru_update_module, latent_dim, rnn_hidden_dim):
        super().__init__()
        self.latent_dim = latent_dim
        self.rnn_hidden_dim = rnn_hidden_dim


        # Manually updated GRU cell with masking
        self.gru_update = gru_update_module

        # Project h(t) to z0
        self.projection_layer = nn.Linear(rnn_hidden_dim, latent_dim)

        # ✅ LayerNorm modules
        self.h_layernorm = nn.LayerNorm(rnn_hidden_dim)
        self.z0_layernorm = nn.LayerNorm(latent_dim)
        # self.h_layernorm = RMSNorm(rnn_hidden_dim)
        # self.z0_layernorm = RMSNorm(latent_dim)

    def forward(self, x, t_scalar, mask, prefix_time_list):
        """
        Args:
            x: [B, T, D]         - Input sequences
            t_scalar: [T]        - Time vector (shared across batch)
            mask: [B, T]         - Binary mask for observations
            prefix_time_list: [S] - Prefix times to extract z0

        Returns:
            z0_all: [S, B, latent_dim]
        """
        B, T, _ = x.size()
        device = x.device
        h = torch.zeros(B, self.rnn_hidden_dim, device=device)

        # Sort and prepare prefix times
        if not isinstance(prefix_time_list, torch.Tensor):
            prefix_time_list = torch.tensor(prefix_time_list, device=device)
        prefix_time_list = torch.sort(prefix_time_list)[0]

        h_list = []
        # Only integrate up to the maximum prefix time
        t_max = prefix_time_list[-1].item()
        T_max_idx = (t_scalar <= t_max).sum().item()


        for i in range(T_max_idx):
            h = self.gru_update(h, x[:, i, :], mask[:, i])  # Masked GRU update
            h_list.append(h.clone())

        # Match each prefix time to a nearest index
        t_scalar_crop = t_scalar[:T_max_idx]
        z0_list = []
        h_T_list = []
        for t_prefix in prefix_time_list:
            idx = (t_scalar_crop <= t_prefix).sum().item() - 1
            idx = max(0, min(idx, T_max_idx - 1))
            h_cut = h_list[idx]  # [B, rnn_hidden_dim]
            h_cut = self.h_layernorm(h_cut)  # ✅ Normalize h_T
            z0 = self.projection_layer(h_cut)  # [B, latent_dim]
            # z0 = self.z0_layernorm(z0)  # ✅ Normalize z0
            z0_list.append(z0)
            h_T_list.append(h_cut)

        z0_all = torch.stack(z0_list, dim=0)     # [S, B, latent_dim]
        h_T_all = torch.stack(h_T_list, dim=0)   # [S, B, rnn_hidden_dim]
        return z0_all, h_T_all
class GRUDeltaEncoderEnhanced(nn.Module):
    """
    GRU encoder with Δt input (no masking, full parallelism).
    Processes up to max prefix time and extracts z0_all, h_T_all at different prefix times.
    """
    def __init__(self, input_dim, latent_dim, rnn_hidden_dim, num_layers=2, dropout=0.1):
        super().__init__()
        self.latent_dim = latent_dim
        self.rnn_hidden_dim = rnn_hidden_dim


        # GRU accepts [B, T, input_dim + 1], where last dim is Δt
        self.gru = nn.GRU(
            input_size=input_dim + 1,
            hidden_size=rnn_hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0.0,
        )

        # Projection from final GRU hidden state to latent z0
        # self.h_layernorm = nn.LayerNorm(rnn_hidden_dim)
        # self.z0_layernorm = nn.LayerNorm(latent_dim)

        # self.h_layernorm = RMSNorm(rnn_hidden_dim)
        # self.z0_layernorm = RMSNorm(latent_dim)

        self.projection_layer = nn.Linear(rnn_hidden_dim, latent_dim)

    def forward(self, x, t_scalar, prefix_time_list):
        """
        Args:
            x: [B, T, D]         - Input sequences
            t_scalar: [T]        - Time vector (shared across batch)
            prefix_time_list: [S] - Prefix times to extract z0

        Returns:
            z0_all: [S, B, latent_dim]
            h_T_all: [S, B, rnn_hidden_dim]
        """
        B, T, D = x.shape
        device = x.device

        # Compute delta_t and prepend zero for first time step
        delta_t = t_scalar[1:] - t_scalar[:-1]                   # [T-1]
        delta_t = F.pad(delta_t, (1, 0))                         # [T], Δt[0] = 0
        delta_t = delta_t.unsqueeze(0).expand(B, -1)             # [B, T]
        delta_t = delta_t.unsqueeze(-1)  # [B, T, 1]

        # Concatenate x with Δt: [B, T, D + 1]
        x_with_delta = torch.cat([x, delta_t], dim=-1)

        # Forward GRU: h_seq [B, T, H]
        h_seq, _ = self.gru(x_with_delta)                          # [B, T, H]

        # Sort and prepare prefix times
        if not isinstance(prefix_time_list, torch.Tensor):
            prefix_time_list = torch.tensor(prefix_time_list, device=device)
        prefix_time_list = torch.sort(prefix_time_list)[0]  # [S]

        # Extract z0/h_T at specific prefix times
        z0_list = []
        h_T_list = []
        for t_prefix in prefix_time_list:
            idx = (t_scalar <= t_prefix).sum().item() - 1
            idx = max(0, min(idx, t_scalar.shape[0] - 1))  # Clamp index safely
            h_cut = h_seq[:, idx, :]  # [B, H]
            # h_cut = self.h_layernorm(h_cut)
            z0 = self.projection_layer(h_cut)
            # z0 = self.z0_layernorm(z0)
            z0_list.append(z0)
            h_T_list.append(h_cut)

        z0_all = torch.stack(z0_list, dim=0)                     # [S, B, latent_dim]
        h_T_all = torch.stack(h_T_list, dim=0)                   # [S, B, rnn_hidden_dim]
        return z0_all, h_T_all
class RecognitionODEFunc(nn.Module):
    """
    Time-dependent ODE function for the recognition RNN.
    This entire function operates in the rnn_hidden_dim space.
    """
    def __init__(self, rnn_hidden_dim, hidden_units=64):
        super().__init__()
        # Input is now rnn_hidden_dim + 1 (for time)
        self.net = nn.Sequential(
            nn.Linear(rnn_hidden_dim + 1, hidden_units),
            nn.Tanh(),
            nn.Linear(hidden_units, hidden_units),
            nn.Tanh(),
            nn.Linear(hidden_units, rnn_hidden_dim)
        )

    def forward(self, t, h):
        """
        Calculates the derivative of the RNN's hidden state.

        Args:
            t (torch.Tensor): A scalar tensor for the current time.
            h (torch.Tensor): The RNN hidden state, shape [Batch, rnn_hidden_dim].

        Returns:
            torch.Tensor: The derivative dh/dt, shape [Batch, rnn_hidden_dim].
        """
        t_vec = torch.ones(h.shape[0], 1, device=h.device) * t
        h_and_t = torch.cat([h, t_vec], dim=1)
        return self.net(h_and_t)
class ODERNNEncoder(nn.Module):
    """
    An encoder that uses an ODE-RNN mechanism to encode sequences into latent states z0
    for multiple cutoff prefix times, efficiently evaluating all in one pass.
    """
    def __init__(self, input_dim, latent_dim, rnn_hidden_dim, hidden_units, min_step=1e-3, method='rk4'):
        super().__init__()

        # ODE function in the RNN hidden space
        self.ode_func = RecognitionODEFunc(rnn_hidden_dim, hidden_units=hidden_units)
        self.gru_update = GRUUpdate(input_dim, rnn_hidden_dim)

        self.latent_dim = latent_dim
        self.rnn_hidden_dim = rnn_hidden_dim
        self.min_step = min_step
        self.method = method  # <-- store the integration method

        # Project h(t) to z0
        self.projection_layer = nn.Linear(rnn_hidden_dim, latent_dim)

    def forward(self, x, t_scalar, mask, prefix_time_list):
        """
        Args:
            x: [B, T, D]     - Observed inputs
            t_scalar: [T]    - Time vector shared across batch (normalized or not)
            mask: [B, T]     - Observation mask
            prefix_time_list: Tensor[S] - Prefix times to stop and extract z0

        Returns:
            z0_all: [S, B, latent_dim]
        """
        B, T, D = x.size()
        device = x.device
        h = torch.zeros(B, self.rnn_hidden_dim, device=device)

        # Ensure prefix_time_list is a sorted torch tensor
        if not isinstance(prefix_time_list, torch.Tensor):
            prefix_time_list = torch.tensor(prefix_time_list, device=device)
        prefix_time_list = torch.sort(prefix_time_list)[0]

        # Only integrate up to the maximum prefix time
        t_max = prefix_time_list[-1].item()
        T_max_idx = (t_scalar <= t_max).sum().item()

        h_list = []
        for i in range(T_max_idx):
            if i > 0:
                t_prev, t_now = t_scalar[i - 1], t_scalar[i]
                dt = (t_now - t_prev).item()

                if dt < self.min_step:
                    dh = self.ode_func(t_prev, h)
                    h = h + dh * dt
                else:
                    t_points = torch.linspace(t_prev, t_now, max(2, int(dt / self.min_step)), device=device)
                    h = odeint(self.ode_func, h, t_points, method=self.method)[-1]  # <-- use self.method

            # Discrete GRU update
            h = self.gru_update(h, x[:, i, :], mask[:, i])
            h_list.append(h.clone())

        # Evaluate each prefix time: map to nearest time index and get h(t_prefix)
        z0_list = []
        for t_prefix in prefix_time_list:
            idx = (t_scalar <= t_prefix).sum().item() - 1
            idx = max(0, min(idx, len(h_list) - 1))
            h_cut = h_list[idx]  # [B, rnn_hidden_dim]
            z0 = self.projection_layer(h_cut)
            z0_list.append(z0)

        z0_all = torch.stack(z0_list, dim=0)  # [S, B, latent_dim]
        return z0_all


# Embeder families
class Z0ToEmbedMLP(nn.Module):
    """
    Maps (z0, t_token) → embedding e
    Input:  z0 ∈ [..., latent_dim], t_token ∈ [..., 1]
    Output: embedding ∈ [..., embed_dim]
    """
    def __init__(self, latent_dim, embed_dim):
        super().__init__()
        input_dim = latent_dim + 1  # z0 + time
        self.net = nn.Sequential(
            nn.Linear(input_dim, embed_dim),
            nn.Tanh(),
            nn.Linear(embed_dim, embed_dim)
        )

    def forward(self, z0, t_token):
        """
        z0: [..., latent_dim]
        t_token: [..., 1]
        """
        if t_token.shape != z0.shape[:-1] + (1,):
            raise ValueError(f"t_token must match z0 shape[:-1] + (1,), got {t_token.shape} vs {z0.shape}")
        x = torch.cat([z0, t_token], dim=-1)  # [..., latent_dim + 1]
        return self.net(x)  # [..., embed_dim]
class TimeModulation(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.fc1 = nn.Linear(1, embed_dim * 2)
        self.fc2 = nn.Linear(embed_dim * 2, embed_dim * 2)

    def forward(self, t):
        """
        t: [S, B, M, 1]
        returns: gamma [S, B, M, E], beta [S, B, M, E]
        """
        x = torch.tanh(self.fc1(t))          # [S, B, M, 2E]
        x = torch.tanh(self.fc2(x))          # [S, B, M, 2E]
        gamma, beta = torch.chunk(x, 2, dim=-1)
        return gamma, beta
class TimeModulationEmbedder(nn.Module):
    def __init__(self, rnn_hidden_dim, embed_dim):
        super().__init__()
        self.embed_dim = embed_dim

        # h_T → embedding
        self.h_proj = nn.Sequential(
            nn.Linear(rnn_hidden_dim, embed_dim),
            nn.Tanh(),
            nn.Linear(embed_dim, embed_dim)
        )

        # t_i → γ(t_i), β(t_i)
        self.time_modulation = TimeModulation(embed_dim)
        #self.norm = nn.LayerNorm(embed_dim)
        # self.norm = RMSNorm(embed_dim)

    def forward(self, h_T_tokens, t_tokens):
        """
        h_T_tokens: [S, B, M, rnn_hidden_dim]
        t_tokens:   [S, B, M, 1]

        Returns:
            e_norm:   [S, B, M, embed_dim]
        """
        h_proj = self.h_proj(h_T_tokens)              # [S, B, M, embed_dim]
        gamma, beta = self.time_modulation(t_tokens)  # [S, B, M, embed_dim] each
        e = gamma * h_proj + beta
        return e

# Position encoding
class FourierFeaturePositionalEncoding(nn.Module):
    def __init__(self, embed_dim, n=10000.0):
        super().__init__()
        if embed_dim % 2 != 0:
            raise ValueError("embed_dim must be even")
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(n) / embed_dim))
        self.register_buffer('div_term', div_term)

    def forward(self, t):
        """
        t: tensor of shape [..., 1] or [...], containing time
        Returns: [..., embed_dim]
        """
        if t.dim() == 0:
            t = t.unsqueeze(0)

        if t.shape[-1] != 1:
            t = t.unsqueeze(-1)  # Ensure shape [..., 1]

        pe = torch.zeros(*t.shape[:-1], self.div_term.shape[0] * 2, device=t.device)
        pe[..., 0::2] = torch.sin(t * self.div_term)
        pe[..., 1::2] = torch.cos(t * self.div_term)
        return pe

# Parameter Decoder
class EmbedToParamMLP(nn.Module):
    """
    Maps embedding e → (A, b) for each token.
    Input shape: [*, embed_dim]
    Output shapes:
        A: [*, d, d]
        b: [*, d]
    """
    def __init__(self, latent_dim, embed_dim):
        super().__init__()
        self.latent_dim = latent_dim
        self.net = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.Tanh(),
            nn.Linear(embed_dim, latent_dim**2 + latent_dim)  # A + b
        )

    def forward(self, e):
        """
        Args:
            e: [..., embed_dim]
        Returns:
            A: [..., d, d]
            b: [..., d]
        """
        d = self.latent_dim
        out = self.net(e)  # [..., d^2 + d]
        A_flat, b = torch.split(out, [d * d, d], dim=-1)
        A = A_flat.view(*e.shape[:-1], d, d)  # reshape to [..., d, d]
        return A, b

# Data Decoder
class PiecewiseLinearODEFunc(nn.Module):
    def __init__(self, A_token, b_token, token_times):
        """
        ODE function dz/dt = A(t)z + b(t), where A(t), b(t) are piecewise constant.

        Args:
            A_token:     [S, B, T_token, d, d] - Token-specific A matrices
            b_token:     [S, B, T_token, d]    - Token-specific b vectors
            token_times: [T_token]             - Start time of each token interval (must be sorted)
        """
        super().__init__()
        self.A_token = A_token
        self.b_token = b_token
        self.token_times = token_times  # [T_token], e.g., 80 entries for 80 tokens

    def forward(self, t, z):
        """
        Evaluates the ODE function at time t using the corresponding (A, b) for the token interval.

        Args:
            t: scalar tensor (0-D), current integration time
            z: [S, B, d], current latent state

        Returns:
            dz/dt: [S, B, d], derivative at time t
        """
        # Determine which token interval the current time t falls into
        index = torch.searchsorted(self.token_times, t, right=True) - 1
        index = index.clamp(min=0, max=self.token_times.size(0) - 1)

        # # Debug print (optional for inspection)
        # print(f"[DEBUG] t = {t.item():.5f}, using token index = {index.item()}")

        # Select A and b for the corresponding token index
        A = self.A_token[:, :, index, :, :]  # [S, B, d, d]
        b = self.b_token[:, :, index, :]     # [S, B, d]

        # Compute dz/dt = A z + b
        z_unsq = z.unsqueeze(-1)             # [S, B, d, 1]
        dz_dt = torch.matmul(A, z_unsq).squeeze(-1) + b  # [S, B, d]
        return dz_dt
def solve_piecewise_ode(z0_all, A_token, b_token, t_scalar, token_times, method='rk4'):
    """
    Solves dz/dt = A(t)z + b(t) where A, b are piecewise constant over token segments.

    Args:
        z0_all:     [S, B, d]                - Initial latent state for all scenarios
        A_token:    [S, B, T_token, d, d]    - Token-specific A matrices
        b_token:    [S, B, T_token, d]       - Token-specific b vectors
        t_scalar:   [T]                      - Full time vector for ODE integration
        token_times:[T_token]                - Initial time of each token interval (sorted, length = num_tokens)
        method:     str                      - ODE integration method (e.g., 'rk4', 'dopri5')

    Returns:
        z_traj: [S, B, T, d] - Latent trajectory over time
    """
    # Construct the time-dependent ODE function with piecewise A(t) and b(t)
    func = PiecewiseLinearODEFunc(A_token, b_token, token_times)

    # Integrate the ODE from z0_all over the full time grid t_scalar
    if method == 'dopri5':
        z_traj = odeint(func, z0_all, t_scalar, method=method, rtol=1e-3, atol=1e-4)
    else:
        z_traj = odeint(func, z0_all, t_scalar, method=method)

    return z_traj.permute(1, 2, 0, 3)  # [S, B, T, d]

class PiecewiseNonlinearODEFunc(nn.Module):
    def __init__(self, embedding_tokens, token_times, mlp_func):
        """
        ODE function dz/dt = f(z, e_t), where e_t is the contextual embedding
        for each token interval (piecewise constant).

        Args:
            embedding_tokens: [S, B, M, E] - token embeddings after attention
            token_times: [M]              - token start times (must be sorted)
            mlp_func: nn.Module           - MLP that maps [z, e_t] → dz/dt
        """
        super().__init__()
        self.embedding_tokens = embedding_tokens
        self.token_times = token_times
        self.mlp_func = mlp_func

    def forward(self, t, z):
        """
        Evaluate the nonlinear ODE function at time t.

        Args:
            t: scalar tensor
            z: [S, B, d]

        Returns:
            dz_dt: [S, B, d]
        """
        # Find token interval index
        index = torch.searchsorted(self.token_times, t, right=True) - 1
        index = index.clamp(min=0, max=self.token_times.size(0) - 1)

        # Select context embedding for this token
        e_t = self.embedding_tokens[:, :, index, :]  # [S, B, E]

        # Concatenate z and e_t
        z_e = torch.cat([z, e_t], dim=-1)  # [S, B, d+E]

        # Apply MLP
        dz_dt = self.mlp_func(z_e)  # [S, B, d]

        return dz_dt


class NonlinearDecoderMLP(nn.Module):
    def __init__(self, latent_dim, embed_dim, hidden_units=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim + embed_dim, hidden_units),
            nn.Tanh(),
            nn.Linear(hidden_units, hidden_units),
            nn.Tanh(),
            nn.Linear(hidden_units, latent_dim)
        )

    def forward(self, z_and_e):
        return self.net(z_and_e)


def solve_nonlinearpiecewise_ode(z0_all, embedding_tokens, t_scalar, token_times, mlp_func, method='rk4'):
    """
    Solves dz/dt = f(z, e_t) with piecewise constant context embeddings.

    Args:
        z0_all:           [S, B, d]       - Initial latent states
        embedding_tokens: [S, B, M, E]    - Token embeddings (after attention)
        t_scalar:         [T]             - Full time grid
        token_times:      [M]             - Start times of each token interval
        mlp_func:         nn.Module       - MLP that maps [z, e_t] → dz/dt
        method:           str             - ODE method ('rk4', 'dopri5', etc.)

    Returns:
        z_traj: [S, B, T, d] - Latent trajectory over time
    """
    func = PiecewiseNonlinearODEFunc(embedding_tokens, token_times, mlp_func)
    if method == 'dopri5':
        z_traj = odeint(func, z0_all, t_scalar, method=method, rtol=1e-3, atol=1e-4)
    else:
        z_traj = odeint(func, z0_all, t_scalar, method=method)

    return z_traj.permute(1, 2, 0, 3)

class ReconstructMLP(nn.Module):
    """
    Maps z_traj ∈ [..., latent_dim] to x_traj ∈ [..., output_dim]
    Supports arbitrary leading dimensions (e.g., [S, B, T, d] or [B, T, d]).
    """
    def __init__(self, latent_dim, output_dim, hidden_units=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden_units),
            nn.Tanh(),
            nn.Linear(hidden_units, output_dim)
        )
        # self.layernorm = nn.LayerNorm(latent_dim)
        # self.layernorm = RMSNorm(latent_dim)

    def forward(self, z_traj):
        """
        Args:
            z_traj: [..., latent_dim]

        Returns:
            x_traj: [..., output_dim]
        """
        orig_shape = z_traj.shape[:-1]  # save [...]
        d = z_traj.shape[-1]

        # Use reshape instead of view for non-contiguous inputs
        z_flat = z_traj.reshape(-1, d)  # [N, latent_dim]
        # z_flat = self.layernorm(z_flat)
        x_flat = self.net(z_flat)  # [N, output_dim]

        return x_flat.view(*orig_shape, -1)  # [..., output_dim]

# intra attention within a system
class TokenSelfAttentionBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, dropout=0.1, mlp_ratio=4.0, use_rope=False, rotary_dim=None):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        # self.ln1 = RMSNorm(embed_dim)

        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

        self.ln2 = nn.LayerNorm(embed_dim)
        # self.ln2 = RMSNorm(embed_dim)

        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, embed_dim)
        )

        self.embed_dim = embed_dim
        self.use_rope = use_rope
        self.rotary_dim = rotary_dim if rotary_dim is not None else embed_dim

        self.d_head = embed_dim // num_heads
        self.rotary_dim = min(self.rotary_dim, self.d_head)
        assert self.rotary_dim % 2 == 0, "rotary_dim must be even"


        # Manual projection layers for RoPE support
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        """
        x: [S, B, M, E]
        """
        S, B, M, E = x.shape
        x_flat = x.view(S * B, M, E)  # [SB, M, E]
        x_norm = self.ln1(x_flat)  # pre-LN input

        if self.use_rope:
            # Manually project Q, K, V for RoPE
            q = self.q_proj(x_norm)  # [SB, M, E]
            k = self.k_proj(x_norm)
            v = self.v_proj(x_norm)

            # Reshape to [SB, M, H, d_head]
            q = q.view(S * B, M, self.attn.num_heads, self.d_head)
            k = k.view(S * B, M, self.attn.num_heads, self.d_head)

            # Build RoPE embeddings
            sin, cos = build_rope_sin_cos(seq_len=M, rotary_dim=self.rotary_dim, device=x.device)

            # Apply RoPE to first rotary_dim
            q = torch.cat([
                apply_rope(q[..., :self.rotary_dim], sin, cos),
                q[..., self.rotary_dim:]
            ], dim=-1)
            k = torch.cat([
                apply_rope(k[..., :self.rotary_dim], sin, cos),
                k[..., self.rotary_dim:]
            ], dim=-1)

            # Reshape back to [SB, M, E]
            q = q.reshape(S * B, M, E)
            k = k.reshape(S * B, M, E)

            # Apply attention
            attn_output, _ = self.attn(q, k, v)  # default attention
        else:
            # Use default nn.MultiheadAttention projection and flow
            attn_output, _ = self.attn(x_norm, x_norm, x_norm)

        # Residual connection (Pre-LN)
        x_attn = x_flat + self.dropout(attn_output)

        # MLP block with second residual
        x2 = self.ln2(x_attn)
        mlp_out = self.mlp(x2)
        x_out = x_attn + self.dropout(mlp_out)

        return x_out.view(S, B, M, E)

class TokenSelfAttentionStack(nn.Module):
    def __init__(self, embed_dim, num_heads=4, depth=4, dropout=0.1, mlp_ratio=4.0):
        super().__init__()
        self.blocks = nn.ModuleList([
            TokenSelfAttentionBlock(
                embed_dim=embed_dim,
                num_heads=num_heads,
                dropout=dropout,
                mlp_ratio=mlp_ratio
            )
            for _ in range(depth)
        ])

    def forward(self, x):
        """
        x: [S, B, M, E] — token embeddings
        """
        for block in self.blocks:
            x = block(x)
        return x

# inter attention across systems
class TokenPrimitiveCrossAttentionBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, dropout=0.1, mlp_ratio=4.0):
        super().__init__()
        self.ln_q = nn.LayerNorm(embed_dim)
        # self.ln_q = RMSNorm(embed_dim)

        self.ln_kv = nn.LayerNorm(embed_dim)
        # self.ln_kv = RMSNorm(embed_dim)

        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

        self.ln2 = nn.LayerNorm(embed_dim)
        # self.ln2 = RMSNorm(embed_dim)

        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, embed_dim)
        )

    def forward(self, tokens, primitives):
        """
        tokens: [S, B, M, E] — token embeddings (query)
        primitives: [K, E] — shared OPL (key/value), passed externally from LassODE
        """
        S, B, M, E = tokens.shape
        tokens_flat = tokens.view(S * B, M, E)         # [SB, M, E]

        q = self.ln_q(tokens_flat)                     # [SB, M, E]
        kv = self.ln_kv(primitives)                    # [K, E]
        kv = kv.unsqueeze(0).expand(S * B, -1, -1)     # [SB, K, E]

        # Cross-attention: tokens attend to primitives
        attn_out, _ = self.attn(q, kv, kv)             # [SB, M, E]
        x_attn = tokens_flat + self.dropout(attn_out)  # Residual

        # FFN + residual
        x2 = self.ln2(x_attn)
        mlp_out = self.mlp(x2)
        x_out = x_attn + self.dropout(mlp_out)
        return x_out.view(S, B, M, E)
class TokenPrimitiveCrossAttentionStack(nn.Module):
    def __init__(self, embed_dim, num_heads=4, depth=4, dropout=0.1, mlp_ratio=4.0):
        super().__init__()
        self.blocks = nn.ModuleList([
            TokenPrimitiveCrossAttentionBlock(
                embed_dim=embed_dim,
                num_heads=num_heads,
                dropout=dropout,
                mlp_ratio=mlp_ratio
            )
            for _ in range(depth)
        ])

    def forward(self, tokens, primitives):
        """
        tokens: [S, B, M, E] — token embeddings
        primitives: [K, E] — shared primitive library (passed from LassODE)
        """
        for block in self.blocks:
            tokens = block(tokens, primitives)
        return tokens

# premitive attention updates and GRU dynamic hub, similar to slot attention, NeurIPS 2020
class DynamicHubUpdateBlock(nn.Module):
    """
    Refined update mechanism for dynamic hub (primitive library).
    Each primitive slot (row of H) attends to all tokens, and is updated via GRUCell.

    Input:
        primitive_library: [K, E] — learnable primitive slots (hub)
        tokens: [S, B, M, E] — token embeddings across scenarios and batches
    Output:
        H_updated: [K, E] — updated primitive library
    """
    def __init__(self, embed_dim, num_heads=4, dropout=0.1):
        super().__init__()
        self.ln_q = nn.LayerNorm(embed_dim)
        self.ln_kv = nn.LayerNorm(embed_dim)

        # self.ln_q = RMSNorm(embed_dim)
        # self.ln_kv = RMSNorm(embed_dim)

        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

        self.ln_delta = nn.LayerNorm(embed_dim)
        # self.ln_delta = RMSNorm(embed_dim)

        self.gru_cell = nn.GRUCell(embed_dim, embed_dim)

    def forward(self, primitive_library, tokens):
        """
        Arguments:
            primitive_library: [K, E] — shared hub vectors (query)
            tokens: [S, B, M, E] — token embeddings (key/value)
        Returns:
            Updated primitive_library: [K, E]
        """
        S, B, M, E = tokens.shape
        K = primitive_library.size(0)

        # Flatten tokens across S, B → [SB, M, E]
        tokens_flat = tokens.view(S * B, M, E)

        # Apply LayerNorm
        q = self.ln_q(primitive_library)                    # [K, E]
        kv = self.ln_kv(tokens_flat)                        # [SB, M, E]

        # Expand queries to match batch: [SB, K, E]
        q_expand = q.unsqueeze(0).expand(S * B, -1, -1)     # [SB, K, E]

        # Cross-attention: each primitive slot attends to all tokens
        attn_out, _ = self.attn(q_expand, kv, kv)           # [SB, K, E]

        # Aggregate across batch (SB) — mean update per slot
        delta_H = attn_out.mean(dim=0)                      # [K, E]
        delta_H = self.ln_delta(delta_H)                    # normalize

        # Update each slot via GRU
        H_updated = self.gru_cell(delta_H, primitive_library)  # [K, E]

        return H_updated


# Main function Lass ODE

# the one with only intra attention
class LassODE_intra(nn.Module):
    """
    LassODE: Large-scale Small-ODE model integrating encoder, embedding, attention, ODE solver, and decoder.
    """

    def __init__(self, encoder, embedder, pos_encoder, param_decoder, reconstruct_mlp, method,
                 attn_stack=None):
        super().__init__()
        self.method = method
        self.encoder = encoder
        self.embedder = embedder
        self.pos_encoder = pos_encoder
        self.param_decoder = param_decoder
        self.reconstruct_mlp = reconstruct_mlp

        self.attn_stack = attn_stack  # 🆕 Optional stacked attention module

    def tokenize_time(self, t_scalar, num_tokens):
        T = t_scalar.shape[0]
        chunk_len = T // num_tokens

        # if T % num_tokens != 0:
        #     print(f"[WARNING] T={T} is not divisible by num_tokens={num_tokens}. "
        #           f"Last {T % num_tokens} steps will be ignored.")

        token_indices = [i * chunk_len for i in range(num_tokens)]
        token_times = t_scalar[token_indices]  # [num_tokens]
        return token_times

    def generate_prefix_times(self, t_scalar, num_scenarios):
        t_start = t_scalar[0].item()
        t_end = t_scalar[-1].item()
        prefix_times = torch.sort(torch.rand(num_scenarios) * (t_end - t_start) + t_start)[0]
        return prefix_times  # [S]

    def forward(self, x_batch, t_scalar, mask_batch, prefix_time_list, token_times):
        B, T, data_dim = x_batch.shape
        num_scenarios = len(prefix_time_list)
        num_tokens = token_times.shape[0]
        rnn_hidden_dim = self.encoder.rnn_hidden_dim

        # Step 1: Broadcast token_times
        token_times_broadcasted = token_times.view(1, 1, num_tokens, 1).expand(num_scenarios, B, num_tokens, 1)

        # Step 2: Encode z0 and hT
        # z0_all, h_T_all = self.encoder(x_batch, t_scalar, mask_batch, prefix_time_list) # traditional GRU
        z0_all, h_T_all = self.encoder(x_batch, t_scalar, prefix_time_list) # GRU-Deltat

        hT_tokens = h_T_all.unsqueeze(2).expand(num_scenarios, B, num_tokens, rnn_hidden_dim)

        # Step 3: Compute embeddings
        z_embed = self.embedder(hT_tokens, token_times_broadcasted)      # [S, B, M, E]
        pos_encoding = self.pos_encoder(token_times_broadcasted)         # [S, B, M, E]
        final_embeddings = z_embed + pos_encoding                        # [S, B, M, E]

        # Step 4: Apply attention (🆕)
        if self.attn_stack is not None:
            final_embeddings = self.attn_stack(final_embeddings)

        # Step 5: Decode to A, b
        A_tokens, b_tokens = self.param_decoder(final_embeddings)

        # Step 6: Solve ODE
        z_traj = solve_piecewise_ode(z0_all, A_tokens, b_tokens, t_scalar, token_times, method=self.method)

        # Step 7: Reconstruct
        x_traj = self.reconstruct_mlp(z_traj)

        return x_traj

    def compute_loss(self, x_traj, x_batch):
        x_batch_expanded = x_batch.unsqueeze(0)  # [1, B, T, D]
        return ((x_traj - x_batch_expanded) ** 2).mean()

# the one with parallel intra and inter attention
class LassODE_parrllel_intra_inter(nn.Module):
    """
    LassODE: Large-scale Small-ODE model with parallel intra (token-token) and inter (token-primitive) attention.
    """

    def __init__(self, encoder, embedder, pos_encoder, param_decoder, reconstruct_mlp, method,
                 intra_attn_stack=None, inter_attn_stack=None, num_primitives=16, embed_dim=128):
        super().__init__()
        self.method = method
        self.encoder = encoder
        self.embedder = embedder
        self.pos_encoder = pos_encoder
        self.param_decoder = param_decoder
        self.reconstruct_mlp = reconstruct_mlp

        self.intra_attn_stack = intra_attn_stack  # 🔹 Token-token self-attention
        self.inter_attn_stack = inter_attn_stack  # 🔸 Token-primitive cross-attention

        # 🔸 Shared primitive library for inter attention: [K, E]
        self.primitive_library = nn.Parameter(torch.randn(num_primitives, embed_dim))
        nn.init.trunc_normal_(self.primitive_library, std=0.02)

    def tokenize_time(self, t_scalar, num_tokens):
        T = t_scalar.shape[0]
        chunk_len = T // num_tokens

        # if T % num_tokens != 0:
        #     print(f"[WARNING] T={T} is not divisible by num_tokens={num_tokens}. "
        #           f"Last {T % num_tokens} steps will be ignored.")

        token_indices = [i * chunk_len for i in range(num_tokens)]
        token_times = t_scalar[token_indices]  # [num_tokens]
        return token_times

    def generate_prefix_times(self, t_scalar, num_scenarios):
        t_start = t_scalar[0].item()
        t_end = t_scalar[-1].item()
        prefix_times = torch.sort(torch.rand(num_scenarios) * (t_end - t_start) + t_start)[0]
        return prefix_times  # [S]

    def forward(self, x_batch, t_scalar, mask_batch, prefix_time_list, token_times):
        B, T, data_dim = x_batch.shape
        S = len(prefix_time_list)
        M = token_times.shape[0]
        E = self.primitive_library.shape[-1]

        # Step 1: Broadcast token_times
        token_times_broadcasted = token_times.view(1, 1, M, 1).expand(S, B, M, 1)

        # Step 2: Encode z0 and h_T
        z0_all, h_T_all = self.encoder(x_batch, t_scalar, prefix_time_list)  # [S, B, H]
        hT_tokens = h_T_all.unsqueeze(2).expand(S, B, M, self.encoder.rnn_hidden_dim)

        # Step 3: Embedding
        z_embed_intra = self.embedder(hT_tokens, token_times_broadcasted)     # [S, B, M, E]
        pos_encoding = self.pos_encoder(token_times_broadcasted)              # [S, B, M, E]

        # Step 4a: Inter Attention (tokens ↔ primitives)
        if self.inter_attn_stack is not None:
            z_embed_inter = self.inter_attn_stack(z_embed_intra, self.primitive_library)
        else:
            z_embed_inter = 0.0

        # Step 4b: Intra Attention (tokens ↔ tokens)
        z_intra_input = z_embed_intra + pos_encoding
        if self.intra_attn_stack is not None:
            z_embed_intra_refined = self.intra_attn_stack(z_intra_input)  # [S, B, M, E]
        else:
            z_embed_intra_refined = z_intra_input

        # Step 4c: Merge Parallel Outputs
        final_embeddings = z_embed_intra_refined  + z_embed_inter          # [S, B, M, E]

        # Step 5: Decode to A, b
        A_tokens, b_tokens = self.param_decoder(final_embeddings)

        # Step 6: Solve ODE
        z_traj = solve_piecewise_ode(z0_all, A_tokens, b_tokens, t_scalar, token_times, method=self.method)

        # Step 7: Reconstruct
        x_traj = self.reconstruct_mlp(z_traj)

        return x_traj

    def compute_loss(self, x_traj, x_batch):
        x_batch_expanded = x_batch.unsqueeze(0)  # [1, B, T, D]
        return ((x_traj - x_batch_expanded) ** 2).mean()

class LassODE(nn.Module):
    """
    LassODE: Large-scale Small-ODE model with parallel intra (token-token) and inter (token-primitive) attention.
    """

    def __init__(self, encoder, embedder, pos_encoder, param_decoder, reconstruct_mlp, method,
                 intra_attn_stack=None, inter_attn_stack=None, dynamic_hub_block=None, num_blocks=4, num_primitives=16, embed_dim=128):
        super().__init__()
        self.method = method
        self.encoder = encoder
        self.embedder = embedder
        self.pos_encoder = pos_encoder
        self.param_decoder = param_decoder
        self.reconstruct_mlp = reconstruct_mlp

        self.num_blocks = num_blocks
        self.intra_attn_stack = intra_attn_stack  # 🔹 Token-token self-attention
        self.inter_attn_stack = inter_attn_stack  # 🔸 Token-primitive cross-attention
        self.dynamic_hub_block = dynamic_hub_block

        # 🔸 Shared primitive library for inter attention: [K, E]
        self.primitive_library = nn.Parameter(torch.randn(num_primitives, embed_dim))
        nn.init.trunc_normal_(self.primitive_library, std=0.02)

        self.gate_layer = nn.Sequential(
            nn.LayerNorm(2 * embed_dim),  # applies along last dim (E)
            # RMSNorm(2 * embed_dim),  # applies along last dim (E)
            nn.Linear(2 * embed_dim, embed_dim),  # also applies per feature vector (along E)
            nn.Sigmoid()  # elementwise
        )

    def tokenize_time(self, num_tokens, device="cpu", T=1.0):
        """
        Define token start times by splitting [0, T) into `num_tokens` equal segments.

        Args:
            num_tokens: Number of tokens.
            device: Torch device.
            T: Time horizon (default=1.0). Can be 2, 5, 10, etc.

        Returns:
            token_times: Tensor [num_tokens], start times of each token in [0, T).
        """
        step = T / num_tokens
        token_times = torch.arange(0, num_tokens, device=device, dtype=torch.float32) * step
        return token_times

    def generate_prefix_times(self, t_scalar, num_scenarios):
        t_start = t_scalar[0].item()
        t_end = t_scalar[-1].item()
        prefix_times = torch.sort(torch.rand(num_scenarios) * (t_end - t_start) + t_start)[0]
        return prefix_times  # [S]

    def forward(self, x_batch, t_scalar, t_target, prefix_time_list, token_times):
        """
        Args:
            x_batch: [B, T, D] input sequence
            t_scalar: [T] observation times (for encoder)
            t_target: [T_out] output times (for decoder)
            prefix_time_list: [S] prefix cutoff times
            token_times: [num_tokens] token anchors in [0, 1)
        """
        B, T, data_dim = x_batch.shape # number of batches, total time, data_dim
        S = len(prefix_time_list) # number of scenarios
        M = token_times.shape[0] # number of tokens
        E = self.primitive_library.shape[-1] # embed_dim

        # Step 1: Broadcast token_times
        token_times_broadcasted = token_times.view(1, 1, M, 1).expand(S, B, M, 1)

        # Step 2: Encode z0 and h_T
        z0_all, h_T_all = self.encoder(x_batch, t_scalar, prefix_time_list)  # [S, B, H]
        hT_tokens = h_T_all.unsqueeze(2).expand(S, B, M, self.encoder.rnn_hidden_dim)

        # Step 3: Initial Embedding + Positional Encoding
        z_embed = self.embedder(hT_tokens, token_times_broadcasted)  # [S, B, M, E]
        pos_encoding = self.pos_encoder(token_times_broadcasted)  # [S, B, M, E]
        z_embed = z_embed + pos_encoding

        primitive_library_updated = self.primitive_library
        # Step 4: Iterative Refinement through Multi-Block Attention

        # Store embeddings from each block
        z_embed_blocks = []

        # Step 4: Iterative Refinement through Multi-Block Attention
        for l in range(self.num_blocks):
            # 🔹 Intra Attention (token-token)
            if self.intra_attn_stack is not None:
                z_embed = self.intra_attn_stack(z_embed)  # [S, B, M, E]

            # 🔸 Inter Attention (token ↔ primitive)
            if self.inter_attn_stack is not None:
                z_embed_inter = self.inter_attn_stack(z_embed, primitive_library_updated)
            else:
                z_embed_inter = torch.zeros_like(z_embed)  # Ensure compatibility

            # 🧩 Optional gating layer (skip if None)
            if self.gate_layer is not None:
                gate_input = torch.cat([z_embed, z_embed_inter], dim=-1)  # [S, B, M, 2E]
                gate = self.gate_layer(gate_input)  # [S, B, M, E]
                z_embed = z_embed + gate * z_embed_inter  # Gated residual
            else:
                z_embed = z_embed + z_embed_inter

            # 🔁 Save embedding for final aggregation
            z_embed_blocks.append(z_embed)

            # 🔄 Dynamic Hub Update
            if self.dynamic_hub_block is not None:
                primitive_library_updated = self.dynamic_hub_block(primitive_library_updated, z_embed)

        # 🔚 Aggregate embeddings across all blocks (e.g., mean)
        final_embeddings = torch.stack(z_embed_blocks, dim=0).mean(dim=0)  # [S, B, M, E]

        # Step 5: Decode to A, b
        A_tokens, b_tokens = self.param_decoder(final_embeddings)

        # Step 6: Solve ODE
        z_traj = solve_piecewise_ode(z0_all, A_tokens, b_tokens, t_target, token_times, method=self.method)

        # Step 7: Reconstruct
        x_traj = self.reconstruct_mlp(z_traj)

        return x_traj

    def compute_loss(self, x_traj, x_batch, true_dim=None):
        """
        Args:
            x_traj: [S, B, T, D]  - reconstructed trajectories
            x_batch: [B, T, D]    - true padded data
            true_dim: int or None - number of true channels for this system
        """
        x_batch_expanded = x_batch.unsqueeze(0)  # [1, B, T, D]

        if true_dim is not None:
            x_traj = x_traj[..., :true_dim]
            x_batch_expanded = x_batch_expanded[..., :true_dim]

        return ((x_traj - x_batch_expanded) ** 2).mean()

class LassODE_nonlinearODE_inter_intra_block(nn.Module):
    """
    LassODE: Large-scale Small-ODE model with parallel intra (token-token) and inter (token-primitive) attention.
    """

    def __init__(self, encoder, embedder, pos_encoder, nonlinear_decoder, reconstruct_mlp, method,
                 intra_attn_stack=None, inter_attn_stack=None, dynamic_hub_block=None, num_blocks=4, num_primitives=16, embed_dim=128):
        super().__init__()
        self.method = method
        self.encoder = encoder
        self.embedder = embedder
        self.pos_encoder = pos_encoder
        self.nonlinear_decoder = nonlinear_decoder
        self.reconstruct_mlp = reconstruct_mlp

        self.num_blocks = num_blocks
        self.intra_attn_stack = intra_attn_stack  # 🔹 Token-token self-attention
        self.inter_attn_stack = inter_attn_stack  # 🔸 Token-primitive cross-attention
        self.dynamic_hub_block = dynamic_hub_block

        # 🔸 Shared primitive library for inter attention: [K, E]
        self.primitive_library = nn.Parameter(torch.randn(num_primitives, embed_dim))
        nn.init.trunc_normal_(self.primitive_library, std=0.02)

        self.gate_layer = nn.Sequential(
            nn.LayerNorm(2 * embed_dim),  # applies along last dim (E)
            # RMSNorm(2 * embed_dim),  # applies along last dim (E)
            nn.Linear(2 * embed_dim, embed_dim),  # also applies per feature vector (along E)
            nn.Sigmoid()  # elementwise
        )

    def tokenize_time(self, t_scalar, num_tokens):
        T = t_scalar.shape[0]
        chunk_len = T // num_tokens

        # if T % num_tokens != 0:
        #     print(f"[WARNING] T={T} is not divisible by num_tokens={num_tokens}. "
        #           f"Last {T % num_tokens} steps will be ignored.")

        token_indices = [i * chunk_len for i in range(num_tokens)]
        token_times = t_scalar[token_indices]  # [num_tokens]
        return token_times

    def generate_prefix_times(self, t_scalar, num_scenarios):
        t_start = t_scalar[0].item()
        t_end = t_scalar[-1].item()
        prefix_times = torch.sort(torch.rand(num_scenarios) * (t_end - t_start) + t_start)[0]
        return prefix_times  # [S]

    def forward(self, x_batch, t_scalar, prefix_time_list, token_times):
        B, T, data_dim = x_batch.shape # number of batches, total time, data_dim
        S = len(prefix_time_list) # number of scenarios
        M = token_times.shape[0] # number of tokens
        E = self.primitive_library.shape[-1] # embed_dim

        # Step 1: Broadcast token_times
        token_times_broadcasted = token_times.view(1, 1, M, 1).expand(S, B, M, 1)

        # Step 2: Encode z0 and h_T
        z0_all, h_T_all = self.encoder(x_batch, t_scalar, prefix_time_list)  # [S, B, H]
        hT_tokens = h_T_all.unsqueeze(2).expand(S, B, M, self.encoder.rnn_hidden_dim)

        # Step 3: Initial Embedding + Positional Encoding
        z_embed = self.embedder(hT_tokens, token_times_broadcasted)  # [S, B, M, E]
        pos_encoding = self.pos_encoder(token_times_broadcasted)  # [S, B, M, E]
        z_embed = z_embed + pos_encoding

        primitive_library_updated = self.primitive_library
        # Step 4: Iterative Refinement through Multi-Block Attention
        for l in range(self.num_blocks):
            # Intra Attention (token-token)
            if self.intra_attn_stack is not None:
                z_embed = self.intra_attn_stack(z_embed)  # [S, B, M, E]

            # Inter Attention (token ↔ primitive)
            if self.inter_attn_stack is not None:
                z_embed_inter = self.inter_attn_stack(z_embed, primitive_library_updated)
            else:
                z_embed_inter = 0.0

            #z_embed = z_embed + z_embed_inter  # [S, B, M, E]
            # optional gates
            gate_input = torch.cat([z_embed, z_embed_inter], dim=-1)  # [S, B, M, 2E]
            gate = self.gate_layer(gate_input)  # [S, B, M, E]
            z_embed = z_embed + gate * z_embed_inter

            # Dynamic Hub Update (primitive update from tokens)
            if self.dynamic_hub_block is not None:
                primitive_library_updated = self.dynamic_hub_block(primitive_library_updated, z_embed)


        # Step 6: Solve ODE
        z_traj = solve_nonlinearpiecewise_ode(
            z0_all=z0_all,
            embedding_tokens=z_embed,
            t_scalar=t_scalar,
            token_times=token_times,
            mlp_func=self.nonlinear_decoder,  # 🆕 Pass your nonlinear decoder module here
            method=self.method
        )
        # Step 7: Reconstruct
        x_traj = self.reconstruct_mlp(z_traj)

        return x_traj

    def compute_loss(self, x_traj, x_batch):
        x_batch_expanded = x_batch.unsqueeze(0)  # [1, B, T, D]
        return ((x_traj - x_batch_expanded) ** 2).mean()
