Thank you for viewing our source code! 

## Introduction

We introduce **Fira**, a plug-and-play memory-efficient training framework of LLMs.

Different from LoRA and Galore, we realize training with full-rank gradients of full-rank weights, constituting the first attempt to achieve full-rank training consistently under the low-rank constraint. 

Our method is easy to implement, basically relying on just two lines of equations. 

This code file includes, specific implementation of Fira, pre-training benchmark on C4 dataset. we will provide the link to our GitHub repository containing the full code upon acceptance.

## Pre-training LLaMA (60M~7B) on the C4 dataset

`./pre_training_c4` includes the code for pre-training LLaMA models on the C4 dataset.
`./optimizer_torch` includes the implementation of optimizers Fira.

### Set up the environment
```bash
cd pre_training_c4
pip install -r requirements.txt
```
Our experiment scripts are validated on Python 3.9 with PyTorch 2.2.2.

### Code Structure
`./pre_training_c4/torchrun_main.py` script is used for pre-training LLaMA models on the C4 dataset. 
`./pre_training_c4/scripts` directory stores the benchmark scripts across different LLaMA model sizes (60M, 130M, 350M, 1B, 7B).

For instance, to pre-train a 60M model on C4 dataset, execute the following command:
```bash
# LLaMA-60M, Fira-Adam, 1 A100, 1 Node
torchrun --standalone --nproc_per_node 1 torchrun_main.py \
    --model_config llama_configs/llama_60m.json \
    --lr 0.01 \
    --alpha 0.25 \
    --rank 128 \
    --update_proj_gap 200 \
    --batch_size 256 \
    --total_batch_size 512 \
    --num_training_steps 10000 \
    --warmup_steps 1000 \
    --weight_decay 0 \
    --dtype bfloat16 \
    --eval_every 1000 \
    --optimizer fira_adamw 
```

### Notice
+ Following Galore, we implemented Fira based on AdamW. However, for all our experiments, Adam is used by default with `weight_decay=0`. So the comparison with the full-rank baseline Adam is fair.

+ This script directly accesses huggingface to load the C4 dataset, so please ensure a stable internet connection.
