#!/usr/bin/env bash
# ============================================================================
# cleanup_grpo_inference.sh
# ----------------------------------------------------------------------------
# Standalone utility to discover and (optionally) terminate running
# grpo_inference.py processes (and their spawned children) for the current
# user (or a specified user). Designed to be separate from the orchestrator.
# ----------------------------------------------------------------------------
# Features:
#   * Dry-run mode by default (shows what would be killed)
#   * Recursive tree termination (parents first display, children killed first)
#   * Optional pattern override (default: grpo_inference.py)
#   * Target specific user (default: current $USER)
#   * Choose signal (default SIGTERM), optional forced SIGKILL after wait
#   * Filter to roots-only (avoid duplicate kills) while still removing children
#   * Colorized readable summary (auto-disables if not a TTY)
# ----------------------------------------------------------------------------
# Usage examples:
#   ./cleanup_grpo_inference.sh                  # Dry run summary
#   ./cleanup_grpo_inference.sh --kill           # Actually terminate with SIGTERM then SIGKILL if needed
#   ./cleanup_grpo_inference.sh -p grpo_inference.py --kill
#   ./cleanup_grpo_inference.sh --user someuser --kill
#   ./cleanup_grpo_inference.sh --kill --no-color
#   ./cleanup_grpo_inference.sh --kill --signal SIGINT
#   ./cleanup_grpo_inference.sh --kill --wait-seconds 10
# ----------------------------------------------------------------------------
set -euo pipefail

PATTERN="grpo_inference.py"
TARGET_USER="${USER}"
DO_KILL=0
FORCE=1                 # If 1, send SIGKILL to stubborn PIDs after wait
WAIT_SECONDS=5
SIGNAL="SIGTERM"
COLOR=1
ROOTS_ONLY=1            # Identify root processes matching pattern and kill their full trees
VERBOSE=0
ONLY_ORPHANS=0          # If 1, restrict to PPID=1
MAX_PER_CHECKPOINT=0    # If >0 keep at most N newest per checkpoint-<num>
KILL_PARENTS=0          # If 1 also kill wrapper run_grpo_inference*.sh parents

print_help() {
  cat <<EOF
Cleanup script for grpo_inference process trees.

Options:
  -p|--pattern <regex>         Process command regex (default: grpo_inference.py)
  -u|--user <user>             Only consider processes owned by <user> (default: current user)
  --all-users                  Ignore user filter (requires permission + caution)
  --kill                       Perform termination (default is dry run)
  --no-force                   Do not send SIGKILL after waiting
  --force                      (default) escalate to SIGKILL if processes remain
  --signal <SIG>               Primary signal to send (default: SIGTERM)
  --wait-seconds <n>           Seconds to wait before optional SIGKILL (default: 5)
  --only-orphans               Restrict to PPID=1 processes (already detached)
  --max-per-checkpoint <n>     Keep at most N newest processes per checkpoint-*; kill older ones
  --kill-parents               Also terminate ancestor run_grpo_inference*.sh shells
  --no-color                   Disable colored output
  --color                      Force colored output
  --include-nonroots           Attempt to kill every matching PID (not just roots)
  --verbose                    Extra diagnostic logs
  -h|--help                 Show this help

Exit codes:
  0 success
  1 no matches (dry run) or usage error
EOF
}

# Color helpers
if [[ ! -t 1 ]]; then COLOR=0; fi
c() { if (( COLOR )); then printf "\033[%sm" "$1"; fi; }
reset() { if (( COLOR )); then printf "\033[0m"; fi; }
red() { c 31; }
green() { c 32; }
yellow() { c 33; }
cyan() { c 36; }

# Arg parsing
while (( $# )); do
  case "$1" in
    -p|--pattern) PATTERN="$2"; shift 2;;
    -u|--user) TARGET_USER="$2"; shift 2;;
    --all-users) TARGET_USER=""; shift;;
    --kill) DO_KILL=1; shift;;
    --no-force) FORCE=0; shift;;
    --force) FORCE=1; shift;;
  --signal) SIGNAL="$2"; shift 2;;
    --wait-seconds) WAIT_SECONDS="$2"; shift 2;;
  --only-orphans) ONLY_ORPHANS=1; shift;;
  --max-per-checkpoint) MAX_PER_CHECKPOINT="$2"; shift 2;;
  --kill-parents) KILL_PARENTS=1; shift;;
    --no-color) COLOR=0; shift;;
    --color) COLOR=1; shift;;
    --include-nonroots) ROOTS_ONLY=0; shift;;
    --verbose) VERBOSE=1; shift;;
    -h|--help) print_help; exit 0;;
    *) echo "Unknown argument: $1" >&2; print_help; exit 1;;
  esac
