#!/usr/bin/env bash

# Multi-seed launcher for AntFall experiments following the S3 naming scheme.
# Mirrors the tuned hyper-parameters used by the single-run AntFall script and
# enables tagged halfway/final checkpoints for seed 7 within a single training run.

set -euo pipefail

SEEDS=(7 23 42 73)

GPU=${GPU:-0}
MAN_REW=${MAN_REW:-0.24}
MAN_NOISE=${MAN_NOISE:-0.6}
USE_ADJ=${USE_ADJ:-1}
MAX_TIMESTEPS=${MAX_TIMESTEPS:-5000000}
EVAL_FREQ=${EVAL_FREQ:-5000}

HALF_TAG=${HALF_TAG:-half}
FINAL_TAG=${FINAL_TAG:-final}

MODEL_DIR=${MODEL_DIR:-./models}
LOG_DIR=${LOG_DIR:-./logs}
ENV_NAME="AntFall"

COMMON_ARGS=(
  --env_name "$ENV_NAME"
  --man_rew_scale "$MAN_REW"
  --manager_propose_freq 10
  --train_manager_freq 10
  --man_ctrl_rew_balance_start 0.1
  --man_ctrl_rew_balance_end 0.32
  --man_ctrl_rew_balance_steps 320000
  --man_noise_sigma "$MAN_NOISE"
  --reach_warmup_samples 3000
  --reach_warmup_rounds 1
)

if [[ "$USE_ADJ" -eq 0 ]]; then
  COMMON_ARGS+=(--disable_adj_net)
  echo "[INFO] Running without adjacency network."
else
  echo "[INFO] Running with adjacency network enabled."
fi

mkdir -p "$MODEL_DIR"
mkdir -p "$LOG_DIR"

run_training() {
  local seed="$1"
  local max_steps="$2"
  local eval_freq="$3"
  shift 3
  local extra_args=("$@")
  local algo="S3_AntFall_seed_${seed}"

  local run_args=(
    "${COMMON_ARGS[@]}"
    --algo "$algo"
    --seed "$seed"
    --max_timesteps "$max_steps"
    --eval_freq "$eval_freq"
    "${extra_args[@]}"
  )

  echo "[GPU ${GPU}] Launching ${algo} (seed=${seed}, steps=${max_steps}, eval_freq=${eval_freq})."
  CUDA_VISIBLE_DEVICES=$GPU \
  python main.py "${run_args[@]}"
}

for seed in "${SEEDS[@]}"; do
  extra_args=()
  if [[ "$seed" -eq 7 ]]; then
    extra_args+=(--save_models --save_halfway_checkpoint --half_checkpoint_tag "$HALF_TAG" --final_checkpoint_tag "$FINAL_TAG")
  fi

  echo "[INFO] === Seed ${seed}: full training run (${MAX_TIMESTEPS} steps). ==="
  run_training "$seed" "$MAX_TIMESTEPS" "$EVAL_FREQ" "${extra_args[@]}"
done

echo "All S3 AntFall runs finished."
