#!/bin/bash

# Serial integrated pipeline:
# 1) Base inference (utils/inference.py)
# 2) LLM-as-Judge evaluates answer correctness (utils/evaluation.py)
# 3) Filter samples where all responses are incorrect
# 4) Motivation inference (utils/motivation_exp.py)
# 5) Judge reasoning path correctness and independence
# 6) Summarize metrics

set -euo pipefail

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
cd "$REPO_ROOT"

# Select Python interpreter: prefer user-provided PYTHON_BIN, then python3, then python
PYTHON_BIN="${PYTHON_BIN:-}"
if [[ -z "$PYTHON_BIN" ]]; then
  if command -v python3 >/dev/null 2>&1; then
    PYTHON_BIN=python3
  elif command -v python >/dev/null 2>&1; then
    PYTHON_BIN=python
  else
  echo "Python interpreter not found. Please install python3 or export PYTHON_BIN pointing to the interpreter path."
    exit 1
  fi
fi

print_usage() {
    cat <<'USAGE'

Usage:
  bash scripts/run_motivation_full_pipeline.sh [options]

Options (override defaults with flags):
  --model_name PATH                     Base inference model; used in the first-stage normal inference. (required)
  --dataset_name NAME                   Dataset name (without extension). Example: aime25 / gpqa_diamond_Avg4. (required)
  --dataset_path PATH                   Dataset directory; the script reads <dataset_name>.jsonl from here. (required)
  --motivation_model_name PATH          Model for the second-stage Motivation inference (reason using the provided reference answer).
                                        If omitted, defaults to --model_name. Use to compare different models with a reference answer or
                                        to separate resource usage.
  --base_output_dir PATH                Custom unified output directory; if provided, overrides the default layout
                                        inference_results/<model>/<dataset>/<params>/
  --n_generations N                     Number of candidate responses generated per sample (used in both stages). Default: 16.
  --num_samples N                       Max number of samples to process in stage 1 (for speed/sampling). Default: 10000.
  --temperature FLOAT                   Sampling temperature (shared by both stages). Default: 0.6.
  --top_p FLOAT                         Nucleus sampling top-p (shared). Default: 0.95.
  --top_k INT                           Top-k sampling (shared). Default: 20.
  --max_tokens INT                      Max generation tokens (shared). Default: 8192.
  --tensor_parallel_size INT            vLLM tensor parallelism (shared). Default: 1.
  --gpu_memory_utilization FLOAT        vLLM GPU memory utilization (shared). Default: 0.95.
  --judge_model_accuracy NAME           LLM-as-Judge model for stage-1 answer correctness. Default: gpt-41-mini-0414-global.
  --judge_model_reasoning NAME          LLM-as-Judge model for stage-2 reasoning correctness and independence. Default: gemini-2.5-pro-06-17.
  --use_system_prompt                   Inject a unified system prompt into both stages' chats (same as SYSTEM_PROMPT in the repo). If present, enabled.
  -h | --help                           Show this help.

Environment variables:
  EVAL_FALLBACK_MODELS  Comma-separated fallback judge models. Automatically switch in order when the primary judge model fails repeatedly
                         or returns invalid outputs.
                         Example: export EVAL_FALLBACK_MODELS="gpt-41-mini-0414-global,o4-mini-0416-global,gemini-2.5-pro-06-17"

Notes:
  - This script chains the entire pipeline: base inference -> correctness evaluation -> filter fully-wrong samples -> Motivation inference ->
    reasoning evaluation -> metrics summary.
  - motivation_model_name is optional; by default it uses model_name. Specify only when you want a different model for the second stage.
USAGE
}

# Default parameters (can be overridden by flags of the same name)
MODEL_NAME=""
# Model used in stage 2; if empty, reuse MODEL_NAME
MOTIVATION_MODEL_NAME=""
BASE_OUTPUT_DIR=""
DATASET_NAME=""
DATASET_PATH=""
# Generation-related
N_GENERATIONS=16          # candidates per sample
NUM_SAMPLES=10000         # stage-1 sampling upper bound
TEMPERATURE=0.6           # sampling temperature
TOP_P=0.95                # top-p
TOP_K=20                  # top-k
MAX_TOKENS=8192           # max generated tokens

# Resources (vLLM)
TENSOR_PARALLEL_SIZE=1
GPU_MEMORY_UTILIZATION=0.95

# Judge models
JUDGE_MODEL_ACC="gpt-41-mini-0414-global"   # stage-1 correctness judge
JUDGE_MODEL_REASON="gemini-2.5-pro-06-17"   # stage-2 reasoning judge
USE_SYSTEM_PROMPT=false

