model=$1  # Qwen2.5-1.5B-Instruct
save_path=$2 # /workspace/output/medqa_dpo/Qwen2.5-1.5B-Instruct

code_path=/workspace/LLaMA-Factory
cd $code_path

dataset_path=/workspace/HSIR/data/MedQA
lr=1e-5
template=qwen

## Initial SFT
task=medqa_train_ds_seed
output_dir=${save_path}/${task}
mkdir -p $output_dir

train_log=$output_dir/training.log

# model training
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run \
    --nproc_per_node 8 \
    --master_port 12138 \
    src/train.py \
    --deepspeed examples/deepspeed/ds_z3_config.json \
    --stage sft \
    --template $template \
    --model_name_or_path /workspace/huggingface/$model \
    --do_train \
    --dataset $task \
    --dataset_dir ${dataset_path} \
    --finetuning_type full \
    --output_dir $output_dir \
    --overwrite_cache False \
    --cutoff_len 2048 \
    --warmup_ratio 0.1 \
    --overwrite_output_dir \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --val_size 0 \
    --per_device_eval_batch_size 1 \
    --save_strategy no \
    --learning_rate $lr \
    --num_train_epochs 3.0 \
    --plot_loss \
    --bf16 2>&1 | tee ${train_log}

# model evaluation
bash /workspace/HSIR/src/MedQA/step8_model_inference.sh $output_dir /workspace/HSIR/data/testset/medqa_testset.json


lr=1e-6
## HSIR Iteration 1

solution_path=$output_dir/solutions

# obtain the self-generated solutions
bash /workspace/HSIR/src/MedQA/step1_rejection_sampling.sh $output_dir $solution_path /workspace/HSIR/data/MedQA/medqa_train_ds_unlabeled.json step1_rejection_sampling 10
# obtain the verify-then-exit sampling solutions
bash /workspace/HSIR/src/MedQA/step2_verify-then-exit_sampling.sh $output_dir $solution_path $solution_path/step1_rejection_sampling.json step2_verify_then_exit_sampling 5
# verify the correctness of solutions, and split them into right group and wrong group
python /workspace/HSIR/src/MedQA/step3_splict_right_wrong.py $solution_path step2_verify_then_exit_sampling $solution_path/right_and_wrong_solutions.json
# calculate the indiv score for correct solutions
bash /workspace/HSIR/src/MedQA/step4_cal_indiv_score.sh $output_dir $solution_path/right_and_wrong_solutions.json $solution_path right_solutions_indiv_scores
# filter the data with lower indiv scores
python /workspace/HSIR/src/MedQA/step5_filter_data.py --data_path $solution_path/right_solutions_indiv_scores.json --save_path $solution_path/right_solutions_indiv_scores_filtered.json --indiv_score_thre -0.5
# prepare the SFT dataset
python /workspace/HSIR/src/MedQA/step6_dpo_dataset_prepare.py $solution_path right_solutions_indiv_scores_filtered $dataset_path/HSIR-DPO_iter1.json

# model training
task=HSIR-DPO_iter1
output_dir=${save_path}/${task}
mkdir -p $output_dir

train_log=$output_dir/training.log

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run \
    --nproc_per_node 8 \
    --master_port 12138 \
    src/train.py \
    --deepspeed examples/deepspeed/ds_z3_config.json \
    --stage dpo \
    --pref_ftx 0.5 \
    --template $template \
    --model_name_or_path ${save_path}/medqa_train_ds_seed \
    --do_train \
    --dataset medqa_train_ds_seed,$task \
    --dataset_dir ${dataset_path} \
    --finetuning_type full \
    --output_dir $output_dir \
    --overwrite_cache False \
    --cutoff_len 2048 \
    --warmup_ratio 0.1 \
    --overwrite_output_dir \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 1 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --val_size 0 \
    --per_device_eval_batch_size 1 \
    --save_strategy no \
    --learning_rate $lr \
    --num_train_epochs 3.0 \
    --plot_loss \
    --bf16 2>&1 | tee ${train_log}

# model evaluation
bash /workspace/HSIR/src/MedQA/step8_model_inference.sh $output_dir /workspace/HSIR/data/testset/medqa_testset.json


lr=1e-6
## HSIR Iteration 2

solution_path=$output_dir/solutions

