# Regression-Wasserstein (PyTorch)

This repo has code for **Towards Wasserstein Distance Estimation From Sliced Optimal Transport** in PyTorch.  
We run **five experiments**:  
1. Mixture of Gaussians  
2. k-NN Classification (ShapeNetV2)  
3. Embedding (UMAP visualization)  
4. Wormhole Comparison  
5. RG-Wormhole
    - Training time comparison experiment
    - Embeddings experiment
    - Reconstruction experiment
    - Barycenter experiment
    - Interpolation experiment

## Install
```bash
pip install -r requirements.txt
```


## Download Datasets

We use the following datasets in our experiments:

- **MNIST Point Clouds**
- **ShapeNetV2**
- **MERFISH Cell Niches**
- **scRNA-seq**
- **ModelNet40**

Please download each dataset manually and place them under the `data/` folder with the structure below:

```bash
data/
├── MNIST/
├── ShapeNetCore.v2.PC15k.zip
├── cellular_niche
    ├── sc_data.h5ad
    ├── st_data.h5ad
└── ModelNet40/

```
Make sure the folder names match, otherwise the preprocessing scripts may not find the data.

Preprocess datasets:

```bash
python dataset_preprocessing_pcnmist.py
python dataset_preprocessing_modelnet40.py
python dataset_preprocessing_merfish.py
python dataset_preprocessing_scrna.py
python dataset_preprocessing_pointcloud.py
```

Split into train/test:

```bash
python split_dataset_pcnmist.py
python split_dataset_modelnet40.py
python split_dataset_merfish.py
python split_dataset_scrna.py
python split_dataset_pointcloud.py
```

## Experiments

### 1. Mixture of Gaussians

```bash
python simulation_metric_compute.py
python simulation_metric_alpha_estimation.py
```

### 2. Classification (ShapeNetV2 k-NN)

```bash
python knn_pointcloud.py
python knn_pointcloud_eval.py
```

### 3. Embedding (UMAP visualization)

```bash
python embeddings_visualization.py
```

### 4. Wormhole Comparison

Train and evaluate Wormhole (baseline)
```bash
python compare_wormhole_wormhole_train.py
python compare_wormhole_wormhole_eval.py
```

Train and evaluate RG variants (proposed methods)
```bash
python compare_wormhole_rg_compute_optimal_alpha.py
python compare_wormhole_rg_compute_distances.py
python compare_wormhole_rg_eval.py
```

### 5. RG-Wormhole

Train Wormhole (baseline):

```bash
python rg_wormhole_train_wormhole.py
```

Train RG-Wormhole:

```bash
python rg_wormhole_train_rg_wormhole.py
```

Compare training time:

```bash
python rg_wormhole_compare_time_wormhole.py
python rg_wormhole_compare_time_rg_wormhole.py
```

Embeddings experiment:
```bash
python rg_wormhole_embeddings_baseline.py
python rg_wormhole_embeddings_rg_wormhole.py
```

Reconstruction experiment:
```bash
python rg_wormhole_plot_reconstruction.py
```

Barycenter experiment:
```bash
python rg_wormhole_plot_barycenter.py
```

Interpolation experiment:
```bash
python rg_wormhole_plot_interpolation.py
```

