#!/bin/bash


data_name=$1
echo "data_name: $data_name"

sizes=(6xs2 6xs 5xs 5xs2 5xs1 4xs 4xs2 4xs1 3xs 3xs2 3xs1 xxs xxs1 xxs2 xxs3 xxs4 xs xs3 xs2 xs1 s s3 s2 s1 base m5 m4 m3 m2 m1 m)

export http_proxy=http://star-proxy.oa.com:3128 
export https_proxy=http://star-proxy.oa.com:3128
wandb login --relogin 9f00c4cc90ef7ea81eabc2df4b2306f49cb73bf8
wandb online

mkdir -p model_law_large_${data_name}_log

for size in "${sizes[@]}"; do
    echo "run for ${size}"
    torchrun --standalone --nproc_per_node=8 train_gpt2.py \
        --input_folder data_law/$data_name/pretrain \
        --save_every 2000 \
        --val_loss_every 500 \
        --run_name $size \
        --warmup_ratio 0.05 \
        --warmdown_ratio 0.9 \
        --sequence_length 512 \
        --device_batch_size 16 \
        --num_epochs 4 \
        --weight_decay 0.1 \
        --learning_rate 0.0003 \
        --batch_size 128 \
        --bf16 \
        --model_size $size \
        --output_dir model_law_large_${data_name}_log/ \
        --wandb_project model_law_large_${data_name}
done

savegpu