

data_root=
model_root=${data_root}/output
BASE_MODEL=Llama-2-7b-hf
MODEL=${BASE_MODEL}_flan_lora_qv_merged
REWARD_MODEL=${BASE_MODEL}_hh_harmless_rm
NUM_GPUS=6
BATCH_SIZE_PER_GPU=2
TOTAL_BATCH_SIZE=240
GRADIENT_ACC_STEPS=$(($TOTAL_BATCH_SIZE/$NUM_GPUS/$BATCH_SIZE_PER_GPU))
echo "Training llama model ${MODEL_SIZE} using $NUM_GPUS GPUs, $BATCH_SIZE_PER_GPU batch size per GPU, $GRADIENT_ACC_STEPS gradient accumulation steps"

# datasets=(hh_rlhf_harmless/hh_rlhf_harmless_data_train BeaverTails/beavertails_data_train HarmfulQA/harmfulqa_data flan_v2/flan_v2_data)
# save_name=(hh_harmless beavertails harmfulqa flan)
datasets=(hh_rlhf_harmless/train)
save_name=(hh_harmless_ppo)
seed=1

# prompt tuning
# for (( i=0; i<${#datasets[*]}; ++i))
# do
#     accelerate launch \
#         --use_deepspeed \
#         --deepspeed_config_file configs/ds_configs/stage2_no_offloading_accelerate.conf \
#         --mixed_precision bf16 \
#         --num_machines 1 \
#         --num_processes $NUM_GPUS \
#         training/finetune.py \
#         --train_file ${data_root}/data/processed/${datasets[$i]}.jsonl \
#         --model_name_or_path ${model_root}/${MODEL} \
#         --use_flash_attn \
#         --use_prompt_tuning \
#         --tokenizer_name ${model_root}/${MODEL} \
#         --use_slow_tokenizer \
#         --max_seq_length 2048 \
#         --preprocessing_num_workers 100 \
#         --checkpointing_steps epoch \
#         --per_device_train_batch_size $BATCH_SIZE_PER_GPU \
#         --gradient_accumulation_steps $GRADIENT_ACC_STEPS \
#         --learning_rate 1e-4 \
#         --lr_scheduler_type linear \
#         --warmup_ratio 0.03 \
#         --weight_decay 0. \
#         --num_train_epochs 3 \
#         --output_dir ${data_root}/output/${MODEL}_${save_name[$i]}_${seed}/ \
#         --report_to wandb \
#         --with_tracking \
#         --logging_steps 10 \
#         --num_virtual_tokens 64 \
#         --prompt_tuning_init_text "You are a helpful and harmless AI assistant." \
#         --seed ${seed} 
# done

# IA3 training
# for (( i=3; i<${#datasets[*]}; ++i))
# do
#     accelerate launch \
#         --use_deepspeed \
#         --deepspeed_config_file configs/ds_configs/stage2_no_offloading_accelerate.conf \
#         --mixed_precision bf16 \
#         --num_machines 1 \
#         --num_processes $NUM_GPUS \
#         training/finetune.py \
#         --train_file ${data_root}/data/processed/${datasets[i]}.jsonl \
#         --model_name_or_path ${model_root}/${MODEL} \
#         --use_flash_attn \
#         --use_ia3 \
#         --tokenizer_name ${model_root}/${MODEL} \
#         --use_slow_tokenizer \
#         --max_seq_length 2048 \
#         --preprocessing_num_workers 100 \
#         --checkpointing_steps epoch \
#         --per_device_train_batch_size $BATCH_SIZE_PER_GPU \
#         --gradient_accumulation_steps $GRADIENT_ACC_STEPS \
#         --learning_rate 1e-4 \
#         --lr_scheduler_type linear \
#         --warmup_ratio 0.03 \
#         --weight_decay 0. \
#         --num_train_epochs 3 \
#         --output_dir ${data_root}/output/${MODEL}_${save_name[i]}_${seed}/ \
#         --report_to wandb \
#         --with_tracking \
#         --logging_steps 10 \
#         --ia3_module down_proj \
#         --feedforward_modules down_proj \
#         --seed ${seed} 
# done

