#LEARNABLE
NVIDIA_TF32_OVERRIDE=0 JAX_DEFAULT_MATMUL_PRECISION=highest CUBLAS_WORKSPACE_CONFIG=:4096:8 CUDA_VISIBLE_DEVICES=0,1 python src/lgmodeling/att_matching.py \
    --seed 0 --tgt_len 256 --mem_len 256 --eval_tgt_len 256 \
    --model-a /root/weights/lm1b/gpt2-finetune/finetune-learnable-indice0-heads12-shared1-routed0-topk0-seed0/best_500000 \
    --model-b /root/weights/lm1b/gpt2-finetune/finetune-learnable-indice0-heads12-shared1-routed0-topk0-seed20/best_500000 \
    --data-path /root/datasets/lm1b --dataset lm1b
#ROPE
NVIDIA_TF32_OVERRIDE=0 JAX_DEFAULT_MATMUL_PRECISION=highest CUBLAS_WORKSPACE_CONFIG=:4096:8 CUDA_VISIBLE_DEVICES=0,1 python src/lgmodeling/att_matching.py \
    --seed 0 --tgt_len 256 --mem_len 256 --eval_tgt_len 256 \
    --model-a /root/weights/lm1b/gpt2-finetune/finetune-rope-indice0-heads12-shared1-routed0-topk0-seed0/best_500000 \
    --model-b /root/weights/lm1b/gpt2-finetune/finetune-rope-indice0-heads12-shared1-routed0-topk0-seed20/best_500000 \
    --data-path /root/datasets/lm1b --dataset lm1b