#!/bin/bash
# shellcheck disable=SC2090,SC2086,SC2089,SC1091
#SBATCH -c 40
#SBATCH -w anonymous
#SBATCH --gres=gpu:4
#SBATCH --job-name=repo_example
#SBATCH --tasks-per-node=1
#SBATCH --output=%x-%j.out
#SBATCH --time=01:00:00
#SBATCH --dependency=afterany:4871

#! Default project path
PROJECT_PATH="$HOME/projects/repo"
#! Get the current date and time
DATETIME=$(date '+%Y%m%d_%H%M%S')

#! Parse command-line options
if ! OPTIONS=$(getopt -o p: --long project_path: -n 'parse-options' -- "$@"); then
	echo "repo_example.sh: Error parsing options" >&2
	exit 1
fi

eval set -- "$OPTIONS"

while true; do
	case "$1" in
	-p | --project_path)
		PROJECT_PATH="$2"
		shift 2
		;;
	--)
		shift
		break
		;;
	*)
		break
		;;
	esac
done
echo "repo_example.sh: PROJECT_PATH=$PROJECT_PATH"

# rm -rf /dev/shm/*
# rm -rf /tmp/tmp*
# rm -rf $PROJECT_PATH/ray/*

#! Setting general ML parameters
# MODEL_TYPE="testing-moe"
# MODEL_TYPE="7B"
MODEL_TYPE="125M"
# MODEL_TYPE="13B"
# MODEL_TYPE="30B"
COMM_BATCHES=5
USE_RAY=true
USE_SHM=false
USE_S3=false
LOCAL_BATCH_SIZE=128
LOCAL_STEPS=32
N_CLIENTS=4
TOTAL_STEPS=5120
WARMUP_STEPS=100
COOLDOWN_STEPS=200
N_ROUNDS=$((TOTAL_STEPS / (LOCAL_STEPS)))
FL_EVAL_PERIOD=null
LOCAL_EVAL_INTERVAL=$LOCAL_STEPS

#! Set and export the run UUID and the save path
export RUN_UUID="fed-${MODEL_TYPE}-example-$DATETIME"
# export RUN_UUID="fed-125M-example-20250501_134201b"
# export SAVE_PATH="s3://checkpoints/$RUN_UUID"
export SAVE_PATH="$PROJECT_PATH/runs/$RUN_UUID/$DATETIME"

#! Initialize the external configs
EXTERNAL_CONFIGS=""
# EXTERNAL_CONFIGS="llm_config=dbrx-18b"

#! Dataset configs
DATASET_NAME="smollm_corpus"
export DATASET_CACHE_DIR="~/datasets/repo/dataset_cache"
# export DATASET_CACHE_DIR="$PROJECT_PATH/dataset_cache"
mkdir -p $DATASET_CACHE_DIR
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS dataset=$DATASET_NAME"                                     # Dataset name
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS dataset.train.root_local=$DATASET_CACHE_DIR/$DATASET_NAME" # Path of the local cache for the training dataset
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS dataset.val.root_local=$DATASET_CACHE_DIR/$DATASET_NAME"   # Path of the local cache for the evaluation dataset
#! If the number of clients is 1, use the single-client dataset
if [ $N_CLIENTS -eq 1 ]; then
	EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS dataset/streams@dataset.train.streams=1_client_small" # Stream configuration for the training dataset
	EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS dataset/streams@dataset.val.streams=1_client_small"   # Stream configuration for the training dataset
else
	EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS dataset/streams@dataset.train.streams=${N_CLIENTS}_clients" # Stream configuration for the training dataset
	EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS dataset/streams@dataset.val.streams=${N_CLIENTS}_clients"   # Stream configuration for the training dataset
fi
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS centralized.stream_id=null"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.eval_loader=null"

#! Local optimizer configs
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.max_duration=${TOTAL_STEPS}ba"
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.optimizer.name=adopt"
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.optimizer.lr=1.0e-4"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.scheduler.schedulers.lr.name=constant_with_sqrt_cooldown_with_warmup"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.scheduler.schedulers.lr.t_max=${TOTAL_STEPS}ba"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.scheduler.schedulers.lr.t_warmup=${WARMUP_STEPS}ba"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS ++llm_config.scheduler.schedulers.lr.t_cooldown=${COOLDOWN_STEPS}ba"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS ~llm_config.scheduler.schedulers.lr.alpha_f"

#! Batch size configs (train and eval)
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.global_train_batch_size=$LOCAL_BATCH_SIZE"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.device_train_microbatch_size=auto"
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.device_train_microbatch_size=1"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.device_eval_batch_size=$LOCAL_BATCH_SIZE"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS ++llm_config.device_eval_microbatch_size=auto"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS ++llm_config.device_eval_microbatch_size=1"

#! Precision and attention implementation configs
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.precision=amp_bf16"
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.precision=amp_fp16"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.model.attn_config.attn_impl=flash"
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.model.attn_config.attn_impl=torch"
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS ~llm_config.fsdp_config" # DDP
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS ++llm_config.fsdp_config.state_dict_type=sharded"

#! Federated learning configs
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.n_rounds=$N_ROUNDS"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.n_total_clients=$N_CLIENTS"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.n_clients_per_round=$N_CLIENTS"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.n_local_steps=$LOCAL_STEPS"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.eval_period=$FL_EVAL_PERIOD"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.strategy_name=FEDAVG"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.strategy_kwargs={}"
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.strategy_kwargs.server_learning_rate=1.0"
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.strategy_kwargs.server_momentum=0.0"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.reset_optimizer=true"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.ignore_failed_rounds=false"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.accept_failures_cnt=0"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.dropout_ratio=0.0"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.dropout_function_name=random"

