#!/bin/bash

# 打印用法说明
usage() {
  echo "Usage: $0 --gpu_ids <ids> --dpo_path <path> --ref_path <path> --batch_size <size> --kto_dataset <path> --tag <name> --seq_cal <name>"
  exit 1
}

# 默认参数（可选）
gpu_ids=""
dpo_path=""
ref_path=""
batch_size=""
kto_dataset=""
tag=""
seq_cal=""

# 参数解析
while [[ $# -gt 0 ]]; do
  key="$1"
  case $key in
    --gpu_ids)
      gpu_ids="$2"
      shift; shift
      ;;
    --dpo_path)
      dpo_path="$2"
      shift; shift
      ;;
    --ref_path)
      ref_path="$2"
      shift; shift
      ;;
    --batch_size)
      batch_size="$2"
      shift; shift
      ;;
    --kto_dataset)
      kto_dataset="$2"
      shift; shift
      ;;
    --tag)
      tag="$2"
      shift; shift
      ;;
    --seq_cal)
      seq_cal="$2"
      shift; shift
      ;;
    *)
      echo "Unknown option $1"
      usage
      ;;
  esac
done

# 参数校验
if [[ -z "$gpu_ids" || -z "$dpo_path" || -z "$ref_path" || -z "$batch_size" || -z "$kto_dataset" || -z "$tag" || -z "$seq_cal" ]]; then
  echo "Missing required arguments"
  usage
fi

# 路径设置
dpo_ds="${dpo_path}/${tag}.json"
ref_ds="${ref_path}/${tag}.json"

# 生成 ref_ds
if [ ! -f "$ref_ds" ]; then
  echo "ref_ds not found: $ref_ds, generating it now..."
  CUDA_VISIBLE_DEVICES=${gpu_ids} ~/verl_250713/.conda/bin/python \
    ~/verl_250713/scripts/testprm_1_compute_logp.py \
    --model ${ref_path} \
    --data ${kto_dataset} \
    --prompt_key messages \
    --output_key ref_logp \
    --output_jsonl ${ref_ds} \
    --batch_size ${batch_size}
fi

# 生成 dpo_ds
if [ ! -f "$dpo_ds" ]; then
  echo "dpo_ds not found: $dpo_ds, generating it now..."
  CUDA_VISIBLE_DEVICES=${gpu_ids} ~/verl_250713/.conda/bin/python \
    ~/verl_250713/scripts/testprm_1_compute_logp.py \
    --model ${dpo_path} \
    --data ${kto_dataset} \
    --prompt_key messages \
    --output_key prm_logp \
    --output_jsonl ${dpo_ds} \
    --batch_size ${batch_size}
fi

# 启动评估
~/verl_250713/.conda/bin/python \
  ~/verl_250713/scripts/valid_everyce.py \
  --dpo_ds ${dpo_ds} \
  --ref_ds ${ref_ds} \
  --dpo_logp_name prm_logp \
  --ref_logp_name ref_logp \
  --seq_reward_cal ${seq_cal}
