# Codebase for MoEUT

The official repository for our paper "MoEUT: Mixture-of-Experts Universal Transformers".

## Installation

This project requires Python 3.10 and PyTorch 2.2.

```bash
pip3 install -r requirements.txt
```

Create a Weights and Biases account and run
```bash
wandb login
```

More information on setting up Weights and Biases can be found on
https://docs.wandb.com/quickstart.

For plotting, LaTeX is required (to avoid Type 3 fonts and to render symbols). Installation is OS specific.

## Usage

The code makes use of Weights and Biases for experiment tracking. In the "sweeps" directory, we provide sweep configurations for all experiments we have performed.

To reproduce our results, start a sweep for each of the YAML files in the "sweeps" directory. Run wandb agent for each of them in the main directory. This will run all the experiments, and they will be displayed on the W&B dashboard.

### Re-creating plots from the paper

Edit config file "paper/moe_universal/config.json". Enter your project name in the field "wandb_project" (e.g. "username/moeut"). Copy the checkpoint of your runs to paper/moe_universal/checkpoints/<run_id>/model.ckpt. Then run "paper/moe_universal/run_tests.py" to run additional validations on zero-shot downstream tasks. This will take long time.

To reprodce a speficif plot or table, run the script of interest within the "paper" directory. For example:

```bash
cd paper/moe_universal
python3 main_result_table.py
```


## Structure
```
├───cache - temporary files automatically generated by this code
├───framework - reusable library for running experiments
│    ├─  datasets - a large collection of diverse datasets
│    ├─  visualize - universal plotting functions working for TF and W&B
│    ├─  helpers - helper routines for cluster, wandb and training setup
│    ├─  utils - useful utils (downloaders, samplers, multiprocessing)
│    ├─  layers - useful layers
│    ├─  tasks - main training loop, etc.
│    └─  * - all kinds of reusable components
│
├───save - saved checkpoints and trainig state
├───sweeps - Weights and Biases experiment configs
├───tasks - experiments. Add new experiments as new files here, it will be automatically picked up.
├───main.py - initialization code
└───cluster.json - configuration for the ClusterTool
```


## Useful built-in arguments

- `-task`: which task to use. Tasks are picked up from tasks directory automatically. See how to create a new task in the `Creating a new task` chapter.
- `-name`: state will be saved in `save/<name>` folder.
- `-restore <checkpoint file>`: restores everything, including the command line arguments, from a checkpoint file. If any other argument is specified, it overwrites the one found in the checkpoint.
- `-reset 1`: do not load checkpoint from `save/<name>` but restart training.
- `-log`: can be `tb` for tensorboard or `wandb` for Weights & Biases. All supported plot types are defined in `framework/visualize/plot.py` and support logging on both. If `tb` is specified, the run will start a Tensorboard session on port 7000 (or the next available)
- `-gpu <index>`: which GPU to use. Leave empty for allocating the next empty one.
- `-lr <learning rate>`: specify learning rate
- `-batch_size <batch size>`: specify batch size
- `-wd`: weight decay
- `-stop_after <n_iters>`: terminate after this many iterations. It also sets the amount of steps for the LR scheduler if used.
- `-amp 1`: use mixed-precision training
- `-grad_clip <max norm>`: clip gradients to the this max norm. 1 by default. Specify `none` to disable.
- `-lr_sched.type cos`: use cos learning rate decay
- `-lr_warmup <n_iters>`: use linear LR warmup for this many steps.
- `-load_pretrained_model <checkpoint file>`: loads the model only, but not the arguments, opitmizer state, etc, from a checkpoint.
- `-length_bucketed_sampling 1`: groups examples of similar length into batches to save compute wasted for padding. Only works for some datasets.
- `-save_interval <n_iters>`: how often to save checkpoints.
- `-test_interval <n_iters>`: how often to run automatic tests.
- `-per_device_batch_size <batch size>`: specify the per-GPU batch size. Microbatching (gradient accumulation) will be used to ensure that the actual batch size is <= than the specified. Uneven division is supported.
- `-n_microbatch <number of microbatching steps>`: manually specify the number of microbatches. Mutually exclusive with `per_device_batch_size`.

There are many other useful default arguments, defined in `framework/task/task.py`, `framework/task/simple_task.py` and `framework/helpers/training_helper.py`.

