#!/usr/bin/env bash
set -euo pipefail

# Appendix: decoding sweep for continuous_decay (N=20).

PROMPTS_FILE=${PROMPTS_FILE:-outputs/stage2/prompts/wjb_adv_good_intervened.json}
OUTDIR=${OUTDIR:-outputs/stage2/appendix_decode_sweep_simple}
BASE_MODEL=${BASE_MODEL:-Qwen/Qwen2.5-7B-Instruct}
DETECTORS=${DETECTORS:-configs/detectors_stage1_qwenhx_deficit.json}
CONTROLLER=${CONTROLLER:-configs/controller_stage2_star_qwenhx_stride2_epsm0p5_es2.json}
SELECTED_DIMS=${SELECTED_DIMS:-configs/selected_dims_12.json}
SCORER_DIR=${SCORER_DIR:-outputs/aegis_scorer/ckpt_aegis_scorer_v2_20260117_192621}
MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-128}
MAX_PROMPTS=${MAX_PROMPTS:-80}
SEED=${SEED:-0}
TOP_K=${TOP_K:-0}
N=${N:-20}
USE_SHARDED=${USE_SHARDED:-0}
NUM_SHARDS=${NUM_SHARDS:-8}
GPUS=${GPUS:-0,1,2,3,4,5,6,7}

mkdir -p "${OUTDIR}"

T_VALUES=("0.7" "1.0" "1.3")
P_VALUES=("0.9" "1.0")

for T in "${T_VALUES[@]}"; do
  for P in "${P_VALUES[@]}"; do
    T_TAG=${T/./p}
    P_TAG=${P/./p}
    OUT_PATH="${OUTDIR}/cont_N20_T${T_TAG}_P${P_TAG}.jsonl"
    echo "[run] T=${T} P=${P} -> ${OUT_PATH}"
    if [[ "${USE_SHARDED}" == "1" ]]; then
      python scripts/run_stage2_sharded_v1.py \
        --runner scripts/run_stage2_joint_cbf_continuous_decay.py \
        --prompts_file "${PROMPTS_FILE}" \
        --output_path "${OUT_PATH}" \
        --num_shards "${NUM_SHARDS}" \
        --gpus "${GPUS}" \
        --summary_script "" \
        --overwrite \
        --detectors_config "${DETECTORS}" \
        --controller_config "${CONTROLLER}" \
        --selected_dims "${SELECTED_DIMS}" \
        --scorer_dir "${SCORER_DIR}" \
        --base_model "${BASE_MODEL}" \
        --max_prompts "${MAX_PROMPTS}" \
        --max_new_tokens "${MAX_NEW_TOKENS}" \
        --seed "${SEED}" \
        --continuous_steps "${N}" \
        -- --temperature "${T}" --top_p "${P}" --top_k "${TOP_K}"
      RUN_ARGS_PATH="${OUT_PATH%.jsonl}.run_args.json"
      if [[ ! -f "${RUN_ARGS_PATH}" ]]; then
        python - <<PY
import json
payload = {
    "prompts_file": "${PROMPTS_FILE}",
    "output_path": "${OUT_PATH}",
    "continuous_steps": int("${N}"),
    "temperature": float("${T}"),
    "top_p": float("${P}"),
    "top_k": int("${TOP_K}"),
    "seed": int("${SEED}"),
    "max_new_tokens": int("${MAX_NEW_TOKENS}"),
    "max_prompts": int("${MAX_PROMPTS}"),
    "base_model": "${BASE_MODEL}",
    "detectors_config": "${DETECTORS}",
    "controller_config": "${CONTROLLER}",
    "selected_dims": "${SELECTED_DIMS}",
    "scorer_dir": "${SCORER_DIR}",
    "use_sharded": True,
    "num_shards": int("${NUM_SHARDS}"),
    "gpus": "${GPUS}",
}
with open("${RUN_ARGS_PATH}", "w", encoding="utf-8") as f:
    json.dump(payload, f, ensure_ascii=False, indent=2)
print("wrote", "${RUN_ARGS_PATH}")
PY
      fi
    else
      python scripts/run_stage2_joint_cbf_continuous_decay.py \
        --prompts_file "${PROMPTS_FILE}" \
        --output_path "${OUT_PATH}" \
        --detectors_config "${DETECTORS}" \
        --controller_config "${CONTROLLER}" \
        --selected_dims "${SELECTED_DIMS}" \
        --scorer_dir "${SCORER_DIR}" \
        --base_model "${BASE_MODEL}" \
        --max_prompts "${MAX_PROMPTS}" \
        --max_new_tokens "${MAX_NEW_TOKENS}" \
        --seed "${SEED}" \
        --temperature "${T}" \
        --top_p "${P}" \
        --top_k "${TOP_K}" \
        --continuous_steps "${N}"
    fi
  done
done
