<div align="center">
<h1>Multi-Marginal Schrödinger Bridge Variational Inference</h1>
<h3>MMSBVI</h3>
</div>

<p align="center">
  <a href="README_CN.md">中文</a> 


<p align="center">
  <a href="#iclr-submission--review-policy">ICLR Submission</a> •
  <a href="#core-concepts">Core Concepts</a> •
  <a href="#architectural-highlights">Architectural Highlights</a> •
  <a href="#installation">Installation</a> •
  <a href="#reproducing-validation">Reproducing Validation</a> •
  <a href="#code-structure">Code Structure</a>
</p>

---

This repository contains the JAX implementation for the paper, ***Geometric Variational Inference: Elliptic Schrödinger Bridges, Anchor Compatibility, and Rates Entropic to Wasserstein***. This project establishes and numerically validates a fundamental equivalence between Variational Inference in path space and a Multi-Marginal Schrödinger Bridge (MMSB) problem, reframing Bayesian smoothing through the lens of optimal transport and information geometry.

## ICLR Submission & Review Policy

- Anonymous submission: this repository corresponds to an anonymous ICLR 2026 submission. Please avoid adding any identifying information (names, affiliations, links) during the review period.
- No new issues/PRs during review: to preserve anonymity and reduce noise, we do not accept new issues or pull requests until the review process concludes. After the decision, the tracker will be re-opened and feedback consolidated.
- Scope: the code is intended for theory-driven diagnostics (not SOTA benchmarking) and omit baseline implementations during review.

## Core Concepts

The central thesis of this work is that **the prior is the geometry**. We demonstrate that Bayesian smoothing for continuous-time systems can be viewed as finding a geodesic on a Riemannian manifold whose metric is determined by the reference process.

This is formalized by **Theorem 1 (VI-MMSB Equivalence)**, which proves that minimizing the variational free energy is equivalent to solving a multi-marginal Schrödinger Bridge problem. The objective is to find a path measure $Q$ that minimizes the Kullback-Leibler (KL) divergence to a reference process $P_{\text{ref}}$ (e.g., an Ornstein-Uhlenbeck process), subject to matching a set of target marginals $\{\rho_{t_k}^{\text{obs}}\}$ derived from observations.

Formally, in the paper the constrained optimization is expressed as (anchors denote posterior time marginals $\mu_k=(P_{\text{post}})_{t_k}$):

$$
Q^{*}
  = \underset{Q}{\text{arg min}} 
    \mathrm{KL}\bigl(Q | P_{\mathrm{ref}}\bigr)
  \quad\text{s.t.}\quad
  Q_{t_k} = \mu_k, \; k = 0,\dots,K.
$$

The solution to this problem, the posterior path measure $Q^*$, traces a geodesic on the space of probability distributions endowed with the **Onsager–Fokker metric**. This framework unifies classical and modern perspectives, recovering the Rauch–Tung–Striebel (RTS) smoother in the linear–Gaussian case and interpolating between Wasserstein OT (low $\sigma$) and mixture (m‑connection) geometry (high $\sigma$, under compact/discrete Doeblin conditions). Note: mixture (m‑connection) geodesics are not Fisher–Rao Levi–Civita geodesics; a Fisher–Rao component appears only in unbalanced extensions (Hellinger–Kantorovich).

This repository provides a high-precision implementation of the **Iterative Proportional Fitting Procedure (IPFP)** to solve this problem, serving as a tool for the rigorous numerical validation of these theoretical findings. A neural-network-based control approach is outlined as a direction to tackle higher-dimensional problems in the paper.

## Architectural Highlights

The architecture of this project integrates principles of academic research with modern machine learning engineering.

1.  **Solver Architecture**
    *   **Classical Grid Solver (`ipfp_1d.py`,`ipfp_2d.py`)**: An Iterative Proportional Fitting Procedure (IPFP) based on the Sinkhorn algorithm. It provides a high-precision solution for low-dimensional problems, which is used for theoretical validation.

2.  **Highly Modular & Extensible**
    *   **Type System (`types.py`)**: Utilizes `chex.dataclass` and `jaxtyping` to define the type system, decoupling core concepts like the problem definition (`MMSBProblem`), algorithm configurations (`IPFPConfig`, `ControlGradConfig`), and the solution (`MMSBSolution`).
    *   **Component Registry (`registry.py`)**: Implements a factory pattern that allows for dynamic registration and loading of different solvers, networks, and integrators via string names, managed through configuration files (e.g., Hydra).

3.  **High-Performance Computing**
    *   The entire codebase is built on JAX, using its `jit`, `vmap`, and `pmap` transformations for parallel computing and GPU acceleration.
    *   In addition to standard XLA optimizations, the 2D MMSB path supports optional Pallas custom CUDA kernels and a compiled main loop to push performance to the limit:
        - Batched 1D column normalization for transition kernels using a 2D grid over `(dt × columns)` with row tiling; fuses `exp + trapezoid weighting` per column.
        - 2D fused normalization (clip + 2D trapezoid integration + normalization) with a tiled variant (two-stage: per-tile local mass accumulation + global scaling) for large grids.
        - Fully compiled IPFP main loop via `lax.fori_loop`, with in-graph error checks and epsilon scheduling to reduce Python control-flow and kernel launch overhead.

