#!/bin/bash
# Set multiprocessing method to 'spawn' for better stability with multiple models
# This is important when running multiple models in parallel
export VLLM_WORKER_MULTIPROC_METHOD=spawn

# Model names
MODEL_1="Qwen/Qwen2.5-7B-Instruct"
MODEL_2="Qwen/Qwen2.5-3B-Instruct"
MODEL_3="Qwen/Qwen2.5-32B-Instruct"
MODEL_4="Qwen/Qwen2.5-14B-Instruct"
MODEL_5="Qwen/Qwen2.5-1.5B-Instruct"
MODEL_6="Qwen/Qwen2.5-72B-Instruct"
MODEL_7="meta-llama/Llama-3.1-8B-Instruct"
MODEL_8="meta-llama/Llama-3.2-3B-Instruct"
MODEL_9="google/gemma-2-2b-it"
MODEL_10="google/gemma-2-9b-it"
MODEL_11="openai/gpt-3.5-turbo"
MODEL_12="anthropic/claude-v1"
MODEL_13="cohere/command-xlarge-nightly"
MODEL_14="ai21/j1-jumbo"
MODEL_15="microsoft/turing-nlg-17b"

# Default values
NUM_QUERY=500
ROUND=4
TEMP=0.7
COT=false
ADD_SELF_RESPONSE=true
EARLY_STOPPING=true
DATASET="math500"
DATA_TOPIC="algebra"  # Only used for CollegeMath dataset
USE_SERVER=false
CLUSTER2=false
CLUSTER1=false  # Add cluster1 computing option
NODES=()  # Array to store node names
NODE_GPUS=()  # Array to store node:gpu mappings
SERVER_URLS=()  # Array to store server URLs

# Parse command line arguments
while [[ $# -gt 0 ]]; do
    case $1 in
        --num_agents)
            NUM_AGENTS="$2"
            shift 2
            ;;
        --models)
            shift
            MODEL_CHOICES=()
            while [[ $# -gt 0 && ! $1 =~ ^-- ]]; do
                MODEL_CHOICES+=("$1")
                shift
            done
            ;;
        --nodes)
            shift
            NODES=()
            while [[ $# -gt 0 && ! $1 =~ ^-- ]]; do
                NODES+=("$1")
                shift
            done
            ;;
        --node_gpus)
            shift
            NODE_GPUS=()
            while [[ $# -gt 0 && ! $1 =~ ^-- ]]; do
                NODE_GPUS+=("$1")
                shift
            done
            ;;
        --gpus)
            shift
            GPU_INDICES=()
            while [[ $# -gt 0 && ! $1 =~ ^-- ]]; do
                GPU_INDICES+=("$1")
                shift
            done
            ;;
        --server_urls)
            shift
            SERVER_URLS=()
            while [[ $# -gt 0 && ! $1 =~ ^-- ]]; do
                SERVER_URLS+=("$1")
                shift
            done
            ;;
        --dataset)
            DATASET="$2"
            shift 2
            ;;
        --num_query)
            NUM_QUERY="$2"
            shift 2
            ;;
        --round)
            ROUND="$2"
            shift 2
            ;;
        --temp)
            TEMP="$2"
            shift 2
            ;;
        --cot)
            COT=true
            shift
            ;;
        --add_self_response)
            ADD_SELF_RESPONSE=true
            shift
            ;;
        --no_early_stopping)
            EARLY_STOPPING=false
            shift
            ;;
        --data_topic)
            DATA_TOPIC="$2"
            shift 2
            ;;
        --use_server)
            USE_SERVER=true
            shift
            ;;
        --cluster2)
            CLUSTER2=true
            shift
            ;;
        --cluster1)
            CLUSTER1=true
            shift
            ;;
        *)
            echo "Unknown option: $1"
            exit 1
            ;;
    esac
done

