export MODEL_PATH="Llama-3.2-3B"
export SAVE_PATH="Llama-3.2-3B_ntk"
export MASTER_ADDR="localhost"
export MASTER_PORT="1234"
export GLOO_SOCKET_IFNAME="lo"
export NCCL_SOCKET_IFNAME="lo"
export WANDB_DISABLED=true
wandb offline

CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node=4 --use_env run.py \
    --model_name_or_path $MODEL_PATH \
    --data_path "MetaMathQA-395K.json" \
    --data_length 10000000 \
    --bf16 True \
    --output_dir $SAVE_PATH \
    --num_train_epochs 1 \ 
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --evaluation_strategy "no" \
    --save_strategy "epoch" \
    # --save_steps 1000 \
    --save_total_limit 2 \
    --learning_rate 1e-3 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 10

python eval_gsm8k.py --model "Llama-3.2-3B_ntk/checkpoint-12344" --peft_type "ntk"
python eval_math.py --model "Llama-3.2-1B_ntk/checkpoint-12344" --peft_type "ntk"