#!/bin/bash
NUM_GPUS=${1:-8}

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7  
export NCCL_DEBUG=INFO
export NCCL_IB_DISABLE=1
export NCCL_P2P_DISABLE=1

python -m torch.distributed.launch \
    --nproc_per_node=$NUM_GPUS \
    --master_port=29500 \
    train_staged.py \
    --gpus 0,1,2,3,4,5,6,7 \
    --batch_size 64 \
    --num_workers 8 \
    ${@:2}

