#!/usr/bin/env bash
set -Eeuo pipefail

# =========================================================
# Iso-accuracy cost sweep runner (MeshFT-Net / MGN / MGN-HP / HNN)
# - Runs the extended cost_bench.py across grids/seeds/configs
# - Records per-step timings, peak memory, parameter counts
# - Adds iso-accuracy probe (TAMSE <= ACC_TAU)
# - Aggregates:
#     * ALL rows
#     * ISO-only rows (meets_iso==1)
#     * Best-ISO per (model, grid) by minimal infer_ms
# =========================================================

# --------- User-overridable knobs (can be changed via env) ----------
PYTHON="${PYTHON:-python3}"
DEVICE="${DEVICE:-cuda}"                # cuda or cpu
AB_MODULE="${AB_MODULE:-analytic_wave_bench.py}"  
COST_BENCH="${COST_BENCH:-cost_bench.py}"    # the extended cost bench

# Sweep space
SEEDS=(${SEEDS:-0 1 2})
GRIDS=(${GRIDS:-32x32 64x64 128x128 256x256})
BATCH="${BATCH:-8}"

# dt schedule: either fixed list aligned with GRIDS, or geometric scaling
# If DTLIST is provided, it will be used positionally; otherwise we compute base*(ref/max(Nx,Ny))
DTLIST=(${DTLIST:-})                    # e.g., ("0.008" "0.004" "0.002" "0.001")
DT_BASE="${DT_BASE:-0.004}"
DT_REF_N="${DT_REF_N:-32}"

# Model capacity knobs (MGN/HP); MeshFT-Net capacity comes from Hodge Geom-MLP
MGN_HIDDENS=(${MGN_HIDDENS:-64})
MGN_LAYERS=(${MGN_LAYERS:-4})
MESHFT_GEOM_HIDDEN=(${MESHFT_GEOM_HIDDEN:-64})
MESHFT_GEOM_LAYERS=(${MESHFT_GEOM_LAYERS:-2})

# Toggles
HNN_ENABLE="${HNN_ENABLE:-1}"           # 1=enable HNN, 0=disable
LAM_HAM="${LAM_HAM:-0.1}"               # HP penalty for MGN-HP (0 disables)
USE_SN="${USE_SN:-0}"                   # 1 to apply spectral norm to all linear layers

# Physics / data
KMAX="${KMAX:-4}"
C_SPEED="${C_SPEED:-1.0}"
STATE_MODE="${STATE_MODE:-canonical}"
DATA_STATE_MODE="${DATA_STATE_MODE:-canonical}"

# Masking (keep fixed across runs for apples-to-apples)
MISS_RATIO="${MISS_RATIO:-0.0}"
MISS_MODE="${MISS_MODE:-random}"
GRID_STRIDE="${GRID_STRIDE:-2}"

# Iso-accuracy probe (cost_bench.py -> TAMSE)
ACC_T="${ACC_T:-64}"                    # rollout steps used for TAMSE probe
ACC_STRIDE="${ACC_STRIDE:-1}"
ACC_TAU="${ACC_TAU:-0.2}"               # iso threshold (TAMSE<=tau)

# Timing loops
WARMUP="${WARMUP:-20}"
ITERS="${ITERS:-80}"

# Output dirs
STAMP="$(date +%Y%m%d_%H%M%S)"
OUT_ROOT="${OUT_ROOT:-runs/cost_iso_${STAMP}}"
LOG_DIR="${OUT_ROOT}/logs"
CSV_DIR="${OUT_ROOT}/csv"
AGG_ALL="${OUT_ROOT}/ALL_rows.csv"
AGG_ISO="${OUT_ROOT}/ISO_rows.csv"
AGG_BEST="${OUT_ROOT}/BEST_iso_by_model_grid.csv"
mkdir -p "${OUT_ROOT}" "${LOG_DIR}" "${CSV_DIR}"

# OOM guard (heuristic): skip if free GPU memory below threshold
SKIP_ON_LOW_MEM="${SKIP_ON_LOW_MEM:-1}"
MIN_FREE_GPU_MB="${MIN_FREE_GPU_MB:-2048}"