# Argument parsing
while [[ $# -gt 0 ]]; do
  case "$1" in
    -h|--help) print_usage; exit 0 ;;
    --model_name) MODEL_NAME="$2"; shift 2 ;;
    --model_name=*) MODEL_NAME="${1#*=}"; shift ;;
    --dataset_name) DATASET_NAME="$2"; shift 2 ;;
    --dataset_name=*) DATASET_NAME="${1#*=}"; shift ;;
    --dataset_path) DATASET_PATH="$2"; shift 2 ;;
    --dataset_path=*) DATASET_PATH="${1#*=}"; shift ;;
    --motivation_model_name) MOTIVATION_MODEL_NAME="$2"; shift 2 ;;
    --motivation_model_name=*) MOTIVATION_MODEL_NAME="${1#*=}"; shift ;;
    --base_output_dir) BASE_OUTPUT_DIR="$2"; shift 2 ;;
    --base_output_dir=*) BASE_OUTPUT_DIR="${1#*=}"; shift ;;
    --n_generations|-n) N_GENERATIONS="$2"; shift 2 ;;
    --n_generations=*|-n=*) N_GENERATIONS="${1#*=}"; shift ;;
    --num_samples) NUM_SAMPLES="$2"; shift 2 ;;
    --num_samples=*) NUM_SAMPLES="${1#*=}"; shift ;;
    --temperature) TEMPERATURE="$2"; shift 2 ;;
    --temperature=*) TEMPERATURE="${1#*=}"; shift ;;
    --top_p) TOP_P="$2"; shift 2 ;;
    --top_p=*) TOP_P="${1#*=}"; shift ;;
    --top_k) TOP_K="$2"; shift 2 ;;
    --top_k=*) TOP_K="${1#*=}"; shift ;;
    --max_tokens) MAX_TOKENS="$2"; shift 2 ;;
    --max_tokens=*) MAX_TOKENS="${1#*=}"; shift ;;
    --tensor_parallel_size) TENSOR_PARALLEL_SIZE="$2"; shift 2 ;;
    --tensor_parallel_size=*) TENSOR_PARALLEL_SIZE="${1#*=}"; shift ;;
    --gpu_memory_utilization) GPU_MEMORY_UTILIZATION="$2"; shift 2 ;;
    --gpu_memory_utilization=*) GPU_MEMORY_UTILIZATION="${1#*=}"; shift ;;
    --judge_model_accuracy) JUDGE_MODEL_ACC="$2"; shift 2 ;;
    --judge_model_accuracy=*) JUDGE_MODEL_ACC="${1#*=}"; shift ;;
    --judge_model_reasoning) JUDGE_MODEL_REASON="$2"; shift 2 ;;
    --judge_model_reasoning=*) JUDGE_MODEL_REASON="${1#*=}"; shift ;;
    --use_system_prompt) USE_SYSTEM_PROMPT=true; shift ;;
  *) echo "Unknown parameter: $1"; echo ""; print_usage; exit 1 ;;
  esac
done

# Basic validation
if [[ -z "$MODEL_NAME" || -z "$DATASET_NAME" || -z "$DATASET_PATH" ]]; then
  echo "Missing required parameters: --model_name/--dataset_name/--dataset_path"
  echo ""
  print_usage
  exit 1
fi

# Python booleans and optional parameters
if [[ "$USE_SYSTEM_PROMPT" == true ]]; then
  USE_SYSTEM_PROMPT_PY=True
else
  USE_SYSTEM_PROMPT_PY=False
fi

if [[ -z "$MOTIVATION_MODEL_NAME" ]]; then
  MOTIVATION_MODEL_NAME_PY=None
else
  # Wrap in single quotes to avoid spaces breaking the value
  MOTIVATION_MODEL_NAME_PY="'$MOTIVATION_MODEL_NAME'"
fi

if [[ -z "$BASE_OUTPUT_DIR" ]]; then
  BASE_OUTPUT_DIR_PY=None
else
  BASE_OUTPUT_DIR_PY="'$BASE_OUTPUT_DIR'"
fi

echo "=============================="
echo "Motivation pipeline start"
echo "=============================="
echo "Base model: $MODEL_NAME"
echo "Dataset: $DATASET_NAME"
echo "Dataset path: $DATASET_PATH"
echo "Motivation model: ${MOTIVATION_MODEL_NAME:-'(same as base model)'}"
echo "Output dir (override): ${BASE_OUTPUT_DIR:-'(default layout)'}"
echo "Judge (accuracy): $JUDGE_MODEL_ACC"
echo "Judge (reasoning): $JUDGE_MODEL_REASON"
echo "n: $N_GENERATIONS, num_samples: $NUM_SAMPLES, temp: $TEMPERATURE, top_p: $TOP_P, top_k: $TOP_K, max_tokens: $MAX_TOKENS"
echo "tensor_parallel_size: $TENSOR_PARALLEL_SIZE, gpu_mem_util: $GPU_MEMORY_UTILIZATION, use_system_prompt: $USE_SYSTEM_PROMPT"
echo "=============================="

"$PYTHON_BIN" - <<PY
from utils.motivation_exp import run_motivation_full_pipeline
import json

res = run_motivation_full_pipeline(
    base_model_name='$MODEL_NAME',
    dataset_name='$DATASET_NAME',
    dataset_path='$DATASET_PATH',
    judge_model_for_accuracy='$JUDGE_MODEL_ACC',
    judge_model_for_reasoning='$JUDGE_MODEL_REASON',
    motivation_model_name=$MOTIVATION_MODEL_NAME_PY,
    base_output_dir=$BASE_OUTPUT_DIR_PY,
    n_generations=int($N_GENERATIONS),
    num_samples=int($NUM_SAMPLES),
    temperature=float($TEMPERATURE),
    top_p=float($TOP_P),
    top_k=int($TOP_K),
    max_tokens=int($MAX_TOKENS),
    tensor_parallel_size=int($TENSOR_PARALLEL_SIZE),
    gpu_memory_utilization=float($GPU_MEMORY_UTILIZATION),
    use_system_prompt=$USE_SYSTEM_PROMPT_PY,
)

print(json.dumps(res, ensure_ascii=False, indent=2))
PY

echo "\n=============================="
echo "Motivation pipeline completed"
echo "=============================="
