#!/usr/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

set -ex

# Activate conda environment
source /home/ubuntu/miniconda3/etc/profile.d/conda.sh
conda activate titan

# use envs as local overrides for convenience
# e.g.
# NODE_RANK=0 NNODES=2 MASTER_ADDR=172.31.35.202 NGPU=8 ./run_train_dist.sh
NGPU=${NGPU:-"8"}
NODE_RANK=${NODE_RANK:-"0"}
NNODES=${NNODES:-"2"}
MASTER_ADDR=${MASTER_ADDR:-"172.31.35.202"}
MASTER_PORT=${MASTER_PORT:-"29500"}
LIGHTHOUSE_PORT=${LIGHTHOUSE_PORT:-"29510"}

export NCCL_IB_DISABLE=1
export NCCL_SOCKET_IFNAME=enp71s0
export GLOO_SOCKET_IFNAME=enp71s0
export TT_SINGULAR_VALUES_DIR=./outputs/svd
export TT_PLOT_SINGULAR_VALUES=1
export TT_ALIGN_DATA_BY_TOKENS=1
export LOG_RANK=${LOG_RANK:-"0"}
export WANDB_PROJECT=central-run-8B
export WANDB_ENTITY=ajanthan-pluralis-research

CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/llama3_8b.toml"}
output_dir=${output_dir:-"./outputs/final_7b_zclip_long_seq_len_test"}
log_file=${log_file:-"$output_dir/train.log"}

# Use the master node's IP for the lighthouse service
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://${MASTER_ADDR}:${LIGHTHOUSE_PORT}"}

# Handle additional overrides passed as arguments
overrides=""
if [ $# -ne 0 ]; then
    overrides="$*"
fi

PYTORCH_ALLOC_CONF="expandable_segments:True" \
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \
torchrun --nproc_per_node=${NGPU} \
--nnodes=${NNODES} \
--node_rank=${NODE_RANK} \
--rdzv_id 101 \
--rdzv_backend c10d \
--rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \
--master_addr=${MASTER_ADDR} \
--master_port=${MASTER_PORT} \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
-m torchtitan.train --job.config_file ${CONFIG_FILE} --job.dump_folder=${output_dir} ${overrides} 2>&1 | tee ${log_file}
