#!/usr/bin/env bash
# Extrapolation suite runner for 4+ models (MeshFT-Net / MGN / MGN-HP / HNN [+FNO/GraphCON/PI-MGN]).
# - Trains ALL requested models (incl. MeshFT-Net with learnable geometry-conditioned Hodge)
# - Evaluates extrapolation on four scenarios:
#   (1) Frequency OOD, (2) Resolution OOD, (3) Parameter OOD, (4) Long-horizon OOD

set -Eeuo pipefail

# ----------------------- configurable knobs (override via env) -----------------------
OUT_ROOT="${OUT_ROOT:-runs/extrapolation_test_kdk}"
PY="${PYTHON_BIN:-python3}"

# Training distribution
TRAIN_GRID_X="${TRAIN_GRID_X:-32}"
TRAIN_GRID_Y="${TRAIN_GRID_Y:-32}"
TRAIN_DT="${TRAIN_DT:-0.004}"
TRAIN_KMAX="${TRAIN_KMAX:-3}"
TRAIN_C="${TRAIN_C:-1.0}"
TRAIN_PAIRS="${TRAIN_PAIRS:-4000}"
BATCH="${BATCH:-16}"
EPOCHS="${EPOCHS:-10}"
Q_ONLY="${Q_ONLY:-0}"              # 1 = supervise q only (p is unobserved)

# MeshFT-Net Hodge mode
MeshFT_HODGE_MODE="${MeshFT_HODGE_MODE:-learn_geom}"

# MGN-HP penalty weight
LAM_HAM="${LAM_HAM:-0.05}"

# physics-informed loss (PI-MGN)
LAM_PI="${LAM_PI:-0.05}"

# Seeds
SEEDS=(${SEEDS:-0 1 2})

# Test defaults (each scenario may overwrite)
TEST_DT_DEFAULT="${TEST_DT_DEFAULT:-0.004}"
ROLLOUT_T_DEFAULT="${ROLLOUT_T_DEFAULT:-200}"
# test pairs size
TEST_PAIRS="${TEST_PAIRS:-512}"

# normalization & eval controls
STD_INPUTS="${STD_INPUTS:-1}"      # standardize coords/V0/edge_attr
STD_STATE="${STD_STATE:-1}"        # standardize state channels
STD_GEO_EVAL="${STD_GEO_EVAL:-test}"  # 'train' or 'test'

# optional baselines toggles
INCLUDE_FNO="${INCLUDE_FNO:-1}"
INCLUDE_GRAPHCON="${INCLUDE_GRAPHCON:-1}"
INCLUDE_PIMGN="${INCLUDE_PIMGN:-1}"

# FNO hyper-params
FNO_WIDTH="${FNO_WIDTH:-64}"
FNO_LAYERS="${FNO_LAYERS:-4}"
FNO_MODES1="${FNO_MODES1:-12}"
FNO_MODES2="${FNO_MODES2:-12}"

# GraphCON hyper-params
GRAPHCON_HIDDEN="${GRAPHCON_HIDDEN:-64}"
GRAPHCON_LAYERS="${GRAPHCON_LAYERS:-6}"
GRAPHCON_ALPHA="${GRAPHCON_ALPHA:-0.5}"
GRAPHCON_GAMMA="${GRAPHCON_GAMMA:-1.0}"
GRAPHCON_DT_INNER="${GRAPHCON_DT_INNER:-1.0}"

# save artifacts
SAVE_MODELS="${SAVE_MODELS:-1}"
SAVE_ENERGY_CSV="${SAVE_ENERGY_CSV:-1}"

# ----------------------- helpers -----------------------

detect_device() {
  ${PY} - <<'PY'
import torch, sys
print("cuda" if torch.cuda.is_available() else "cpu")
PY
}

DEVICE="$(detect_device)"

assert_py_script() {
  if [[ ! -f "extrapolation_bench.py" ]]; then
    echo "[ERROR] Missing extrapolation_bench.py in current directory." >&2
    exit 1
  fi
}

