## MNIST Manipulator Experiments

This folder contains the code, trained models, and data to reproduce the MNIST experiments in "A Pattern Language for Machine Learning Tasks". See Figure 2 and Appendix D.3 for more details.

If you just want to train a simple `manipulator` model, then you can use `train.py` to train a toy model on the MNIST dataset and output some plots. For example:
```
pip install -r requirements.txt
python3 train.py --device=cpu --epochs=20
```
When it's finished training, this will output the test accuracy of the getter, save the trained models, and output three plots:
1. `mnist_getter.png` will show the probability vector output of the getter for one example, the distribution of labels across the whole dataset, and the confusion matrix against the true labels.
2. `mnist_putter_get_put.png` will show some example images along with their reconstruction from the getter using the GetPut rule.
3. `mnist_putter_put_all.png` will put all possible labels onto some example digits to show that the style of the original image is preserved by the putter.

If you want to produce more plots, you can do:
```
python3 train.py --getter=trained_getter.pt --putter=trained_putter.pt
```

Otherwise, there are two main entry points: `run_decay_test.py` allows you to run the accuracy decay experiment shown in Figure 7. `produce_plots.py` produces plots for Figure 2 and 7 using the included pretrained models and data. The model definitions themselves and training loops are given in `mnist_manipulator.py`, and the particular hyperparameter setup and training schedule for the accuracy decay experiments is given in `training_schedules.py`. The larger model used to produce Figure 2 is defined in `model_figure5a.py`. 

Trained models for producing Figure 2 are included (`outputs/models/figure5a_putter.pt`, `outputs/models/figure5a_getter.pt`), and an example of how to initialize and use them can be found in `produce_plots.py`. Log files, plots, and trained models are included for one run each of the accuracy decay experiment for methods (a), (b), and (c) as discussed in Appendix F.2 - these are stored in `outputs/plots` etc - see `run_decay_test.py` and `training_schedules.py` to understand what the plots mean and how to initialize the models. Accuracy data for more runs is included in `outputs/data/` (enough to reproduce Figure 7 using `produce_plots.py`).