# Check if required arguments are provided
if [ -z "$NUM_AGENTS" ] || [ -z "${MODEL_CHOICES[*]}" ] || [ -z "$DATASET" ]; then
    echo "Usage: bash run_math_multi.sh --num_agents <number> --models <model_choices...> --dataset <dataset> [options]"
    echo "  Required arguments:"
    echo "    --num_agents <number>    : Number of agents (2-8)"
    echo "    --models <choices...>    : Model choices (1-15) for each agent"
    echo "    --dataset <name>         : Dataset name (gsm8k, math500, collegemath, aime2024, aime2025)"
    echo "  Optional arguments:"
    echo "    --nodes <names...>       : Node names for each agent"
    echo "    --node_gpus <mappings...> : Node:GPU mappings for each agent"
    echo "    --gpus <indices...>      : GPU indices for each model (default: sequential starting from 0)"
    echo "    --server_urls <urls...>  : Server URLs for each model (default: none)"
    echo "    --num_query <number>     : Number of queries (default: 500)"
    echo "    --round <number>         : Number of debate rounds (default: 4)"
    echo "    --temp <number>          : Temperature (default: 0.7)"
    echo "    --cot                    : Enable Chain-of-Thought"
    echo "    --add_self_response      : Enable self-response"
    echo "    --no_early_stopping      : Disable early stopping"
    echo "    --data_topic <topic>     : Topic for CollegeMath dataset (default: algebra)"
    echo "                               Options: algebra, calculus, precalculus, differential_equation,"
    echo "                                        linear_algebra, probability, vector_calculus"
    echo "    --use_server             : Use vLLM server APIs instead of direct integration"
    echo "    --cluster2                 : Enable cluster2 option"
    echo "    --cluster1                 : Enable cluster1 computing configuration"
    echo "  Model choices:"
    echo "    1: Qwen2.5-7B"
    echo "    2: Qwen2.5-3B"
    echo "    3: Qwen2.5-32B"
    echo "    4: Qwen2.5-14B"
    echo "    5: Qwen2.5-1.5B"
    echo "    6: Qwen2.5-72B"
    echo "    7: Llama-3.1-8B"
    echo "    8: Llama-3.2-3B"
    echo "    9: Gemma-2-2B"
    echo "    10: Gemma-2-9B"
    echo "    11: GPT-3.5-Turbo"
    echo "    12: Claude-v1"
    echo "    13: Command-Xlarge-Nightly"
    echo "    14: J1-Jumbo"
    echo "    15: Turing-NLG-17B"
    exit 1
fi

# Validate dataset name (case insensitive comparison)
dataset_lower=$(echo "$DATASET" | tr '[:upper:]' '[:lower:]')
if [[ "$dataset_lower" != "gsm8k" && "$dataset_lower" != "math500" && "$dataset_lower" != "collegemath" && "$dataset_lower" != "aime2024" && "$dataset_lower" != "aime2025" ]]; then
    echo "❌ Invalid dataset name: $DATASET"
    echo "Valid options: gsm8k, math500, collegemath, aime2024, aime2025"
    exit 1
fi

# Validate number of agents
if [ "$NUM_AGENTS" -lt 2 ] || [ "$NUM_AGENTS" -gt 8 ]; then
    echo "❌ Number of agents must be between 2 and 8"
    exit 1
fi

