set -e
set -u
set -o pipefail

# Usage: ./run_online_iterative_paper.sh [OPTIONS]
# Options:
#   --dataset_name, -d          Dataset name (default: tldr)
#   --model_name, -m            Model name (default: g_2_2b)
#   --ref_model_name, -r        Reference model path (default: outputs/sft/sft-tldr-g_2_2b)
#   --skip_model_fitting, -s    Skip model fitting (default: false)
#   --tag, -t                   Experiment tag (default: "")
#   --num_rollout_prompts, -n   Number of rollout prompts (default: 256)
#   --num_gpus, -g              Number of GPUs (default: 2)
#   --use_bt_plus, -b           Use BT plus (default: false)
#   --use_fsdp, -f              Enable FSDP config based on model type (default: false)
#                              - gemma models (including g_*): uses recipes/fsdp_acc_gemma.yaml
#                              - qwen models: uses recipes/fsdp_acc_qwen.yaml
#   --use_fsdp_for_rewards      Enable FSDP config for reward computation (default: false)
#   --skip_golden_reward        Skip golden reward evaluation (default: false)
#   --proxy_reward_batch_size   Batch size for proxy reward evaluation (default: 32)
#   --golden_reward_batch_size  Batch size for golden reward computation (default: 8)
#   --use_lora                  Enable LoRA fine-tuning (default: false)
#   --lora_r                    LoRA rank value (default: 16, only used if --use_lora is true)
#   --beta                      Beta parameter for importance sampling weights (default: 3.0)
#   --help, -h                  Show this help message

# create max tokens lenght dict
declare -A max_tokens_dict=(
    ["uf"]=450
    ["gsm8k"]=450
    ["tldr"]=64
    ["hh"]=128
)

declare -A prompt_dataset_name_dict=(
    ["uf"]="trl-lib/ultrafeedback-prompt"
    ["gsm8k"]="datasets/gsm8k-sft"
    ["tldr"]="trl-lib/tldr-preference"
    ["hh"]="datasets/hh-rlhf-helpful-base"
)

declare -A dataset_test_set_name_dic=(
  ["uf"]="test"
  ["gsm8k"]="test"
  ["tldr"]="validation"
  ["hh"]="test"
)

# Default values
CURRENT_DATASET_NAME="tldr"
CURRENT_MODEL_NAME="g_2_2b"
REF_MODEL_NAME="outputs/sft/sft-tldr-g_2_2b"
SKIP_MODEL_FITTING=false
TAG="v2-bt_plus"
NUM_OF_ROLLOUT_PROMPTS=512
NUM_GPUS=2
USE_BT_PLUS=true
USE_FSDP=false
USE_FSDP_FOR_REWARDS=false
SKIP_GOLDEN_REWARD=false
PROXY_REWARD_BATCH_SIZE=16
GOLDEN_REWARD_BATCH_SIZE=8
USE_LORA=false
LORA_R=16
BETA=3.0