# obtain the self-generated solutions
bash /workspace/HSIR/src/MedQA/step1_rejection_sampling.sh $output_dir $solution_path /workspace/HSIR/data/MedQA/medqa_train_ds_unlabeled.json step1_rejection_sampling 10
# obtain the verify-then-exit sampling solutions
bash /workspace/HSIR/src/MedQA/step2_verify-then-exit_sampling.sh $output_dir $solution_path $solution_path/step1_rejection_sampling.json step2_verify_then_exit_sampling 5
# verify the correctness of solutions, and split them into right group and wrong group
python /workspace/HSIR/src/MedQA/step3_splict_right_wrong.py $solution_path step2_verify_then_exit_sampling $solution_path/right_and_wrong_solutions.json
# calculate the indiv score for correct solutions
bash /workspace/HSIR/src/MedQA/step4_cal_indiv_score.sh $output_dir $solution_path/right_and_wrong_solutions.json $solution_path right_solutions_indiv_scores
# filter the data with lower indiv scores
python /workspace/HSIR/src/MedQA/step5_filter_data.py --data_path $solution_path/right_solutions_indiv_scores.json --save_path $solution_path/right_solutions_indiv_scores_filtered.json --indiv_score_thre -0.5
# prepare the SFT dataset
python /workspace/HSIR/src/MedQA/step6_dpo_dataset_prepare.py $solution_path right_solutions_indiv_scores_filtered $dataset_path/HSIR-DPO_iter2.json

# model training
task=HSIR-DPO_iter2
output_dir=${save_path}/${task}
mkdir -p $output_dir

train_log=$output_dir/training.log

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run \
    --nproc_per_node 8 \
    --master_port 12138 \
    src/train.py \
    --deepspeed examples/deepspeed/ds_z3_config.json \
    --stage dpo \
    --pref_ftx 0.5 \
    --template $template \
    --model_name_or_path ${save_path}/HSIR-DPO_iter1 \
    --do_train \
    --dataset medqa_train_ds_seed,$task \
    --dataset_dir ${dataset_path} \
    --finetuning_type full \
    --output_dir $output_dir \
    --overwrite_cache False \
    --cutoff_len 2048 \
    --warmup_ratio 0.1 \
    --overwrite_output_dir \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 1 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --val_size 0 \
    --per_device_eval_batch_size 1 \
    --save_strategy no \
    --learning_rate $lr \
    --num_train_epochs 3.0 \
    --plot_loss \
    --bf16 2>&1 | tee ${train_log}

# model evaluation
bash /workspace/HSIR/src/MedQA/step8_model_inference.sh $output_dir /workspace/HSIR/data/testset/medqa_testset.json


lr=1e-6
## HSIR Iteration 3

solution_path=$output_dir/solutions

# obtain the self-generated solutions
bash /workspace/HSIR/src/MedQA/step1_rejection_sampling.sh $output_dir $solution_path /workspace/HSIR/data/MedQA/medqa_train_ds_unlabeled.json step1_rejection_sampling 10
# obtain the verify-then-exit sampling solutions
bash /workspace/HSIR/src/MedQA/step2_verify-then-exit_sampling.sh $output_dir $solution_path $solution_path/step1_rejection_sampling.json step2_verify_then_exit_sampling 5
# verify the correctness of solutions, and split them into right group and wrong group
python /workspace/HSIR/src/MedQA/step3_splict_right_wrong.py $solution_path step2_verify_then_exit_sampling $solution_path/right_and_wrong_solutions.json
# calculate the indiv score for correct solutions
bash /workspace/HSIR/src/MedQA/step4_cal_indiv_score.sh $output_dir $solution_path/right_and_wrong_solutions.json $solution_path right_solutions_indiv_scores
# filter the data with lower indiv scores
python /workspace/HSIR/src/MedQA/step5_filter_data.py --data_path $solution_path/right_solutions_indiv_scores.json --save_path $solution_path/right_solutions_indiv_scores_filtered.json --indiv_score_thre -0.5
# prepare the SFT dataset
python /workspace/HSIR/src/MedQA/step6_dpo_dataset_prepare.py $solution_path right_solutions_indiv_scores_filtered $dataset_path/HSIR-DPO_iter3.json

# model training
task=HSIR-DPO_iter3
output_dir=${save_path}/${task}
mkdir -p $output_dir

train_log=$output_dir/training.log

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run \
    --nproc_per_node 8 \
    --master_port 12138 \
    src/train.py \
    --deepspeed examples/deepspeed/ds_z3_config.json \
    --stage dpo \
    --pref_ftx 0.5 \
    --template $template \
    --model_name_or_path ${save_path}/HSIR-DPO_iter2 \
    --do_train \
    --dataset medqa_train_ds_seed,$task \
    --dataset_dir ${dataset_path} \
    --finetuning_type full \
    --output_dir $output_dir \
    --overwrite_cache False \
    --cutoff_len 2048 \
    --warmup_ratio 0.1 \
    --overwrite_output_dir \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 1 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --val_size 0 \
    --per_device_eval_batch_size 1 \
    --save_strategy no \
    --learning_rate $lr \
    --num_train_epochs 3.0 \
    --plot_loss \
    --bf16 2>&1 | tee ${train_log}

# model evaluation
bash /workspace/HSIR/src/MedQA/step8_model_inference.sh $output_dir /workspace/HSIR/data/testset/medqa_testset.json