#!/bin/bash
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1

#########options#########
# model set : Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-3B-Instruct, Qwen/Qwen2.5-7B-Instruct
# model set : meta-llama/Llama-3.1-8B-Instruct, google/gemma-2-9b-it
# dataset set : anthropic, multijail

# Set persona mode
USE_PERSONA=true
PERSONA_FLAG=""
if [ "$USE_PERSONA" = true ]; then
    PERSONA_FLAG="--persona"
fi

# Set number of agents (2, 3, or 4)
NUM_AGENTS=2

# Set models for each agent (edit as needed)
MODEL_1="Qwen/Qwen2.5-1.5B-Instruct"
MODEL_2="Qwen/Qwen2.5-1.5B-Instruct"
MODEL_3="Qwen/Qwen2.5-1.5B-Instruct"  # Only used if NUM_AGENTS >= 3
MODEL_4="Qwen/Qwen2.5-1.5B-Instruct"  # Only used if NUM_AGENTS == 4

# Set dataset (anthropic or multijail)
DATASET="anthropic"  # or "multijail"
LANG="en" # en or ko for multijail

# Set num_query based on dataset
if [ "$DATASET" = "anthropic" ]; then
    NUM_QUERY=37
elif [ "$DATASET" = "multijail" ]; then
    NUM_QUERY=50
else
    echo "Unknown dataset: $DATASET"
    exit 1
fi

# Set round based on num_agents
if [ "$NUM_AGENTS" -eq 4 ]; then
    ROUND=4
else
    ROUND=8
fi

# Helper to build model args
MODEL_ARGS="--model_1 $MODEL_1 --model_2 $MODEL_2"
if [ "$NUM_AGENTS" -ge 3 ]; then
    MODEL_ARGS+=" --model_3 $MODEL_3"
fi
if [ "$NUM_AGENTS" -ge 4 ]; then
    MODEL_ARGS+=" --model_4 $MODEL_4"
fi

if [ "$DATASET" = "anthropic" ]; then
    python src/inference_multi_safety.py --num_agents $NUM_AGENTS \
        $MODEL_ARGS \
        --round $ROUND --num_query $NUM_QUERY --dataset $DATASET \
        $PERSONA_FLAG --add_self_response --use_server
elif [ "$DATASET" = "multijail" ]; then
    for LANG in en ko; do
        python src/inference_multi_safety.py --num_agents $NUM_AGENTS \
            $MODEL_ARGS \
            --round $ROUND --num_query $NUM_QUERY --dataset $DATASET --language $LANG \
            $PERSONA_FLAG --add_self_response --use_server
    done
fi

# kill vllm servers. port 8000 and 8001
PID=$(ps aux | grep "python -m vllm.entrypoints.openai.api_server" | grep -- "--port 8000" | grep -v grep | awk '{print $2}')
if [ -n "$PID" ]; then
  kill -9 $PID
  echo "Server on port 8000 stopped (PID: $PID)"
else
  echo "Server on port 8000 not running"
fi

PID=$(ps aux | grep "python -m vllm.entrypoints.openai.api_server" | grep -- "--port 8001" | grep -v grep | awk '{print $2}')
if [ -n "$PID" ]; then
  kill -9 $PID
  echo "Server on port 8001 stopped (PID: $PID)"
else
  echo "Server on port 8001 not running"
fi