# Function to show help
show_help() {
    echo "Usage: $0 [OPTIONS]"
    echo ""
    echo "Options:"
    echo "  --dataset_name, -d          Dataset name (default: tldr)"
    echo "  --model_name, -m            Model name (default: g_2_2b)"
    echo "  --ref_model_name, -r        Reference model path (default: outputs/sft/sft-tldr-g_2_2b)"
    echo "  --skip_model_fitting, -s    Skip model fitting (default: false)"
    echo "  --tag, -t                   Experiment tag (default: v2-bt_plus)"
    echo "  --num_rollout_prompts, -n   Number of rollout prompts (default: 512)"
    echo "  --num_gpus, -g              Number of GPUs (default: 2)"
    echo "  --no_bt_plus, -b            Disable BT plus (default: enabled)"
      echo "  --use_fsdp, -f              Enable FSDP config based on model type (default: false)"
  echo "                              - gemma models (including g_*): uses recipes/fsdp_acc_gemma.yaml"
  echo "                              - qwen models: uses recipes/fsdp_acc_qwen.yaml"
      echo "  --use_fsdp_for_rewards      Enable FSDP config for golden reward computation (default: false)"
    echo "  --skip_golden_reward        Skip golden reward evaluation (default: false)"
    echo "  --proxy_reward_batch_size   Batch size for proxy reward evaluation (default: 16)"
    echo "  --golden_reward_batch_size  Batch size for golden reward computation (default: 8)"
    echo "  --use_lora                  Enable LoRA fine-tuning (default: false)"
    echo "  --lora_r                    LoRA rank value (default: 16, only used if --use_lora is true)"
    echo "  --beta                      Beta parameter for importance sampling weights (default: 3.0)"
    echo "  --help, -h                  Show this help message"
    echo ""
    echo "Examples:"
    echo "  $0 --dataset_name tldr --model_name g_2_2b --num_gpus 4"
    echo "  $0 -d uf -m gemma_7b --use_fsdp true --proxy_reward_batch_size 64"
    echo "  $0 --dataset_name gsm8k --no_bt_plus --tag custom_experiment"
    echo "  $0 --dataset_name tldr --use_fsdp_for_rewards true --golden_reward_batch_size 4"
    echo "  $0 --dataset_name tldr --skip_golden_reward --num_gpus 4"
    echo "  $0 --dataset_name tldr --use_lora --lora_r 32"
    echo ""
    echo "Note: This script runs ISO experiments for spn values 8,16,32,64 and baseline only for spn=16"
    exit 0
}

# Parse command line arguments
while [[ $# -gt 0 ]]; do
    case $1 in
        --dataset_name|-d)
            CURRENT_DATASET_NAME="$2"
            shift 2
            ;;
        --model_name|-m)
            CURRENT_MODEL_NAME="$2"
            shift 2
            ;;
        --ref_model_name|-r)
            REF_MODEL_NAME="$2"
            shift 2
            ;;
        --skip_model_fitting|-s)
            SKIP_MODEL_FITTING="$2"
            shift 2
            ;;
        --tag|-t)
            TAG="$2"
            shift 2
            ;;
        --num_rollout_prompts|-n)
            NUM_OF_ROLLOUT_PROMPTS="$2"
            shift 2
            ;;
        --num_gpus|-g)
            NUM_GPUS="$2"
            shift 2
            ;;
        --no_bt_plus|-b)
            USE_BT_PLUS=false
            shift 1
            ;;
        --use_fsdp|-f)
            USE_FSDP="$2"
            shift 2
            ;;
        --use_fsdp_for_rewards)
            USE_FSDP_FOR_REWARDS="$2"
            shift 2
            ;;
        --skip_golden_reward)
            SKIP_GOLDEN_REWARD=true
            shift 1
            ;;
        --proxy_reward_batch_size)
            PROXY_REWARD_BATCH_SIZE="$2"
            shift 2
            ;;
        --golden_reward_batch_size)
            GOLDEN_REWARD_BATCH_SIZE="$2"
            shift 2
            ;;
        --use_lora)
            USE_LORA=true
            shift 1
            ;;
        --lora_r)
            LORA_R="$2"
            shift 2
            ;;
        --beta)
            BETA="$2"
            shift 2
            ;;
        --help|-h)
            show_help
            ;;
        *)
            echo "Unknown option: $1"
            echo "Use --help for usage information."
            exit 1
            ;;
    esac
done

# Set random DDP port
DDP_PORT=$(shuf -i 20000-30000 -n 1)

# Function to get FSDP config based on model name
function get_fsdp_config {
  if [ "$USE_FSDP" = true ]; then
    if [[ "$CURRENT_MODEL_NAME" == *"gemma"* ]] || [[ "$CURRENT_MODEL_NAME" == *"g_"* ]]; then
      echo "--config_file recipes/fsdp_acc_gemma.yaml"
    elif [[ "$CURRENT_MODEL_NAME" == *"qwen"* ]]; then
      echo "--config_file recipes/fsdp_acc_qwen.yaml"
    else
      echo ""
    fi
  else
    echo ""
  fi
}

