# Quantization-based Bounds on the Wasserstein Metric

This repository contains the implementation of various approximation methods for computing bounds of Wasserstein distance between probability measures on regular grid, as described in our paper submission. The implementation is based on JAX for efficient computation and automatic differentiation.

## Project Structure

- `ot_jax/optimal_transport/jax_wasserstein.py`: Core implementation of the approximation methods
- `ot_jax/benchmark/`: Benchmarking code for evaluating the methods
- `data/`: Experimental data and datasets
- `output/`: Generated results and figures

## Methods Implemented

The following approximation methods are implemented:

1. Bilevel Lower Bound
2. Min-Cost Lower Bound
3. Bilevel Upper Bound
4. Weighted Cost Upper Bound
5. Entropy Upper Bound
6. Entropy Lower Bound
7. Exact Wasserstein Distance (for comparison)

## Installation

### Prerequisites

- Python 3.12 or higher
- For GPU support on Linux: CUDA 12.x
- For macOS: Metal support (included with OS)

### Using uv (Recommended)

[uv](https://github.com/astral-sh/uv) is a fast Python package installer. To install using uv:

```bash
# Install uv if you haven't already
curl -LsSf https://astral.sh/uv/install.sh | sh

# Create and activate a new virtual environment
uv venv
source .venv/bin/activate  # On Unix/macOS
# or
.venv\Scripts\activate  # On Windows

# Install dependencies
uv sync
```

### Using pip

```bash
# Create and activate a new virtual environment
python -m venv .venv
source .venv/bin/activate  # On Unix/macOS
# or
.venv\Scripts\activate  # On Windows

# Install dependencies
pip install -e .
```

## Platform-Specific Notes

- **macOS**: JAX will use Metal for acceleration by default
- **Linux**: CUDA support will be installed automatically for GPU acceleration
- **Windows**: Limited support, recommended to use WSL2 for better compatibility

## Running Experiments

The benchmarking code is located in the `ot_jax/benchmark` directory. 

Example usage:

```python
from ot_jax.optimal_transport.jax_wasserstein import bilevel_lower_bound
import jax.numpy as jnp

# Example computation
x = jnp.array(...)  # Your first measure
y = jnp.array(...)  # Your second measure
distance = bilevel_lower_bound(x, y, p=2, scale_factor=4)
```
