#!/bin/bash

# Multi-GPU WandB sweep runner with multiple agents per GPU
# Usage: ./run_multi_gpu_sweep.sh --server <name> [--gpu_count <count>] [--agents_per_gpu <count>]

usage() {
    echo "Usage: $0 --server <name> [--gpu_count <count>] [--agents_per_gpu <count>]"
    echo "Example: $0 --server amai --gpu_count 2 --agents_per_gpu 3"
    echo "Defaults: GPU count: 3, Agents per GPU: 1"
    exit 1
}

# Parse arguments
gpu_count=3
agents_per_gpu=1
while [[ "$#" -gt 0 ]]; do
    case $1 in
        --server) server="$2"; shift 2 ;;
        --gpu_count) gpu_count="$2"; shift 2 ;;
        --agents_per_gpu) agents_per_gpu="$2"; shift 2 ;;
        *) echo "Unknown parameter: $1"; usage ;;
    esac
done

if [ -z "$server" ]; then
    echo "Error: --server required"
    usage
fi

sweep_config="scripts/sweeps/resmp_attn_loss_sweep.yaml"

echo "Starting multi-GPU hyperparameter sweep"
echo "Server: $server | GPUs: 0-$((gpu_count-1)) | Agents per GPU: $agents_per_gpu"

mkdir -p logs

# Initialize single sweep
echo "Initializing WandB sweep..."
sweep_output=$(wandb sweep --project m3epi_v3 "$sweep_config" 2>&1)
sweep_id=$(echo "$sweep_output" | grep -o 'wandb agent [^[:space:]]*' | awk '{print $3}')

if [ -z "$sweep_id" ]; then
    echo "Failed to initialize sweep. Output:"
    echo "$sweep_output"
    exit 1
fi

echo "Sweep ID: $sweep_id"

# Launch agents
pids=()
for gpu_id in $(seq 0 $((gpu_count-1))); do
    for agent_id in $(seq 1 $agents_per_gpu); do
        echo "Starting agent $agent_id on GPU $gpu_id..."
        export CUDA_VISIBLE_DEVICES=$gpu_id
        nohup wandb agent "$sweep_id" > "logs/sweep_${server}_gpu${gpu_id}_agent${agent_id}.log" 2>&1 &
        pid=$!
        pids+=($pid)
        echo "GPU $gpu_id Agent $agent_id started (PID: $pid)"
        sleep 2  # Stagger startup
    done
done

echo ""
echo "All agents started successfully!"
echo "Sweep ID: $sweep_id"
echo "PIDs: ${pids[*]}"
echo ""
echo "Monitor logs:"
for gpu_id in $(seq 0 $((gpu_count-1))); do
    for agent_id in $(seq 1 $agents_per_gpu); do
        echo "  GPU $gpu_id Agent $agent_id: tail -f logs/sweep_${server}_gpu${gpu_id}_agent${agent_id}.log"
    done
done
echo ""
echo "Kill all agents: kill ${pids[*]}"
echo "Or use: pkill -f 'wandb agent'"