#!/usr/bin/env bash
set -euo pipefail

# Train eworm v4_val starting from the bundled seed pair under ./seeds/.
#
# This script intentionally DOES NOT use --resume (so epoch starts at 0).
# It writes ALL outputs into a new output directory and never overwrites the two global npy files.
#
# Usage:
#   ./run.sh
#   ./run.sh --epochs 5 --tstop-ms 200
#   ./run.sh --out ./runs/my_run --suffix from0125 --epochs 50
#   ./run.sh --heliox-device gpu --heliox-permute-type 3
#   ./run.sh --resume
#   ./run.sh --replay 1
#   ./run.sh --k-mul 1 --k-len 160
#   ./run.sh --k-mul 10 --k-max-t-ms 120
#   ./run.sh --clip-strategy 1 --clip-threshold 1e6
#   ./run.sh --clip-strategy 2 --clip-every 10
#
# Notes:
# - Requires: ./trial10/000_circuit_search_config.json + ./trial10/sample_#0_circuit_old.pkl

SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd -P)"
cd "$SCRIPT_DIR"

# Allow overriding python executable for portability.
PYTHON="${PYTHON:-python3}"

# NEURON Python bindings are not always installed site-wide. If a local NEURON
# install exists (e.g. $HOME/nrn/install), add it to PYTHONPATH by default.
NRN_PYTHON_LIB="${NRN_PYTHON_LIB:-$HOME/nrn/install/lib/python}"
if [[ -d "$NRN_PYTHON_LIB" ]]; then
  export PYTHONPATH="$NRN_PYTHON_LIB:${PYTHONPATH:-}"
fi

EPOCHS="50"
TSTOP_MS="5000"
K_MUL="${EWORM_K_MUL:-5}"
K_LEN="${EWORM_K_LEN:-}"
K_MAX_T_MS="${EWORM_K_MAX_T_MS:-}"
PREFIX="eworm"
SUFFIX="from0125"
OUT_DIR=""
BASE_TRIAL_DIR="./data/trial10"
RESUME_FROM_DIR=""

# Prefer explicit HELIOX_PYTHON_LIB. If not set, try to locate the bundled HelioX
# (this demo folder is shipped alongside ../heliox/ in the bundle layout).
if [[ -z "${HELIOX_PYTHON_LIB:-}" ]]; then
  CAND_HELIOX_LIB="${SCRIPT_DIR}/../heliox/python_lib"
  if [[ -d "${CAND_HELIOX_LIB}" ]]; then
    HELIOX_PYTHON_LIB="${CAND_HELIOX_LIB}"
  fi
fi
HELIOX_PYTHON_LIB="${HELIOX_PYTHON_LIB:-}"
# Training is HelioX-only. NEURON is used for frontend/export only.
HELIOX_DEVICE="${HELIOX_DEVICE:-gpu}"          # cpu|gpu
HELIOX_PERMUTE_TYPE="${HELIOX_PERMUTE_TYPE:-3}" # matches HelioX defaults (GPU permute3)
EWORM_HELIOX_EXPORT_PATH="${EWORM_HELIOX_EXPORT_PATH:-}" # optional override
REPLAY="${EWORM_REPLAY:-0}"
REPLAY_USE_VECPLAY="1"
REPLAY_STREAMING="${EWORM_REPLAY_STREAMING:-0}"
REPORT_GPU_MEM="${EWORM_REPORT_GPU_MEM:-0}"
REPLAY_CACHE_SIGNALS="${EWORM_REPLAY_CACHE_SIGNALS:-0}"
PRINT_EPOCH_TIME="${EWORM_PRINT_EPOCH_TIME:-0}"

# Backend replay gradient clip controls (HelioX-only).
# Strategy:
#   0: disable (NOT recommended; can explode)
#   1: check every LR tick (most stable; default)
#   2: check every N ticks (faster; set via --clip-every)
CLIP_STRATEGY="${EWORM_REPLAY_CLIP_STRATEGY:-1}"
CLIP_EVERY="${EWORM_REPLAY_CLIP_CHECK_EVERY:-1}"
CLIP_THRESHOLD="${EWORM_REPLAY_GRAD_L2NORM_THRESHOLD:-1e6}"