# Function to get FSDP config for reward computation
function get_reward_fsdp_config {
  if [ "$USE_FSDP_FOR_REWARDS" = true ]; then
    echo "--config_file recipes/fsdp_acc_gemma.yaml"
  else
    echo ""
  fi
}

# Function to get LoRA parameters
function get_lora_params {
  if [ "$USE_LORA" = true ]; then
    LORA_ALPHA=$((LORA_R / 2))
    echo "--learning_rate=5.0e-5 --use_peft=True --lora_r=$LORA_R --lora_alpha=$LORA_ALPHA"
  else
    echo ""
  fi
}

function generate_from_sft {
  local MODEL_NAME=$1
  local OUTPUT_PATH=$2
  local NUM_RETURN_SEQUENCES=$3
#  accelerate launch --num_processes=$NUM_GPUS --main_process_port=$DDP_PORT \
  python rso/generate_from_sft_vllm.py \
    --dataset_name=${prompt_dataset_name_dict[$CURRENT_DATASET_NAME]} \
    --model_name_or_path=$MODEL_NAME \
    --max_new_tokens=${max_tokens_dict[$CURRENT_DATASET_NAME]} \
    --num_return_sequences=$NUM_RETURN_SEQUENCES \
    --split=${dataset_test_set_name_dic[$CURRENT_DATASET_NAME]} \
    --save_dataset_path=$OUTPUT_PATH --num_of_gpus=$NUM_GPUS
}

# ------ Generate Baseline Dataset for Testing ------
SFT_GEN_TEST_DATASET_NAME="datasets/dpo_gen/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/1/test/base_sample"

if [ "$SKIP_MODEL_FITTING" = false ] ; then
  rm -rf ${SFT_GEN_TEST_DATASET_NAME}*
  generate_from_sft $REF_MODEL_NAME $SFT_GEN_TEST_DATASET_NAME 3
fi


for ROUND in {1..2} ; do

echo "Round $ROUND"

# Function to generate data, run RSO, and create datasets
function generate_data_and_run_rso {
  local REF_MODEL=$1
  local DATASET_SUFFIX=$2
  
  RSO_DATASET_NAME="datasets/dpo_gen/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}/rso/${DATASET_SUFFIX}"
  
  if [ "$SKIP_MODEL_FITTING" = false ] ; then
    rm -rf ${RSO_DATASET_NAME}*
    
    # accelerate launch --num_processes=$NUM_GPUS --main_process_port=$DDP_PORT \
    python rso/generate_from_sft_vllm.py \
      --dataset_name=${prompt_dataset_name_dict[$CURRENT_DATASET_NAME]} \
      --model_name_or_path=$REF_MODEL \
      --max_new_tokens=${max_tokens_dict[$CURRENT_DATASET_NAME]} \
      --num_return_sequences=64 --num_of_gpus=$NUM_GPUS \
      --save_dataset_path=$RSO_DATASET_NAME --sub_sample=$NUM_OF_ROLLOUT_PROMPTS

    # ------ Evaluate Reward and Run RSO ------

    accelerate launch --num_processes=$NUM_GPUS --main_process_port=$DDP_PORT \
      rso/rso_pairwise_reward.py \
      "recipes/reward/t5_pair_${CURRENT_DATASET_NAME}.yaml" \
      --model_name_or_path="outputs/reward/reward_pair-${CURRENT_DATASET_NAME}-t5-xxl" \
      --per_device_eval_batch_size=$PROXY_REWARD_BATCH_SIZE \
      --sample_dataset_path=$RSO_DATASET_NAME \
      --reward_dataset_path="${RSO_DATASET_NAME}_reward" \
      --save_dataset_path="${RSO_DATASET_NAME}_ranked"

    python prepare_dataset.py \
      --gen_dataset_path="${RSO_DATASET_NAME}_ranked" \
      --output_path="${RSO_DATASET_NAME}_ranked_ready"

    # ------- Create ISO & Direct SFT Dataset ------
    if [ "$USE_BT_PLUS" = true ] ; then
      python make_iso_baseline_dataset.py \
        --init_reward_dataset_path="${RSO_DATASET_NAME}_reward" \
        --target_dataset_base_path="datasets/dpo_gen/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}" \
        --plus_bt_is_score \
        --beta="$BETA"
    else
      python make_iso_baseline_dataset.py \
         --init_reward_dataset_path="${RSO_DATASET_NAME}_reward" \
         --target_dataset_base_path="datasets/dpo_gen/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}" \
         --beta="$BETA"
    fi
  fi
}