run_one_case() {
  # Args:
  #   $1 = scenario name (freq|reso|param|long)
  #   $2 = seed
  local SCEN="$1"
  local SEED="$2"

  # Defaults (override per scenario below)
  local TEST_GRID_X TEST_GRID_Y TEST_DT TEST_KMAX TEST_C ROLLOUT_T
  TEST_GRID_X="$TRAIN_GRID_X"
  TEST_GRID_Y="$TRAIN_GRID_Y"
  TEST_DT="$TEST_DT_DEFAULT"
  TEST_KMAX="$TRAIN_KMAX"
  TEST_C="$TRAIN_C"
  ROLLOUT_T="$ROLLOUT_T_DEFAULT"

  case "${SCEN}" in
    freq)  TEST_KMAX=$(( TRAIN_KMAX * 2 )) ;;
    reso)  TEST_GRID_X=$(( TRAIN_GRID_X * 2 )); TEST_GRID_Y=$(( TRAIN_GRID_Y * 2 )) ;;
    param) TEST_C="1.4" ;;
    long)  ROLLOUT_T=$(( ROLLOUT_T_DEFAULT * 3 )) ;;
    *)     echo "[ERROR] Unknown scenario: ${SCEN}" >&2; exit 1 ;;
  esac

  local OUT_DIR="${OUT_ROOT}/${SCEN}/seed-${SEED}"
  mkdir -p "${OUT_DIR}"

  echo "--------------------------------------------------------------------------------"
  echo "[RUN] scenario=${SCEN}  seed=${SEED}  device=${DEVICE}"
  echo "      train: grid=${TRAIN_GRID_X}x${TRAIN_GRID_Y}, kmax=${TRAIN_KMAX}, c=${TRAIN_C}, dt=${TRAIN_DT}"
  echo "      test : grid=${TEST_GRID_X}x${TEST_GRID_Y}, kmax=${TEST_KMAX}, c=${TEST_C}, dt=${TEST_DT}, T=${ROLLOUT_T}"
  echo "      opts : include(FNO=${INCLUDE_FNO}, GraphCON=${INCLUDE_GRAPHCON}, PI-MGN=${INCLUDE_PIMGN}), std(inputs=${STD_INPUTS},state=${STD_STATE},geo_eval=${STD_GEO_EVAL})"
  echo "--------------------------------------------------------------------------------"

  ${PY} extrapolation_bench.py \
    --out_dir "${OUT_DIR}" \
    --seed "${SEED}" \
    --device "${DEVICE}" \
    --train_grid "${TRAIN_GRID_X}" "${TRAIN_GRID_Y}" \
    --train_dt "${TRAIN_DT}" \
    --train_kmax "${TRAIN_KMAX}" \
    --train_c "${TRAIN_C}" \
    --train_pairs "${TRAIN_PAIRS}" \
    --batch_size "${BATCH}" \
    --epochs "${EPOCHS}" \
    --q_only_supervision "${Q_ONLY}" \
    --std_inputs "${STD_INPUTS}" \
    --std_state "${STD_STATE}" \
    --test_grid "${TEST_GRID_X}" "${TEST_GRID_Y}" \
    --test_dt "${TEST_DT}" \
    --test_kmax "${TEST_KMAX}" \
    --test_c "${TEST_C}" \
    --test_pairs "${TEST_PAIRS}" \
    --rollout_T "${ROLLOUT_T}" \
    --std_geo_eval "${STD_GEO_EVAL}" \
    --meshft_hodge_mode "${MeshFT_HODGE_MODE}" \
    --mgn_hidden 64 \
    --mgn_layers 4 \
    --lam_ham "${LAM_HAM}" \
    --hnn_hidden 64 \
    --hnn_layers 4 \
    --include_fno "${INCLUDE_FNO}" \
    --include_graphcon "${INCLUDE_GRAPHCON}" \
    --include_pimgn "${INCLUDE_PIMGN}" \
    --fno_width "${FNO_WIDTH}" \
    --fno_layers "${FNO_LAYERS}" \
    --fno_modes1 "${FNO_MODES1}" \
    --fno_modes2 "${FNO_MODES2}" \
    --graphcon_hidden "${GRAPHCON_HIDDEN}" \
    --graphcon_layers "${GRAPHCON_LAYERS}" \
    --graphcon_alpha "${GRAPHCON_ALPHA}" \
    --graphcon_gamma "${GRAPHCON_GAMMA}" \
    --graphcon_dt_inner "${GRAPHCON_DT_INNER}" \
    --lam_pi "${LAM_PI}" \
    --save_models "${SAVE_MODELS}" \
    --save_energy_csv "${SAVE_ENERGY_CSV}"
}

