### do "conda activate gpt" before running this script

# Min time per run: ~3 hours on a 8xA100 node
# Max time per run: ~30 hours on a 8xA100 node
# Est time: ~10 hours/job * 36 jobs = 360 hours on a 8xA100 node, or ~48 hours on 8 8xA100 nodes

export WANDB_MODE=offline
export WANDB_API_KEY=1159dda0d0566b72d5cd71464a06ff6b73efb455
export WANDB__SERVICE_WAIT=300

DATA_DIR=./open_small # TODO
OUT_DIR=./ # TODO
WANDB_PROJ=moe_gpt

BATCH_SIZE=64
GRAD_ACCUM=8
MAX_ITERS=100_000
BLOCK_SIZE=128


lr=3e-3
num_active_experts=2

for n_layer in 3 6 9; do
for d_model in 256 512 1024 2048; do
for num_experts in 4 8 16; do
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=$(shuf -i 49152-65535 -n 1) \
train_gpt.py config/train_open_small.py --block_size=${BLOCK_SIZE} --struct="btt_norm_moe_para" --num_experts=${num_experts} --num_active_experts=${num_active_experts} --layers=all_but_last --d_model=${d_model} --n_layer=${n_layer} --n_head=-1 --d_head=64 --max_iters=${MAX_ITERS} --data_dir=${DATA_DIR} --out_dir=${OUT_DIR} --batch_size=${BATCH_SIZE} --gradient_accumulation_steps=${GRAD_ACCUM} --init_lr=${lr} --wandb_project=${WANDB_PROJ}
done;
done;
done;
