#!/bin/bash

# Define model size groups
group1=(5xs 4xs 3xs xxs)
group2=(6xs xs s)
group3=(base m)
case $1 in
    1)
        sizes=("${group1[@]}")
        ;;
    2)
        sizes=("${group2[@]}")
        ;;
    3)
        sizes=("${group3[@]}")
        ;;
    4)
        sizes=("${group4[@]}")
        ;;
    *)
        # group only 1 element, that is $1
        sizes=($1)
        ;;
esac

export http_proxy=http://star-proxy.oa.com:3128 
export https_proxy=http://star-proxy.oa.com:3128
wandb login --relogin 9f00c4cc90ef7ea81eabc2df4b2306f49cb73bf8
wandb online

# Create output directory
mkdir -p fineweb_log

# Data name
data_name=fineweb10B
# Iterate through sizes array
for size in "${sizes[@]}"; do
    echo "Running for model size: ${size}"
    torchrun --standalone --nproc_per_node=8 train_gpt2.py \
        --input_folder "fineweb10B/" \
        --save_every 4000 \
        --val_loss_every 4000 \
        --run_name $size \
        --warmup_ratio 0.05 \
        --warmdown_ratio 0.9 \
        --sequence_length 512 \
        --device_batch_size 16 \
        --num_epochs 1 \
        --weight_decay 0.1 \
        --learning_rate 0.0003 \
        --batch_size 128 \
        --bf16 \
        --model_size $size \
        --output_dir "fineweb_log/" \
        --wandb_project fineweb
done

savegpu