# MiSe-GNN

> Official PyTorch implementation of the paper **"Error-Driven Graph Augmentation for Mesh-Based PDE Surrogates"**.

## 📖 Key Idea

<p align="center">
  <img src="figures/inferflow.png" width="1000">
</p>

<p align="center">
  <em>
  Overview of the MiSe-GNN inference flow, illustrating error-guided graph augmentation
  for steady-state flow prediction.
  </em>
</p>

## 🚀 Usage

The training and inference pipeline is handled via `scripts/cli.py`.

### Training Command
To train the model on the `2d_profile` dataset for `Pressure` prediction:
```bash
# Set configuration and runtime options
MODE="mise_gnn"                           # Options: mgn | mgn_tree | mise_gnn
CONFIG_PATH="configs/2d_profile.json"
TARGET="Pressure"
RUN_NAME="mise_gnn_${TARGET}"
SAVE_PATH="output"

# Launch
python cli.py \
  --config "$CONFIG_PATH" \
  --mode "$MODE" \
  --target_field "$TARGET" \
  --run_name "$RUN_NAME" \
  --save_path "$SAVE_PATH"
```

The `--target_field` argument defines the specific field (e.g., `U1`, `Pressure`, etc.) you want the model to predict.

Supported modes:
- `mgn`: standard MeshGraphNet baseline
- `mgn_tree`: hierarchical MGN with tree-based augmentation
- `mise_gnn`: proposed error-guided self-edge enhanced GNN

## 🧩 Datasets & Configs

Configuration files are located in the `configs/` directory. Each `.json` file defines dataset paths, model hyperparameters, and training settings.

Supported datasets:
- `2d_profile.json`
- `2d_multiscale.json`
- `tensile2d.json`
- `vkils59.json`

## ⚙️ Dependencies

Install the required Python libraries before running the scripts, we strictly recommend matching the versions for PyTorch, PyTorch Geometric and DGL to avoid CUDA conflicts.

**Deep Learning & Graph Core:**
* `torch == 2.6.0` (CUDA 12.4)
* `dgl == 2.4.0.th24.cu124` (CUDA 12.4)
* `torch-geometric == 2.6.0`
* `torch-scatter == 2.1.2+pt26cu124`, `torch-sparse == 0.6.18+pt26cu124`, `torch-cluster == 1.6.3+pt26cu124`, `torch-spline-conv == 1.2.2+pt26cu124`
* `nvidia-physicsnemo==1.4.0a0`

**Scientific Computing & Utils:**
* `numpy`, `pandas`, `scipy`
* `matplotlib`, `scikit-learn`
* `pyyaml`, `h5py`, `tqdm`
* `muscat == 2.5.1`
* `plaid == 0.1.10`

## 🖊️ Citation
If you use this code or our results in your research, please cite:

```
@inproceedings{misegnn2026,
  title={Error-Driven Graph Augmentation for Mesh-Based PDE Surrogates},
  author={Anonymous Authors},
  booktitle={Proceedings of the International Conference on Machine Learning (ICML)},
  year={2026}
}
```

## 📜 License
This project is licensed under the **BSD 3-Clause License**.
