#!/bin/bash
# Default values
model_path="path/to/rm"
chat_template=""
torch_dtype=""
attn_implementation=""
dataset_path="./carb_data"
lang_subset="all"
max_length=2048
debug=false
trust_remote_code=false

# Parse command line options
usage() {
  echo "Usage: $0 [-m MODEL_PATH] [-c CHAT_TEMPLATE] [-t TORCH_DTYPE] [-a ATTN_IMPLEMENTATION] [-d DATASET_PATH] [-s LANG_SUBSET] [-l MAX_LENGTH] [-D]"
  echo "  -m MODEL_PATH          Path to the model (default: ${model_path})"
  echo "  -c CHAT_TEMPLATE       Chat template to use"
  echo "  -t TORCH_DTYPE         Torch data type"
  echo "  -a ATTN_IMPLEMENTATION Attention implementation"
  echo "  -d DATASET_PATH        Path to dataset (default: ${dataset_path})"
  echo "  -s LANG_SUBSET         Language subset (default: all)"
  echo "  -l MAX_LENGTH          Maximum length of the input sequence (default: ${max_length})"
  echo "  -r                     Trust remote code"
  echo "  -D                     Enable debug mode"
  echo "  -h                     Display this help message"
  exit 1
}

while getopts "m:c:t:a:d:s:l:rDh" opt; do
  case ${opt} in
    m )
      model_path=$OPTARG
      ;;
    c )
      chat_template=$OPTARG
      ;;
    t )
      torch_dtype=$OPTARG
      ;;
    a )
      attn_implementation=$OPTARG
      ;;
    d )
      dataset_path=$OPTARG
      ;;
    s )
      lang_subset=$OPTARG
      ;;
    l )
      max_length=$OPTARG
      ;;
    r )
      trust_remote_code=true
      ;;
    D )
      debug=true
      ;;
    h )
      usage
      ;;
    \? )
      usage
      ;;
  esac
done

model_name=$(basename ${model_path})

# Automatically detect number of GPUs
num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
echo "Detected ${num_gpus} GPUs"

cd path/to/CARB

cmd="python scripts/run_v2.py --model=${model_path} --dataset=${dataset_path} --lang_subset=${lang_subset} --disable_beaker_save --max_length=${max_length}"

# Build command with optional parameters

if [ -n "$chat_template" ]; then
  cmd="${cmd} --chat_template=${chat_template}"
fi

if [ -n "$torch_dtype" ]; then
  cmd="${cmd} --torch_dtype=${torch_dtype}"
fi

if [ -n "$attn_implementation" ]; then
  cmd="${cmd} --attn_implementation=${attn_implementation}"
fi

if [ "$debug" = true ]; then
  cmd="${cmd} --debug"
fi

if [ "$trust_remote_code" = true ]; then
  cmd="${cmd} --trust_remote_code"
fi

# Execute the command
echo "Running command: ${cmd}"
eval ${cmd}