load_model="$PROJECT_ROOT/midi_rwkv.pth"
proj_dir="$PROJECT_ROOT/RWKV-PEFT/peft_model"

n_layer=12
n_embd=384

micro_bsz=1
epoch_save=2
ctx_len=2048
train_epochs=16
lr=5e-2 # TODO play around with lr: openmose recs 1e-4, 1e-2 seems to be stable, 5e-2 appears best

export TOKENIZERS_PARALLELISM=false

python3 train.py --load_model $load_model \
--proj_dir $proj_dir --n_layer $n_layer --n_embd $n_embd --vocab_size 16000 \
--ctx_len $ctx_len --micro_bsz $micro_bsz \
--epoch_count $train_epochs --epoch_begin 0 --epoch_save $epoch_save \
--lr_init $lr --lr_final $lr --warmup_steps 20 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision tf32 --strategy deepspeed_stage_1 --grad_cp 0 \
--my_testing "x070" \
--train_type "state"  --op triton # triton
# TODO it looks like the triton backend does not like bf16 for whatever reason