# PyTorch implementation of "Adversarial Schrödinger Bridge Matching"

## Requirements ##
Create the Anaconda environment with the following command:
```
conda env create -f environment.yml
```
Then install the package ```eot_benchmark``` following the [repo](https://github.com/ngushchin/EntropicOTBenchmark).

## Download Celeba ##
Download Celeba dataset using the [link](https://drive.google.com/uc?export=download&id=0B7EVK8r0v71pZjFTYXZWM3FlRnM), unzip it and set the variable $data_root in training and testing scripts for Celeba accordingly. 

## Celeba-128, one-sided pretraining ##
To pretrain ASBM for male-to-female translation before running D-IMF procedure with $\epsilon=1$ and $T=4$ run the following script:
```
bash train_celeba_128_male2female_ema_T_4_minibatch_ot_eps_1.sh
```
and with $\epsilon=10$:
```
bash train_celeba_128_male2female_ema_T_4_minibatch_ot_eps_10.sh
```
To pretrain ASBM for female-to-male translation before running D-IMF procedure with $\epsilon=1$ and $T=4$ run the following script:
```
bash train_celeba_128_female2male_ema_T_4_minibatch_ot_eps_1.sh
```
and with $\epsilon=10$:
```
bash train_celeba_128_female2male_ema_T_4_minibatch_ot_eps_10.sh
```

## Celeba-128, D-IMF ##
To run D-IMF procedure for Celeba-128 with $\epsilon=1$ and $T=4$ after one-sided pretraining run the following script:
```
bash train_celeba_128_T_4_imf_ema_sampling_eps_1.sh
```
and with $\epsilon=10$:
```
bash train_celeba_128_T_4_imf_ema_sampling_eps_1.sh
```

## SwissRoll, 2D, one-sided pretraining ##
To pretrain ASBM for 2D Gaussian-to-SwissRoll translation before running D-IMF procedure with $\epsilon=0.03$ and $T=4$ run the following script:
```
bash train_gaussian2swissroll_T_4_eps_0.03.sh,
```
with $\epsilon=0.1$:
```
bash train_gaussian2swissroll_T_4_eps_0.1.sh
```
and with $\epsilon=0.3$:
```
bash train_gaussian2swissroll_T_4_eps_0.3.sh
```
To pretrain ASBM for SwissRoll-to-2D Gaussian translation before running D-IMF procedure with $\epsilon=0.03$ and $T=4$ run the following script:
```
bash train_swissroll2gaussian_T_4_eps_0.03.sh,
```
with $\epsilon=0.1$:
```
bash train_swissroll2gaussian_T_4_eps_0.1.sh
```
and with $\epsilon=0.3$:
```
bash train_swissroll2gaussian_T_4_eps_0.3.sh
```

## SwissRoll, 2D, D-IMF ##
To run D-IMF procedure for Gaussian-SwissRoll experiment with $\epsilon=0.03$ and $T=4$ after one-sided pretraining run the following script:
```
bash train_celeba_128_T_4_imf_ema_sampling_eps_1.sh
```
and with $\epsilon=10$:
```
bash train_celeba_128_T_4_imf_ema_sampling_eps_1.sh
```
To visualize the results for D-IMF for $\epsilon=0.1$ run the following script:
```
bash test_swiss_roll_imf_eps_0.1.sh
```

## Gaussian-to-Gaussian Schrödinger Bridge, D-IMF ##
To reproduce results from the paper for Gaussian-to-Gaussian Schrödinger Bridge follow the notebook ```D-IMF_Gaussian_case.ipynb```.

## Colored MNIST, D-IMF ##
To reproduce results from the paper for Colored MNIST experiment with $\epsilon=1$ run the following script:
```
bash ASBM_colored_mnist_eps_1.sh
```
and with $\epsilon=10$:
```
bash ASBM_colored_mnist_eps_10.sh
```

## Entropy OT benchmark, D-IMF ##
To reproduce results from the paper for Entropy OT benchmark choose the values for dimension and epsilon (`current_dim` and `current_eps`) according to the text and run  the following scripts:
```
python train_SB_bench_ABM.py --D_opt_steps 3 --plan 'ind' --dim ${current_dim} --epsilon ${current_eps} --fb 'f' --num_timesteps 32 --num_iterations 100000 --eval_freq 2500
python train_SB_bench_ABM.py --D_opt_steps 3 --plan 'ind' --dim ${current_dim} --epsilon ${current_eps} --fb 'b' --num_timesteps 32 --num_iterations 100000 --eval_freq 2500
```
Then for D-IMF specify pretrain checkoint paths and run:
```
current_T=32
d_opt_steps=3

bw_ckpt='path_to_bw_first_D_IMF_iter_ckpt'
fw_ckpt='path_to_fw_first_D_IMF_iter_ckpt'

inner_imf_mark_proj_iters=50000

python train_ASBM_SB_bench.py --epsilon ${current_eps} --dim ${current_dim} --fw_ckpt ${fw_ckpt} --bw_ckpt ${bw_ckpt} --inner_imf_mark_proj_iters ${inner_imf_mark_proj_iters} --imf_iters 10 --eval_freq 5000 --num_timesteps ${current_T} --D_opt_steps ${d_opt_steps}
```
