#!/usr/bin/bash

umask 007

export NUM_GPUS=${NUM_OF_GPUS}
export NUM_NODES=${WORLD_SIZE}
export NODE_RANK=${RANK}

export DISABLE_WANDB=true

if [ ! -z "$NUM_OF_GPUS" ]; then
  export NUM_GPUS=${NUM_OF_GPUS}
else
  export NUM_GPUS=${NUM_OF_GPUS}
fi
export NUM_NODES=${WORLD_SIZE}
export NODE_RANK=${RANK}
export NODE_LAUNCH_TIMESTAMP=$(date -u +"%Y-%m-%dT%H:%M:%S.%3N")+00:00 


# Decide single nodes or multiple nodes.
if [ -n "$MASTER_PORT" ]; then
  echo "Using distributed: $WORLD_SIZE, $RANK, $MASTER_ADDR, $MASTER_PORT"
  job_n=$WORLD_SIZE
  job_id=$RANK
  master_addr=$MASTER_ADDR
  master_port=$MASTER_PORT
  num_gpu=$NUM_OF_GPUS
else
  echo "Using localhost: 1, 0, localhost"
  job_n=1
  job_id=0
  master_addr=localhost
  master_port=12345
  if [ -n "$NUM_OF_GPUS" ]; then
    num_gpu=$NUM_OF_GPUS
  else
    num_gpu=8
  fi
fi


# No NCCL DEBUG info
export NCCL_DEBUG=WARN
export NCCL_SOCKET_IFNAME=eth0

# Run
torchrun \
--nproc_per_node=$num_gpu --nnodes=$job_n \
--rdzv-id=${JOB_UUID} \
--rdzv-backend=c10d \
--rdzv-endpoint=${master_addr}:${master_port} \
inference_scripts/run_wan_inference_distributed.py ${@:1}