#!/usr/bin/env bash
set -euo pipefail

# Usage:
#   bash run_eval_all.sh /path/to/model [task1 task2 ...]
#
# Examples:
#   bash run_eval_all.sh /path/to/model
#   bash run_eval_all.sh /path/to/model gsm8k countdown
#
# Optional env overrides:
#   GPUS="8,9" NPROC=2 MASTER_PORT=22320 OUTPUT_DIR="eval_results" BLOCK_LENGTH=32 GEN_LENGTHS="128 256 512"

if [[ $# -lt 1 ]]; then
  echo "Usage: $0 <model_path> [task1 task2 ...]"
  exit 1
fi

MODEL_PATH="$1"
shift || true

# If tasks were provided positionally, use them; otherwise default set.
if [[ $# -gt 0 ]]; then
  DATASETS="$*"
else
  DATASETS="math gsm8k countdown sudoku"
fi

GPUS="${GPUS:-8,9}"
NPROC="${NPROC:-2}"
MASTER_PORT="${MASTER_PORT:-22320}"
OUTPUT_DIR="${OUTPUT_DIR:-eval_results}"
BLOCK_LENGTH="${BLOCK_LENGTH:-32}"
GEN_LENGTHS="${GEN_LENGTHS:-128 256 512}"

batch_size_for_gen_len () {
  local gl="$1"
  if [[ "$gl" -eq 512 ]]; then
    echo 8
  else
    echo 16
  fi
}

diff_steps_for_gen_len () {
  local gl="$1"
  echo $(( gl / 2 ))
}

echo "Model path: ${MODEL_PATH}"
echo "Tasks: ${DATASETS}"
echo "GPUs: ${GPUS} | nproc: ${NPROC} | master_port: ${MASTER_PORT}"
echo "Output dir: ${OUTPUT_DIR} | block_length: ${BLOCK_LENGTH}"
echo "Gen lengths: ${GEN_LENGTHS}"
echo

for dataset in ${DATASETS}; do
  for gen_len in ${GEN_LENGTHS}; do
    bs="$(batch_size_for_gen_len "${gen_len}")"
    ds="$(diff_steps_for_gen_len "${gen_len}")"

    echo "=== Running: dataset=${dataset} gen_length=${gen_len} diffusion_steps=${ds} batch_size=${bs} ==="
    CUDA_VISIBLE_DEVICES="${GPUS}" torchrun \
      --nproc_per_node "${NPROC}" \
      --master_port "${MASTER_PORT}" \
      eval.py \
        --dataset "${dataset}" \
        --batch_size "${bs}" \
        --gen_length "${gen_len}" \
        --block_length "${BLOCK_LENGTH}" \
        --diffusion_steps "${ds}" \
        --output_dir "${OUTPUT_DIR}" \
        --model_path "${MODEL_PATH}"
    echo
  done
done

echo "All runs complete. Results in: ${OUTPUT_DIR}"
