# LookWhere

# Setup

This repository is based on [deit3-jax](https://github.com/affjljoo3581/deit3-jax). To pretrain with ImageNet, please follow the setup instructions located in their README.

# Docker
If on windows, make sure you have WSL2 installed and create a folder inside WSL (Ubuntu) where this repo is cloned. Then, follow these instructions to setup the container:
1. Install docker using the [instructions here](https://docs.docker.com/engine/install/ubuntu/). Make sure you follow the post-installation steps for Linux linked at the bottom, if on Linux.
2. Then download the nvidia container toolkit using these [instructions](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). Verify you can run the sample workload linked at the bottom.
3. Enter WSL if using Windows.
4. Run `nvidia-smi` and make note of your cuda version. Change the cuda version of the base image in `dockerfile` to match. Note that the version has 3 numbers.
5. Run `chmod +x build.sh && ./build.sh` to build the image.
6. Download the dev containers extension. Press `ctrl+P` and select `> Rebuild and reopen in dev-container`.
7. This will open the container, which has the JAX environment setup. To test, first run `python3 ./sanity/test.py` and verify that you see a `JAX is using the GPU!` message. Then run `python3 sanity/pytorch_dataloading.py` and ensure you can train an MNIST classifier to `0.96` in a few seconds. 
8. Enjoy 🎉

# Testing Zoom Code Shapes
Make changes and then test shapes with the following: 
```bash
$ ./config/test.sh
```

# Running a sweep
An example script has been included for reference to sweep hyperparameters during ImageNet pretraining. Start the script by running
```bash
$ ./config/run.sh --start <sweep_start_idx> --end <sweep_end_idx>
```

Note that the script above expects `GCS_DATASET_DIR`, `GCS_MODEL_DIR` and `WANDB_API_KEY` to have been set (as described in the deit3-jax instructions).