#!/usr/bin/env bash
set -euo pipefail

# ============================================================
# Usage:
#   run_from_configs.sh MASTER_TAG DATA_CFG TRAIN_CFG RUN_TAG [SEED] [extra train.py args...]
#
# Seed resolution order:
#   1) positional arg #5
#   2) env var SEED (exported by sbatch script)
#   3) default 0
# ============================================================

MASTER_TAG="${1:-orbit}"
DATA_CFG="${2:-configs/data/orbit.yaml}"
TRAIN_CFG="${3:-configs/models/attn_next_k_from_0.yaml}"
RUN_TAG="${4:-}"

# Seed: prefer explicit positional; else env; else 0
if [[ $# -ge 5 ]]; then
  SEED_ARG="$5"
  shift 5
else
  SEED_ARG="${SEED:-0}"
  shift 4
fi
SEED="${SEED_ARG}"

# Environment defaults
DEVICE="${DEVICE:-cuda}"
RUNS_DIR="${RUNS_DIR:-boids_for_figure_jan_27}"
REGEN_DATA="${REGEN_DATA:-0}"
WANDB_ON="${WANDB_ON:-1}"

# Validation
[[ -f "${DATA_CFG}"  ]] || { echo "ERROR: DATA_CFG not found: ${DATA_CFG}"  >&2; exit 1; }
[[ -f "${TRAIN_CFG}" ]] || { echo "ERROR: TRAIN_CFG not found: ${TRAIN_CFG}" >&2; exit 1; }

# Tags / directories
data_base="$(basename "${DATA_CFG%.*}")"
train_base="$(basename "${TRAIN_CFG%.*}")"

master_dir="${RUNS_DIR}/${MASTER_TAG}"
data_dir="${master_dir}/data/${data_base}_seed${SEED}"
train_dir="${master_dir}/train/${train_base}_seed${SEED}${RUN_TAG:+_${RUN_TAG}}"

mkdir -p "${data_dir}" "${train_dir}"

data_path="${data_dir}/bundle.pt"

_copy_cfg_nc() { local src="$1"; local dst="$2"; cp -n "${src}" "${dst}" 2>/dev/null || true; }

# ---- 1) Data generation (uniform) ----
_copy_cfg_nc "${DATA_CFG}" "${data_dir}/data_config.yaml"

if [[ -f "${data_path}" && "${REGEN_DATA}" != "1" ]]; then
  echo "[run] data exists: ${data_path} (skip; set REGEN_DATA=1 to force)"
else
  echo "[run] generating data -> ${data_path}"
  python -u data_generator.py \
    --config "${DATA_CFG}" \
    --set seed="${SEED}" \
    --set device=cpu \
    --set out="${data_path}"
fi

# ---- 2) Training overrides (only override keys that exist in YAML) ----
_copy_cfg_nc "${TRAIN_CFG}" "${train_dir}/train_config.yaml"

mapfile -t KV_OVERRIDES < <(python - "${TRAIN_CFG}" "${SEED}" "${DEVICE}" "${data_path}" "${train_dir}" "${WANDB_ON}" <<'PY'
import sys, yaml

train_cfg_path, seed, device, data_path, outdir, wandb_on = sys.argv[1:7]
cfg = yaml.safe_load(open(train_cfg_path, "r")) or {}

def get_path(d, path):
    cur = d
    for p in path.split("."):
        if not isinstance(cur, dict) or p not in cur:
            return None
        cur = cur[p]
    return cur

def exists(path):
    return get_path(cfg, path) is not None

def pick(cands, default):
    for k in cands:
        if exists(k):
            return k
    return default

wandb_val = "true" if str(wandb_on) == "1" else "false"

pairs = [
    (pick(["seed"], "seed"), seed),
    (pick(["device"], "device"), device),
    (pick(["data", "data_path", "bundle", "bundle_path", "data_bundle", "dataset", "dataset_path"], "data"), data_path),
    (pick(["outdir", "output_dir", "run_dir", "train_dir", "out"], "outdir"), outdir),
    (pick(["wandb", "use_wandb", "wandb_on"], "wandb"), wandb_val),
]

for k, v in pairs:
    print(f"{k}={v}")
PY
)

SET_ARGS=()
for kv in "${KV_OVERRIDES[@]}"; do
  SET_ARGS+=( --set "${kv}" )
done

echo "[run] python -u train.py --config ${TRAIN_CFG} ${SET_ARGS[*]} $*"
python -u train.py --config "${TRAIN_CFG}" "${SET_ARGS[@]}" "$@"

echo "[done] Master: ${master_dir}"
echo "[done] Data:   ${data_path}"
echo "[done] Train:  ${train_dir}"