#! DEPT configs
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.resize_vocab=null"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.random_layers=[]"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.random_init_freq=0"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.truly_random_init=true"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.personalized_layers=[]"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.frozen_layers=null"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.unfrozen_layers=null"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.set_trainer_params_filter_keys=false"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.set_trainer_key_to_filter=transformer"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS fl.split_eval=false"
export COMPOSER_FAIL_ON_VOCAB_MISMATCH=0
export ALLOW_EMBEDDING_RESIZING=0
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS +fl.parameter_scheduler_kwargs.EXP_AVG=1 +fl.parameter_scheduler_kwargs.EXP_AVG_SQ=1"

#! repo communication stack configs
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS repo.n_nodes=1"
if [ "$USE_RAY" = true ]; then
	EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS repo.comm_stack.ray=true"
fi
if [ "$USE_SHM" = true ]; then
	EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS repo.comm_stack.shm=true"
fi
if [ "$USE_S3" = true ]; then
	EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS repo.comm_stack.s3=true"
fi
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS repo.comm_stack.n_batches=$COMM_BATCHES"

EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS use_wandb=true" # Disable wandb logging on the ServerApp

#! Additional configs
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.eval_subset_num_batches=1"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS ++llm_config.tp_config=null"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.save_folder=null"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.save_overwrite=true"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.save_interval=${TOTAL_STEPS}ba"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.eval_interval=${LOCAL_EVAL_INTERVAL}ba"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.eval_first=false"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS llm_config.console_log_interval=100ba"
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS ~llm_config.loggers.wandb"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS ~llm_config.loggers.tensorboard"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS ~llm_config.callbacks.activation_monitor_full_model"
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS ++llm_config.callbacks.noise_scale_monitor={}"

#! Evaluation gauntlet
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS icl_tasks_config=empty eval_gauntlet_config=empty" # Empty gauntlet

#! repo checkpointing
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS repo.checkpoint=true"
EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS repo.saving_path=$SAVE_PATH"
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS repo.cleanup_checkpoints=false"
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS repo.cleanup_checkpoints_per_round=false"

#! Resume options
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS repo.resume_round=-1"
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS repo.copy_client_checkpoints=false"
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS repo.restore_run_uuid=fed-125M-example-20250501_134201"
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS repo.restore_cent_run_uuid=null"
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS repo.restore_cent_run_batches=null"
# EXTERNAL_CONFIGS="$EXTERNAL_CONFIGS pretrained_model_path=~/anonymous/projects/repo/runs/fed-125M-example-20250501_134201/server/21/current_server_parameters.npz"

#! External configs will be appended at the end of the command, thus rule over those
#! in the `repo_base.sh` script
export EXTERNAL_CONFIGS

#! Set the Flower driver and fleet API addresses
FLOWER_SUPERLINK_IP=$(hostname -I | awk '{print $1}')
#! Set the Flower driver and fleet API addresses if they haven't been set
DRIVER_API_PORT=$(uv run python -c "from repo.port_utils import get_free_tcp_port; print(get_free_tcp_port())")
DRIVER_API_ADDRESS=${DRIVER_API_ADDRESS:-"$FLOWER_SUPERLINK_IP:$DRIVER_API_PORT"}
FLEET_API_PORT=$(uv run python -c "from repo.port_utils import get_free_tcp_port; print(get_free_tcp_port(['${DRIVER_API_PORT}']))")
FLEET_API_ADDRESS=${FLEET_API_ADDRESS:-"$FLOWER_SUPERLINK_IP:$FLEET_API_PORT"}

# Set Ray address
RAY_PORT=$(uv run python -c "from repo.port_utils import get_free_tcp_port; print(get_free_tcp_port(['${DRIVER_API_PORT}','${FLEET_API_PORT}']))")
RAY_NODE_IP=$(hostname -I | awk '{print $1}')
RAY_ADDRESS="$RAY_NODE_IP:$RAY_PORT"
RAY_TEMP_DIR="$PROJECT_PATH/ray"

#! Exporting the variables for the `repo_base.sh` script
export DRIVER_API_ADDRESS
export FLEET_API_ADDRESS
export RAY_ADDRESS
export RAY_PORT
export RAY_NODE_IP

#! The following command will start the Ray head node. When running in a single node
#! setup, it is sufficient as both client and server will automatically attach to the
#! current Ray session.
if [ "$USE_RAY" = true ]; then
	uv run ray start --head --port=$RAY_PORT --temp-dir $RAY_TEMP_DIR &
fi

# bash $HOME/projects/repo/scripts/repo_base.sh -p "$PROJECT_PATH" ${MODEL_TYPE}

bash $HOME/projects/repo/scripts/repo_base_bi_independent.sh -p "$PROJECT_PATH" ${MODEL_TYPE}

# NOTE: For executing in a multinode setup, the Ray head node will need to be started
# only on the first node (e.g., `ray start --head --port=$RAY_PORT`), which can run the
# federated learning server, the Flower Superlink, and, potentially, a single client
#(e.g., `bash $HOME/projects/repo/scripts/repo_base.sh ${MODEL_TYPE}`). The other nodes
# that run, let's say, other clients (e.g., `DRIVER_API_ADDRESS="$FLOWER_SUPERLINK_IP:54752" FLEET_API_ADDRESS="$FLOWER_SUPERLINK_IP:54753" repo_base_client_only.sh`,
# where `$FLOWER_SUPERLINK_IP` is the IP address of the first node), will need to attach
# to the Ray head node (e.g., `ray start --address=$RAY_ADDRESS`, where `$RAY_ADDRESS`
# is the IP address and port of the head node) and run the `repo_base_client_only.sh`
# script.
