#!/usr/bin/env bash

# Sweep candidate reward-balance schedules for AntMaze variants, with optional
# parallel execution. Jobs are launched in the background and we cap the number
# of concurrent processes via MAX_PARALLEL (defaults to number of GPUs, or 1).

set -euo pipefail

ENVS=("AntMaze" "AntMazeSparse")
SEED=7
GPUS=(0)  # edit if multiple GPUs are available

# Candidate schedules: tuples of (start, end, steps)
# BALANCE_START=(0.05 0.12)
# BALANCE_END=(0.25 0.35)
# BALANCE_STEPS=(200000 400000)
BALANCE_START=(0.0 0.08)
BALANCE_END=(0.18 0.32)
BALANCE_STEPS=(300000 550000)
MANAGER_PERIOD=10  # c timesteps between manager proposals / intrinsic updates

MAX_PARALLEL=${MAX_PARALLEL:-${#GPUS[@]}}
if (( MAX_PARALLEL <= 0 )); then
  MAX_PARALLEL=1
fi

pids=()

prune_finished() {
  local alive=()
  for pid in "${pids[@]}"; do
    if kill -0 "$pid" 2>/dev/null; then
      alive+=("$pid")
    fi
  done
  pids=("${alive[@]}")
}

wait_for_slot() {
  while ((${#pids[@]} >= MAX_PARALLEL)); do
    if ! wait -n; then
      echo "[ERROR] One of the training jobs failed." >&2
      exit 1
    fi
    prune_finished
  done
}

launch_job() {
  local gpu=$1 env=$2 start=$3 end=$4 steps=$5 label=$6
  local algo="HAWK_${env}_s${SEED}_${label}_parallel"

  echo "[GPU ${gpu}] ${algo} start=${start} end=${end} steps=${steps}"

  (
    set -euo pipefail
    CUDA_VISIBLE_DEVICES=$gpu \
    python main.py \
      --env_name "$env" \
      --algo "$algo" \
      --seed "$SEED" \
      --manager_propose_freq "$MANAGER_PERIOD" \
      --train_manager_freq "$MANAGER_PERIOD" \
      --man_ctrl_rew_balance_start "$start" \
      --man_ctrl_rew_balance_end "$end" \
      --man_ctrl_rew_balance_steps "$steps"
  ) &

  local pid=$!
  pids+=("$pid")
}

job=0
for idx in "${!BALANCE_START[@]}"; do
  start=${BALANCE_START[$idx]}
  end=${BALANCE_END[$idx]}
  steps=${BALANCE_STEPS[$idx]}
  label="bal${idx+1}"

  for env in "${ENVS[@]}"; do
    gpu=${GPUS[$((job % ${#GPUS[@]}))]}
    wait_for_slot
    launch_job "$gpu" "$env" "$start" "$end" "$steps" "$label"
    ((job++))
  done
done

for pid in "${pids[@]}"; do
  wait "$pid" || exit 1
done

echo "All balance sweeps finished."
