#!/bin/bash

DATA_PATH="pissa-dataset"

LR=1e-5
RANK=128

BASE_MODEL="meta-llama/Llama-2-7b-hf"
MODEL="Llama-2-7b"

RES_MODEL="output/$MODEL-PiSSA-r$RANK"

if [ -e $RES_MODEL ]; then
    echo "Use pre-initialized residual model."
else
    echo "Perform PiSSA initialization by my self."
    python init_pissa.py --base_model_path $BASE_MODEL \
                         --output_dir $RES_MODEL \
                         --init_weights pissa_niter_16 \
                         --lora_r $RANK \
                         --lora_alpha $RANK \
                         --lora_dropout 0 \
                         --target_modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj
fi

OUTPUT_PATH="output/metamath-395k-$MODEL-PiSSA-r$RANK-lr$LR"
deepspeed --master_port=16971 --include=localhost:0,1,2,3,4,5,6,7 train.py \
    --deepspeed configs/ds_config_zero2.json \
    --model_name_or_path $RES_MODEL \
    --full_finetune False \
    --bf16 \
    --adapter_name_or_path "pissa_init" \
    --data_path $DATA_PATH \
    --sub_task metamath \
    --dataset_split train \
    --dataset_field instruction output \
    --output_dir $OUTPUT_PATH \
    --num_train_epochs 1 \
    --model_max_length 512 \
    --per_device_train_batch_size 16 \
    --gradient_accumulation_steps 1 \
    --save_strategy "steps" \
    --save_steps 100 \
    --save_total_limit 100 \
    --learning_rate $LR \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --logging_steps 1 \
    --lr_scheduler_type "cosine" \
    --report_to "tensorboard" \

CKPT=200
while [ $CKPT -le 3000 ]; do
    echo "checkpoint-$CKPT"
    ROOT_PATH="$OUTPUT_PATH/checkpoint-$CKPT"
    python pp.py --output_dir $ROOT_PATH --model_name_or_path $RES_MODEL
    python utils/gen_vllm.py --model $ROOT_PATH/test --sub_task metamath --output_file $ROOT_PATH/metamath_response.json
    python utils/test_acc.py --input_file $ROOT_PATH/metamath_response.json > $ROOT_PATH/acc.txt
    rm -r $ROOT_PATH/test
    CKPT=$((CKPT + 200))
done

# ROOT_PATH="$OUTPUT_PATH/checkpoint-781"
# python pp.py --output_dir $ROOT_PATH --model_name_or_path $RES_MODEL
# python utils/gen_vllm.py --model $ROOT_PATH/test --sub_task metamath --output_file $ROOT_PATH/metamath_response.json
# python utils/test_acc.py --input_file $ROOT_PATH/metamath_response.json > $ROOT_PATH/acc.txt
# rm -r $ROOT_PATH/test