# Error as Signal: Stiffness-Aware Diffusion Sampling via Embedded Runge-Kutta Guidance

[![arXiv](https://img.shields.io/badge/arXiv-2400.00000-b31b1b.svg)](https://arxiv.org/abs/xxxx.xxxxx) This repository contains the official implementation of the paper **"Error as Signal: Stiffness-Aware Diffusion Sampling via Embedded Runge-Kutta Guidance"**. 

ERK-Guid proposes a novel stiffness-aware sampling method for diffusion models. By utilizing the Embedded Runge-Kutta (ERK) solution difference as a proxy for local truncation error (LTE), we effectively suppress severe errors along the dominant eigenvector axes in stiff regions without requiring additional evaluation.

## 🛠 Environment Setup

The code was tested on **Ubuntu, CUDA 12.1, and RTX 3090 GPUs**.

**1. Create and activate the Conda environment:**
You can set up the environment using the provided `environment.yml` or manually via the commands below:

```bash
conda create -n ERK-Guid python=3.10 -y
conda activate ERK-Guid

# Install PyTorch and related packages
conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=12.1 -c pytorch -c nvidia

# Install required pip packages
pip install huggingface-hub==0.20.2 diffusers==0.26.3 accelerate==0.27.2 click==8.3.1 scipy==1.15.3
```

**2. Install `torch-fidelity` for metric calculations:**
```bash
pip install -e git+[https://github.com/toshas/torch-fidelity.git@master#egg=torch-fidelity](https://github.com/toshas/torch-fidelity.git@master#egg=torch-fidelity)
```

## 🚀 Usage

Our codebase is built on top of the [EDM2](https://github.com/NVlabs/edm2) pipeline. The core ERK-Guid logic is implemented in `generate_images.py`.

### Reproducing Tables 2 & 3

We provide a bash script to reproduce the sampling and evaluation results for Table 2 and Table 3.

*Note: Ensure you have the reference statistics (`dataset-refs/img512.pkl`) downloaded following the original EDM2 instructions. For Precision and Recall evaluations, the original ImageNet dataset is also required. After downloading and applying resize and center-crop preprocessing, provide the path to this folder as `input2` in the `calculate_precision_recall_is.py` script.*

```bash
# Run the evaluation script (Uses 8 GPUs by default)
bash scripts/reproduce.sh
```
Images are temporarily generated in the `output/` directory for evaluation (FID, Precision, Recall, IS) and automatically removed afterward to save disk space. The final numerical results will be appended to `output/result.txt`.

### Reproducing Table 4

1. Clone the `diff-sampler` repository:
```bash
git clone [https://github.com/zju-pi/diff-sampler.git](https://github.com/zju-pi/diff-sampler.git)
```
2. Replace the original `solvers.py` with our provided implementation:
```bash
cp src/solvers_erk_guid.py diff-sampler/diff-solvers-main/solvers.py
```
3. **Integrate hyperparameters:** Modify the sampling scripts in `diff-sampler` to accept and pass our custom hyperparameters (`w_stiff` and `w_con`) down to the replaced solver function.
4. Run your modified sampling scripts. (The exact values for $w_{stiff}$ and $w_{con}$ used in our experiments are detailed in the paper).

## 📌 Citation

If you find this code useful for your research, please cite our paper:

```bibtex
@inproceedings{kong2026error,
  title={Error as Signal: Stiffness-Aware Diffusion Sampling via Embedded Runge-Kutta Guidance},
  author={Kong, Inho and Lee, Sojin and Hong, Youngjoon and Kim, Hyunwoo J},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2026}
}
```

## 🙏 Acknowledgements

This repository heavily borrows from the excellent codebases of [EDM2](https://github.com/NVlabs/edm2) by NVIDIA and [diff-sampler](https://github.com/zju-pi/diff-sampler). We thank the authors for their open-source contributions.