# Long RNN

For training long context RNNs such as Mamba and RWKV. We want to explore the upper bound of the long context modeling cababilities of RNNs.

# Features

- Implemented with Accelerate + DeepSpeed for distributed training.
- Support gradient checkpointing at sequence level for training with longer sequences.
- Support passing an initial state to Mamba-1, Mamba-2, and RWKV-5.
- Pure PyTorch implementation of Mamba-1's parallel associative scan.
- Support analysis of the states' distribution and different components of the update rule in Mamba and RWKV models.

# How to Run

## Installation

Environment:
- CUDA >= 12.0
- Triton > 2.2.0

Install with `pip install -r requirements.txt`.

## Data Preparation

To build the dataset for pre-training, following the README file in the `data` directory. 

I uses RedPajama-4K in the paper, which is a dataset based on the RedPajama dataset, by performing deduplication then filtering out all sequences less than 4K tokens.
For evaluation, I uses RedPajama-16K, which underwent the same deduplication process, but all sequences with less than 16K tokens are excluded.

## Training

Just run

```shell
accelerate run train.py {args}
```

For available arguments, check out `arguments.py`. You can also use a script to launch a job:

```shell
bash train_mamba2.sh
```

RWKV-5 (Not working now):

```shell
python train.py --model rwkv --batch_size 1 --packing_count 1 --grad_accum 32 --max_length 4096 --lr_scheduler stabledrop --pretrained_path /mnt/data/user/tc_agi/klara/models/rwkv/rwkv5_world_180m_20240528/train-model --ckpt_dir result/checkpoints
```

## Analysis

The code in the `analysis` directory is for analyzing the states' distribution and different components of the update rule. Each of the `get_*.py` file is for extracting and inspecting one particular component of the update rule. The `analyze_states.py` is for computing the statistics of the SSM and convoulutional state of each layer and each head over time, as well as drawing the distribution of each channel in the states. For specific arguments, check out the file that your are executing, which contains the and `Args` class that specifies all the arguments.