ALPHA_W_SCALE="${EWORM_ALPHA_W_SCALE:-1.0}"
ALPHA_X_SCALE="${EWORM_ALPHA_X_SCALE:-1.0}"
PLATEAU_LR_MULT="${EWORM_PLATEAU_LR_MULTIPLIER:-0.3}"
PLATEAU_PATIENCE="${EWORM_PLATEAU_PATIENCE_EPOCHS:-5}"
PLATEAU_RESET_ADAM="${EWORM_PLATEAU_RESET_ADAM:-none}" # none|w|x|both

FREEZE_W="${EWORM_FREEZE_W:-0}"
FREEZE_X="${EWORM_FREEZE_X:-0}"
X_L2_COEF="${EWORM_X_L2_COEF:-1e-1}"
X_UPDATE_EVERY="${EWORM_X_UPDATE_EVERY:-1}"
X_UPDATE_BURST="${EWORM_X_UPDATE_BURST:-1}"
X_UPDATE_OFFSET="${EWORM_X_UPDATE_OFFSET:-0}"

PERCISE="${EWORM_PERCISE:-1}"
FAST_DVDW="${EWORM_FAST_DVDW:-1}"
GLOBAL_STREAM="${EWORM_GLOBAL_STREAM:-1}"
GLOBAL_CHUNK="${EWORM_GLOBAL_CHUNK:-256}"

PRINT_ENV="0"
DEBUG="0"
DEBUG_BLOCKING="0"
CAPTURE_PROGRESS_EVERY=""
REPLAY_PROGRESS_EVERY=""
RESUME="0"
RESUME_EPOCHS=""
RESUME_START_EPOCH=""
RESUME_RESTORE_OPT="0"
RESUME_RESTORE_VMIN="0"
RESUME_RESTORE_RUN_BEST="0"
RESUME_RESET_ADAM="none" # none|w|x|both
RESUME_SET_ALPHA_MULT=""

die() {
  echo "ERROR: $*" >&2
  exit 1
}

python_preflight() {
  HELIOX_PYTHON_LIB="$HELIOX_PYTHON_LIB" "${PYTHON}" - <<'PY'
import importlib
import os
import sys

def require(mod: str, hint: str) -> None:
    try:
        importlib.import_module(mod)
    except Exception as e:
        print(f"ERROR: missing python module: {mod} ({e})", file=sys.stderr)
        print(f"  Hint: {hint}", file=sys.stderr)
        sys.exit(2)

require("numpy", "pip install numpy  (or: pip install -r requirements.txt)")
# NOTE: import order matters on some systems: importing torch before neuron can segfault.
require("neuron", "pip install neuron (or install NEURON + Python bindings)")
require("torch", "pip install torch  (or: pip install -r requirements.txt)")
require("tqdm", "pip install tqdm  (or: pip install -r requirements.txt)")
require("scipy", "pip install scipy  (or: pip install -r requirements.txt)")

try:
    import opt_einsum  # noqa: F401
except Exception:
    print("WARN: opt_einsum not found; falling back to torch.einsum (may be slower).", file=sys.stderr)
    print("  Install: pip install opt_einsum  (or: pip install -r requirements.txt)", file=sys.stderr)

heliox_lib = os.environ.get("HELIOX_PYTHON_LIB", "").strip()
if not heliox_lib:
    print("ERROR: HELIOX_PYTHON_LIB is not set.", file=sys.stderr)
    sys.exit(2)
if not os.path.isdir(os.path.expanduser(heliox_lib)):
    print(f"ERROR: HELIOX_PYTHON_LIB not found: {heliox_lib}", file=sys.stderr)
    sys.exit(2)
PY
}

abspath() {
  "${PYTHON}" - "$1" <<'PY'
import os, sys
print(os.path.abspath(os.path.expanduser(sys.argv[1])))
PY
}

