#!/bin/bash




set -euo pipefail

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)"

source ${REPO_ROOT}/.venv/bin/activate

START_TIME=$(date '+%Y-%m-%d %H:%M:%S')
echo "Script started at: $START_TIME"

EPOCH=$1
ITERATION=$2
NUM_GPUS=$3
BASE_MODEL=$4
TRAIN_DATASET=$5
EVAL_DATASET=$6
MODEL_PATH=$7
VAL_DATA_PATH=$8
SYSTEM_PROMPT=$9
USE_GUIDED_DECODING=${10:-"false"}


TRACE_DIR="${SCRIPT_DIR}/inference_trace"
OUTPUT_LOG_DIR="${SCRIPT_DIR}/inference_logs"
mkdir -p "${TRACE_DIR}" "${OUTPUT_LOG_DIR}"


TIMESTAMP=$(date '+%Y%m%d_%H%M%S')

SUFFIX=""
if [ "${USE_GUIDED_DECODING}" = "true" ]; then
  SUFFIX="-guided"
fi
LOG_FILE="${OUTPUT_LOG_DIR}/${BASE_MODEL}/${EPOCH}epoch/${TRAIN_DATASET}/${EVAL_DATASET}/iter${ITERATION}-${TIMESTAMP}${SUFFIX}.log"
OUTPUT_PATH="${TRACE_DIR}/${BASE_MODEL}/${EPOCH}epoch/${EVAL_DATASET}/${TRAIN_DATASET}-iter${ITERATION}${SUFFIX}.jsonl"

EVAL_DATASET_PATH="${VAL_DATA_PATH}/${EVAL_DATASET}.jsonl"

mkdir -p "$(dirname "${LOG_FILE}")"
mkdir -p "$(dirname "${OUTPUT_PATH}")"

echo "----------------------------------------"
echo "Starting inference for:"
echo "Dataset Path:    ${EVAL_DATASET_PATH}"
echo "Model Path:      ${MODEL_PATH}"
echo "Output Path:     ${OUTPUT_PATH}"
echo "Log Path:        ${LOG_FILE}"
echo "----------------------------------------"




INFER_CMD=(
  python "${SCRIPT_DIR}/ray_inference_runner.py"
  --model-path "${MODEL_PATH}"
  --tokenizer-path "${MODEL_PATH}"
  --jsonl-file "${EVAL_DATASET_PATH}"
  --output-file "$OUTPUT_PATH"
  --max-new-tokens 8192
  --num-workers $NUM_GPUS
  --batch-size 8
  --tensor-parallel-size 1
  --system-prompt "${SYSTEM_PROMPT}"
  --no-sample
)


if [ "$USE_GUIDED_DECODING" = "true" ]; then
  INFER_CMD+=(--use-guided-decoding)
fi


"${INFER_CMD[@]}" > "${LOG_FILE}" 2>&1

if [ $? -eq 0 ]; then
  echo "Success: ${EVAL_DATASET}"
else
  echo "Error: Failed ${EVAL_DATASET}"
  exit 1
fi



END_TIME=$(date '+%Y-%m-%d %H:%M:%S')
echo "Inference Script ended at: $END_TIME"


python3 ${SCRIPT_DIR}/utils/send_to_slack.py \
  "📝 Inference Finished for \`$EVAL_DATASET\`
> • Model: \`$BASE_MODEL\` (Epoch: \`$EPOCH\`, Iter: \`$ITERATION\`)"


${SCRIPT_DIR}/utils/diff_trace.sh \
  $TRAIN_DATASET \
  $EVAL_DATASET \
  $BASE_MODEL \
  $ITERATION \
  $EPOCH \
  $USE_GUIDED_DECODING \