## Installation

### Environment Setup
We recommend Python 3.10–3.11 and `pip`. Choose exactly one of the following paths (do not install both CPU and GPU requirements in the same environment).

Option A — one‑liner (auto‑detects GPU/CPU):
```bash
python setup_environment.py
```

Option B — manual (CPU only):
```bash
pip install -r requirements-cpu.txt
```

Option C — manual (GPU with CUDA 12.x):
1) Install JAX/JAXLIB matching your CUDA per the official JAX docs (recommended):
   https://github.com/google/jax#pip-installation (use the CUDA 12 wheel)
   
   Example (subject to your platform and CUDA toolchain):
   ```bash
   # Refer to JAX docs for the exact command for your CUDA version
   pip install --upgrade "jax==0.6.2" "jaxlib==0.6.2"
   ```
2) Then install the project deps (without re-installing jax):
```bash
pip install -r requirements-gpu.txt --no-deps
```

Notes
- Do not combine `requirements-cpu.txt` and `requirements-gpu.txt` in the same environment.
- If you encounter resolution issues with JAX wheels, defer to the JAX installation guide and then use `--no-deps` for the project requirements.

### Core Dependencies
*   **JAX Ecosystem**: `jax`, `jaxlib`, `flax`, `optax`, `chex`
*   **Optimal Transport**: `ott-jax`
*   **Scientific Computing**: `numpy`, `scipy`
*   **Configuration**: `hydra-core`

### Running Core Tests
To verify that the environment is set up correctly, please run the test suite:
```bash
pytest tests/
```
All test cases should pass.

## Performance Options (2D IPFP)

### Enable Pallas Kernels

```python
from src.mmsbvi.core.types import IPFP2DConfig

config = IPFP2DConfig(
    use_pallas_kernels=True,   # turn on Pallas path
    pallas_norm_tiled=True,    # use tiled fused 2D normalization (recommended for large grids)
    pallas_tile_i=64,          # optional: row tile
    pallas_tile_j=64,          # optional: col tile
    pallas_block_rows=128,     # optional: row blocking for batched 1D normalization
)
```

Requirements: ensure `jax[cuda]` ≥ 0.6.2 and `jax.experimental.pallas` is available. If not, the code automatically falls back to the standard JAX/XLA path.

### Enable Compiled Main Loop

```python
config = IPFP2DConfig(
    compiled_loop=True,
    compiled_max_iterations=1000,   # optional: override max iters
    compiled_check_interval=10,     # optional: override check interval
)
```

The compiled loop keeps the full IPFP iteration inside the computation graph, including convergence checks and epsilon scheduling, minimizing host-device round-trips.

## Reproducing Validation

The key theoretical validations and figures from the paper can be reproduced with scripts located in the `automation/` directory.

### Complete Validation Suite
To run all validation workflows in sequence, execute the main script. This will reproduce the figures and numerical results.
```bash
chmod +x automation/run_complete_validation_suite.sh
./automation/run_complete_validation_suite.sh
```

### Individual Validation Workflows
You can also run each validation workflow independently:
*   **RTS Equivalence Validation**: Verifies the consistency of the MMSB solution with the Rauch-Tung-Striebel (RTS) smoother under specific conditions.
    ```bash
    ./automation/run_rts_equivalence_workflow.sh
    ```
*   **Geometric Limits Validation**: Explores how the Schrödinger bridge converges to a deterministic optimal transport path as the noise term approaches zero.
    ```bash
    ./automation/run_geometric_limits_workflow.sh
    ```
*   **Parameter Sensitivity Analysis**: Analyzes the sensitivity of the model's performance to key parameters, such as regularization strength and time step size.
    ```bash
    ./automation/run_parameter_sensitivity_workflow.sh
    ```
The generated results will be saved in the `results/` directory, organized by experiment type.

## Code Structure

The project is structured to separate the core algorithms from the experimental validation scripts.

```
src/mmsbvi/
├── core/                    # Core type definitions, configs, and component registry
├── algorithms/              # Core algorithm implementations (IPFP, Neural Control)
├── solvers/                 # Numerical solvers (PDE, Gaussian Kernel)
├── integrators/             # SDE numerical integration schemes
├── nets/                    # Neural network architectures (Flax)
├── utils/                   # Utility functions (logging, config)
└── configs/                 # Hydra configuration files

theoretical_verification/    # Scripts for 1D theoretical validation experiments
tests/                       # Unit and integration tests
automation/                  # Shell scripts for validation workflows
```

---

<div align="center">
This repository is licensed under the MIT License.
</div>