# Validate number of model choices
if [ ${#MODEL_CHOICES[@]} -ne "$NUM_AGENTS" ]; then
    echo "❌ Number of model choices must match number of agents"
    exit 1
fi

# Validate nodes if provided
if [ ! -z "${NODES[*]}" ] && [ ${#NODES[@]} -ne "$NUM_AGENTS" ]; then
    echo "❌ Number of nodes must match number of agents"
    exit 1
fi

# Validate node:gpu mappings
if [ ! -z "${NODE_GPUS[*]}" ] && [ ${#NODE_GPUS[@]} -ne "$NUM_AGENTS" ]; then
    echo "❌ Number of node:gpu mappings must match number of agents"
    exit 1
fi

# Validate server URLs if using server mode
if [ "$USE_SERVER" = true ] && [ ${#SERVER_URLS[@]} -ne "$NUM_AGENTS" ]; then
    echo "❌ Number of server URLs must match number of agents when using server mode"
    exit 1
fi

# Map model choices to actual model names
MODEL_ARGS=()
for choice in "${MODEL_CHOICES[@]}"; do
    case $choice in
        1)
            MODEL_ARGS+=("$MODEL_1")
            ;;
        2)
            MODEL_ARGS+=("$MODEL_2")
            ;;
        3)
            MODEL_ARGS+=("$MODEL_3")
            ;;
        4)
            MODEL_ARGS+=("$MODEL_4")
            ;;
        5)
            MODEL_ARGS+=("$MODEL_5")
            ;;
        6)
            MODEL_ARGS+=("$MODEL_6")
            ;;
        7)
            MODEL_ARGS+=("$MODEL_7")
            ;;
        8)
            MODEL_ARGS+=("$MODEL_8")
            ;;
        9)
            MODEL_ARGS+=("$MODEL_9")
            ;;
        10)
            MODEL_ARGS+=("$MODEL_10")
            ;;
        11)
            MODEL_ARGS+=("$MODEL_11")
            ;;
        12)
            MODEL_ARGS+=("$MODEL_12")
            ;;
        13)
            MODEL_ARGS+=("$MODEL_13")
            ;;
        14)
            MODEL_ARGS+=("$MODEL_14")
            ;;
        15)
            MODEL_ARGS+=("$MODEL_15")
            ;;
        *)
            echo "❌ Invalid model choice: $choice"
            exit 1
            ;;
    esac
done

# Function to run command on a specific node
run_on_node() {
    local node=$1
    local cmd=$2
    
    if [ "$node" = "localhost" ]; then
        eval "$cmd"
    else
        ssh "$node" "cd $(pwd) && $cmd"
    fi
}

# Construct base command
BASE_CMD="python src/inference_multi.py --num_agents $NUM_AGENTS"

# Add model arguments
for ((i=0; i<NUM_AGENTS; i++)); do
    BASE_CMD+=" --model_$((i+1)) ${MODEL_ARGS[$i]}"
    BASE_CMD+=" --gpu_$((i+1)) ${GPU_INDICES[$i]:-$i}"
    if [ ! -z "${NODES[*]}" ]; then
        BASE_CMD+=" --node_$((i+1)) ${NODES[$i]}"
    fi
done

# Add other arguments
BASE_CMD+=" --num_query $NUM_QUERY"
BASE_CMD+=" --round $ROUND"
BASE_CMD+=" --temp $TEMP"
BASE_CMD+=" --dataset $dataset_lower"

if [ "$dataset_lower" == "collegemath" ]; then
    BASE_CMD+=" --data_topic $DATA_TOPIC"
fi

if [ "$COT" = true ]; then
    BASE_CMD+=" --cot"
fi

if [ "$ADD_SELF_RESPONSE" = true ]; then
    BASE_CMD+=" --add_self_response"
fi

if [ "$EARLY_STOPPING" = false ]; then
    BASE_CMD+=" --early_stopping false"
fi

if [ "$CLUSTER2" = true ]; then
    BASE_CMD+=" --cluster2"
fi

if [ "$CLUSTER1" = true ]; then
    BASE_CMD+=" --cluster1"
fi

# Run the command on each node
if [ ! -z "${NODES[*]}" ]; then
    # Get unique nodes
    unique_nodes=($(printf "%s\n" "${NODES[@]}" | sort -u))
    
    # Run on each unique node
    for node in "${unique_nodes[@]}"; do
        echo "🚀 Running on node: $node"
        run_on_node "$node" "$BASE_CMD" &
    done
    
    # Wait for all processes to complete
    wait
else
    # Run locally
    echo "🚀 Running locally"
    eval "$BASE_CMD"
fi

echo "✅ All processes completed!" 
