#!/bin/bash

# the same as scripts/run_gpt2_124M.sh but with PyTorch

# if you wish to train on just a single GPU, simply skip the torchrun part, i.e.
# python train_gpt2.py ... (all the other arguments the same)
#torchrun --standalone --nproc_per_node=8 train_gpt2.py \

for lrmode in none fanout fanin; do
for lr in 0.000075 0.00015 0.0003 0.0006 0.0012 0.0024; do
    python3 train_gpt2_overfit.py \
    --input_bin "dev/data/fineweb10B/fineweb_train_*.bin" \
    --input_val_bin "dev/data/fineweb10B/fineweb_val_*.bin" \
    --sample_every 0 \
    --output_dir pylog_gpt2_124M_overfit \
    --write_tensors 0 \
    --model d12 \
    --lwlr_mode ${lrmode} \
    --batch_size 16 \
    --sequence_length 1024 \
    --total_batch_size 524288 \
    --overfit_batch_number 20 \
    --dtype bfloat16 \
    --compile 1 \
    --tensorcores 1 \
    --flash 1 \
    --num_iterations 2000 \
    --weight_decay 0.1 \
    --zero_stage 1 \
    --learning_rate ${lr} \
    --warmup_iters 700 \
    --learning_rate_decay_frac 0.0 \
    --opt_type adamw
done
done

for lrmode in faninnosqrt; do
for lr in 0.0003 0.0006 0.0012 0.0024 0.0048 0.0096; do
    python3 train_gpt2_overfit.py \
    --input_bin "dev/data/fineweb10B/fineweb_train_*.bin" \
    --input_val_bin "dev/data/fineweb10B/fineweb_val_*.bin" \
    --sample_every 0 \
    --output_dir pylog_gpt2_124M_overfit \
    --write_tensors 0 \
    --model d12 \
    --lwlr_mode ${lrmode} \
    --batch_size 16 \
    --sequence_length 1024 \
    --total_batch_size 524288 \
    --overfit_batch_number 20 \
    --dtype bfloat16 \
    --compile 1 \
    --tensorcores 1 \
    --flash 1 \
    --num_iterations 2000 \
    --weight_decay 0.1 \
    --zero_stage 1 \
    --learning_rate ${lr} \
    --warmup_iters 700 \
    --learning_rate_decay_frac 0.0 \
    --opt_type adamw
done
done
