# Periodical Moving Average Accelerates Gradient Accumulation for Post-Training

This repository is the official implementation of [Periodical Moving Average Accelerates Gradient Accumulation for Post-Training]. 

## Requirements

Our experiments use the SWIFT frame, so for the requirements you can refer to the readme of SWIFT: https://github.com/modelscope/swift/blob/main/README.md


## SFT and its Evaluation
For sft experiments, you can modify the following bash script according to your specific task. Here is an example in our experiments to do sft task on phi-2 model with Alpaca dataset using AdamW optimizer: 
```
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
swift sft \             
    --model_type phi2-3b \                                       
    --sft_type full \
    --template_type default \
    --batch_size 32 \                                             
    --train_dataset_sample 41600 \
    --eval_steps 10 \
    --output_dir output \
    --optim adamw_torch \
    --num_train_epochs 1 \
    --max_length 512 \
    --learning_rate 2e-6 \
    --weight_decay 0.01 \
    --use_flash_attn true \
    --save_only_model true \
    --dataset alpaca-en \
    --agma_gradient_accumulation_steps 1 \
    --gradient_accumulation_steps 4 \
    --gradient_checkpointing false \
```
Here is some hyperparameter setting that is important. We set the batch size to be 32, and learning rate to be 2e-6, gradient_accumulation_steps to be 4. All the parameter settings in SWIFT framework are inherited: https://github.com/modelscope/swift/blob/main/docs/source_en/LLM/Command-line-parameters.md#dpo-parameters 

We have modified the code of the SWIFT framework to adapt to our optimizer. In order to use our own optimizer, we added additional parameters. The [--optim] now accepts new parameters including "AGMA" and "AGMA_Lion", which respectively represent the optimizers we designed: AdamW-PMA and Lion-PMA. The [--agma_gradient_accumulation_steps] represents the period length of PMA.

In our sft experiments, we compare this four different optimizers by modify the [--optim] to be "adamw_torch","AGMA","lion_32bit" and "AGMA_Lion". For example, when utilize AdamW-PMA-4 or Lion-PMA-4, we set the [--agma_gradient_accumulation_steps] to be 4 and [--gradient_accumulation_steps] to be 1. When utilize AdamW and Lion, we set the[--agma_gradient_accumulation_steps] to be 1 and [--gradient_accumulation_steps] to be 4. The other hyper-parameters should be set to the same.

For the evaluation of model after sft, we use the following parameter settings, just modify the [--ckpt_dir] to be the path of your model's checkpoint:
```
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 \
swift eval \
    --ckpt_dir checkpoint \
    --eval_dataset mmlu \
    --eval_limit 10 \
    --load_dataset_config true \
    --use_flash_attn false \
    --max_new_tokens 4096 \
    --temperature 0.1 \
    --top_p 0.7 \
    --repetition_penalty 1. \
    --merge_lora false \
    --truncation_strategy truncation_left \
```

## DPO
We do experiments on DPO task using the following script, using phi-2 model on hh-rlhf-harmless-base dataset to do DPO task:
```
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
swift dpo \
    --model_type phi2-3b \
    --sft_type full \
    --template_type default \
    --batch_size 16 \
    --train_dataset_sample 41600 \
    --eval_steps 100 \
    --output_dir output \
    --optim lion_32bit \
    --num_train_epochs 1 \
    --max_length 512 \
    --learning_rate 2e-6 \
    --weight_decay 0.01 \
    --use_flash_attn true \
    --save_only_model true \
    --dataset hh-rlhf-harmless-base \
    --agma_gradient_accumulation_steps 16 \
    --gradient_accumulation_steps 1 \
    --gradient_checkpointing false 
```
Similar to SFT task, we verify the effectiveness of PMA method by comparing different optimizers by modify the hyper-parameter [--optim]、[--agma_gradient_accumulation_steps] and [--gradient_accumulation_steps].


## Results
The evaluation results of sft model are as follows, other experiment results can be found in paper:

| Optimizers   |   Val Loss  |   Avg score  |
| ------------ | ----------- | ------------ |
| AdamW-4      |   0.9212    |     24.4     |
| AdamW-8      |   0.9408    |     23.3     |
| AdamW-PMA-4  |   0.9352    |     21.9     |
| AdamW-PMA-8  |   0.9078    |     27.7     |
| Lion-4       |   0.9227    |     21.8     |
| Lion-8       |   0.9486    |     23.0     |
| Lion-PMA-4   |   0.9136    |     21.8     |
| Lion-PMA-8   |   0.9373    |     22.3     |


## Contributing
