BASE_MODEL="meta-llama/Llama-2-7b-hf"
DATA_PATH="meta-math/MetaMathQA"

peft=$1
lr=$2


OUTPUT="output/checkpoint/$peft-$lr/"

rm output/checkpoint/$peft-$lr/ft/*  
python finetune.py \
    --model_name_or_path $BASE_MODEL \
    --output_dir $OUTPUT \
    --hrft_r 8 \
    --peft_type $peft \
    --data_path $DATA_PATH \
    --dataset_split "train[:100000]"\
    --dataset_field query response \
    --num_train_epochs 2 \
    --per_device_train_batch_size 8 \
    --gradient_accumulation_steps 2 \
    --save_strategy "steps" \
    --save_steps 0 \
    --save_total_limit 0 \
    --learning_rate $lr \
    --weight_decay 0. \
    --warmup_ratio 0.005 \
    --lr_scheduler_type "cosine" \
    --logging_steps 4 \
    --report_to "wandb"

python inference/gsm8k_inference.py --model "$OUTPUT/ft/" |& tee "gsm8k-$peft-$lr.txt"
python inference/MATH_inference.py --model "$OUTPUT/ft/" |& tee "MATH-$peft-$lr.txt"