ckpt_start_epoch() {
  local out_dir="$1"
  local prefix="$2"
  local suffix="$3"
  "${PYTHON}" - "$out_dir" "$prefix" "$suffix" <<'PY'
import sys
import numpy as np

out_dir, prefix, suffix = sys.argv[1:4]
path = f"{out_dir.rstrip('/')}/ckpt_{prefix}_{suffix}.npz"
ckpt = np.load(path, allow_pickle=True)
start_epoch = int(ckpt["start_epoch"].item())
print(start_epoch)
PY
}

make_default_out_dir() {
  local prefix="$1"
  local suffix="$2"
  local ts
  ts="$(date +%Y%m%d_%H%M%S)"
  echo "./runs/${prefix}_${suffix}_${ts}"
}

while [[ $# -gt 0 ]]; do
  case "$1" in
    --heliox-device)
      HELIOX_DEVICE="${2:?missing value for --heliox-device}"
      shift 2
      ;;
    --heliox-permute-type)
      HELIOX_PERMUTE_TYPE="${2:?missing value for --heliox-permute-type}"
      shift 2
      ;;
    --heliox-export-path)
      EWORM_HELIOX_EXPORT_PATH="${2:?missing value for --heliox-export-path}"
      shift 2
      ;;
    --resume)
      RESUME="1"
      shift 1
      ;;
    --resume-epochs)
      RESUME_EPOCHS="${2:?missing value for --resume-epochs (int)}"
      shift 2
      ;;
    --resume-from)
      RESUME_FROM_DIR="${2:?missing value for --resume-from (dir)}"
      shift 2
      ;;
    --resume-start-epoch)
      RESUME_START_EPOCH="${2:?missing value for --resume-start-epoch}"
      shift 2
      ;;
    --resume-restore-opt)
      RESUME_RESTORE_OPT="1"
      shift 1
      ;;
    --resume-restore-vmin)
      RESUME_RESTORE_VMIN="1"
      shift 1
      ;;
    --resume-restore-run-best)
      RESUME_RESTORE_RUN_BEST="1"
      shift 1
      ;;
    --resume-reset-adam)
      RESUME_RESET_ADAM="${2:?missing value for --resume-reset-adam (none|w|x|both)}"
      shift 2
      ;;
    --resume-set-alpha-mult)
      RESUME_SET_ALPHA_MULT="${2:?missing value for --resume-set-alpha-mult}"
      shift 2
      ;;
    --epochs)
      EPOCHS="${2:?missing value for --epochs}"
      shift 2
      ;;
    --tstop-ms)
      TSTOP_MS="${2:?missing value for --tstop-ms}"
      shift 2
      ;;
    --k-mul)
      K_MUL="${2:?missing value for --k-mul (int)}"
      shift 2
      ;;
    --k-len)
      K_LEN="${2:?missing value for --k-len (int)}"
      shift 2
      ;;
    --k-max-t-ms)
      K_MAX_T_MS="${2:?missing value for --k-max-t-ms (float)}"
      shift 2
      ;;
    --prefix)
      PREFIX="${2:?missing value for --prefix}"
      shift 2
      ;;
    --suffix)
      SUFFIX="${2:?missing value for --suffix}"
      shift 2
      ;;
    --out)
      OUT_DIR="${2:?missing value for --out}"
      shift 2
      ;;
    --base-trial)
      BASE_TRIAL_DIR="${2:?missing value for --base-trial}"
      shift 2
      ;;
    --print-env)
      PRINT_ENV="1"
      shift 1
      ;;
    --debug)
      DEBUG="1"
      shift 1
      ;;
    --debug-blocking)
      DEBUG_BLOCKING="1"
      shift 1
      ;;
    --debug-capture-every)
      CAPTURE_PROGRESS_EVERY="${2:?missing value for --debug-capture-every (int)}"
      shift 2
      ;;
    --debug-replay-every)
      REPLAY_PROGRESS_EVERY="${2:?missing value for --debug-replay-every (int)}"
      shift 2
      ;;
    --print-epoch-time)
      PRINT_EPOCH_TIME="1"
      shift 1
      ;;
    --no-print-epoch-time)
      PRINT_EPOCH_TIME="0"
      shift 1
      ;;
    --replay)
      REPLAY="${2:?missing value for --replay (0|1)}"
      shift 2
      ;;
    --replay-streaming)
      REPLAY_STREAMING="${2:?missing value for --replay-streaming (0|1)}"
      shift 2
      ;;
    --replay-cache-signals)
      REPLAY_CACHE_SIGNALS="${2:?missing value for --replay-cache-signals (0|1)}"
      shift 2
      ;;
    --report-gpu-mem)
      REPORT_GPU_MEM="${2:?missing value for --report-gpu-mem (0|1)}"
      shift 2
      ;;
    --clip-strategy)
      CLIP_STRATEGY="${2:?missing value for --clip-strategy (0|1|2)}"
      shift 2
      ;;
    --clip-every)
      CLIP_EVERY="${2:?missing value for --clip-every (int>=1)}"
      shift 2
      ;;
    --clip-threshold)
      CLIP_THRESHOLD="${2:?missing value for --clip-threshold (float)}"
      shift 2
      ;;
    --alpha-w-scale)
      ALPHA_W_SCALE="${2:?missing value for --alpha-w-scale (float)}"
      shift 2
      ;;
    --alpha-x-scale)
      ALPHA_X_SCALE="${2:?missing value for --alpha-x-scale (float)}"
      shift 2
      ;;
    --plateau-lr-mult)
      PLATEAU_LR_MULT="${2:?missing value for --plateau-lr-mult (float)}"
      shift 2
      ;;
    --plateau-patience)
      PLATEAU_PATIENCE="${2:?missing value for --plateau-patience (int)}"
      shift 2
      ;;
    --plateau-reset-adam)
      PLATEAU_RESET_ADAM="${2:?missing value for --plateau-reset-adam (none|w|x|both)}"
      shift 2
      ;;
    --freeze-w)
      FREEZE_W="1"
      shift 1
      ;;
    --freeze-x)
      FREEZE_X="1"
      shift 1
      ;;
    --x-l2)
      X_L2_COEF="${2:?missing value for --x-l2 (float)}"
      shift 2
      ;;
    --x-update-every)
      X_UPDATE_EVERY="${2:?missing value for --x-update-every (int>=1)}"
      shift 2
      ;;
    --x-update-burst)
      X_UPDATE_BURST="${2:?missing value for --x-update-burst (int>=1)}"
      shift 2
      ;;
    --x-update-offset)
      X_UPDATE_OFFSET="${2:?missing value for --x-update-offset (int)}"
      shift 2
      ;;
    -h|--help)
      sed -n '1,140p' "$0"
      exit 0
      ;;
    *)
      die "unknown arg: $1 (use --help)"
      ;;
  esac
