\# Gradient flows on the feature-Gaussian manifold

## Software implementation
All code used to produce the results in the paper are in this repo. All preprocessed data are in `/data`. Pretrained models are in `/models`.

## Dependencies
Support Python versions: 
- Python3. We are using Python3.6.9.
- Pytorch version: 1.7.1+cu110. 
Pytorch is recommended to download from the official website https://pytorch.org/get-started/locally/. 

All the dependencies are in `requirements.txt`. 
```
python -m pip install -r requirements.txt
```

## Reproducing the results
- Gradient flow results (Figure 2 in the main text). The data path format is`./data/{source dataset}_{target dataset}.tar`. The code will generate a snapshot of flowed images every 10 steps. 
```
python3 gradflow_main.py
```
- Gradien flow on a mixture of Gaussians (Figure 4 and 5 in Appendix). The code will generate a snapshot of flowed images every 20 steps. 

Figure 4
```
python3 gradflow_gaussian_samenumclass.py
```
Figure 5
``` 
python3 gradflow_gaussian_diffnumclass.py
```
then run `gaussian_relabel.ipynb` to relabel the source data and generate plots like Figure 5.

- Transfer learing (figure 3). Run `transfer_learning` notebook to get one transfer learning results (5 bars for one experiment) in figure 3. 

- Compare with "Dataset Dynamics via Gradient Flows in Probability Space" David Alvarez-Melis, Nicolò Fusi
Go to `/otdd` folder and run 
```
python3 main.py
```
Then run `./otdd/transfer_learning_comparison` notebook to transfer learn and get the same plot as `transfer_learning`

## Datasets
All the real datasets are public on the internet. We have used the following datasets:
- MNIST: http://yann.lecun.com/exdb/mnist/
- KMNSIT: https://github.com/rois-codh/kmnist
- FashionMNIST: https://github.com/zalandoresearch/fashion-mnist
- USPS: https://git-disl.github.io/GTDLBench/datasets/usps_dataset/

We obtained the datasets from Torchvision: https://pytorch.org/vision/stable/datasets.html
