# GenDA: Generative Data Assimilation

This codebase accompanies the paper **"GenDA: Generative Data Assimilation on Complex Urban Areas via Classifier-Free Diffusion Guidance"**. It implements a generative data assimilation framework using diffusion modeling to reconstruct high-resolution fluid flow fields from sparse observations.

## Data Availability

**The datasets used in this work will be publicly available upon publication of this work** 

Consequently, this repository currently serves as a reference implementation of the model architecture and training logic.

## Project Structure

### `gen_da/` - Core Library
The core implementation of the generative model and diffusion framework.
*   `gen_da.py`: Defines the `GenSynth` class, the main diffusion model wrapper containing the training loss and sampling loops.
*   `denoiser.py`: Implements the Graph Neural Network (GNN) denoiser architecture used to predict noise/fields.
*   `dpm_solver_plus_plus_2s.py`: Implementation of the DPM-Solver++ 2S high-order diffusion sampler.
*   `deep_typed_graph_net.py`: JAX implementation of the Graph Networks Simulator generalized for TypedGraphs (based on Pfaff et al. 2021).
*   `typed_graph_net.py`: Library of underlying typed GNN components.
*   `xarray_jax.py`: Utilities for handling `xarray` structures with JAX backends to enable JIT-compilation.
*   `mlp.py`, `casting.py`, `model_utils.py`, `samplers_utils.py`: Various neural network blocks and utility functions.

### `training/` - Training Infrastructure
Scripts and modules associated with the training loop.
*   `train.py`: The primary training script using Flax/JAX, Optax, and Orbax. Handles distributed training setup.
*   `graph_dataset.py`: Utilities for loading graph structures and creating batched datasets.
*   `obs_sampling.py`: Implements observation strategies, including random sampling and mobile drone swarm trajectories.

### `baselines/` - Comparison Methods
Implementations of baseline methods for benchmarking.
*   `meshgraphnets.py`: Implementation of the MeshGraphNet (MGN) and MultiScale MGN architectures.
*   `train_meshgraphnets.py`: Training script for the supervised baseline models.
*   `lcsvd.py`: Implementation of a sensor-based POD/SVD baseline (Linear Reconstruction).

### `evaluation_scripts/` - Analysis Tools
Scripts used to generate the results and plots in the paper.
*   `compare_models_reconstruction.py`: Compares reconstruction quality (RMSE, Energy Spectra) between GenDA and baselines.
*   `compare_cfg_sweep.py`: Analyzes the impact of the Classifer-Free Guidance (CFG) scale $\gamma$.
*   `compare_observations_sweep.py`: Evaluates performance stability across varying numbers of sparse observations.
*   `compare_sampling_strategies.py`: Benchmarks performance when using random sensors versus drone swarms.
*   `vtu_writer.py`: Helper tool to export reconstructed fields to VTU format for visualization in ParaView.

### `data_processing/` - Preprocessing
Tools for converting raw simulation data into graph format.
*   `build_graph_structures.py`: Converts raw simulation slices (VTU) into node/edge graph structures, typically implementing k-NN connectivity.
*   `normalize_data.py`: Computes global mean and standard deviation statistics for the training dataset.

## Installation

This project relies on JAX and Flax. Install dependencies via `requirements.txt`:

```bash
pip install -r requirements.txt
```

*Note: You may need to install the specific version of `jax` and `jaxlib` compatible with your CUDA version (see [JAX installation guide](https://github.com/google/jax)).*
