#!/usr/bin/env bash

if [ -z "${BASH_VERSION}" ]; then
	echo "Please use bash to run this script." >&2
	exit 1
fi

set -x

SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)"
ROOT_DIR="$(dirname "${SCRIPT_DIR}")"
export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}"
export LOGLEVEL="${LOGLEVEL:-WARNING}"

MODEL_NAME_OR_PATH="huggyllama/llama-7b"
OUTPUT_DIR="${ROOT_DIR}/output/sft-huggingface"
while [[ "$#" -gt 0 ]]; do
	arg="$1"
	shift
	case "${arg}" in
		--model_name_or_path)
			MODEL_NAME_OR_PATH="$1"
			shift
			;;
		--model_name_or_path=*)
			MODEL_NAME_OR_PATH="${arg#*=}"
			;;
		--output_dir)
			OUTPUT_DIR="$1"
			shift
			;;
		--output_dir=*)
			OUTPUT_DIR="${arg#*=}"
			;;
		*)
			echo "Unknown parameter passed: '${arg}'" >&2
			exit 1
			;;
	esac
done

mkdir -p "${OUTPUT_DIR}"
OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)"
if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then
	echo '*' >"${OUTPUT_DIR}/.gitignore"
fi

cp -f "$0" "${OUTPUT_DIR}/script.sh"

if [[ -z "${WANDB_API_KEY}" ]]; then
	export WANDB_MODE="offline"
fi

exec 1> >(tee "${OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${OUTPUT_DIR}/stderr.log" >&2)

torchrun --standalone --nproc_per_node=8 \
	--module safe_rlhf.finetune.huggingface \
	--datasets alpaca \
	--model_name_or_path "${MODEL_NAME_OR_PATH}" \
	--output_dir "${OUTPUT_DIR}" \
	--num_train_epochs 3 \
	--per_device_train_batch_size 4 \
	--per_device_eval_batch_size 4 \
	--gradient_accumulation_steps 8 \
	--evaluation_strategy no \
	--save_strategy steps \
	--save_steps 2000 \
	--save_total_limit 1 \
	--learning_rate 2e-5 \
	--weight_decay 0.0 \
	--lr_warmup_ratio 0.03 \
	--lr_scheduler_type cosine \
	--logging_steps 1 \
	--fsdp "full_shard auto_wrap" \
	--fsdp_config "${ROOT_DIR}/safe_rlhf/configs/fsdp_config.json" \
	--bf16 True \
	--tf32 True
