<div align="center">
<h1>MMSBVI: Multi-Marginal Schrödinger Bridge Variational Inference</h1>
<h3>多边际薛定谔桥变分推断</h3>
</div>

<p align="center">
  <a href="#iclr-提交与评审政策">ICLR 提交</a> •
  <a href="#核心概念">核心概念</a> •
  <a href="#架构特色">架构特色</a> •
  <a href="#安装">安装</a> •
  <a href="#复现验证">复现验证</a> •
  <a href="#代码结构">代码结构</a>
</p>

---

本仓库是论文 ***Geometric Variational Inference: Elliptic Schrödinger Bridges, Anchor Compatibility, and Rates Entropic to Wasserstein*** 的官方 JAX 实现。本项目旨在建立并数值验证一个核心理论：路径空间中的变分推断 (Variational Inference) 与一个多边际薛定谔桥 (Multi-Marginal Schrödinger Bridge, MMSB) 问题存在基础等价性。这一发现将经典的贝叶斯平滑问题置于最优传输与信息几何的统一视角下进行审视。

## ICLR 提交与评审政策

- 匿名提交：本仓库对应 ICLR 2026 匿名投稿。评审期内请勿添加任何可识别个人/机构的信息（姓名、单位、外链等）。
- 评审期不接受新 Issues/PR：为保持匿名与减少噪声，评审结束前不接收新的 Issues 或 Pull Requests；待评审结果公布后统一开启并处理反馈。
- 代码范围：本仓库主要用于理论诊断（非 SOTA 基准），评审期内不包含基线实现。

## 核心思想

本工作的核心论点是：**先验即几何 (The Prior is the Geometry)**。我们证明了连续时间系统中的贝叶斯平滑问题，等价于在某个黎曼流形上寻找一条测地线，而该流形的度量完全由先验参考过程决定。

这一思想由我们的核心成果 **定理1 (VI-MMSB 等价性)** 精确阐述。该定理证明了，最小化变分自由能的目标，与求解一个多边际薛定谔桥问题是完全等价的。该问题的目标是寻找一个路径测度 $Q$，使其与一个参考过程 $P_{\text{ref}}$ (例如 Ornstein-Uhlenbeck 过程) 的 KL 散度 (Kullback-Leibler divergence) 最小，同时满足其在一系列观测时间点的边际分布恰好是给定的目标边际 $\{\rho_{t_k}^{\text{obs}}\}$。

与论文一致的形式化描述如下（锚点为后验时间边际 $\mu_k=(P_{\text{post}})_{t_k}$）：

$$
Q^{*}
  = \underset{Q}{\text{arg min}}
    \mathrm{KL}\bigl(Q | P_{\mathrm{ref}}\bigr)
  \quad\text{s.t.}\quad
  Q_{t_k} = \mu_k, k = 0,\dots,K.
$$

此问题的解，即后验路径测度 $Q^*$，其演化轨迹是在由 **Onsager–Fokker 度量** 所赋予几何结构的概率分布空间中的一条测地线。该框架统一了经典与现代观点：在线性高斯情形可恢复 Rauch–Tung–Striebel (RTS) 平滑器；在参数极限下，于低噪声 ($\sigma\to0$) 收敛到 Wasserstein OT 的位移测地线；于高噪声 ($\sigma\to\infty$) 且在有界域/离散网格并满足 Doeblin 下界的条件下，收敛到混合（m‑connection）几何的测地线。注意：混合（m‑connection）测地线并非 Fisher–Rao Levi‑Civita 测地线；只有在引入非平衡扩展（Hellinger–Kantorovich）时才出现 Fisher–Rao 分量。

本仓库提供了一个高精度的**迭代比例拟合算法 (IPFP)** 实现，作为严格数值验证上述理论发现的核心工具。

## 架构亮点

本项目的架构设计融合了学术研究的严谨性与现代机器学习的工程实践。

1.  **求解器架构 (Solver Architecture)**
    *   **经典网格求解器 (`ipfp_1d.py`)**: 基于 Sinkhorn 算法的迭代比例拟合过程 (IPFP)，为低维问题提供高精度解，用于理论验证。

2.  **高度模块化与可扩展性 (Highly Modular & Extensible)**
    *   **类型系统 (`types.py`)**: 使用 `chex.dataclass` 和 `jaxtyping` 定义类型系统，将核心概念如问题定义 (`MMSBProblem`)、算法配置 (`IPFPConfig`, `ControlGradConfig`) 和解 (`MMSBSolution`) 等进行解耦。
    *   **组件注册表 (`registry.py`)**: 采用工厂模式，允许通过字符串名称动态注册和加载不同的求解器、网络和积分器，并通过配置文件（如 Hydra）进行管理。

3.  **高性能计算 (High-Performance Computing)**
    *   整个代码库基于 JAX 构建，使用其 `jit`, `vmap`, `pmap` 等变换进行并行计算和GPU加速。
    *   2D MMSB 路径使用 Pallas 自定义 CUDA 内核与编译化主循环，使性能提升到工程极限：
        - 可选 Pallas 批量 1D 列归一化（`(dt×列)` 二维 grid + 行分块），融合 `exp + 梯形加权`；
        - 可选 2D 融合归一化（裁剪 + 二维梯形积分 + 归一化）含 tiled 版本（两阶段：分块局部质量 + 全局缩放）；
        - 完整编译化 IPFP 主循环（`lax.fori_loop`），内图误差评估与 ε 调度。

