# Codebase for MoEUT

The official training repository for our paper ["MoEUT: Mixture-of-Experts Universal Transformers"](https://arxiv.org/abs/2405.16039). This codebase is the
one we used to develop the model and it's quite messy.

If you are looking for an easy to use, short, cleaned up version, please take a look at [https://github.com/robertcsordas/moeut](https://github.com/robertcsordas/moeut).

## Installation

This project requires Python >= 3.10 and PyTorch >= 2.2.

```bash
pip3 install -r requirements.txt
```

Create a Weights and Biases account and run
```bash
wandb login
```

More information on setting up Weights and Biases can be found on
https://docs.wandb.com/quickstart.

For plotting, LaTeX is required (to avoid Type 3 fonts and to render symbols). Installation is OS specific.

## Pretrained Model Checkpoints

We released the model checkpoints from our paper for all of our MoEUT models. They can be found at https://huggingface.co/robertcsordas/moeut_training_checkpoints.

**NOTE**: These are not production quality pretrained models, but only a proof of concept. Because of our limited resources, they are only trained on 6.5B tokens which is very little with modern standards.

The structure of the checkpoint repository:
```
├───cache - Tokenizers for the different datasets
└───checkpoints - Model checkpoints
```

The ``cache`` folder contains our tokenizers and must be copied to this folder in order to avoid minor differences that can happen if different version of SentencePiece is used than ours. When a specific task is run, it will automatically download the necessary data and tokenize it. It only tokenizes the amount of data actually needed for training/evaluation to avoid wasting space and time.

The checkpoints folder contains all the model checkpoints (without the optimizer state which we removed to save space). It can be loaded with ``--restore checkpoints/C4_44M.ckpt``. It automatically resotres all configurations used for training.

In order to run a simple validation pass, you can run:

```bash
python3 main.py -restore checkpoints/C4_44M.ckpt -test_only 1 -log tb -name test -reset 1 -lm.eval.enabled 0 -stop_after 0
```

The flag ``-log tb`` is used to switch to tensorboard logging instead of W&B which was used for the training run, ``-lm.eval.enabled 0`` disables the costly downstream evals. ``-stop_after 0`` is a hack to avoid wasting exessive amount of time on tokenizing training data which will not be used for evaluation anywas (sorry, this could be handled better). For the other flags, see the details at the end of this doc.

## Usage

The code makes use of Weights and Biases for experiment tracking. In the "sweeps" directory, we provide sweep configurations for all experiments we have performed.

To reproduce our results, start a sweep for each of the YAML files in the "sweeps" directory. Run wandb agent for each of them in the main directory. This will run all the experiments, and they will be displayed on the W&B dashboard.

## ClusterTool

The code is designed to work with [ClusterTool](https://github.com/RobertCsordas/cluster_tool).

If used wih ClusterTool, W&B sweeps, run preemption, file synchronization, etc will be handled automatically.

### Re-creating plots from the paper

Edit config file "paper/moe_universal/config.json". Enter your project name in the field "wandb_project" (e.g. "username/moeut"). Copy the checkpoint of your runs to paper/moe_universal/checkpoints/<run_id>/model.ckpt. Then run "paper/moe_universal/run_tests.py" to run additional validations on zero-shot downstream tasks. This will take long time.

To reprodce a speficif plot or table, run the script of interest within the "paper" directory. For example:

```bash
cd paper/moe_universal
python3 main_result_table.py
```

## Structure
```
├───cache - temporary files automatically generated by this code
├───framework - reusable library for running experiments
│    ├─  datasets - a large collection of diverse datasets
│    ├─  visualize - universal plotting functions working for TF and W&B
│    ├─  helpers - helper routines for cluster, wandb and training setup
│    ├─  utils - useful utils (downloaders, samplers, multiprocessing)
│    ├─  layers - useful layers
│    ├─  tasks - main training loop, etc.
│    └─  * - all kinds of reusable components
│
├───save - saved checkpoints and trainig state
├───sweeps - Weights and Biases experiment configs
├───tasks - experiments. Add new experiments as new files here, it will be automatically picked up.
├───main.py - initialization code
└───cluster.json - configuration for the ClusterTool
```


## Useful built-in arguments

- `-task`: which task to use. Tasks are picked up from tasks directory automatically. See how to create a new task in the `Creating a new task` chapter.
- `-name`: state will be saved in `save/<name>` folder. Necessary to provide if using TB.
- `-restore <checkpoint file>`: restores everything, including the command line arguments, from a checkpoint file. If any other argument is specified, it overwrites the one found in the checkpoint.
- `-reset 1`: do not load checkpoint from `save/<name>` but restart training.
- `-log`: can be `tb` for tensorboard or `wandb` for Weights & Biases. All supported plot types are defined in `framework/visualize/plot.py` and support logging on both. If `tb` is specified, the run will start a Tensorboard session on port 7000 (or the next available)
- `-gpu <index>`: which GPU to use. Leave empty for allocating the next empty one.
- `-lr <learning rate>`: specify learning rate
- `-batch_size <batch size>`: specify batch size
- `-wd`: weight decay
- `-stop_after <n_iters>`: terminate after this many iterations. It also sets the amount of steps for the LR scheduler if used.
- `-amp 1`: use mixed-precision training
- `-grad_clip <max norm>`: clip gradients to the this max norm. 1 by default. Specify `none` to disable.
- `-lr_sched.type cos`: use cos learning rate decay
- `-lr_warmup <n_iters>`: use linear LR warmup for this many steps.
- `-load_pretrained_model <checkpoint file>`: loads the model only, but not the arguments, opitmizer state, etc, from a checkpoint.
- `-length_bucketed_sampling 1`: groups examples of similar length into batches to save compute wasted for padding. Only works for some datasets.
- `-save_interval <n_iters>`: how often to save checkpoints.
- `-test_interval <n_iters>`: how often to run automatic tests.
- `-test_only 1`: run only a validation pass.
- `-per_device_batch_size <batch size>`: specify the per-GPU batch size. Microbatching (gradient accumulation) will be used to ensure that the actual batch size is <= than the specified. Uneven division is supported.
- `-n_microbatch <number of microbatching steps>`: manually specify the number of microbatches. Mutually exclusive with `per_device_batch_size`.

There are many other useful default arguments, defined in `framework/task/task.py`, `framework/task/simple_task.py` and `framework/helpers/training_helper.py`.

## Known issues

Triton seems to be broken on Volta GPUs when using float16 starting from PyTorch 2.2 onwards (see [github issue](https://github.com/pytorch/pytorch/issues/127157)). Until the PyTorch team does not fix the issue, please downgrade to PyTorch 2.1 or disable AMP if you have Volta GPUs.


