NVIDIA_TF32_OVERRIDE=0 JAX_DEFAULT_MATMUL_PRECISION=highest CUBLAS_WORKSPACE_CONFIG=:4096:8 WANDB_MODE=offline CUDA_VISIBLE_DEVICES=0  python src/llama/finetune.py \
    --model-path ~/weights/lmc/llama-wt103/lr0.00025-step60000-warm2000-batch64-layer12-hidden768-heads12-typeGLU-seed0/best_260000\
    --seed 0 --tgt_len 256 --mem_len 0 --eval_tgt_len 256 --num_attention_heads 3 --num_key_value_heads 3 --head_dim 64 --lmc-layer-indices 0 1 2 3 4 5 6 7 8 9 10 11\
    --learning-rate 0.00025 --batch-size 64 --max_step 60000 --warmup_step 2000 --dataset wt103 --eval-frequency 2000 --save-frequency 2000\
    --wandb-project LMC-Attention --wandb-group "Llama-OneBillionWord-FFN" \
    --model-save-dir ~/weights/lmc/llama-wt103/llama-finetune --data-path ~/datasets/wt103



