
<div align="center">
  <h1>
    Rapid training of Hamiltonian graph networks using random features
  </h1>
  <h4>
    Train random feature Hamiltonian graph neural networks (RF-HGN) using <a href="https://gitlab.com/fd-research/swimnetworks/-/tree/swimnetworks-hgn?ref_type=heads">SWIM</a> (see our <a href="https://arxiv.org/abs/2506.06558">paper</a>).
  </h4>
  <p float="center">
    <img src="assets/model.png" width="400" alt="Our RF-HGN model architecture is illustrated." />
    <img src="assets/rope_example.gif" width="100" alt="Model is trained with 5 nodes and being applied to 3, 6, and 10 nodes. Red circles around the predictions represent the true solution." />
  <figcaption>Left: RF-HGN architecture is illsutrated. Middle: Zero-shot evaluation of RF-HGN trained with 5 nodes, red circles represent the true solution.</figcaption>
  </p>
  <p float="center">
    <img src="assets/chain_example.gif" width="400" alt="Model is trained with 5 nodes and being applied to 10 nodes. Red circles around the predictions represent the true solution." />
  </p>
  <figcaption>Zero-shot evaluation of (RF-)HGNs trained with 5 nodes. Left: Adam trained HGN. Middle: Random feature HGN using ELM. Right: Random feature HGN using SWIM.</figcaption>
</div>



---

<details open>
  <summary>
    <h2>Models</h2>
  </summary>

  Available models are
  - `HNN` with `MLP` architecture (see `src/model/fcnn.py`) and gradient-descent based training (see `src/train/trad.py`),
  - `HGN` with `GNN` architecture (see `src/model/gnn.py`) and gradient-descent based training,
  - `RF-HNN` with `MLP` architecture and random feature (e.g. SWIM) training (see `src/train/sample.py`),
  - `RF-HGN` with `GNN` architecture and random feature (e.g. SWIM) training.
</details>

---

<details close>
  <summary>
    <h2>Setup and overview</h2>
  </summary>

  This project expects python version `3.12.10` and the corresponding `pip` when installing the dependencies at `requirements.txt`.
  Physics-enhanced random-feature graph neural networks with a very fast optimizer (SWIM). If you want to
  reproduce the exact same results as the paper experiments, you need a CUDA capable GPU. Otherwise,
  you can still run the experiments by setting the parameter `device` in the .toml config files to `cpu`.

  Repository overview:
  - `configs` includes .toml files that list main (hyper-)parameters (data, model,...) for all the experiments.
  - `src/data` includes code for generating the mass-spring data with different geometries (chain and lattice) and degrees-of-freedom (2D and 3D).
  - `src/model` includes full-connected and graph network torch model implementations (independent of the training method).
  - `src/train` includes random feature and SOTA iterative gradient descent-based optimization algorithms.
  - `src/utils` is for utility functions like error functions, Hamilton's equation, ODE-solvers etc.
  - `src/swimnetworks` submodule implementing the SWIM algorithm.
  - `assets` includes assets like .gif and .png files.
</details>

---

<details close>
  <summary>
    <h2>Example</h2>
  </summary>

  Here is a small pseudo-code example of defining a (SWIM) RF-HGN:
  ```py
  from src.model import GNN
  from src.train import sample_and_linear_solve
  from src.utils import SamplingArgs, LinearSolveArgs

  dof = 2          # Example system is in 2D
  vdim = 2*dof     # [q_bar, p_bar]
  edim = dof+1     # [delta_q_bar, delta_q_bar_norm]

  # For the example target: chain of objects in 2D space.
  # Adjust n_obj and edge_index accordingly when testing (e.g. for zero-shot)
  gnn = GNN(dof=dof, n_obj=[5], edge_index=[[4,3], [3,2], [2,1], [1,0]], node_features_dim=vdim, edge_features_dim=edim)

  sampling_args = SamplingArgs(
    param_sampler='relu'
    sample_uniformly=True,
    resample_duplicates=True,
    dtype=np.float32,
  )
  linear_solve_args = LinearSolveArgs(
    mode="forward",
    driver='gels',
    rcond=1e-10,
    device='cpu',
  )
  sample_and_linear_solve(gnn, x_train, L, dxdt_train, sampling_args, linear_solve_args)
  ```

