#!/usr/bin/env bash

# Flat (single-policy) TD3 sweep over the standard seed set.

set -euo pipefail

SEEDS=(7 19 23 42 73)
ENV_NAME=${ENV_NAME:-AntFall}
GPU=${GPU:-0}
LOG_DIR=${LOG_DIR:-./logs_td3}
MODEL_DIR=${MODEL_DIR:-./models_td3}
MAX_PARALLEL=${MAX_PARALLEL:-${#SEEDS[@]}}  # cap concurrency if desired

pids=()

prune_finished() {
  local alive=()
  for pid in "${pids[@]}"; do
    if kill -0 "$pid" 2>/dev/null; then
      alive+=("$pid")
    fi
  done
  pids=("${alive[@]}")
}

wait_for_slot() {
  while ((${#pids[@]} >= MAX_PARALLEL)); do
    if ! wait -n; then
      echo "[ERROR] One of the TD3 runs failed." >&2
      exit 1
    fi
    prune_finished
  done
}

for seed in "${SEEDS[@]}"; do
  wait_for_slot
  algo="td3_flat_${ENV_NAME}_s${seed}"
  echo "[GPU ${GPU}] Running ${algo}"

  (
    set -euo pipefail
    CUDA_VISIBLE_DEVICES=$GPU \
    python main.py \
      --env_name "$ENV_NAME" \
      --algo "$algo" \
      --seed "$seed" \
      --gid "$GPU" \
      --log_dir "$LOG_DIR" \
      --model_dir "$MODEL_DIR" \
      --save_models \
      "$@"
  ) &
  pids+=("$!")
done

for pid in "${pids[@]}"; do
  wait "$pid" || exit 1
done

echo "All flat TD3 runs finished."
