# Code for Spatially Structured Recurrent Modules

To use this code base, you will need [miniconda](https://docs.conda.io/en/latest/miniconda.html). 

Once miniconda is installed, install external dependencies with: 
```bash
conda env create -f conda_environ.yml
``` 

From this directory, call `python pull_data.py` to download the datasets. 

Having made sure `speedrun` is in path, you can run the bouncing ball experiments as following: 
```bash
python train_wm_bb.py experiments/BB/S2GRU-0 --inherit templates/BB/S2GRU-X
``` 
The tensorboard logs can be found in `experiments/BB/S2GRU-0/Logs`. 

Our experiments ran on a V100-32GB GPU for 200 epochs. If you run out of GPU memory, consider using a smaller batch size:
 ```bash
python train_wm_bb.py experiments/BB/S2GRU-0 --inherit templates/BB/S2GRU-X --config.data.kwargs.batch_size 16
``` 
Batch size 16 should use half the memory, 8 should use a quarter, and so on. Note however that smaller batch size may affect the performance. 

To run a baseline model, say a LSTM, use: 
```bash
python train_wm_bb.py experiments/BB/LSTM-0 --inherit templates/BB/LSTM-X
``` 

To run the Starcraft experiments, use: 
  ```bash
python train_wm_sc2.py experiments/SC2/S2GRU-0 --inherit templates/SC2/S2GRU-X
```