done

if [[ "$PYTHON" == */* ]]; then
  [[ -x "$PYTHON" ]] || die "python not found/executable: $PYTHON"
else
  command -v "$PYTHON" >/dev/null 2>&1 || die "python not found on PATH: $PYTHON"
fi
[[ -d "$HELIOX_PYTHON_LIB" ]] || die "HELIOX_PYTHON_LIB not found: $HELIOX_PYTHON_LIB"
python_preflight
if [[ "${EWORM_SIM_BACKEND:-heliox}" != "heliox" ]]; then
  die "NEURON stepping backend is removed; do not set EWORM_SIM_BACKEND (got: ${EWORM_SIM_BACKEND:-})"
fi
if [[ "$HELIOX_DEVICE" != "cpu" && "$HELIOX_DEVICE" != "gpu" ]]; then
  die "--heliox-device must be cpu|gpu (got: $HELIOX_DEVICE)"
fi
[[ "$HELIOX_PERMUTE_TYPE" =~ ^[0-9]+$ ]] || die "--heliox-permute-type must be an int (got: $HELIOX_PERMUTE_TYPE)"
if [[ -n "$RESUME_EPOCHS" && "$RESUME" != "1" ]]; then
  die "--resume-epochs requires --resume"
fi

if [[ -z "$OUT_DIR" ]]; then
  OUT_DIR="$(make_default_out_dir "$PREFIX" "$SUFFIX")"
fi

ABS_OUT_DIR="$(abspath "$OUT_DIR")"
ABS_RECORDS_DIR="$(abspath "${SCRIPT_DIR}/records")"
if [[ "${ABS_OUT_DIR}" == "${ABS_RECORDS_DIR}"* ]]; then
  die "--out must not point into ./records (keep records read-only; use ./runs/...)"
fi

# If out dir exists, append a numeric suffix to avoid overwriting.
if [[ -e "$OUT_DIR" ]]; then
  i=1
  while [[ -e "${OUT_DIR}_${i}" ]]; do
    i=$((i + 1))
  done
  OUT_DIR="${OUT_DIR}_${i}"
fi

CFG_SRC="${BASE_TRIAL_DIR%/}/000_circuit_search_config.json"
CONN_SRC="${BASE_TRIAL_DIR%/}/sample_#0_circuit_old.pkl"
[[ -f "$CFG_SRC" ]] || die "missing config: $CFG_SRC"
[[ -f "$CONN_SRC" ]] || die "missing connection: $CONN_SRC"

mkdir -p "$OUT_DIR"
cp -a "$CFG_SRC" "$OUT_DIR/"
cp -a "$CONN_SRC" "$OUT_DIR/"
# Reuse the precomputed K if present (avoid an expensive rebuild on first run).
for kfile in "${BASE_TRIAL_DIR%/}"/K_eworm_v4_x*.npz; do
  if [[ -f "$kfile" ]]; then
    cp -a "$kfile" "$OUT_DIR/"
  fi
done

# If resuming into a new output directory, copy checkpoint + latest train artifacts from RESUME_FROM_DIR.
if [[ "${RESUME}" == "1" && -n "${RESUME_FROM_DIR}" ]]; then
  RESUME_FROM_DIR="$(abspath "${RESUME_FROM_DIR}")"
  if [[ "$(abspath "${OUT_DIR}")" == "${RESUME_FROM_DIR}" ]]; then
    die "--out must differ from --resume-from (refuse to overwrite the source record)"
  fi
  CKPT_SRC="${RESUME_FROM_DIR%/}/ckpt_${PREFIX}_${SUFFIX}.npz"
  [[ -f "${CKPT_SRC}" ]] || die "resume-from missing checkpoint: ${CKPT_SRC}"
  # Only copy if destination doesn't already have a checkpoint.
  if [[ ! -f "${OUT_DIR%/}/ckpt_${PREFIX}_${SUFFIX}.npz" ]]; then
    cp -a "${CKPT_SRC}" "${OUT_DIR%/}/"
  fi
  for f in \
    "weights_train_${PREFIX}_${SUFFIX}.npy" \
    "x_train_${PREFIX}_${SUFFIX}.npy" \
    "error_${PREFIX}_${SUFFIX}.npy" \
    "weights_optimal_${PREFIX}_${SUFFIX}.npy" \
    "x_optimal_${PREFIX}_${SUFFIX}.npy" \
    "run_best_${PREFIX}_${SUFFIX}.npz" \
    "plateau_vmin_${PREFIX}_${SUFFIX}.npz"; do
    if [[ -f "${RESUME_FROM_DIR%/}/${f}" && ! -f "${OUT_DIR%/}/${f}" ]]; then
      cp -a "${RESUME_FROM_DIR%/}/${f}" "${OUT_DIR%/}/"
    fi
  done
fi

export MPLBACKEND="${MPLBACKEND:-Agg}"
export HELIOX_PYTHON_LIB
export HELIOX_DEVICE
export HELIOX_PERMUTE_TYPE
if [[ -n "$EWORM_HELIOX_EXPORT_PATH" ]]; then
  export EWORM_HELIOX_EXPORT_PATH
fi
export EWORM_REPLAY="$REPLAY"
export EWORM_REPLAY_USE_VECPLAY="$REPLAY_USE_VECPLAY"
export EWORM_REPLAY_STREAMING="$REPLAY_STREAMING"
export EWORM_REPORT_GPU_MEM="$REPORT_GPU_MEM"
export EWORM_PRINT_EPOCH_TIME="$PRINT_EPOCH_TIME"
export EWORM_REPLAY_CACHE_SIGNALS="$REPLAY_CACHE_SIGNALS"
export EWORM_REPLAY_CLIP_STRATEGY="$CLIP_STRATEGY"
export EWORM_REPLAY_CLIP_CHECK_EVERY="$CLIP_EVERY"
export EWORM_REPLAY_GRAD_L2NORM_THRESHOLD="$CLIP_THRESHOLD"
export EWORM_ALPHA_W_SCALE="$ALPHA_W_SCALE"
export EWORM_ALPHA_X_SCALE="$ALPHA_X_SCALE"
export EWORM_PLATEAU_LR_MULTIPLIER="$PLATEAU_LR_MULT"
export EWORM_PLATEAU_PATIENCE_EPOCHS="$PLATEAU_PATIENCE"
export EWORM_PLATEAU_RESET_ADAM="$PLATEAU_RESET_ADAM"
export EWORM_FREEZE_W="$FREEZE_W"
export EWORM_FREEZE_X="$FREEZE_X"
export EWORM_X_L2_COEF="$X_L2_COEF"
export EWORM_X_UPDATE_EVERY="$X_UPDATE_EVERY"
export EWORM_X_UPDATE_BURST="$X_UPDATE_BURST"
export EWORM_X_UPDATE_OFFSET="$X_UPDATE_OFFSET"
export EWORM_RUN_TEST_AFTER_TRAIN="${EWORM_RUN_TEST_AFTER_TRAIN:-0}"
export EWORM_TSTOP_MS="$TSTOP_MS"
export EWORM_K_MUL="$K_MUL"
if [[ -n "${K_LEN}" ]]; then
  export EWORM_K_LEN="$K_LEN"
fi
if [[ -n "${K_MAX_T_MS}" ]]; then
  export EWORM_K_MAX_T_MS="$K_MAX_T_MS"
fi
export EWORM_PERCISE="$PERCISE"
export EWORM_FAST_DVDW="$FAST_DVDW"
export EWORM_GLOBAL_STREAM="$GLOBAL_STREAM"
export EWORM_GLOBAL_CHUNK="$GLOBAL_CHUNK"

# Backend-side debug/progress (optional).
# These are consumed by HelioX runtime (C++/CUDA) and can be enabled in a portable way for remote machines.
if [[ "$DEBUG" == "1" ]]; then
  export HELIOX_LEARN_DEBUG=1
  export HELIOX_LEARN_CAPTURE_PROGRESS_EVERY="${CAPTURE_PROGRESS_EVERY:-200}"
  export HELIOX_LEARN_REPLAY_PROGRESS_EVERY="${REPLAY_PROGRESS_EVERY:-200}"
  # Worm-side debug: K loading/synthesis/compute progress + verbose logs.
  export EWORM_VERBOSE="${EWORM_VERBOSE:-1}"
  export EWORM_PRINT_K_PROGRESS="${EWORM_PRINT_K_PROGRESS:-1}"
  export EWORM_K_PROGRESS_EVERY_BLOCKS="${EWORM_K_PROGRESS_EVERY_BLOCKS:-8}"
fi
if [[ "$DEBUG_BLOCKING" == "1" ]]; then
  export CUDA_LAUNCH_BLOCKING=1
fi

# Optional: NVIDIA HPC SDK runtime deps (some builds may depend on OpenACC runtime).
if [[ -d "/opt/nvidia/hpc_sdk" ]]; then
  export LD_LIBRARY_PATH="/opt/nvidia/hpc_sdk/Linux_x86_64/25.11/compilers/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/25.3/cuda/12.8/targets/x86_64-linux/lib:/opt/nvidia/hpc_sdk/Linux_x86_64/25.3/math_libs/12.8/targets/x86_64-linux/lib:${LD_LIBRARY_PATH:-}"
fi

if [[ "$PRINT_ENV" == "1" ]]; then
  echo "PYTHON=$PYTHON"
  echo "OUT_DIR=$OUT_DIR"
  echo "PREFIX=$PREFIX"
  echo "SUFFIX=$SUFFIX"
  echo "EPOCHS=$EPOCHS"
  echo "RESUME=$RESUME"
  if [[ -n "$RESUME_START_EPOCH" ]]; then
    echo "RESUME_START_EPOCH=$RESUME_START_EPOCH"
  fi
  echo "RESUME_RESTORE_OPT=$RESUME_RESTORE_OPT"
  echo "RESUME_RESET_ADAM=$RESUME_RESET_ADAM"
  if [[ -n "$RESUME_SET_ALPHA_MULT" ]]; then
    echo "RESUME_SET_ALPHA_MULT=$RESUME_SET_ALPHA_MULT"
  fi
  echo "HELIOX_DEVICE=$HELIOX_DEVICE"
  echo "HELIOX_PERMUTE_TYPE=$HELIOX_PERMUTE_TYPE"
  if [[ -n "${EWORM_HELIOX_EXPORT_PATH:-}" ]]; then
    echo "EWORM_HELIOX_EXPORT_PATH=$EWORM_HELIOX_EXPORT_PATH"
  fi
  echo "EWORM_REPLAY=$EWORM_REPLAY"
  echo "EWORM_REPLAY_USE_VECPLAY=$EWORM_REPLAY_USE_VECPLAY"
  echo "EWORM_REPLAY_STREAMING=$EWORM_REPLAY_STREAMING"
  echo "EWORM_REPORT_GPU_MEM=$EWORM_REPORT_GPU_MEM"
  echo "EWORM_REPLAY_CACHE_SIGNALS=$EWORM_REPLAY_CACHE_SIGNALS"
  echo "EWORM_REPLAY_CLIP_STRATEGY=$EWORM_REPLAY_CLIP_STRATEGY"
  echo "EWORM_REPLAY_CLIP_CHECK_EVERY=$EWORM_REPLAY_CLIP_CHECK_EVERY"
  echo "EWORM_REPLAY_GRAD_L2NORM_THRESHOLD=$EWORM_REPLAY_GRAD_L2NORM_THRESHOLD"
  echo "EWORM_ALPHA_W_SCALE=$EWORM_ALPHA_W_SCALE"
  echo "EWORM_ALPHA_X_SCALE=$EWORM_ALPHA_X_SCALE"
  echo "EWORM_PLATEAU_LR_MULTIPLIER=$EWORM_PLATEAU_LR_MULTIPLIER"
  echo "EWORM_PLATEAU_PATIENCE_EPOCHS=$EWORM_PLATEAU_PATIENCE_EPOCHS"
  echo "EWORM_PLATEAU_RESET_ADAM=$EWORM_PLATEAU_RESET_ADAM"
  echo "EWORM_FREEZE_W=$EWORM_FREEZE_W"
  echo "EWORM_FREEZE_X=$EWORM_FREEZE_X"
  echo "EWORM_X_L2_COEF=$EWORM_X_L2_COEF"
  echo "EWORM_X_UPDATE_EVERY=$EWORM_X_UPDATE_EVERY"
  echo "EWORM_X_UPDATE_BURST=$EWORM_X_UPDATE_BURST"
  echo "EWORM_X_UPDATE_OFFSET=$EWORM_X_UPDATE_OFFSET"
  echo "EWORM_TSTOP_MS=$EWORM_TSTOP_MS"
  echo "EWORM_K_MUL=$EWORM_K_MUL"
  if [[ -n "${EWORM_K_LEN:-}" ]]; then
    echo "EWORM_K_LEN=$EWORM_K_LEN"
  fi
  if [[ -n "${EWORM_K_MAX_T_MS:-}" ]]; then
    echo "EWORM_K_MAX_T_MS=$EWORM_K_MAX_T_MS"
  fi
  echo "EWORM_PERCISE=$EWORM_PERCISE"
  echo "EWORM_FAST_DVDW=$EWORM_FAST_DVDW"
  echo "EWORM_GLOBAL_STREAM=$EWORM_GLOBAL_STREAM"
  echo "EWORM_GLOBAL_CHUNK=$EWORM_GLOBAL_CHUNK"
fi

# If we are resuming, optionally repair the checkpoint so the run really continues from best-so-far.
if [[ "${RESUME}" == "1" ]]; then
  if [[ "${RESUME_RESTORE_OPT}" == "1" || "${RESUME_RESTORE_VMIN}" == "1" || "${RESUME_RESTORE_RUN_BEST}" == "1" || "${RESUME_RESET_ADAM}" != "none" || -n "${RESUME_SET_ALPHA_MULT}" ]]; then
    CKPT_EDITOR="${SCRIPT_DIR}/edit_ckpt.py"
    [[ -f "${CKPT_EDITOR}" ]] || die "missing ckpt editor: ${CKPT_EDITOR}"
    EDIT_ARGS=( "--out" "${OUT_DIR}" "--prefix" "${PREFIX}" "--suffix" "${SUFFIX}" )
    if [[ "${RESUME_RESTORE_OPT}" == "1" ]]; then
      EDIT_ARGS+=( "--restore-opt" )
    fi
    if [[ "${RESUME_RESTORE_VMIN}" == "1" ]]; then
      EDIT_ARGS+=( "--restore-vmin-snapshot" )
    fi
    if [[ "${RESUME_RESTORE_RUN_BEST}" == "1" ]]; then
      EDIT_ARGS+=( "--restore-run-best" )
    fi
    if [[ "${RESUME_RESET_ADAM}" != "none" ]]; then
      EDIT_ARGS+=( "--reset-adam" "${RESUME_RESET_ADAM}" )
    fi
    if [[ -n "${RESUME_SET_ALPHA_MULT}" ]]; then
      EDIT_ARGS+=( "--set-alpha-multiplier" "${RESUME_SET_ALPHA_MULT}" )
    fi
    "${PYTHON}" "${CKPT_EDITOR}" "${EDIT_ARGS[@]}"
  fi
fi

if [[ "$RESUME" == "1" && -n "$RESUME_EPOCHS" ]]; then
  if [[ ! "$RESUME_EPOCHS" =~ ^[0-9]+$ ]]; then
    die "--resume-epochs must be an int (got: $RESUME_EPOCHS)"
  fi
  if [[ -n "$RESUME_START_EPOCH" ]]; then
    die "cannot combine --resume-epochs with --resume-start-epoch"
  fi
  if [[ ! -f "${OUT_DIR%/}/ckpt_${PREFIX}_${SUFFIX}.npz" ]]; then
    die "--resume-epochs requires a checkpoint at ${OUT_DIR%/}/ckpt_${PREFIX}_${SUFFIX}.npz (use --resume-from or --out with an existing run)"
  fi
  START_EPOCH="$(ckpt_start_epoch "${OUT_DIR%/}" "$PREFIX" "$SUFFIX")"
  EPOCHS="$(( START_EPOCH + RESUME_EPOCHS ))"
fi

echo "Starting training from seed pair:"
echo "  ${SCRIPT_DIR}/seeds/weights_optimal_eworm_v4.npy"
echo "  ${SCRIPT_DIR}/seeds/x_optimal_eworm_v4.npy"
echo "Outputs go to: $OUT_DIR"
echo "Backend: HelioX"
echo "  HELIOX_DEVICE=$HELIOX_DEVICE  HELIOX_PERMUTE_TYPE=$HELIOX_PERMUTE_TYPE"

CMD=("$PYTHON" train.py \
  --base-trial "$BASE_TRIAL_DIR" \
  --output-path "$OUT_DIR" \
  --prefix "$PREFIX" \
  --suffix "$SUFFIX" \
  --epochs "$EPOCHS" \
  --k-mul "$K_MUL")

if [[ -n "${K_LEN}" ]]; then
  CMD+=("--k-len" "$K_LEN")
fi
if [[ -n "${K_MAX_T_MS}" ]]; then
  CMD+=("--k-max-t-ms" "$K_MAX_T_MS")
fi

if [[ "$RESUME" == "1" ]]; then
  CMD+=("--resume")
  if [[ -n "$RESUME_START_EPOCH" ]]; then
    CMD+=("--resume-start-epoch" "$RESUME_START_EPOCH")
  fi
fi
exec "${CMD[@]}"
