#!/usr/bin/env bash
set -euo pipefail

R=1
ARGS=("$@")
OUT=()

# Track batch_size flag for rewriting later
batch_flag_name=""
batch_value=""
batch_style_equals=0  # 1 if provided as --batch_size=VAL, else 0

# Parse args (strip R/nproc_per_node; capture batch_size; keep the rest)
i=0
while (( i < ${#ARGS[@]} )); do
  arg="${ARGS[i]}"
  case "$arg" in
    --R|--r|--nproc_per_node)
      R="${ARGS[i+1]}"; i=$((i+2));;
    --R=*|--r=*|--nproc_per_node=*)
      R="${arg#*=}"; i=$((i+1));;

    --batch_size|--batch-size)
      batch_flag_name="$arg"
      batch_value="${ARGS[i+1]}"
      batch_style_equals=0
      i=$((i+2));;

    --batch_size=*|--batch-size=*)
      batch_flag_name="${arg%%=*}"
      batch_value="${arg#*=}"
      batch_style_equals=1
      i=$((i+1));;

    *)
      OUT+=("$arg")
      i=$((i+1));;
  esac
done

# Scale batch_size if present: batch_size = int(batch_size * R / 8)
if [[ -n "${batch_flag_name}" ]]; then
  if ! [[ "$batch_value" =~ ^[0-9]+$ ]]; then
    echo "Error: --batch_size must be an integer, got '$batch_value'" >&2
    exit 1
  fi
  scaled=$(( batch_value * 8 / R ))
  (( scaled < 1 )) && scaled=1

  if (( batch_style_equals == 1 )); then
    OUT+=("${batch_flag_name}=${scaled}")
  else
    OUT+=("${batch_flag_name}" "${scaled}")
  fi
fi

# Run with the requested GPU count
exec torchrun --nproc_per_node="${R}" "${OUT[@]}"
