#LEARNABLE
NVIDIA_TF32_OVERRIDE=0 JAX_DEFAULT_MATMUL_PRECISION=highest CUBLAS_WORKSPACE_CONFIG=:4096:8 WANDB_MODE=online CUDA_VISIBLE_DEVICES=0 \
    python src/lgmodeling/train_model.py --seed 0 --tgt_len 256 --mem_len 0 --eval_tgt_len 256 --position-embeddings "learnable" \
    --num-shared-experts 1 --num-routed-experts 0 --topk 0 --rotary-dim 64 --n_layer 12 --n_embd 192 --n_head 3 --n_inner 768\
    --attention-bias --learning-rate 0.00025 --batch-size 64 --max_step 60000 --warmup_step 0 --dataset wt103 \
    --model-save-dir /root/weights/wt103 --data-path /root/datasets/wt103
#ROPE
NVIDIA_TF32_OVERRIDE=0 JAX_DEFAULT_MATMUL_PRECISION=highest CUBLAS_WORKSPACE_CONFIG=:4096:8 WANDB_MODE=online CUDA_VISIBLE_DEVICES=0 \
    python src/lgmodeling/train_model.py --seed 0 --tgt_len 256 --mem_len 0 --eval_tgt_len 256 --position-embeddings "rope" \
    --num-shared-experts 1 --num-routed-experts 0 --topk 0 --rotary-dim 64 --n_layer 12 --n_embd 192 --n_head 3 --n_inner 768 \
    --attention-bias --learning-rate 0.00025 --batch-size 64 --max_step 60000 --warmup_step 0 --dataset wt103 \
    --model-save-dir /root/weights/wt103 --data-path /root/datasets/wt103