#!/usr/bin/env bash
set -euo pipefail

# Fix sink-related OmniGibson macros in whatever source tree IS-Bench uses.

ROOT_DIR="$(pwd)"
if command -v git >/dev/null 2>&1; then
  if git_root="$(git rev-parse --show-toplevel 2>/dev/null)"; then
    ROOT_DIR="$git_root"
  fi
fi

declare -a filled_files=()
declare -a source_sink_files=()

if [[ -n "${OMNIGIBSON_SRC:-}" ]]; then
  if [[ -f "$OMNIGIBSON_SRC/omnigibson/object_states/filled.py" ]]; then
    filled_files+=("$OMNIGIBSON_SRC/omnigibson/object_states/filled.py")
  fi
  if [[ -f "$OMNIGIBSON_SRC/omnigibson/object_states/particle_source_or_sink.py" ]]; then
    source_sink_files+=("$OMNIGIBSON_SRC/omnigibson/object_states/particle_source_or_sink.py")
  fi
fi

detect_omnigibson_src() {
  python - <<'PY'
import importlib.util
from pathlib import Path

spec = importlib.util.find_spec("omnigibson")
if spec and spec.origin:
    root = Path(spec.origin).resolve().parent.parent
    print(root)
PY
}

if [[ -z "${OMNIGIBSON_SRC:-}" ]]; then
  if auto_src="$(detect_omnigibson_src 2>/dev/null)"; then
    if [[ -n "$auto_src" ]]; then
      if [[ -f "$auto_src/omnigibson/object_states/filled.py" ]]; then
        filled_files+=("$auto_src/omnigibson/object_states/filled.py")
      fi
      if [[ -f "$auto_src/omnigibson/object_states/particle_source_or_sink.py" ]]; then
        source_sink_files+=("$auto_src/omnigibson/object_states/particle_source_or_sink.py")
      fi
    fi
  fi
fi

while IFS= read -r -d '' f; do
  filled_files+=("$f")
done < <(find "$ROOT_DIR" -type f -path "*/omnigibson/object_states/filled.py" -print0)

while IFS= read -r -d '' f; do
  source_sink_files+=("$f")
done < <(find "$ROOT_DIR" -type f -path "*/omnigibson/object_states/particle_source_or_sink.py" -print0)

if [[ ${#filled_files[@]} -eq 0 && ${#source_sink_files[@]} -eq 0 ]]; then
  echo "No OmniGibson object_states files found under: $ROOT_DIR" >&2
  exit 1
fi

update_file() {
  local file="$1"
  local pattern="$2"
  local replacement="$3"

  python - "$file" "$pattern" "$replacement" <<'PY'
import re
import sys
from pathlib import Path

path = Path(sys.argv[1])
pattern = sys.argv[2]
replacement = sys.argv[3]

text = path.read_text()
new_text, n = re.subn(pattern, replacement, text, flags=re.MULTILINE)
if n == 0:
    print(f"[WARN] Pattern not found in {path}")
else:
    path.write_text(new_text)
    print(f"[OK] Updated {path}")
PY
}

for f in "${filled_files[@]}"; do
  update_file "$f" "^m\\.VOLUME_FILL_PROPORTION\\s*=\\s*.*$" "m.VOLUME_FILL_PROPORTION = 0.001"
done

for f in "${source_sink_files[@]}"; do
  update_file "$f" "^m\\.MAX_SOURCE_PARTICLES_PER_STEP\\s*=\\s*.*$" "m.MAX_SOURCE_PARTICLES_PER_STEP = 1"
done
