# On Transportation of Mini-batches: A Hierarchical approach

## Requirement

* python 3.6
* pytorch 1.7.1
* torchvision
* numpy
* tqdm
* geoopt
* geomloss
* POT
* matplotlib
* pyabc

## Gradient Flow (GradientFlow)
```
python main.py
```

## Color Transfer (ColorTransfer)
```
python main.py --k=10 --m=100 --T=5000 --source images/s1.bmp --target images/t1.bmp 
```

### Terminologies
--k : number of mini-batches

--m : the size of mini-batches

--T : the number of steps

--cluster: K mean clustering to compress images

--palette: show color palette

--source: Path to the source image

--target: Path to the target image
## Approximate Bayesian Computation (ABC)
```
python main.py --k=16 --m=16
```

### Terminologies
--k : number of mini-batches

--m : the size of mini-batches

## Deep Adaptation on digits datasets (DeepDA)

### Code organization
cfg.py : this file contains arguments for training.

methods.py : this file implements the training process of the deep DA.

models.py : this file contains the architecture of the genertor and the classifier. 

train_digits.py: running file for deep DA.

utils.py : this file contains the implementation of utility functions.

### Terminologies
--source_ds : source dataset 

--target_ds : target dataset

--data_dir : path to dataset

--k : number of mini-batches

--mbsize : mini-batch size

--n_epochs : number of running epochs

--lr : initial learning rate

--test_interval : interval of two continuous test phase

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--eta1 : weight of embedding loss ($\alpha$ in equation 10)

--eta2 : weight of transportation loss ($\lambda_t$ in equation 10)

### Change the number of mini-batches $k$
```
bash sh/exp_mOT_change_k.sh
bash sh/exp_BoMbOT_change_k.sh
```

### Change the mini-batch size $m$
```
bash sh/exp_mOT_change_m.sh
bash sh/exp_BoMbOT_change_m.sh
```

## Deep Generative model (DeepGM)

### Code organization
Celeba_generator.py, Cifar_generator.py : these files contain the architecture of the genertor on CelebA and CIFAR10 datasets, and include some self-function to compute losses of corresponding baselines. 

experiments.py : this file contains some functions for generating images.

fid_score.py: this file is used to compute the FID score.

gen_images.py: read saved models to produce 10000 images to calculate FID.

inception.py: this file contains the architecture of Inception Net V3.

main_celeba.py, main_cifar.py : running files on the corresponding datasets.

utils.py : this file contains implementation of utility functions.

### Terminologies
--method : type of OT loss (OT, sliced)

--k : number of mini-batches

--m : mini-batch size

--epochs : number of epochs at k = 1. The actual running epochs is calculated by multiplying this value by the value of k.

--latent-size : latent size of the generator

--datadir : path to dataset

--L : number of projections when using slicing approach

--reg : OT regularization coefficient for Sinkhorn algorithm

--breg : OT regularization coefficient for computing transport plan between mini-batches

### m-OT
``` 
bash sh/run_OT.sh
```

### BoMb-OT
``` 
bash sh/run_BoMbOT.sh
```

### eBoMb-OT
``` 
bash sh/run_eBoMbOT.sh
```

## Acknowledgment

The structure of DeepDA is largely based on [JUMBOT](https://github.com/kilianFatras/JUMBOT). The structure of ABC is largely based on [SlicedABC](https://github.com/kimiandj/slicedwass_abc). We are very grateful for their open sources.