# Supplementary Code: Generative Inverse Design with Abstention via Diagonal Flow Matching

This repository contains the implementation for the paper "Generative Inverse Design with Abstention via Diagonal Flow Matching".

## Overview

We introduce **Diagonal Flow Matching (Diag-CFM)**, a zero-anchoring strategy for conditional flow matching that yields a learning problem provably invariant to coordinate permutations. This eliminates the sensitivity to design and label orderings present in standard CFM, yielding order-of-magnitude improvements in round-trip accuracy.

We also develop two architecture-specific uncertainty metrics, **Zero-Deviation** and **Self-Consistency**, that enable:
1. Selecting the best candidate among multiple generations
2. Abstaining from unreliable predictions
3. Detecting out-of-distribution targets

## Requirements

- Python 3.10+
- PyTorch
- NumPy
- Pandas
- Matplotlib
- SciPy
- Seaborn
- scikit-learn
- tqdm

Install dependencies:
```bash
pip install torch numpy pandas matplotlib scipy seaborn scikit-learn tqdm
```

Or install the package:
```bash
pip install -e .
```

## Repository Structure

```
supplementary_code/
├── setup.py                     # Package installation
├── README.md                    # This file
└── uq_diagcfm/
    ├── __init__.py
    │
    │   # === CORE MODELS ===
    ├── models.py                # MLP, INN, ConditionalINN architectures
    ├── losses.py                # INN loss functions (MMD, bidirectional loss)
    ├── solvers.py               # ODE integration (Euler solver)
    │
    │   # === TRAINING ===
    ├── train.py                 # Training routines for Diag-CFM, CFM, INN
    ├── models_for_datasets.py   # Dataset-specific model creation
    ├── checkpointing.py         # Model checkpoint saving/loading
    ├── ensembles.py             # Ensemble loading utilities
    │
    │   # === DATA UTILITIES ===
    ├── data_utils_gas_turbine.py  # Gas Turbine Combustor dataset
    ├── data_utils_unifoil.py      # Unifoil airfoil dataset
    ├── data_utils_dtlz.py         # DTLZ benchmark
    │
    │   # === EVALUATION ===
    ├── evaluation_utils.py      # Forward/inverse performance metrics
    ├── uq_evaluation_utils.py   # Uncertainty quantification metrics
    ├── ood_utils.py             # OOD point generation
    │
    │   # === EVALUATION SCRIPTS ===
    ├── evaluate_results_gas_turbine.py  # Performance evaluation
    ├── evaluate_results_unifoil.py
    ├── evaluate_results_dtlz.py
    ├── evaluate_uq_gas_turbine.py       # UQ/OOD detection evaluation
    ├── evaluate_uq_unifoil.py
    ├── evaluate_uq_dtlz.py
    ├── select_best.py                   # Select-best evaluation
    ├── error_rejection.py               # Error-rejection (abstention) evaluation
    ├── ablation_diag_cfm.py             # CFM vs Diag-CFM ablation study
    │
    │   # === UTILITIES ===
    ├── paths.py                 # Path configuration
    └── utils.py                 # General utilities
```

## Diagonal Flow Matching (Diag-CFM)

### Key Concept

Standard CFM pairs design parameters with a concatenation of noise and labels: `[z; y] -> x`. This creates sensitivity to the arbitrary ordering of coordinates.

Diag-CFM uses **zero-anchoring**: we augment the state space and pair labels with zero:
- Source state: `s_0 = [z; y]` where `z ~ Uniform[0,1]^P`
- Target state: `s_1 = [x; 0_L]` where `x` is the design and `0_L` is a zero vector

This ensures the per-coordinate regression always matches:
- Design coordinates to noise (`x - z`)
- Label coordinates to zero (`0 - y = -y`)

### Training

The Diag-CFM loss is:
```
L_Diag-CFM = E[(||v_θ(t, s_t) - (s_1 - s_0)||²)]
```

where `s_t = (1-t)·s_0 + t·s_1` is the linear interpolation.

### Bidirectional Use

**Synthesis (inverse design):** Given target `y*`, sample `z ~ Uniform[0,1]^P`, set `s(0) = [z; y*]`, integrate forward to `t=1`, return first `P` coordinates as design `x`.

**Analysis (forward prediction):** Given design `x`, set `s(1) = [x; 0_L]`, integrate backward to `t=0`, return last `L` coordinates as predicted labels `ŷ`.

## Uncertainty Quantification Metrics

### Zero-Deviation (Diag-CFM specific)

During synthesis, the label dimensions should flow from `y*` to approximately zero. We measure:
```
Zero-Deviation(x) = ||s(1)[P+1:P+L]||²
```
where `s(1)[P+1:P+L]` are the last `L` components of the synthesis output.

**Cost:** No additional forward passes (computed as byproduct of generation).

### Self-Consistency (Diag-CFM specific)

A reliable generated design should reconstruct the target labels when passed through analysis:
```
Self-Consistency(x) = ||y_reconstructed - y*||²
```

**Cost:** One additional forward pass per candidate.

### Ensemble Variance

Train multiple models with different initializations. Compute variance across ensemble predictions:
```
Ensemble-Variance(x) = Tr(D_y^{-1} · Σ_ens(x))
```

