# B2S6

This is the code repository accompanying the manuscript ''Block-Biased Mamba for Long-Range Sequence Processing." The repository is heavily adapted from the ''state-spaces" GitHub repository (https://github.com/HazyResearch/state-spaces.git) and the PyTorch implemenration of Mamba (https://github.com/alxndrTL/mamba.py). While it contains references to existing papers and code repositories, it includes no information that reveals the identities of the manuscript authors.

Disclaimer: All codes in this repository are a purely PyTorch implementation and are significantly slower than a hardware-aware [Mamba model](https://github.com/state-spaces/mamba). Implementing B2S6 more efficiently would require a modification of the CUDA code and is an important future work.

## Setup

### Requirements
This repository requires Python 3.9+ and Pytorch 1.10+.
It has been tested up to Pytorch 1.13.1.
Other packages are listed in [requirements.txt](./requirements.txt).
Some care may be needed to make some of the library versions compatible, particularly torch/torchvision/torchaudio/torchtext.

Example installation:
```
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
pip install -r requirements.txt
```

### Data

Basic datasets are auto-downloaded.
All logic for creating and loading datasets is in [src/dataloaders](./src/dataloaders/) directory.
The README inside this subdirectory documents how to download and organize other datasets.

### Configs and Hyperparameters

Configurations can be changed in [configs/experiment/lra](./configs/experiment/lra/).

### WandB

Logging with [WandB](https://wandb.ai/site) is built into this repository.
In order to use this, simply set your `WANDB_API_KEY` environment variable, and change the `wandb.project` attribute of [configs/config.yaml](configs/config.yaml) (or pass it on the command line e.g. `python -m train .... wandb.project=b2s6`).

Set `wandb=null` to turn off WandB logging.

## Execution

### LRA Benchmarks

The Long-Range Arena benchmarks can be tested by running
```
python -m train experiment=lra/mamba-foo
```
where `foo` is the name of the problem, choosing from `listops`, `imdb`, `aan`, `cifar`, `pathfinder`, and `pathx`. IMPORTANT: Several manual changes are required:

1. Change `d_model` on [this line](src/models/sequence/modules/s4block.py#L1083) to `config.model.d_model / 8` in the configuration file. Otherwise, you will get an `assert (d_input == d_model) or alpha == 0.0` error.
2. For the `ListOps` task, also change `n_state` on [this line](src/models/sequence/modules/s4block.py#L1084) from `64` to `4`.
3. To run a job with multiple GPUs, divide `config.loader.batch_size` by the number of GPUs.

These changes are subject to automated control in a future release.

### LLM Training

To train a B2S6 model on the SlimPajama-6B dataset, replace the [mamba_simple.py](https://github.com/HazyResearch/based/blob/main/based/models/mixers/mamba/modules/mamba_simple.py) file in the [based repository](https://github.com/HazyResearch/based/) with [B2S6_LLM.py](./B2S6_LLM.py) and download the dataset [here](https://huggingface.co/datasets/DKYoon/SlimPajama-6B).

### Experiment in Section 3

The experiment in section 3 is implemented in [wavesum.py](./wavesum.py), where one replaces [Model](./wavesum.py#L92) with their favorite recurrent unit, e.g., [S4D](./models/s4/s4d.py), [S6](https://github.com/alxndrTL/mamba.py/blob/main/mambapy/mamba.py), or [B2S6](src/models/sequence/modules/s4block.py#L1079). Note that the input and output shape of S4D is by default (B D L) and that of Mamba/B2S6 is by default (B L D). Please apply transposes accordingly.

### Experiment in Section 4

The experiment in section 4 is implemented in [copyme.py](./copyme.py), where one replaces [Model](./copyme.py#L98) with their favorite recurrent unit, e.g., [S4D](./models/s4/s4d.py), [S6](https://github.com/alxndrTL/mamba.py/blob/main/mambapy/mamba.py), or [B2S6](src/models/sequence/modules/s4block.py#L1079). Note that the input and output shape of S4D is by default (B D L) and that of Mamba/B2S6 is by default (B L D). Please apply transposes accordingly.