#!/bin/bash
echo "Starting job on $(hostname) at $(date)"

export PYTHONPATH=./Wan2.2   
conda activate cu129
export CUDA_HOME=$CONDA_PREFIX
export OMP_NUM_THREADS=8
export SLURM_JOB_NUM_NODES=${SLURM_JOB_NUM_NODES:-1}
export SLURM_GPUS_ON_NODE=${SLURM_GPUS_ON_NODE:-1}
export SLURM_NODEID=${SLURM_NODEID:-0}

if [ -z "$SLURM_JOB_ID" ]; then
     master_addr=127.0.0.1
else
     nodes=$(scontrol show hostnames $SLURM_JOB_NODELIST)
     master_addr=$(echo "$nodes" | head -n 1)
fi


echo "NODELIST"=$nodes
echo "MASTER_ADDR"=$master_addr
echo "current node index= $SLURM_NODEID"
export OMP_NUM_THREADS=8
# export TORCH_NCCL_TRACE_BUFFER_SIZE=1
# export TORCH_CPP_LOG_LEVEL="INFO"
# export TORCH_DISTRIBUTED_DEBUG="DETAIL"
# export NCCL_TRACE=INFO
# export NCCL_DEBUG=INFO
export NCCL_IB_HCA=mlx5_0:1,mlx5_1:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_13:1,mlx5_16:1,mlx5_17:1
export NCCL_IB_DISABLE=0 
export NCCL_SOCKET_IFNAME=bond1
export NCCL_IB_RETRY_CNT=7
export NCCL_IB_TIMEOUT=23
# export USE_TE_ATTN=1
torchrun \
     --nnodes=$SLURM_JOB_NUM_NODES \
     --nproc_per_node=$SLURM_GPUS_ON_NODE \
     --node_rank $SLURM_NODEID \
     --master_addr=${master_addr} \
     --master_port=29503 \
    -m utils.train --config=cosmos_predict2/configs/base/wan_base_config.py -- experiment=wan_real0910_iclr_fromi2v

