#!/bin/bash
# shellcheck disable=SC2090,SC2086,SC2089,SC1091
#SBATCH -c 20
#SBATCH -w anonymous
#SBATCH --gres=gpu:2
#SBATCH --job-name=repo_cen_example
#SBATCH --tasks-per-node=1
#SBATCH --output=%x-%j.out
#SBATCH --time=01:00:00
#SBATCH --dependency=afterany:4871

# Default project path
PROJECT_PATH="$HOME/projects/repo"

# Parse command-line options
if ! OPTIONS=$(getopt -o p: --long project_path: -n 'parse-options' -- "$@"); then
	echo "cen_example.sh: Error parsing options" >&2
	exit 1
fi

eval set -- "$OPTIONS"

while true; do
	case "$1" in
	-p | --project_path)
		PROJECT_PATH="$2"
		shift 2
		;;
	--)
		shift
		break
		;;
	*)
		break
		;;
	esac
done
echo "cen_example.sh: PROJECT_PATH=$PROJECT_PATH"

#! Fix the number of total samples/tokens to train on
# MODEL_TYPE="13B"
MODEL_TYPE="125M"
BATCH_SIZE=256
TOTAL_STEPS=5120
WARMUP_STEPS=100
COOLDOWN_STEPS=240
EVAL_INTERVAL=$((TOTAL_STEPS / 100))
# Get the current date and time
DATETIME=$(date '+%Y%m%d_%H%M%S')
export RUN_UUID="cen-${MODEL_TYPE}-bs$BATCH_SIZE-$DATETIME"

#! Initialize the external configs
EXTERNAL_CONFIGS=""

#! Dataset configs
DATASET_NAME="smollm-corpus-shared"
export DATASET_CACHE_DIR="$PROJECT_PATH"
mkdir -p $DATASET_CACHE_DIR
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS dataset=$DATASET_NAME"                                     # Dataset name
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS dataset.train.root_local=$DATASET_CACHE_DIR/$DATASET_NAME" # Path of the local cache for the training dataset
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS dataset.val.root_local=$DATASET_CACHE_DIR/$DATASET_NAME"   # Path of the local cache for the evaluation dataset
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS centralized.stream_id=null"                                # Setting null concatenates all streams for centralized training
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS dataset/streams@dataset.train.streams=smollm_corpus_cent_4_clients"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS dataset/streams@dataset.val.streams=smollm_corpus_cent_4_clients"
#! Disable the evaluation data loader since the SmolLM corpus doesn't have the validation data
#! NOTE: The script will automatically fall back to the training data loader if the evaluation data loader is not set
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.eval_loader=null" # Model name or path

#! SmolLM Baseline data mixture
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS dataset=smollm-corpus-shared"                                     # Dataset name
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS dataset.train.root_local=$DATASET_CACHE_DIR/smollm-corpus-shared" # Path of the local cache for the training dataset
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS dataset.val.root_local=$DATASET_CACHE_DIR/smollm-corpus-shared"   # Path of the local cache for the evaluation dataset
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS dataset/streams@dataset.train.streams=smollm_corpus_cent"         # Stream configuration for the training dataset -- 8 clients
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS dataset/streams@dataset.val.streams=smollm_corpus_cent"           # Stream configuration for the training dataset --  8 clients
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS centralized.stream_id=null"                                       # ID of the stream to use only for centralized training (they are concatenated if null)

#! Local optimizer configs
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.max_duration=${TOTAL_STEPS}ba"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.optimizer.name=adopt"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.optimizer.lr=1.0e-4"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.scheduler.schedulers.lr.name=constant_with_sqrt_cooldown_with_warmup"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.scheduler.schedulers.lr.t_max=${TOTAL_STEPS}ba"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.scheduler.schedulers.lr.t_warmup=${WARMUP_STEPS}ba"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS ++llm_config.scheduler.schedulers.lr.t_cooldown=${COOLDOWN_STEPS}ba"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS ~llm_config.scheduler.schedulers.lr.alpha_f"

#! Batch size configs (train and eval)
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.global_train_batch_size=$BATCH_SIZE"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.device_train_microbatch_size=8"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.device_eval_batch_size=$BATCH_SIZE"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS ++llm_config.device_eval_microbatch_size=auto"

#! Precision and attention implementation configs
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.precision=amp_bf16"
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.precision=amp_fp16"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.model.attn_config.attn_impl=flash"
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.model.attn_config.attn_impl=torch"
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS ~llm_config.fsdp_config" # DDP
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.fsdp_config.sharding_strategy=HYBRID_SHARD"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS ++llm_config.fsdp_config.data_parallel_replicate_degree=2"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS ++llm_config.fsdp_config.data_parallel_shard_degree=2"

#! Additional configs
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.eval_subset_num_batches=1"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS ++llm_config.tp_config=null"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.save_interval=${EVAL_INTERVAL}ba"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.eval_interval=${EVAL_INTERVAL}ba"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.eval_first=false"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.console_log_interval=100ba"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS ~llm_config.loggers.wandb"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS ~llm_config.loggers.tensorboard"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS ++llm_config.callbacks.noise_scale_monitor={}"

#! Evaluation gauntlet
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS icl_tasks_config=empty eval_gauntlet_config=empty" # Empty gauntlet

#! Load pre-trained models
#! Note that, in case both are set, the NDArrays checkpoint will be loaded last
# CHECKPOINT_PATH=""
#! Load a model from a checkpoint of type .pt (can reside in the S3 bucket)
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.load_path=$CHECKPOINT_PATH"
#! Load a model from a checkpoint of type NDArrays (can reside in the S3 bucket)
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS pretrained_model_path=$CHECKPOINT_PATH"

#! Loading a different weight matrix for the token embedding module of type NDArrays (can reside in the S3 bucket)
#! NOTE: This is crucial to set when analyzing DEPT-SPEC checkpoints since the token embedding matrix is not shared across the clients
# WTE_PARAMETERS_PATH=""
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS wte_parameters_path=$WTE_PARAMETERS_PATH"

# External configs will be appended at the end of the command, thus rule over those in the `base_centralised_training.sh` script
export EXTERNAL_CONFIGS

# PYTORCH_MASTER_ADDRESS=128.232.115.48 PYTORCH_NODE_RANK=0 PYTORCH_WORLD_SIZE=4 PYTORCH_BASE_RANK=0 bash $PROJECT_PATH/scripts/base_centralised_training.sh -p "$PROJECT_PATH" ${MODEL_TYPE}

PYTORCH_MASTER_ADDRESS=128.232.115.48 PYTORCH_NODE_RANK=1 PYTORCH_WORLD_SIZE=4 PYTORCH_BASE_RANK=2 PYTORCH_MASTER_PORT=11680 bash $PROJECT_PATH/scripts/base_centralised_training.sh -p "$PROJECT_PATH" ${MODEL_TYPE}
