#!/bin/bash
# LLaMA-3B, SRON, 1 Node, 4 H100, Gradient checkpointing
module load cuda/11.8

export optimizer=sron
export lr=2.0e-2
export seed=0
export OMP_NUM_THREADS=1
export weight_decay=0.0
export momentum=0.0
export scale=5.0e-2


torchrun --standalone --nproc_per_node 4 torchrun_main.py \
    --model_type llama \
    --model_config configs/llama_7b.json \
    --lr $lr \
    --scale $scale \
    --batch_size 128 \
    --activation_checkpointing \
    --total_batch_size 512 \
    --num_training_steps 150000 \
    --warmup_ratio $warmup_ratio \
    --dtype bfloat16 \
    --eval_every 1000 \
    --save_every 100000 \
    --seed $seed \
    --momentum $momentum \
    --save_dir llama_7b/$optimizer/seed_$seed+$lr*$scale+wd_$weight_decay \
    --optimizer $optimizer > logs/llama7b/$optimizer/seed_$seed+$lr*$scale+wd_$weight_deca.out 2>&1 &
wait

echo 'Done!'