# Training codebase for "Mixture of Sparse Attention: Content-Based Learnable Sparse Attention via Expert-Choice Routing"

## How to reproduce:

Run experiments in the sweeps directory or run the configurations in the .vscode/launch.json

in launch.json there are two types of configurations:
```
{MODEL}_{size}{sparsity}, e.g. MoSA_tiny2, which represents isoflop training of MoSA of size tiny with sparsity 2,
{MODEL}_sw{sequence_length}, e.g. routing_sw4096, which represents training Routing attention with sliding window attention on sequences of length 4096
```

## Structure
```
├───cache - temporary files automatically generated by this code
├───framework
│    ├─  datasets - a large collection of diverse datasets
│    ├─  helpers - helper routines for cluster, wandb and training setup
│    ├─  utils - useful utils (downloaders, samplers, multiprocessing)
│    ├─  layers - useful layers
│    ├─  tasks - main training loop, etc.
│    └─  * - all kinds of reusable components
│
├───save - saved checkpoints and trainig state
├───sweeps - Weights and Biases experiment configs
├───tasks - experiments. Add new experiments as new files here, it will be automatically picked up.
├───main.py - initialization code
└───cluster.json - configuration for the ClusterTool
```

## ClusterTool

The code is designed to work with [ClusterTool](https://github.com/RobertCsordas/cluster_tool). Then, edit W&B project name `cluster.json` in _this_ directory.

Example on how to run the experiments:
```bash
ct -s -m sc -gt all_not_a100 -sp low wandb sweep sweeps/gru_repeat.yaml
```

The meaning of flags:
 - `-s -m sc` run it on the SLURM claster called sc
 - `-gt all_not_a100` on any GPU type that is not A100
 - `-sp low` use low SLURM priority

If used wih ClusterTool, W&B sweeps, run preemption, file synchronization, etc will be handled automatically.

## Useful built-in arguments

- `-task`: which task to use. Tasks are picked up from tasks directory automatically. See how to create a new task in the `Creating a new task` chapter.
- `-name`: state will be saved in `save/<name>` folder.
- `-restore <checkpoint file>`: restores everything, including the command line arguments, from a checkpoint file. If any other argument is specified, it overwrites the one found in the checkpoint.
- `-reset 1`: do not load checkpoint from `save/<name>` but restart training.
- `-log`: can be `tb` for tensorboard or `wandb` for Weights & Biases. All supported plot types are defined in `framework/visualize/plot.py` and support logging on both. If `tb` is specified, the run will start a Tensorboard session on port 7000 (or the next available)
- `-gpu <index>`: which GPU to use. Leave empty for allocating the next empty one.
- `-lr <learning rate>`: specify learning rate
- `-batch_size <batch size>`: specify batch size
- `-wd`: weight decay
- `-stop_after <n_iters>`: terminate after this many iterations. It also sets the amount of steps for the LR scheduler if used.
- `-amp 1`: use mixed-precision training
- `-grad_clip <max norm>`: clip gradients to the this max norm. 1 by default. Specify `none` to disable.
- `-lr_sched.type cos`: use cos learning rate decay
- `-lr_warmup <n_iters>`: use linear LR warmup for this many steps.
- `-load_pretrained_model <checkpoint file>`: loads the model only, but not the arguments, opitmizer state, etc, from a checkpoint.
- `-length_bucketed_sampling 1`: groups examples of similar length into batches to save compute wasted for padding. Only works for some datasets.
- `-save_interval <n_iters>`: how often to save checkpoints.
- `-test_interval <n_iters>`: how often to run automatic tests.
- `-per_device_batch_size <batch size>`: specify the per-GPU batch size. Microbatching (gradient accumulation) will be used to ensure that the actual batch size is <= than the specified. Uneven division is supported.
- `-n_microbatch <number of microbatching steps>`: manually specify the number of microbatches. Mutually exclusive with `per_device_batch_size`.

There are many other useful default arguments, defined in `framework/task/task.py`, `framework/task/simple_task.py` and `framework/helpers/training_helper.py`.
