#LEARNABLE
NVIDIA_TF32_OVERRIDE=0 JAX_DEFAULT_MATMUL_PRECISION=highest CUBLAS_WORKSPACE_CONFIG=:4096:8 WANDB_MODE=online CUDA_VISIBLE_DEVICES=0 \
    python src/lgmodeling/train_model.py --seed 0 --tgt_len 512 --mem_len 512 --eval_tgt_len 128 --position-embeddings "learnable" \
    --num-shared-experts 1 --num-routed-experts 0 --topk 0 --rotary-dim 64 --n_layer 12 --n_embd 512 --n_head 8 --n_inner 2048 \
    --attention-bias --learning-rate 0.00025 --batch-size 24 --max_step 60000 --warmup_step 0 --dataset enwik8 \
    --model-save-dir /root/weights/enwik8 --data-path /root/datasets/enwik8
#ROPE
NVIDIA_TF32_OVERRIDE=0 JAX_DEFAULT_MATMUL_PRECISION=highest CUBLAS_WORKSPACE_CONFIG=:4096:8 WANDB_MODE=online CUDA_VISIBLE_DEVICES=0 \
    python src/lgmodeling/train_model.py --seed 0 --tgt_len 512 --mem_len 512 --eval_tgt_len 128 --position-embeddings "rope" \
    --num-shared-experts 1 --num-routed-experts 0 --topk 0 --rotary-dim 64 --n_layer 12 --n_embd 512 --n_head 8 --n_inner 2048 \
    --attention-bias --learning-rate 0.00025 --batch-size 24 --max_step 60000 --warmup_step 0 --dataset enwik8 \
    --model-save-dir /root/weights/enwik8 --data-path /root/datasets/enwik8