## Directory Contents
- `shapes3d/` Trains a model from scratch to isolate generative factors from the Shapes3D dataset, and measures the
information content of the representations using mutual information neural estimation (Belghazi et al., 2018)
- `mnist.ipynb` iPython notebook to isolate digit style (stroke and thickness) after training on images grouped by digit class
- `abc_mnist_training_evolution.gif` showing an example training evolution of the embeddings of MNIST digits, isolating style

## Python environment
The code was run successfully with Python 3.6.12.  The necessary libraries may be installed using pip with the following line:

`pip install -r requirements.txt`

## Isolating factors of variation from Shapes3D

The Shapes3D dataset allows complete control and knowledge of factors of variation. 
We use it to pin down precisely the factor isolation which results from different set supervision settings. 
A large variety of experiments may be run using the script `shapes3d/train.py`.  
The Shapes3D dataset will be automatically downloaded by tensorflow_datasets.

The following example call

`python -m shapes3d.train --inactive_vars=03 --curate_both_stacks=False`

trains a network from scratch with the wall hue and scale generative factors inactive, with only one out of each pair of training sets curated.  The second set, for every training batch, will be sampled randomly across all images (the 'ABC-X' experiments in Figure 4).

The command line flag `inactive_vars` takes a string of digits from 0-5, one for each of the six generative factors of the Shapes3D dataset (wall hue, object hue, floor hue, scale, shape, and orientation).  
`01` curates stacks with wall and object hue as inactive variables, for example.  
Note that this curation process uses `tf.data.Dataset.filter` to run through the Shapes3D dataset, which requires searching through more of the dataset to find each training set when there are more inactive factors of variation.

The resulting embeddings tend to be low dimensional, so visualization via PCA (as in Fig. 3 of the manuscript) is informative.  
We also include functionality to estimate the mutual information between the learned embeddings and each of the generative factors.

Other noteworthy flags:

- `save_pngs` outputs images during training like the two above of sample embeddings and the mutual information measurements.

- `similarity_type` sets which distance metric to use when computing the loss. The best results seem to come from squared Euclidean distance (`l2sq`), but several others are implemented and there's room to explore.

- `run_augmentation_experiment` is a boolean flag which will run a double augmentation comparison if set to `True`; this is another means to isolating factors of variation with different strengths.

Training progress can be monitored with Tensorboard.

## Fast digit style isolation on MNIST

The iPython notebook (abc_mnist.ipynb) partitions the MNIST training set into 10 different tf.data.Datasets, with the option to withhold one digit for test time (as in the paper).
Embedding visualizations are of the two PCA dimensions with the largest variance, also as in the paper.