#!/usr/bin/env bash
set -Eeuo pipefail

# ---------- Config ----------
GPUS=(0 1 2 3)                              # GPU IDs
# METHODS=(npo kto selu stl-selu)                                # e.g., ga npo selu ...
DATASET_SUBSET="${DATASET_SUBSET:-forget01}"
LORA_PAIRS=("8:32" "16:32" "4:16")          # rank:alpha
LEARNING_RATES=("1e-04")            # LR sweep

BASE_MODEL="${BASE_MODEL:-open-unlearning/tofu_Llama-2-7b-chat-hf_full}"
# DATA_PATH="${DATA_PATH:-tinyBenchmarks/tinyMMLU}"
DATA_PATH="${DATA_PATH:-tinyBenchmarks/tinyAI2_arc}"

# Optional: add global default flags here (dtype, batch size, etc.)
EXTRA_ARGS=( "$@" )

# Track 1 PID per GPU
declare -A GPU_PID
for g in "${GPUS[@]}"; do GPU_PID[$g]=0; done

failed=()

launch_job () {
  local GPU="$1"; shift
  # Remaining "$@" are the python args
  (
    CUDA_VISIBLE_DEVICES="$GPU" \
    PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
    python run_tofu_unlearn_eval_utility.py "$@" \
    || echo "FAILED ${*}"  # emit a line if the job errors
  ) &
  GPU_PID[$GPU]=$!
  echo "Launched PID ${GPU_PID[$GPU]} on GPU $GPU: python run_tofu_unlearn_eval_utility.py $*"
}

wait_one_gpu_free () {
  # Wait until any GPU slot becomes free, then return
  while :; do
    for g in "${GPUS[@]}"; do
      pid=${GPU_PID[$g]}
      if [[ "$pid" != 0 ]]; then
        if ! kill -0 "$pid" 2>/dev/null; then
          # process finished
          wait "$pid" || true
          GPU_PID[$g]=0
          echo "$g"
          return 0
        fi
      else
        # already free
        echo "$g"
        return 0
      fi
    done
    # no free GPU yet; sleep briefly
    sleep 0.5
  done
}

# ---------- Job scheduling ----------
for M in "${METHODS[@]}"; do
  echo ">>> Evaluating method=$M on subset=$DATASET_SUBSET"

  for PAIR in "${LORA_PAIRS[@]}"; do
    IFS=':' read -r RANK ALPHA <<< "$PAIR"

    for LR in "${LEARNING_RATES[@]}"; do
      ADAPTER_ID="ganeric15/msc_unlearn_lora_${RANK}_${ALPHA}_${LR}_${M}_tofu_${DATASET_SUBSET}"

      # Build argv for this job safely as an array (no flattening!)
      args=(
        --model_id "$BASE_MODEL"
        --adapter_id "$ADAPTER_ID"
        --data_path "$DATA_PATH"
        --all_epochs
        "${EXTRA_ARGS[@]}"
      )

      # Find / wait for a free GPU slot
      GPU=$(wait_one_gpu_free)
      echo "---- Scheduling $ADAPTER_ID on GPU $GPU"
      launch_job "$GPU" "${args[@]}"
    done
  done

  echo ">>> All jobs for method=$M launched."
done

# ---------- Wait for all outstanding jobs ----------
for g in "${GPUS[@]}"; do
  pid=${GPU_PID[$g]}
  if [[ "$pid" != 0 ]]; then
    wait "$pid" || true
    GPU_PID[$g]=0
  fi
done

# Collect failures from the subshell echo lines (optional: grep your logs)
echo "All evaluations complete."