### FM Loss (Flow Matching Loss)

Evaluate the flow matching loss at `t=0.5`:
```
FM-Loss(x) = ||v_θ(0.5, s_0.5) - (s_1 - s_0)||²
```

## Training Models

### Quick Start

```bash
# Train a single Diag-CFM model on Gas Turbine
python -m uq_diagcfm.train train_gas_turbine

# Train on Unifoil
python -m uq_diagcfm.train train_unifoil

# Train on DTLZ (P=50, L=3)
python -m uq_diagcfm.train train_dtlz
```

### Training Ensembles

For reproducible experiments, train multiple models:

```bash
# Gas Turbine: 5x Diag-CFM + 5x CFM
python -m uq_diagcfm.train main_gas_turbine

# Unifoil: 5x Diag-CFM + 5x CFM
python -m uq_diagcfm.train main_unifoil

# DTLZ: 10 models for each P in {12, 24, 50, 100}
python -m uq_diagcfm.train main_dtlz_scaling
```

### INN Baselines

```bash
# Train INN on Gas Turbine
python -m uq_diagcfm.train train_inn_gas_turbine [epochs]
```

## Evaluation

### Performance Evaluation (Forward/Inverse)

Compute forward MSE, round-trip error, and design diversity:

```bash
# Gas Turbine
python -m uq_diagcfm.evaluate_results_gas_turbine

# Unifoil
python -m uq_diagcfm.evaluate_results_unifoil

# DTLZ (specify P)
python -m uq_diagcfm.evaluate_results_dtlz 12
python -m uq_diagcfm.evaluate_results_dtlz 50
```

### Select-Best Evaluation

Evaluate UQ metrics on selecting the best candidate among K generations:

```bash
# Run on all datasets
python -m uq_diagcfm.select_best

# Run on specific dataset
python -m uq_diagcfm.select_best --dataset gas_turbine
python -m uq_diagcfm.select_best --dataset dtlz --P 50
```

### Error-Rejection (Abstention) Evaluation

Evaluate UQ metrics for abstention - rejecting uncertain samples:

```bash
# Run on all datasets
python -m uq_diagcfm.error_rejection --mode error_rejection --dataset all

# Run on specific dataset
python -m uq_diagcfm.error_rejection --mode error_rejection --dataset gas_turbine
```

### OOD Detection Evaluation

Evaluate UQ metrics for detecting out-of-distribution targets:

```bash
# Gas Turbine
python -m uq_diagcfm.evaluate_uq_gas_turbine [nb_samples] [difficulty]

# Unifoil
python -m uq_diagcfm.evaluate_uq_unifoil [nb_samples] [difficulty]

# DTLZ
python -m uq_diagcfm.evaluate_uq_dtlz [P] [nb_samples] [difficulty]
```

Difficulty levels control how close OOD points are to the training distribution:
- `easy`: Points far from training data (normalized distance 0.15-1.0)
- `medium`: Points at moderate distance (0.05-0.15)
- `hard`: Points just outside the boundary (0.02-0.08)

### CFM vs Diag-CFM Ablation

Reproduce the ablation study comparing standard CFM and Diag-CFM under different parameter orderings:

```bash
python -m uq_diagcfm.ablation_diag_cfm
```

## Datasets

### Gas Turbine Combustor

- **Design parameters (P=6):** Geometric and operating parameters
- **Labels (L=3):** Unmixedness, pressure loss, thermoacoustic growth rate
- **Source:** See paper references

### Unifoil (Airfoil Aerodynamics)

- **Design parameters (P=14):** POD coefficients for airfoil shape
- **Labels (L=3):** Lift, drag, moment coefficients
- **Conditioning:** Angle of attack, Mach number
- **Source:** See paper references

### DTLZ Benchmark

- **Design parameters (P=12, 24, 50, 100):** Scalable dimension
- **Labels (L=3):** Multi-objective function values
- **Advantage:** Analytical forward function enables exact round-trip error computation
- **Source:** See paper references

## Output Structure

Training runs are saved to `checkpoints/<dataset>/<run_name>/`:
```
checkpoints/
└── gas_turbine/
    └── YYYYMMDD-HHMMSS_diagcfm_unshuffled_hd1024_d3/
        ├── run_info.json          # Hyperparameters and metrics
        └── model_checkpoint.pth   # Model weights
```

Results are saved to `results/`:
```
results/
├── gas_turbine_evaluation.json
├── unifoil_evaluation.json
└── uq/
    ├── select_best_all_datasets.json
    └── uq_ood_*.json
```

## Key Findings

1. **Diag-CFM achieves order-of-magnitude lower round-trip error** than CFM and INN across all datasets (Table 1-3 in paper).

2. **Zero-Deviation and Self-Consistency outperform general-purpose UQ metrics** for select-best, error-rejection, and OOD detection tasks.

3. **Zero-Deviation is particularly attractive** as it requires no additional forward passes - computed as a byproduct of generation.

## License

This code is provided for research purposes as supplementary material.

## Citation

If you use this code, please cite the accompanying paper (citation will be provided upon publication).
