#!/bin/bash -x

set -euo pipefail

project_home={project_home}

export HTTP_PROXY=$http_proxy
export HTTPS_PROXY=$https_proxy

cd $project_home
source .venv/bin/activate
# cd flash-linear-attention/legacy/training

ml cuda/12.9 || true

mkdir -p /tmp/.triton/autotune

export TRITON_CACHE_DIR=tmp/.triton/autotune

export NCCL_NSOCKS_PERTHREAD=4
export NCCL_SOCKET_NTHREADS=2
export NCCL_MIN_CHANNELS=32
export NCCL_DEBUG=INFO
export NCCL_IB_RETRY_CNT=10
export NCCL_MIN_NCHANNELS=11
export NCCL_TREE_THRESHOLD=4294967296
export TORCH_DISTRIBUTED_DEBUG=INFO
export TORCH_DISTRIBUTED_TIMEOUT=300
export TORCHELASTIC_MAX_FAILED_CONNECTIONS=60
export TORCH_DISTRIBUTED_HEARTBEAT_TIMEOUT=300

NPROC_PER_NODE=$(nvidia-smi -L | wc -l)

echo NPROC_PER_NODE=$NPROC_PER_NODE

export WANDB_OFFLINE=true
export WANDB_RESUME={wandb_resume}
export WANDB_PROJECT={wandb_project}
export WANDB_ENTITY={wandb_entity}
export WANDB_RUN_GROUP={wandb_group}
export WANDB_DIR={wandb_dir}

torchrun \
    --nnodes=1 \
    --nproc_per_node=$NPROC_PER_NODE \
    --standalone \
    training.py --cfg {trainer_config_path}