# SRON: State-free LLM Training via Row-wise Gradient Scaling

This repo contains the official implementation for the paper **SRON: State-free LLM Training via Row-wise Gradient Scaling**.



## Reproducibility

### Install

Our experiments were mainly conducted by Python 3.10.16 with a CUDA version 11.8. See the requirements.txt, run Install from pip:

```shell
pip install -r requirements.txt
```

### Usage

```python
from optimizer import get_optimizer_params
from sron_torch import FROST

scale_params, base_params = get_optimizer_params(model, logger)
optimizer = FROST(
    lr=args.lr,
    wd=args.weight_decay,
    sgd_params=scale_params,
    adamw_params=base_params,
    momentum=args.momentum,
    scale=args.scale,
)
```

### Pre-Training LLaMA/GPT/Gemma on C4 dataset

The scripts for pre-training LLaMA models on C4 dataset are in shells folder. The C4 dataset is available for download from Hugging Face.

#### Example for pre-training LLaMA-60M model

```bash
#!/bin/bash
# LLaMA-60M, SRON, 1 Node, 4 3090
module load cuda/11.8

export optimizer=sron
export lr=2.0e-2
export seed=0
export OMP_NUM_THREADS=1
export weight_decay=0.0
export momentum=0.0
export scale=5.0e-2


torchrun --standalone --nproc_per_node 4 torchrun_main.py \
    --model_type llama \
    --model_config configs/llama_60m.json \
    --lr $lr \
    --scale $scale \
    --batch_size 128 \
    --total_batch_size 512 \
    --num_training_steps 10000 \
    --warmup_ratio $warmup_ratio \
    --dtype bfloat16 \
    --eval_every 1000 \
    --save_every 100000 \
    --seed $seed \
    --momentum $momentum \
    --save_dir llama_60m/$optimizer/seed_$seed+$lr*$scale+wd_$weight_decay \
    --optimizer $optimizer > logs/llama60m/$optimizer/seed_$seed+$lr*$scale+wd_$weight_deca.out 2>&1 &
wait

echo 'Done!'
```