echo "== Iso-accuracy Cost Sweep =="
echo " device      : ${DEVICE}"
echo " ab module   : ${AB_MODULE}"
echo " bench       : ${COST_BENCH}"
echo " out root    : ${OUT_ROOT}"
echo " grids       : ${GRIDS[*]}"
echo " seeds       : ${SEEDS[*]}"
echo " mgn(H,L)    : H=[${MGN_HIDDENS[*]}], L=[${MGN_LAYERS[*]}]"
echo " meshft geom : H=[${MESHFT_GEOM_HIDDEN[*]}], L=[${MESHFT_GEOM_LAYERS[*]}]"
echo " acc(T,τ)    : T=${ACC_T}, tau=${ACC_TAU}"
echo

[[ -f "${COST_BENCH}" ]] || { echo "[ERROR] ${COST_BENCH} not found"; exit 1; }

gpu_free_mb () {
  if command -v nvidia-smi >/dev/null 2>&1; then
    nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits | awk 'NR==1{print $1; exit}'
  else
    echo 999999
  fi
}

dt_for_grid () {
  local nx="$1" ny="$2" idx="$3"
  if (( ${#DTLIST[@]} > 0 )); then
    if (( idx < ${#DTLIST[@]} )); then
      echo "${DTLIST[$idx]}"
    else
      echo "${DTLIST[-1]}"
    fi
    return 0
  fi
  # geometric scaling: dt = DT_BASE * (DT_REF_N / max(nx,ny))
  python3 - <<PY
base=float("${DT_BASE}")
ref=float("${DT_REF_N}")
m=max(int("${nx}"), int("${ny}"))
print(f"{base*(ref/m):.8f}")
PY
}

# Header for aggregated CSVs
ALL_HDR="scenario,grid,seed,batch,dt,device,model,params,eff_substeps,infer_ms,train_ms,samples_per_s,nodes_per_s,inf_peakMB,train_peakMB,speed_per_param,time_per_node_per_param,tamse,relF,meets_iso,gate_alpha,hidden,layers,meshft_geom_hidden,meshft_geom_layers"
echo "${ALL_HDR}" > "${AGG_ALL}"

# ---------------- Run loop ----------------
SCEN="iso"
for si in "${!SEEDS[@]}"; do
  SEED="${SEEDS[$si]}"
  for gi in "${!GRIDS[@]}"; do
    GRID="${GRIDS[$gi]}"
    IFS=x read -r NX NY <<< "${GRID}"
    DT_RUN="$(dt_for_grid "${NX}" "${NY}" "${gi}")"

    # simple OOM gate
    if [[ "${DEVICE}" == "cuda" && "${SKIP_ON_LOW_MEM}" == "1" ]]; then
      FREE_MB="$(gpu_free_mb || echo 0)"
      if [[ "${FREE_MB}" -lt "${MIN_FREE_GPU_MB}" ]]; then
        echo "[SKIP] low free GPU mem (${FREE_MB}MB) @ grid ${GRID}, seed ${SEED}"
        continue
      fi
    fi

    for H_THIS in "${MGN_HIDDENS[@]}"; do
      for L_THIS in "${MGN_LAYERS[@]}"; do
        for MH_THIS in "${MESHFT_GEOM_HIDDEN[@]}"; do
          for ML_THIS in "${MESHFT_GEOM_LAYERS[@]}"; do
            RUN_TAG="g${GRID}_dt${DT_RUN}_mH${H_THIS}_mL${L_THIS}_fH${MH_THIS}_fL${ML_THIS}_seed${SEED}"
            OUT_CSV="${CSV_DIR}/cost_${RUN_TAG}.csv"
            LOG="${LOG_DIR}/cost_${RUN_TAG}.log"

            set +e
            "${PYTHON}" "${COST_BENCH}" \
              --ab_module "${AB_MODULE}" \
              --device "${DEVICE}" \
              --mesh grid \
              --grid "${NX}" "${NY}" \
              --dt "${DT_RUN}" \
              --batch_size "${BATCH}" \
              --val_size 256 \
              --kmax "${KMAX}" \
              --c_speed "${C_SPEED}" \
              --c_wave "${C_SPEED}" \
              --state_mode "${STATE_MODE}" \
              --data_state_mode "${DATA_STATE_MODE}" \
              --miss_ratio "${MISS_RATIO}" \
              --miss_mode "${MISS_MODE}" \
              --grid_stride "${GRID_STRIDE}" \
              --mgn_hidden "${H_THIS}" \
              --mgn_layers "${L_THIS}" \
              --lam_ham "${LAM_HAM}" \
              --hnn_enable "${HNN_ENABLE}" \
              --meshft_hodge_mode learn_geom \
              --meshft_geom_hidden "${MH_THIS}" \
              --meshft_geom_layers "${ML_THIS}" \
              --use_spectral_norm "${USE_SN}" \
              --warmup "${WARMUP}" \
              --iters "${ITERS}" \
              --seed "${SEED}" \
              --acc_T "${ACC_T}" \
              --acc_stride "${ACC_STRIDE}" \
              --acc_tau "${ACC_TAU}" \
              --out_csv "${OUT_CSV}" \
              > "${LOG}" 2>&1
            RC=$?
            set -e

            if [[ ${RC} -ne 0 ]]; then
              echo "[WARN] run failed: ${RUN_TAG} (rc=${RC}). See ${LOG}"
              continue
            fi

            # Append to global ALL csv with metadata
            if [[ -f "${OUT_CSV}" ]]; then
              # cost_bench CSV header: model,params,eff_substeps,infer_ms,train_ms,samples/s,nodes/s,inf_peakMB,train_peakMB,speed_per_param,time_per_node_per_param,tamse,relF,meets_iso,gate_alpha
              awk -F, -v OFS=',' \
                -v scen="${SCEN}" -v grid="${GRID}" -v seed="${SEED}" -v batch="${BATCH}" -v dt="${DT_RUN}" -v dev="${DEVICE}" \
                -v mh="${MH_THIS}" -v ml="${ML_THIS}" -v hh="${H_THIS}" -v ll="${L_THIS}" \
                'NR>1 { print scen,grid,seed,batch,dt,dev,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,hh,ll,mh,ml }' \
                "${OUT_CSV}" >> "${AGG_ALL}"
            fi
          done
        done
      done
    done
  done
done

# ---------------- Aggregation (Python) ----------------
"${PYTHON}" - "${AGG_ALL}" "${AGG_ISO}" "${AGG_BEST}" <<'PY'
import sys, csv, math, statistics
import pandas as pd

all_csv, iso_csv, best_csv = sys.argv[1:4]
df = pd.read_csv(all_csv)

# ISO-only rows
iso = df[df["meets_iso"] == 1].copy()
iso.to_csv(iso_csv, index=False)

# Best-ISO per (model, grid) by minimal infer_ms; keep a few tie-breaks.
def pick_best(g):
    # Sort by infer_ms, then train_peakMB, then params
    g = g.sort_values(["infer_ms","train_peakMB","params"], ascending=[True, True, True])
    return g.iloc[0]

best = (iso.groupby(["grid","model"], as_index=False)
          .apply(pick_best)
          .reset_index(drop=True))

best.to_csv(best_csv, index=False)

# Pretty print small summaries
def summarize(tag, d):
    print(f"\n== {tag} ==")
    if d.empty:
        print("(empty)")
        return
    cols = ["grid","model","params","infer_ms","train_ms","inf_peakMB","train_peakMB","tamse","relF"]
    cols = [c for c in cols if c in d.columns]
    print(d[cols].to_string(index=False))

summarize("ISO-only (head)", iso.head(16))
summarize("Best ISO per (model, grid)", best)
PY

echo
echo "== Sweep finished =="
echo "  ALL rows : ${AGG_ALL}"
echo "  ISO rows : ${AGG_ISO}"
echo "  BEST ISO : ${AGG_BEST}"
echo "  Logs     : ${LOG_DIR}"
echo "  Per-run  : ${CSV_DIR}"