#!/bin/bash
# shellcheck disable=SC2090,SC2086,SC2089,SC1091
# Default project path
PROJECT_PATH="$HOME/projects/repo"

# Parse command-line options
if ! OPTIONS=$(getopt -o p: --long project_path: -n 'parse-options' -- "$@"); then
	echo "base_centralised_training.sh: Error parsing options" >&2
	exit 1
fi

eval set -- "$OPTIONS"

while true; do
	case "$1" in
	-p | --project_path)
		PROJECT_PATH="$2"
		shift 2
		;;
	--)
		shift
		break
		;;
	*)
		break
		;;
	esac
done
MODEL_SIZE="$1"
echo "base_centralised_training.sh: MODEL_SIZE=$MODEL_SIZE"

#! Check if at least one arguments are passed
if [[ $# -lt 1 ]]; then
	echo "base_centralised_training.sh: Illegal number of parameters."
	echo "Usage: base_centralised_training.sh <llm_model_config> (-p/--project_path <project_path>)"
	exit 1
fi
echo "base_centralised_training.sh: PROJECT_PATH=$PROJECT_PATH"
#! Moving to the project folder
cd "$PROJECT_PATH" || exit
#! Preparing environment
if [[ $(hostname) == *'gpu-q'* ]]; then
	echo "base_centralised_training.sh: Assuming the script is executing in the CSD3."
	#! Executing the environment preparation script
	#! NOTE: Must use "." to execute, "sh" doesn't work
	. "$PROJECT_PATH"/scripts/install_hpc_env.sh -p "$PROJECT_PATH"
else
	echo "base_centralised_training.sh: Assuming the script is executing NOT in the CSD3."
	#! Executing the environment preparation script
	#! NOTE: Must use "." to execute, "sh" doesn't work
	. "$PROJECT_PATH"/scripts/install_env.sh -p "$PROJECT_PATH"
fi
#! Set `LLM_CONFIG` environment variable
. "$PROJECT_PATH"/scripts/set_llm_config.sh -p "$PROJECT_PATH" "$MODEL_SIZE"

#! Export the endpoint of the S3 object store
# export S3_ENDPOINT_URL='http://anonymous.anonymous.:9000'
export S3_ENDPOINT_URL='http://anonymous.anonymous.:9000'

#! Saving path
DATETIME=$(date '+%Y%m%d_%H%M%S')
#! If RUN_UUID hasn't been set, set it to the default value
if [ -z "$RUN_UUID" ]; then
	export RUN_UUID="cen-$MODEL_SIZE-$DATETIME"
fi

#! If SAVE_PATH hasn't been set, set it to the default value
if [ -z "$SAVE_PATH" ]; then
	export SAVE_PATH="$PROJECT_PATH/$RUN_UUID"
	mkdir -p "$SAVE_PATH"
fi

#! If repo_SAVE_PATH hasn't been set, set it to the default value
if [ -z "$repo_SAVE_PATH" ]; then
	export repo_SAVE_PATH="$PROJECT_PATH/runs/$RUN_UUID"
fi
mkdir -p "$repo_SAVE_PATH"
mkdir -p "$repo_SAVE_PATH/$DATETIME"

#! Getting visible GPUs
N_GPUS=$(uv run python -c 'import torch; print(torch.cuda.device_count())')
if [ "$N_GPUS" -eq 0 ]; then
	echo "No GPUs found. Exiting."
	CUDA_VISIBLE_DEVICES=""
else
	CUDA_VISIBLE_DEVICES=$(seq -s, 0 $((N_GPUS - 1)))
fi
echo "base_centralised_training.sh: CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"

#! Setting the run UUID
LLM_CONFIG="$LLM_CONFIG run_uuid=$RUN_UUID"

echo "base_centralised_training.sh: LLM_CONFIG=$LLM_CONFIG"
echo "base_centralised_training.sh: EXTERNAL_CONFIGS=$EXTERNAL_CONFIGS"

#! Run Hydra resolver
HYDRA_FULL_ERROR=1 uv run python -m repo.hydra_resolver $LLM_CONFIG $EXTERNAL_CONFIGS hydra/job_logging=none hydra/hydra_logging=none 2>&1 | tee "$repo_SAVE_PATH/$DATETIME"/hydra_resolver.log

#! If PYTORCH_MASTER_ADDRESS hasn't been set, set it to the default value
if [ -z "$PYTORCH_MASTER_ADDRESS" ]; then
	export PYTORCH_MASTER_ADDRESS="127.0.0.1"
fi

#! If PYTORCH_NODE_RANK hasn't been set, set it to the default value
if [ -z "$PYTORCH_NODE_RANK" ]; then
	export PYTORCH_NODE_RANK="0"
fi

#! If PYTORCH_WORLD_SIZE hasn't been set, set it to the default value
if [ -z "$PYTORCH_WORLD_SIZE" ]; then
	export PYTORCH_WORLD_SIZE=$N_GPUS
fi

#! If PYTORCH_BASE_RANK hasn't been set, set it to the default value
if [ -z "$PYTORCH_BASE_RANK" ]; then
	export PYTORCH_BASE_RANK=0
fi

#! If PYTORCH_MASTER_PORT hasn't been set, set it to the default value
if [ -z "$PYTORCH_MASTER_PORT" ]; then
	PYTORCH_MASTER_PORT=$(python -c "from repo.port_utils import get_free_tcp_port; print(get_free_tcp_port())")
	export PYTORCH_MASTER_PORT
	echo "base_centralised_training.sh: PYTORCH_MASTER_PORT=$PYTORCH_MASTER_PORT"
fi

#! Launch centralised training script
#! NOTE: Adding `NCCL_BLOCKING_WAIT=1` breaks the optimizer's checkpointing. We don't know why yet.
# TORCH_LOGS="+dynamo" TORCHDYNAMO_VERBOSE=1
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True APPOINTED_CUDA_DEVICE=$CUDA_VISIBLE_DEVICES CUDA_LAUNCH_BLOCKING=1 HYDRA_FULL_ERROR=1 RUN_UUID=$(uuidgen) uv run composer --world_size $PYTORCH_WORLD_SIZE --base_rank $PYTORCH_BASE_RANK --node_rank $PYTORCH_NODE_RANK --master_addr $PYTORCH_MASTER_ADDRESS --master_port $PYTORCH_MASTER_PORT $PROJECT_PATH/repo/centralised_train.py hydra/job_logging=none hydra/hydra_logging=none 2>&1 | tee $repo_SAVE_PATH/"$DATETIME"/centralised_train.log &
#! Keep the pid and wait for it
BACK_PID=$!
wait $BACK_PID
