#!/usr/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

set -ex

# Activate conda environment
source ~/miniconda3/etc/profile.d/conda.sh
conda activate arch

# use envs as local overwrites for convenience
# e.g.
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
# LOG_RANK=0,1,2,3,4,5,6,7
NGPU=${NGPU:-"8"}
export LOG_RANK=${LOG_RANK:-"0"}
export WANDB_PROJECT=arch_warmupv2
export WANDB_ENTITY=ajanthan-pluralis-research
export WANDB_NAME="archwarmup-16l-lr8e-3-sl1024-bs1024-ablation-250"
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/llama3_1b.toml"}
output_dir=${output_dir:-"./outputs/archwarmup-16l-lr8e-3-sl1024-bs1024-ablation-250"}
log_file=${log_file:-"$output_dir/train.log"}

TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"}

PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
-m torchtitan.train --job.config_file ${CONFIG_FILE} --job.dump_folder=${output_dir} "$@" 2>&1 | tee ${log_file}
