#!/usr/bin/env bash
set -euo pipefail

# ------------------------------------------------------------
# Shallow-water J(z) × G-flex ablation runner
#
# Runs:
#   n_traj ∈ {LOW, HIGH} × seed ∈ SEEDS
#   with 4 model modes (inside the python script):
#     J_fixed__G_diag, J_fixed__G_full, J_state__G_diag, J_state__G_full
#
# Produces (by python):
#   ROOT/ntraj{N}/seed{S}/mode_{MODE}/summary.json
#   ROOT/aggregate.json
#
# Plus this script generates:
#   ROOT/summary.csv            (flattened per-run table)
#   ROOT/summary_by_setting.csv (grouped mean/std over seeds for each (n_traj, mode))
# ------------------------------------------------------------

PY=shallow_jdep_ablation_gflex.py

# Seeds to run
SEEDS=(0 1 2)

# Two data regimes (reviewer defense)
NTRAJ=4

# Common experiment parameters (override per call if needed)
COMMON="--device cuda \
        --H 64 --W 64 \
        --dt 0.02 \
        --traj_len 260 \
        --k_train 8 \
        --rollout_steps 600 \
        --n_eval_trajs 3 \
        --h_min 1e-3 \
        --truth_j_mode state \
        --truth_g0 1.0 \
        --batch_size 6 \
        --n_epochs 5 \
        --lr 1e-3 \
        --weight_decay 0.0 \
        --clip_grad 1.0 \
        --hidden_g 64 --layers_g 3 \
        --hidden_c 32 --layers_c 2 \
        --c_min 0.05 \
        --unitmean_c 1 \
        --deterministic 0"

ROOT="runs/shallow_jdep_gflex_ablation"
mkdir -p "$ROOT"

run_one () {
  local ntraj="$1"
  local seed="$2"

  echo "[run] n_traj=$ntraj seed=$seed"
  python "$PY" $COMMON \
    --seeds "$seed" \
    --ntraj_low "$ntraj" \
    --ntraj_high "$ntraj" \
    --outdir "$ROOT"
}

# ------------------------------------------------------------
# Run all settings
# ------------------------------------------------------------

for n in "$NTRAJ"; do
  for s in "${SEEDS[@]}"; do
    run_one "$n" "$s"
  done
done

# ------------------------------------------------------------
# Aggregate into CSVs (per-run, and grouped by (n_traj, mode))
# ------------------------------------------------------------

python - <<'PY'
import json, csv, re
from pathlib import Path
from statistics import mean, pstdev
from collections import defaultdict

ROOT = Path("runs/shallow_jdep_gflex_ablation").resolve()
OUT_SUMMARY = ROOT / "summary.csv"
OUT_GRP = ROOT / "summary_by_setting.csv"

paths = sorted(ROOT.glob("ntraj*/seed*/mode_*/summary.json"))
if not paths:
    raise SystemExit(f"[ERR] no summary.json found under {ROOT}")

COLUMNS = [
    "n_traj","seed","mode","path",
    "best_val_kstep_mse",
    "rollout_vrmse_mean","rollout_mse_mean",
    "drift_mass_mean","drift_momx_mean","drift_momy_mean","drift_energy_mean",
    "gain_mse_cx","gain_mse_cy",
]

rows = []
for p in paths:
    d = json.loads(p.read_text())
    rows.append({
        "n_traj": d.get("n_traj",""),
        "seed": d.get("seed",""),
        "mode": d.get("mode", d.get("c_mode","")),
        "path": str(p.parent),
        "best_val_kstep_mse": d.get("best_val_kstep_mse",""),
        "rollout_vrmse_mean": d.get("rollout_vrmse_mean",""),
        "rollout_mse_mean": d.get("rollout_mse_mean",""),
        "drift_mass_mean": d.get("drift_mass_mean",""),
        "drift_momx_mean": d.get("drift_momx_mean",""),
        "drift_momy_mean": d.get("drift_momy_mean",""),
        "drift_energy_mean": d.get("drift_energy_mean",""),
        "gain_mse_cx": d.get("gain_mse_cx",""),
        "gain_mse_cy": d.get("gain_mse_cy",""),
    })

OUT_SUMMARY.parent.mkdir(parents=True, exist_ok=True)
with OUT_SUMMARY.open("w", newline="", encoding="utf-8") as f:
    w = csv.DictWriter(f, fieldnames=COLUMNS)
    w.writeheader()
    for r in rows:
        w.writerow(r)
print(f"[OK] wrote {OUT_SUMMARY}")

def to_float(x):
    try:
        return float(x)
    except Exception:
        return float("nan")

metrics = [
    "best_val_kstep_mse",
    "rollout_vrmse_mean","rollout_mse_mean",
    "drift_mass_mean","drift_momx_mean","drift_momy_mean","drift_energy_mean",
    "gain_mse_cx","gain_mse_cy",
]

grp = defaultdict(list)
for r in rows:
    key = (int(r["n_traj"]), str(r["mode"]))
    grp[key].append(r)

GRP_COLS = ["n_traj","mode","n_seeds"] + [m+"_mean" for m in metrics] + [m+"_std" for m in metrics]
with OUT_GRP.open("w", newline="", encoding="utf-8") as f:
    w = csv.DictWriter(f, fieldnames=GRP_COLS)
    w.writeheader()
    for (n_traj, mode), rs in sorted(grp.items()):
        out = {"n_traj": n_traj, "mode": mode, "n_seeds": len(rs)}
        for m in metrics:
            xs = [to_float(r[m]) for r in rs]
            xs = [x for x in xs if x == x]  # drop NaN
            out[m+"_mean"] = mean(xs) if xs else float("nan")
            out[m+"_std"] = pstdev(xs) if len(xs) > 1 else 0.0
        w.writerow(out)
print(f"[OK] wrote {OUT_GRP}")
PY

echo
echo "Done. Artifacts under: $ROOT"
echo "  - aggregate.json (from python)         : $ROOT/aggregate.json"
echo "  - per-run summary                      : $ROOT/summary.csv"
echo "  - grouped by (n_traj, mode) mean/std   : $ROOT/summary_by_setting.csv"