# Function to run experiments with a specific reference model
function run_experiment_get_winrate {
  local RECIPE_PATH=$1
  local EXP_DATASET_NAME=$2
  local POLICY_MODEL_NAME=$3
  local POLICY_GEN_DATASET_NAME=$4
  local REFERENCE_MODEL=$5

  if [ "$SKIP_MODEL_FITTING" = false ] ; then
    accelerate launch $(get_fsdp_config) --num_processes=$NUM_GPUS --main_process_port=$DDP_PORT \
      dpo.py $RECIPE_PATH \
      --dataset_name="$EXP_DATASET_NAME" \
      --model_name_or_path="$REFERENCE_MODEL" \
      --output_dir="$POLICY_MODEL_NAME" \
      --run_name="${POLICY_MODEL_NAME##outputs/dpo_policy/}" \
      $(get_lora_params)

    generate_from_sft $POLICY_MODEL_NAME $POLICY_GEN_DATASET_NAME 2
  fi

  accelerate launch $(get_fsdp_config) --num_processes=$NUM_GPUS --main_process_port=$DDP_PORT \
    win_rate/win_pairwise_baseline_proxy.py \
    "recipes/reward/t5_pair_${CURRENT_DATASET_NAME}.yaml" \
    --model_name_or_path="outputs/reward/reward_pair-${CURRENT_DATASET_NAME}-t5-xxl" \
    --per_device_eval_batch_size=$PROXY_REWARD_BATCH_SIZE \
    --sft_gen_path=$SFT_GEN_TEST_DATASET_NAME \
    --target_gen_path=$POLICY_GEN_DATASET_NAME \
    --log_tag="$TAG"
}

if [ "$SKIP_MODEL_FITTING" = false ] ; then
  rm -rf outputs/dpo_policy/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}*
fi

if [ $ROUND -eq 1 ] ; then
  # First iteration: use original reference model for all experiments
  generate_data_and_run_rso $REF_MODEL_NAME "16"
  
  run_experiment_get_winrate "recipes/dpo/${CURRENT_MODEL_NAME}-${CURRENT_DATASET_NAME}.yaml" \
    "datasets/dpo_gen/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}/rso/16_ranked_ready" \
    "outputs/dpo_policy/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}/rso/16" \
    "datasets/policy_eval_gen/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}/rso/16" \
    $REF_MODEL_NAME

  # Run baseline experiment only for spn=16
  run_experiment_get_winrate "recipes/dpo/${CURRENT_MODEL_NAME}-${CURRENT_DATASET_NAME}.yaml" \
    "datasets/dpo_gen/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}/baseline/16" \
    "outputs/dpo_policy/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}/baseline/16" \
    "datasets/policy_eval_gen/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}/baseline/16" \
    $REF_MODEL_NAME

  # Run ISO experiments for spn values 8, 16, 32, 64
  for spn in 16 8 32 64; do
    run_experiment_get_winrate "recipes/dpo/${CURRENT_MODEL_NAME}-${CURRENT_DATASET_NAME}.yaml" \
      "datasets/dpo_gen/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}/iso/${spn}" \
      "outputs/dpo_policy/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}/iso/${spn}" \
      "datasets/policy_eval_gen/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}/iso/${spn}" \
      $REF_MODEL_NAME
  done

