export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/anaconda3/lib


DEV=0
n_gpus=8
device_shift=0

randomize_features=0
feature_coordinates='0_1_2_3_4_5'
feature_complexity=6
num_labels=2


# TODO: change this to your own
local_dir=''
logging_path='logs_parity/log_save10k'
output_path='ckpts_parity/ckpts_save10k'


data_dimension=100
vocab_size=2
use_cls_head=0
tie_word_embeddings=0
add_cls_token=0


n_steps=8000000
if [ $DEV = 1 ]; then
    n_steps=10000
fi
n_examples=$n_steps


linear_mlp=0
skip_mlp=0

cnt=0
for seed in 0 1 2
do
for head_dim in 8
do
for num_layers in 2
do 
for batch_size in 32
do
for learning_rate in 3e-4
do
for n_heads in 16
do 
for subsample in -1
do 
for weight_decay in 0
do
hidden_size=$((n_heads * head_dim))
warmup_ratio=0.06

eval_batch_size=128

num_workers=16

device_id=$((cnt % n_gpus))
device_id=$((device_id+device_shift))


if [ $batch_size = 1 ]; then 
  save_intvl=100000
  log_intvl=100000
elif [ $batch_size = 32 ]; then
  save_intvl=5000
  log_intvl=5000
elif [ $batch_size = 8 ]; then 
  save_intvl=4000
  log_intvl=40000
fi 


wandb_mode='online'
if [ $DEV = 1 ]; then
    subexample=4096
    n_epochs=1
    num_workers=0
    wandb_mode='disabled'
fi



WANDB_MODE=$wandb_mode \
CUDA_VISIBLE_DEVICES=$device_id \
python boolean_expts.py \
    --model_type='gpt2' \
    --hidden_size=$hidden_size \
    --head_dim=$head_dim \
    --learning_rate=$learning_rate \
    --weight_decay=$weight_decay \
    --logging_path=$logging_path \
    --output_path=$output_path \
    --seed=$seed \
    --n_heads=$n_heads \
    --use_cls_head=$use_cls_head \
    --add_cls_token=$add_cls_token \
    --tie_word_embeddings=$tie_word_embeddings \
    --linear_mlp=$linear_mlp \
    --skip_mlp=$skip_mlp \
    --vocab_size=$vocab_size \
    --num_labels=$num_labels \
    --feature_complexity=$feature_complexity \
    --randomize_features=$randomize_features \
    --feature_coordinates=$feature_coordinates \
    --data_dimension=$data_dimension \
    --n_steps=$n_steps \
    --n_examples=$n_examples \
    --num_layers=$num_layers \
    --num_workers=$num_workers \
    --batch_size=$batch_size \
    --eval_batch_size=$eval_batch_size \
    --subsample=$subsample \
    --log_intvl=$log_intvl \
    --save_intvl=$save_intvl \
    --warmup_ratio=$warmup_ratio &

cnt=$((cnt+1))

if [ $DEV = 1 ]; then
    exit
fi
done 
done
done
done 
done
done
done 
done