#!/usr/bin/env bash
set -euo pipefail

# Module B: forcing/slack ablation for continuous_decay (N=20)

PROMPTS_FILE=${PROMPTS_FILE:-outputs/stage2/prompts/wjb_adv_good_intervened.json}
OUT_DIR=${OUT_DIR:-outputs/stage2/moduleB_ablation}
BASE_MODEL=${BASE_MODEL:-Qwen/Qwen2.5-7B-Instruct}
DETECTORS=${DETECTORS:-configs/detectors_stage1_qwenhx_deficit.json}
CONTROLLER=${CONTROLLER:-configs/controller_stage2_star_qwenhx_stride2_epsm0p5_es2.json}
SELECTED_DIMS=${SELECTED_DIMS:-configs/selected_dims_12.json}
SCORER_DIR=${SCORER_DIR:-outputs/aegis_scorer/ckpt_aegis_scorer_v2_20260117_192621}
MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-256}
SEED=${SEED:-2026}
TEMPERATURE=${TEMPERATURE:-1.0}
TOP_P=${TOP_P:-1.0}
TOP_K=${TOP_K:-0}
NUM_SHARDS=${NUM_SHARDS:-8}
GPUS=${GPUS:-0,1,2,3,4,5,6,7}
N=${N:-20}
PROGRESS=${PROGRESS:-0}
POLL_SECS=${POLL_SECS:-30}

mkdir -p "${OUT_DIR}"

for F in 1 0; do
  for S in 1 0; do
    OUT_PATH="${OUT_DIR}/cont_N${N}_f${F}_s${S}.jsonl"
    echo "[run] N=${N} forcing=${F} slack=${S} -> ${OUT_PATH}"
    if [[ "${PROGRESS}" == "1" ]]; then
      python scripts/run_stage2_sharded_v1.py \
        --runner scripts/run_stage2_joint_cbf_continuous_decay.py \
        --prompts_file "${PROMPTS_FILE}" \
        --output_path "${OUT_PATH}" \
        --num_shards "${NUM_SHARDS}" \
        --gpus "${GPUS}" \
        --summary_script "" \
        --overwrite \
        --detectors_config "${DETECTORS}" \
        --controller_config "${CONTROLLER}" \
        --selected_dims "${SELECTED_DIMS}" \
        --scorer_dir "${SCORER_DIR}" \
        --base_model "${BASE_MODEL}" \
        --max_new_tokens "${MAX_NEW_TOKENS}" \
        --seed "${SEED}" \
        --continuous_steps "${N}" \
        -- --temperature "${TEMPERATURE}" --top_p "${TOP_P}" --top_k "${TOP_K}" --forcing_enabled "${F}" --slack_enabled "${S}" &
      pid=$!
      shard_root="$(dirname "${OUT_PATH}")/$(basename "${OUT_PATH}" .jsonl)_shards"
      while kill -0 "${pid}" 2>/dev/null; do
        if [[ -d "${shard_root}" ]]; then
          lines=$(wc -l "${shard_root}"/output_shard_*.jsonl 2>/dev/null | awk '{s+=$1} END{print s+0}')
          echo "[progress] ${OUT_PATH} lines=${lines}"
        fi
        sleep "${POLL_SECS}"
      done
      wait "${pid}"
    else
      python scripts/run_stage2_sharded_v1.py \
        --runner scripts/run_stage2_joint_cbf_continuous_decay.py \
        --prompts_file "${PROMPTS_FILE}" \
        --output_path "${OUT_PATH}" \
        --num_shards "${NUM_SHARDS}" \
        --gpus "${GPUS}" \
        --summary_script "" \
        --overwrite \
        --detectors_config "${DETECTORS}" \
        --controller_config "${CONTROLLER}" \
        --selected_dims "${SELECTED_DIMS}" \
        --scorer_dir "${SCORER_DIR}" \
        --base_model "${BASE_MODEL}" \
        --max_new_tokens "${MAX_NEW_TOKENS}" \
        --seed "${SEED}" \
        --continuous_steps "${N}" \
        -- --temperature "${TEMPERATURE}" --top_p "${TOP_P}" --top_k "${TOP_K}" --forcing_enabled "${F}" --slack_enabled "${S}"
    fi
  done
done
