#!/bin/bash

# the same as scripts/run_gpt2_124M.sh but with PyTorch

# if you wish to train on just a single GPU, simply skip the torchrun part, i.e.
# python train_gpt2.py ... (all the other arguments the same)
#torchrun --standalone --nproc_per_node=8 train_gpt2.py \


for k in "d12" "d24"; do
if [ $k = "d12" ]
then
BS=16
runcommand="python3 train_gpt2_overtrain_sfadam.py"
fi
if [ $k = "d24" ]
then
BS=8
runcommand="python3 train_gpt2_overtrain_sfadam.py"
#runcommand="torchrun --standalone --nproc_per_node=2 train_gpt2_overtrain_sfadam.py"
fi
for j in "layerdepth"; do
for lrmode in fanin none; do
lrlist=( 0.125 0.25 0.5 1.0 2.0 4.0 8.0 )
for lrmul in ${lrlist[@]}; do
    ${runcommand}  \
    --input_bin "dev/data/fineweb10B/fineweb_train_*.bin" \
    --sample_every 0 \
    --write_tensors 0 \
    --model ${k} \
    --residual_scale ${j} \
    --weight_init fanin \
    --lwlr_mode ${lrmode} \
    --batch_size ${BS} \
    --sequence_length 1024 \
    --total_batch_size 524288 \
    --overfit_batch_number 10 \
    --dtype bfloat16 \
    --compile 1 \
    --tensorcores 1 \
    --flash 1 \
    --num_iterations 1400 \
    --weight_decay 0.0 \
    --zero_stage 0 \
    --lr_multiplier ${lrmul}  | tee outputovertrain${k}sfadam${j}${i}${lrmode}${lrmul}.txt
done
done
done
done