# LoRA training
for (( i=0; i<${#datasets[*]}; ++i))
do
    accelerate launch \
        --use_deepspeed \
        --deepspeed_config_file configs/ds_configs/stage3_no_offloading_accelerate.conf \
        --mixed_precision bf16 \
        --num_machines 1 \
        --num_processes $NUM_GPUS \
        training/ppo.py \
        --train_file ${data_root}/data/raw_train/${datasets[i]}.jsonl \
        --model_name_or_path ${model_root}/${MODEL} \
        --reward_model_name_or_path ${model_root}/${REWARD_MODEL} \
        --use_flash_attn \
        --early_stopping \
        --optimize_device_cache \
        --tokenizer_name ${model_root}/${MODEL} \
        --max_seq_length 4096 \
        --max_prompt_length 2048 \
        --torch_dtype "float16" \
        --preprocessing_num_workers 64 \
        --per_device_train_batch_size $BATCH_SIZE_PER_GPU \
        --gradient_accumulation_steps $GRADIENT_ACC_STEPS \
        --learning_rate 1e-4 \
        --lr_scheduler_type linear \
        --warmup_ratio 0.03 \
        --weight_decay 0. \
        --num_train_epochs 3 \
        --output_dir ${data_root}/output/${BASE_MODEL}_${save_name[i]} \
        --report_to 'none' \
        --logging_steps 10 \
        --seed 1 
        # --use_ia3 \
        # --ia3_module down_proj \
        # --feedforward_modules down_proj 
        # --use_lora \
        # --lora_rank 64 \
        # --lora_alpha 128 \
        # --lora_dropout 0.1 \
        # --lora_module v_proj q_proj \
        # --gradient_checkpointing
done

# for seed in 1
# do
#     accelerate launch \
#         --use_deepspeed \
#         --deepspeed_config_file configs/ds_configs/stage2_no_offloading_accelerate.conf \
#         --mixed_precision bf16 \
#         --num_machines 1 \
#         --num_processes $NUM_GPUS \
#         training/finetune.py \
#         --train_file ${data_root}/processed/HarmfulQA/harmfulqa_data.jsonl \
#         --model_name_or_path ${model_root}/${MODEL} \
#         --use_flash_attn \
#         --use_lora \
#         --lora_rank 64 \
#         --lora_alpha 16 \
#         --lora_dropout 0.1 \
#         --tokenizer_name ${model_root}/${MODEL} \
#         --use_slow_tokenizer \
#         --max_seq_length 2048 \
#         --preprocessing_num_workers 100 \
#         --checkpointing_steps epoch \
#         --per_device_train_batch_size $BATCH_SIZE_PER_GPU \
#         --gradient_accumulation_steps $GRADIENT_ACC_STEPS \
#         --learning_rate 1e-4 \
#         --lr_scheduler_type linear \
#         --warmup_ratio 0.03 \
#         --weight_decay 0. \
#         --num_train_epochs 5 \
#         --output_dir output/${MODEL}_harmfulqa_kq_${seed}/ \
#         --report_to wandb \
#         --with_tracking \
#         --logging_steps 10 \
#         --lora_module k_proj q_proj \
#         --seed ${seed} \
#         --gradient_checkpointing
# done
#     # --gradient_checkpointing
#     # --with_tracking \
#     # --use_qlora \
# accelerate launch \
#     --use_deepspeed \
#     --deepspeed_config_file configs/ds_configs/stage3_offloading_accelerate.conf \
#     --mixed_precision bf16 \
#     --num_machines 1 \
#     --num_processes $NUM_GPUS \
#     training/finetune.py \
#     --train_file ${data_root}/processed/sharegpt/sharegpt_data.jsonl \
#     --model_name_or_path ${model_root}/${MODEL} \
#     --use_flash_attn \
#     --use_lora \
#     --lora_rank 64 \
#     --lora_alpha 16 \
#     --lora_dropout 0.1 \
#     --tokenizer_name ${model_root}/${MODEL} \
#     --use_slow_tokenizer \
#     --max_seq_length 1024 \
#     --preprocessing_num_workers 64 \
#     --checkpointing_steps 300 \
#     --per_device_train_batch_size $BATCH_SIZE_PER_GPU \
#     --gradient_accumulation_steps $GRADIENT_ACC_STEPS \
#     --learning_rate 1e-4 \
#     --lr_scheduler_type linear \
#     --warmup_ratio 0.03 \
#     --weight_decay 0. \
#     --num_train_epochs 1 \
#     --output_dir output/${MODEL}_lora_sharegpt_kq/ \
#     --report_to wandb \
#     --logging_steps 10 \
#     --lora_module k_proj q_proj\
    # --gradient_checkpointing \
    # --with_tracking \


# accelerate launch \
#     --mixed_precision bf16 \
#     --num_machines 1 \
#     --num_processes $NUM_GPUS \
#     training/finetune.py \
#     --train_file ${data_root}/processed/hh_rlhf_harmless/hh_rlhf_harmless_data_train.jsonl \
#     --model_name_or_path ${model_root}/${MODEL} \
#     --use_flash_attn \
#     --use_lora \
#     --lora_rank 64 \
#     --lora_alpha 16 \
#     --lora_dropout 0.1 \
#     --tokenizer_name ${model_root}/${MODEL} \
#     --use_slow_tokenizer \
#     --max_seq_length 2048 \
#     --preprocessing_num_workers 100 \
#     --checkpointing_steps 300 \
#     --per_device_train_batch_size $BATCH_SIZE_PER_GPU \
#     --gradient_accumulation_steps $GRADIENT_ACC_STEPS \
#     --learning_rate 1e-4 \
#     --lr_scheduler_type linear \
#     --warmup_ratio 0.03 \
#     --weight_decay 0. \
#     --num_train_epochs 3 \
#     --output_dir output/${MODEL}_lora_harmless_mlp_all/ \
#     --report_to wandb \
#     --logging_steps 10 \
#     --lora_module up_proj gate_proj down_proj\
#     --with_tracking \

# python open_instruct/merge_lora.py \
#     --base_model_name_or_path ../hf_llama2_models/${MODEL_SIZE} \
#     --lora_model_name_or_path output/tulu_v2_${MODEL_SIZE}_lora/ \
#     --output_dir output/tulu_v2_${MODEL_SIZE}_lora_merged/ \
#     --save_tokenizer
