#!/bin/bash

NODE_LIST=("173.0.56.5" "173.0.99.3" "173.0.66.3")
len=${#NODE_LIST[@]}

MASTER_ADDR=${NODE_LIST[0]}
MASTER_PORT=12346

# Array to hold PIDs of background jobs
bg_pids=()

# Function to kill the background job when the script exits
cleanup() {
    for IP_ADDR in "${NODE_LIST[@]}"
    do
        ssh $IP_ADDR "sudo apt install -y psmisc"
        ssh $IP_ADDR "sudo fuser -k -v /dev/nvidia*"
    done 
    echo "Exiting jobs."
    for pid in "${bg_pids[@]}"
    do
        kill $pid 2>/dev/null
    done
}

# Register the cleanup function to be called on script exit
trap cleanup EXIT

for i in $(seq 0 $((len-1)))
do
    IP_ADDR=${NODE_LIST[$i]}
    ssh $IP_ADDR "bash -c 'source conda activate!; \
        echo \"Conda environment: \$CONDA_DEFAULT_ENV\"; \
        which python; \
        python -c \"import sys; print(sys.executable)\"; \
        torchrun --nnodes=$len --nproc_per_node=4 --node_rank=$i --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT /train_1-5d.py'" &
    bg_pid=$!
    bg_pids+=($bg_pid)
    echo "Started job with PID $bg_pid on $IP_ADDR"
done

# Wait for all background jobs to finish
wait

echo "All jobs have completed."
