# Wavelet Optimal Transport
Official code for Unpaired Single-Cell Dataset Alignment with Wavelet Optimal Transport.

**Wavelet Optimal Transport (WOT)** is a *multi-resolution* optimal transport method that aligns highly noisy, incomplete, and non-isometric datasets.

The repository contains all scripts and settings to setup, run, and recreate experiment results for Wavelet Optimal Transport.

![Wavelet Optimal Transport Overview](visualizations/method.png)

## Getting Started
### Installation
Install WOT with the necessary dependencies using
```
cd cell-state
conda env create --name wot --file=env.yml
conda activate wot
```
Ensure CUDA is available using `python -c 'import torch;print(torch.cuda.is_available())'`.

**Note**: we have only tested this repository with `python=3.10`
## Overview of Repo
* Directory `src` hosts all code for WOT
* Directory `tests` hosts all code to run experiments that are reported in the paper
* Directory `plots` hosts all code to plot the figures that are reported in the paper (still messy so beware)
* Directory `data` hosts all data that was used in the experiments reported in the paper

## Usage - Align Datasets with WOT
We can access `LWOT` and `EWOT` via
```
from src.ewot import EWOT
from src.lwot import LWOT
```
To demonstrate, let's run `EWOT` on two random datasets `A` (1000 samples with 5 feature dimensions) and `B` (1000 samples with 10 feature dimensions). We start by initializing random numpy matrices
```
import numpy as np
A = np.random.random((1000, 5))
B = np.random.random((1000, 10))
```
And then initialize WOT
```
wot = EWOT(A, B, n_scales=20, w_op="simple_tight", T=None)
```
We then calculate the transport plan T by running the WOT optimization process
```
wot.solve(p=None, q=None) # run OT optimization, calculates T
```
Once we have found the transport plan, we can now use barycentric projection to project points from data space `A` to data space `B`
```
projected_B = wot.project(to_X2=False) # barycentric projection
```
or vice versa, if we want to project points from `B` to `A`
```
projected_A = wot.project(to_X2=True) # barycentric projection
```

## Paper Results
**NOTE: ensure that the hyperparameters match the values listed on Table 3 in Appendix C**
### Bifurcation alignment 
To recreate the bifurcation simulation experiment, run
```
Usage:
python -m tests.test_toy

Flags:
    --outpath                   string              directory path to dump experiment results, both data and figures
    --datapath                  string              directory path to access bifurcation dataset
    --noise                     option              run increasing noise experiment
    --dropout                   option              run increasing dropout experiment 
    --num_trials                int                 number of trials to run to produce confidence intervals
    --max_noise                 float               max noise to add, value is in fraction of mean distance between points
    --max_dropout               float               max fraction to dropout from the dataset
    --num_interval              int                 number of steps to run for noise and dropout experiments between min and max values
    --plot                      option              plot results if True
    --epsilon_noise             float               entropic regularization for noise experiment
    --epsilon_dropout           float               entropic regularization for dropout experiment
    --agg_op                    string              aggregation operation for wavelet coefficients scales
    --n_scales                  int                 number of wavelet scales to use
```
### Shape alignment
To recreate the shape alignment experiment, run
```
Usage:
python -m tests.test_shape_align

Flags:
    --outpath                   string              directory path to dump experiment results, both data and figures
    --datapath                  string              directory path to access SHREC20 dataset
    --lwot, --no-lwot           option              use LWOT (--lwot) or EWOT (--no-lwot)
    --wavelet_kernel            string              which wavelet kernel to use (options: "heat", "mexican_hat", "itersine", "simple_tight", "half_cosine_kernel" or "meyer")
    --n_samples                 int                 number of points to sample from each shape
    --unbalanced                option              run unbalanced sinkhorn iterations if True or balanced if False
    --plot                      option              whether to plot figures
    --epsilon                   float               entropic regularization parameter
    --agg_op                    string              aggregation operation for wavelet coefficients scales
    --n_scales                  int                 number of wavelet scales to use
```
### Single Cell Alignment
To recreate the single-cell alignment experiment, run
```
Usage:
python -m tests.test_single_cell_align

Flags:
    --datapath                  string              directory path to access SNARE-seq and scGEM datasets
    --lwot, --no-lwot           option              use LWOT (--lwot) or EWOT (--no-lwot)
    --SNAREseq                  option              run SNARE-seq experiment
    --scGEM                     option              run scGEM experiment
    --wavelet_kernel            string              which wavelet kernel to use (options: "heat", "mexican_hat", "itersine", "simple_tight", "half_cosine_kernel" or "meyer")
    --unbalanced                option              run unbalanced sinkhorn iterations if True or balanced if False
    --log_foscttm               option              whether or not to calculate FOSCTTM metric from experiments
    --snare_epsilon             float               entropic regularization parameter for SNARE-seq dataset
    --scGEM_epsilon             float               entropic regularization parameter for SNARE-seq dataset
    --agg_op                    string              aggregation operation for wavelet coefficients scales
    --n_scales                  int                 number of wavelet scales to use
```