aggregate_csv() {
  # Aggregate all JSON summaries into a single CSV (pure stdlib Python).
  local OUT_DIR="${OUT_ROOT}/_compiled"
  mkdir -p "${OUT_DIR}"
  local CSV="${OUT_DIR}/summary_compiled.csv"

  ${PY} - "${OUT_ROOT}" "${CSV}" <<'PY'
import os, json, csv, sys, glob, math

root = sys.argv[1]
csv_path = sys.argv[2]
rows = []
for path in glob.glob(os.path.join(root, "*", "seed-*", "extrapolation_summary.json")):
    try:
        with open(path, "r") as f:
            j = json.load(f)
        parts = path.replace(root + os.sep, "").split(os.sep)
        scenario = parts[0]
        seed = parts[1].split("-")[-1]
        train = j["train"]; test = j["test"]
        mse = j.get("one_step_mse", {})
        ro  = j.get("rollout", {})

        def g(d, key, sub=None):
            if sub is None:
                return d.get(key)
            return (d.get(key) or {}).get(sub)

        # helpers
        def f(x): 
            try: return float(x)
            except: return ""

        rows.append({
            "scenario": scenario,
            "seed": int(seed),
            "train_grid": f"{train['grid'][0]}x{train['grid'][1]}",
            "test_grid": f"{test['grid'][0]}x{test['grid'][1]}",
            "train_kmax": train["kmax"],
            "test_kmax": test["kmax"],
            "train_c": train["c"],
            "test_c": test["c"],
            "train_dt": train["dt"],
            "test_dt": test["dt"],
            "epochs": train["epochs"],
            "train_pairs": train["pairs"],
            "rollout_T": test["rollout_T"],
            "q_only": train["q_only"],
            "meshft_hodge_mode": train["meshft_hodge_mode"],

            # one-step MSE (core 4)
            "MSE_MeshFT-Net": f(mse.get("MeshFT-Net")),
            "MSE_MGN":        f(mse.get("MGN")),
            "MSE_MGNHP":      f(mse.get("MGN-HP")),
            "MSE_HNN":        f(mse.get("HNN")),
            # one-step MSE (optional)
            "MSE_FNO":        f(mse.get("FNO")),
            "MSE_GraphCON":   f(mse.get("GraphCON")),
            "MSE_PIMGN":      f(mse.get("PI-MGN")),

            # rollout metrics (core 4)
            "RelFinal_MeshFT-Net": f(g(ro, "MeshFT-Net", "rel_final")),
            "RelFinal_MGN":        f(g(ro, "MGN",        "rel_final")),
            "RelFinal_MGNHP":      f(g(ro, "MGN-HP",     "rel_final")),
            "RelFinal_HNN":        f(g(ro, "HNN",        "rel_final")),
            "Drift_MeshFT-Net":    f(g(ro, "MeshFT-Net", "drift_mean")),
            "Drift_MGN":           f(g(ro, "MGN",        "drift_mean")),
            "Drift_MGNHP":         f(g(ro, "MGN-HP",     "drift_mean")),
            "Drift_HNN":           f(g(ro, "HNN",        "drift_mean")),
            "TAMSE_MeshFT-Net":    f(g(ro, "MeshFT-Net", "tamse")),
            "TAMSE_MGN":           f(g(ro, "MGN",        "tamse")),
            "TAMSE_MGNHP":         f(g(ro, "MGN-HP",     "tamse")),
            "TAMSE_HNN":           f(g(ro, "HNN",        "tamse")),

            # rollout metrics (optional 3)
            "RelFinal_FNO":        f(g(ro, "FNO",        "rel_final")),
            "RelFinal_GraphCON":   f(g(ro, "GraphCON",   "rel_final")),
            "RelFinal_PIMGN":      f(g(ro, "PI-MGN",     "rel_final")),
            "Drift_FNO":           f(g(ro, "FNO",        "drift_mean")),
            "Drift_GraphCON":      f(g(ro, "GraphCON",   "drift_mean")),
            "Drift_PIMGN":         f(g(ro, "PI-MGN",     "drift_mean")),
            "TAMSE_FNO":           f(g(ro, "FNO",        "tamse")),
            "TAMSE_GraphCON":      f(g(ro, "GraphCON",   "tamse")),
            "TAMSE_PIMGN":         f(g(ro, "PI-MGN",     "tamse")),
        })
    except Exception as e:
        print(f"[WARN] skip {path}: {e}", file=sys.stderr)

rows.sort(key=lambda r: (r["scenario"], r["seed"]))
fields = [
    "scenario","seed","train_grid","test_grid","train_kmax","test_kmax","train_c","test_c",
    "train_dt","test_dt","epochs","train_pairs","rollout_T","q_only","meshft_hodge_mode",
    "MSE_MeshFT-Net","MSE_MGN","MSE_MGNHP","MSE_HNN","MSE_FNO","MSE_GraphCON","MSE_PIMGN",
    "RelFinal_MeshFT-Net","RelFinal_MGN","RelFinal_MGNHP","RelFinal_HNN","RelFinal_FNO","RelFinal_GraphCON","RelFinal_PIMGN",
    "Drift_MeshFT-Net","Drift_MGN","Drift_MGNHP","Drift_HNN","Drift_FNO","Drift_GraphCON","Drift_PIMGN",
    "TAMSE_MeshFT-Net","TAMSE_MGN","TAMSE_MGNHP","TAMSE_HNN","TAMSE_FNO","TAMSE_GraphCON","TAMSE_PIMGN"
]
with open(csv_path, "w", newline="") as f:
    w = csv.DictWriter(f, fieldnames=fields)
    w.writeheader()
    for r in rows:
        w.writerow(r)
print(csv_path)
PY

  echo "[OK] Compiled CSV -> ${CSV}"
}

# ----------------------- main flow -----------------------
assert_py_script
mkdir -p "${OUT_ROOT}"

SCENARIOS=(freq reso param long)

for scen in "${SCENARIOS[@]}"; do
  for seed in "${SEEDS[@]}"; do
    run_one_case "${scen}" "${seed}"
  done
done

aggregate_csv

echo "All done. Results under: ${OUT_ROOT}"