</details>

---

<details close>
  <summary>
    <h2>Paper experiments</h2>
  </summary>

  All the experiments can be first conducted using the scripts `main_*.py` and then be plotted using the
  scripts `plot_*.py`. In the following you can find information how to plot the already saved data. For
  generating the data you should first run the experiments as described in the next section.
</details>

---

<details close>
  <summary>
    <h2>Plotting using saved experiment data</h2>
  </summary>

  Given the generated data saved at `paper-experiments`, the plotting scripts can be used to plot the paper figures as:

  **Optimizer study (Table 1):**
  ```sh
      python plot_optim_comparison.py \
          -d paper-experiments/optim-comparison/lattice\[3,\ 3\]_cuda
  ```
  **Adam comparison (Appendix Figure H.9, H.31):**
  ```sh
      python plot_adam_comparison.py \
          -d paper-experiments/adam-comparison/lattice\[2,\ 2\]_cuda/
  ```
  ```sh
      python plot_adam_comparison.py \
          -d paper-experiments/adam-comparison/lattice\[3,\ 3\]_cuda/
  ```
  ```sh
      python plot_adam_comparison.py \
          -d paper-experiments/adam-comparison/lattice\[4,\ 4\]_cpu/
  ```
  **Node scaling (Figure 6):**
  ```sh
      python plot_node_scaling.py \
          -d paper-experiments/node-scaling/chain_start2_end1024fcnn_4096gnn \
          --all
  ```
  **2D chain integration (Figure 7):**

  Trained on 8, tested on 8 nodes (Figure 7 top):
  ```sh
  python plot_integration_2DOF_chain.py \
      --node_scaling_dir ./paper-experiments/node-scaling/chain_start2_end1024fcnn_4096gnn \
      --Nx_gnn 8 --Nx_test 8 --delta_t 1e-3 --len_traj 5000 --obj 4 --all
  ```
  (Zero-shot) Trained on 8, tested on 4096 nodes (Figure 7 bottom):
  ```sh
  python plot_integration_2DOF_chain.py \
      --node_scaling_dir ./paper-experiments/node-scaling/chain_start2_end1024fcnn_4096gnn \
      --Nx_gnn 8 --Nx_test 4096 --delta_t 1e-3 --len_traj 5000 --obj 2048 --all
  ```
  **Corner node integrations (Figure H.11):**

  Trained on 8, tested on 8 nodes, plots left corner node trajectory:
  ```sh
  python plot_integration_2DOF_chain.py \
      --node_scaling_dir ./paper-experiments/node-scaling/chain_start2_end1024fcnn_4096gnn \
      --Nx_gnn 8 --Nx_test 8 --delta_t 1e-3 --len_traj 5000 --obj 0 --all
  ```
  Trained on 8, tested on 8 nodes, plots right corner node trajectory:
  ```sh
  python plot_integration_2DOF_chain.py \
      --node_scaling_dir ./paper-experiments/node-scaling/chain_start2_end1024fcnn_4096gnn \
      --Nx_gnn 8 --Nx_test 8 --delta_t 1e-3 --len_traj 5000 --obj 7 --all
  ```
  **Translation and rotation invariance experiment (Figure 2)**
  ```sh
      python plot_translation_rotation_invariant.py \
          -o paper-experiments/translation-rotation-invariance \
          --len_traj 100 --delta_t 1e-4
  ```
  **Chain simulation for Table 3, Figure H.10-15**
  ```sh
      python plot_chain_simulation.py \
          -true <path-to-folder>/true_simulation_{spring,anharmonic,morse}-chain.npz \
          -swim <path-to-folder>/swim/pred_simulation_{spring,anharmonic,morse}-chain.npz \
          -elm <path-to-folder>/elm/pred_simulation_{spring,anharmonic,morse}-chain.npz \
          -adam <path-to-folder>/adam/pred_simulation_{spring,anharmonic,morse}-chain.npz \
          -o <output-directory> \
          -t configs/plot_chain_simulation.toml
  ```
