#!/usr/bin/env bash
set -euo pipefail

# =============================================================================
# Symbol Edit Faithfulness Pipeline — End-to-End for Kimina-Prover-RL-1.7B
#
# Runs Stage 1 → Stage 2 → Stage 3 (inference n=1) → Stage 4 (scoring) → Stage 5 (aggregate)
#
# Usage:
#   export GOOGLE_API_KEY=...
#   bash symbol_edit/run_symbol_edit_pipeline.sh
# =============================================================================

: "${GOOGLE_API_KEY:?Set GOOGLE_API_KEY first}"

ROOT="$(cd "$(dirname "$0")/.." && pwd)"
cd "$ROOT"

DATASET="./datasets_validation/minif2f/dataset.jsonl"
LABEL_OUTPUT="./symbol_edit/data/labeled_symbols.jsonl"
FIXED_OUTPUT="./symbol_edit/data/labeled_symbols_fixed.jsonl"
SELECT_OUTPUT="./symbol_edit/data/selected_symbols.jsonl"
EDIT_OUTPUT_DIR="./symbol_edit/data"
MODEL_ID="AI-MO/Kimina-Prover-RL-1.7B"
NUM_SAMPLES=1
GEMINI_MODEL="gemini-2.5-flash"
PORT=8000

EDIT_TYPES=("statement_edit" "proof_edit")

echo "============================================"
echo "  Symbol Edit Pipeline — Kimina-Prover-RL-1.7B"
echo "============================================"

# ─── Stage 1: Label symbol picks (Gemini) ───────────────────────────────────
echo ""
echo ">>> Stage 1: Labeling symbol picks in $DATASET"
if [ -f "$LABEL_OUTPUT" ]; then
    existing=$(wc -l < "$LABEL_OUTPUT")
    echo "    Resuming: $existing already labeled"
fi

python3 symbol_edit/label_symbol_roles.py \
    --input "$DATASET" \
    --output "$LABEL_OUTPUT" \
    --model "$GEMINI_MODEL" \
    --limit 0

labeled=$(wc -l < "$LABEL_OUTPUT")
echo "    Done: $labeled problems labeled"

# ─── Stage 1.5: Fix char offsets via context anchoring ──────────────────────
echo ""
echo ">>> Stage 1.5: Fixing symbol char offsets via context anchoring"
python3 symbol_edit/fix_offsets.py \
    --input "$DATASET" \
    --labels "$LABEL_OUTPUT" \
    --output "$FIXED_OUTPUT"

# ─── Stage 2: Select (filter + random pick) ──────────────────────────────────
echo ""
echo ">>> Stage 2: Selecting one candidate per source"
python3 symbol_edit/select_candidates.py \
    --input "$FIXED_OUTPUT" \
    --output "$SELECT_OUTPUT" \
    --dataset "$DATASET"

# ─── Stage 3: Build edited datasets ──────────────────────────────────────────
echo ""
echo ">>> Stage 3: Building edited datasets in $EDIT_OUTPUT_DIR"

python3 symbol_edit/build_symbol_edit_unsound.py \
    --input "$DATASET" \
    --labels "$SELECT_OUTPUT" \
    --output_dir "$EDIT_OUTPUT_DIR" \
    --limit 0

echo "    Edit datasets built:"
for et in "${EDIT_TYPES[@]}"; do
    count=$(wc -l < "${EDIT_OUTPUT_DIR}/${et}_unsound.jsonl" 2>/dev/null || echo 0)
    echo "      ${et}: ${count} problems"
done

# ─── Stage 3: Inference with vLLM (local, n=1) ──────────────────────────────
echo ""
echo ">>> Stage 3: Running inference (n=$NUM_SAMPLES per problem)"

echo "    Starting vLLM server on port $PORT..."
nohup vllm serve "$MODEL_ID" \
    --port "$PORT" \
    --tensor-parallel-size 1 \
    --max_model_len 40960 \
    > "./symbol_edit/data/vllm_symbol_edit.log" 2>&1 &
VLLM_PID=$!

echo "    Waiting for vLLM to start..."
for _ in $(seq 1 180); do
    (echo > /dev/tcp/127.0.0.1/$PORT) >/dev/null 2>&1 && break
    sleep 2
done
if ! (echo > /dev/tcp/127.0.0.1/$PORT) >/dev/null 2>&1; then
    echo "    ERROR: vLLM failed to start"
    kill $VLLM_PID 2>/dev/null || true
    exit 1
fi
echo "    vLLM ready"

RESULTS_DIR="./symbol_edit/data/inference_results"
mkdir -p "$RESULTS_DIR"

for et in "${EDIT_TYPES[@]}"; do
    edit_file="${EDIT_OUTPUT_DIR}/${et}_unsound.jsonl"

    if [ ! -f "$edit_file" ] || [ "$(wc -l < "$edit_file")" -eq 0 ]; then
        echo "    Skipping $et (no data)"
        continue
    fi

    echo "    Running inference for $et..."
    python3 llm_inference/gpu_inference_Kimina-Prover-RL-1-7B.py \
        --port "$PORT" \
        --num_samples_per_task "$NUM_SAMPLES" \
        --model_id "$MODEL_ID" \
        --method_tag "SEU_${et}" \
        --eval_dir "$RESULTS_DIR" \
        --dataset_path "$edit_file" \
        --use_examples_in_prompt 0
done

echo "    Stopping vLLM..."
kill $VLLM_PID 2>/dev/null || true
wait $VLLM_PID 2>/dev/null || true

# ─── Stage 4 & 5: Score + Aggregate ─────────────────────────────────────────
echo ""
echo ">>> Stage 4 & 5: Scoring and aggregating results"

SCORED_DIR="./symbol_edit/data/scored"
SUMMARY_DIR="./symbol_edit/data/summary"
mkdir -p "$SCORED_DIR" "$SUMMARY_DIR"

for et in "${EDIT_TYPES[@]}"; do
    edit_file="${EDIT_OUTPUT_DIR}/${et}_unsound.jsonl"
    model_name=$(echo "$MODEL_ID" | sed 's|.*/||' | tr '.' '-')
    output_file="${RESULTS_DIR}/${model_name}-SEU_${et}_output.jsonl"
    scored_file="${SCORED_DIR}/${et}.scored.jsonl"
    summary_file="${SUMMARY_DIR}/${et}.summary.json"

    if [ ! -f "$output_file" ]; then
        echo "    Skipping $et scoring (no inference output)"
        continue
    fi

    echo "    Scoring $et..."
    python3 symbol_edit/score_symbol_edit.py \
        --edit_type "$et" \
        --edited_dataset "$edit_file" \
        --output_jsonl "$output_file" \
        --scored_output "$scored_file" \
        --model "$GEMINI_MODEL"

    echo "    Aggregating $et..."
    python3 symbol_edit/aggregate_symbol_edit_results.py \
        --input "$scored_file" \
        --output "$summary_file" \
        --break_by_family
done

# ─── Final Summary ───────────────────────────────────────────────────────────
echo ""
echo "============================================"
echo "  FINAL RESULTS"
echo "============================================"
for et in "${EDIT_TYPES[@]}"; do
    summary_file="${SUMMARY_DIR}/${et}.summary.json"
    if [ -f "$summary_file" ]; then
        echo ""
        echo "--- $et ---"
        cat "$summary_file"
    fi
done
echo ""
echo "Pipeline complete."
