#!/usr/bin/env bash
set -euo pipefail

source ~/.bashrc

# Initialize Conda environment
eval "$(conda shell.bash hook)"

# Base paths and settings
initial_model="RLHFlow/LLaMA3-SFT-v2"
base_path="./iter_dpo"
mkdir -p "$base_path"
iteration_prefix="Test"

alpha_for_iter() {
  # iter -> xpo_alpha
  local it="$1"
  if [ "$it" -eq 1 ]; then
    echo "1.0e-5"
  elif [ "$it" -eq 2 ]; then
    echo "5.0e-6"
  else
    echo "0.0"
  fi
}

# Function to run a set of operations for a model iteration
run_iteration() {
  local iteration="$1"
  local model_path="$2"
  local jsonl_input="$3"
  local json_output="$4"
  local model_output="$5"
  local xpo_alpha="$6"   

  echo "[Iter ${iteration}] XPO alpha = ${xpo_alpha}"

  conda activate vllm

  my_world_size=8
  infer_model="${model_path}"
  prompt_dir="${jsonl_input}"
  output_dir="${json_output}"

  CUDA_VISIBLE_DEVICES=0 python ./generation/gen_hf2.py --model_name_or_path "${infer_model}" --dataset_name_or_path "${prompt_dir}" --output_dir "${output_dir}" --K 2 --temperature 1.0 --local_index 0 --my_world_size ${my_world_size} --eos_ids 128009 &
  CUDA_VISIBLE_DEVICES=1 python ./generation/gen_hf2.py --model_name_or_path "${infer_model}" --dataset_name_or_path "${prompt_dir}" --output_dir "${output_dir}" --K 2 --temperature 1.0 --local_index 1 --my_world_size ${my_world_size} --eos_ids 128009 &
  CUDA_VISIBLE_DEVICES=2 python ./generation/gen_hf2.py --model_name_or_path "${infer_model}" --dataset_name_or_path "${prompt_dir}" --output_dir "${output_dir}" --K 2 --temperature 1.0 --local_index 2 --my_world_size ${my_world_size} --eos_ids 128009 &
  CUDA_VISIBLE_DEVICES=3 python ./generation/gen_hf2.py --model_name_or_path "${infer_model}" --dataset_name_or_path "${prompt_dir}" --output_dir "${output_dir}" --K 2 --temperature 1.0 --local_index 3 --my_world_size ${my_world_size} --eos_ids 128009 &
  CUDA_VISIBLE_DEVICES=4 python ./generation/gen_hf2.py --model_name_or_path "${infer_model}" --dataset_name_or_path "${prompt_dir}" --output_dir "${output_dir}" --K 2 --temperature 1.0 --local_index 4 --my_world_size ${my_world_size} --eos_ids 128009 &
  CUDA_VISIBLE_DEVICES=5 python ./generation/gen_hf2.py --model_name_or_path "${infer_model}" --dataset_name_or_path "${prompt_dir}" --output_dir "${output_dir}" --K 2 --temperature 1.0 --local_index 5 --my_world_size ${my_world_size} --eos_ids 128009 &
  CUDA_VISIBLE_DEVICES=6 python ./generation/gen_hf2.py --model_name_or_path "${infer_model}" --dataset_name_or_path "${prompt_dir}" --output_dir "${output_dir}" --K 2 --temperature 1.0 --local_index 6 --my_world_size ${my_world_size} --eos_ids 128009 &
  CUDA_VISIBLE_DEVICES=7 python ./generation/gen_hf2.py --model_name_or_path "${infer_model}" --dataset_name_or_path "${prompt_dir}" --output_dir "${output_dir}" --K 2 --temperature 1.0 --local_index 7 --my_world_size ${my_world_size} --eos_ids 128009 &
  wait

  python ./generation/merge_data.py --base_path "${output_dir}" --output_dir "${output_dir}_data.jsonl" --num_datasets 8

  accelerate launch annotate_data/get_rewards.py \
    --dataset_name_or_path "${output_dir}_data.jsonl" \
    --output_dir "${model_output}" --K 2

  conda activate rlhflow

  cat <<EOT > dpo_config.yaml
run_name: ${iteration}
output_dir: ${iteration}
model_name_or_path: ${model_path}
ref_model: ${model_path}
learning_rate: 5.0e-7
num_train_epochs: 2
logging_steps: 2
gradient_checkpointing: true
do_train: true
do_eval: true
eval_steps: 10000
choose_type: max_min
train_dir: ${model_output}
eval_dir: ${model_output}
loss_type: sigmoid
lr_scheduler_type: cosine
max_length: 2048
max_prompt_length: 1000
eval_strategy: steps
bf16: true
per_device_train_batch_size: 1
per_device_eval_batch_size: 1
gradient_accumulation_steps: 16
report_to: wandb
label_smoothing: 0.1
xpo_alpha: ${xpo_alpha}
EOT

  accelerate launch --config_file ./configs/zero2.yaml dpo_iteration/run_dpo.py dpo_config.yaml
}

# Main loop for iterations
for i in 1 2 3; do
  iteration_name="XPO_iter${i}"
  jsonl_input="RLHFlow/ultrafeedback_iter${i}"
  json_output="${base_path}/${iteration_prefix}${i}_${iteration_name}"
  model_output="${base_path}/${iteration_prefix}${i}_${iteration_name}_reward.json"

  if [ "${i}" -eq 1 ]; then
    model_path="${initial_model}"
  else
    prev=$((i-1))
    model_path="XPO_iter${prev}"
  fi

  xpo_alpha_val="$(alpha_for_iter "${i}")"
  run_iteration "${iteration_name}" "${model_path}" "${jsonl_input}" "${json_output}" "${model_output}" "${xpo_alpha_val}"
done