done

# Collect processes
PS_FORMAT="pid=,ppid=,user=,args="
if [[ -n "$TARGET_USER" ]]; then
  MAPFILE_CMD=(ps -u "$TARGET_USER" -o $PS_FORMAT)
else
  MAPFILE_CMD=(ps -eo $PS_FORMAT)
fi

if (( VERBOSE )); then echo "[debug] Command: ${MAPFILE_CMD[*]}" >&2; fi

mapfile -t LINES < <("${MAPFILE_CMD[@]}") || true

declare -A PPID_OF
declare -A USER_OF
declare -A CMD_OF

MATCHED_PIDS=()
for line in "${LINES[@]}"; do
  # shellcheck disable=SC2206
  parts=($line)
  [[ ${#parts[@]} -ge 3 ]] || continue
  pid=${parts[0]}
  ppid=${parts[1]}
  user=${parts[2]}
  # Reconstruct command (remaining parts)
  cmd=${line#*$user }
  PPID_OF[$pid]=$ppid
  USER_OF[$pid]=$user
  CMD_OF[$pid]="$cmd"
  if [[ $cmd =~ $PATTERN ]]; then
    MATCHED_PIDS+=("$pid")
  fi
done

if (( ${#MATCHED_PIDS[@]} == 0 )); then
  echo "No processes found matching pattern '$PATTERN'${TARGET_USER:+ for user '$TARGET_USER'}." >&2
  exit 1
fi

# Orphan filter
if (( ONLY_ORPHANS )); then
  temp=()
  for p in "${MATCHED_PIDS[@]}"; do [[ ${PPID_OF[$p]:-} == 1 ]] && temp+=("$p"); done
  MATCHED_PIDS=("${temp[@]}")
  if (( ${#MATCHED_PIDS[@]} == 0 )); then
    echo "No orphan grpo_inference processes (PPID=1)." >&2
    exit 1
  fi
fi

# Extract checkpoint tokens
declare -A CHECKPOINT_OF
for p in "${MATCHED_PIDS[@]}"; do
  cp=$(grep -oE 'checkpoint-[0-9]+' <<< "${CMD_OF[$p]}" | tail -n1 || true)
  CHECKPOINT_OF[$p]="${cp:-none}"
done

# Enforce per-checkpoint cap
if (( MAX_PER_CHECKPOINT > 0 )); then
  declare -A GROUPS
  for p in "${MATCHED_PIDS[@]}"; do
    key=${CHECKPOINT_OF[$p]}
    GROUPS[$key]="${GROUPS[$key]} $p"
  done
  KEEP=()
  DROP=()
  for key in "${!GROUPS[@]}"; do
    # shellcheck disable=SC2206
    arr=(${GROUPS[$key]})
    IFS=$'\n' sorted=($(printf '%s\n' "${arr[@]}" | sort -nr))
    count=0
    for pid in "${sorted[@]}"; do
      if (( count < MAX_PER_CHECKPOINT )); then
        KEEP+=("$pid"); ((count++))
      else
        DROP+=("$pid")
      fi
    done
  done
  if ((${#DROP[@]})); then
    declare -A DROP_SET
    for d in "${DROP[@]}"; do DROP_SET[$d]=1; done
    filtered=()
    for p in "${MATCHED_PIDS[@]}"; do [[ -z ${DROP_SET[$p]:-} ]] && filtered+=("$p"); done
    MATCHED_PIDS=("${filtered[@]}")
    (( VERBOSE )) && echo "[debug] Removed ${#DROP[@]} old PIDs due to --max-per-checkpoint=$MAX_PER_CHECKPOINT" >&2
  fi
fi

# Determine roots (those whose parent is not also in MATCHED)
IS_MATCHED=()
for p in "${MATCHED_PIDS[@]}"; do IS_MATCHED[$p]=1; done
ROOT_PIDS=()
if (( ROOTS_ONLY )); then
  for p in "${MATCHED_PIDS[@]}"; do
    parent=${PPID_OF[$p]:-}
    if [[ -z "$parent" || -z ${IS_MATCHED[$parent]:-} ]]; then
      ROOT_PIDS+=("$p")
    fi
  done
else
  ROOT_PIDS=("${MATCHED_PIDS[@]}")
fi

# Add parent launcher shells if requested
if (( KILL_PARENTS )); then
  add_parents=()
  for p in "${ROOT_PIDS[@]}"; do
    cur=${PPID_OF[$p]:-}
    depth=0
    while [[ -n "$cur" && "$cur" != 1 && $depth -lt 5 ]]; do
      if [[ ${CMD_OF[$cur]:-} =~ run_grpo_inference_.*\.sh ]]; then
        add_parents+=("$cur")
        break
      fi
      cur=${PPID_OF[$cur]:-}
      ((depth++))
    done
  done
  for ap in "${add_parents[@]}"; do
    if [[ -z ${IS_MATCHED[$ap]:-} ]]; then
      ROOT_PIDS+=("$ap")
      IS_MATCHED[$ap]=1
    fi
  done
fi

# Recursive function to gather full tree for a root
collect_tree() {
  local root=$1
  echo "$root"
  for child_pid in "${!PPID_OF[@]}"; do
    if [[ ${PPID_OF[$child_pid]} == "$root" ]]; then
      if [[ -n ${IS_MATCHED[$child_pid]:-} || $ROOTS_ONLY -eq 0 ]]; then
        collect_tree "$child_pid"
      fi
    fi
  done
}

# Build kill lists per root
TOTAL_SET=()
for r in "${ROOT_PIDS[@]}"; do
  while IFS= read -r pid; do TOTAL_SET+=("$pid"); done < <(collect_tree "$r")
done

# Unique PIDs
declare -A SEEN
UNIQUE_PIDS=()
for p in "${TOTAL_SET[@]}"; do
  if [[ -z ${SEEN[$p]:-} ]]; then
    SEEN[$p]=1
    UNIQUE_PIDS+=("$p")
  fi
done

# Display summary
echo
if (( DO_KILL )); then
  echo -e "$(red)Planned termination$(reset): pattern='${PATTERN}' user='${TARGET_USER:-ALL}' roots_only=$ROOTS_ONLY signal=$SIGNAL force=$FORCE wait=${WAIT_SECONDS}s"
else
  echo -e "$(cyan)Dry run$(reset): pattern='${PATTERN}' user='${TARGET_USER:-ALL}' roots_only=$ROOTS_ONLY"
fi

echo "Root processes (${#ROOT_PIDS[@]}): ${ROOT_PIDS[*]}"
echo "Total processes to act on (${#UNIQUE_PIDS[@]}): ${UNIQUE_PIDS[*]}"

printf "\n%-8s %-8s %-16s %-10s %s\n" PID PPID CHECKPOINT USER COMMAND
for p in "${UNIQUE_PIDS[@]}"; do
  printf "%-8s %-8s %-16s %-10s %s\n" "$p" "${PPID_OF[$p]:-?}" "${CHECKPOINT_OF[$p]:-none}" "${USER_OF[$p]:-?}" "${CMD_OF[$p]:-?}" | cut -c -200
done

if (( DO_KILL == 0 )); then
  echo -e "\nUse --kill to actually terminate these processes."
  exit 0
fi

# Kill children first for each root (post-order traversal)
kill_tree() {
  local root=$1
  # Find children first
  for child in "${!PPID_OF[@]}"; do
    if [[ ${PPID_OF[$child]} == "$root" ]]; then
      if [[ -n ${SEEN[$child]:-} ]]; then
        kill_tree "$child"
      fi
    fi
  done
  if kill -0 "$root" 2>/dev/null; then
    echo "Sending $SIGNAL to $root (${CMD_OF[$root]:0:60})"
    kill -s "$SIGNAL" "$root" 2>/dev/null || true
  fi
}

for r in "${ROOT_PIDS[@]}"; do
  kill_tree "$r"
done

echo "Waiting $WAIT_SECONDS second(s) before force escalation..."
sleep "$WAIT_SECONDS"

# Escalate if FORCE enabled
if (( FORCE )); then
  STILL=()
  for p in "${UNIQUE_PIDS[@]}"; do
    if kill -0 "$p" 2>/dev/null; then
      STILL+=("$p")
    fi
  done
  if ((${#STILL[@]})); then
    echo -e "$(yellow)Escalating to SIGKILL for ${#STILL[@]} stubborn process(es)$(reset): ${STILL[*]}"
    kill -KILL "${STILL[@]}" 2>/dev/null || true
  else
    echo -e "$(green)All processes exited after $SIGNAL$(reset)"
  fi
else
  echo "Force escalation disabled (--no-force)."
fi

echo "Done."
