import torch
import numpy as np
from scipy.linalg import logm, expm, norm  # Schur–Parlett based logm
import torch.nn as nn
from torchdiffeq import odeint_adjoint as odeint
import torch.nn.functional as F

from memKNO.network import MLP


def _to_numpy(a: torch.Tensor) -> np.ndarray:
    # Move to CPU and convert to float64 numpy for SciPy
    return a.detach().cpu().to(torch.float64).numpy()

def _to_torch(a_np: np.ndarray, like: torch.Tensor) -> torch.Tensor:
    # Back to torch with original dtype/device
    return torch.from_numpy(a_np).to(dtype=like.dtype, device=like.device)

def discrete_to_continuous_via_logm(Ad: torch.Tensor, dt: float,
                                    jitter: float = 1e-8,
                                    rel_res_tol: float = 5e-5):
    """
    Prefer Schur-Parlett logm; validate; fall back if needed.

    Args:
        Ad: [D,D] discrete one-step (column form!) matrix
        dt: scalar step size
        jitter: small diagonal push to avoid singularity / branch-cut issues
        rel_res_tol: tolerance for ||expm(Ac*dt)-Ad|| / ||Ad|| check

    Returns:
        Ac: [D,D] continuous-time generator such that exp(Ac*dt) ≈ Ad
    """
    assert Ad.ndim == 2 and Ad.size(0) == Ad.size(1), "Ad must be square"
    D = Ad.size(0)
    device, dtype = Ad.device, Ad.dtype

    # Preflight: push spectrum slightly away from 0 / negative real axis
    I = torch.eye(D, device=device, dtype=dtype)
    Ad_safe = Ad + jitter * I

    # Try Schur–Parlett logm in float64 on CPU
    try:
        Ad_np = _to_numpy(Ad_safe)
        L_np = logm(Ad_np)          # complex128 matrix log
        Ac_np = (L_np / dt)
        # If result has tiny imaginary part from numerics, drop it
        if np.max(np.abs(Ac_np.imag)) < 1e-10:
            Ac_np = Ac_np.real

        # Sanity check: exp(Ac*dt) must reconstruct Ad within tolerance
        rec_np = expm(Ac_np * dt)
        den = max(norm(Ad_np), 1.0)
        rel_res = norm(rec_np - Ad_np) / den
        if not np.isfinite(Ac_np).all() or rel_res > rel_res_tol:
            raise RuntimeError(f"logm residual too large: {rel_res:.3e}")

        return _to_torch(Ac_np, like=Ad)

    except Exception as e:
        # Fallback 1: if Ad close to I, use series log(I+X) with 20 terms
        X = Ad - I
        if torch.linalg.norm(X).item() < 0.5:
            # log(I+X) ≈ ∑_{k=1}^K (-1)^{k+1} X^k / k
            K = 20
            term = X.clone()
            L = term.clone()
            sign = -1.0
            for k in range(2, K+1):
                term = term @ X
                L = L + (sign * term) / float(k)
                sign *= -1.0
            Ac_series = L / dt
            return Ac_series

        # Fallback 2 (very rough): Euler init Ac ≈ (Ad - I) / dt
        # Good enough for initialization if everything else fails
        Ac_euler = (Ad - I) / float(dt)
        return Ac_euler



def _init_gate_bias(module: nn.Module, value: float) -> None:
    """Fill the bias of all Linear layers inside `module` with `value`."""
    for m in module.modules():
        if isinstance(m, nn.Linear):
            with torch.no_grad():
                m.bias.fill_(value)


def _init_last_linear_bias(module: nn.Module, value: float):
    last = None
    for m in module.modules():
        if isinstance(m, nn.Linear):
            last = m
    if last is not None:
        with torch.no_grad():
            last.bias.fill_(value)



class LatentODEfunc(nn.Module):
    def __init__(self, state_dim, code_dim, hidden_dim, num_layers=3, nl='swish', **kwargs):
        super(LatentODEfunc, self).__init__()
        input_dim = code_dim * state_dim
        self.net = MLP(
            in_dim=input_dim,    # in_dim = out_dim = input_dim
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            nl=nl
        )

    def forward(self, t, u):
        return self.net(u)
    


