#!/bin/bash
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1

# ====== USER CONFIGURATION ======
# Choose model, port, persona, dataset, and modes here
MODEL_1="Qwen/Qwen2.5-1.5B-Instruct"   # Model name
PORT=8000                              # Port for vLLM server
USE_PERSONA="None"                 # ethical, helper, or None (no persona)
DATASET="anthropic"                   # anthropic or multijail
LANGUAGE="en"                        # Only used if DATASET=multijail (en or ko)
USE_SELF_REFINEMENT=true               # true or false
USE_SELF_CONSISTENCY=false             # true or false
# =================================

# Check for mutually exclusive options
if [ "$USE_SELF_REFINEMENT" = true ] && [ "$USE_SELF_CONSISTENCY" = true ]; then
  echo "❌ Both self_refinement and self_consistency cannot be true at the same time."
  exit 1
fi

# Set persona flag
PERSONA_FLAG=""
if [ "$USE_PERSONA" = "ethical" ] || [ "$USE_PERSONA" = "helper" ]; then
  PERSONA_FLAG="--persona $USE_PERSONA"
fi

# Set self_refinement/self_consistency flags and round/num_samples
SELF_REFINE_FLAG=""
SELF_CONSIST_FLAG=""
ROUND=""
NUM_SAMPLES=""
if [ "$USE_SELF_REFINEMENT" = true ]; then
  SELF_REFINE_FLAG="--self_refinement"
  ROUND="--round 16"
fi
if [ "$USE_SELF_CONSISTENCY" = true ]; then
  SELF_CONSIST_FLAG="--self_consistency"
  NUM_SAMPLES="--num_samples 16"
fi

# Set dataset and num_query
if [ "$DATASET" = "anthropic" ]; then
  NUM_QUERY=37
  DATASET_FLAG="--dataset anthropic"
  LANGUAGE_FLAG=""
elif [ "$DATASET" = "multijail" ]; then
  NUM_QUERY=50
  DATASET_FLAG="--dataset multijail"
  LANGUAGE_FLAG="--language $LANGUAGE"
else
  echo "❌ Unknown dataset: $DATASET"
  exit 1
fi


# Build inference command
CMD="python src/inference_self_safety.py \
  --model_1 $MODEL_1 \
  --num_query $NUM_QUERY \
  $DATASET_FLAG \
  $LANGUAGE_FLAG \
  $SELF_REFINE_FLAG $SELF_CONSIST_FLAG \
  $ROUND $NUM_SAMPLES \
  --port $PORT \
  $PERSONA_FLAG"

# Run inference
# Remove extra spaces
CMD=$(echo $CMD | xargs)
echo "Running: $CMD"
eval $CMD

# Kill vLLM server
PID=$(ps aux | grep "python -m vllm.entrypoints.openai.api_server" | grep -- "--port $PORT" | grep -v grep | awk '{print $2}')
if [ -n "$PID" ]; then
  kill -9 $PID
  echo "Server on port $PORT stopped (PID: $PID)"
else
  echo "Server on port $PORT not running"
fi
