#!/usr/bin/env bash
set -euo pipefail

# 1. Capture Positional Arguments
MASTER_TAG="${1:-interp}"
DATA_CFG="${2:-configs/data/eb.yaml}"
TRAIN_CFG="${3:-}"
RUN_TAG="${4:-}"
SEED="${5:-0}"

# 2. Environment Setup
RUNS_DIR="${RUNS_DIR:-jan_10_dropout_0.1_4heads}"
DEVICE="${DEVICE:-cuda}"
REGEN_DATA="${REGEN_DATA:-0}"
WANDB_ON="${WANDB_ON:-1}"

# 3. Validation
[[ -f "${DATA_CFG}"  ]] || { echo "ERROR: DATA_CFG not found: ${DATA_CFG}"  >&2; exit 1; }
[[ -f "${TRAIN_CFG}" ]] || { echo "ERROR: TRAIN_CFG not found: ${TRAIN_CFG}" >&2; exit 1; }

# 4. Directory & Tag Setup
if [[ -n "${RUN_TAG}" ]]; then
    train_tag="${RUN_TAG}_seed${SEED}"
else
    train_base="$(basename "${TRAIN_CFG%.*}")"
    train_tag="${train_base}_seed${SEED}"
fi

data_base="$(basename "${DATA_CFG%.*}")"
data_tag="${data_base}_seed${SEED}"

master_dir="${RUNS_DIR}/${MASTER_TAG}"
data_dir="${master_dir}/data/${data_tag}"
train_dir="${master_dir}/train/${train_tag}"

# Check for existing run to prevent overwriting
if [[ -f "${train_dir}/final.pt" ]]; then
    echo "[run] SKIP: Run already complete at ${train_dir}/final.pt"
    exit 0
fi

mkdir -p "${data_dir}" "${train_dir}"
data_path="${data_dir}/bundle.pt"

_copy_cfg_nc() { local src="$1"; local dst="$2"; cp -n "${src}" "${dst}" 2>/dev/null || true; }

# 5. Data Generation
MODE="$(python - "${DATA_CFG}" <<'PY'
import sys, yaml
try:
    cfg = yaml.safe_load(open(sys.argv[1], "r")) or {}
    print(cfg.get("mode", "default"))
except:
    print("default")
PY
)"

_copy_cfg_nc "${DATA_CFG}" "${data_dir}/data_config.yaml"

if [[ -f "${data_path}" && "${REGEN_DATA}" != "1" ]]; then
  echo "[run] data exists: ${data_path}"
else
  echo "[run] generating data -> ${data_path} (mode=${MODE})"
  if [[ "${MODE}" == "potential_sde" ]]; then
    mapfile -t ARGV < <(python - "${DATA_CFG}" <<'PY'
import sys, yaml
cfg = yaml.safe_load(open(sys.argv[1], "r")) or {}
argv = cfg.get("argv", [])
if argv:
    for a in argv: print(str(a))
PY
)
    python -u data_generator.py --out "${data_path}" --seed "${SEED}" --device cpu potential_sde "${ARGV[@]}"
  else
    python -u data_generator.py --config "${DATA_CFG}" --set seed="${SEED}" --set device=cpu --set out="${data_path}"
  fi
fi

# 6. Training Configuration
_copy_cfg_nc "${TRAIN_CFG}" "${train_dir}/train_config.yaml"

mapfile -t KV_OVERRIDES < <(python - "${TRAIN_CFG}" "${SEED}" "${DEVICE}" "${data_path}" "${train_dir}" "${WANDB_ON}" <<'PY'
import sys, yaml

train_cfg_path = sys.argv[1]
seed, device, data_path, outdir, wandb_on = sys.argv[2:7]
wandb_val = "true" if str(wandb_on) == "1" else "false"

cfg = yaml.safe_load(open(train_cfg_path, "r")) or {}

def pick_key(d, keys, default):
    for k in keys:
        if k in d: return k
    return default

pairs = []
pairs.append(("seed", seed))
pairs.append(("device", device))

data_key = pick_key(cfg, ["data", "data_path", "dataset"], "data")
out_key  = pick_key(cfg, ["outdir", "output_dir", "run_dir"], "outdir")
wb_key   = pick_key(cfg, ["wandb", "use_wandb"], "wandb")

pairs.append((data_key, data_path))
pairs.append((out_key, outdir))
pairs.append((wb_key, wandb_val))

for k, v in pairs:
    print(f"{k}={v}")
PY
)

SET_ARGS=()
for kv in "${KV_OVERRIDES[@]}"; do
  SET_ARGS+=( --set "${kv}" )
done

# 7. Execution
shift 5 || true

WANDB_NAME="${train_tag}"

echo "============================================================"
echo "[run] Training directory: ${train_dir}"
echo "[run] Data path: ${data_path}"
echo "[run] Wandb name: ${WANDB_NAME}"
echo "[run] Extra args: $*"
echo "============================================================"

python -u train.py \
    --config "${TRAIN_CFG}" \
    "${SET_ARGS[@]}" \
    --wandb-name "${WANDB_NAME}" \
    "$@"

echo "[run] Done. Output: ${train_dir}"