else
  # Second iteration: use corresponding models from first iteration as reference
  
  # RSO experiment using RSO model from previous iteration
  RSO_REF_MODEL="outputs/dpo_policy/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/$((ROUND-1))/rso/16"
  generate_data_and_run_rso $RSO_REF_MODEL "rso_ref"
  
  run_experiment_get_winrate "recipes/dpo/${CURRENT_MODEL_NAME}-${CURRENT_DATASET_NAME}.yaml" \
    "datasets/dpo_gen/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}/rso/rso_ref_ranked_ready" \
    "outputs/dpo_policy/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}/rso/16" \
    "datasets/policy_eval_gen/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}/rso/16" \
    $RSO_REF_MODEL

  # Baseline experiment using baseline model from previous iteration
  BASELINE_REF_MODEL="outputs/dpo_policy/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/$((ROUND-1))/baseline/16"
  generate_data_and_run_rso $BASELINE_REF_MODEL "baseline_ref"
  
  run_experiment_get_winrate "recipes/dpo/${CURRENT_MODEL_NAME}-${CURRENT_DATASET_NAME}.yaml" \
    "datasets/dpo_gen/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}/baseline/16" \
    "outputs/dpo_policy/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}/baseline/16" \
    "datasets/policy_eval_gen/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}/baseline/16" \
    $BASELINE_REF_MODEL

  # ISO experiments using ISO models from previous iteration
  for spn in 16 8 32 64; do
    ISO_REF_MODEL="outputs/dpo_policy/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/$((ROUND-1))/iso/${spn}"
    generate_data_and_run_rso $ISO_REF_MODEL "iso_${spn}_ref"
    
    run_experiment_get_winrate "recipes/dpo/${CURRENT_MODEL_NAME}-${CURRENT_DATASET_NAME}.yaml" \
      "datasets/dpo_gen/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}/iso/${spn}" \
      "outputs/dpo_policy/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}/iso/${spn}" \
      "datasets/policy_eval_gen/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}/iso/${spn}" \
      $ISO_REF_MODEL
  done
fi

# ------- Compute Golden Reward ------

if [ "$SKIP_GOLDEN_REWARD" = false ] ; then

accelerate launch $(get_reward_fsdp_config) --num_processes=$NUM_GPUS --main_process_port=$DDP_PORT win_rate/compute_rewards.py \
  --dataset_path="${SFT_GEN_TEST_DATASET_NAME}-proxy_reward" \
  --output_path="${SFT_GEN_TEST_DATASET_NAME}-golden_rewards" \
  --model_name="outputs/reward/gemma-2-27b-Reward-${CURRENT_DATASET_NAME}" --batch_size=$GOLDEN_REWARD_BATCH_SIZE --use_cls

for dataset_path in datasets/policy_eval_gen/${CURRENT_MODEL_NAME}/${CURRENT_DATASET_NAME}/${ROUND}/*/*-proxy_reward; do

  accelerate launch $(get_reward_fsdp_config) --num_processes=$NUM_GPUS --main_process_port=$DDP_PORT win_rate/compute_rewards.py \
    --dataset_path=$dataset_path --output_path="${dataset_path/proxy_reward/golden_rewards}" \
    --model_name="outputs/reward/gemma-2-27b-Reward-${CURRENT_DATASET_NAME}" --batch_size=$GOLDEN_REWARD_BATCH_SIZE --use_cls

  python win_rate/compare_rewards.py \
    --dataset1_path="${SFT_GEN_TEST_DATASET_NAME}-golden_rewards" \
    --dataset2_path="${dataset_path/proxy_reward/golden_rewards}" \
    --log_tag="$TAG"

done

fi

done