class LatentProcess(nn.Module):
    """
    Continuous-time latent process with the following modes:
      - "linear"         : d alpha / dt = A alpha
      - "linear+memory"  : d alpha / dt = A alpha + phi_dec(m)
                           d m / dt     = phi_enc(alpha) - Lambda * m
                           with phi_enc: R^{D_alpha} -> R^{D_m}, phi_dec: R^{D_m} -> R^{D_alpha}
      - "neural_ode"     : d alpha / dt = f(alpha)  (provided via ode_func)
      -----------------------------------------------------------------------------------------------
      - "gru_memory"      : d alpha / dt = A alpha + phi_dec(m)
                           d m / dt     = tau_m^{-1} * [ z(α,m) ⊙ (\tilde m(α, r(α,m)⊙m) - m) ]
      - "lstm_memory"     : d alpha / dt = A alpha + phi_dec(m)
                           d c / dt     = tau_c^{-1} * [ f(α,m) ⊙ c + i(α,m) ⊙ \tilde c(α,m) - c ]
                           d m / dt     = tau_h^{-1} * [ o(α,m) ⊙ tanh(c) - m ]

    Returns:
      (alpha_t, memory_t)
        alpha_t : [T, B, D_alpha]
        memory_t: [T, B, D_m] or None (None for "linear" / "neural_ode")
    """

    def __init__(
        self, state_dim: int, code_dim: int,
        latent_type: str = "linear+memory",
        ode_func: nn.Module = None,
        solver: str = "rk4",
        # memory settings
        memory_dim: int | None = None,
        # encoder/decoder MLP hyperparams
        enc_hidden_dim: int | None = None,
        dec_hidden_dim: int | None = None,
        enc_layers: int = 2,
        dec_layers: int = 2,
        nl: str = "swish",
        # linear operator parameterization (Scheme A: dense pH)
        linear_param: str = "free",                  # {"free", "pH_dense"}
        ph_osc_dims: int = 0,                        # oscillatory subspace dim (>=0)
        ph_epsP: float = 1e-6,                       # ε for P = L^T L + ε I
        ph_min_freq: float = 1e-3,                   # frequency lower bound for oscillatory modes
        return_aux: bool = False,                    # return ||phi_dec(m)||^2
    ):
        super().__init__()
        assert latent_type in {"linear", "linear+memory", "neural_ode", "gru_memory", "lstm_memory"}
        self.latent_type = latent_type
        self.state_dim = state_dim
        self.code_dim = code_dim
        self.latent_dim = state_dim * code_dim  # D_alpha: latent_dim
        self.memory_dim = int(memory_dim) if memory_dim is not None else self.latent_dim  # D_m
        self.solver = solver
        self.return_aux = bool(return_aux)

        # Initialize linear drift for alpha in linear / linear+memory / gru / lstm
        self.linear_param = linear_param
        if self.latent_type in {"linear", "linear+memory", "gru_memory", "lstm_memory"}:
            if self.linear_param == "free":
                self.linear = nn.Parameter(torch.zeros(self.latent_dim, self.latent_dim))
            elif self.linear_param == "pH_dense":
                D = self.latent_dim
                # J = 0.5*(B - B^T)
                self.ph_B = nn.Parameter(torch.zeros(D, D))
                # R = P_perp (C^T C) P_perp, where P_perp = I - U U^T
                self.ph_C = nn.Parameter(torch.zeros(D, D))
                # P = L^T L + eps I
                self.ph_L = nn.Parameter(torch.eye(D))
                self.ph_epsP = float(ph_epsP)
                # oscillatory subspace U from QR of W[:, :d]
                d = max(0, min(int(ph_osc_dims), D))
                self.ph_d = d
                self.ph_W = nn.Parameter(torch.randn(D, max(d, 1)))
                k = max(1, d // 2) if d > 0 else 1
                self.ph_omega_raw = nn.Parameter(torch.zeros(k))  # learnable raw
                self.ph_min_freq = float(ph_min_freq)
                self._ph_cache_active: bool = False
                self._ph_cached = None

        if self.latent_type == "linear+memory":
            # Nonlinear maps between alpha and memory spaces
            enc_hidden_dim = enc_hidden_dim or max(32, min(self.latent_dim, 256))
            dec_hidden_dim = dec_hidden_dim or max(32, min(self.latent_dim, 256))
            # phi_enc: R^{D_alpha} -> R^{D_m}
            self.memory_encoder = MLP(in_dim=self.latent_dim, hidden_dim=enc_hidden_dim, out_dim=self.memory_dim,
                                      num_layers=max(1, enc_layers), nl=nl,
                                      last_zero_init=False, use_layernorm=True, norm_where="pre")
            # phi_dec: R^{D_m} -> R^{D_alpha}
            self.memory_decoder = MLP(in_dim=self.memory_dim, hidden_dim=dec_hidden_dim, out_dim=self.latent_dim,
                                      num_layers=max(1, dec_layers), nl=nl,
                                      last_zero_init=True, use_layernorm=True, norm_where="pre")
            # Positive diagonal Lambda for memory decay/stability: lambda = softplus(raw) + eps
            self._raw_lambda = nn.Parameter(torch.zeros(self.memory_dim))
            self._lambda_eps = 1e-5
        
        # ===================== 🟦 NEW: GRU-memory ODE =====================
        if self.latent_type == "gru_memory":
            dec_hidden_dim = dec_hidden_dim or max(32, min(self.latent_dim, 256))
            self.memory_decoder = MLP(
                in_dim=self.memory_dim, hidden_dim=dec_hidden_dim, out_dim=self.latent_dim,
                num_layers=max(1, dec_layers), nl=nl,
                last_zero_init=True, use_layernorm=True, norm_where="pre"
            )
            gate_in = self.latent_dim + self.memory_dim
            # gates/candidate produce D_m each (we apply sigmoid/tanh in RHS)
            self.gru_r = MLP(in_dim=gate_in, hidden_dim=max(32, self.memory_dim), out_dim=self.memory_dim,
                              num_layers=max(1, enc_layers), nl=nl)
            self.gru_z = MLP(in_dim=gate_in, hidden_dim=max(32, self.memory_dim), out_dim=self.memory_dim,
                              num_layers=max(1, enc_layers), nl=nl)
            self.gru_h = MLP(in_dim=gate_in, hidden_dim=max(32, self.memory_dim), out_dim=self.memory_dim,
                              num_layers=max(1, enc_layers), nl=nl)
            # learnable positive inverse time constant tau_m^{-1}
            self._raw_tau_m_inv = nn.Parameter(torch.zeros(self.memory_dim))
            self._tau_eps = 1e-5

            # 🟦 NEW: gate bias init (stability) — start update gate slightly closed
            _init_last_linear_bias(self.gru_z, -1.0)

        # ===================== 🟦 NEW: LSTM ODE =====================
        if self.latent_type == "lstm_memory":
            dec_hidden_dim = dec_hidden_dim or max(32, min(self.latent_dim, 256))
            self.memory_decoder = MLP(
                in_dim=self.memory_dim, hidden_dim=dec_hidden_dim, out_dim=self.latent_dim,
                num_layers=max(1, dec_layers), nl=nl,
                last_zero_init=True, use_layernorm=True, norm_where="pre"
            )
            gate_in = self.latent_dim + self.memory_dim
            # i,f,o,c tilde heads (apply sigmoid/tanh in RHS)
            self.lstm_i = MLP(in_dim=gate_in, hidden_dim=max(32, self.memory_dim), out_dim=self.memory_dim,
                              num_layers=max(1, enc_layers), nl=nl)
            self.lstm_f = MLP(in_dim=gate_in, hidden_dim=max(32, self.memory_dim), out_dim=self.memory_dim,
                              num_layers=max(1, enc_layers), nl=nl)
            self.lstm_o = MLP(in_dim=gate_in, hidden_dim=max(32, self.memory_dim), out_dim=self.memory_dim,
                              num_layers=max(1, enc_layers), nl=nl)
            self.lstm_cand = MLP(in_dim=gate_in, hidden_dim=max(32, self.memory_dim), out_dim=self.memory_dim,
                                  num_layers=max(1, enc_layers), nl=nl)
            # learnable positive inverse time constants
            self._raw_tau_c_inv = nn.Parameter(torch.zeros(self.memory_dim))
            self._raw_tau_h_inv = nn.Parameter(torch.zeros(self.memory_dim))
            self._tau_eps = 1e-5

            # 🟦 NEW: gate bias init (stability)
            # Forget gate slightly open; Input gate slightly closed
            _init_last_linear_bias(self.lstm_f, +1.0)
            _init_last_linear_bias(self.lstm_i, -1.0)

        # --------------------------------------------------
        if self.latent_type == "neural_ode":
            assert ode_func is not None, "`neural_ode` requires `ode_func` defined on alpha."
            self.ode_func = ode_func

        # self._raw_mem_scale = nn.Parameter(torch.tensor(0.0))
        self._raw_mem_scale = nn.Parameter(torch.zeros(self.latent_dim))

    # -----------------------------------------------------------------------------------------
    @property
    def mem_scale(self) -> torch.Tensor:
        return torch.sigmoid(self._raw_mem_scale)


    def init_linear_from_Ad(self, Ad_col: torch.Tensor, dt: float,
                            clip_positive_symmetric: bool = False,
                            max_pos_real: float = 0.0):
        """
        Initialize 'free' linear generator from discrete one-step column-form Ad.

        Notes:
        - Your Phase-I code used Zp_hat = Z0 @ A_use.T, so A_use is column-form Ad.
        - This function expects that same column-form Ad.
        """
        if self.linear_param != "free":
            raise RuntimeError("init_linear_from_Ad only supports linear_param='free'")

        # (1) compute continuous generator Ac = log(Ad)/dt with robust routine
        Ac = discrete_to_continuous_via_logm(Ad_col, dt)

        # (2) write into parameter
        with torch.no_grad():
            self.linear.data.copy_(Ac.to(self.linear))

        # (3) optional: clip symmetric growth part to avoid positive real drift
        if clip_positive_symmetric:
            with torch.no_grad():
                A = self.linear.data
                # symmetric part S = (A + A^T)/2
                S = 0.5 * (A + A.T)
                # project positive eigenvalues down to max_pos_real (<= 0 recommended)
                evals, Q = torch.linalg.eigh(S)
                evals_clipped = torch.minimum(evals, torch.tensor(max_pos_real, device=A.device, dtype=A.dtype))
                S_clipped = (Q @ torch.diag(evals_clipped) @ Q.T)
                # replace symmetric part -> A_new = (A - S) + S_clipped
                self.linear.data = (A - S) + S_clipped


    # construct building blocks for pH_dense
    # A = (J - R) * P, J=−J^T, R⪰0, P≻0.
    def _ph_dense_build_terms(self):
        if getattr(self, "_ph_cache_active", False) and (self._ph_cached is not None):
            return self._ph_cached  # (J, R, P)

        D = self.latent_dim
        device = self.ph_B.device
        dtype = self.ph_B.dtype
        # U, P_perp
        if self.ph_d > 0:
            W = self.ph_W[:, :self.ph_d]
            U, _ = torch.linalg.qr(W, mode='reduced')  # [D, d]
            P_perp = torch.eye(D, device=device, dtype=dtype) - U @ U.transpose(0, 1)
        else:
            U = None
            P_perp = torch.eye(D, device=device, dtype=dtype)

        P_perp = 0.5 * (P_perp + P_perp.T)
        # R = P_perp (C^T C) P_perp
        CtC = self.ph_C.transpose(0, 1) @ self.ph_C
        R = P_perp @ CtC @ P_perp

        # P = L^T L + eps I
        P = self.ph_L.transpose(0, 1) @ self.ph_L + self.ph_epsP * torch.eye(D, device=device, dtype=dtype)

        # J = U Ω U^T + P_perp * skew(B) * P_perp
        Js = 0.5 * (self.ph_B - self.ph_B.transpose(0, 1))
        if (U is not None) and (self.ph_d >= 2):
            k = max(1, self.ph_d // 2)
            omega = F.softplus(self.ph_omega_raw) + self.ph_min_freq  # [k] 下界
            Omega = torch.zeros(self.ph_d, self.ph_d, device=device, dtype=dtype)
            for i in range(k):
                i0, i1 = 2 * i, 2 * i + 1
                if i1 >= self.ph_d:
                    break
                w = omega[i]
                Omega[i0, i1] = -w
                Omega[i1, i0] =  w
            J_rot = U @ Omega @ U.transpose(0, 1)
        else:
            J_rot = torch.zeros(D, D, device=device, dtype=dtype)
        J_perp = P_perp @ Js @ P_perp
        J = J_rot + J_perp

        if getattr(self, "_ph_cache_active", False):
            self._ph_cached = (J, R, P)
        return J, R, P

    def _prepare_ph_cache(self):
        if self.linear_param == "pH_dense":
            self._ph_cache_active = True
            self._ph_cached = None
            _ = self._ph_dense_build_terms()

    def _clear_ph_cache(self):
        if self.linear_param == "pH_dense":
            self._ph_cache_active = False
            self._ph_cached = None

    # unified linear application alpha -> alpha A^T = alpha P (-J - R)
    def _apply_linear(self, alpha: torch.Tensor) -> torch.Tensor:
        if self.linear_param == "free":
            return alpha @ self.linear.T
        elif self.linear_param == "pH_dense":
            J, R, P = self._ph_dense_build_terms()
            alphaP = alpha @ P           # IMPORTANT: left-multiply by P first
            return alphaP @ (-J - R)     # then right-multiply by (-J - R)
        else:
            # fallback (should not happen)
            return alpha

    # ---------------- ODE right-hand-sides ----------------
    def _f_linear(self, t: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor:
        """
        RHS for the pure linear case: d alpha / dt = A alpha
        alpha: [B, D_alpha], A: [D_alpha, D_alpha]
        """
        return self._apply_linear(alpha)

    @property
    def lambda_diag(self) -> torch.Tensor:
        """Positive diagonal vector for memory decay term."""
        if not hasattr(self, "_raw_lambda"):
            return None
        return F.softplus(self._raw_lambda) + self._lambda_eps

    def init_memory_time_constant(self, tau_in_steps: float, dt: float):
        lam0 = 1.0 / (tau_in_steps * dt)
        raw0 = float(np.log(np.expm1(lam0)))
        with torch.no_grad():
            self._raw_lambda.fill_(raw0)

    def _f_linear_memory(self, t: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        """
        Coupled RHS for [alpha, m] in the linear+memory case.

        z: [B, D_alpha + D_m]
        Returns: [B, D_alpha + D_m]
        """
        D_alpha = self.latent_dim
        alpha, m = z[..., :D_alpha], z[..., D_alpha:]
        # d alpha / dt = A alpha + phi_dec(m)
        alpha_dot = self._apply_linear(alpha) + self.mem_scale * self.memory_decoder(m)
        # alpha_dot = self._apply_linear(alpha) + self.memory_decoder(m)
        # d m / dt = phi_enc(alpha) - Lambda * m
        lam = self.lambda_diag  # shape: [D_m]
        m_dot = self.memory_encoder(alpha) - lam * m
        return torch.cat([alpha_dot, m_dot], dim=-1)

    # GRU-memory RHS
    @property
    def tau_m_inv(self) -> torch.Tensor:
        return (F.softplus(self._raw_tau_m_inv) + self._tau_eps).clamp(max=10.0)

    def _f_gru_memory(self, t: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        D_alpha = self.latent_dim
        alpha, m = z[..., :D_alpha], z[..., D_alpha:]
        concat = torch.cat([alpha, m], dim=-1)
        r = torch.sigmoid(self.gru_r(concat))              # [B, D_m]
        z_gate = torch.sigmoid(self.gru_z(concat))         # [B, D_m]
        concat_h = torch.cat([alpha, r * m], dim=-1)
        m_tilde = torch.tanh(self.gru_h(concat_h))         # [B, D_m]
        m_dot = self.tau_m_inv * (z_gate * (m_tilde - m))  # relaxation to target
        alpha_dot = self._apply_linear(alpha) + self.mem_scale * self.memory_decoder(m)
        # alpha_dot = self._apply_linear(alpha) + self.memory_decoder(m)
        return torch.cat([alpha_dot, m_dot], dim=-1)

    # LSTM-memory RHS
    @property
    def tau_c_inv(self) -> torch.Tensor:
        return (F.softplus(self._raw_tau_c_inv) + self._tau_eps).clamp(max=10.0)

    @property
    def tau_h_inv(self) -> torch.Tensor:
        return (F.softplus(self._raw_tau_h_inv) + self._tau_eps).clamp(max=10.0)

    def _f_lstm_memory(self, t: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        D_alpha = self.latent_dim
        D_m = self.memory_dim
        alpha = z[..., :D_alpha]
        m = z[..., D_alpha:D_alpha + D_m]   # hidden
        c = z[..., D_alpha + D_m:]
        concat = torch.cat([alpha, m], dim=-1)
        i = torch.sigmoid(self.lstm_i(concat))
        f = torch.sigmoid(self.lstm_f(concat))
        o = torch.sigmoid(self.lstm_o(concat))
        c_tilde = torch.tanh(self.lstm_cand(concat))
        c_dot = self.tau_c_inv * (f * c + i * c_tilde - c)
        m_star = o * torch.tanh(c)
        m_dot = self.tau_h_inv * (m_star - m)
        alpha_dot = self._apply_linear(alpha) + self.mem_scale * self.memory_decoder(m)
        # alpha_dot = self._apply_linear(alpha) + self.memory_decoder(m)
        return torch.cat([alpha_dot, m_dot, c_dot], dim=-1)
    
    # --------------------------------------------
    def _rhs_parameters(self):
        ps = []
        if self.latent_type in {"linear", "linear+memory", "gru_memory", "lstm_memory"}:
            if self.linear_param == "free":
                ps.append(self.linear)
            elif self.linear_param == "pH_dense":
                ps += [self.ph_B, self.ph_C, self.ph_L, self.ph_W, self.ph_omega_raw]

        if self.latent_type == "linear+memory":
            ps += list(self.memory_encoder.parameters())
            ps += list(self.memory_decoder.parameters())
            ps.append(self._raw_lambda)

        elif self.latent_type == "gru_memory":
            ps += list(self.memory_decoder.parameters())
            ps += list(self.gru_r.parameters()) + list(self.gru_z.parameters()) + list(self.gru_h.parameters())
            ps.append(self._raw_tau_m_inv)

        elif self.latent_type == "lstm_memory":
            ps += list(self.memory_decoder.parameters())
            ps += (list(self.lstm_i.parameters()) + list(self.lstm_f.parameters()) +
                list(self.lstm_o.parameters()) + list(self.lstm_cand.parameters()))
            ps += [self._raw_tau_c_inv, self._raw_tau_h_inv]

        elif self.latent_type == "neural_ode":
            ps += list(self.ode_func.parameters())

        ps.append(self._raw_mem_scale)
        return tuple(ps)
    
    def _odeint(self, func, y0, t_eval):
        return odeint(func, y0, t_eval, method=self.solver, adjoint_params=self._rhs_parameters())

    def _odeint_with_tf(self, func, tf_alpha, t_eval, tf_epsilon: float, tf_mask: torch.Tensor | None,
                        use_memory: bool = False, aug_init: torch.Tensor | None = None,
                        tf_detach_alpha_starts: bool = True,
                        detach_memory_between_segments: bool = False):
        # tf_alpha: [T, B, D_alpha] (ground truth alpha sequence)
        # aug_init: [B, D_m], only provided when using memory augmentation
        T = t_eval.numel()
        mask = self._build_tf_mask(T, epsilon=tf_epsilon, tf_mask=tf_mask, device=tf_alpha.device)
        if T == 1:
            y0_alpha = tf_alpha[0].detach()
            y0 = torch.cat([y0_alpha, aug_init], dim=-1) if use_memory else y0_alpha
            seg = self._odeint(func, y0, t_eval[:1])  # [1, B, D]
            return seg
        segs, start = [], 0
        aug_curr = aug_init if use_memory else None   # current value for the augmented state
        while start < T - 1:
            ends = torch.where(mask[start:T-1])[0]
            end = start + int(ends[0].item()) + 1 if len(ends) > 0 else T - 1
            t_seg = t_eval[start: end+1]
            y0_alpha = tf_alpha[start] if not tf_detach_alpha_starts else tf_alpha[start].detach()
            y0 = torch.cat([y0_alpha, aug_curr], dim=-1) if use_memory else y0_alpha
            seg = self._odeint(func, y0, t_seg)
            segs.append(seg if len(segs)==0 else seg[1:])
            start = end
            """aug_curr = seg[-1, :, self.latent_dim:] if use_memory else None"""
            if use_memory:
                aug_curr = seg[-1, :, self.latent_dim:]   # 传递 memory
                if detach_memory_between_segments:
                    aug_curr = aug_curr.detach() 
        alpha_t = torch.cat(segs, dim=0)
        return alpha_t

    # --------------------------------------------
    # build teacher-forcing mask over intervals (length T-1)
    def _build_tf_mask(self, T: int, epsilon: float, tf_mask: torch.Tensor | None, device: torch.device) -> torch.Tensor:
        """Return boolean mask of length T-1; True means *end* this segment at that index.
        The last interval is always False to keep the final step connected.
        """
        if T <= 1:
            return torch.zeros(0, dtype=torch.bool, device=device)
        if tf_mask is not None:
            assert tf_mask.ndim == 1 and tf_mask.numel() == T - 1, "tf_mask must have length T-1"
            mask = tf_mask.to(device=device, dtype=torch.bool).clone()
        else:
            if epsilon <= 1e-8:
                mask = torch.zeros(T - 1, dtype=torch.bool, device=device)
            else:
                mask = torch.rand(T - 1, device=device) < float(epsilon)
        mask[-1] = False
        return mask

    # ---------------- Public API ----------------
    def forward(
        self,
        alpha_0: torch.Tensor,
        t_eval: torch.Tensor,
        memory_init: torch.Tensor | None = None,
        # teacher forcing controls
        teacher_forcing: bool = False,
        tf_alpha: torch.Tensor | None = None,   # shape [T, B, D_alpha]
        tf_epsilon: float = 0.0,
        tf_mask: torch.Tensor | None = None,
        tf_detach_alpha_starts: bool = True,    #######################################
        detach_memory_between_segments: bool = False ##################################
    ):
        """
        Integrate the latent dynamics over t_eval.

        Args
        ----
        alpha_0    : [B, D_alpha]  initial latent state
        t_eval     : [T]           time grid (monotonic recommended)
        memory_init: [B, D_m] or None. Used for "linear+memory" / 🟦"gru_memory" / 🟦"lstm_memory".
                     For LSTM, this is the initial hidden h0; the cell c0 is initialized by the
                     classic convention (zeros) if not provided.
        ---------------------------------------------------------------------------------------------
        teacher_forcing: if True, segment the integration and reset the *alpha* initial state of each
                         segment to the provided ground-truth `tf_alpha` at that segment's start.
                         Memory states (m or h/c) are NOT teacher-forced; they carry over from the
                         previous segment end.
        tf_alpha   : [T, B, D_alpha] ground-truth alpha trajectory aligned with `t_eval`.
        tf_epsilon : probability for cutting each interval (Bernoulli on T-1 intervals).
        tf_mask    : optional boolean mask of length T-1. If provided, overrides `tf_epsilon`.

        Returns
        -------
        alpha_t    : [T, B, D_alpha]
        memory_t   : [T, B, D_m] or None
        """
        assert alpha_0.dim() == 2 and alpha_0.size(-1) == self.latent_dim, \
            f"alpha_0 must be [B, {self.latent_dim}]"
        assert t_eval.dim() == 1, "t_eval must be 1D [T]"
        if teacher_forcing:
            assert tf_alpha is not None and tf_alpha.dim() == 3 and tf_alpha.size(0) == t_eval.numel(), \
                "tf_alpha must be [T,B,D_alpha] aligned with t_eval"
            tf_alpha = tf_alpha.to(device=alpha_0.device, dtype=alpha_0.dtype)
        t_eval = t_eval.to(device=alpha_0.device, dtype=alpha_0.dtype)

        self._prepare_ph_cache()
        aux = {}

        try:
            # neural_ode
            if self.latent_type == "neural_ode":
                if teacher_forcing:
                    sol = self._odeint_with_tf(self.ode_func, tf_alpha, t_eval, tf_epsilon, tf_mask,
                                               tf_detach_alpha_starts=tf_detach_alpha_starts, detach_memory_between_segments=detach_memory_between_segments)
                else:
                    sol = self._odeint(self.ode_func, alpha_0, t_eval)  # [T, B, D_alpha]
                return (sol, None, {}) if self.return_aux else (sol, None)

            # linear: integrate alpha with linear dynamics
            if self.latent_type == "linear":
                if teacher_forcing:
                    sol = self._odeint_with_tf(self._f_linear, tf_alpha, t_eval, tf_epsilon, tf_mask,
                                               tf_detach_alpha_starts=tf_detach_alpha_starts, detach_memory_between_segments=detach_memory_between_segments)
                else:
                    sol = self._odeint(self._f_linear, alpha_0, t_eval)  # [T, B, D_alpha]
                return (sol, None, {}) if self.return_aux else (sol, None)

            B = alpha_0.size(0)
            device = alpha_0.device
            dtype = alpha_0.dtype

            # ===================== linear+memory =====================
            if self.latent_type == "linear+memory":
                if memory_init is None:
                    memory_init = torch.zeros(B, self.memory_dim, dtype=dtype, device=device)
                else:
                    assert memory_init.shape == (B, self.memory_dim), \
                        f"memory_init must be [B, {self.memory_dim}]"
                z0 = torch.cat([alpha_0, memory_init], dim=-1)  # [B, D_alpha + D_m]
                if teacher_forcing:
                    zt = self._odeint_with_tf(self._f_linear_memory, tf_alpha, t_eval, tf_epsilon, tf_mask,
                                            use_memory=True, aug_init=memory_init,
                                            tf_detach_alpha_starts=tf_detach_alpha_starts, detach_memory_between_segments=detach_memory_between_segments)
                else:
                    zt = self._odeint(self._f_linear_memory, z0, t_eval)  # [T, B, D_alpha + D_m]
                alpha_t = zt[..., :self.latent_dim]
                memory_t = zt[..., self.latent_dim:]
                if self.return_aux:
                    dec = self.memory_decoder(memory_t.reshape(-1, self.memory_dim))
                    aux["phi_dec_l2"] = (dec.pow(2).sum(dim=-1)).mean()

                return (alpha_t, memory_t, aux) if self.return_aux else (alpha_t, memory_t)

            # ===================== GRU-memory =====================
            if self.latent_type == "gru_memory":
                if memory_init is None:
                    memory_init = torch.zeros(B, self.memory_dim, dtype=dtype, device=device)
                else:
                    assert memory_init.shape == (B, self.memory_dim), \
                        f"memory_init must be [B, {self.memory_dim}]"
                z0 = torch.cat([alpha_0, memory_init], dim=-1)  # [B, D_alpha + D_m]
                if teacher_forcing:
                    zt = self._odeint_with_tf(self._f_gru_memory, tf_alpha, t_eval, tf_epsilon, tf_mask,
                                            use_memory=True, aug_init=memory_init,
                                            tf_detach_alpha_starts=tf_detach_alpha_starts, detach_memory_between_segments=detach_memory_between_segments)
                else:
                    zt = self._odeint(self._f_gru_memory, z0, t_eval)  # [T, B, D_alpha + D_m]
                alpha_t = zt[..., :self.latent_dim]
                memory_t = zt[..., self.latent_dim:]
                if self.return_aux:
                    dec = self.memory_decoder(memory_t.reshape(-1, self.memory_dim))
                    aux["phi_dec_l2"] = (dec.pow(2).sum(dim=-1)).mean()

                return (alpha_t, memory_t, aux) if self.return_aux else (alpha_t, memory_t)

            # ===================== LSTM-memory =====================
            if self.latent_type == "lstm_memory":
                # only memory_init (hidden h0) is provided; c0 follows the classic convention (zeros)
                if memory_init is None:
                    h0 = torch.zeros(B, self.memory_dim, dtype=dtype, device=device)
                else:
                    assert memory_init.shape == (B, self.memory_dim), \
                        f"memory_init must be [B, {self.memory_dim}]"
                    h0 = memory_init
                c0 = torch.zeros(B, self.memory_dim, dtype=dtype, device=device)  # classic LSTM init
                z0 = torch.cat([alpha_0, h0, c0], dim=-1)                          # [B, D_alpha + 2*D_m]
                if teacher_forcing:
                    zt = self._odeint_with_tf(self._f_lstm_memory, tf_alpha, t_eval, tf_epsilon, tf_mask,
                                            use_memory=True, aug_init=z0[:, self.latent_dim:],
                                            tf_detach_alpha_starts=tf_detach_alpha_starts, detach_memory_between_segments=detach_memory_between_segments)
                else:
                    zt = self._odeint(self._f_lstm_memory, z0, t_eval)   # [T, B, D_alpha + 2*D_m]
                alpha_t = zt[..., :self.latent_dim]
                memory_t = zt[..., self.latent_dim:self.latent_dim + self.memory_dim]  # return hidden only
                if self.return_aux:
                    dec = self.memory_decoder(memory_t.reshape(-1, self.memory_dim))
                    aux["phi_dec_l2"] = (dec.pow(2).sum(dim=-1)).mean()

                return (alpha_t, memory_t, aux) if self.return_aux else (alpha_t, memory_t)

            raise RuntimeError("Unknown latent_type")
        finally:
            self._clear_ph_cache()

    # ====================== NEW: build continuous-time generator A (column-vector convention) ======================
    def _continuous_generator_A(self) -> torch.Tensor:
        """
        Return the continuous-time generator A such that, in the *column-vector* convention:
            x_dot = A @ x
        Your implementation in this class uses *row*-vector style in forward calls, e.g.
            alpha_dot = alpha @ self.linear.T         (free)
            alpha_dot = (alpha @ P) @ (-J - R)        (pH_dense)
        which corresponds to column A being:
            A_free    = self.linear
            A_pHdense = (J - R) @ P
        because [alpha @ (P(-J-R))]^T = (J-R)P alpha^T.
        """
        if self.latent_type not in {"linear", "linear+memory", "gru_memory", "lstm_memory"}:
            raise RuntimeError("continuous generator A is only defined for linear-based modes")

        if self.linear_param == "free":
            # alpha_dot = alpha @ self.linear.T   =>   x_dot = self.linear @ x
            return self.linear
        elif self.linear_param == "pH_dense":
            J, R, P = self._ph_dense_build_terms()
            # Row form uses alpha_dot = (alpha @ P) @ (-J - R)
            # Column form is: x_dot = (J - R) @ P @ x
            return (J - R) @ P
        else:
            raise RuntimeError(f"Unknown linear_param: {self.linear_param}")

    # ====================== NEW: spectral regularizer (covers both 'free' and 'pH_dense') ======================
    def spectral_regularizer(self) -> torch.Tensor:
        """
        A small scalar penalty that discourages unstable dynamics.
        - For 'free': penalize positive eigenvalues of the symmetric part S=(A+A^T)/2 (continuous-time growth).
        - For 'pH_dense': keep your original lightweight scale penalty on P(-J-R) to avoid blow-up of factors.
        """
        if self.latent_type not in {"linear", "linear+memory", "gru_memory", "lstm_memory"}:
            return torch.tensor(0.0, device=self._raw_mem_scale.device)

        if self.linear_param == "free":
            A = self._continuous_generator_A()             # [D, D]
            S = 0.5 * (A + A.T)                           # symmetric part
            # smooth ReLU on eigenvalues of S; penalize positive ones
            evals = torch.linalg.eigvalsh(S)              # symmetric eigs
            return F.softplus(evals).pow(2).mean()

        elif self.linear_param == "pH_dense":
            # Keep your previous scale regularizer on P(-J - R).
            J, R, P = self._ph_dense_build_terms()
            M = P @ (-J - R)
            return (M * M).mean()

        else:
            return torch.tensor(0.0, device=self._raw_mem_scale.device)

    # ====================== NEW: structural diagnostics ======================
    @torch.no_grad()
    def _structural_diagnostic(self, t_eval: torch.Tensor | None = None, num_energy_samples: int = 16) -> dict:
        """
        Compute stability-related diagnostics WITHOUT changing any state.
        Returns a dict with:
          - max_real_eigA:   max Re(eig(A)) for continuous-time generator (column convention)
          - Ad_spectral_rad: spectral radius of exp(A*dt) if dt available
          - exp_norm2:       ||exp(A*dt)||_2 (cheap proxy via 2-norm) if dt available
          - ph_energy_term:  mean(x^T P R P x) over random samples (pH_dense only; should be >= 0)
        """
        out = {}
        if self.latent_type not in {"linear", "linear+memory", "gru_memory", "lstm_memory"}:
            return out

        try:
            A = self._continuous_generator_A()  # [D, D]
        except Exception:
            return out

        # 1) continuous-time max real part
        try:
            eigA = torch.linalg.eigvals(A)
            out["max_real_eigA"] = float(eigA.real.max().item())
        except Exception:
            pass

        # 2) one-step discrete behavior if dt provided
        if t_eval is not None and t_eval.numel() >= 2:
            try:
                dt = float((t_eval[1] - t_eval[0]).item())
                Ad = torch.matrix_exp(A * dt)
                # spectral radius (max |lambda| of Ad)
                evals_Ad = torch.linalg.eigvals(Ad)
                out["Ad_spectral_rad"] = float(evals_Ad.abs().max().item())
                # cheap 2-norm proxy
                out["exp_norm2"] = float(torch.linalg.matrix_norm(Ad, ord=2).item())
            except Exception:
                pass

        # 3) port-Hamiltonian energy dissipation term (only meaningful for pH_dense)
        if self.linear_param == "pH_dense":
            try:
                J, R, P = self._ph_dense_build_terms()
                D = A.size(0)
                # sample standard normals x ~ N(0, I)
                x = torch.randn(num_energy_samples, D, device=A.device, dtype=A.dtype)
                Px = x @ P.T
                # (P x)^T R (P x) for each sample
                diss = torch.einsum("bi,ij,bj->b", Px, R, Px)  # shape [B]
                out["ph_energy_term"] = float(diss.mean().item())  # should be >= 0; energy derivative is -that
            except Exception:
                pass

        return out


























"""

######################
# A CDE model looks like
#
# z_t = z_0 + \int_0^t f_\theta(z_s) dX_s
######################
class LatentCDEfunc(nn.Module):
    
    # input_dim: dimension for the input (control) sequence
    
    def __init__(self, input_dim, state_dim, code_dim, hidden_dim, num_layers=3, nl='swish', **kwargs):
        super(LatentCDEfunc, self).__init__()
        self.latent_dim = state_dim * code_dim
        self.input_dim = input_dim

        self.net = MLP(
            in_dim=self.latent_dim,
            hidden_dim=hidden_dim,
            out_dim=self.latent_dim*input_dim,
            num_layers=num_layers,
            nl=nl
        )    # latent_dim --> latent_dim * input_dim (f(z_t)dX_t)
    

    def forward(self, t, z):
        # shape of z: [B, latent_dim]
        z = self.net(z)
        ######################
        # Easy-to-forget gotcha: Best results tend to be obtained by adding a final tanh nonlinearity.
        ######################
        z = z.tanh()
        ######################
        # Ignoring the batch dimension, the shape of the output tensor must be a matrix,
        # because we need it to represent a linear map from R^input_dim to R^latent_dim.
        ######################
        z = z.view(z.size(0), self.latent_dim, self.input_dim)    # shape: [B, latent_dim, input_dim]
        return z



class LatentProcess(nn.Module):
    def __init__(
        self, state_dim, code_dim, input_dim=None,
        method="ode",                # 'ode' / 'rnn' / 'cde'
        rnn_type="lstm",             # 'rnn' / 'gru' / 'lstm' (discrete sequence model)
        rnn_num_layers=1,
        ode_func=None, cde_func=None, solver="rk4"
    ):
        super().__init__()
        self.method    = method
        self.rnn_type  = rnn_type
        self.rnn_num_layers = rnn_num_layers
        self.latent_dim = state_dim * code_dim
        self.input_dim = input_dim

        if method == 'ode':
            self.solver   = solver
            self.ode_func = ode_func
        elif method == 'cde':
            self.solver = solver
            self.cde_func = cde_func
        else:
            rnn_cls = {
                'rnn': nn.RNN,
                'gru': nn.GRU,
                'lstm': nn.LSTM,
            }[self.rnn_type]
            self.rnn = rnn_cls(
                input_size=self.latent_dim,
                hidden_size=self.latent_dim,
                num_layers=self.rnn_num_layers,
                batch_first=True
            )
            self.rnn_input_poj = nn.Sequential(
                nn.Linear(input_dim, self.latent_dim),
                nn.ReLU(),
                nn.LayerNorm(self.latent_dim)
            )
            self.residual = nn.Linear(self.latent_dim, self.latent_dim)
            self.readout = nn.Sequential(
                nn.LayerNorm(self.latent_dim),
                nn.Linear(self.latent_dim, self.latent_dim),
                nn.ReLU()
            )


    def forward(self, alpha_0: torch.Tensor, t_eval: torch.Tensor, control_seq: None, control_ts: None):
        # ---------- Continuous‑time ode -------------
        if self.method == "ode":
            return odeint(self.ode_func, alpha_0, t_eval, method=self.solver)
        
        # ---------- Continuous‑time cde -------------
        elif self.method == "cde":
            B, T, input_dim = control_seq.shape
            control_ts = control_ts.to(control_seq.device)
            # Interpolate the control seq
            coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(control_seq, t=control_ts)
            X = torchcde.CubicSpline(coeffs, t=control_ts)    # X is input control sequence
            
            adjoint_params = tuple(self.cde_func.parameters()) + (coeffs,)    # ????????????????????????????????
            out = torchcde.cdeint(X=X, func=self.cde_func, z0=alpha_0, t=t_eval, adjoint=True,
                                   adjoint_params=adjoint_params, method=self.solver)
            return out.permute(1, 0, 2)

        # ---------- Discrete‑time ---------------
        elif self.method == "rnn":
            h0 = alpha_0.unsqueeze(0).repeat(self.rnn_num_layers, 1, 1)  # [num_layers, B, latent_dim]
            if self.rnn_type == 'lstm':
                c0 = torch.zeros_like(h0)
                control_seq = self.rnn_input_poj(control_seq)
                out, (hn, cn) = self.rnn(control_seq, (h0, c0))  # out: [B, T, latent_dim]
                out = out + self.residual(out)
            else:
                out, hn = self.rnn(control_seq, h0)
                out = out + self.residual(out)
            out = out.permute(1, 0, 2)
            out = self.readout(out)
            return out    # [T', B, latent_dim]
    
        # TODO: implement teacher forcing"""