{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"machine_shape":"hm","gpuType":"A100"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"_uSr9AeUmETU","outputId":"eb367cc8-ee4d-4022-d759-c1cfddd0ec59"},"outputs":[{"output_type":"stream","name":"stdout","text":["=== SDE -> learned FP -> FP symmetries (variable coefficients) ===\n","JAX devices: [CudaDevice(id=0)]\n","SDE: T=2.0 dt=0.01 n_traj=512\n","Surrogate: hidden=128 depth=3 act=tanh\n","Symmetry: m=6 hidden=(128, 128, 128) act=swish\n","\n","Data-informed FP domain: t∈[0.001,2] x∈[-6.2,5.56]\n","\n","\n","=== Stage A: learn drift/diffusion from trajectory increments ===\n","[surrogate] step      1  loss=-1.738e+00  nll=-1.738e+00  grad_mu=9.72e-03  grad_sig=2.60e-05  amp_l1=5.50e-02\n","[surrogate] step    250  loss=-1.793e+00  nll=-1.793e+00  grad_mu=6.33e-04  grad_sig=3.09e-17  amp_l1=2.63e-01\n","[surrogate] step    500  loss=-1.808e+00  nll=-1.808e+00  grad_mu=3.62e-03  grad_sig=2.33e-17  amp_l1=2.59e-01\n","[surrogate] step    750  loss=-1.800e+00  nll=-1.800e+00  grad_mu=1.38e-03  grad_sig=9.90e-18  amp_l1=2.46e-01\n","[surrogate] step   1000  loss=-1.789e+00  nll=-1.790e+00  grad_mu=1.11e-02  grad_sig=1.51e-17  amp_l1=2.31e-01\n","[surrogate] step   1250  loss=-1.790e+00  nll=-1.791e+00  grad_mu=1.77e-02  grad_sig=1.12e-17  amp_l1=2.21e-01\n","[surrogate] step   1500  loss=-1.809e+00  nll=-1.809e+00  grad_mu=1.55e-02  grad_sig=1.79e-17  amp_l1=2.03e-01\n","[surrogate] step   1750  loss=-1.796e+00  nll=-1.796e+00  grad_mu=1.56e-02  grad_sig=1.06e-17  amp_l1=1.85e-01\n","[surrogate] step   2000  loss=-1.802e+00  nll=-1.803e+00  grad_mu=1.87e-02  grad_sig=8.49e-18  amp_l1=1.66e-01\n","[surrogate] step   2250  loss=-1.794e+00  nll=-1.795e+00  grad_mu=8.19e-03  grad_sig=2.89e-18  amp_l1=1.44e-01\n","[surrogate] step   2500  loss=-1.788e+00  nll=-1.788e+00  grad_mu=4.69e-03  grad_sig=2.57e-18  amp_l1=1.20e-01\n","[surrogate] step   2750  loss=-1.794e+00  nll=-1.794e+00  grad_mu=1.94e-02  grad_sig=4.30e-18  amp_l1=9.30e-02\n","[surrogate] step   3000  loss=-1.796e+00  nll=-1.796e+00  grad_mu=1.49e-02  grad_sig=8.36e-19  amp_l1=6.09e-02\n","[surrogate] step   3250  loss=-1.802e+00  nll=-1.802e+00  grad_mu=8.38e-03  grad_sig=4.88e-19  amp_l1=3.89e-02\n","[surrogate] step   3500  loss=-1.792e+00  nll=-1.792e+00  grad_mu=6.82e-03  grad_sig=4.28e-21  amp_l1=2.69e-03\n","[surrogate] step   3750  loss=-1.798e+00  nll=-1.798e+00  grad_mu=1.45e-02  grad_sig=2.97e-24  amp_l1=1.18e-04\n","[surrogate] step   4000  loss=-1.796e+00  nll=-1.796e+00  grad_mu=7.10e-03  grad_sig=6.31e-22  amp_l1=1.97e-03\n","[surrogate] step   4250  loss=-1.800e+00  nll=-1.800e+00  grad_mu=1.58e-02  grad_sig=3.85e-21  amp_l1=4.08e-03\n","[surrogate] step   4500  loss=-1.807e+00  nll=-1.807e+00  grad_mu=6.65e-03  grad_sig=2.67e-21  amp_l1=2.43e-03\n","[surrogate] step   4750  loss=-1.805e+00  nll=-1.805e+00  grad_mu=9.24e-03  grad_sig=1.86e-21  amp_l1=2.55e-03\n","[surrogate] step   5000  loss=-1.799e+00  nll=-1.799e+00  grad_mu=1.32e-02  grad_sig=4.52e-22  amp_l1=1.20e-03\n","[surrogate] step   5250  loss=-1.785e+00  nll=-1.785e+00  grad_mu=1.44e-02  grad_sig=1.40e-21  amp_l1=2.46e-03\n","[surrogate] step   5500  loss=-1.810e+00  nll=-1.810e+00  grad_mu=1.56e-02  grad_sig=7.71e-21  amp_l1=4.91e-03\n","[surrogate] step   5750  loss=-1.795e+00  nll=-1.795e+00  grad_mu=6.90e-03  grad_sig=7.50e-22  amp_l1=1.79e-03\n","[surrogate] step   6000  loss=-1.789e+00  nll=-1.789e+00  grad_mu=1.19e-02  grad_sig=8.77e-22  amp_l1=1.89e-03\n","[surrogate] probe mu(t,x): [-0.07653213 -0.02448728  0.06887933]\n","[surrogate] probe sigma(t,x): [1.00784644 1.00784644 1.00784644]\n","\n","=== Stage B: learn FP symmetries from learned (mu,sigma) ===\n","[train] step      1  tot=8.155e-03  S8=8.590e-04 (r1=5.35e-05, r2=1.56e-04, r3=6.50e-04)  S5=1.664e-01  S1=6.198e-06 (cl=1.14e-13, varC=6.20e-06)  S2=2.33e-11  S3=4.21e-39  S4=1.40e-37  S9=8.59e-04  logdet=4.630e+01  col=8.916e-02  origLie3=1.195e-01  min/max_col=4.55e-01/1.03e+00  t_max=0.6 ramp=0.00  (w_s2=1.3e-07, w_s3=6.7e-08, w_s4=6.7e-08, w_s9=6.7e-08, w_lie3=6.7e-07)\n","[train] step   1000  tot=1.530e-01  S8=3.348e-05 (r1=5.45e-07, r2=1.44e-05, r3=1.85e-05)  S5=1.664e-01  S1=2.996e-06 (cl=1.52e-14, varC=3.00e-06)  S2=6.57e-10  S3=2.56e-38  S4=6.61e-37  S9=3.35e-05  logdet=4.317e+01  col=1.872e-03  origLie3=3.231e-03  min/max_col=9.09e-01/1.00e+00  t_max=0.67 ramp=0.07  (w_s2=1.3e-04, w_s3=6.7e-05, w_s4=6.7e-05, w_s9=6.7e-05, w_lie3=6.7e-04)\n","[train] step   2000  tot=3.056e-01  S8=2.214e-05 (r1=2.14e-07, r2=8.44e-06, r3=1.35e-05)  S5=1.664e-01  S1=3.477e-06 (cl=1.51e-14, varC=3.48e-06)  S2=7.75e-10  S3=1.22e-38  S4=7.83e-37  S9=2.21e-05  logdet=4.317e+01  col=2.529e-04  origLie3=8.807e-04  min/max_col=9.76e-01/1.00e+00  t_max=0.74 ramp=0.13  (w_s2=2.7e-04, w_s3=1.3e-04, w_s4=1.3e-04, w_s9=1.3e-04, w_lie3=1.3e-03)\n","[train] step   3000  tot=4.584e-01  S8=2.274e-05 (r1=1.51e-07, r2=8.46e-06, r3=1.41e-05)  S5=1.664e-01  S1=3.471e-06 (cl=1.55e-14, varC=3.47e-06)  S2=8.48e-10  S3=8.89e-39  S4=8.42e-37  S9=2.27e-05  logdet=4.317e+01  col=1.566e-04  origLie3=7.342e-04  min/max_col=9.87e-01/1.02e+00  t_max=0.81 ramp=0.20  (w_s2=4.0e-04, w_s3=2.0e-04, w_s4=2.0e-04, w_s9=2.0e-04, w_lie3=2.0e-03)\n","[train] step   4000  tot=6.112e-01  S8=2.383e-05 (r1=1.83e-07, r2=9.21e-06, r3=1.44e-05)  S5=1.664e-01  S1=3.451e-06 (cl=1.57e-14, varC=3.45e-06)  S2=8.15e-10  S3=8.88e-39  S4=8.35e-37  S9=2.38e-05  logdet=4.317e+01  col=3.458e-04  origLie3=5.141e-04  min/max_col=9.64e-01/1.02e+00  t_max=0.88 ramp=0.27  (w_s2=5.3e-04, w_s3=2.7e-04, w_s4=2.7e-04, w_s9=2.7e-04, w_lie3=2.7e-03)\n","[train] step   5000  tot=7.640e-01  S8=2.522e-05 (r1=2.59e-07, r2=9.72e-06, r3=1.52e-05)  S5=1.664e-01  S1=3.403e-06 (cl=1.51e-14, varC=3.40e-06)  S2=7.44e-10  S3=4.03e-38  S4=6.00e-37  S9=2.52e-05  logdet=4.317e+01  col=4.161e-04  origLie3=7.070e-04  min/max_col=9.59e-01/1.02e+00  t_max=0.95 ramp=0.33  (w_s2=6.7e-04, w_s3=3.3e-04, w_s4=3.3e-04, w_s9=3.3e-04, w_lie3=3.3e-03)\n","[eval:current]  resid ||V-WA||/||V|| = 7.126e-02\n","  angle  1: 0.002857 rad  = 0.1637 deg\n","  angle  2: 0.013287 rad  = 0.7613 deg\n","  angle  3: 0.026321 rad  = 1.5081 deg\n","  angle  4: 0.045593 rad  = 2.6123 deg\n","  angle  5: 0.112589 rad  = 6.4509 deg\n","  angle  6: 0.129549 rad  = 7.4226 deg\n","[eval:best_score@2387]  resid ||V-WA||/||V|| = 7.932e-02\n","  angle  1: 0.005287 rad  = 0.3029 deg\n","  angle  2: 0.013126 rad  = 0.7520 deg\n","  angle  3: 0.019042 rad  = 1.0910 deg\n","  angle  4: 0.031921 rad  = 1.8289 deg\n","  angle  5: 0.129565 rad  = 7.4235 deg\n","  angle  6: 0.146908 rad  = 8.4172 deg\n","  per-gen det mse: [1.199e-05 2.107e-05 2.230e-05 7.094e-06 5.271e-05 9.685e-05]\n","[train] step   6000  tot=9.169e-01  S8=2.346e-05 (r1=3.46e-07, r2=9.87e-06, r3=1.33e-05)  S5=1.664e-01  S1=3.300e-06 (cl=1.58e-14, varC=3.30e-06)  S2=7.81e-10  S3=1.71e-38  S4=5.72e-37  S9=2.35e-05  logdet=4.318e+01  col=6.085e-04  origLie3=4.779e-04  min/max_col=9.63e-01/1.03e+00  t_max=1.02 ramp=0.40  (w_s2=8.0e-04, w_s3=4.0e-04, w_s4=4.0e-04, w_s9=4.0e-04, w_lie3=4.0e-03)\n","[train] step   7000  tot=1.070e+00  S8=2.340e-05 (r1=1.41e-07, r2=1.05e-05, r3=1.28e-05)  S5=1.664e-01  S1=3.345e-06 (cl=1.56e-14, varC=3.34e-06)  S2=7.32e-10  S3=6.62e-39  S4=4.67e-37  S9=2.34e-05  logdet=4.317e+01  col=7.228e-04  origLie3=6.799e-04  min/max_col=9.59e-01/1.03e+00  t_max=1.09 ramp=0.47  (w_s2=9.3e-04, w_s3=4.7e-04, w_s4=4.7e-04, w_s9=4.7e-04, w_lie3=4.7e-03)\n","[train] step   8000  tot=1.222e+00  S8=2.512e-05 (r1=3.49e-07, r2=1.03e-05, r3=1.45e-05)  S5=1.664e-01  S1=3.316e-06 (cl=1.36e-14, varC=3.32e-06)  S2=6.39e-10  S3=1.11e-38  S4=3.74e-37  S9=2.51e-05  logdet=4.317e+01  col=1.214e-03  origLie3=5.381e-04  min/max_col=9.59e-01/1.06e+00  t_max=1.16 ramp=0.53  (w_s2=1.1e-03, w_s3=5.3e-04, w_s4=5.3e-04, w_s9=5.3e-04, w_lie3=5.3e-03)\n","[train] step   9000  tot=1.375e+00  S8=2.735e-05 (r1=2.07e-07, r2=1.20e-05, r3=1.51e-05)  S5=1.664e-01  S1=3.126e-06 (cl=1.45e-14, varC=3.13e-06)  S2=6.78e-10  S3=1.11e-38  S4=3.87e-37  S9=2.73e-05  logdet=4.317e+01  col=7.072e-04  origLie3=4.522e-04  min/max_col=9.63e-01/1.03e+00  t_max=1.23 ramp=0.60  (w_s2=1.2e-03, w_s3=6.0e-04, w_s4=6.0e-04, w_s9=6.0e-04, w_lie3=6.0e-03)\n","[train] step  10000  tot=1.528e+00  S8=2.619e-05 (r1=1.60e-07, r2=1.17e-05, r3=1.43e-05)  S5=1.664e-01  S1=3.088e-06 (cl=1.41e-14, varC=3.09e-06)  S2=5.95e-10  S3=1.00e-38  S4=3.64e-37  S9=2.62e-05  logdet=4.317e+01  col=7.564e-04  origLie3=3.889e-04  min/max_col=9.50e-01/1.02e+00  t_max=1.3 ramp=0.67  (w_s2=1.3e-03, w_s3=6.7e-04, w_s4=6.7e-04, w_s9=6.7e-04, w_lie3=6.7e-03)\n","[eval:current]  resid ||V-WA||/||V|| = 6.240e-02\n","  angle  1: 0.002925 rad  = 0.1676 deg\n","  angle  2: 0.009835 rad  = 0.5635 deg\n","  angle  3: 0.013755 rad  = 0.7881 deg\n","  angle  4: 0.052293 rad  = 2.9962 deg\n","  angle  5: 0.102126 rad  = 5.8514 deg\n","  angle  6: 0.103361 rad  = 5.9221 deg\n","[eval:best_score@2387]  resid ||V-WA||/||V|| = 7.932e-02\n","  angle  1: 0.005287 rad  = 0.3029 deg\n","  angle  2: 0.013126 rad  = 0.7520 deg\n","  angle  3: 0.019042 rad  = 1.0910 deg\n","  angle  4: 0.031921 rad  = 1.8289 deg\n","  angle  5: 0.129565 rad  = 7.4235 deg\n","  angle  6: 0.146908 rad  = 8.4172 deg\n","  per-gen det mse: [1.044e-05 3.086e-05 4.275e-05 1.168e-05 3.503e-05 5.343e-05]\n","[train] step  11000  tot=1.681e+00  S8=2.790e-05 (r1=1.04e-07, r2=1.24e-05, r3=1.54e-05)  S5=1.664e-01  S1=3.299e-06 (cl=1.47e-14, varC=3.30e-06)  S2=7.18e-10  S3=2.20e-39  S4=3.49e-37  S9=2.79e-05  logdet=4.317e+01  col=1.953e-04  origLie3=5.251e-04  min/max_col=9.81e-01/1.03e+00  t_max=1.37 ramp=0.73  (w_s2=1.5e-03, w_s3=7.3e-04, w_s4=7.3e-04, w_s9=7.3e-04, w_lie3=7.3e-03)\n","[train] step  12000  tot=1.833e+00  S8=2.453e-05 (r1=7.78e-08, r2=1.26e-05, r3=1.19e-05)  S5=1.664e-01  S1=3.141e-06 (cl=1.48e-14, varC=3.14e-06)  S2=6.58e-10  S3=2.10e-39  S4=3.61e-37  S9=2.45e-05  logdet=4.317e+01  col=4.380e-04  origLie3=2.431e-04  min/max_col=9.52e-01/1.01e+00  t_max=1.44 ramp=0.80  (w_s2=1.6e-03, w_s3=8.0e-04, w_s4=8.0e-04, w_s9=8.0e-04, w_lie3=8.0e-03)\n","[train] step  13000  tot=1.986e+00  S8=3.008e-05 (r1=9.42e-08, r2=1.38e-05, r3=1.62e-05)  S5=1.664e-01  S1=3.141e-06 (cl=1.37e-14, varC=3.14e-06)  S2=6.07e-10  S3=6.04e-39  S4=3.72e-37  S9=3.01e-05  logdet=4.317e+01  col=3.253e-04  origLie3=5.475e-04  min/max_col=9.98e-01/1.04e+00  t_max=1.51 ramp=0.87  (w_s2=1.7e-03, w_s3=8.7e-04, w_s4=8.7e-04, w_s9=8.7e-04, w_lie3=8.7e-03)\n","[train] step  14000  tot=2.139e+00  S8=2.870e-05 (r1=2.58e-07, r2=1.35e-05, r3=1.50e-05)  S5=1.664e-01  S1=2.959e-06 (cl=1.25e-14, varC=2.96e-06)  S2=4.79e-10  S3=2.45e-40  S4=3.38e-37  S9=2.87e-05  logdet=4.317e+01  col=1.516e-03  origLie3=4.407e-04  min/max_col=9.36e-01/1.04e+00  t_max=1.58 ramp=0.93  (w_s2=1.9e-03, w_s3=9.3e-04, w_s4=9.3e-04, w_s9=9.3e-04, w_lie3=9.3e-03)\n","[train] step  15000  tot=2.292e+00  S8=2.563e-05 (r1=1.69e-07, r2=1.25e-05, r3=1.29e-05)  S5=1.664e-01  S1=3.021e-06 (cl=1.34e-14, varC=3.02e-06)  S2=5.21e-10  S3=2.09e-39  S4=3.34e-37  S9=2.56e-05  logdet=4.317e+01  col=8.066e-04  origLie3=3.420e-04  min/max_col=9.61e-01/1.04e+00  t_max=1.65 ramp=1.00  (w_s2=2.0e-03, w_s3=1.0e-03, w_s4=1.0e-03, w_s9=1.0e-03, w_lie3=1.0e-02)\n","[eval:current]  resid ||V-WA||/||V|| = 5.927e-02\n","  angle  1: 0.002131 rad  = 0.1221 deg\n","  angle  2: 0.003980 rad  = 0.2280 deg\n","  angle  3: 0.017119 rad  = 0.9809 deg\n","  angle  4: 0.047088 rad  = 2.6980 deg\n","  angle  5: 0.091331 rad  = 5.2329 deg\n","  angle  6: 0.101152 rad  = 5.7956 deg\n","[eval:best_score@2387]  resid ||V-WA||/||V|| = 7.932e-02\n","  angle  1: 0.005287 rad  = 0.3029 deg\n","  angle  2: 0.013126 rad  = 0.7520 deg\n","  angle  3: 0.019042 rad  = 1.0910 deg\n","  angle  4: 0.031921 rad  = 1.8289 deg\n","  angle  5: 0.129565 rad  = 7.4235 deg\n","  angle  6: 0.146908 rad  = 8.4172 deg\n","  per-gen det mse: [1.115e-05 3.235e-05 6.195e-05 9.066e-06 3.395e-05 2.837e-05]\n","[train] step  16000  tot=2.292e+00  S8=2.814e-05 (r1=1.17e-07, r2=1.39e-05, r3=1.42e-05)  S5=1.664e-01  S1=3.100e-06 (cl=1.33e-14, varC=3.10e-06)  S2=5.57e-10  S3=7.89e-39  S4=3.34e-37  S9=2.81e-05  logdet=4.317e+01  col=1.015e-03  origLie3=4.063e-04  min/max_col=9.67e-01/1.05e+00  t_max=1.72 ramp=1.00  (w_s2=2.0e-03, w_s3=1.0e-03, w_s4=1.0e-03, w_s9=1.0e-03, w_lie3=1.0e-02)\n","[train] step  17000  tot=2.292e+00  S8=3.249e-05 (r1=1.01e-07, r2=1.75e-05, r3=1.49e-05)  S5=1.664e-01  S1=2.947e-06 (cl=1.29e-14, varC=2.95e-06)  S2=5.19e-10  S3=3.24e-39  S4=3.17e-37  S9=3.25e-05  logdet=4.317e+01  col=7.834e-04  origLie3=3.641e-04  min/max_col=9.73e-01/1.05e+00  t_max=1.79 ramp=1.00  (w_s2=2.0e-03, w_s3=1.0e-03, w_s4=1.0e-03, w_s9=1.0e-03, w_lie3=1.0e-02)\n","[train] step  18000  tot=2.292e+00  S8=3.063e-05 (r1=8.76e-08, r2=1.43e-05, r3=1.62e-05)  S5=1.664e-01  S1=3.012e-06 (cl=1.29e-14, varC=3.01e-06)  S2=5.02e-10  S3=9.00e-39  S4=3.24e-37  S9=3.06e-05  logdet=4.317e+01  col=4.016e-04  origLie3=3.056e-04  min/max_col=9.66e-01/1.03e+00  t_max=1.86 ramp=1.00  (w_s2=2.0e-03, w_s3=1.0e-03, w_s4=1.0e-03, w_s9=1.0e-03, w_lie3=1.0e-02)\n","[train] step  19000  tot=2.291e+00  S8=3.256e-05 (r1=1.80e-07, r2=1.65e-05, r3=1.59e-05)  S5=1.664e-01  S1=2.772e-06 (cl=1.24e-14, varC=2.77e-06)  S2=4.41e-10  S3=3.41e-39  S4=3.43e-37  S9=3.26e-05  logdet=4.317e+01  col=3.420e-04  origLie3=2.943e-04  min/max_col=9.75e-01/1.02e+00  t_max=1.93 ramp=1.00  (w_s2=2.0e-03, w_s3=1.0e-03, w_s4=1.0e-03, w_s9=1.0e-03, w_lie3=1.0e-02)\n"]}],"source":["\"\"\"\n","End-to-end demo (fully general 1D coefficients):\n","  SDE trajectories -> learn drift/diffusion (mu,sigma) -> build variable-coefficient FP operator\n","  -> learn Lie point symmetries of that FP operator.\n","\n","Key point:\n","- We DO NOT assume constant diffusion in the symmetry learner.\n","- The FP determining equations are the *general* (variable-coefficient) projectable FP\n","  determining equations specialized to 1D, using Gaeta–Quintero.\n","\n","Reference (uploaded by you):\n","- Rewrite FP as  u_t + A u_xx + B u_x + C u = 0  with\n","    A(t,x) = -1/2 (σ σ^T),\n","    B(t,x) = f - ∂_x(σ σ^T),\n","    C(t,x) = (∂_x f) - 1/2 ∂_{xx}(σ σ^T).     (Eq. 4.1–4.2)\n","\n","\n","- For projectable symmetries with φ = α(t,x) + β(t,x) u, the (τ,ξ,β) satisfy the determining\n","  equations (Eq. 4.12). fileciteturn1file2L25-L42\n","\n","In 1D (x only), these become (using A,B,C above and τ=τ(t), ξ=ξ(t,x), β=β(t,x)):\n","\n","  r1:  ∂t(τ A) + ξ A_x - 2 A ξ_x = 0\n","  r2:  ∂t(τ B) - ξ_t + B ξ_x - ξ B_x + 2 A β_x - A ξ_xx = 0\n","  r3:  ∂t(τ C) + β_t + A β_xx + B β_x + ξ C_x = 0\n","\n","We learn m generators jointly with the same loss suite / curriculum / ramping as your\n","(fp_symmetry_learn_heat1d_joint_v4.py) script, but swapping the heat-specific DEs for r1,r2,r3.\n","\n","Pipeline steps\n","--------------\n","1) Simulate data for an SDE (Euler–Maruyama). For the demo we simulate\n","      dX = sigma0 dW\n","   but the learning code treats mu(t,x), sigma(t,x) as fully general.\n","\n","2) Learn mu(t,x) and sigma(t,x) from increments via a Gaussian conditional likelihood:\n","      ΔX | (t,x) ~ Normal(mu(t,x) Δt,  sigma(t,x)^2 Δt)\n","   with subtle biases toward \"simple\" functions:\n","     - linear skip in mu (helps represent affine drift with correct derivatives),\n","     - diffusion parameterization with (constant base + global amplitude * residual),\n","       and L1 on the amplitude so it collapses to constant if supported by data,\n","       *without hard-coding* constant diffusion,\n","     - small derivative penalties on mu and sigma to discourage spurious roughness.\n","\n","3) Freeze learned mu,sigma, construct FP coefficients A,B,C using (4.2) and autodiff.\n","\n","4) Learn FP symmetries by minimizing:\n","     - determining loss: mean over points of sum_i (r1_i^2 + r2_i^2 + r3_i^2)\n","     - independence loss: ||G - I||_F^2 where G is Gram of stacked generator columns\n","     - logdet barrier: -log det(G + eps I) (encourage full-rank span)\n","     - Lie closure loss: commutators [v_i, v_j] lie in span{v_k} (least squares residual)\n","     - column norm loss: keep each generator non-trivial (norm around a target)\n","\n","Run\n","---\n","  python sde_fp_symmetry_from_data_varcoeff_v1.py\n","\n","Tips\n","----\n","- For a quick smoke test, reduce:\n","    SurrogateTrainConfig.steps, TrainConfigFP.steps\n","- If symmetry learning is unstable, try:\n","    smaller lr, larger batch, or increasing derivative penalty in the surrogate.\n","\n","\"\"\"\n","\n","from __future__ import annotations\n","\n","import math\n","from dataclasses import dataclass\n","from typing import Callable, Dict, Tuple\n","\n","import numpy as np\n","import jax\n","import jax.numpy as jnp\n","import optax\n","\n","jax.config.update(\"jax_enable_x64\", True)\n","\n","DTYPE = jnp.float64\n","Array = jnp.ndarray\n","\n","\n","# =============================================================================\n","# 1) SDE simulation (data only)\n","# =============================================================================\n","\n","@dataclass(frozen=True)\n","class SDEConfig:\n","    T: float = 2.0\n","    dt: float = 0.01\n","    n_traj: int = 512\n","    x0_low: float = -0.5\n","    x0_high: float = 0.5\n","\n","    # simulation-only (unknown in real use)\n","    sigma0: float = 1.0\n","\n","\n","def simulate_traj_and_increments(key: Array, cfg: SDEConfig) -> Tuple[Array, Array, Array]:\n","    \"\"\"Euler–Maruyama simulation for the demo SDE: dX = sigma0 dW.\n","\n","    Returns\n","    -------\n","    TX_all : (n_traj*(N+1), 2)  rows are [t, x]\n","    Z      : (n_traj*N, 2)      conditioning states [t_n, x_n]\n","    dX     : (n_traj*N, 1)      increments X_{n+1} - X_n\n","    \"\"\"\n","    dt = float(cfg.dt)\n","    N = int(cfg.T / dt)\n","    t_grid = jnp.linspace(0.0, cfg.T, N + 1, dtype=DTYPE)\n","\n","    k_init, k_noise = jax.random.split(key, 2)\n","    x0 = jax.random.uniform(\n","        k_init, (cfg.n_traj,), minval=cfg.x0_low, maxval=cfg.x0_high, dtype=DTYPE\n","    )\n","    dW = jax.random.normal(k_noise, (cfg.n_traj, N), dtype=DTYPE) * math.sqrt(dt)\n","\n","    def step(x, dWn):\n","        return x + DTYPE(cfg.sigma0) * dWn, x + DTYPE(cfg.sigma0) * dWn\n","\n","    _, xs = jax.lax.scan(step, x0, dW.T)\n","    xs = jnp.concatenate([x0[None, :], xs], axis=0)  # (N+1, n_traj)\n","\n","    TT = jnp.broadcast_to(t_grid[:, None], xs.shape)\n","    TX_all = jnp.stack([TT, xs], axis=-1).reshape(-1, 2)\n","\n","    x_n = xs[:-1, :]\n","    x_np1 = xs[1:, :]\n","    t_n = jnp.broadcast_to(t_grid[:-1, None], x_n.shape)\n","\n","    Z = jnp.stack([t_n, x_n], axis=-1).reshape(-1, 2)\n","    dX = (x_np1 - x_n).reshape(-1, 1)\n","    return TX_all, Z, dX\n","\n","\n","# =============================================================================\n","# 2) Learn drift/diffusion surrogates from increments\n","# =============================================================================\n","\n","@dataclass(frozen=True)\n","class SurrogateNetConfig:\n","    hidden: int = 128\n","    depth: int = 3\n","    activation: str = \"tanh\"  # tanh|relu|gelu\n","\n","    # subtle bias: make affine drifts easy (and derivatives correct)\n","    use_linear_skip: bool = True\n","    lin_init_scale: float = 1e-2\n","\n","    sigma_min: float = 1e-6\n","\n","    # subtle bias: prefer (nearly) constant diffusion while still allowing general σ(t,x)\n","    use_amp_residual_sigma: bool = True\n","    res_scale: float = 1.0\n","    amp_init: float = 0.05\n","\n","\n","@dataclass(frozen=True)\n","class SurrogateTrainConfig:\n","    steps: int = 6000\n","    batch_size: int = 8192\n","    lr: float = 3e-3\n","    weight_decay: float = 1e-6\n","    clip_norm: float = 1.0\n","    print_every: int = 250\n","    seed: int = 0\n","\n","    # regularizers (keep small; they bias toward smooth/simple coefficients)\n","    w_grad_mu: float = 5e-5\n","    w_grad_sigma: float = 5e-5\n","    w_amp_l1: float = 1e-3\n","\n","\n","def _act(x: Array, name: str) -> Array:\n","    if name == \"tanh\":\n","        return jnp.tanh(x)\n","    if name == \"relu\":\n","        return jax.nn.relu(x)\n","    if name == \"gelu\":\n","        return jax.nn.gelu(x)\n","    raise ValueError(f\"unknown activation: {name}\")\n","\n","\n","def _mlp_init(key, in_dim: int, out_dim: int, hidden: int, depth: int) -> Dict[str, Array]:\n","    keys = jax.random.split(key, depth + 1)\n","    dims = [in_dim] + [hidden] * depth + [out_dim]\n","    params = {\"W\": [], \"b\": []}\n","    for i in range(len(dims) - 1):\n","        fan_in, fan_out = dims[i], dims[i + 1]\n","        lim = math.sqrt(6.0 / float(fan_in + fan_out))\n","        W = jax.random.uniform(keys[i], (fan_in, fan_out), minval=-lim, maxval=lim, dtype=DTYPE)\n","        b = jnp.zeros((fan_out,), dtype=DTYPE)\n","        params[\"W\"].append(W)\n","        params[\"b\"].append(b)\n","    return params\n","\n","\n","def _mlp_apply(params: Dict[str, Array], x: Array, activation: str) -> Array:\n","    h = x\n","    for i in range(len(params[\"W\"]) - 1):\n","        h = h @ params[\"W\"][i] + params[\"b\"][i]\n","        h = _act(h, activation)\n","    return h @ params[\"W\"][-1] + params[\"b\"][-1]\n","\n","\n","def _normalize_z(Z: Array) -> Tuple[Array, Dict[str, Array]]:\n","    \"\"\"Normalize (t,x) to roughly [-1,1] with data-driven stats.\"\"\"\n","    mean = jnp.mean(Z, axis=0)\n","    std = jnp.std(Z, axis=0) + DTYPE(1e-8)\n","    Zs = (Z - mean) / std\n","    return Zs, {\"mean\": mean, \"std\": std}\n","\n","\n","def _denorm_z(zs: Array, stats: Dict[str, Array]) -> Array:\n","    return zs * stats[\"std\"] + stats[\"mean\"]\n","\n","\n","def init_surrogate_params(key: Array, net: SurrogateNetConfig) -> Dict[str, Dict]:\n","    k_mu, k_sig, k_lin, k_base, k_amp = jax.random.split(key, 5)\n","    params = {\n","        \"mu_mlp\": _mlp_init(k_mu, 2, 1, net.hidden, net.depth),\n","        \"sig_mlp\": _mlp_init(k_sig, 2, 1, net.hidden, net.depth),\n","    }\n","    if net.use_linear_skip:\n","        W = net.lin_init_scale * jax.random.normal(k_lin, (2, 1), dtype=DTYPE)\n","        b = jnp.zeros((1,), dtype=DTYPE)\n","        params[\"mu_lin\"] = {\"W\": W, \"b\": b}\n","    else:\n","        params[\"mu_lin\"] = None\n","\n","    if net.use_amp_residual_sigma:\n","        params[\"sig_base\"] = jnp.asarray([1.0], dtype=DTYPE) + 0.01 * jax.random.normal(k_base, (1,), dtype=DTYPE)\n","        params[\"sig_amp\"] = jnp.asarray([net.amp_init], dtype=DTYPE) + 0.01 * jax.random.normal(k_amp, (1,), dtype=DTYPE)\n","    else:\n","        params[\"sig_base\"] = None\n","        params[\"sig_amp\"] = None\n","\n","    return params\n","\n","\n","def surrogate_mu_sigma(\n","    params: Dict[str, Dict],\n","    z_norm: Array,\n","    net: SurrogateNetConfig,\n",") -> Tuple[Array, Array]:\n","    \"\"\"Return (mu, sigma) at normalized inputs z_norm (shape (...,2)).\"\"\"\n","    mu = _mlp_apply(params[\"mu_mlp\"], z_norm, net.activation)\n","    if net.use_linear_skip and params[\"mu_lin\"] is not None:\n","        mu = mu + (z_norm @ params[\"mu_lin\"][\"W\"] + params[\"mu_lin\"][\"b\"])\n","\n","    sig_raw = _mlp_apply(params[\"sig_mlp\"], z_norm, net.activation)  # residual\n","    if net.use_amp_residual_sigma and (params[\"sig_base\"] is not None) and (params[\"sig_amp\"] is not None):\n","        # global amplitude gate -> discourages *derivative leakage* when collapsing to constant\n","        sig = params[\"sig_base\"] + params[\"sig_amp\"] * jnp.tanh(sig_raw / DTYPE(net.res_scale)) * DTYPE(net.res_scale)\n","    else:\n","        sig = sig_raw\n","\n","    # enforce positivity (and avoid degenerate sigma)\n","    sig = jax.nn.softplus(sig) + DTYPE(net.sigma_min)\n","    return mu, sig\n","\n","\n","def train_surrogate(\n","    key: Array,\n","    Z: Array,\n","    dX: Array,\n","    sde_cfg: SDEConfig,\n","    net: SurrogateNetConfig,\n","    tr: SurrogateTrainConfig,\n",") -> Tuple[Dict, Dict[str, Array]]:\n","    \"\"\"Learn mu(t,x) and sigma(t,x) from increments.\"\"\"\n","    dt = DTYPE(sde_cfg.dt)\n","\n","    Z_norm, stats = _normalize_z(Z)\n","\n","    params = init_surrogate_params(key, net)\n","\n","    opt = optax.chain(\n","        optax.clip_by_global_norm(tr.clip_norm),\n","        optax.adamw(tr.lr, weight_decay=tr.weight_decay),\n","    )\n","    opt_state = opt.init(params)\n","\n","    n = Z.shape[0]\n","    key = jax.random.PRNGKey(tr.seed)\n","\n","    def mu_sigma_at(params_local, z_norm_batch):\n","        return surrogate_mu_sigma(params_local, z_norm_batch, net)\n","\n","    def nll_loss(params_local, z_norm_batch, dX_batch):\n","        mu, sig = mu_sigma_at(params_local, z_norm_batch)  # (B,1),(B,1)\n","        # Gaussian conditional likelihood for increments\n","        mean = mu * dt\n","        var = (sig * sig) * dt + DTYPE(1e-12)\n","        resid = dX_batch - mean\n","        nll = 0.5 * (resid * resid) / var + 0.5 * jnp.log(var)\n","        return jnp.mean(nll)\n","\n","    # derivative penalties (encourage smoothness / simplicity)\n","    def grad_penalties(params_local, z_norm_batch):\n","        def mu_scalar(z):\n","            return mu_sigma_at(params_local, z[None, :])[0][0, 0]\n","        def sig_scalar(z):\n","            return mu_sigma_at(params_local, z[None, :])[1][0, 0]\n","\n","        g_mu = jax.vmap(jax.grad(mu_scalar))(z_norm_batch)      # (B,2)\n","        g_sig = jax.vmap(jax.grad(sig_scalar))(z_norm_batch)    # (B,2)\n","        return jnp.mean(g_mu * g_mu), jnp.mean(g_sig * g_sig)\n","\n","    def amp_l1(params_local):\n","        if net.use_amp_residual_sigma and (params_local[\"sig_amp\"] is not None):\n","            return jnp.mean(jnp.abs(params_local[\"sig_amp\"]))\n","        return DTYPE(0.0)\n","\n","    @jax.jit\n","    def step(params_local, opt_state_local, key_local):\n","        key_local, k_idx = jax.random.split(key_local, 2)\n","        idx = jax.random.randint(k_idx, (tr.batch_size,), 0, n)\n","        z_b = Z_norm[idx]\n","        dx_b = dX[idx]\n","\n","        def total(p):\n","            base = nll_loss(p, z_b, dx_b)\n","            gmu, gsig = grad_penalties(p, z_b)\n","            reg = DTYPE(tr.w_grad_mu) * gmu + DTYPE(tr.w_grad_sigma) * gsig + DTYPE(tr.w_amp_l1) * amp_l1(p)\n","            return base + reg, (base, gmu, gsig, amp_l1(p))\n","\n","        (loss, (base, gmu, gsig, a1)), grads = jax.value_and_grad(total, has_aux=True)(params_local)\n","        updates, opt_state_local = opt.update(grads, opt_state_local, params_local)\n","        params_local = optax.apply_updates(params_local, updates)\n","\n","        aux = {\"loss\": loss, \"nll\": base, \"grad_mu\": gmu, \"grad_sig\": gsig, \"amp_l1\": a1}\n","        return params_local, opt_state_local, key_local, aux\n","\n","    print(\"\\n=== Stage A: learn drift/diffusion from trajectory increments ===\")\n","    for it in range(1, tr.steps + 1):\n","        params, opt_state, key, aux = step(params, opt_state, key)\n","        if (it % tr.print_every) == 0 or it == 1:\n","            print(\n","                f\"[surrogate] step {it:6d}  loss={float(aux['loss']):.3e}  \"\n","                f\"nll={float(aux['nll']):.3e}  grad_mu={float(aux['grad_mu']):.2e}  \"\n","                f\"grad_sig={float(aux['grad_sig']):.2e}  amp_l1={float(aux['amp_l1']):.2e}\"\n","            )\n","\n","    return params, stats\n","\n","\n","# =============================================================================\n","# 3) Symmetry learning for variable-coefficient FP (Gaeta–Quintero)\n","# =============================================================================\n","\n","@dataclass(frozen=True)\n","class DomainFP:\n","    t_min: float\n","    t_max: float\n","    x_min: float\n","    x_max: float\n","\n","\n","@dataclass(frozen=True)\n","class TrainConfigFP:\n","    seed: int = 0\n","    m: int = 6  # how many generators to learn jointly\n","\n","    # network\n","    hidden: Tuple[int, ...] = (128, 128, 128)\n","    activation: str = \"swish\"  # swish|tanh|relu\n","\n","    # training\n","    steps: int = 20000\n","    batch: int = 512\n","    lr: float = 3e-4\n","    clip_norm: float = 1.0\n","\n","    # curriculum on t_max\n","    use_curriculum: bool = True\n","    t_max_start: float = 0.6\n","    t_max_end: float = 2.0\n","    curriculum_steps: int = 20000\n","\n","    # losses (same suite as your heat scripts)\n","    w_det: float = 1.0\n","    w_ind: float = 0.8\n","    w_lie: float = 0.05\n","    w_col: float = 0.08\n","    w_logdet: float = 0.05\n","    ramp_steps: int = 15000\n","\n","    # additional paper-aligned structure losses (kept small by default)\n","    w_s2_jacobi: float = 2e-3\n","    w_s3_skewsym: float = 1e-3\n","    w_s4_bilinearity: float = 1e-3\n","    w_s9_after_flow: float = 1e-3\n","\n","    # keep the original (tau,xi,beta) closure loss as a stabilizer (small weight)\n","    w_lie_full3: float = 1e-2\n","\n","    # lie closure\n","    lie_pairs: int = 12\n","    ridge: float = 1e-6\n","\n","    # diversity numerics\n","    min_col_norm: float = 1e-3\n","    gram_eps: float = 1e-4\n","\n","    # logging / eval\n","    log_every: int = 1000\n","    eval_every: int = 5000\n","    eval_grid_t: int = 12\n","    eval_grid_x: int = 64\n","\n","\n","def act_fp(x: Array, name: str) -> Array:\n","    if name == \"swish\":\n","        return x * jax.nn.sigmoid(x)\n","    if name == \"tanh\":\n","        return jnp.tanh(x)\n","    if name == \"relu\":\n","        return jax.nn.relu(x)\n","    raise ValueError(f\"Unknown activation: {name}\")\n","\n","\n","def mlp_init_fp(key, in_dim: int, out_dim: int, hidden: Tuple[int, ...]) -> Dict:\n","    keys = jax.random.split(key, num=len(hidden) + 1)\n","    dims = (in_dim,) + tuple(hidden) + (out_dim,)\n","    params = {\"W\": [], \"b\": []}\n","    for i in range(len(dims) - 1):\n","        k = keys[i]\n","        fan_in, fan_out = dims[i], dims[i + 1]\n","        lim = math.sqrt(6.0 / float(fan_in + fan_out))\n","        W = jax.random.uniform(k, (fan_in, fan_out), minval=-lim, maxval=lim, dtype=DTYPE)\n","        b = jnp.zeros((fan_out,), dtype=DTYPE)\n","        params[\"W\"].append(W)\n","        params[\"b\"].append(b)\n","    return params\n","\n","\n","def mlp_apply_fp(params: Dict, x: Array, activation: str) -> Array:\n","    h = x\n","    for i in range(len(params[\"W\"]) - 1):\n","        h = h @ params[\"W\"][i] + params[\"b\"][i]\n","        h = act_fp(h, activation)\n","    return h @ params[\"W\"][-1] + params[\"b\"][-1]\n","\n","\n","def normalize_t(dom: DomainFP, t: Array) -> Array:\n","    tn = (t - DTYPE(dom.t_min)) / (DTYPE(dom.t_max) - DTYPE(dom.t_min) + DTYPE(1e-12))\n","    return DTYPE(2.0) * tn - DTYPE(1.0)\n","\n","\n","def normalize_tx(dom: DomainFP, tx: Array) -> Array:\n","    t, x = tx[..., 0], tx[..., 1]\n","    tn = (t - DTYPE(dom.t_min)) / (DTYPE(dom.t_max) - DTYPE(dom.t_min) + DTYPE(1e-12))\n","    xn = (x - DTYPE(dom.x_min)) / (DTYPE(dom.x_max) - DTYPE(dom.x_min) + DTYPE(1e-12))\n","    return jnp.stack([DTYPE(2.0) * tn - DTYPE(1.0), DTYPE(2.0) * xn - DTYPE(1.0)], axis=-1)\n","\n","\n","def sample_batch_fp(key, dom: DomainFP, batch: int, t_max: Array) -> Array:\n","    k1, k2 = jax.random.split(key, 2)\n","    t_hi = jnp.maximum(DTYPE(dom.t_min) + DTYPE(1e-6), t_max)\n","    t = jax.random.uniform(k1, (batch,), minval=DTYPE(dom.t_min), maxval=t_hi, dtype=DTYPE)\n","    x = jax.random.uniform(k2, (batch,), minval=DTYPE(dom.x_min), maxval=DTYPE(dom.x_max), dtype=DTYPE)\n","    return jnp.stack([t, x], axis=-1)\n","\n","\n","def orthonormal_basis(M: Array, eps: float = 1e-8) -> Array:\n","    Q, R = jnp.linalg.qr(M, mode=\"reduced\")\n","    diag = jnp.abs(jnp.diag(R))\n","    mx = jnp.max(diag + eps)\n","    keep = diag > eps * mx\n","    keep = jnp.where(\n","        jnp.any(keep),\n","        keep,\n","        jnp.concatenate([jnp.array([True]), jnp.zeros((keep.shape[0] - 1,), dtype=bool)]),\n","    )\n","    return Q[:, keep]\n","\n","\n","def principal_angles(V: Array, W: Array) -> Array:\n","    Qv = orthonormal_basis(V)\n","    Qw = orthonormal_basis(W)\n","    s = jnp.linalg.svd(Qv.T @ Qw, compute_uv=False)\n","    s = jnp.clip(s, 0.0, 1.0)\n","    return jnp.sort(jnp.arccos(s))\n","\n","\n","def best_mixing_residual(V: Array, W: Array) -> float:\n","    WT_W = W.T @ W + DTYPE(1e-8) * jnp.eye(W.shape[1], dtype=DTYPE)\n","    A = jnp.linalg.solve(WT_W, W.T @ V)\n","    resid = jnp.linalg.norm(V - W @ A) / (jnp.linalg.norm(V) + DTYPE(1e-12))\n","    return float(resid)\n","\n","\n","def init_params_fp(key, cfg: TrainConfigFP) -> Dict:\n","    k1, k2, k3 = jax.random.split(key, 3)\n","    return {\n","        \"tau\": mlp_init_fp(k1, in_dim=1, out_dim=cfg.m, hidden=cfg.hidden),\n","        \"xi\": mlp_init_fp(k2, in_dim=2, out_dim=1 * cfg.m, hidden=cfg.hidden),\n","        \"beta\": mlp_init_fp(k3, in_dim=2, out_dim=cfg.m, hidden=cfg.hidden),\n","    }\n","\n","\n","def forward_fp(params: Dict, dom: DomainFP, cfg: TrainConfigFP, tx: Array):\n","    t = tx[..., 0]\n","    tn = normalize_t(dom, t)[..., None]\n","    z = normalize_tx(dom, tx)\n","    tau = mlp_apply_fp(params[\"tau\"], tn, cfg.activation)  # (...,m)\n","    xi_flat = mlp_apply_fp(params[\"xi\"], z, cfg.activation)\n","    xi = xi_flat.reshape(*xi_flat.shape[:-1], cfg.m, 1)  # (...,m,1)\n","    beta = mlp_apply_fp(params[\"beta\"], z, cfg.activation)  # (...,m)\n","    return tau, xi, beta\n","\n","\n","def t_max_curriculum_jax(cfg: TrainConfigFP, dom: DomainFP, step: Array) -> Array:\n","    if not cfg.use_curriculum:\n","        return DTYPE(dom.t_max)\n","    s = jnp.minimum(DTYPE(step), DTYPE(cfg.curriculum_steps))\n","    frac = s / DTYPE(max(cfg.curriculum_steps, 1))\n","    return DTYPE(cfg.t_max_start) + frac * DTYPE(cfg.t_max_end - cfg.t_max_start)\n","\n","\n","def make_losses_varcoeff(\n","    dom: DomainFP,\n","    cfg: TrainConfigFP,\n","    mu_fn: Callable[[Array, Array], Array],\n","    sig_fn: Callable[[Array, Array], Array],\n","):\n","    \"\"\"\n","    Symmetry-learning loss suite.\n","\n","    We keep the original working losses (FP determining equations S8, column-independence,\n","    logdet barrier, column-norm regularizer, and the original 3-component Lie-closure loss),\n","    but we *re-factor* them to align with the paper loss naming scheme from `loss functions.py`:\n","\n","      - S8: FP determining equations (Gaeta–Quintero Eq. 4.12 specialized to 1D)\n","      - S5: column independence (Gram ~ I)\n","      - S1: Lie bracket closure + constancy of structure coefficients on (t,x) fields (tau,xi)\n","      - S2: Jacobi identity on learned structure coefficients (built from S1 projection)\n","      - S3: skew-symmetry of structure coefficients\n","      - S4: (numerical) linearity check of the S1 projection coefficients\n","      - S9: \"after-flow\" robustness: evaluate S8 at points flowed a small epsilon along a few generators\n","\n","    Any terms not in the paper framework are kept as separate \"orig_*\" stabilizers.\n","    \"\"\"\n","\n","    # ---------------------------- FP coefficients (1D) ----------------------------\n","    def sigma2(tt, xx):\n","        s = sig_fn(tt, xx)\n","        return s * s\n","\n","    def A(tt, xx):\n","        # Eq (4.2): A = -1/2 (σ σ^T) ; in 1D this is -1/2 σ^2\n","        return -DTYPE(0.5) * sigma2(tt, xx)\n","\n","    def B(tt, xx):\n","        # Eq (4.2): B = f - ∂_x(σ σ^T) ; in 1D: f - (σ^2)_x\n","        s2_x = jax.grad(lambda x_: sigma2(tt, x_))(xx)\n","        return mu_fn(tt, xx) - s2_x\n","\n","    def C(tt, xx):\n","        # Eq (4.2): C = (∂_x f) - 1/2 ∂_{xx}(σ σ^T) ; in 1D: f_x - 1/2 (σ^2)_{xx}\n","        f_x = jax.grad(lambda x_: mu_fn(tt, x_))(xx)\n","        s2_xx = jax.grad(lambda x_: jax.grad(lambda x2: sigma2(tt, x2))(x_))(xx)\n","        return f_x - DTYPE(0.5) * s2_xx\n","\n","    def coeffs_and_derivs(t_scalar, x_scalar):\n","        A0 = A(t_scalar, x_scalar)\n","        B0 = B(t_scalar, x_scalar)\n","        C0 = C(t_scalar, x_scalar)\n","\n","        A_t = jax.grad(lambda tt: A(tt, x_scalar))(t_scalar)\n","        A_x = jax.grad(lambda xx: A(t_scalar, xx))(x_scalar)\n","\n","        B_t = jax.grad(lambda tt: B(tt, x_scalar))(t_scalar)\n","        B_x = jax.grad(lambda xx: B(t_scalar, xx))(x_scalar)\n","\n","        C_t = jax.grad(lambda tt: C(tt, x_scalar))(t_scalar)\n","        C_x = jax.grad(lambda xx: C(t_scalar, xx))(x_scalar)\n","\n","        return A0, A_t, A_x, B0, B_t, B_x, C0, C_t, C_x\n","\n","    # ---------------------------- generator vector eval ----------------------------\n","    def tau_vec(params, t_scalar):\n","        t_arr = jnp.asarray([[t_scalar]], dtype=DTYPE)\n","        tn = normalize_t(dom, t_arr)\n","        return mlp_apply_fp(params[\"tau\"], tn, cfg.activation)[0]  # (m,)\n","\n","    def xi_vec(params, z):\n","        _, xi, _ = forward_fp(params, dom, cfg, z[None, :])\n","        return xi[0, :, 0]  # (m,)\n","\n","    def beta_vec(params, z):\n","        _, _, beta = forward_fp(params, dom, cfg, z[None, :])\n","        return beta[0]  # (m,)\n","\n","    def point_derivs(params, z):\n","        t = z[0]\n","        tau, xi, beta = forward_fp(params, dom, cfg, z[None, :])\n","        tau, xi, beta = tau[0], xi[0, :, 0], beta[0]\n","\n","        tau_t = jax.jacfwd(lambda tt: tau_vec(params, tt))(t)  # (m,)\n","\n","        J_xi = jax.jacrev(lambda zz: xi_vec(params, zz))(z)  # (m,2)\n","        xi_t = J_xi[:, 0]\n","        xi_x = J_xi[:, 1]\n","\n","        H_xi = jax.jacrev(jax.jacrev(lambda zzz: xi_vec(params, zzz)))(z)  # (m,2,2)\n","        xi_xx = H_xi[:, 1, 1]\n","\n","        J_b = jax.jacrev(lambda zz: beta_vec(params, zz))(z)  # (m,2)\n","        beta_t = J_b[:, 0]\n","        beta_x = J_b[:, 1]\n","\n","        H_b = jax.jacrev(jax.jacrev(lambda zzz: beta_vec(params, zzz)))(z)  # (m,2,2)\n","        beta_xx = H_b[:, 1, 1]\n","\n","        return tau, tau_t, xi, xi_t, xi_x, xi_xx, beta, beta_t, beta_x, beta_xx\n","\n","    # ---------------------------- S8: determining equations ----------------------------\n","    def det_residuals(params, batch):\n","        def one(z):\n","            tau, tau_t, xi, xi_t, xi_x, xi_xx, beta, beta_t, beta_x, beta_xx = point_derivs(params, z)\n","            t_scalar, x_scalar = z[0], z[1]\n","            A0, A_t, A_x, B0, B_t, B_x, C0, C_t, C_x = coeffs_and_derivs(t_scalar, x_scalar)\n","\n","            r1 = (tau_t * A0 + tau * A_t) + xi * A_x - DTYPE(2.0) * A0 * xi_x\n","            r2 = (tau_t * B0 + tau * B_t) - (xi_t - B0 * xi_x + xi * B_x) + DTYPE(2.0) * A0 * beta_x - A0 * xi_xx #needed to change\n","            r3 = (tau_t * C0 + tau * C_t) + beta_t + A0 * beta_xx + B0 * beta_x + xi * C_x\n","            return r1, r2, r3\n","\n","        r1, r2, r3 = jax.vmap(one)(batch)  # each (B,m)\n","        return r1, r2, r3\n","\n","    def s8_fp_determining_loss_1d(params, batch):\n","        r1, r2, r3 = det_residuals(params, batch)\n","        loss = jnp.mean(r1 * r1 + r2 * r2 + r3 * r3)\n","        aux = {\n","            \"s8_r1\": jnp.mean(r1 * r1),\n","            \"s8_r2\": jnp.mean(r2 * r2),\n","            \"s8_r3\": jnp.mean(r3 * r3),\n","        }\n","        return loss, aux\n","\n","    def det_per_gen(params, batch):\n","        r1, r2, r3 = det_residuals(params, batch)\n","        per = jnp.mean(r1 * r1 + r2 * r2 + r3 * r3, axis=0)  # (m,)\n","        return per\n","\n","    # ---------------------------- shared column stacking ----------------------------\n","    def stack_cols(params, batch):\n","        tau, xi, beta = forward_fp(params, dom, cfg, batch)\n","        Vb = jnp.stack([tau, xi[:, :, 0], beta], axis=-1)  # (B,m,3)\n","        return jnp.reshape(jnp.transpose(Vb, (0, 2, 1)), (batch.shape[0] * 3, cfg.m))\n","\n","    # ---------------------------- S5: column independence ----------------------------\n","    def s5_column_independence_loss(params, batch):\n","        V = stack_cols(params, batch)  # (3B,m)\n","        col_norms = jnp.sqrt(jnp.sum(V * V, axis=0) + DTYPE(1e-12))\n","        Vn = V / col_norms[None, :]\n","        G = Vn.T @ Vn / DTYPE(Vn.shape[0])\n","        loss = jnp.mean((G - jnp.eye(cfg.m, dtype=DTYPE)) ** 2)\n","        aux = {\"s5_ind_mse\": loss, \"min_col\": jnp.min(col_norms), \"max_col\": jnp.max(col_norms)}\n","        return loss, aux\n","\n","    # ---------------------------- original stabilizers ----------------------------\n","    def orig_logdet_loss(params, batch):\n","        V = stack_cols(params, batch)\n","        col_norms = jnp.sqrt(jnp.sum(V * V, axis=0) + DTYPE(1e-12))\n","        Vn = V / col_norms[None, :]\n","        G = Vn.T @ Vn / DTYPE(Vn.shape[0])\n","        G = G + DTYPE(cfg.gram_eps) * jnp.eye(cfg.m, dtype=DTYPE)\n","        sign, ld = jnp.linalg.slogdet(G)\n","        return -ld  # encourage large det\n","\n","    def orig_col_norm_loss(params, batch, target: float = 1.0):\n","        V = stack_cols(params, batch)\n","        col_norms = jnp.sqrt(jnp.sum(V * V, axis=0) + DTYPE(1e-12))\n","        tgt = DTYPE(target)\n","        penalty = jnp.mean((col_norms - tgt) ** 2) + jnp.mean(jax.nn.relu(DTYPE(cfg.min_col_norm) - col_norms) ** 2)\n","        return penalty\n","\n","    # Original Lie closure (on (tau,xi,beta) stacked columns) as an extra stabilizer.\n","    def orig_lie_closure_full3(params, batch, key_lie):\n","        V = stack_cols(params, batch)  # (3B,m)\n","        G = V.T @ V + DTYPE(cfg.ridge) * jnp.eye(cfg.m, dtype=DTYPE)\n","\n","        def eval_fields(z):\n","            t = z[0]\n","            return tau_vec(params, t), xi_vec(params, z), beta_vec(params, z)\n","\n","        def bracket_at_point(z, i, j):\n","            (tau_i, xi_i, beta_i) = eval_fields(z)\n","            (tau_j, xi_j, beta_j) = eval_fields(z)\n","\n","            ti, xi_i, bi = tau_i[i], xi_i[i], beta_i[i]\n","            tj, xj, bj = tau_j[j], xi_j[j], beta_j[j]\n","\n","            def tj_fun(t):\n","                return tau_vec(params, t)[j]\n","            def xj_fun(z_):\n","                return xi_vec(params, z_)[j]\n","            def bj_fun(z_):\n","                return beta_vec(params, z_)[j]\n","\n","            dt_tj = jax.grad(tj_fun)(z[0])\n","            grad_xj = jax.grad(xj_fun)(z)         # (2,)\n","            grad_bj = jax.grad(bj_fun)(z)         # (2,)\n","\n","            def ti_fun(t):\n","                return tau_vec(params, t)[i]\n","            def xi_fun(z_):\n","                return xi_vec(params, z_)[i]\n","            def bi_fun(z_):\n","                return beta_vec(params, z_)[i]\n","\n","            dt_ti = jax.grad(ti_fun)(z[0])\n","            grad_xi = jax.grad(xi_fun)(z)         # (2,)\n","            grad_bi = jax.grad(bi_fun)(z)         # (2,)\n","\n","            bt = ti * dt_tj - tj * dt_ti\n","            bx = ti * grad_xj[0] + xi_i * grad_xj[1] - (tj * grad_xi[0] + xj * grad_xi[1])\n","            bb = ti * grad_bj[0] + xi_i * grad_bj[1] - (tj * grad_bi[0] + xj * grad_bi[1])\n","            return bt, bx, bb\n","\n","        k1, k2 = jax.random.split(key_lie, 2)\n","        ii = jax.random.randint(k1, (cfg.lie_pairs,), 0, cfg.m)\n","        jj = jax.random.randint(k2, (cfg.lie_pairs,), 0, cfg.m)\n","\n","        def one_pair(pair_idx):\n","            i = ii[pair_idx]\n","            j = jj[pair_idx]\n","\n","            def one_point(z):\n","                bt, bx, bb = bracket_at_point(z, i, j)\n","                return jnp.stack([bt, bx, bb], axis=0)  # (3,)\n","\n","            Bvec = jax.vmap(one_point)(batch)  # (B,3)\n","            v = jnp.reshape(Bvec, (batch.shape[0] * 3,))  # (3B,)\n","\n","            rhs = V.T @ v\n","            c = jnp.linalg.solve(G, rhs)\n","            r = v - V @ c\n","            rel = jnp.mean(r * r) / (jnp.mean(v * v) + DTYPE(1e-12))\n","            return rel\n","\n","        rels = jax.vmap(one_pair)(jnp.arange(cfg.lie_pairs))\n","        loss = jnp.mean(rels)\n","        return loss, {\"orig_lie_rel_mse\": loss}\n","\n","    # ---------------------------- S1–S4: Lie algebra structure on (tau,xi) ----------------------------\n","    # Ordered pairs (i,j) with i != j\n","    idx = jnp.arange(cfg.m, dtype=jnp.int32)\n","    idx_i = jnp.repeat(idx, repeats=cfg.m - 1)\n","    base = jnp.arange(cfg.m - 1, dtype=jnp.int32)\n","    i_col = idx[:, None]\n","    idx_j = (base + (base >= i_col).astype(jnp.int32)).reshape(-1)\n","    K = int(idx_i.shape[0])\n","\n","    reg_s1 = DTYPE(cfg.ridge)\n","\n","    # Triples i<j<k for Jacobi (computed once on host, then baked in)\n","    triples = [(i, j, k) for i in range(cfg.m) for j in range(i + 1, cfg.m) for k in range(j + 1, cfg.m)]\n","    tri_i = jnp.asarray([t[0] for t in triples], dtype=jnp.int32) if triples else jnp.zeros((0,), dtype=jnp.int32)\n","    tri_j = jnp.asarray([t[1] for t in triples], dtype=jnp.int32) if triples else jnp.zeros((0,), dtype=jnp.int32)\n","    tri_k = jnp.asarray([t[2] for t in triples], dtype=jnp.int32) if triples else jnp.zeros((0,), dtype=jnp.int32)\n","\n","    # Pair indices for S4 linearity check (projection is linear, so this should be ~0)\n","    P_lin = min(8, max(K // 2, 1))\n","    p_lin = jnp.arange(P_lin, dtype=jnp.int32)\n","    q_lin = (p_lin + 1) % K\n","\n","    def _s1_point_err_and_C(tau, xi, tau_t, xi_t, xi_x):\n","        # V: (2,m)\n","        V = jnp.stack([tau, xi], axis=0)  # (2,m)\n","\n","        tau_i, tau_j = tau[idx_i], tau[idx_j]  # (K,)\n","        xi_i, xi_j = xi[idx_i], xi[idx_j]\n","        tau_t_i, tau_t_j = tau_t[idx_i], tau_t[idx_j]\n","        xi_t_i, xi_t_j = xi_t[idx_i], xi_t[idx_j]\n","        xi_x_i, xi_x_j = xi_x[idx_i], xi_x[idx_j]\n","\n","        a = tau_i * tau_t_j - tau_j * tau_t_i\n","        b = tau_i * xi_t_j + xi_i * xi_x_j - tau_j * xi_t_i - xi_j * xi_x_i\n","        Bmat = jnp.stack([a, b], axis=0)  # (2,K)\n","\n","        G = V @ V.T\n","        G_reg = G + reg_s1 * jnp.eye(2, dtype=DTYPE)\n","        X = jnp.linalg.solve(G_reg, Bmat)      # (2,K)\n","        Cmat = V.T @ X                         # (m,K)\n","        PB = V @ Cmat                          # (2,K)\n","        E = Bmat - PB\n","\n","        closure_mse = jnp.mean(E * E)\n","        return closure_mse, Cmat, V, Bmat, G_reg\n","\n","    def s1_lie_loss(params, tx_batch: jnp.ndarray):\n","        def eval_at_z(z):\n","            tau, tau_t, xi, xi_t, xi_x, _, _, _, _, _ = point_derivs(params, z)\n","            return tau, xi, tau_t, xi_t, xi_x\n","\n","        taus, xis, tau_ts, xi_ts, xi_xs = jax.vmap(eval_at_z)(tx_batch)  # each (B,m)\n","\n","        def one_point(tau, xi, tau_t, xi_t, xi_x):\n","            return _s1_point_err_and_C(tau, xi, tau_t, xi_t, xi_x)[:2]\n","\n","        closure_mse_pts, Cs = jax.vmap(one_point)(taus, xis, tau_ts, xi_ts, xi_xs)\n","        closure_mse = jnp.mean(closure_mse_pts)\n","        C_var = jnp.mean(jnp.var(Cs, axis=0))  # scalar\n","\n","        total = closure_mse + C_var\n","        aux = {\n","            \"s1_closure_mse\": closure_mse,\n","            \"s1_C_var\": C_var,\n","        }\n","        return total, aux\n","\n","    def _Cbar_to_c_tensor(Cbar):\n","        # Cbar: (m,K) -> c: (m,m,m) with c[i,j,k]\n","        def one_k(k):\n","            M = jnp.zeros((cfg.m, cfg.m), dtype=DTYPE)\n","            return M.at[idx_i, idx_j].set(Cbar[k, :])\n","        mats = jax.vmap(one_k)(jnp.arange(cfg.m, dtype=jnp.int32))  # (m,m,m) but indexed [k,i,j]\n","        c = jnp.transpose(mats, (1, 2, 0))  # (i,j,k)\n","        return c\n","\n","    def s2_jacobi_loss(params, tx_batch: jnp.ndarray):\n","        # Build structure constants from batch-averaged coefficients from S1.\n","        def eval_at_z(z):\n","            tau, tau_t, xi, xi_t, xi_x, _, _, _, _, _ = point_derivs(params, z)\n","            _, Cmat, _, _, _ = _s1_point_err_and_C(tau, xi, tau_t, xi_t, xi_x)\n","            return Cmat\n","\n","        Cs = jax.vmap(eval_at_z)(tx_batch)  # (B,m,K)\n","        Cbar = jnp.mean(Cs, axis=0)         # (m,K)\n","        c = _Cbar_to_c_tensor(Cbar)         # (m,m,m)\n","\n","        if tri_i.shape[0] == 0:\n","            return DTYPE(0.0), {\"s2_jacobi_mse\": DTYPE(0.0)}\n","\n","        def jac_one(i, j, k):\n","            # vector over m\n","            term1 = c[i, j, :] @ c[:, k, :]  # (m,)\n","            term2 = c[j, k, :] @ c[:, i, :]\n","            term3 = c[k, i, :] @ c[:, j, :]\n","            jac = term1 + term2 + term3\n","            return jnp.mean(jac * jac)\n","\n","        vals = jax.vmap(jac_one)(tri_i, tri_j, tri_k)\n","        loss = jnp.mean(vals)\n","        return loss, {\"s2_jacobi_mse\": loss}\n","\n","    def s3_skewsym_loss(params, tx_batch: jnp.ndarray):\n","        # Use same averaged c tensor as S2.\n","        def eval_at_z(z):\n","            tau, tau_t, xi, xi_t, xi_x, _, _, _, _, _ = point_derivs(params, z)\n","            _, Cmat, _, _, _ = _s1_point_err_and_C(tau, xi, tau_t, xi_t, xi_x)\n","            return Cmat\n","\n","        Cs = jax.vmap(eval_at_z)(tx_batch)  # (B,m,K)\n","        Cbar = jnp.mean(Cs, axis=0)\n","        c = _Cbar_to_c_tensor(Cbar)  # (m,m,m)\n","\n","        skew = c + jnp.swapaxes(c, 0, 1)  # c[i,j,k] + c[j,i,k]\n","        skew_mse = jnp.mean(skew * skew)\n","\n","        diag = c[jnp.arange(cfg.m), jnp.arange(cfg.m), :]  # (m,m)\n","        diag_mse = jnp.mean(diag * diag)\n","\n","        loss = skew_mse + diag_mse\n","        return loss, {\"s3_skew_mse\": skew_mse, \"s3_diag_mse\": diag_mse}\n","\n","    def s4_bilinearity_loss(params, tx_batch: jnp.ndarray):\n","        # Numerical check: projection coefficients are linear in the bracket argument B.\n","        def eval_at_z(z):\n","            tau, tau_t, xi, xi_t, xi_x, _, _, _, _, _ = point_derivs(params, z)\n","            closure_mse, Cmat, V, Bmat, G_reg = _s1_point_err_and_C(tau, xi, tau_t, xi_t, xi_x)\n","\n","            def one_pair(p, q):\n","                Bsum = Bmat[:, p] + Bmat[:, q]  # (2,)\n","                Xsum = jnp.linalg.solve(G_reg, Bsum)  # (2,)\n","                Csum = V.T @ Xsum                     # (m,)\n","                return jnp.mean((Csum - (Cmat[:, p] + Cmat[:, q])) ** 2)\n","\n","            errs = jax.vmap(one_pair)(p_lin, q_lin)\n","            return jnp.mean(errs)\n","\n","        per_pt = jax.vmap(eval_at_z)(tx_batch)\n","        loss = jnp.mean(per_pt)\n","        return loss, {\"s4_lin_mse\": loss}\n","\n","    # ---------------------------- S9: after-flow robustness ----------------------------\n","    def s9_fp_after_flow_loss_1d(params, tx_batch: jnp.ndarray, eps: float = 1e-2):\n","        # Flow a small step along a couple generators and evaluate S8 there as a robustness check.\n","        tau, xi, _ = forward_fp(params, dom, cfg, tx_batch)  # (B,m),(B,m,1)\n","        n_use = min(2, cfg.m)\n","\n","        def flow_i(i):\n","            dt = eps * tau[:, i]\n","            dx = eps * xi[:, i, 0]\n","            t2 = jnp.clip(tx_batch[:, 0] + dt, DTYPE(dom.t_min), DTYPE(dom.t_max))\n","            x2 = jnp.clip(tx_batch[:, 1] + dx, DTYPE(dom.x_min), DTYPE(dom.x_max))\n","            tx2 = jnp.stack([t2, x2], axis=1)\n","            d2, _ = s8_fp_determining_loss_1d(params, tx2)\n","            return d2\n","\n","        vals = jax.vmap(flow_i)(jnp.arange(n_use, dtype=jnp.int32))\n","        loss = jnp.mean(vals)\n","        return loss, {\"s9_det_flow\": loss}\n","\n","    return (\n","        jax.jit(s8_fp_determining_loss_1d),\n","        jax.jit(det_per_gen),\n","        jax.jit(s5_column_independence_loss),\n","        jax.jit(orig_logdet_loss),\n","        jax.jit(orig_col_norm_loss),\n","        jax.jit(s1_lie_loss),\n","        jax.jit(s2_jacobi_loss),\n","        jax.jit(s3_skewsym_loss),\n","        jax.jit(s4_bilinearity_loss),\n","        jax.jit(s9_fp_after_flow_loss_1d),\n","        jax.jit(orig_lie_closure_full3),\n","        jax.jit(stack_cols),\n","    )\n","# =============================================================================\n","# (Optional) ground truth evaluation for the Brownian-motion demo\n","# =============================================================================\n","\n","def ground_truth_heat1d(tx: Array, sigma0: float):\n","    # Basis as provided by you (for the constant-diffusion heat case):\n","    t = tx[:, 0]\n","    x = tx[:, 1]\n","    N = tx.shape[0]\n","    tau = jnp.zeros((N, 6), dtype=DTYPE)\n","    xi = jnp.zeros((N, 6, 1), dtype=DTYPE)\n","    beta = jnp.zeros((N, 6), dtype=DTYPE)\n","\n","    s0 = DTYPE(sigma0)\n","    s02 = s0 * s0\n","\n","    tau = tau.at[:, 0].set(DTYPE(1.0))              # v1: ∂t\n","    xi = xi.at[:, 1, 0].set(DTYPE(1.0))             # v2: ∂x\n","    beta = beta.at[:, 2].set(DTYPE(1.0))            # v3: u∂u\n","\n","    xi = xi.at[:, 3, 0].set(s02 * t)                # v4\n","    beta = beta.at[:, 3].set(-s0 * x)\n","\n","    tau = tau.at[:, 4].set(DTYPE(2.0) * t)          # v5\n","    xi = xi.at[:, 4, 0].set(x)\n","\n","    tau = tau.at[:, 5].set(t * t)                   # v6\n","    xi = xi.at[:, 5, 0].set(x * t)\n","    beta = beta.at[:, 5].set(-DTYPE(0.5) * (t + (x * x) / (s02 + DTYPE(1e-12))))\n","\n","    return tau, xi, beta\n","\n","\n","def stack_cols_gt(tx: Array, sigma0: float) -> Array:\n","    tau, xi, beta = ground_truth_heat1d(tx, sigma0)\n","    Vb = jnp.stack([tau, xi[:, :, 0], beta], axis=-1)  # (N,6,3)\n","    return jnp.reshape(jnp.transpose(Vb, (0, 2, 1)), (tx.shape[0] * 3, 6))\n","\n","\n","def eval_grid(dom: DomainFP, cfg: TrainConfigFP) -> Array:\n","    t = jnp.linspace(DTYPE(dom.t_min), DTYPE(dom.t_max), cfg.eval_grid_t, dtype=DTYPE)\n","    x = jnp.linspace(DTYPE(dom.x_min), DTYPE(dom.x_max), cfg.eval_grid_x, dtype=DTYPE)\n","    T, X = jnp.meshgrid(t, x, indexing=\"ij\")\n","    return jnp.stack([T.reshape(-1), X.reshape(-1)], axis=1)\n","\n","\n","# =============================================================================\n","# main\n","# =============================================================================\n","\n","def main():\n","    # -------------------- configs --------------------\n","    sde_cfg = SDEConfig()\n","    net_cfg = SurrogateNetConfig()\n","    tr_cfg = SurrogateTrainConfig()\n","\n","    fp_cfg = TrainConfigFP()\n","\n","    print(\"=== SDE -> learned FP -> FP symmetries (variable coefficients) ===\")\n","    print(f\"JAX devices: {jax.devices()}\")\n","    print(f\"SDE: T={sde_cfg.T} dt={sde_cfg.dt} n_traj={sde_cfg.n_traj}\")\n","    print(f\"Surrogate: hidden={net_cfg.hidden} depth={net_cfg.depth} act={net_cfg.activation}\")\n","    print(f\"Symmetry: m={fp_cfg.m} hidden={fp_cfg.hidden} act={fp_cfg.activation}\")\n","    print()\n","\n","    # -------------------- simulate data --------------------\n","    key = jax.random.PRNGKey(0)\n","    TX_all, Z, dX = simulate_traj_and_increments(key, sde_cfg)\n","\n","    # set FP domain from data (with margin)\n","    t_min = 1e-3\n","    t_max = float(sde_cfg.T)\n","    x_vals = np.asarray(TX_all[:, 1])\n","    q_lo, q_hi = np.quantile(x_vals, [0.001, 0.999])\n","    x_pad = 0.25 * (q_hi - q_lo + 1e-6)\n","    dom = DomainFP(t_min=t_min, t_max=t_max, x_min=float(q_lo - x_pad), x_max=float(q_hi + x_pad))\n","\n","    print(f\"Data-informed FP domain: t∈[{dom.t_min:.3g},{dom.t_max:.3g}] x∈[{dom.x_min:.3g},{dom.x_max:.3g}]\")\n","    print()\n","\n","    # -------------------- learn mu,sigma --------------------\n","    k_sur = jax.random.PRNGKey(tr_cfg.seed)\n","    params_surr, stats = train_surrogate(k_sur, Z, dX, sde_cfg, net_cfg, tr_cfg)\n","\n","    # freeze learned mu,sigma as scalar-callable functions for the FP loss\n","    def mu_fn(t_scalar: Array, x_scalar: Array) -> Array:\n","        z = jnp.stack([t_scalar, x_scalar], axis=0)[None, :]  # (1,2)\n","        z_norm = (z - stats[\"mean\"][None, :]) / stats[\"std\"][None, :]\n","        mu, _ = surrogate_mu_sigma(params_surr, z_norm, net_cfg)\n","        return mu[0, 0]\n","\n","    def sig_fn(t_scalar: Array, x_scalar: Array) -> Array:\n","        z = jnp.stack([t_scalar, x_scalar], axis=0)[None, :]  # (1,2)\n","        z_norm = (z - stats[\"mean\"][None, :]) / stats[\"std\"][None, :]\n","        _, sig = surrogate_mu_sigma(params_surr, z_norm, net_cfg)\n","        return sig[0, 0]\n","\n","    # quick coefficient sanity check on a few points\n","    tx_probe = jnp.array([[0.25, 0.0], [1.0, 0.5], [1.5, -0.5]], dtype=DTYPE)\n","    mu_probe = jax.vmap(lambda z: mu_fn(z[0], z[1]))(tx_probe)\n","    sig_probe = jax.vmap(lambda z: sig_fn(z[0], z[1]))(tx_probe)\n","    print(\"[surrogate] probe mu(t,x):\", np.asarray(mu_probe))\n","    print(\"[surrogate] probe sigma(t,x):\", np.asarray(sig_probe))\n","    print()\n","\n","    # -------------------- learn FP symmetries --------------------\n","    print(\"=== Stage B: learn FP symmetries from learned (mu,sigma) ===\")\n","    key = jax.random.PRNGKey(fp_cfg.seed)\n","    params = init_params_fp(key, fp_cfg)\n","\n","    s8_loss_fn, det_per_gen_fn, s5_loss_fn, logdet_loss_fn, col_loss_fn, s1_loss_fn, s2_loss_fn, s3_loss_fn, s4_loss_fn, s9_loss_fn, lie_full3_loss_fn, stack_cols_fn = \\\n","        make_losses_varcoeff(dom, fp_cfg, mu_fn=mu_fn, sig_fn=sig_fn)\n","\n","    opt = optax.chain(\n","        optax.clip_by_global_norm(fp_cfg.clip_norm),\n","        optax.adam(fp_cfg.lr),\n","    )\n","    opt_state = opt.init(params)\n","\n","    # evaluation grid + GT span (only meaningful for the demo constant-diffusion SDE)\n","    pts_eval = eval_grid(dom, fp_cfg)\n","    W_eval = stack_cols_gt(pts_eval, sde_cfg.sigma0)\n","\n","    best_params = params\n","    best_score = float(\"inf\")\n","    best_step = 0\n","\n","    @jax.jit\n","    def step_fn(params, opt_state, key, step):\n","        t_max_cur = t_max_curriculum_jax(fp_cfg, dom, step)\n","        key, k_batch, k_lie = jax.random.split(key, 3)\n","        batch = sample_batch_fp(k_batch, dom, fp_cfg.batch, t_max_cur)\n","\n","        ramp = jnp.minimum(DTYPE(1.0), DTYPE(step) / DTYPE(max(fp_cfg.ramp_steps, 1)))\n","        w_ind = DTYPE(fp_cfg.w_ind) * ramp\n","        w_lie = DTYPE(fp_cfg.w_lie) * ramp\n","        w_logdet = DTYPE(fp_cfg.w_logdet) * ramp\n","\n","\n","        def total_loss(p):\n","            s8, s8aux = s8_loss_fn(p, batch)\n","            s5, s5aux = s5_loss_fn(p, batch)\n","            s1, s1aux = s1_loss_fn(p, batch)\n","            s2, s2aux = s2_loss_fn(p, batch)\n","            s3, s3aux = s3_loss_fn(p, batch)\n","            s4, s4aux = s4_loss_fn(p, batch)\n","            s9, s9aux = s9_loss_fn(p, batch)\n","            ld = logdet_loss_fn(p, batch)\n","            col = col_loss_fn(p, batch, target=1.0)\n","            lie3, lie3aux = lie_full3_loss_fn(p, batch, k_lie)\n","\n","            w_s2 = DTYPE(fp_cfg.w_s2_jacobi) * ramp\n","            w_s3 = DTYPE(fp_cfg.w_s3_skewsym) * ramp\n","            w_s4 = DTYPE(fp_cfg.w_s4_bilinearity) * ramp\n","            w_s9 = DTYPE(fp_cfg.w_s9_after_flow) * ramp\n","            w_lie3 = DTYPE(fp_cfg.w_lie_full3) * ramp\n","\n","            tot = (\n","                DTYPE(fp_cfg.w_det) * s8\n","                + w_ind * s5\n","                + w_logdet * ld\n","                + w_lie * s1\n","                + DTYPE(fp_cfg.w_col) * col\n","                + w_s2 * s2\n","                + w_s3 * s3\n","                + w_s4 * s4\n","                + w_s9 * s9\n","                + w_lie3 * lie3\n","            )\n","\n","            aux = {\n","                \"tot\": tot,\n","                # paper-aligned terms\n","                \"s8\": s8,\n","                \"s8_r1\": s8aux[\"s8_r1\"],\n","                \"s8_r2\": s8aux[\"s8_r2\"],\n","                \"s8_r3\": s8aux[\"s8_r3\"],\n","                \"s5\": s5,\n","                \"s5_ind_mse\": s5aux[\"s5_ind_mse\"],\n","                \"s1\": s1,\n","                \"s1_closure_mse\": s1aux[\"s1_closure_mse\"],\n","                \"s1_C_var\": s1aux[\"s1_C_var\"],\n","                \"s2\": s2,\n","                \"s2_jacobi_mse\": s2aux[\"s2_jacobi_mse\"],\n","                \"s3\": s3,\n","                \"s3_skew_mse\": s3aux[\"s3_skew_mse\"],\n","                \"s3_diag_mse\": s3aux[\"s3_diag_mse\"],\n","                \"s4\": s4,\n","                \"s4_lin_mse\": s4aux[\"s4_lin_mse\"],\n","                \"s9\": s9,\n","                \"s9_det_flow\": s9aux[\"s9_det_flow\"],\n","                # original stabilizers\n","                \"logdet\": ld,\n","                \"col\": col,\n","                \"orig_lie3\": lie3,\n","                \"orig_lie_rel_mse\": lie3aux[\"orig_lie_rel_mse\"],\n","                # misc\n","                \"min_col\": s5aux[\"min_col\"],\n","                \"max_col\": s5aux[\"max_col\"],\n","                \"t_max\": t_max_cur,\n","                \"ramp\": ramp,\n","                \"w_s2\": w_s2,\n","                \"w_s3\": w_s3,\n","                \"w_s4\": w_s4,\n","                \"w_s9\": w_s9,\n","                \"w_lie3\": w_lie3,\n","            }\n","            return tot, aux\n","\n","        (loss, aux), grads = jax.value_and_grad(total_loss, has_aux=True)(params)\n","        updates, opt_state = opt.update(grads, opt_state, params)\n","        params = optax.apply_updates(params, updates)\n","        return params, opt_state, key, aux\n","\n","    def eval_angles(tag: str, params_eval):\n","        V_eval = stack_cols_fn(params_eval, pts_eval)\n","        resid = best_mixing_residual(V_eval, W_eval)\n","        ang = principal_angles(W_eval, V_eval)\n","        ang_np = np.asarray(ang)\n","        print(f\"[eval:{tag}]  resid ||V-WA||/||V|| = {resid:.3e}\")\n","        for k, a in enumerate(ang_np, start=1):\n","            print(f\"  angle {k:2d}: {a:.6f} rad  = {a*180.0/math.pi:.4f} deg\")\n","\n","    for step in range(1, fp_cfg.steps + 1):\n","        params, opt_state, key, aux = step_fn(params, opt_state, key, jnp.asarray(step, dtype=DTYPE))\n","\n","        # combined score (same idea as your v4 scripts)\n","        score = float(aux[\"s8\"] + DTYPE(0.25) * aux[\"s5\"] + DTYPE(0.25) * aux[\"col\"] + DTYPE(0.1) * aux[\"s1\"])\n","        if score < best_score:\n","            best_score = score\n","            best_params = params\n","            best_step = step\n","\n","        if (step % fp_cfg.log_every) == 0 or step == 1:\n","\n","            print(\n","                f\"[train] step {step:6d}  tot={float(aux['tot']):.3e}  \"\n","                f\"S8={float(aux['s8']):.3e} (r1={float(aux['s8_r1']):.2e}, r2={float(aux['s8_r2']):.2e}, r3={float(aux['s8_r3']):.2e})  \"\n","                f\"S5={float(aux['s5']):.3e}  \"\n","                f\"S1={float(aux['s1']):.3e} (cl={float(aux['s1_closure_mse']):.2e}, varC={float(aux['s1_C_var']):.2e})  \"\n","                f\"S2={float(aux['s2']):.2e}  S3={float(aux['s3']):.2e}  S4={float(aux['s4']):.2e}  S9={float(aux['s9']):.2e}  \"\n","                f\"logdet={float(aux['logdet']):.3e}  col={float(aux['col']):.3e}  origLie3={float(aux['orig_lie3']):.3e}  \"\n","                f\"min/max_col={float(aux['min_col']):.2e}/{float(aux['max_col']):.2e}  \"\n","                f\"t_max={float(aux['t_max']):.3g} ramp={float(aux['ramp']):.2f}  \"\n","                f\"(w_s2={float(aux['w_s2']):.1e}, w_s3={float(aux['w_s3']):.1e}, w_s4={float(aux['w_s4']):.1e}, w_s9={float(aux['w_s9']):.1e}, w_lie3={float(aux['w_lie3']):.1e})\"\n","            )\n","\n","        if (step % fp_cfg.eval_every) == 0:\n","            eval_angles(\"current\", params)\n","            eval_angles(f\"best_score@{best_step}\", best_params)\n","\n","            # per-generator det residuals (help diagnose degeneracy)\n","            batch_eval = sample_batch_fp(jax.random.PRNGKey(1234), dom, fp_cfg.batch, DTYPE(dom.t_max))\n","            per = np.asarray(det_per_gen_fn(params, batch_eval))\n","            print(\"  per-gen det mse:\", np.array2string(per, precision=3, floatmode=\"fixed\"))\n","\n","    print(\"\\n=== Final evaluation ===\")\n","    eval_angles(\"current\", params)\n","    eval_angles(f\"best_score@{best_step}\", best_params)\n","\n","\n","if __name__ == \"__main__\":\n","    main()"]}]}