### do "conda activate gpt" before running this script
export WANDB_MODE=online
export WANDB_API_KEY=1159dda0d0566b72d5cd71464a06ff6b73efb455
export WANDB__SERVICE_WAIT=300

DATA_DIR=./open_small
OUT_DIR=./
WANDB_PROJ=gpt_final

BATCH_SIZE=512
GRAD_ACCUM=1
MAX_ITERS=100_000
BLOCK_SIZE=128

# Token batch size: 512 * 128 = 65_536
# Short seq regime: d >> N / 6 ~= 20

#### BTT ####
# α, β, γ, δ, ε, φ, ρ
# r=0, slsl
# # matching
# v1="(0.5|0|0.5|0|0.5|0.5|0)"
# v2="(0.33|0|0.67|0|0.67|0.33|0)"
# v3="(0.25|0|0.75|0|0.75|0.25|0)"
# # mismatch
# v4="(0.67|0|0.33|0|0.67|0.33|0)"

# # r=0.25
# u1="(0.5|0|0.5|0|0.5|0.5|0.25)"
# u2="(0.33|0|0.67|0|0.67|0.33|0.25)"
# u3="(0.25|0|0.75|0|0.75|0.25|0.25)"
# # mismatch
# u4="(0.67|0|0.33|0|0.67|0.33|0.25)"

# # large first
# # matching
# v2="(0.67|0|0.33|0|0.33|0.67|0)"
# v3="(0.75|0|0.25|0|0.25|0.75|0)"
# v4="(0.33|0|0.67|0|0.33|0.67|0)"

# # r=0.25
# u2="(0.67|0|0.33|0|0.33|0.67|0.25)"
# u3="(0.75|0|0.25|0|0.25|0.75|0.25)"
# # mismatch
# u4="(0.33|0|0.67|0|0.33|0.67|0.25)"

#### TT ####
# α, β, γ, δ, ε, φ, ρ
# r=0, slsl
tt1="(0.5|0.5|0|0.5|0.5|0|0)"
tt2="(0.33|0.67|0|0.33|0.67|0|0)" 
tt3="(0.5|0.5|0|0.5|0.5|0|0.25)"
tt4="(0.33|0.67|0|0.33|0.67|0|0.25)" 

#### TT + BTT ####
ttb1="(0.5|0.25|0.25|0.25|0.5|0.25|0)"
ttb2="(0.5|0.25|0.25|0.25|0.5|0.25|0.25)"

#### Low Rank ####
l1="(1|0|0|0|1|0|0.75)"
l2="(1|0|0|0|1|0|0.5)"
l3="(1|0|0|0|1|0|0.25)"

# BTT
lr=3e-3
for d_model in 1024 768 512; do
# for vec in ${v1} ${v2} ${v3} ${v4} ${u1} ${u2} ${u3} ${u4}; do
# for vec in ${v2} ${v3} ${v4} ${u2} ${u3} ${u4}; do
# for vec in ${tt1} ${tt2} ${tt3} ${tt4} ${ttb1} ${ttb2} ${l1} ${l2} ${l3}; do
for vec in ${tt1} ${tt2} ${tt3} ${tt4} ${ttb1} ${ttb2} ${l1} ${l2} ${l3}; do
ALLOWED_GPUs="0 1 2 3 4 5 6 7"
gpu_id=$(get_free_gpu "${ALLOWED_GPUs}")
CUDA_VISIBLE_DEVICES=${gpu_id} python train_gpt.py config/train_open_small.py --block_size=${BLOCK_SIZE} --struct=simple_ein_vec_norm --expr=${vec} --layers=all_but_last --d_model=${d_model} --n_layer=3 --n_head=-1 --d_head=64 --max_iters=${MAX_ITERS} --data_dir=${DATA_DIR} --out_dir=${OUT_DIR} --batch_size=${BATCH_SIZE} --gradient_accumulation_steps=${GRAD_ACCUM} --init_lr=${lr} --wandb_project=${WANDB_PROJ} &
sleep 20
done
done