## 安装

### 环境配置
推荐使用 Python 3.10–3.11 与 `pip`。请在同一环境中“二选一”安装（不要同时安装 CPU 与 GPU 依赖）。

方案 A —— 一键安装（自动检测 GPU/CPU）
```bash
python setup_environment.py
```

方案 B —— 手动（仅 CPU）
```bash
pip install -r requirements-cpu.txt
```

方案 C —— 手动（GPU，CUDA 12.x）
1）按照官方 JAX 文档安装与你的 CUDA 匹配的 JAX/JAXLIB（强烈建议）：
   https://github.com/google/jax#pip-installation （选择 CUDA 12 轮子）

   示例（请以官方文档为准，按你的平台/CUDA 调整）：
   ```bash
   pip install --upgrade "jax==0.6.2" "jaxlib==0.6.2"
   ```
2）然后安装本项目依赖（避免重复安装 jax）：
```bash
pip install -r requirements-gpu.txt --no-deps
```

注意
- 不要在同一环境中同时安装 `requirements-cpu.txt` 与 `requirements-gpu.txt`。
- 如遇 JAX 轮子解析问题，请先按 JAX 官方文档安装 JAX，再使用 `--no-deps` 安装项目依赖。

### 核心依赖
*   **JAX 生态**: `jax`, `jaxlib`, `flax`, `optax`, `chex`
*   **最优运输**: `ott-jax`
*   **科学计算**: `numpy`, `scipy`
*   **配置**: `hydra-core`

### 运行核心测试
为确保环境配置正确，请运行测试套件：
```bash
pytest tests/
```
所有测试用例均应通过。

## 性能选项（2D IPFP）

### 启用 Pallas 内核

```python
from src.mmsbvi.core.types import IPFP2DConfig

config = IPFP2DConfig(
    use_pallas_kernels=True,   # 开启Pallas路径
    pallas_norm_tiled=True,    # 2D 融合归一化使用分块实现（大网格推荐）
    pallas_tile_i=64,          # 可选：行方向tile
    pallas_tile_j=64,          # 可选：列方向tile
    pallas_block_rows=128,     # 可选：1D批量归一化的行块大小
)
```

依赖：确保 `jax[cuda]` 版本≥0.6.2，且 `jax.experimental.pallas` 可用；若环境不支持，会自动回退到原生 JAX/XLA 实现。

### 启用编译化主循环

```python
config = IPFP2DConfig(
    compiled_loop=True,
    compiled_max_iterations=1000,   # 可选：覆盖最大迭代次数
    compiled_check_interval=10,     # 可选：覆盖检查间隔
)
```

说明：该选项将IPFP主循环完全置于图中，减少Python控制流与kernel启动开销；误差评估与ε调度以图内逻辑执行。

### 数值稳定与精度策略

- 对数核默认在构造阶段使用 float64 以增强稳定性；在应用阶段以 `compute_dtype`（默认 float32）执行并通过 `matmul_precision("high")` 走 TF32 路径；
- 最终密度归一过程在 Pallas 融合核中严格保证质量为1，避免累积误差。

## 复现验证

论文中的关键理论验证和图表，可通过 `automation/` 目录下的脚本复现。

### 完整的验证套件
要按顺序运行所有验证工作流，请执行主脚本。这将复现图表和数值结果。
```bash
chmod +x automation/run_complete_validation_suite.sh
./automation/run_complete_validation_suite.sh
```

### 单个验证工作流
您也可以独立运行每个验证工作流：
*   **RTS等价性验证**: 验证MMSB解在特定条件下与Rauch-Tung-Striebel (RTS)平滑器的一致性。
    ```bash
    ./automation/run_rts_equivalence_workflow.sh
    ```
*   **几何极限验证**: 探索当噪声趋于零时，薛定谔桥如何收敛到确定性的最优传输路径。
    ```bash
    ./automation/run_geometric_limits_workflow.sh
    ```
*   **参数敏感性分析**: 分析模型性能对关键参数（如正则化强度、时间步长）的敏感度。
    ```bash
    ./automation/run_parameter_sensitivity_workflow.sh
    ```
生成的结果将按实验类型保存在 `results/` 目录中。

## 代码结构

项目结构旨在将核心算法与实验验证脚本清晰分离。

```
src/mmsbvi/
├── core/                    # 核心类型定义、配置和组件注册表
├── algorithms/              # 核心算法实现 (IPFP, Neural Control)
├── solvers/                 # 数值求解器 (PDE, Gaussian Kernel)
├── integrators/             # SDE数值积分格式
├── nets/                    # 神经网络架构 (Flax)
├── utils/                   # 工具函数 (日志、配置)
└── configs/                 # Hydra配置文件

theoretical_verification/    # 1D理论验证实验脚本
tests/                       # 单元和集成测试
automation/                  # 一键复现工作流的Shell脚本
```

---

<div align="center">
本仓库基于MIT许可证。
</div>
