source ~/.bashrc

# Initialize Conda environment
eval "$(conda shell.bash hook)"

# Base paths and settings
initial_model="Mistral-7B-Instruct-v0.3" 
base_dataset_path="datasets"
base_model_path="models"
ratio=0
eta=0.01
iteration_prefix="onpo"


# Function to run a set of operations for a model iteration
run_iteration() {
    local iteration=$1
    local previous_model_1=$2
    local previous_model_2=$3
    local input_path=$4
    local json_output=$5
    local pref_output=$6

    conda activate new_vllm
    my_world_size=8
    sanity_check=False
    use_tour=True
    K=8
    
    CUDA_VISIBLE_DEVICES=0 python ./generation/get_hf2.py --model_name_or_path ${previous_model_2} --dataset_name_or_path ${input_path} --output_dir ${json_output} --sanity_check $sanity_check --K $K --temperature 1.0 --local_index 0 --my_world_size ${my_world_size} --eos_ids 128009 &
    CUDA_VISIBLE_DEVICES=1 python ./generation/get_hf2.py --model_name_or_path ${previous_model_2} --dataset_name_or_path ${input_path} --output_dir ${json_output} --sanity_check $sanity_check --K $K --temperature 1.0 --local_index 1 --my_world_size ${my_world_size} --eos_ids 128009 &
    CUDA_VISIBLE_DEVICES=2 python ./generation/get_hf2.py --model_name_or_path ${previous_model_2} --dataset_name_or_path ${input_path} --output_dir ${json_output} --sanity_check $sanity_check --K $K --temperature 1.0 --local_index 2 --my_world_size ${my_world_size} --eos_ids 128009 &
    CUDA_VISIBLE_DEVICES=3 python ./generation/get_hf2.py --model_name_or_path ${previous_model_2} --dataset_name_or_path ${input_path} --output_dir ${json_output} --sanity_check $sanity_check --K $K --temperature 1.0 --local_index 3 --my_world_size ${my_world_size} --eos_ids 128009 &
    CUDA_VISIBLE_DEVICES=4 python ./generation/get_hf2.py --model_name_or_path ${previous_model_2} --dataset_name_or_path ${input_path} --output_dir ${json_output} --sanity_check $sanity_check --K $K --temperature 1.0 --local_index 4 --my_world_size ${my_world_size} --eos_ids 128009 &
    CUDA_VISIBLE_DEVICES=5 python ./generation/get_hf2.py --model_name_or_path ${previous_model_2} --dataset_name_or_path ${input_path} --output_dir ${json_output} --sanity_check $sanity_check --K $K --temperature 1.0 --local_index 5 --my_world_size ${my_world_size} --eos_ids 128009 &
    CUDA_VISIBLE_DEVICES=6 python ./generation/get_hf2.py --model_name_or_path ${previous_model_2} --dataset_name_or_path ${input_path} --output_dir ${json_output} --sanity_check $sanity_check --K $K --temperature 1.0 --local_index 6 --my_world_size ${my_world_size} --eos_ids 128009 &
    CUDA_VISIBLE_DEVICES=7 python ./generation/get_hf2.py --model_name_or_path ${previous_model_2} --dataset_name_or_path ${input_path} --output_dir ${json_output} --sanity_check $sanity_check --K $K --temperature 1.0 --local_index 7 --my_world_size ${my_world_size} --eos_ids 128009 &
    wait
    
    python ./generation/merge_data.py --base_path $json_output --output_dir "${json_output}.json" --num_datasets $my_world_size

    CUDA_VISIBLE_DEVICES=0 accelerate launch annotate_data/get_pref_single.py --use_tournament $use_tour --dataset_name_or_path "${json_output}_0.json" --output_dir "${pref_output}_0.json" --K $K &
    CUDA_VISIBLE_DEVICES=1 accelerate launch annotate_data/get_pref_single.py --use_tournament $use_tour --dataset_name_or_path "${json_output}_1.json" --output_dir "${pref_output}_1.json" --K $K &
    CUDA_VISIBLE_DEVICES=2 accelerate launch annotate_data/get_pref_single.py --use_tournament $use_tour --dataset_name_or_path "${json_output}_2.json" --output_dir "${pref_output}_2.json" --K $K &
    CUDA_VISIBLE_DEVICES=3 accelerate launch annotate_data/get_pref_single.py --use_tournament $use_tour --dataset_name_or_path "${json_output}_3.json" --output_dir "${pref_output}_3.json" --K $K &
    CUDA_VISIBLE_DEVICES=4 accelerate launch annotate_data/get_pref_single.py --use_tournament $use_tour --dataset_name_or_path "${json_output}_4.json" --output_dir "${pref_output}_4.json" --K $K &
    CUDA_VISIBLE_DEVICES=5 accelerate launch annotate_data/get_pref_single.py --use_tournament $use_tour --dataset_name_or_path "${json_output}_5.json" --output_dir "${pref_output}_5.json" --K $K &
    CUDA_VISIBLE_DEVICES=6 accelerate launch annotate_data/get_pref_single.py --use_tournament $use_tour --dataset_name_or_path "${json_output}_6.json" --output_dir "${pref_output}_6.json" --K $K &
    CUDA_VISIBLE_DEVICES=7 accelerate launch annotate_data/get_pref_single.py --use_tournament $use_tour --dataset_name_or_path "${json_output}_7.json" --output_dir "${pref_output}_7.json" --K $K &

    wait
    python ./annotate_data/merge.py --base_path $pref_output --output_dir "${pref_output}_data.json" --num_datasets $my_world_size

    # train for pi'_{t+1}, referece model pi'_t, which is previous_model_1
    # start from pi_t, which is previous_model_2
    output_model_path_1="${base_model_path}/${iteration_prefix}_iter${iteration}_1"
    mkdir $output_model_path_1
    
    accelerate launch --config_file ./configs/zero3.yaml ./onpo/onpo_train.py  --output_dir $output_model_path_1 \
      --ref_model $previous_model_1 --model_name_or_path $previous_model_2 --learning_rate 5e-7 --eta $eta \
      --train_dir "${pref_output}_data.json" --loss_type onpo --lr_scheduler_type cosine --num_train_epochs 1 

    # -----------------------------------------------------------------------------------------------------------------
    # train for pi_{t+1}, reference_model is pi'_{t+1}, which is output_model_path_1
    # start from pi'_{t+1}, which is output_model_path_1
    output_model_path_2="${base_model_path}/${iteration_prefix}_iter${iteration}_2"
    mkdir $output_model_path_2
    
    accelerate launch --config_file ./configs/zero3.yaml ./onpo/onpo_train.py  --output_dir $output_model_path_2 \
      --ref_model $output_model_path_1 --model_name_or_path $output_model_path_1 --learning_rate 5e-7 --eta $eta \
      --train_dir "${pref_output}_data.json" --loss_type onpo --lr_scheduler_type cosine --num_train_epochs 1
}


# Main loop for iterations
for i in {1..3}
do
    iteration_name="iter${i}"
    input_path="RLHFlow/iterative-prompt-v1-iter${i}-20K"
    mkdir "${base_dataset_path}/${iteration_prefix}_${iteration_name}"
    json_output="${base_dataset_path}/${iteration_prefix}_${iteration_name}/data"
    pref_output="${base_dataset_path}/${iteration_prefix}_${iteration_name}/pref"

    # json_output="${base_dataset_path}/${iteration_prefix}_${iteration_name}"
    # reward_output="${base_dataset_path}/${iteration_prefix}_${iteration_name}/${iteration_prefix}_${iteration_name}_reward.json"
    
    # Determine the model path: first iteration uses the initial model, subsequent iterations use the previous iteration's model
    if [ $i -eq 1 ]; then
        previous_model_1=$initial_model
        previous_model_2=$initial_model
    else
        previous_iteration=$((i-1))
        previous_model_1="${base_model_path}/${iteration_prefix}_iter${previous_iteration}_1"
        previous_model_2="${base_model_path}/${iteration_prefix}_iter${previous_iteration}_2"
    fi

    run_iteration $i $previous_model_1 $previous_model_2 $input_path $json_output $pref_output
done