#LEARNABLE
NVIDIA_TF32_OVERRIDE=0 JAX_DEFAULT_MATMUL_PRECISION=highest CUBLAS_WORKSPACE_CONFIG=:4096:8 WANDB_MODE=online CUDA_VISIBLE_DEVICES=0,1 python src/lgmodeling/finetune.py \
    --model-path /root/weights/lm1b/lr0.00025-learnable-step500000-warm2000-size96-layer12-embd768-heads12-shared1-routed0-topk0/best_500000 \
    --seed 0 --tgt_len 256 --mem_len 0 --eval_tgt_len 256 --n_head 3 --lmc-layer-indices 0 \
    --learning-rate 0.00025 --batch-size 96 --max_step 500000 --warmup_step 2000 --dataset lm1b \
    --model-save-dir /root/weights/lm1b/gpt2-finetune --data-path /root/datasets/lm1b

#LEARNABLE
NVIDIA_TF32_OVERRIDE=0 JAX_DEFAULT_MATMUL_PRECISION=highest CUBLAS_WORKSPACE_CONFIG=:4096:8 WANDB_MODE=online CUDA_VISIBLE_DEVICES=0,1  python src/lgmodeling/finetune.py \
    --model-path /root/weights/lm1b/lr0.00025-learnable-step500000-warm2000-size96-layer12-embd768-heads12-shared1-routed0-topk0/best_500000 \
    --seed 0 --tgt_len 256 --mem_len 0 --eval_tgt_len 256 --n_head 3 --lmc-layer-indices 0 \
    --learning-rate 0.00025 --batch-size 96 --max_step 500000 --warmup_step 2000 --dataset lm1b \
    --model-save-dir /root/weights/lm1b/gpt2-finetune --data-path /root/datasets/lm1b

