#LEARNABLE
NVIDIA_TF32_OVERRIDE=0 JAX_DEFAULT_MATMUL_PRECISION=highest CUBLAS_WORKSPACE_CONFIG=:4096:8 WANDB_MODE=online CUDA_VISIBLE_DEVICES=0\
    python src/lgmodeling/finetune.py --n_head 8 --lmc-layer-indices 0 --seed 0 \
    --model-path /root/weights/enwik8/lr0.00025-learnable-step60000-warm0-size24-layer12-embd512-heads8-shared1-routed0-topk0/best_60000 \
    --tgt_len 512 --mem_len 512 --eval_tgt_len 128 --learning-rate 0.00025 --batch-size 24 --max_step 60000 --warmup_step 0 --dataset enwik8 \
    --model-save-dir /root/weights/enwik8/gpt2-finetune --data-path /root/datasets/enwik8
#ROPE
NVIDIA_TF32_OVERRIDE=0 JAX_DEFAULT_MATMUL_PRECISION=highest CUBLAS_WORKSPACE_CONFIG=:4096:8 WANDB_MODE=online CUDA_VISIBLE_DEVICES=0\
    python src/lgmodeling/finetune.py --n_head 8 --lmc-layer-indices 0 --seed 0 \
    --model-path /root/weights/enwik8/lr0.00025-rope-step60000-warm0-size24-layer12-embd512-heads8-shared1-routed0-topk0/best_60000 \
    --tgt_len 512 --mem_len 512 --eval_tgt_len 128 --learning-rate 0.00025 --batch-size 24 --max_step 60000 --warmup_step 0 --dataset enwik8 \
    --model-save-dir /root/weights/enwik8/gpt2-finetune --data-path /root/datasets/enwik8