BASE_MODEL="meta-llama/Llama-2-7b-hf"
DATA_PATH="pissa-dataset"

DELTA=10
PNORM=fro
LR=5e-4
RANK=128

OUTPUT_PATH="output/metamath-100k-Llama-2-7b-NBLoRA-r$RANK-d$DELTA-$PNORM-lr$LR"
deepspeed --master_port=16971 --include=localhost:0,1,2,3,4,5,6,7 train.py \
    --deepspeed configs/ds_config_zero2.json \
    --model_name_or_path $BASE_MODEL \
    --full_finetune False \
    --bf16 \
    --init_weights True \
    --target_modules "q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj" \
    --lora_rank $RANK \
    --lora_alpha $RANK \
    --lora_delta $DELTA \
    --lora_pnorm $PNORM \
    --lora_dropout 0 \
    --data_path $DATA_PATH \
    --sub_task metamath:99968 \
    --dataset_split train \
    --dataset_field instruction output \
    --output_dir $OUTPUT_PATH \
    --num_train_epochs 1 \
    --model_max_length 512 \
    --per_device_train_batch_size 16 \
    --gradient_accumulation_steps 1 \
    --save_strategy "steps" \
    --save_steps 1000 \
    --save_total_limit 10 \
    --learning_rate $LR \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --logging_steps 1 \
    --lr_scheduler_type "cosine" \
    --report_to "tensorboard" \

ROOT_PATH="$OUTPUT_PATH/checkpoint-781"
python pp.py --output_dir $ROOT_PATH --model_name_or_path $BASE_MODEL
python utils/gen_vllm.py --model $ROOT_PATH/test --sub_task metamath --output_file $ROOT_PATH/metamath_response.json
python utils/test_acc.py --input_file $ROOT_PATH/metamath_response.json > $ROOT_PATH/acc.txt
rm -r $ROOT_PATH/test