#!/bin/bash

DEFAULT_GPUS_PER_NODE=8
DEFAULT_MASTER_ADDR="127.0.0.1"
DEFAULT_MASTER_PORT=25001

echo "SLURM_JOB_ID = $SLURM_JOB_ID"
echo "SLURM_JOB_NAME = $SLURM_JOB_NAME"

RUN_NAME=${RUN_NAME:-$DEFAULT_RUN_NAME}
RUN_NAME=${RUN_NAME:-$SLURM_JOB_NAME}
echo "RUN_NAME = $RUN_NAME"

OUTPUT_DIR=${OUTPUT_DIR:-"runs/train/$RUN_NAME"}
echo "OUTPUT_DIR = $OUTPUT_DIR"

export WANDB_PROJECT="vila"
export WANDB_DIR=$OUTPUT_DIR
export WANDB_RUN_ID=$RUN_NAME
export WANDB_NAME=$RUN_NAME
export WANDB_RESUME="allow"

NNODES=${SLURM_JOB_NUM_NODES:-1}
echo "NNODES = $NNODES"

NODES=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | tr '\n' ' ')
echo "NODES = $NODES"

NODE_RANK=${SLURM_PROCID:-0}
echo "NODE_RANK = $NODE_RANK"

GPUS_PER_NODE=${SLURM_JOB_GPUS_PER_NODE:-$DEFAULT_GPUS_PER_NODE}
echo "GPUS_PER_NODE = $GPUS_PER_NODE"

MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
MASTER_ADDR=${MASTER_ADDR:-$DEFAULT_MASTER_ADDR}
echo "MASTER_ADDR = $MASTER_ADDR"

MASTER_PORT=${MASTER_PORT:-$DEFAULT_MASTER_PORT}
echo "MASTER_PORT = $MASTER_PORT"

GLOBAL_TRAIN_BATCH_SIZE=${GLOBAL_TRAIN_BATCH_SIZE:-$DEFAULT_GLOBAL_TRAIN_BATCH_SIZE}
echo "GLOBAL_TRAIN_BATCH_SIZE = $GLOBAL_TRAIN_BATCH_SIZE"

GRADIENT_ACCUMULATION_STEPS=${GRADIENT_ACCUMULATION_STEPS:-$DEFAULT_GRADIENT_ACCUMULATION_STEPS}
echo "GRADIENT_ACCUMULATION_STEPS = $GRADIENT_ACCUMULATION_STEPS"

PER_DEVICE_TRAIN_BATCH_SIZE=$((GLOBAL_TRAIN_BATCH_SIZE / NNODES / GPUS_PER_NODE / GRADIENT_ACCUMULATION_STEPS))
echo "PER_DEVICE_TRAIN_BATCH_SIZE = $PER_DEVICE_TRAIN_BATCH_SIZE"

if [ -n "$MAX_PER_DEVICE_TRAIN_BATCH_SIZE" ] && [ "$PER_DEVICE_TRAIN_BATCH_SIZE" -gt "$MAX_PER_DEVICE_TRAIN_BATCH_SIZE" ]; then
    echo "PER_DEVICE_TRAIN_BATCH_SIZE is greater than MAX_PER_DEVICE_TRAIN_BATCH_SIZE"
    exit 1
fi

export NCCL_IB_SL=1
export CUDA_DEVICE_MAX_CONNECTIONS=1
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
export OMP_NUM_THREADS=1
