{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"1qpinkvZIN4cHcKUcUZ5GGIiurW0lxPCw","timestamp":1769481076563}],"machine_shape":"hm","gpuType":"A100"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"code","source":["# -*- coding: utf-8 -*-\n","\"\"\"SDESym_Ex5_V2_SS.ipynb\n","\n","Automatically generated by Colab.\n","\n","Original file is located at\n","    https://colab.research.google.com/drive/1ueEHbS720lq2sAA4m_1BQVUV_jmuuUFQ\n","\n","# Work Logs\n","- 1.7 5:13pm: created notebook to change from Ex1 for Ex5\n","- 1.9 2:12am: now FP works, and SDE supposedly works but haven't checked yet.\n","\n","# Ground-truth Generators\n","\n","All FP symmetry generators:\n","- $v_1 = \\partial_t$\n","- $v_2 = u\\partial_u$\n","- $v_3 = \\partial_y$\n","- $v_4 = 2t\\partial_t + x\\partial_x + (y+a_2t)\\partial_y - 2u\\partial_u$\n","- $v_5 = -t\\partial_y + (y-a_2t)u\\partial_u$\n","- $v_6 = t[t\\partial_t+x\\partial_x+y\\partial_y]+[t(a_1+a_2y-1)-\\frac{1}{2}(x^2+y^2+a_2^2t^2)]u\\partial_u$\n","\n","SDE symmetry generators: $v_1, v_3, v_4$.\n","\n","# SDE Symmetry\n","\n","##Imports & Downloads\n","\"\"\"\n","\n","!pip install dm-haiku\n","\n","# Install (if needed)\n","# pip install --quiet jax jaxlib optax dm-haiku\n","\n","import math\n","import functools\n","import numpy as np\n","import jax\n","import jax.numpy as jnp\n","import haiku as hk\n","import optax\n","\n","jax.config.update(\"jax_enable_x64\", True)\n","\n","import itertools\n","from functools import partial\n","\n","import matplotlib.pyplot as plt\n","\n","Array = jnp.ndarray\n","\n","print(jnp.array([0.]).dtype)   # should print float64\n","\n","# @title Neural SDE surrogate — Example 5 (Section 6) — NO PRIORS (drift+diffusion), alternating training (IMPROVED)\n","# Learns f(x,y) and diffusion Σ(x,y) from increments, without using the analytic forms.\n","#\n","# True data generator (kept separate): dx = (a1/x) dt + dW1, dy = a2 dt + dW2\n","# Training model: ΔX ~ Normal( f(x,y) dt,  Σ(x,y) dt ), with Σ = L L^T (Cholesky, PSD)\n","#\n","# Improvements added (still \"no priors\"):\n","#   A) Replace finite-difference diffusion smoothness with exact JVP (Jacobian-vector product) penalty\n","#   B) Add whitening / calibration penalty: z = L_dt^{-1}(ΔX - f dt) should be ~ N(0,I)\n","#   C) Remove weight decay (often biases σ downward); keep optional tiny weight decay as 0 by default\n","#   D) Multiple drift steps per diffusion step + slightly smaller diffusion LR\n","#\n","# NOTE: This remains increment-likelihood training; no analytic drift/diffusion forms are used.\n","\n","import math\n","from dataclasses import dataclass\n","\n","import numpy as np\n","import jax\n","import jax.numpy as jnp\n","import matplotlib.pyplot as plt\n","import optax\n","\n","jax.config.update(\"jax_enable_x64\", True)\n","\n","# -------------------------------------------------------------------\n","# 0) Config\n","# -------------------------------------------------------------------\n","@dataclass\n","class CFG:\n","    # --- data ---\n","    T: float = 2.0\n","    dt: float = 0.01\n","    n_traj: int = 2048\n","\n","    # initial distribution (keep x away from 0 for the generator only)\n","    x0_low: float = 0.5\n","    x0_high: float = 3.0\n","    y0_low: float = -2.0\n","    y0_high: float = 2.0\n","    x_eps: float = 0.2  # generator safety floor for |x|\n","\n","    # --- model / training ---\n","    hidden: int = 128\n","    depth: int = 3\n","\n","    steps: int = 20000\n","    batch_size: int = 8192\n","\n","    # learning rates\n","    lr_f: float = 2e-3      # drift LR\n","    lr_L: float = 1e-3      # diffusion LR (slightly smaller)\n","\n","    # schedule: more drift steps than diffusion steps per outer iter\n","    drift_steps_per_iter: int = 3\n","    diff_steps_per_iter: int = 1\n","\n","    # regularization + optimization\n","    weight_decay: float = 0.0   # IMPORTANT: set to 0 to avoid σ bias (was 1e-6)\n","    grad_clip: float = 1.0\n","\n","    # diffusion regularization (generic, not physics):\n","    smooth_weight: float = 1e-3   # JVP-based smoothness on L(x)\n","    var_weight: float = 5e-4      # stabilize diffusion scale within minibatch\n","\n","    # NEW: whitening / calibration penalty\n","    whiten_weight: float = 5e-3   # try in [1e-3, 1e-2]\n","\n","    # numerical floors\n","    sigma_floor: float = 1e-3\n","\n","cfg = CFG()\n","key_main = jax.random.PRNGKey(0)\n","\n","# -------------------------------------------------------------------\n","# 1) TRUE DATA GENERATOR (ground truth used ONLY here + for later comparison plots)\n","# -------------------------------------------------------------------\n","truth = dict(a1=1.0, a2=0.5, sigma=1.0)\n","\n","def make_ex5_data(key, cfg: CFG, truth):\n","    a1, a2, sig, dt = truth[\"a1\"], truth[\"a2\"], truth[\"sigma\"], cfg.dt\n","    N = int(cfg.T / dt)\n","    t = jnp.linspace(0.0, cfg.T, N + 1)\n","\n","    kx0, ky0, kn = jax.random.split(key, 3)  # IMPORTANT: split keys\n","    x0 = jax.random.uniform(kx0, (cfg.n_traj,), minval=cfg.x0_low, maxval=cfg.x0_high, dtype=jnp.float64)\n","    y0 = jax.random.uniform(ky0, (cfg.n_traj,), minval=cfg.y0_low, maxval=cfg.y0_high, dtype=jnp.float64)\n","\n","    dW = jax.random.normal(kn, (cfg.n_traj, N, 2), dtype=jnp.float64) * math.sqrt(dt) * sig\n","\n","    def step(state, dWn):\n","        x, y = state[:, 0], state[:, 1]\n","        x_safe = jnp.sign(x) * jnp.maximum(jnp.abs(x), cfg.x_eps)\n","\n","        fx = a1 / x_safe\n","        fy = a2\n","\n","        x1 = x + fx * dt + dWn[:, 0]\n","        y1 = y + fy * dt + dWn[:, 1]\n","\n","        # keep away from 0 after update too (generator only)\n","        x1 = jnp.sign(x1) * jnp.maximum(jnp.abs(x1), cfg.x_eps)\n","        state1 = jnp.stack([x1, y1], axis=-1)\n","        return state1, state1\n","\n","    state0 = jnp.stack([x0, y0], axis=-1)\n","    _, states = jax.lax.scan(step, state0, jnp.swapaxes(dW, 0, 1))  # (N, n_traj, 2)\n","    XY = jnp.concatenate([state0[None, :, :], states], axis=0)       # (N+1, n_traj, 2)\n","    XY = jnp.swapaxes(XY, 0, 1)                                      # (n_traj, N+1, 2)\n","    return t, XY\n","\n","t, XY = make_ex5_data(key_main, cfg, truth)\n","print(\"Data shapes: t =\", t.shape, \", XY =\", XY.shape)\n","\n","# quick plot x(t) for a few trajectories\n","t_np = np.asarray(t)\n","XY_np = np.asarray(XY)\n","plt.figure(figsize=(7,4))\n","for i in range(20):\n","    plt.plot(t_np[::2], XY_np[i, ::2, 0], alpha=0.5)\n","plt.xlabel(\"t\"); plt.ylabel(\"x(t)\")\n","plt.title(\"Example 5: sample x-trajectories (data generator)\")\n","plt.grid(True, alpha=0.3); plt.tight_layout(); plt.show()\n","\n","# -------------------------------------------------------------------\n","# 2) Build increment dataset: input [x,y] -> output Δ[x,y]\n","# -------------------------------------------------------------------\n","def build_increment_dataset(XY):\n","    XY_n   = XY[:, :-1, :]           # (n_traj, N, 2)\n","    XY_np1 = XY[:, 1:, :]            # (n_traj, N, 2)\n","    dXY    = XY_np1 - XY_n           # (n_traj, N, 2)\n","\n","    X_in   = XY_n.reshape(-1, 2)     # (n_traj*N, 2)\n","    dX_out = dXY.reshape(-1, 2)      # (n_traj*N, 2)\n","    return X_in, dX_out\n","\n","X_raw, dX = build_increment_dataset(XY)\n","print(\"\\nIncrement dataset shapes: X_raw =\", X_raw.shape, \", dX =\", dX.shape)\n","\n","# sanity checks\n","dx_direct = XY[:, 1:, :] - XY[:, :-1, :]\n","dX_mat = dX.reshape(XY.shape[0], -1, 2)\n","print(\"\\n[SANITY CHECKS]\")\n","print(\"dt from grid:\", float(t[1]-t[0]), \" cfg.dt:\", cfg.dt)\n","print(\"max|Δ - (next-prev)| =\", float(jnp.max(jnp.abs(dX_mat - dx_direct))))\n","\n","# support check near y=0\n","y_all = np.asarray(X_raw[:, 1])\n","frac_y0 = float(np.mean(np.abs(y_all) < 0.05))\n","print(f\"\\n[SUPPORT CHECK] fraction of samples with |y|<0.05: {frac_y0:.4f}\")\n","if frac_y0 < 0.01:\n","    print(\"WARNING: y=0 slice is mostly extrapolation. Consider plotting at y=median(y) instead.\")\n","\n","# -------------------------------------------------------------------\n","# 3) Normalize inputs\n","# -------------------------------------------------------------------\n","class Normalizer:\n","    def __init__(self, mean, std):\n","        self.mean = mean\n","        self.std = std\n","    def __call__(self, x):\n","        return (x - self.mean) / (self.std + 1e-8)\n","\n","def fit_normalizer(X):\n","    return Normalizer(jnp.mean(X, axis=0), jnp.std(X, axis=0))\n","\n","in_norm = fit_normalizer(X_raw)\n","X = in_norm(X_raw)\n","\n","print(\"\\n[NORMALIZATION CHECK]\")\n","print(\"mean ~\", np.asarray(jax.device_get(X.mean(0))), \" std ~\", np.asarray(jax.device_get(X.std(0))))\n","\n","# -------------------------------------------------------------------\n","# 4) Model: drift MLP + diffusion-Cholesky MLP (PSD covariance)\n","# -------------------------------------------------------------------\n","def glorot(k, fan_in, fan_out):\n","    lim = math.sqrt(6.0 / (fan_in + fan_out))\n","    return jax.random.uniform(k, (fan_in, fan_out), minval=-lim, maxval=lim, dtype=jnp.float64)\n","\n","def init_mlp(key, in_dim, out_dim, hidden, depth):\n","    keys = jax.random.split(key, depth+1)\n","    dims = [in_dim] + [hidden]*depth + [out_dim]\n","    params = []\n","    for i in range(len(dims)-1):\n","        params.append({\n","            \"W\": glorot(keys[i], dims[i], dims[i+1]),\n","            \"b\": jnp.zeros((dims[i+1],), dtype=jnp.float64),\n","        })\n","    return params\n","\n","def mlp(params, x):\n","    h = x\n","    for i, layer in enumerate(params):\n","        h = h @ layer[\"W\"] + layer[\"b\"]\n","        if i < len(params)-1:\n","            h = jax.nn.swish(h)\n","    return h\n","\n","def init_model(key, cfg: CFG):\n","    kf, kL = jax.random.split(key, 2)\n","    pf = init_mlp(kf, in_dim=2, out_dim=2, hidden=cfg.hidden, depth=cfg.depth)\n","    pL = init_mlp(kL, in_dim=2, out_dim=3, hidden=cfg.hidden, depth=cfg.depth)\n","    return {\"pf\": pf, \"pL\": pL}\n","\n","def drift_and_chol(params, x_norm, cfg: CFG):\n","    f = mlp(params[\"pf\"], x_norm)  # (...,2)\n","\n","    raw = mlp(params[\"pL\"], x_norm)  # (...,3)\n","    l11_raw = raw[..., 0]\n","    l21     = raw[..., 1]\n","    l22_raw = raw[..., 2]\n","\n","    l11 = jax.nn.softplus(l11_raw) + cfg.sigma_floor\n","    l22 = jax.nn.softplus(l22_raw) + cfg.sigma_floor\n","\n","    zeros = jnp.zeros_like(l11)\n","    L = jnp.stack([\n","        jnp.stack([l11, zeros], axis=-1),\n","        jnp.stack([l21, l22], axis=-1),\n","    ], axis=-2)\n","    return f, L  # Σ = L L^T\n","\n","# -------------------------------------------------------------------\n","# 5) Increment NLL loss (multivariate Gaussian), + regularizers (JVP smoothness + whitening)\n","# -------------------------------------------------------------------\n","def l2_tree(p):\n","    return sum([jnp.sum(v**2) for v in jax.tree_util.tree_leaves(p)])\n","\n","def mvn_nll(dX, mean, L_dt):\n","    r = dX - mean\n","    z = jax.vmap(lambda Li, ri: jax.scipy.linalg.solve_triangular(Li, ri, lower=True))(L_dt, r)\n","    quad = jnp.sum(z*z, axis=-1)\n","    logdet = 2.0 * jnp.log(jnp.clip(jnp.diagonal(L_dt, axis1=-2, axis2=-1), 1e-12, None)).sum(axis=-1)\n","    return 0.5 * (quad + logdet)\n","\n","def diffusion_smoothness_penalty(params, xb_norm, cfg: CFG, key):\n","    # Exact JVP penalty: penalize ||J_L(x) v||^2 for random v\n","    v = jax.random.normal(key, xb_norm.shape, dtype=xb_norm.dtype)\n","    v = v / (jnp.linalg.norm(v, axis=-1, keepdims=True) + 1e-12)\n","\n","    def L_fn(x):\n","        return drift_and_chol(params, x, cfg)[1]  # (...,2,2)\n","\n","    _, dL = jax.jvp(L_fn, (xb_norm,), (v,))\n","    return jnp.mean(dL * dL)\n","\n","def diffusion_variance_penalty(params, xb_norm, cfg: CFG):\n","    _, L = drift_and_chol(params, xb_norm, cfg)\n","    diag = jnp.diagonal(L, axis1=-2, axis2=-1)  # (B,2)\n","    return jnp.mean(jnp.var(diag, axis=0))\n","\n","def whiten_penalty(dX, mean, L_dt):\n","    # z = L_dt^{-1} (dX - mean) should be ~ N(0, I)\n","    r = dX - mean\n","    z = jax.vmap(lambda Li, ri: jax.scipy.linalg.solve_triangular(Li, ri, lower=True))(L_dt, r)  # (B,2)\n","\n","    z_mean = jnp.mean(z, axis=0)\n","    zc = z - z_mean\n","    B = z.shape[0]\n","    cov = (zc.T @ zc) / jnp.maximum(B - 1, 1)\n","\n","    I = jnp.eye(2, dtype=z.dtype)\n","    return jnp.sum(z_mean**2) + jnp.sum((cov - I)**2)\n","\n","def loss_total(params, xb_norm, dxb, cfg: CFG, key_smooth):\n","    f, L = drift_and_chol(params, xb_norm, cfg)\n","    mean = f * cfg.dt\n","    L_dt = L * math.sqrt(cfg.dt)\n","\n","    nll = jnp.mean(mvn_nll(dxb, mean, L_dt))\n","\n","    wd = cfg.weight_decay * l2_tree(params)  # default 0\n","    smooth = cfg.smooth_weight * diffusion_smoothness_penalty(params, xb_norm, cfg, key_smooth)\n","    varpen = cfg.var_weight * diffusion_variance_penalty(params, xb_norm, cfg)\n","    white = cfg.whiten_weight * whiten_penalty(dxb, mean, L_dt)\n","\n","    return nll + wd + smooth + varpen + white, (nll, smooth, varpen, white)\n","\n","# -------------------------------------------------------------------\n","# 6) Alternating training loop (multiple drift steps, then diffusion)\n","# -------------------------------------------------------------------\n","key_main, k_model = jax.random.split(key_main, 2)\n","params = init_model(k_model, cfg)\n","\n","opt_f = optax.chain(optax.clip_by_global_norm(cfg.grad_clip),\n","                    optax.adamw(cfg.lr_f, weight_decay=0.0))\n","opt_L = optax.chain(optax.clip_by_global_norm(cfg.grad_clip),\n","                    optax.adamw(cfg.lr_L, weight_decay=0.0))\n","\n","opt_state_f = opt_f.init(params[\"pf\"])\n","opt_state_L = opt_L.init(params[\"pL\"])\n","\n","rng_np = np.random.default_rng(0)\n","\n","def sample_minibatch(X_norm, dX, bs):\n","    N = X_norm.shape[0]\n","    idx = rng_np.choice(N, size=min(bs, N), replace=False)\n","    return X_norm[idx], dX[idx]\n","\n","@jax.jit\n","def step_drift(pf, pL, opt_state_f, xb_norm, dxb, key_smooth):\n","    def _loss(pf_):\n","        p = {\"pf\": pf_, \"pL\": jax.lax.stop_gradient(pL)}\n","        val, aux = loss_total(p, xb_norm, dxb, cfg, key_smooth)\n","        return val, aux\n","    (val, aux), grads = jax.value_and_grad(_loss, has_aux=True)(pf)\n","    updates, opt_state_f2 = opt_f.update(grads, opt_state_f, pf)\n","    pf2 = optax.apply_updates(pf, updates)\n","    return pf2, opt_state_f2, val, aux\n","\n","@jax.jit\n","def step_diffusion(pf, pL, opt_state_L, xb_norm, dxb, key_smooth):\n","    def _loss(pL_):\n","        p = {\"pf\": jax.lax.stop_gradient(pf), \"pL\": pL_}\n","        val, aux = loss_total(p, xb_norm, dxb, cfg, key_smooth)\n","        return val, aux\n","    (val, aux), grads = jax.value_and_grad(_loss, has_aux=True)(pL)\n","    updates, opt_state_L2 = opt_L.update(grads, opt_state_L, pL)\n","    pL2 = optax.apply_updates(pL, updates)\n","    return pL2, opt_state_L2, val, aux\n","\n","print_every = 200\n","loss_hist = []\n","\n","for step in range(1, cfg.steps + 1):\n","    xb, dxb = sample_minibatch(np.asarray(X), np.asarray(dX), cfg.batch_size)\n","    xb = jnp.asarray(xb)\n","    dxb = jnp.asarray(dxb)\n","\n","    # multiple drift steps\n","    for _ in range(cfg.drift_steps_per_iter):\n","        key_main, ks = jax.random.split(key_main, 2)\n","        params[\"pf\"], opt_state_f, val1, aux1 = step_drift(params[\"pf\"], params[\"pL\"], opt_state_f, xb, dxb, ks)\n","\n","    # fewer diffusion steps\n","    for _ in range(cfg.diff_steps_per_iter):\n","        key_main, ks = jax.random.split(key_main, 2)\n","        params[\"pL\"], opt_state_L, val2, aux2 = step_diffusion(params[\"pf\"], params[\"pL\"], opt_state_L, xb, dxb, ks)\n","\n","    loss_hist.append(float(val2))\n","\n","    if step % print_every == 0 or step == 1 or step == cfg.steps:\n","        nll, smooth, varpen, white = aux2\n","        print(\n","            f\"step {step:5d}/{cfg.steps} | total = {float(val2):.6e} \"\n","            f\"| nll={float(nll):.6e} | smooth={float(smooth):.3e} | var={float(varpen):.3e} | white={float(white):.3e}\"\n","        )\n","\n","print(\"\\nTraining finished.\")\n","\n","# -------------------------------------------------------------------\n","# 7) Diagnostics + comparison (comparison uses truth ONLY here)\n","# -------------------------------------------------------------------\n","def eval_slice(params, x_vals, y0):\n","    y_vals = jnp.full_like(x_vals, y0)\n","    XY_raw = jnp.stack([x_vals, y_vals], axis=-1)\n","    XY_norm = in_norm(XY_raw)\n","    f_hat, L_hat = drift_and_chol(params, XY_norm, cfg)\n","    Sigma = jnp.einsum(\"...ik,...jk->...ij\", L_hat, L_hat)  # L L^T\n","    sigx = jnp.sqrt(jnp.clip(Sigma[...,0,0], 1e-12, None))\n","    sigy = jnp.sqrt(jnp.clip(Sigma[...,1,1], 1e-12, None))\n","    return XY_raw, f_hat, sigx, sigy\n","\n","y0 = float(np.median(np.asarray(X_raw[:,1])))\n","print(f\"\\nUsing slice y0 = median(y) = {y0:.3f}  (set y0=0.0 if you have enough support there)\")\n","\n","x_vals = jnp.linspace(0.5, 2.0, 200)\n","XY_s, f_s, sigx_s, sigy_s = eval_slice(params, x_vals, y0=y0)\n","\n","x_np = np.asarray(XY_s[:,0])\n","fx_hat = np.asarray(f_s[:,0])\n","fy_hat = np.asarray(f_s[:,1])\n","sx_hat = np.asarray(sigx_s)\n","sy_hat = np.asarray(sigy_s)\n","\n","fx_true = truth[\"a1\"] / x_np\n","fy_true = truth[\"a2\"] * np.ones_like(x_np)\n","sig_true = truth[\"sigma\"] * np.ones_like(x_np)\n","\n","plt.figure(figsize=(7,4))\n","plt.plot(x_np, fx_hat, lw=2, label=r\"$\\hat f_x$\")\n","plt.plot(x_np, fx_true, lw=2, label=r\"true $a_1/x$\")\n","plt.xlabel(\"x\"); plt.ylabel(\"drift in x\")\n","plt.title(f\"Drift-x slice at y={y0:.2f} (comparison only)\")\n","plt.grid(True, alpha=0.3); plt.legend(); plt.tight_layout(); plt.show()\n","\n","plt.figure(figsize=(7,4))\n","plt.plot(x_np, fy_hat, lw=2, label=r\"$\\hat f_y$\")\n","plt.plot(x_np, fy_true, lw=2, ls=\"--\", label=r\"true $a_2$\")\n","plt.xlabel(\"x\"); plt.ylabel(\"drift in y\")\n","plt.title(f\"Drift-y slice at y={y0:.2f} (comparison only)\")\n","plt.grid(True, alpha=0.3); plt.legend(); plt.tight_layout(); plt.show()\n","\n","plt.figure(figsize=(7,4))\n","plt.plot(x_np, sx_hat, lw=2, label=r\"$\\hat\\sigma_x$\")\n","plt.plot(x_np, sy_hat, lw=2, label=r\"$\\hat\\sigma_y$\")\n","plt.plot(x_np, sig_true, lw=2, ls=\"--\", label=\"true σ\")\n","plt.xlabel(\"x\"); plt.ylabel(\"diffusion (marginal σ)\")\n","plt.title(f\"Diffusion slice at y={y0:.2f} (comparison only)\")\n","plt.grid(True, alpha=0.3); plt.legend(); plt.tight_layout(); plt.show()\n","\n","plt.figure(figsize=(6,4))\n","plt.plot(np.arange(1, cfg.steps+1), loss_hist, lw=2)\n","plt.xlabel(\"step\"); plt.ylabel(\"loss\")\n","plt.title(\"Training loss (no-prior drift+diffusion) — improved\")\n","plt.grid(True, alpha=0.3); plt.tight_layout(); plt.show()\n","\n","globals().update({\"cfg_ex5\": cfg, \"params_ex5\": params, \"in_norm_ex5\": in_norm})\n","print(\"\\nExported: cfg_ex5, params_ex5, in_norm_ex5.\")\n","\n","# === Animations: drift and diffusion vs x as time evolves (2D SDE surrogate) ===\n","# FIX: set y-limits BEFORE animating (blit=True doesn't autoscale after empty init)\n","\n","import jax\n","import jax.numpy as jnp\n","import numpy as np\n","import matplotlib.pyplot as plt\n","from matplotlib.animation import FuncAnimation\n","from IPython.display import HTML, display\n","\n","# ----------------- bind names expected by \"downstream\" style -----------------\n","params_sde = params_ex5\n","in_norm_sde = in_norm_ex5\n","cfg = cfg_ex5\n","\n","# ----------------- helpers ---------------------------------------------------\n","def _sigma_marginals_from_L(L_hat):\n","    Sigma = np.einsum(\"...ik,...jk->...ij\", L_hat, L_hat)\n","    sigx = np.sqrt(np.clip(Sigma[..., 0, 0], 1e-12, None))\n","    sigy = np.sqrt(np.clip(Sigma[..., 1, 1], 1e-12, None))\n","    sigxy = Sigma[..., 0, 1]\n","    return sigx, sigy, sigxy\n","\n","def eval_fx_fy_sigmas(params, in_norm, y0, x_min=0.25, x_max=2.5, n_points=200):\n","    x_vals = jnp.linspace(x_min, x_max, n_points)\n","    y_vals = jnp.full_like(x_vals, y0)\n","    XY_raw = jnp.stack([x_vals, y_vals], axis=-1)     # (N,2)\n","    XY_norm = in_norm(XY_raw)\n","\n","    f_hat, L_hat = drift_and_chol(params, XY_norm, cfg)  # f_hat (N,2), L_hat (N,2,2)\n","\n","    x_np = np.asarray(x_vals)\n","    fx_np = np.asarray(f_hat[:, 0])\n","    fy_np = np.asarray(f_hat[:, 1])\n","    L_np = np.asarray(L_hat)\n","\n","    sigx_np, sigy_np, sigxy_np = _sigma_marginals_from_L(L_np)\n","    return x_np, fx_np, fy_np, sigx_np, sigy_np, sigxy_np\n","\n","# ----------------- choose animation frames from the simulated time series -----\n","if (\"t\" in globals()) and (\"XY\" in globals()):\n","    t_array = np.asarray(t)\n","    XY_np = np.asarray(XY)  # (n_traj, Nt, 2)\n","    Nt = t_array.shape[0]\n","    y_med = np.median(XY_np[:, :, 1], axis=0)  # (Nt,)\n","\n","    n_frames = 100\n","    frame_ids = np.linspace(0, Nt - 1, n_frames).astype(int)\n","else:\n","    t_array = np.linspace(0.0, float(getattr(cfg, \"T\", 2.0)), 101)\n","    y_med = np.linspace(-1.0, 1.0, t_array.shape[0])\n","    n_frames = 100\n","    frame_ids = np.linspace(0, len(t_array) - 1, n_frames).astype(int)\n","\n","# x-range for plots (avoid x=0 singularity)\n","x_min, x_max = 0.25, 2.5\n","\n","# Optional truth overlays\n","have_truth = \"truth\" in globals()\n","a1_true = float(truth[\"a1\"]) if have_truth and (\"a1\" in truth) else None\n","a2_true = float(truth[\"a2\"]) if have_truth and (\"a2\" in truth) else None\n","sig_true = float(truth[\"sigma\"]) if have_truth and (\"sigma\" in truth) else 1.0\n","\n","# ----------------- PRECOMPUTE Y-LIMS (critical fix for blit=True) ------------\n","probe_k = min(10, len(frame_ids))\n","probe_ids = np.linspace(0, len(frame_ids) - 1, probe_k).astype(int)\n","\n","fx_all = []\n","sx_all = []\n","sy_all = []\n","\n","for kk in probe_ids:\n","    idx = frame_ids[kk]\n","    y0 = float(y_med[idx])\n","    x_vals, fx_hat, _, sigx_hat, sigy_hat, _ = eval_fx_fy_sigmas(\n","        params_sde, in_norm_sde, y0, x_min=x_min, x_max=x_max\n","    )\n","    fx_all.append(fx_hat)\n","    sx_all.append(sigx_hat)\n","    sy_all.append(sigy_hat)\n","\n","fx_all = np.concatenate(fx_all, axis=0)\n","sx_all = np.concatenate(sx_all, axis=0)\n","sy_all = np.concatenate(sy_all, axis=0)\n","\n","# ignore NaNs/Infs if any\n","fx_finite = fx_all[np.isfinite(fx_all)]\n","sx_finite = sx_all[np.isfinite(sx_all)]\n","sy_finite = sy_all[np.isfinite(sy_all)]\n","\n","# fallback if everything is non-finite (shouldn't happen, but keeps plot from crashing)\n","if fx_finite.size == 0:\n","    fx_lo, fx_hi = -1.0, 1.0\n","else:\n","    fx_lo, fx_hi = np.nanmin(fx_finite), np.nanmax(fx_finite)\n","\n","if sx_finite.size == 0 or sy_finite.size == 0:\n","    s_lo, s_hi = 0.0, 2.0\n","else:\n","    s_lo = min(np.nanmin(sx_finite), np.nanmin(sy_finite))\n","    s_hi = max(np.nanmax(sx_finite), np.nanmax(sy_finite))\n","\n","# add padding\n","fx_pad = 0.10 * (fx_hi - fx_lo + 1e-9)\n","s_pad  = 0.10 * (s_hi - s_lo + 1e-9)\n","\n","fx_ylim = (fx_lo - fx_pad, fx_hi + fx_pad)\n","s_ylim  = (s_lo - s_pad,  s_hi + s_pad)\n","\n","# ================== 1) Drift-x animation =====================================\n","fig_drift, ax_drift = plt.subplots(figsize=(6, 4))\n","line_fx, = ax_drift.plot([], [], lw=2, label=r\"learned $\\hat f_x(x,y)$\")\n","\n","if a1_true is not None:\n","    gt_fx_line, = ax_drift.plot([], [], lw=2, ls=\"--\", label=r\"true $a_1/x$\")\n","\n","ax_drift.set_xlim(x_min, x_max)\n","ax_drift.set_ylim(*fx_ylim)  # <-- FIX\n","ax_drift.set_xlabel(\"x\")\n","ax_drift.set_ylabel(\"drift-x\")\n","title_drift = ax_drift.set_title(\"\")\n","ax_drift.grid(True, alpha=0.3)\n","ax_drift.legend(loc=\"upper right\")\n","\n","def init_drift():\n","    line_fx.set_data([], [])\n","    if a1_true is not None:\n","        gt_fx_line.set_data([], [])\n","    title_drift.set_text(\"\")\n","    return (line_fx, gt_fx_line, title_drift) if a1_true is not None else (line_fx, title_drift)\n","\n","def update_drift(frame_k):\n","    idx = frame_ids[frame_k]\n","    t0 = float(t_array[idx])\n","    y0 = float(y_med[idx])\n","\n","    x_vals, fx_hat, _, _, _, _ = eval_fx_fy_sigmas(params_sde, in_norm_sde, y0, x_min=x_min, x_max=x_max)\n","    line_fx.set_data(x_vals, fx_hat)\n","\n","    if a1_true is not None:\n","        gt_fx_line.set_data(x_vals, a1_true / x_vals)\n","\n","    title_drift.set_text(rf\"$\\hat f_x(x,y)$ slice at $y={y0:.3f}$ (frame $t\\approx{t0:.2f}$)\")\n","    return (line_fx, gt_fx_line, title_drift) if a1_true is not None else (line_fx, title_drift)\n","\n","anim_drift = FuncAnimation(\n","    fig_drift,\n","    update_drift,\n","    init_func=init_drift,\n","    frames=len(frame_ids),\n","    interval=80,\n","    blit=True,\n",")\n","plt.close(fig_drift)\n","display(HTML(anim_drift.to_jshtml()))\n","\n","# ================== 2) Diffusion marginals animation =========================\n","fig_diff, ax_diff = plt.subplots(figsize=(6, 4))\n","line_sx, = ax_diff.plot([], [], lw=2, label=r\"learned $\\hat\\sigma_x(x,y)$\")\n","line_sy, = ax_diff.plot([], [], lw=2, label=r\"learned $\\hat\\sigma_y(x,y)$\")\n","ax_diff.axhline(sig_true, color=\"k\", linestyle=\"--\", label=rf\"true $\\sigma={sig_true}$\")\n","\n","ax_diff.set_xlim(x_min, x_max)\n","ax_diff.set_ylim(*s_ylim)  # <-- FIX\n","ax_diff.set_xlabel(\"x\")\n","ax_diff.set_ylabel(\"diffusion (marginal)\")\n","title_diff = ax_diff.set_title(\"\")\n","ax_diff.grid(True, alpha=0.3)\n","ax_diff.legend(loc=\"upper right\")\n","\n","def init_diff():\n","    line_sx.set_data([], [])\n","    line_sy.set_data([], [])\n","    title_diff.set_text(\"\")\n","    return line_sx, line_sy, title_diff\n","\n","def update_diff(frame_k):\n","    idx = frame_ids[frame_k]\n","    t0 = float(t_array[idx])\n","    y0 = float(y_med[idx])\n","\n","    x_vals, _, _, sigx_hat, sigy_hat, _ = eval_fx_fy_sigmas(params_sde, in_norm_sde, y0, x_min=x_min, x_max=x_max)\n","    line_sx.set_data(x_vals, sigx_hat)\n","    line_sy.set_data(x_vals, sigy_hat)\n","    title_diff.set_text(rf\"$\\hat\\sigma(x,y)$ marginals at $y={y0:.3f}$ (frame $t\\approx{t0:.2f}$)\")\n","    return line_sx, line_sy, title_diff\n","\n","anim_diff = FuncAnimation(\n","    fig_diff,\n","    update_diff,\n","    init_func=init_diff,\n","    frames=len(frame_ids),\n","    interval=80,\n","    blit=True,\n",")\n","plt.close(fig_diff)\n","display(HTML(anim_diff.to_jshtml()))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000,"output_embedded_package_id":"1RmFOZNtEczey7tPZDYupbzH95TuA_9xB"},"id":"Ff_tCtCMHDVG","outputId":"49de3127-144c-442d-92dc-4d6df7a2268b","executionInfo":{"status":"ok","timestamp":1769526325078,"user_tz":300,"elapsed":215613,"user":{"displayName":"Dacha Thurbur","userId":"16530582923668080051"}}},"execution_count":null,"outputs":[{"output_type":"display_data","data":{"text/plain":"Output hidden; open in https://colab.research.google.com to view."},"metadata":{}}]},{"cell_type":"code","source":["# ============================================================\n","# EX5 (2D) — Generator Nets + Losses S1–S7 + Training + Span Check\n","# (COPY/PASTE THIS WHOLE CELL)\n","#\n","# This cell is robust:\n","#   - Removes the need for any \"assert\" placeholder lines.\n","#   - Auto-creates TX_gen from (t, XY) if missing.\n","#   - Auto-defines mu_fn/sig_fn from drift_and_chol + params_ex5 + in_norm_ex5 if not already defined.\n","#   - Defines surrogate_f_sigma in the style some downstream code expects.\n","# ============================================================\n","\n","import math\n","from dataclasses import dataclass\n","import numpy as np\n","import jax\n","import jax.numpy as jnp\n","import optax\n","\n","jax.config.update(\"jax_enable_x64\", True)\n","\n","# ---------------------------\n","# 0) Normalizer tools\n","# ---------------------------\n","class Normalizer:\n","    def __init__(self, mean, std, eps=1e-8):\n","        self.mean = jnp.asarray(mean, dtype=jnp.float64)\n","        self.std  = jnp.asarray(std,  dtype=jnp.float64)\n","        self.eps  = float(eps)\n","    def __call__(self, x):\n","        x = jnp.asarray(x, dtype=jnp.float64)\n","        return (x - self.mean) / (self.std + self.eps)\n","\n","def fit_normalizer(X):\n","    X = jnp.asarray(X, dtype=jnp.float64)\n","    return Normalizer(jnp.mean(X, axis=0, keepdims=True), jnp.std(X, axis=0, keepdims=True))\n","\n","# ---------------------------\n","# 1) Simple MLP helpers (generator nets)\n","# ---------------------------\n","def init_mlp_params(key, sizes, scale=1e-1, dtype=jnp.float64):\n","    keys = jax.random.split(key, len(sizes) - 1)\n","    params = []\n","    for k, (din, dout) in zip(keys, zip(sizes[:-1], sizes[1:])):\n","        W = scale * jax.random.normal(k, (din, dout), dtype=dtype)\n","        b = jnp.zeros((dout,), dtype=dtype)\n","        params.append((W, b))\n","    return params\n","\n","def mlp_forward(params, x, activation=\"tanh\"):\n","    h = x\n","    for (W, b) in params[:-1]:\n","        h = h @ W + b\n","        if activation == \"tanh\":\n","            h = jnp.tanh(h)\n","        elif activation == \"relu\":\n","            h = jnp.maximum(h, 0)\n","        elif activation == \"gelu\":\n","            h = 0.5 * h * (1.0 + jax.lax.erf(h / jnp.sqrt(2.0)))\n","        else:\n","            raise ValueError(f\"Unknown activation: {activation}\")\n","    W, b = params[-1]\n","    return h @ W + b\n","\n","# ---------------------------\n","# 2) Ensure key_main exists\n","# ---------------------------\n","key_main = globals().get(\"key_main\", jax.random.PRNGKey(0))\n","\n","# ---------------------------\n","# 3) Ensure cfg_ex5/params_ex5/in_norm_ex5 exist (trained surrogate)\n","# ---------------------------\n","if \"cfg_ex5\" not in globals() or \"params_ex5\" not in globals() or \"in_norm_ex5\" not in globals():\n","    raise NameError(\n","        \"Missing cfg_ex5 / params_ex5 / in_norm_ex5 from your surrogate training.\\n\"\n","        \"Run your surrogate training cell first (the drift+chol MLP training).\"\n","    )\n","\n","cfg_ex5 = globals()[\"cfg_ex5\"]\n","params_ex5 = globals()[\"params_ex5\"]\n","in_norm_ex5 = globals()[\"in_norm_ex5\"]\n","\n","# Your surrogate forward (must exist from earlier cell)\n","if \"drift_and_chol\" not in globals():\n","    raise NameError(\"Missing drift_and_chol(params, x_norm, cfg). It should be defined in your surrogate cell.\")\n","\n","drift_and_chol = globals()[\"drift_and_chol\"]\n","\n","# ---------------------------\n","# 4) Ensure TX_gen exists (t,x,y triples for sampling)\n","# ---------------------------\n","TX_gen = globals().get(\"TX_gen\", None)\n","\n","def _build_TX_gen_from_sim():\n","    if (\"t\" in globals()) and (\"XY\" in globals()):\n","        t_arr = jnp.asarray(globals()[\"t\"], dtype=jnp.float64)                 # (Nt,)\n","        XY    = jnp.asarray(globals()[\"XY\"], dtype=jnp.float64)                # (n_traj, Nt, 2)\n","        n_traj, Nt, _ = XY.shape\n","        t_tile = jnp.tile(t_arr[None, :], (n_traj, 1))                         # (n_traj, Nt)\n","        TX = jnp.stack([t_tile, XY[..., 0], XY[..., 1]], axis=-1)              # (n_traj, Nt, 3)\n","        TX = TX.reshape(-1, 3)                                                 # (n_traj*Nt, 3)\n","        return TX\n","    return None\n","\n","if TX_gen is None:\n","    TX_gen = _build_TX_gen_from_sim()\n","    if TX_gen is None:\n","        print(\"[warn] TX_gen not found and (t,XY) not found. Will use uniform-box sampler only.\")\n","    else:\n","        print(\"[ok] Built TX_gen from (t, XY):\", TX_gen.shape)\n","    globals()[\"TX_gen\"] = TX_gen\n","\n","# ---------------------------\n","# 5) Build normalizers for generator nets: t_norm_gen, tx_norm_gen, txy_norm_gen\n","# ---------------------------\n","def _ensure_gen_normalizers(TX_gen):\n","    global t_norm_gen, tx_norm_gen, txy_norm_gen\n","\n","    if callable(globals().get(\"t_norm_gen\", None)) and callable(globals().get(\"tx_norm_gen\", None)):\n","        t_norm_gen = globals()[\"t_norm_gen\"]\n","        tx_norm_gen = globals()[\"tx_norm_gen\"]\n","        txy_norm_gen = globals().get(\"txy_norm_gen\", tx_norm_gen)\n","        globals()[\"txy_norm_gen\"] = txy_norm_gen\n","        return\n","\n","    if TX_gen is None:\n","        def _id_t(Z):  return jnp.asarray(Z, dtype=jnp.float64)\n","        def _id_tx(Z): return jnp.asarray(Z, dtype=jnp.float64)\n","        t_norm_gen = _id_t\n","        tx_norm_gen = _id_tx\n","        txy_norm_gen = _id_tx\n","        globals()[\"t_norm_gen\"] = t_norm_gen\n","        globals()[\"tx_norm_gen\"] = tx_norm_gen\n","        globals()[\"txy_norm_gen\"] = txy_norm_gen\n","        print(\"[warn] Using identity generator normalizers (no TX_gen).\")\n","        return\n","\n","    TX_np = np.asarray(jax.device_get(jnp.asarray(TX_gen, dtype=jnp.float64)))\n","    TX_np = TX_np[np.isfinite(TX_np).all(axis=1)]\n","    if TX_np.shape[0] == 0:\n","        def _id_t(Z):  return jnp.asarray(Z, dtype=jnp.float64)\n","        def _id_tx(Z): return jnp.asarray(Z, dtype=jnp.float64)\n","        t_norm_gen = _id_t\n","        tx_norm_gen = _id_tx\n","        txy_norm_gen = _id_tx\n","        globals()[\"t_norm_gen\"] = t_norm_gen\n","        globals()[\"tx_norm_gen\"] = tx_norm_gen\n","        globals()[\"txy_norm_gen\"] = txy_norm_gen\n","        print(\"[warn] TX_gen had no finite rows; using identity normalizers.\")\n","        return\n","\n","    t_np   = TX_np[:, [0]]          # (N,1)\n","    txy_np = TX_np[:, :3]           # (N,3)\n","\n","    t_norm_gen  = fit_normalizer(t_np)\n","    tx_norm_gen = fit_normalizer(txy_np)\n","    txy_norm_gen = tx_norm_gen      # IMPORTANT alias for S6\n","\n","    globals()[\"t_norm_gen\"] = t_norm_gen\n","    globals()[\"tx_norm_gen\"] = tx_norm_gen\n","    globals()[\"txy_norm_gen\"] = txy_norm_gen\n","\n","    try:\n","        print(\"Gen t_norm mean/std:\", np.asarray(jax.device_get(t_norm_gen.mean)).ravel(),\n","              np.asarray(jax.device_get(t_norm_gen.std)).ravel())\n","        print(\"Gen tx_norm mean/std:\", np.asarray(jax.device_get(tx_norm_gen.mean)).ravel(),\n","              np.asarray(jax.device_get(tx_norm_gen.std)).ravel())\n","    except Exception:\n","        pass\n","\n","_ensure_gen_normalizers(TX_gen)\n","\n","# ---------------------------\n","# 6) Auto-wire mu_fn, sig_fn from the trained surrogate (NO need for surrogate_f_sigma)\n","# ---------------------------\n","# Your surrogate is time-homogeneous: drift_and_chol expects x_norm=(x,y) only; t is ignored.\n","def mu_fn(t, x, y):\n","    x = jnp.asarray(x, dtype=jnp.float64)\n","    y = jnp.asarray(y, dtype=jnp.float64)\n","    XY_raw = jnp.stack([x, y], axis=-1)                  # (...,2)\n","    XY_norm = in_norm_ex5(XY_raw)                        # (...,2)\n","    f_hat, _L = drift_and_chol(params_ex5, XY_norm, cfg_ex5)\n","    return jnp.asarray(f_hat, dtype=jnp.float64)         # (...,2)\n","\n","def sig_fn(t, x, y):\n","    x = jnp.asarray(x, dtype=jnp.float64)\n","    y = jnp.asarray(y, dtype=jnp.float64)\n","    XY_raw = jnp.stack([x, y], axis=-1)                  # (...,2)\n","    XY_norm = in_norm_ex5(XY_raw)                        # (...,2)\n","    _f_hat, L = drift_and_chol(params_ex5, XY_norm, cfg_ex5)\n","    return jnp.asarray(L, dtype=jnp.float64)             # (...,2,2) used as σ\n","\n","globals()[\"mu_fn\"] = mu_fn\n","globals()[\"sig_fn\"] = sig_fn\n","\n","# Optional compatibility wrapper some older code expects:\n","def surrogate_f_sigma(params_sde, in_norm_sde, t, x, y, return_L=False):\n","    # params_sde/in_norm_sde are ignored here; we use params_ex5/in_norm_ex5 by design.\n","    f = mu_fn(t, x, y)\n","    L = sig_fn(t, x, y)\n","    Sigma = jnp.einsum(\"...ik,...jk->...ij\", L, L)\n","    if return_L:\n","        return f, Sigma, L\n","    return f, Sigma\n","\n","globals()[\"surrogate_f_sigma\"] = surrogate_f_sigma\n","globals()[\"params_sde\"] = params_ex5\n","globals()[\"in_norm_sde\"] = in_norm_ex5\n","\n","# ---------------------------\n","# 7) Generator network config + init\n","# ---------------------------\n","@dataclass\n","class GenConfig:\n","    n_generators: int = 3\n","    hidden_tau: int = 32\n","    hidden_xi: int = 64\n","    hidden_beta: int = 64\n","    activation: str = \"tanh\"\n","\n","gen_cfg = globals().get(\"gen_cfg\", GenConfig())\n","gen_cfg = GenConfig(**{**gen_cfg.__dict__, \"n_generators\": int(getattr(gen_cfg, \"n_generators\", 3))})\n","m = int(gen_cfg.n_generators)\n","\n","def tau_forward(params, t_norm, activation=\"tanh\"):\n","    return mlp_forward(params, t_norm, activation=activation)[..., 0:1]\n","\n","def xi_forward(params, txy_norm, activation=\"tanh\"):\n","    return mlp_forward(params, txy_norm, activation=activation)[..., 0:2]   # xi=(xi_x,xi_y)\n","\n","def beta_forward(params, txy_norm, activation=\"tanh\"):\n","    return mlp_forward(params, txy_norm, activation=activation)[..., 0:1]\n","\n","def init_generator_params(key, gen_cfg: GenConfig):\n","    keys = jax.random.split(key, 3 * gen_cfg.n_generators)\n","    params_tau, params_xi, params_beta = [], [], []\n","    for i in range(gen_cfg.n_generators):\n","        k_tau, k_xi, k_beta = keys[3*i], keys[3*i+1], keys[3*i+2]\n","        params_tau.append(init_mlp_params(k_tau,  [1, gen_cfg.hidden_tau,  gen_cfg.hidden_tau,  1]))\n","        params_xi.append(init_mlp_params(k_xi,   [3, gen_cfg.hidden_xi,   gen_cfg.hidden_xi,   2]))\n","        params_beta.append(init_mlp_params(k_beta,[3, gen_cfg.hidden_beta, gen_cfg.hidden_beta, 1]))\n","    return {\"tau\": params_tau, \"xi\": params_xi, \"beta\": params_beta}\n","\n","params_gen = globals().get(\"params_gen\", None)\n","if params_gen is None:\n","    key_main, key_gen = jax.random.split(key_main, 2)\n","    params_gen = init_generator_params(key_gen, gen_cfg)\n","    globals()[\"params_gen\"] = params_gen\n","\n","# ---------------------------\n","# 8) Evaluators\n","# ---------------------------\n","def eval_generators(params_gen, t, x, y=None, u=None, activation=None, return_phi=False):\n","    if activation is None:\n","        activation = gen_cfg.activation\n","    t_arr = jnp.asarray(t, dtype=jnp.float64)\n","    x_arr = jnp.asarray(x, dtype=jnp.float64)\n","    y_arr = jnp.zeros_like(x_arr) if (y is None) else jnp.asarray(y, dtype=jnp.float64)\n","    t_b, x_b, y_b = jnp.broadcast_arrays(t_arr, x_arr, y_arr)\n","    B_shape = t_b.shape\n","\n","    t_flat  = t_b.reshape(-1, 1)\n","    x_flat  = x_b.reshape(-1, 1)\n","    y_flat  = y_b.reshape(-1, 1)\n","    txy_flat = jnp.concatenate([t_flat, x_flat, y_flat], axis=1)  # (B,3)\n","\n","    t_norm   = t_norm_gen(t_flat)\n","    txy_norm = tx_norm_gen(txy_flat)\n","\n","    tau_list, xi_list, beta_list = [], [], []\n","    for p_tau, p_xi, p_beta in zip(params_gen[\"tau\"], params_gen[\"xi\"], params_gen[\"beta\"]):\n","        tau_flat = tau_forward(p_tau, t_norm, activation=activation)      # (B,1)\n","        xi_flat  = xi_forward(p_xi,  txy_norm, activation=activation)     # (B,2)\n","        be_flat  = beta_forward(p_beta, txy_norm, activation=activation)  # (B,1)\n","        tau_list.append(tau_flat.reshape(B_shape))\n","        xi_list.append(xi_flat.reshape(B_shape + (2,)))\n","        beta_list.append(be_flat.reshape(B_shape))\n","\n","    tau_vals  = jnp.stack(tau_list, axis=0)   # (m, ...)\n","    xi_vals   = jnp.stack(xi_list, axis=0)    # (m, ..., 2)\n","    beta_vals = jnp.stack(beta_list, axis=0)  # (m, ...)\n","\n","    if return_phi:\n","        if u is None:\n","            raise ValueError(\"return_phi=True requires u.\")\n","        u_arr = jnp.asarray(u, dtype=jnp.float64)\n","        u_b = jnp.broadcast_to(u_arr, B_shape)\n","        phi_vals = beta_vals * u_b[None, ...]\n","        return tau_vals, xi_vals, beta_vals, phi_vals\n","    return tau_vals, xi_vals, beta_vals\n","\n","eval_generators_jit = jax.jit(eval_generators, static_argnames=(\"activation\", \"return_phi\"))\n","\n","def eval_generators_tau_xi(params_gen, t, x, y=None, *, activation=\"tanh\", normalize_txy=None):\n","    t = jnp.asarray(t, dtype=jnp.float64)\n","    x = jnp.asarray(x, dtype=jnp.float64)\n","    y = jnp.zeros_like(x) if (y is None) else jnp.asarray(y, dtype=jnp.float64)\n","    if t.ndim != 1 or x.ndim != 1 or y.ndim != 1:\n","        raise ValueError(\"eval_generators_tau_xi expects 1D arrays (B,) for t,x,y\")\n","\n","    if normalize_txy is None:\n","        t_raw, x_raw, y_raw = t, x, y\n","    else:\n","        t_raw, x_raw, y_raw = normalize_txy(t, x, y)\n","        t_raw = jnp.asarray(t_raw, dtype=jnp.float64)\n","        x_raw = jnp.asarray(x_raw, dtype=jnp.float64)\n","        y_raw = jnp.asarray(y_raw, dtype=jnp.float64)\n","\n","    t_col   = t_raw.reshape(-1, 1)\n","    x_col   = x_raw.reshape(-1, 1)\n","    y_col   = y_raw.reshape(-1, 1)\n","    txy_col = jnp.concatenate([t_col, x_col, y_col], axis=1)  # (B,3)\n","\n","    tN   = t_norm_gen(t_col)\n","    txyN = tx_norm_gen(txy_col)\n","\n","    taus, xis = [], []\n","    for p_tau, p_xi in zip(params_gen[\"tau\"], params_gen[\"xi\"]):\n","        tau = tau_forward(p_tau, tN, activation=activation).reshape(-1)  # (B,)\n","        xi  = xi_forward(p_xi,  txyN, activation=activation)             # (B,2)\n","        taus.append(tau)\n","        xis.append(xi)\n","\n","    tau_all = jnp.stack(taus, axis=0)  # (m,B)\n","    xi_all  = jnp.stack(xis,  axis=0)  # (m,B,2)\n","    return tau_all, xi_all\n","\n","eval_generators_tau_xi_jit = jax.jit(eval_generators_tau_xi, static_argnames=(\"activation\", \"normalize_txy\"))\n","\n","print(f\"[OK] Generator nets ready: m={m}, xi is 2D, normalizers set (t_norm_gen, tx_norm_gen, txy_norm_gen).\")\n","\n","# ============================================================\n","# 9) Losses S1–S7 (2D-correct)\n","# ============================================================\n","\n","def _ordered_pair_indices(n: int):\n","    idx = jnp.arange(n, dtype=jnp.int32)\n","    ii  = jnp.repeat(idx, repeats=n-1)\n","    base = jnp.arange(n - 1, dtype=jnp.int32)\n","    i_col = idx[:, None]\n","    jj_mat = base + (base >= i_col).astype(jnp.int32)\n","    jj = jj_mat.reshape(-1)\n","    return ii, jj\n","\n","def make_s1_lie_loss(n_generators: int, rcond: float = 1e-6):\n","    idx_i, idx_j = _ordered_pair_indices(n_generators)\n","    reg = jnp.asarray(rcond, dtype=jnp.float64) ** 2\n","\n","    def _tau_val_and_dt(params_tau_i, t_scalar):\n","        def tau_scalar(tt):\n","            t_arr = jnp.asarray([[tt]], dtype=jnp.float64)\n","            tN = t_norm_gen(t_arr)\n","            out = tau_forward(params_tau_i, tN, activation=gen_cfg.activation)\n","            return out[0, 0]\n","        return tau_scalar(t_scalar), jax.grad(tau_scalar)(t_scalar)\n","\n","    def _xi_val_and_derivs(params_xi_i, t_scalar, x_scalar, y_scalar):\n","        def xi_vec(tt, xx, yy):\n","            txy = jnp.asarray([[tt, xx, yy]], dtype=jnp.float64)\n","            txyN = tx_norm_gen(txy)\n","            out = xi_forward(params_xi_i, txyN, activation=gen_cfg.activation)  # (1,2)\n","            return out[0]  # (2,)\n","        xi_val = xi_vec(t_scalar, x_scalar, y_scalar)\n","        xi_t = jnp.stack([jax.grad(lambda tt: xi_vec(tt, x_scalar, y_scalar)[c])(t_scalar) for c in (0,1)], axis=0)\n","        xi_x = jnp.stack([jax.grad(lambda xx: xi_vec(t_scalar, xx, y_scalar)[c])(x_scalar) for c in (0,1)], axis=0)\n","        xi_y = jnp.stack([jax.grad(lambda yy: xi_vec(t_scalar, x_scalar, yy)[c])(y_scalar) for c in (0,1)], axis=0)\n","        return xi_val, xi_t, xi_x, xi_y\n","\n","    def _fields_and_derivs_at_point(params_gen, t_scalar, x_scalar, y_scalar):\n","        tau_list, tau_t_list = [], []\n","        xi_list, xi_t_list, xi_x_list, xi_y_list = [], [], [], []\n","        for p_tau, p_xi in zip(params_gen[\"tau\"], params_gen[\"xi\"]):\n","            tau_i, tau_t_i = _tau_val_and_dt(p_tau, t_scalar)\n","            xi_i, xi_t_i, xi_x_i, xi_y_i = _xi_val_and_derivs(p_xi, t_scalar, x_scalar, y_scalar)\n","            tau_list.append(tau_i); tau_t_list.append(tau_t_i)\n","            xi_list.append(xi_i);   xi_t_list.append(xi_t_i)\n","            xi_x_list.append(xi_x_i); xi_y_list.append(xi_y_i)\n","        tau   = jnp.stack(tau_list, axis=0)      # (m,)\n","        tau_t = jnp.stack(tau_t_list, axis=0)    # (m,)\n","        xi    = jnp.stack(xi_list, axis=0)       # (m,2)\n","        xi_t  = jnp.stack(xi_t_list, axis=0)     # (m,2)\n","        xi_x  = jnp.stack(xi_x_list, axis=0)     # (m,2)\n","        xi_y  = jnp.stack(xi_y_list, axis=0)     # (m,2)\n","        return tau, xi, tau_t, xi_t, xi_x, xi_y\n","\n","    def _point_err_and_C(tau, xi, tau_t, xi_t, xi_x, xi_y):\n","        xi_xc, xi_yc = xi[:,0], xi[:,1]\n","        V = jnp.stack([tau, xi_xc, xi_yc], axis=0)  # (3,m)\n","\n","        tau_i, tau_j = tau[idx_i], tau[idx_j]\n","        tau_t_i, tau_t_j = tau_t[idx_i], tau_t[idx_j]\n","\n","        xi_i, xi_j = xi[idx_i], xi[idx_j]\n","        xi_t_i, xi_t_j = xi_t[idx_i], xi_t[idx_j]\n","        xi_x_i, xi_x_j = xi_x[idx_i], xi_x[idx_j]\n","        xi_y_i, xi_y_j = xi_y[idx_i], xi_y[idx_j]\n","\n","        a = tau_i * tau_t_j - tau_j * tau_t_i\n","\n","        b = (tau_i * xi_t_j[:,0] + xi_i[:,0]*xi_x_j[:,0] + xi_i[:,1]*xi_y_j[:,0]\n","             -tau_j * xi_t_i[:,0] - xi_j[:,0]*xi_x_i[:,0] - xi_j[:,1]*xi_y_i[:,0])\n","\n","        c = (tau_i * xi_t_j[:,1] + xi_i[:,0]*xi_x_j[:,1] + xi_i[:,1]*xi_y_j[:,1]\n","             -tau_j * xi_t_i[:,1] - xi_j[:,0]*xi_x_i[:,1] - xi_j[:,1]*xi_y_i[:,1])\n","\n","        B = jnp.stack([a,b,c], axis=0)  # (3,K)\n","\n","        G = V @ V.T\n","        G_reg = G + reg * jnp.eye(3, dtype=G.dtype)\n","        X = jnp.linalg.solve(G_reg, B)   # (3,K)\n","        C = V.T @ X                      # (m,K)\n","        P_B = V @ C                      # (3,K)\n","        E = B - P_B\n","        err = jnp.sum(jnp.abs(E))\n","        return err, C\n","\n","    def _loss_impl(params_gen, txy_batch):\n","        def eval_at_z(z):\n","            return _fields_and_derivs_at_point(params_gen, z[0], z[1], z[2])\n","        taus, xis, tau_ts, xi_ts, xi_xs, xi_ys = jax.vmap(eval_at_z)(txy_batch)\n","        errs, Cs = jax.vmap(_point_err_and_C)(taus, xis, tau_ts, xi_ts, xi_xs, xi_ys)\n","        error_sum = jnp.sum(errs)\n","        var_sum = jnp.sum(jnp.var(Cs, axis=0))\n","        return error_sum + var_sum, {\"error_sum\": error_sum, \"var_sum\": var_sum}\n","\n","    return jax.jit(_loss_impl)\n","\n","def make_s2_jacobi_loss_nested(n_generators: int):\n","    triples = [(i,j,k) for i in range(n_generators) for j in range(i+1,n_generators) for k in range(j+1,n_generators)]\n","    if not triples:\n","        def _zero(params_gen, txy_batch):\n","            return jnp.array(0.0, dtype=jnp.float64), {\"per_point\": jnp.zeros((txy_batch.shape[0],), dtype=jnp.float64), \"num_triples\": 0}\n","        return jax.jit(_zero)\n","\n","    tri_i = jnp.array([t[0] for t in triples], dtype=jnp.int32)\n","    tri_j = jnp.array([t[1] for t in triples], dtype=jnp.int32)\n","    tri_k = jnp.array([t[2] for t in triples], dtype=jnp.int32)\n","    perms6 = jnp.array([[0,1,2],[0,2,1],[1,0,2],[1,2,0],[2,0,1],[2,1,0]], dtype=jnp.int32)\n","\n","    def _fields_jac_hess(params_gen, z):\n","        F_list, J_list, H_list = [], [], []\n","        for p_tau, p_xi in zip(params_gen[\"tau\"], params_gen[\"xi\"]):\n","            def Xi(zz):\n","                tt, xx, yy = zz[0], zz[1], zz[2]\n","                t_arr  = jnp.asarray([[tt]], dtype=jnp.float64)\n","                txy_arr= jnp.asarray([[tt,xx,yy]], dtype=jnp.float64)\n","                tN  = t_norm_gen(t_arr)\n","                txyN= tx_norm_gen(txy_arr)\n","                tau = tau_forward(p_tau, tN, activation=gen_cfg.activation)[0,0]\n","                xi  = xi_forward(p_xi, txyN, activation=gen_cfg.activation)[0]  # (2,)\n","                return jnp.array([tau, xi[0], xi[1]], dtype=jnp.float64)\n","            Fi = Xi(z)\n","            Ji = jax.jacobian(Xi)(z)\n","            Hi = jax.jacobian(lambda zz: jax.jacobian(Xi)(zz))(z)\n","            F_list.append(Fi); J_list.append(Ji); H_list.append(Hi)\n","        return jnp.stack(F_list,0), jnp.stack(J_list,0), jnp.stack(H_list,0)\n","\n","    def _bracket_val(F,J,p,q):\n","        return (J[q] @ F[p]) - (J[p] @ F[q])\n","\n","    def _dir_along(F,J,H,r,p,q):\n","        Jr,Jp,Jq = J[r],J[p],J[q]\n","        Hp,Hq = H[p],H[q]\n","        fr,fp,fq = F[r],F[p],F[q]\n","        t1 = Jq @ (Jp @ fr)\n","        t2 = ((Hq * fr[None,None,:]).sum(axis=2)) @ fp\n","        t3 = Jp @ (Jq @ fr)\n","        t4 = ((Hp * fr[None,None,:]).sum(axis=2)) @ fq\n","        return t1 + t2 - t3 - t4\n","\n","    def _double_bracket(F,J,H,r,p,q):\n","        inner = _bracket_val(F,J,p,q)\n","        return _dir_along(F,J,H,r,p,q) - (J[r] @ inner)\n","\n","    def _jacobi_one_order(F,J,H,u,v,w):\n","        return _double_bracket(F,J,H,u,v,w) + _double_bracket(F,J,H,w,u,v) + _double_bracket(F,J,H,v,w,u)\n","\n","    def _triple_sum_over_6(F,J,H,i,j,k):\n","        inds = jnp.array([i,j,k], dtype=jnp.int32)\n","        def _one_perm(p):\n","            u,v,w = inds[p[0]], inds[p[1]], inds[p[2]]\n","            r = _jacobi_one_order(F,J,H,u,v,w)\n","            return jnp.sum(jnp.abs(r))\n","        return jnp.sum(jax.vmap(_one_perm)(perms6))\n","\n","    def _point_loss(params_gen, z):\n","        F,J,H = _fields_jac_hess(params_gen, z)\n","        per_tr = jax.vmap(lambda a,b,c: _triple_sum_over_6(F,J,H,a,b,c))(tri_i, tri_j, tri_k)\n","        return jnp.sum(per_tr)\n","\n","    _pl = jax.jit(_point_loss)\n","\n","    def _loss_impl(params_gen, txy_batch):\n","        per_point = jax.vmap(lambda z: _pl(params_gen, z))(txy_batch)\n","        return jnp.sum(per_point), {\"per_point\": per_point, \"num_triples\": int(tri_i.shape[0])}\n","\n","    return jax.jit(_loss_impl)\n","\n","def make_s3_skewsym_loss(n_generators: int):\n","    pairs = [(i,j) for i in range(n_generators) for j in range(i+1,n_generators)]\n","    if not pairs:\n","        def _zero(params_gen, txy_batch):\n","            return jnp.array(0.0, dtype=jnp.float64), {\"per_point\": jnp.zeros((txy_batch.shape[0],), dtype=jnp.float64), \"num_pairs\": 0}\n","        return jax.jit(_zero)\n","\n","    pi = jnp.array([p[0] for p in pairs], dtype=jnp.int32)\n","    pj = jnp.array([p[1] for p in pairs], dtype=jnp.int32)\n","\n","    def _fields_and_jac(params_gen, z):\n","        F_list, J_list = [], []\n","        for p_tau, p_xi in zip(params_gen[\"tau\"], params_gen[\"xi\"]):\n","            def Xi(zz):\n","                tt,xx,yy = zz[0], zz[1], zz[2]\n","                t_arr = jnp.asarray([[tt]], dtype=jnp.float64)\n","                txy_arr= jnp.asarray([[tt,xx,yy]], dtype=jnp.float64)\n","                tN  = t_norm_gen(t_arr)\n","                txyN= tx_norm_gen(txy_arr)\n","                tau = tau_forward(p_tau, tN, activation=gen_cfg.activation)[0,0]\n","                xi  = xi_forward(p_xi, txyN, activation=gen_cfg.activation)[0]\n","                return jnp.array([tau, xi[0], xi[1]], dtype=jnp.float64)\n","            Fi = Xi(z)\n","            Ji = jax.jacobian(Xi)(z)\n","            F_list.append(Fi); J_list.append(Ji)\n","        return jnp.stack(F_list,0), jnp.stack(J_list,0)\n","\n","    def _bracket(F,J,p,q):\n","        return (J[q] @ F[p]) - (J[p] @ F[q])\n","\n","    def _point_loss(params_gen, z):\n","        F,J = _fields_and_jac(params_gen, z)\n","        def one(i,j):\n","            r = _bracket(F,J,i,j) + _bracket(F,J,j,i)\n","            return jnp.sum(jnp.abs(r))\n","        return jnp.sum(jax.vmap(one)(pi,pj))\n","\n","    _pl = jax.jit(_point_loss)\n","\n","    def _loss_impl(params_gen, txy_batch):\n","        per_point = jax.vmap(lambda z: _pl(params_gen, z))(txy_batch)\n","        return jnp.sum(per_point), {\"per_point\": per_point, \"num_pairs\": int(pi.shape[0])}\n","\n","    return jax.jit(_loss_impl)\n","\n","def make_s4_bilinearity_loss(n_generators: int, num_cc: int = 4, cc_list=None, normalize: bool = True):\n","    triples = [(i,j,k) for i in range(n_generators) for j in range(i+1,n_generators) for k in range(j+1,n_generators)]\n","    if not triples:\n","        def _zero(params_gen, txy_batch, key=None):\n","            return jnp.array(0.0, dtype=jnp.float64), {\"per_point\": jnp.zeros((txy_batch.shape[0],), dtype=jnp.float64)}\n","        return jax.jit(_zero)\n","\n","    tri_i = jnp.array([t[0] for t in triples], dtype=jnp.int32)\n","    tri_j = jnp.array([t[1] for t in triples], dtype=jnp.int32)\n","    tri_k = jnp.array([t[2] for t in triples], dtype=jnp.int32)\n","    perms6 = jnp.array([[0,1,2],[0,2,1],[1,0,2],[1,2,0],[2,0,1],[2,1,0]], dtype=jnp.int32)\n","\n","    if cc_list is not None:\n","        cc_const = jnp.asarray(cc_list, dtype=jnp.float64)\n","    else:\n","        cc_const = jax.random.uniform(jax.random.PRNGKey(0), (num_cc,2), minval=-1.0, maxval=1.0, dtype=jnp.float64)\n","\n","    def _fields_and_jac(params_gen, z):\n","        F_list, J_list = [], []\n","        for p_tau, p_xi in zip(params_gen[\"tau\"], params_gen[\"xi\"]):\n","            def Xi(zz):\n","                tt,xx,yy = zz[0], zz[1], zz[2]\n","                t_arr = jnp.asarray([[tt]], dtype=jnp.float64)\n","                txy_arr= jnp.asarray([[tt,xx,yy]], dtype=jnp.float64)\n","                tN = t_norm_gen(t_arr)\n","                txyN= tx_norm_gen(txy_arr)\n","                tau = tau_forward(p_tau, tN, activation=gen_cfg.activation)[0,0]\n","                xi  = xi_forward(p_xi, txyN, activation=gen_cfg.activation)[0]\n","                return jnp.array([tau, xi[0], xi[1]], dtype=jnp.float64)\n","            Fi = Xi(z)\n","            Ji = jax.jacobian(Xi)(z)\n","            F_list.append(Fi); J_list.append(Ji)\n","        return jnp.stack(F_list,0), jnp.stack(J_list,0)\n","\n","    def _bracket(F,J,p,q):\n","        return (J[q] @ F[p]) - (J[p] @ F[q])\n","\n","    def _triple_terms(F,J,i,j,k,cc):\n","        inds = jnp.array([i,j,k], dtype=jnp.int32)\n","        def one_perm(p):\n","            u,v,w = inds[p[0]], inds[p[1]], inds[p[2]]\n","            fu,fv,fw = F[u],F[v],F[w]\n","            Ju,Jv,Jw = J[u],J[v],J[w]\n","\n","            def one_cc(cpair):\n","                c,cp = cpair[0], cpair[1]\n","                f_uv = c*fu + cp*fv\n","                J_uv = c*Ju + cp*Jv\n","                f_vw = c*fv + cp*fw\n","                J_vw = c*Jv + cp*Jw\n","\n","                term1 = (Jw @ f_uv) - (J_uv @ fw)\n","                rhs1  = c*_bracket(F,J,u,w) + cp*_bracket(F,J,v,w)\n","                r1 = term1 - rhs1\n","\n","                term2 = (J_vw @ fu) - (Ju @ f_vw)\n","                rhs2  = c*_bracket(F,J,u,v) + cp*_bracket(F,J,u,w)\n","                r2 = term2 - rhs2\n","\n","                if normalize:\n","                    denom = jnp.abs(c) + jnp.abs(cp) + 1e-12\n","                    r1 = r1/denom; r2 = r2/denom\n","                return jnp.sum(jnp.abs(r1)) + jnp.sum(jnp.abs(r2))\n","\n","            return jnp.mean(jax.vmap(one_cc)(cc))\n","        return jnp.sum(jax.vmap(one_perm)(perms6))\n","\n","    def _point_loss(params_gen, z, cc):\n","        F,J = _fields_and_jac(params_gen, z)\n","        per_tr = jax.vmap(lambda a,b,c: _triple_terms(F,J,a,b,c,cc))(tri_i, tri_j, tri_k)\n","        return jnp.sum(per_tr)\n","\n","    _pl = jax.jit(_point_loss)\n","\n","    def _loss_impl(params_gen, txy_batch, key=None):\n","        cc = cc_const\n","        per_point = jax.vmap(lambda z: _pl(params_gen, z, cc))(txy_batch)\n","        return jnp.sum(per_point), {\"per_point\": per_point, \"num_triples\": int(tri_i.shape[0]), \"num_cc\": int(cc.shape[0])}\n","\n","    return jax.jit(_loss_impl)\n","\n","def make_s5_column_independence_loss(n_generators: int, *, mode: str = \"sigma\", tau: float = 0.0, eps: float = 1e-12):\n","    mode = \"sigma\" if mode == \"sigma\" else \"corr_l2\"\n","    mode_code = 0 if mode == \"sigma\" else 1\n","\n","    def _A_from_batch(params_gen, txy_batch):\n","        tB = txy_batch[:,0]\n","        xB = txy_batch[:,1]\n","        yB = txy_batch[:,2]\n","        tau_vals, xi_vals, _ = eval_generators_jit(params_gen, tB, xB, yB)\n","        comp = jnp.concatenate([tau_vals[...,None], xi_vals], axis=2)  # (m,N,3)\n","        comp_N3m = jnp.transpose(comp, (1,2,0))                        # (N,3,m)\n","        return comp_N3m.reshape(-1, n_generators)                      # (3N,m)\n","\n","    def _loss_impl(params_gen, txy_batch):\n","        A = _A_from_batch(params_gen, txy_batch)\n","        col_norms = jnp.linalg.norm(A, axis=0) + eps\n","        Ahat = A / col_norms\n","        G = Ahat.T @ Ahat\n","        if mode_code == 0:\n","            lam = jnp.linalg.eigvalsh(G)\n","            sigma_min = jnp.sqrt(jnp.clip(jnp.min(lam), 0.0, None))\n","            loss = jnp.maximum(0.0, jnp.asarray(tau, dtype=G.dtype) - sigma_min)\n","            return loss, {\"sigma_min\": sigma_min}\n","        else:\n","            I = jnp.eye(G.shape[0], dtype=G.dtype)\n","            off = G - I\n","            off = off - jnp.diag(jnp.diag(off))\n","            return jnp.sum(off*off), {\"gram_diag_mean\": jnp.mean(jnp.diag(G))}\n","\n","    return jax.jit(_loss_impl)\n","\n","def make_s6_commutator_loss_ito(*, mu_fn, sig_fn, use_abs: bool = False):\n","    def tau_val_and_dt(params_tau, t_scalar):\n","        def tau_scalar(tt):\n","            t_arr = jnp.asarray([[tt]], dtype=jnp.float64)\n","            tN = t_norm_gen(t_arr)\n","            out = tau_forward(params_tau, tN, activation=gen_cfg.activation)\n","            return out[0,0]\n","        return tau_scalar(t_scalar), jax.grad(tau_scalar)(t_scalar)\n","\n","    def xi_val_jac_hess(params_xi, t_scalar, x_scalar, y_scalar):\n","        def Xi(z):\n","            tt,xx,yy = z[0],z[1],z[2]\n","            txy = jnp.asarray([[tt,xx,yy]], dtype=jnp.float64)\n","            txyN = txy_norm_gen(txy)\n","            out = xi_forward(params_xi, txyN, activation=gen_cfg.activation)\n","            return out[0,:]  # (2,)\n","        z0 = jnp.array([t_scalar,x_scalar,y_scalar], dtype=jnp.float64)\n","        xi = Xi(z0)\n","        Jz = jax.jacobian(Xi)(z0)                          # (2,3)\n","        Hz = jax.jacobian(lambda z: jax.jacobian(Xi)(z))(z0)  # (2,3,3)\n","        Hxy = Hz[:,1:,1:]                                 # (2,2,2)\n","        return xi, Jz, Hxy\n","\n","    def f_sigma_and_derivs(t_scalar, x_scalar, y_scalar):\n","        def F(z):\n","            tt,xx,yy = z[0],z[1],z[2]\n","            return mu_fn(tt,xx,yy)  # (2,)\n","        z0 = jnp.array([t_scalar,x_scalar,y_scalar], dtype=jnp.float64)\n","        f = F(z0)\n","        Jz_f = jax.jacobian(F)(z0)                         # (2,3)\n","        f_t = Jz_f[:,0]\n","        Jf  = Jz_f[:,1:]                                   # (2,2)\n","\n","        def S_flat(z):\n","            tt,xx,yy = z[0],z[1],z[2]\n","            S = jnp.asarray(sig_fn(tt,xx,yy), dtype=jnp.float64)  # (2,m) or (2,2)\n","            if S.ndim != 2 or S.shape[0] != 2:\n","                raise ValueError(f\"sig_fn must return (2,m); got {S.shape}\")\n","            return S.reshape(-1)\n","        s_flat = S_flat(z0)\n","        Jz_s = jax.jacobian(S_flat)(z0)                    # (2m,3)\n","        ds_dt  = Jz_s[:,0]\n","        ds_dxy = Jz_s[:,1:]                                # (2m,2)\n","\n","        sigma = jnp.asarray(sig_fn(t_scalar,x_scalar,y_scalar), dtype=jnp.float64)\n","        mW = sigma.shape[1]\n","        sigma_t = ds_dt.reshape(2,mW)\n","        Jsigma_xy = ds_dxy.reshape(2,mW,2)                 # (2,m,2) last axis = (x,y)\n","        return f, f_t, Jf, sigma, sigma_t, Jsigma_xy\n","\n","    def _point_residual(params_gen, z):\n","        t,x,y = z[0],z[1],z[2]\n","        f,f_t,Jf,sigma,sigma_t,Jsigma_xy = f_sigma_and_derivs(t,x,y)\n","        a = sigma @ sigma.T  # (2,2)\n","        total = jnp.array(0.0, dtype=jnp.float64)\n","        for p_tau, p_xi in zip(params_gen[\"tau\"], params_gen[\"xi\"]):\n","            tau_i, tau_t_i = tau_val_and_dt(p_tau, t)\n","            xi_i, Jz_xi, Hxy = xi_val_jac_hess(p_xi, t, x, y)\n","            xi_t = Jz_xi[:,0]\n","            Jxi  = Jz_xi[:,1:]  # (2,2)\n","\n","            diff_vec = 0.5 * jnp.einsum(\"pq,rpq->r\", a, Hxy)      # (2,)\n","            adv_xi   = Jxi @ f                                   # (2,)\n","            adv_f    = Jf  @ xi_i                                 # (2,)\n","            r1 = xi_t + adv_xi - adv_f - tau_i*f_t - f*tau_t_i + diff_vec\n","\n","            Dxi_sigma = jnp.einsum(\"k,ikm->im\", xi_i, Jsigma_xy)  # (2,m)\n","            r2 = (Jxi @ sigma) - Dxi_sigma - tau_i*sigma_t - 0.5*tau_t_i*sigma\n","\n","            if use_abs:\n","                total = total + jnp.sum(jnp.abs(r1)) + jnp.sum(jnp.abs(r2))\n","            else:\n","                total = total + jnp.sum(r1*r1) + jnp.sum(r2*r2)\n","        return total\n","\n","    def _loss_impl(params_gen, txy_batch):\n","        txy_batch = jnp.asarray(txy_batch, dtype=jnp.float64)\n","        per_point = jax.vmap(lambda z: _point_residual(params_gen, z))(txy_batch)\n","        return jnp.mean(per_point), {\"per_point\": per_point}\n","\n","    return jax.jit(_loss_impl)\n","\n","def make_s7_pushforward_coeff_loss_sde_only_2d(\n","    *, mu_fn, sig_fn, eps: float = 1e-2, num_steps: int = 1, sigma_floor: float = 1e-8,\n","    dt_neg_penalty: float = 100.0, activation: str = \"tanh\", normalize_txy=None, jit: bool = True, use_heun: bool = True,\n","    fd_t: float = 1e-3, fd_x: float = 1e-3, fd_y: float = 1e-3,\n","    tau_clip: float = 5.0, xi_clip: float = 5.0, xy_clip_abs: float = 50.0, x_min_abs: float = 1e-3,\n","):\n","    eval_gen_tau_xi = eval_generators_tau_xi_jit\n","    eps = jnp.asarray(eps, dtype=jnp.float64)\n","    num_steps = int(num_steps)\n","    fd_t = jnp.asarray(fd_t, dtype=jnp.float64)\n","    fd_x = jnp.asarray(fd_x, dtype=jnp.float64)\n","    fd_y = jnp.asarray(fd_y, dtype=jnp.float64)\n","\n","    def _sanitize(t,x,y):\n","        t = jnp.nan_to_num(t, nan=0.0, posinf=0.0, neginf=0.0)\n","        x = jnp.clip(jnp.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0), -xy_clip_abs, xy_clip_abs)\n","        y = jnp.clip(jnp.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0), -xy_clip_abs, xy_clip_abs)\n","        x = jnp.sign(x) * jnp.maximum(jnp.abs(x), x_min_abs)\n","        return t,x,y\n","\n","    def _mu_eval(t,x,y):\n","        out = jnp.asarray(mu_fn(t,x,y), dtype=jnp.float64)\n","        return jnp.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)\n","\n","    def _sig_eval(t,x,y):\n","        out = jnp.asarray(sig_fn(t,x,y), dtype=jnp.float64)\n","        out = jnp.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)\n","        if out.shape[-2] != 2:\n","            raise ValueError(f\"sig_fn must return (...,2,m); got {out.shape}\")\n","        mW = out.shape[-1]\n","        return out, mW\n","\n","    def _diag_from_flat(params_gen, t_flat, x_flat, y_flat):\n","        tau_all, xi_all = eval_gen_tau_xi(params_gen, t_flat, x_flat, y_flat, activation=activation, normalize_txy=normalize_txy)\n","        tau_all = jnp.asarray(tau_all, dtype=jnp.float64)\n","        xi_all  = jnp.asarray(xi_all,  dtype=jnp.float64)\n","\n","        m = tau_all.shape[0]\n","        Bflat = tau_all.shape[1]\n","        if Bflat % m != 0:\n","            raise ValueError(\"Internal shape error: expected Bflat divisible by m.\")\n","        B = Bflat // m\n","\n","        tau_blk = tau_all.reshape(m, m, B)\n","        xi_blk  = xi_all.reshape(m, m, B, 2)\n","        idx = jnp.arange(m, dtype=jnp.int32)\n","        tau_d = tau_blk[idx, idx, :]\n","        xi_d  = xi_blk[idx, idx, :, :]\n","\n","        tau_d = tau_clip * jnp.tanh(jnp.nan_to_num(tau_d)/tau_clip)\n","        xi_d  = xi_clip  * jnp.tanh(jnp.nan_to_num(xi_d)/xi_clip)\n","        return tau_d, xi_d\n","\n","    def _rhs_diag_with_derivs(params_gen, tS, xS, yS):\n","        m,B = tS.shape\n","        t_flat = tS.reshape(-1); x_flat = xS.reshape(-1); y_flat = yS.reshape(-1)\n","\n","        if normalize_txy is not None:\n","            t_flat, x_flat, y_flat = normalize_txy(t_flat, x_flat, y_flat)\n","            t_flat = jnp.asarray(t_flat, dtype=jnp.float64)\n","            x_flat = jnp.asarray(x_flat, dtype=jnp.float64)\n","            y_flat = jnp.asarray(y_flat, dtype=jnp.float64)\n","\n","        tau0, xi0 = _diag_from_flat(params_gen, t_flat, x_flat, y_flat)\n","\n","        tau_p, xi_p = _diag_from_flat(params_gen, t_flat+fd_t, x_flat, y_flat)\n","        tau_m, xi_m = _diag_from_flat(params_gen, t_flat-fd_t, x_flat, y_flat)\n","        tau_t = (tau_p - tau_m)/(2*fd_t)\n","        xi_t  = (xi_p  - xi_m )/(2*fd_t)\n","\n","        _, xi_xp = _diag_from_flat(params_gen, t_flat, x_flat+fd_x, y_flat)\n","        _, xi_xm = _diag_from_flat(params_gen, t_flat, x_flat-fd_x, y_flat)\n","        xi_x  = (xi_xp - xi_xm)/(2*fd_x)\n","        xi_xx = (xi_xp - 2*xi0 + xi_xm)/(fd_x*fd_x)\n","\n","        _, xi_yp = _diag_from_flat(params_gen, t_flat, x_flat, y_flat+fd_y)\n","        _, xi_ym = _diag_from_flat(params_gen, t_flat, x_flat, y_flat-fd_y)\n","        xi_y  = (xi_yp - xi_ym)/(2*fd_y)\n","        xi_yy = (xi_yp - 2*xi0 + xi_ym)/(fd_y*fd_y)\n","\n","        _, xi_xp_yp = _diag_from_flat(params_gen, t_flat, x_flat+fd_x, y_flat+fd_y)\n","        _, xi_xp_ym = _diag_from_flat(params_gen, t_flat, x_flat+fd_x, y_flat-fd_y)\n","        _, xi_xm_yp = _diag_from_flat(params_gen, t_flat, x_flat-fd_x, y_flat+fd_y)\n","        _, xi_xm_ym = _diag_from_flat(params_gen, t_flat, x_flat-fd_x, y_flat-fd_y)\n","        xi_xy = (xi_xp_yp - xi_xp_ym - xi_xm_yp + xi_xm_ym)/(4*fd_x*fd_y)\n","\n","        tau0  = tau0.reshape(m,B); tau_t = tau_t.reshape(m,B)\n","        xi0   = xi0.reshape(m,B,2); xi_t  = xi_t.reshape(m,B,2)\n","        xi_x  = xi_x.reshape(m,B,2); xi_y  = xi_y.reshape(m,B,2)\n","        xi_xx = xi_xx.reshape(m,B,2); xi_yy = xi_yy.reshape(m,B,2)\n","        xi_xy = xi_xy.reshape(m,B,2)\n","        return tau0, xi0, tau_t, xi_t, xi_x, xi_y, xi_xx, xi_yy, xi_xy\n","\n","    def _a_from_sigma(sg):\n","        a = jnp.einsum(\"...iA,...jA->...ij\", sg, sg)\n","        a = a + (sigma_floor**2)*jnp.eye(2, dtype=a.dtype)\n","        return a\n","\n","    def _flow_step_rhs(params_gen, tS, xS, yS, muS, sgS):\n","        tau, xi, tau_t, xi_t, xi_x, xi_y, xi_xx, xi_yy, xi_xy = _rhs_diag_with_derivs(params_gen, tS, xS, yS)\n","        Jxi = jnp.stack([xi_x, xi_y], axis=-1)  # (m,B,2,2)\n","        a = _a_from_sigma(sgS)\n","        a11,a12,a22 = a[...,0,0], a[...,0,1], a[...,1,1]\n","        diff_vec = 0.5*(a11[...,None]*xi_xx + 2.0*a12[...,None]*xi_xy + a22[...,None]*xi_yy)\n","        adv_xi = jnp.einsum(\"...ij,...j->...i\", Jxi, muS)\n","        k_t  = tau\n","        k_xy = xi\n","        k_mu = xi_t + adv_xi - muS*tau_t[...,None] + diff_vec\n","        k_sg = jnp.einsum(\"...ij,...jA->...iA\", Jxi, sgS) - 0.5*tau_t[...,None,None]*sgS\n","        return k_t, k_xy, k_mu, k_sg\n","\n","    def _flow_allgens(params_gen, t0, x0, y0):\n","        tau_all, _ = eval_gen_tau_xi(params_gen, t0, x0, y0, activation=activation, normalize_txy=normalize_txy)\n","        m = int(tau_all.shape[0]); B = int(t0.shape[0])\n","\n","        tS = jnp.broadcast_to(t0[None,:], (m,B))\n","        xS = jnp.broadcast_to(x0[None,:], (m,B))\n","        yS = jnp.broadcast_to(y0[None,:], (m,B))\n","\n","        t0c,x0c,y0c = _sanitize(t0,x0,y0)\n","        mu0 = _mu_eval(t0c,x0c,y0c)         # (B,2)\n","        sg0,mW = _sig_eval(t0c,x0c,y0c)     # (B,2,mW)\n","\n","        muS = jnp.broadcast_to(mu0[None,:,:], (m,B,2))\n","        sgS = jnp.broadcast_to(sg0[None,:,:,:], (m,B,2,mW))\n","\n","        def body(_, state):\n","            tS,xS,yS,muS,sgS = state\n","            if use_heun:\n","                k1_t,k1_xy,k1_mu,k1_sg = _flow_step_rhs(params_gen,tS,xS,yS,muS,sgS)\n","                tP = tS + eps*k1_t\n","                xP = xS + eps*k1_xy[...,0]\n","                yP = yS + eps*k1_xy[...,1]\n","                muP= muS+ eps*k1_mu\n","                sgP= sgS+ eps*k1_sg\n","                k2_t,k2_xy,k2_mu,k2_sg = _flow_step_rhs(params_gen,tP,xP,yP,muP,sgP)\n","                tN = tS + 0.5*eps*(k1_t + k2_t)\n","                xN = xS + 0.5*eps*(k1_xy[...,0]+k2_xy[...,0])\n","                yN = yS + 0.5*eps*(k1_xy[...,1]+k2_xy[...,1])\n","                muN= muS+ 0.5*eps*(k1_mu+k2_mu)\n","                sgN= sgS+ 0.5*eps*(k1_sg+k2_sg)\n","            else:\n","                k_t,k_xy,k_mu,k_sg = _flow_step_rhs(params_gen,tS,xS,yS,muS,sgS)\n","                tN = tS + eps*k_t\n","                xN = xS + eps*k_xy[...,0]\n","                yN = yS + eps*k_xy[...,1]\n","                muN= muS+ eps*k_mu\n","                sgN= sgS+ eps*k_sg\n","            return (jnp.nan_to_num(tN), jnp.nan_to_num(xN), jnp.nan_to_num(yN), jnp.nan_to_num(muN), jnp.nan_to_num(sgN))\n","\n","        tS,xS,yS,muS,sgS = jax.lax.fori_loop(0, num_steps, body, (tS,xS,yS,muS,sgS))\n","        return tS,xS,yS,muS,sgS,mW\n","\n","    def _loss_impl(params_gen, txy_batch):\n","        txy_batch = jnp.asarray(txy_batch, dtype=jnp.float64)\n","        tL = txy_batch[:,0]; xL = txy_batch[:,1]; yL = txy_batch[:,2]\n","        t_push,x_push,y_push,mu_pred,sg_pred,mW = _flow_allgens(params_gen,tL,xL,yL)\n","        m,B = t_push.shape\n","\n","        tpc,xpc,ypc = _sanitize(t_push.reshape(-1), x_push.reshape(-1), y_push.reshape(-1))\n","        mu_eval = _mu_eval(tpc,xpc,ypc).reshape(m,B,2)\n","        sg_eval,_= _sig_eval(tpc,xpc,ypc)\n","        sg_eval = sg_eval.reshape(m,B,2,mW)\n","\n","        mu_mse = jnp.mean((mu_pred-mu_eval)**2, axis=(1,2))\n","        sg_mse = jnp.mean((sg_pred-sg_eval)**2, axis=(1,2,3))\n","        dt = t_push - tL[None,:]\n","        dt_neg = jnp.mean(jax.nn.softplus(-dt), axis=1)\n","        per_gen = mu_mse + sg_mse + dt_neg_penalty*dt_neg\n","        loss = jnp.mean(per_gen)\n","        return loss, {\"s7_mu_mse\": jnp.mean(mu_mse), \"s7_sigma_mse\": jnp.mean(sg_mse), \"s7_dt_neg\": jnp.mean(dt_neg)}\n","\n","    return jax.jit(_loss_impl) if jit else _loss_impl\n","\n","# Build loss fns\n","s1 = make_s1_lie_loss(m)\n","s2 = make_s2_jacobi_loss_nested(m)\n","s3 = make_s3_skewsym_loss(m)\n","s4 = make_s4_bilinearity_loss(m, num_cc=4)\n","s5 = make_s5_column_independence_loss(m, mode=\"sigma\", tau=0.02)\n","s6 = make_s6_commutator_loss_ito(mu_fn=mu_fn, sig_fn=sig_fn, use_abs=False)\n","s7 = make_s7_pushforward_coeff_loss_sde_only_2d(mu_fn=mu_fn, sig_fn=sig_fn, eps=1e-2, num_steps=1, use_heun=True)\n","\n","# ---------------------------\n","# 10) Training setup\n","# ---------------------------\n","@dataclass\n","class GenTrainConfig:\n","    steps: int = 6000\n","    batch_size: int = 256\n","    lr: float = 2e-4\n","    print_every: int = 200\n","    grad_clip: float = 1.0\n","\n","gen_train_cfg = GenTrainConfig(**getattr(globals().get(\"gen_train_cfg\", GenTrainConfig()), \"__dict__\", {}))\n","\n","def weight_schedule(step, steps):\n","    s = step / max(1, steps)\n","    ramp = jnp.clip((s - 0.6) / 0.4, 0.0, 1.0)\n","    w_s6 = 10.0\n","    w_s7 = 2.0\n","    w_s5 = 2.0\n","    w_s1 = 0.5 * ramp\n","    w_s2 = 0.2 * ramp\n","    w_s3 = 0.2 * ramp\n","    w_s4 = 0.2 * ramp\n","    return jnp.asarray([w_s1,w_s2,w_s3,w_s4,w_s5,w_s6,w_s7], dtype=jnp.float64)\n","\n","# Sampler: half empirical (TX_gen) half uniform\n","TX_gen_np = None\n","if TX_gen is not None:\n","    TX_gen_np = np.asarray(jax.device_get(jnp.asarray(TX_gen, dtype=jnp.float64)))\n","    TX_gen_np = TX_gen_np[np.isfinite(TX_gen_np).all(axis=1)]\n","    if TX_gen_np.shape[0] == 0:\n","        TX_gen_np = None\n","\n","def _infer_bounds_from_TX(TX):\n","    return float(np.min(TX[:,0])), float(np.max(TX[:,0])), float(np.min(TX[:,1])), float(np.max(TX[:,1])), float(np.min(TX[:,2])), float(np.max(TX[:,2]))\n","\n","if TX_gen_np is not None:\n","    tmin_u,tmax_u,xmin_u,xmax_u,ymin_u,ymax_u = _infer_bounds_from_TX(TX_gen_np)\n","else:\n","    tmin_u,tmax_u = 0.0, float(getattr(cfg_ex5, \"T\", 2.0))\n","    xmin_u,xmax_u = 0.2, 5.0\n","    ymin_u,ymax_u = -3.0, 5.0\n","\n","X_FLOOR = max(1e-2, xmin_u)\n","rng_np = np.random.default_rng(0)\n","\n","def sample_txy_batch(batch_size):\n","    n_uni = batch_size//2\n","    n_emp = batch_size - n_uni\n","    chunks = []\n","    if (TX_gen_np is not None) and (n_emp>0):\n","        N = TX_gen_np.shape[0]\n","        idx = rng_np.choice(N, size=n_emp, replace=(n_emp>N))\n","        chunks.append(TX_gen_np[idx])\n","    if n_uni>0:\n","        t = rng_np.uniform(tmin_u, tmax_u, size=(n_uni,1))\n","        x = rng_np.uniform(X_FLOOR, xmax_u, size=(n_uni,1))\n","        y = rng_np.uniform(ymin_u, ymax_u, size=(n_uni,1))\n","        chunks.append(np.concatenate([t,x,y], axis=1))\n","    TXb = np.concatenate(chunks, axis=0)\n","    rng_np.shuffle(TXb)\n","    return jnp.asarray(TXb, dtype=jnp.float64)\n","\n","print(\"[train] bounds:\",\n","      f\"t∈[{tmin_u:.3g},{tmax_u:.3g}]\",\n","      f\"x∈[{X_FLOOR:.3g},{xmax_u:.3g}]\",\n","      f\"y∈[{ymin_u:.3g},{ymax_u:.3g}]\",\n","      f\"| TX_gen = {None if TX_gen_np is None else TX_gen_np.shape[0]}\")\n","\n","# Optimizer\n","warmup = int(0.05*gen_train_cfg.steps)\n","cosine = optax.cosine_decay_schedule(init_value=gen_train_cfg.lr, decay_steps=max(1, gen_train_cfg.steps-warmup), alpha=0.1)\n","schedule = optax.join_schedules([optax.linear_schedule(0.0, gen_train_cfg.lr, warmup), cosine], boundaries=[warmup])\n","\n","optimizer_gen = optax.chain(\n","    optax.clip_by_global_norm(gen_train_cfg.grad_clip),\n","    optax.adamw(learning_rate=schedule, b1=0.9, b2=0.999, eps=1e-8, weight_decay=0.0),\n",")\n","opt_state = optimizer_gen.init(params_gen)\n","\n","@jax.jit\n","def train_step(params_gen, opt_state, txy_batch, step_idx):\n","    w = weight_schedule(step_idx, gen_train_cfg.steps)\n","\n","    L1, aux1 = s1(params_gen, txy_batch)\n","    L2, aux2 = s2(params_gen, txy_batch)\n","    L3, aux3 = s3(params_gen, txy_batch)\n","    L4, aux4 = s4(params_gen, txy_batch, None)\n","    L5, aux5 = s5(params_gen, txy_batch)\n","    L6, aux6 = s6(params_gen, txy_batch)\n","    L7, aux7 = s7(params_gen, txy_batch)\n","\n","    def loss_fn(p):\n","        L1_, _ = s1(p, txy_batch)\n","        L2_, _ = s2(p, txy_batch)\n","        L3_, _ = s3(p, txy_batch)\n","        L4_, _ = s4(p, txy_batch, None)\n","        L5_, _ = s5(p, txy_batch)\n","        L6_, _ = s6(p, txy_batch)\n","        L7_, _ = s7(p, txy_batch)\n","        return w[0]*L1_ + w[1]*L2_ + w[2]*L3_ + w[3]*L4_ + w[4]*L5_ + w[5]*L6_ + w[6]*L7_\n","\n","    total = loss_fn(params_gen)\n","    grads = jax.grad(loss_fn)(params_gen)\n","    updates, opt_state2 = optimizer_gen.update(grads, opt_state, params_gen)\n","    params2 = optax.apply_updates(params_gen, updates)\n","\n","    aux = dict(\n","        total=total, w=w,\n","        L1=L1, L2=L2, L3=L3, L4=L4, L5=L5, L6=L6, L7=L7,\n","        sigma_min=aux5.get(\"sigma_min\", jnp.nan),\n","        s7_mu_mse=aux7.get(\"s7_mu_mse\", jnp.nan),\n","        s7_sigma_mse=aux7.get(\"s7_sigma_mse\", jnp.nan),\n","        s7_dt_neg=aux7.get(\"s7_dt_neg\", jnp.nan),\n","    )\n","    return params2, opt_state2, aux\n","\n","# Train\n","hist = []\n","for step in range(1, gen_train_cfg.steps+1):\n","    txy_batch = sample_txy_batch(gen_train_cfg.batch_size)\n","    params_gen, opt_state, aux = train_step(params_gen, opt_state, txy_batch, step)\n","\n","    if (step % gen_train_cfg.print_every) == 0 or step == 1 or step == gen_train_cfg.steps:\n","        w_np = np.asarray(jax.device_get(aux[\"w\"]))\n","        print(\n","            f\"step {step:5d}/{gen_train_cfg.steps} | total={float(aux['total']):.3e} | \"\n","            f\"L1={float(aux['L1']):.2e} L2={float(aux['L2']):.2e} L3={float(aux['L3']):.2e} L4={float(aux['L4']):.2e} | \"\n","            f\"L5={float(aux['L5']):.2e} L6={float(aux['L6']):.2e} L7={float(aux['L7']):.2e} | \"\n","            f\"sigma_min={float(aux['sigma_min']):.3e} | \"\n","            f\"s7(mu,sig,dtneg)=({float(aux['s7_mu_mse']):.2e},{float(aux['s7_sigma_mse']):.2e},{float(aux['s7_dt_neg']):.2e}) | \"\n","            f\"w={w_np.round(3)}\"\n","        )\n","\n","    hist.append(float(aux[\"total\"]))\n","\n","globals()[\"params_gen\"] = params_gen\n","globals()[\"hist_gen_total\"] = hist\n","print(\"[OK] Training finished. Exported params_gen.\")\n","\n","\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ublREDw-H6lT","outputId":"5e5bd7c9-2128-43fb-ea81-80ccdaa748f3","executionInfo":{"status":"ok","timestamp":1769526539293,"user_tz":300,"elapsed":205419,"user":{"displayName":"Dacha Thurbur","userId":"16530582923668080051"}}},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["[ok] Built TX_gen from (t, XY): (411648, 3)\n","Gen t_norm mean/std: [1.] [0.58022984]\n","Gen tx_norm mean/std: [1.         2.33843767 0.51104933] [0.58022984 1.0494469  1.58468619]\n","[OK] Generator nets ready: m=3, xi is 2D, normalizers set (t_norm_gen, tx_norm_gen, txy_norm_gen).\n","[train] bounds: t∈[0,2] x∈[0.01,7.43] y∈[-4.98,6.27] | TX_gen = 411648\n","step     1/6000 | total=1.686e+02 | L1=4.49e+00 L2=7.00e-16 L3=0.00e+00 L4=1.09e-14 | L5=0.00e+00 L6=3.00e+00 L7=6.93e+01 | sigma_min=8.083e-01 | s7(mu,sig,dtneg)=(4.90e-05,5.85e-07,6.93e-01) | w=[ 0.  0.  0.  0.  2. 10.  2.]\n","step   200/6000 | total=1.385e+02 | L1=1.60e+03 L2=1.76e-16 L3=0.00e+00 L4=1.12e-15 | L5=0.00e+00 L6=2.73e-03 L7=6.92e+01 | sigma_min=6.379e-02 | s7(mu,sig,dtneg)=(4.02e-08,2.62e-09,6.92e-01) | w=[ 0.  0.  0.  0.  2. 10.  2.]\n","step   400/6000 | total=1.375e+02 | L1=6.88e+03 L2=1.58e-14 L3=0.00e+00 L4=1.27e-14 | L5=0.00e+00 L6=2.82e-03 L7=6.87e+01 | sigma_min=2.046e-02 | s7(mu,sig,dtneg)=(3.31e-08,5.37e-09,6.87e-01) | w=[ 0.  0.  0.  0.  2. 10.  2.]\n","step   600/6000 | total=1.359e+02 | L1=2.39e+03 L2=1.57e-13 L3=0.00e+00 L4=5.40e-14 | L5=2.22e-03 L6=7.55e-03 L7=6.79e+01 | sigma_min=1.778e-02 | s7(mu,sig,dtneg)=(7.15e-08,9.69e-09,6.79e-01) | w=[ 0.  0.  0.  0.  2. 10.  2.]\n","step   800/6000 | total=1.347e+02 | L1=2.71e+04 L2=4.55e-13 L3=0.00e+00 L4=7.75e-14 | L5=1.45e-02 L6=3.14e-03 L7=6.73e+01 | sigma_min=5.546e-03 | s7(mu,sig,dtneg)=(1.57e-08,2.72e-09,6.73e-01) | w=[ 0.  0.  0.  0.  2. 10.  2.]\n","step  1000/6000 | total=1.342e+02 | L1=4.13e+03 L2=5.76e-13 L3=0.00e+00 L4=8.17e-14 | L5=1.66e-02 L6=2.89e-03 L7=6.71e+01 | sigma_min=3.435e-03 | s7(mu,sig,dtneg)=(2.51e-08,1.42e-09,6.71e-01) | w=[ 0.  0.  0.  0.  2. 10.  2.]\n","step  1200/6000 | total=1.340e+02 | L1=3.08e+03 L2=6.86e-13 L3=0.00e+00 L4=8.04e-14 | L5=1.73e-02 L6=1.53e-03 L7=6.70e+01 | sigma_min=2.666e-03 | s7(mu,sig,dtneg)=(1.16e-08,1.04e-09,6.70e-01) | w=[ 0.  0.  0.  0.  2. 10.  2.]\n","step  1400/6000 | total=1.339e+02 | L1=3.71e+03 L2=6.75e-13 L3=0.00e+00 L4=6.79e-14 | L5=1.77e-02 L6=9.96e-04 L7=6.69e+01 | sigma_min=2.327e-03 | s7(mu,sig,dtneg)=(7.36e-09,1.09e-09,6.69e-01) | w=[ 0.  0.  0.  0.  2. 10.  2.]\n","step  1600/6000 | total=1.339e+02 | L1=1.09e+04 L2=7.09e-13 L3=0.00e+00 L4=7.00e-14 | L5=1.81e-02 L6=5.60e-04 L7=6.69e+01 | sigma_min=1.888e-03 | s7(mu,sig,dtneg)=(6.25e-09,7.89e-10,6.69e-01) | w=[ 0.  0.  0.  0.  2. 10.  2.]\n","step  1800/6000 | total=1.338e+02 | L1=1.90e+05 L2=7.45e-13 L3=0.00e+00 L4=6.99e-14 | L5=1.84e-02 L6=5.38e-04 L7=6.69e+01 | sigma_min=1.608e-03 | s7(mu,sig,dtneg)=(1.02e-08,7.17e-10,6.69e-01) | w=[ 0.  0.  0.  0.  2. 10.  2.]\n","step  2000/6000 | total=1.338e+02 | L1=3.37e+03 L2=7.26e-13 L3=0.00e+00 L4=6.58e-14 | L5=1.84e-02 L6=3.91e-04 L7=6.69e+01 | sigma_min=1.552e-03 | s7(mu,sig,dtneg)=(6.32e-09,6.87e-10,6.69e-01) | w=[ 0.  0.  0.  0.  2. 10.  2.]\n","step  2200/6000 | total=1.338e+02 | L1=2.87e+03 L2=7.14e-13 L3=0.00e+00 L4=5.93e-14 | L5=1.86e-02 L6=3.41e-04 L7=6.69e+01 | sigma_min=1.368e-03 | s7(mu,sig,dtneg)=(5.26e-09,5.43e-10,6.69e-01) | w=[ 0.  0.  0.  0.  2. 10.  2.]\n","step  2400/6000 | total=1.338e+02 | L1=2.69e+04 L2=7.48e-13 L3=0.00e+00 L4=5.62e-14 | L5=1.87e-02 L6=2.02e-04 L7=6.69e+01 | sigma_min=1.255e-03 | s7(mu,sig,dtneg)=(2.78e-09,5.10e-10,6.69e-01) | w=[ 0.  0.  0.  0.  2. 10.  2.]\n","step  2600/6000 | total=1.338e+02 | L1=1.50e+05 L2=7.30e-13 L3=0.00e+00 L4=5.59e-14 | L5=1.88e-02 L6=4.12e-04 L7=6.69e+01 | sigma_min=1.236e-03 | s7(mu,sig,dtneg)=(8.00e-09,5.17e-10,6.69e-01) | w=[ 0.  0.  0.  0.  2. 10.  2.]\n","step  2800/6000 | total=1.338e+02 | L1=1.41e+04 L2=7.70e-13 L3=0.00e+00 L4=6.31e-14 | L5=1.88e-02 L6=2.13e-04 L7=6.69e+01 | sigma_min=1.248e-03 | s7(mu,sig,dtneg)=(2.97e-09,5.51e-10,6.69e-01) | w=[ 0.  0.  0.  0.  2. 10.  2.]\n","step  3000/6000 | total=1.338e+02 | L1=2.60e+04 L2=7.21e-13 L3=0.00e+00 L4=5.93e-14 | L5=1.88e-02 L6=2.06e-04 L7=6.69e+01 | sigma_min=1.209e-03 | s7(mu,sig,dtneg)=(2.95e-09,5.70e-10,6.69e-01) | w=[ 0.  0.  0.  0.  2. 10.  2.]\n","step  3200/6000 | total=1.338e+02 | L1=3.76e+04 L2=7.72e-13 L3=0.00e+00 L4=5.52e-14 | L5=1.89e-02 L6=2.21e-04 L7=6.69e+01 | sigma_min=1.141e-03 | s7(mu,sig,dtneg)=(2.64e-09,5.43e-10,6.69e-01) | w=[ 0.  0.  0.  0.  2. 10.  2.]\n","step  3400/6000 | total=1.338e+02 | L1=1.08e+05 L2=7.73e-13 L3=0.00e+00 L4=5.53e-14 | L5=1.89e-02 L6=1.78e-04 L7=6.69e+01 | sigma_min=1.097e-03 | s7(mu,sig,dtneg)=(2.57e-09,4.54e-10,6.69e-01) | w=[ 0.  0.  0.  0.  2. 10.  2.]\n","step  3600/6000 | total=1.338e+02 | L1=3.34e+04 L2=7.66e-13 L3=0.00e+00 L4=5.49e-14 | L5=1.89e-02 L6=1.58e-04 L7=6.69e+01 | sigma_min=1.083e-03 | s7(mu,sig,dtneg)=(2.47e-09,4.85e-10,6.69e-01) | w=[ 0.  0.  0.  0.  2. 10.  2.]\n","step  3800/6000 | total=1.338e+02 | L1=1.27e+00 L2=7.03e-13 L3=0.00e+00 L4=9.66e-14 | L5=1.64e-02 L6=4.29e-03 L7=6.69e+01 | sigma_min=3.645e-03 | s7(mu,sig,dtneg)=(8.98e-08,1.52e-09,6.69e-01) | w=[ 0.042  0.017  0.017  0.017  2.    10.     2.   ]\n","step  4000/6000 | total=1.339e+02 | L1=1.11e+00 L2=7.35e-13 L3=0.00e+00 L4=1.08e-13 | L5=1.61e-02 L6=3.46e-03 L7=6.69e+01 | sigma_min=3.866e-03 | s7(mu,sig,dtneg)=(7.56e-08,2.09e-09,6.69e-01) | w=[ 0.083  0.033  0.033  0.033  2.    10.     2.   ]\n","step  4200/6000 | total=1.339e+02 | L1=5.66e-01 L2=6.88e-13 L3=0.00e+00 L4=1.04e-13 | L5=1.63e-02 L6=4.92e-03 L7=6.69e+01 | sigma_min=3.682e-03 | s7(mu,sig,dtneg)=(1.07e-07,2.03e-09,6.69e-01) | w=[ 0.125  0.05   0.05   0.05   2.    10.     2.   ]\n","step  4400/6000 | total=1.339e+02 | L1=3.51e-01 L2=6.90e-13 L3=0.00e+00 L4=1.00e-13 | L5=1.63e-02 L6=4.47e-03 L7=6.69e+01 | sigma_min=3.710e-03 | s7(mu,sig,dtneg)=(8.14e-08,2.44e-09,6.69e-01) | w=[ 0.167  0.067  0.067  0.067  2.    10.     2.   ]\n","step  4600/6000 | total=1.339e+02 | L1=3.39e-01 L2=6.75e-13 L3=0.00e+00 L4=9.75e-14 | L5=1.63e-02 L6=4.61e-03 L7=6.69e+01 | sigma_min=3.704e-03 | s7(mu,sig,dtneg)=(1.01e-07,2.37e-09,6.69e-01) | w=[ 0.208  0.083  0.083  0.083  2.    10.     2.   ]\n","step  4800/6000 | total=1.338e+02 | L1=1.58e-01 L2=7.00e-13 L3=0.00e+00 L4=9.89e-14 | L5=1.63e-02 L6=5.07e-03 L7=6.69e+01 | sigma_min=3.669e-03 | s7(mu,sig,dtneg)=(1.05e-07,2.52e-09,6.69e-01) | w=[ 0.25  0.1   0.1   0.1   2.   10.    2.  ]\n","step  5000/6000 | total=1.338e+02 | L1=1.98e-01 L2=6.39e-13 L3=0.00e+00 L4=1.00e-13 | L5=1.64e-02 L6=4.11e-03 L7=6.69e+01 | sigma_min=3.619e-03 | s7(mu,sig,dtneg)=(8.43e-08,2.72e-09,6.69e-01) | w=[ 0.292  0.117  0.117  0.117  2.    10.     2.   ]\n","step  5200/6000 | total=1.339e+02 | L1=1.74e-01 L2=6.63e-13 L3=0.00e+00 L4=1.01e-13 | L5=1.65e-02 L6=5.98e-03 L7=6.69e+01 | sigma_min=3.483e-03 | s7(mu,sig,dtneg)=(1.09e-07,2.55e-09,6.69e-01) | w=[ 0.333  0.133  0.133  0.133  2.    10.     2.   ]\n","step  5400/6000 | total=1.338e+02 | L1=9.57e-02 L2=6.75e-13 L3=0.00e+00 L4=9.69e-14 | L5=1.65e-02 L6=6.14e-03 L7=6.69e+01 | sigma_min=3.522e-03 | s7(mu,sig,dtneg)=(1.16e-07,2.84e-09,6.69e-01) | w=[ 0.375  0.15   0.15   0.15   2.    10.     2.   ]\n","step  5600/6000 | total=1.338e+02 | L1=1.10e-01 L2=6.58e-13 L3=0.00e+00 L4=1.01e-13 | L5=1.64e-02 L6=4.37e-03 L7=6.69e+01 | sigma_min=3.622e-03 | s7(mu,sig,dtneg)=(9.75e-08,2.71e-09,6.69e-01) | w=[ 0.417  0.167  0.167  0.167  2.    10.     2.   ]\n","step  5800/6000 | total=1.338e+02 | L1=8.03e-02 L2=6.65e-13 L3=0.00e+00 L4=9.77e-14 | L5=1.64e-02 L6=3.89e-03 L7=6.69e+01 | sigma_min=3.563e-03 | s7(mu,sig,dtneg)=(9.19e-08,2.83e-09,6.69e-01) | w=[ 0.458  0.183  0.183  0.183  2.    10.     2.   ]\n","step  6000/6000 | total=1.339e+02 | L1=1.02e-01 L2=6.57e-13 L3=0.00e+00 L4=9.74e-14 | L5=1.64e-02 L6=5.25e-03 L7=6.69e+01 | sigma_min=3.555e-03 | s7(mu,sig,dtneg)=(9.50e-08,2.95e-09,6.69e-01) | w=[ 0.5  0.2  0.2  0.2  2.  10.   2. ]\n","[OK] Training finished. Exported params_gen.\n"]}]},{"cell_type":"code","source":["# ============================================================\n","# 11) Robust span check vs GT SDE symmetry span (m=3)\n","#    v1 = ∂_t\n","#    v3 = ∂_y\n","#    v4 = 2t∂_t + x∂_x + (y + a2 t)∂_y\n","# ============================================================\n","\n","def _gt_generators_ex5(t, x, y, a2):\n","    t = jnp.asarray(t, dtype=jnp.float64)\n","    x = jnp.asarray(x, dtype=jnp.float64)\n","    y = jnp.asarray(y, dtype=jnp.float64)\n","    B = t.shape[0]\n","    v1 = jnp.stack([jnp.ones((B,), dtype=jnp.float64),\n","                    jnp.zeros((B,), dtype=jnp.float64),\n","                    jnp.zeros((B,), dtype=jnp.float64)], axis=1)\n","    v3 = jnp.stack([jnp.zeros((B,), dtype=jnp.float64),\n","                    jnp.zeros((B,), dtype=jnp.float64),\n","                    jnp.ones((B,), dtype=jnp.float64)], axis=1)\n","    v4 = jnp.stack([2.0*t,\n","                    x,\n","                    (y + a2*t)], axis=1)\n","    return jnp.stack([v1, v3, v4], axis=0)  # (3,B,3)\n","\n","def _stack_columns(V_mB3):\n","    V = jnp.transpose(V_mB3, (1,2,0))  # (B,3,m)\n","    return V.reshape(-1, V.shape[-1])  # (3B,m)\n","\n","def span_metrics(A_gt, A_learn):\n","    Qg, _ = jnp.linalg.qr(A_gt)\n","    Ql, _ = jnp.linalg.qr(A_learn)\n","    M = Qg.T @ Ql\n","    s = jnp.linalg.svd(M, compute_uv=False)\n","    s = jnp.clip(s, 0.0, 1.0)\n","    angles = jnp.arccos(s)\n","\n","    Pg = Qg @ Qg.T\n","    Pl = Ql @ Ql.T\n","    r_learn_to_gt = jnp.linalg.norm((jnp.eye(Pg.shape[0]) - Pg) @ A_learn) / (jnp.linalg.norm(A_learn) + 1e-12)\n","    r_gt_to_learn = jnp.linalg.norm((jnp.eye(Pl.shape[0]) - Pl) @ A_gt) / (jnp.linalg.norm(A_gt) + 1e-12)\n","\n","    M_ls, _, _, _ = jnp.linalg.lstsq(A_gt, A_learn, rcond=None)\n","    fit_rel = jnp.linalg.norm(A_learn - A_gt @ M_ls) / (jnp.linalg.norm(A_learn) + 1e-12)\n","    return angles, r_learn_to_gt, r_gt_to_learn, fit_rel\n","\n","# -------- user-editable principal-angle domain (defaults match previous behavior) --------\n","span_t_min = float(globals().get(\"span_t_min\", tmin_u))\n","span_t_max = float(globals().get(\"span_t_max\", tmax_u))\n","span_x_min = float(globals().get(\"span_x_min\", X_FLOOR))\n","span_x_max = float(globals().get(\"span_x_max\", xmax_u))\n","span_y_min = float(globals().get(\"span_y_min\", ymin_u))\n","span_y_max = float(globals().get(\"span_y_max\", ymax_u))\n","\n","print(\n","    \"\\n[SPAN CHECK DOMAIN] \"\n","    f\"t∈[{span_t_min:.6g},{span_t_max:.6g}]  \"\n","    f\"x∈[{span_x_min:.6g},{span_x_max:.6g}]  \"\n","    f\"y∈[{span_y_min:.6g},{span_y_max:.6g}]\"\n",")\n","\n","Bcheck = 512\n","t_s = jnp.asarray(rng_np.uniform(span_t_min, span_t_max, size=(Bcheck,)), dtype=jnp.float64)\n","x_s = jnp.asarray(rng_np.uniform(span_x_min, span_x_max, size=(Bcheck,)), dtype=jnp.float64)\n","y_s = jnp.asarray(rng_np.uniform(span_y_min, span_y_max, size=(Bcheck,)), dtype=jnp.float64)\n","\n","tauL, xiL = eval_generators_tau_xi_jit(params_gen, t_s, x_s, y_s, activation=gen_cfg.activation)\n","V_learn = jnp.stack([tauL, xiL[...,0], xiL[...,1]], axis=-1)  # (m,B,3)\n","A_learn = _stack_columns(V_learn)                             # (3B,m)\n","\n","a2 = float(globals().get(\"truth\", {}).get(\"a2\", 0.5))\n","V_gt = _gt_generators_ex5(t_s, x_s, y_s, a2=a2)               # (3,B,3)\n","A_gt = _stack_columns(V_gt)                                   # (3B,3)\n","\n","angles, rL2G, rG2L, fit_rel = span_metrics(A_gt, A_learn)\n","\n","print(\"\\n===== SPAN CHECK (learned vs GT span{v1,v3,v4}) =====\")\n","print(\"principal angles (rad):\", np.asarray(jax.device_get(angles)))\n","print(\"principal angles (deg):\", np.asarray(jax.device_get(angles))*180.0/np.pi)\n","print(f\"residual ||(I-P_gt)A_learn||/||A_learn||          = {float(rL2G):.3e}\")\n","print(f\"residual ||(I-P_learn)A_gt||/||A_gt||             = {float(rG2L):.3e}\")\n","print(f\"best-mixing fit ||A_learn - A_gt M||/||A_learn||  = {float(fit_rel):.3e}\")\n","\n","globals().update({\n","    \"span_angles_rad\": angles,\n","    \"span_res_learn_to_gt\": rL2G,\n","    \"span_res_gt_to_learn\": rG2L,\n","    \"span_bestmix_fit\": fit_rel,\n","})\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"sAli9iUaNJLl","outputId":"c59392e6-b28d-424d-a298-2834e15fcd9f","executionInfo":{"status":"ok","timestamp":1769526571926,"user_tz":300,"elapsed":2026,"user":{"displayName":"Dacha Thurbur","userId":"16530582923668080051"}}},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["\n","[SPAN CHECK DOMAIN] t∈[0,2]  x∈[0.01,7.43426]  y∈[0.1,4]\n","\n","===== SPAN CHECK (learned vs GT span{v1,v3,v4}) =====\n","principal angles (rad): [1.38736960e-04 1.47899113e-01 3.36709857e-01]\n","principal angles (deg): [7.94904229e-03 8.47399495e+00 1.92920538e+01]\n","residual ||(I-P_gt)A_learn||/||A_learn||          = 9.596e-04\n","residual ||(I-P_learn)A_gt||/||A_gt||             = 2.880e-01\n","best-mixing fit ||A_learn - A_gt M||/||A_learn||  = 9.596e-04\n"]}]}]}