#!/bin/bash




SCRIPT_DIR="$(cd "$(dirname "$(readlink -f "$0")")" && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)"


CONFIG_FILE=${1:-"${SCRIPT_DIR}/configs/base_config.sh"}
if [ ! -f "$CONFIG_FILE" ]; then
    echo "エラー: 設定ファイルが見つかりません: ${CONFIG_FILE}"
    exit 1
fi

echo "CONFIG_FILE: ${CONFIG_FILE}"

source ${REPO_ROOT}/.venv/bin/activate

source $CONFIG_FILE



if [ -f "$LATEST_ITERATION_FILE" ]; then
  MAX_ITERATION=$(cat "$LATEST_ITERATION_FILE")
else

  MAX_ITERATION=7500
fi

echo "最大イテレーションを検出しました: ${MAX_ITERATION}"




process_checkpoint() {
  local iter_to_run=$1

  echo "=================================================="
  echo "イテレーション ${iter_to_run} の処理を開始します"
  echo "=================================================="


  echo "--- ステップ 1: イテレーション ${iter_to_run} のモデル変換 ---"

  local training_checkpoint_dir_base="${CHECK_POINT_ROOT_PATH}/${BASE_MODEL}/${TRAIN_DATASET}_lr_2e-5-minlr_4e-6_GB_64_${EPOCH}epoch"
  local formatted_iter=$(printf "iter_%07d" "$iter_to_run")
  local source_checkpoint_dir="${training_checkpoint_dir_base}/${formatted_iter}"


  if [ ! -d "$source_checkpoint_dir" ]; then
      echo "警告: 変換元のチェックポイントが見つかりません。イテレーション ${iter_to_run} をスキップします: ${source_checkpoint_dir}"
      return
  fi


  ${SCRIPT_DIR}/convert_model.sh \
    "$EPOCH" "$TRAIN_DATASET" "$iter_to_run" "$CHECK_POINT_ROOT_PATH" \
    "$CONVERTED_CHECK_POINT_ROOT_PATH" "$BASE_MODEL" "$MODEL_ORG"
  python3 ${SCRIPT_DIR}/utils/send_to_slack.py "Model conversion for iteration ${iter_to_run} Finished!"



  echo "--- ステップ 2: イテレーション ${iter_to_run} の推論実行 ---"

  local converted_model_path="${CONVERTED_CHECK_POINT_ROOT_PATH}/${BASE_MODEL}/${TRAIN_DATASET}_lr_2e-5-minlr_4e-6_GB_64_${EPOCH}epoch/${formatted_iter}"


  if [ ! -d "$converted_model_path" ]; then
      echo "警告: 変換後のモデルが見つかりません。推論をスキップします: ${converted_model_path}"
      return
  fi

  declare -a inference_datasets=(
    "LLTM-cruxeval-numeric-depth-val"
    "LLTM-livecodebench-numeric-depth-val"
    "LLTM-livecodebench-all-numeric-depth-val"
  )
  local VAL_DATA_PATH="${REPO_ROOT}/scripts/instruction/convert_datasets"

  for i in "${!inference_datasets[@]}"; do
    local cuda_device_id=$i
    local current_eval_dataset="${inference_datasets[$i]}"

    ${SCRIPT_DIR}/inference.sh \
      "$EPOCH" "$iter_to_run" "$NUM_GPUS" "$BASE_MODEL" "$TRAIN_DATASET" \
      "$current_eval_dataset" "$converted_model_path" "$VAL_DATA_PATH" "$SYSTEM_PROMPT"
  done

  ${SCRIPT_DIR}/utils/aggregate_inference_outputs.sh \
      $TRAIN_DATASET \
      $BASE_MODEL \
      $iter_to_run \
      $EPOCH \
      $USE_GUIDED_DECODING
}




ITERATION_STEP=500


for (( iter=${ITERATION_STEP}; iter<=${MAX_ITERATION}; iter+=${ITERATION_STEP} )); do
  process_checkpoint $iter
done


if (( MAX_ITERATION % ITERATION_STEP != 0 )); then
  echo "最終チェックポイント (${MAX_ITERATION}) の処理を追加で実行します。"
  process_checkpoint $MAX_ITERATION
fi


python3 ${SCRIPT_DIR}/utils/send_to_slack.py "<!channel> 全てのチェックポイントの処理 (変換 & 推論) が完了しました！"
echo "=================================================="
echo "全ての処理が完了しました。"
echo "=================================================="
