# Robust Optimal Transport (ROT)

This repository contains a modular Python implementation of the Robust Optimal Transport algorithm, converted from the original Jupyter notebook implementation.

## Overview

The Robust Optimal Transport (ROT) algorithm provides a computationally efficient approach to optimal transport problems with robustness constraints. This implementation uses a multi-scale approach with decreasing approximation parameters to find optimal transport plans between source and target distributions.

## Project Structure

```
src/
├── __init__.py                 # Package initialization
├── utils.py                    # Core utilities (distance calculations, point generation)
├── discrete_set.py             # Discrete set generation and management
├── search_algorithms.py        # DFS-based search algorithms
├── transport_optimization.py   # Main optimization procedures
├── robust_ot.py               # Main ROT algorithm implementation
└── visualization.py           # Plotting and visualization functions

main.py                        # Example script demonstrating usage
README.md                      # This file
```

## Features

- **Modular Design**: Clean separation of concerns with focused modules
- **Type Hints**: Full type annotations for better code documentation
- **Comprehensive Visualization**: Multiple plotting functions for analysis
- **Algorithm Statistics**: Detailed tracking of algorithm performance
- **POT Library Comparison**: Optional comparison with Python Optimal Transport library
- **Example Usage**: Complete example script replicating notebook functionality

## Requirements

### Required Dependencies
```bash
numpy
matplotlib
```

### Optional Dependencies
```bash
pot  # For comparison with POT library (install with: pip install pot)
```

## Installation

Install required dependencies:
```bash
pip install numpy matplotlib
pip install pot
```

## Usage

### Quick Start

Run the example script that replicates the original notebook:

```bash
python main.py
```

This will:
1. Generate source points on a 200x200 grid
2. Create a mass distribution with Gaussian center + exponential noise
3. Generate 30 target points
4. Run the ROT algorithm
5. Display visualizations and statistics

### Using the Modules

```python
import numpy as np
from src import (
    generate_grid_points, generate_mass_distribution, generate_target_points,
    compute_ROT, plot_transport_plan, plot_statistics
)

# Generate data
A = generate_grid_points(grid_size=100)
A_mass = generate_mass_distribution(A, sigma=0.15)
B, B_mass = generate_target_points(n=10)

# Run ROT algorithm
B_weights = np.zeros(len(B))
results = compute_ROT(A, A_mass, B, B_mass, B_weights, lambda_val=0.09)

# Extract results
(transport_plan, B_weights_final, A_delta, full_transport_plan,
 path_lengths_aug, cycle_lengths_aug, path_lengths_cons, cycle_lengths_cons,
 iters_aug, iters_cons, regions, final_delta) = results

# Visualize
plot_transport_plan(A_delta, B, transport_plan, B_weights_final)
plot_statistics(path_lengths_aug, cycle_lengths_aug, 
               path_lengths_cons, cycle_lengths_cons,
               iters_aug, iters_cons, regions)
```

### Key Functions

#### Data Generation
- `generate_grid_points(grid_size)`: Create regular grid of points
- `generate_mass_distribution(A, sigma, ...)`: Generate mass with Gaussian center + noise
- `generate_target_points(n, sigma, noise_fraction)`: Generate target points

#### Main Algorithm
- `compute_ROT(A, A_mass, B, B_mass, B_weights, lambda_val, ...)`: Main ROT algorithm
- `compute_transport_cost(A, B, transport_plan)`: Calculate transport cost
- `analyze_untransported_mass(...)`: Analyze remaining mass

#### Visualization
- `plot_transport_plan(A_delta, B, transport_plan, ...)`: Visualize transport connections
- `plot_mass_distribution(A, A_mass, grid_size)`: Show mass as heatmap
- `plot_untransported_mass_heatmap(...)`: Visualize remaining mass
- `plot_statistics(...)`: Show algorithm performance metrics

## Algorithm Parameters

- **lambda_val**: Regularization parameter (typically λ²)
- **min_delta**: Minimum approximation parameter (default: 0.0002)
- **initial_delta**: Starting approximation parameter (default: 1.0)
- **grid_size**: Size of source point grid (default: 100)
- **n**: Number of target points (default: 10)

## Algorithm Phases

The ROT algorithm operates in alternating phases at multiple scales:

1. **Augmentation Phase**: Find augmenting paths and cycles to increase transport
2. **Weight Increase**: Increase weights for reachable surplus points
3. **Consolidation Phase**: Find consolidating paths to reduce transport cost
4. **Weight Decrease**: Decrease weights for violating points
5. **Scale Reduction**: Halve δ and repeat until minimum δ reached

## Output

The algorithm returns:
- **transport_plan_hat**: Transport plan for representative points
- **B_weights_final**: Final weights for target points  
- **A_delta**: Final representative source points
- **sd_ot**: Full transport plan for original points
- **Statistics**: Path lengths, cycle lengths, iterations for each phase
- **regions**: Number of representative regions at each scale

## Performance

The modular implementation maintains the same algorithmic complexity as the original notebook while providing:
- Better code organization and maintainability
- Improved error handling and type safety
- Enhanced visualization capabilities
- Easier integration into larger projects

## License

This implementation is based on the research presented in the accompanying paper. Please cite appropriately if using in academic work.
