#!/usr/bin/env bash
set -euo pipefail

: "${PYTHON:=python}"
: "${RUN_GEMS:=1}"
: "${RUN_PSRO:=1}"
: "${SEEDS:=0 1 2 3 4}"
: "${ENV_NAME:=simple_tag_v3}"
: "${OUT_TAG:=exp}"
: "${FPS:=30}"
: "${MAX_CYCLES:=50}"
: "${ROLLOUT_STEPS:=600}"
: "${PPO_EPOCHS:=2}"
: "${PPO_BATCH:=256}"
: "${GAMMA:=0.99}"
: "${GAE_LAMBDA:=0.95}"
: "${LR:=3e-4}"
: "${CLIP:=0.2}"
: "${ENT_BETA:=2e-3}"
: "${ITERS:=100}"
: "${DEVICE:=auto}"
: "${CONTINUOUS:=0}"
: "${TAG_RUNNERS:=1}"
: "${TAG_ADVERSARIES:=3}"
: "${TAG_OBSTACLES:=2}"

export CUBLAS_WORKSPACE_CONFIG=":4096:8"
export PYTHONUNBUFFERED=1

if command -v stdbuf >/dev/null 2>&1; then
  PYRUN=(stdbuf -oL -eL "${PYTHON}" -u)
else
  PYRUN=("${PYTHON}" -u)
fi

DATE_STAMP="$(date +%Y%m%d_%H%M%S)"
ROOT="runs/${DATE_STAMP}_${OUT_TAG}"
mkdir -p "$ROOT"

{
  echo "# $(date)"
  echo "PYTHON=${PYTHON}"
  echo "SEEDS=${SEEDS}"
  echo "ENV_NAME=${ENV_NAME}"
  echo "ITERS=${ITERS}"
  echo "MAX_CYCLES=${MAX_CYCLES}"
  echo "ROLLOUT_STEPS=${ROLLOUT_STEPS}"
  echo "PPO_EPOCHS=${PPO_EPOCHS}"
  echo "PPO_BATCH=${PPO_BATCH}"
  echo "GAMMA=${GAMMA}"
  echo "GAE_LAMBDA=${GAE_LAMBDA}"
  echo "LR=${LR}"
  echo "CLIP=${CLIP}"
  echo "ENT_BETA=${ENT_BETA}"
  echo "DEVICE=${DEVICE}"
  echo "CONTINUOUS=${CONTINUOUS}"
  echo "TAG_RUNNERS=${TAG_RUNNERS} TAG_ADVERSARIES=${TAG_ADVERSARIES} TAG_OBSTACLES=${TAG_OBSTACLES}"
  echo "CUBLAS_WORKSPACE_CONFIG=${CUBLAS_WORKSPACE_CONFIG:-}"
  echo "PYTHONUNBUFFERED=${PYTHONUNBUFFERED:-}"
} > "${ROOT}/_run_meta.txt"

COMMON_FLAGS=(--env "${ENV_NAME}" --max_cycles "${MAX_CYCLES}" --rollout_min_steps "${ROLLOUT_STEPS}" --ppo_epochs "${PPO_EPOCHS}" --ppo_batch "${PPO_BATCH}" --gamma "${GAMMA}" --gae_lambda "${GAE_LAMBDA}" --lr "${LR}" --clip "${CLIP}" --fps "${FPS}" --iters "${ITERS}" --device "${DEVICE}")
if [[ "${CONTINUOUS}" == "1" ]]; then
  COMMON_FLAGS+=(--continuous_actions)
fi
TAG_FLAGS=(--tag_runners "${TAG_RUNNERS}" --tag_adversaries "${TAG_ADVERSARIES}" --tag_obstacles "${TAG_OBSTACLES}")

trap 'jobs -p | xargs -r kill' INT

combine_csvs() {
  local algo_dir="$1"
  local combined="$2"
  local first=1
  shopt -s nullglob
  local csvs=("$algo_dir"/seed_*/**/*.csv "$algo_dir"/seed_*/*.csv)
  shopt -u nullglob
  if [[ ${#csvs[@]} -eq 0 ]]; then
    : > "${combined}"
    return 0
  fi
  IFS=$'\n' read -r -d '' -a csvs_sorted < <(printf '%s\n' "${csvs[@]}" | sort -t_ -k2,2n && printf '\0')
  : > "${combined}"
  for csv in "${csvs_sorted[@]}"; do
    seed="$(echo "$csv" | sed -n 's/.*seed_\([0-9]\+\).*/\1/p')"
    if [[ -z "${seed}" ]]; then seed="NA"; fi
    if [[ $first -eq 1 ]]; then
      printf "seed," > "${combined}"
      head -n1 "${csv}" >> "${combined}"
      first=0
    fi
    tail -n +2 "${csv}" | awk -v s="${seed}" 'BEGIN{OFS=","} {print s,$0}' >> "${combined}"
  done
  echo "${combined}"
}

if [[ "${RUN_GEMS}" == "1" ]]; then
  ALGO_DIR="${ROOT}/gems"
  mkdir -p "${ALGO_DIR}"
  for SEED in ${SEEDS}; do
    SEED_DIR="${ALGO_DIR}/seed_${SEED}"
    mkdir -p "${SEED_DIR}"
    CSV_PATH="${SEED_DIR}/gems_seed${SEED}.csv"
    GIF_PATH="${SEED_DIR}/gems_seed${SEED}.gif"
    LOG_PATH="${SEED_DIR}/gems_seed${SEED}.log"
    "${PYRUN[@]}" ./gems.py "${COMMON_FLAGS[@]}" --csv "${CSV_PATH}" --video "${GIF_PATH}" --seed "${SEED}" $( [[ "${ENV_NAME}" == "simple_tag_v3" ]] && echo "${TAG_FLAGS[@]}" ) 2>&1 | tee "${LOG_PATH}"
  done
  combine_csvs "${ALGO_DIR}" "${ROOT}/gems_combined.csv" >/dev/null
fi

if [[ "${RUN_PSRO}" == "1" ]]; then
  ALGO_DIR="${ROOT}/psro"
  mkdir -p "${ALGO_DIR}"
  for SEED in ${SEEDS}; do
    SEED_DIR="${ALGO_DIR}/seed_${SEED}"
    mkdir -p "${SEED_DIR}"
    CSV_PATH="${SEED_DIR}/psro_seed${SEED}.csv"
    GIF_PATH="${SEED_DIR}/psro_seed${SEED}.gif"
    LOG_PATH="${SEED_DIR}/psro_seed${SEED}.log"
    "${PYRUN[@]}" ./psro.py "${COMMON_FLAGS[@]}" --csv "${CSV_PATH}" --video "${GIF_PATH}" --seed "${SEED}" --eval_episodes 1 $( [[ "${ENV_NAME}" == "simple_tag_v3" ]] && echo "${TAG_FLAGS[@]}" ) 2>&1 | tee "${LOG_PATH}"
  done
  combine_csvs "${ALGO_DIR}" "${ROOT}/psro_combined.csv" >/dev/null
fi

echo "${ROOT}"
