#!/usr/bin/env bash
set -euo pipefail

# Policy sensitivity sweep for continuous_decay (fixed N=20).

PROMPTS_FILE=${PROMPTS_FILE:-outputs/stage2/prompts/wjb_adv_good_intervened.json}
OUT_DIR=${OUT_DIR:-outputs/stage2/policy_sensitivity}
BASE_MODEL=${BASE_MODEL:-Qwen/Qwen2.5-7B-Instruct}
DETECTORS=${DETECTORS:-configs/detectors_stage1_qwenhx_deficit.json}
CONTROLLER=${CONTROLLER:-configs/controller_stage2_star_qwenhx_stride2_epsm0p5_es2.json}
SELECTED_DIMS=${SELECTED_DIMS:-configs/selected_dims_12.json}
SCORER_DIR=${SCORER_DIR:-outputs/aegis_scorer/ckpt_aegis_scorer_v2_20260117_192621}
MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-256}
MAX_PROMPTS=${MAX_PROMPTS:-80}
SEED=${SEED:-2026}
TEMPERATURE=${TEMPERATURE:-1.0}
TOP_P=${TOP_P:-1.0}
TOP_K=${TOP_K:-0}
NUM_SHARDS=${NUM_SHARDS:-8}
GPUS=${GPUS:-0,1,2,3,4,5,6,7}
N=${N:-20}

mkdir -p "${OUT_DIR}"

for POLICY in score_max minS mid_window peak_u; do
  OUT_PATH="${OUT_DIR}/cont_N${N}_policy_${POLICY}.jsonl"
  echo "[run] N=${N} policy=${POLICY} -> ${OUT_PATH}"
  python scripts/run_stage2_sharded_v1.py \
    --runner scripts/run_stage2_joint_cbf_continuous_decay.py \
    --prompts_file "${PROMPTS_FILE}" \
    --output_path "${OUT_PATH}" \
    --num_shards "${NUM_SHARDS}" \
    --gpus "${GPUS}" \
    --overwrite \
    --detectors_config "${DETECTORS}" \
    --controller_config "${CONTROLLER}" \
    --selected_dims "${SELECTED_DIMS}" \
    --scorer_dir "${SCORER_DIR}" \
    --base_model "${BASE_MODEL}" \
    --max_prompts "${MAX_PROMPTS}" \
    --max_new_tokens "${MAX_NEW_TOKENS}" \
    --seed "${SEED}" \
    -- \
    --continuous_steps "${N}" \
    --forcing_enabled 1 \
    --slack_enabled 1 \
    --policy_name "${POLICY}" \
    --temperature "${TEMPERATURE}" \
    --top_p "${TOP_P}" \
    --top_k "${TOP_K}"
done