</details>

---

<details close>
  <summary>
    <h2>Running the experiments</h2>
  </summary>

  You can run the following scripts to run the experiments and save the results to your desired location:

  **Optimizer study (for Table 1 data):**
  ```sh
      python main_optim_comparison.py \
          -f ./configs/optim_comparison.toml \
          -o output_dir
  ```
  **Adam comparison (for Appendix Figure 5, Figure H.8, Figure H.9, Table H.31 data):**

  For 2x2:
  ```sh
      python main_adam_comparison.py \
          -f ./configs/2x2_lattice_adam_comparison.toml \
          -o output_dir
  ```
  For 3x3:
  ```sh
      python main_adam_comparison.py \
          -f ./configs/3x3_lattice_adam_comparison.toml \
          -o output_dir
  ```
  For 4x4:
  ```sh
      python main_adam_comparison.py \
          -f ./configs/4x4_lattice_adam_comparison.toml \
          -o output_dir
  ```
  **Example evaluation of zero-shot test case 10x10 3D-lattice after training 2x2, 3x3, or 4x4 (For Figure 5 data please adjust `n_obj` accordingly):**
  ```sh
      python eval_model.py \
          -m <path-to-.pt-pretrained-model> \
          -n_obj 10 10
  ```
  Note: The median performed models are evaluated in the paper in Figure 5. You can find out which pretrained model is the 'median' by running the `plot_adam_comparison.py` script.

  **Node scaling (for Figure 6, H.16-17 data):**
  ```sh
      python main_node_scaling.py \
          -f configs/node_scaling.toml \
          -o output_dir
  ```
  **2D chain integration (for Figure 6, H.16-17 data):**

  Trained on 8, tested on 8 nodes (for Figure 6 top data)

  ```sh
  python main_integrate_2DOF_chain.py \
      --node_scaling_dir ./node-scaling-dir/chain_start2_end1024fcnn_4096gnn \
      --Nx_gnn 8 --Nx_test 8 --delta_t 1e-3 --len_traj 5000
  ```

  (Zero-shot) Trained on 8, tested on 4096 nodes (for Figure 6 bottom data)
  ```sh
  python main_integrate_2DOF_chain.py \
      --node_scaling_dir ./node-scaling-dir/chain_start2_end1024fcnn_4096gnn \
      --Nx_gnn 8 --Nx_test 4096 --delta_t 1e-3 --len_traj 5000
  ```

  **Ablation study of network widths (Figure F.7, Table F.27), condition number (Table F.28), message-passing (Table F.29):**
  ```sh
      python main_ablation_study.py \
          -f configs/ablation_study.toml \
          -o paper-experiments
  ```
  **Translation and rotation invariance experiment (for Figure 2 data):**
  ```sh
      python main_translation_rotation_invariant.py \
          -f configs/translation_rotation_invariant.toml \
          -o output_dir
  ```

  After running the experiments you can run the plotting scripts as described in the previous section.

  Note: The benchmark study (for Table 4, G.30 and Figure H.22 data) can be found under `/comparing_GNNs`. For the (SWIM) RF-HGN results you can run the script `main_benchmark.py`.

  **Additional experiments:**

  Example with additive Gaussian noise (Table H.34):
  ```sh
  python main_noise_study.py \
      -f configs/noise_study.toml \
      -o outs/tables \
      --abs
  ```

  Naive batch-wise training example using Lennard-Jones potential (Table H.35)
  ```sh
  python main_lennard_jones_with_batching_small_study.py \
      -f configs/batch_wise_training_study.toml \
      -o outs/tables \
      --abs
  ```

  Example with Random Fourier Features (Table H.38):
  ```sh
  python main_rff_sigma_study.py \
      -f configs/fourier_experiment.toml \
      --abs
  ```

  Different potentials with the chain geometry, first train models:
  ```sh
  python main_train.py \
      -f configs/train.toml \
      -s {spring-chain,anharmonic-chain,morse-chain} \
      -o chain-potentials/models \
      -t {elm,swim,adam}
  ```
  Then simulate with zero-shot test case (Table 3, Figure H.10-15):
  ```sh
  python simulate_chain_model.py \
      -s {spring,anharmonic,morse}-chain \
      -t ./configs/simulate.toml \
      -m chain-potentials/{adam_hgn,elm_rf_hgn,swim_rf_hgn}_chain_\[5\]_2DOF_{spring,anharmonic,morse}-chain.pt \
      -o chain-potentials/{adam,elm,swim}
  ```
  The resulting simulation solutions (trajectories) should be plotted via `plot_chain_simulation`.

  Other Lennard-Jones experiments: Create Data
  ```sh
  mkdir data-lj

  python create_lj_data.py \
    -nx 3 -ny 3 \
    -dt 0.005 -T 50 \
    -m 1.0 -eps 1.0 -sig 1.0 -cutoff 10.0 -prec single \
    -q_noise 0.1 -p_noise 0.0 \
    --outdir data-lj -bc none \
    -ns 500 -fl 2

  python create_lj_data.py \
    -nx 6 -ny 6 \
    -dt 0.005 -T 50 \
    -m 1.0 -eps 1.0 -sig 1.0 -cutoff 2.0 -prec single \
    -q_noise 0.1 -p_noise 0.0 \
    --outdir data-lj -bc none \
    -ns 500 -fl 2
  ```

  Run the examples (9 particle train/test and 36 train 64 test), `testobj` is number of particles along a row/column in a grid initialization with displacement, so particle size is its square:
  ```sh
  mkdir paper-results

  python main_lennard_jones.py -f configs/lennard_jones_9particles_train_test.toml \
    -d data-lj/500sim_9particles_50steps_0.005deltat_10.0cutoff.npy -nt 300 -ns 20 \
    -zdx 0.1 -o paper-results --testobj 3 --save

  python -u main_lennard_jones.py -f configs/lennard_jones_36particles_train_64particles_test.toml \
    -d data-lj/500sim_36particles_50steps_0.005deltat_2.0cutoff.npy -nt 200 -ns 20 \
    -zdx 0.1 -o paper-results --testobj 8 --save
  ```

  Plot the results (9 particles train/test, 36 particles train and 64 particles zero-shot test)
  ```sh
  python plot_lennard_jones.py \
    -swim data/swim_gnn_9train_9test.npz \
    -elm data/elm_gnn_9train_9test.npz \
    -adam data/adam_gnn_9train_9test.npz \
    -o plots -fl 1000

  python plot_lennard_jones.py \
    -swim data/swim_gnn_36train_64test.npz \
    -elm data/elm_gnn_36train_64test.npz \
    -adam data/adam_gnn_36train_64test.npz \
    -o plots -fl 1000
  ```

</details>

---

<details>
  <summary>
    <h2>Notebooks</h2>
  </summary>

  The notebook `rope.ipynb` includes the setup done for the introductory movie. You can use the training script `main_train.py` to train a model:
  ```sh
      python main_train.py \
          -f configs/train.toml \
          -s spring-chain \
          -o output_dir/ \
          -t {adam,swim}
  ```
  You can then use the saved model in the notebook.
</details>

---

<details close>
  <summary>
    <h2>Citation</h2>
  </summary>

  If you use SWIM RF-HGN in your research, please cite our <a href="https://arxiv.org/abs/2506.06558">paper</a>.

  ```
  @inproceedings{rahma-2026-rfhgn,
      title     = {Rapid training of Hamiltonian graph networks using random features},
      author    = {Atamert Rahma and Chinmay Datar and Ana Cukarska and Felix Dietrich},
      booktitle = {International Conference on Learning Representations (ICLR)},
      year      = {2026},
      note      = {Accepted for ICLR 2026}
      url       = {https://arxiv.org/abs/2506.06558},
  }
  ```

</details>
