#LEARNABLE
NVIDIA_TF32_OVERRIDE=0 JAX_DEFAULT_MATMUL_PRECISION=highest CUBLAS_WORKSPACE_CONFIG=:4096:8 WANDB_MODE=online CUDA_VISIBLE_DEVICES=0,1\
    python src/lgmodeling/train_model.py --seed 0 --tgt_len 256 --mem_len 0 --eval_tgt_len 256 --position-embeddings "learnable" \
    --num-shared-experts 1 --num-routed-experts 0 --topk 0 --rotary-dim 64 --n_layer 12 --n_embd 768 --n_head 12 --n_inner 3072\
    --attention-bias --learning-rate 0.00025 --batch-size 96 --max_step 500000 --warmup_step 2000 --dataset lm1b \
    --model-save-dir /root/weights/lm1b --data-path /root/datasets/lm1b
#ROPE
NVIDIA_TF32_OVERRIDE=0 JAX_DEFAULT_MATMUL_PRECISION=highest CUBLAS_WORKSPACE_CONFIG=:4096:8 WANDB_MODE=online CUDA_VISIBLE_DEVICES=0,1\
    python src/lgmodeling/train_model.py --seed 0 --tgt_len 256 --mem_len 0 --eval_tgt_len 256 --position-embeddings "rope" \
    --num-shared-experts 1 --num-routed-experts 0 --topk 0 --rotary-dim 64 --n_layer 12 --n_embd 768 --n_head 12 --n_inner 3072\
    --attention-bias --learning-rate 0.00025 --batch-size 96 --max_step 500000 --warmup_step 2000 --dataset lm1b \
    --model-save-dir /root/weights/lm1b --data-path /root/datasets/lm1b