### File Structure
The `stream_model.py` contains the core logic of implementing StreamBP. Essential components are:
* `StreamModel`: Wraps the original PretrainedModel to enable partition-wise forward.
* `CheckpointFunctionForStreamBackward`: Modifies the original Checkpoint function to enable partition-wise reforward and backward. The main logic is implemented in its `backward` function.

The `scripts` folder includes the scripts for reproducing all the paper results:
* `test_bp.py`: Section 4.1 experiments.
* `test_dpo.py`, `test_sft.py`, `test_grpo.py`: Section 4.2 experiments.
* `test_zero.py`: Section 4.3 experiments.
* `streambp_linear.py`: Appendix A.1 experiments.

The `trainers` folder contains the implementation of StreamBP for logits under SFT, GRPO, and DPO objectives. We inherits the Trainer from TRL library and modifies the logic of `compute_loss`.

### Usage
**Environment setup:** 
```
conda create -n streambp python=3.10
conda activate streambp
pip install -r requirements.txt
```

To reproduce the paper results, set up the environment variables by:
```
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export PYTHONPATH=.
export CUDA_VISIBLE_DEVICES=0 # for single-GPU experiments
```
Then, run the script in `scripts`, e.g.
```
python scripts/test_bp.py --mode stream --seq_len 10000
python scripts/test_bp.py --mode base --seq_len 10000
```
The following parameters can be adjusted as needed:
* `--mode`: Options: ("base", "stream").
* `--seq_len`: Sequence length.
* `--chunk_size`: Partition size of transformer layer.
* `--use_lora`: Use LoRA in training.