#!/usr/bin/env bash
# (This is just to track which parts are ours)
# ==============================================================================

if [ -z "${BASH_VERSION}" ]; then
	echo "Please use bash to run this script." >&2
	exit 1
fi

set -x

SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)"
ROOT_DIR="$(dirname "${SCRIPT_DIR}")"
OUTPUT_DIR="${ROOT_DIR}/output"
export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}"
export LOGLEVEL="${LOGLEVEL:-WARNING}"

SEED=42
MODEL_NAME_OR_PATH="output/sft-pair"
unset HOSTFILE
ZERO_STAGE=3
EPOCH=3
LAG_MULTI=10
OFFLOAD="none"
SAVE_INTERVAL=500000

while [[ "$#" -gt 0 ]]; do
	arg="$1"
	shift
	case "${arg}" in
		--model_name_or_path)
			MODEL_NAME_OR_PATH="$1"
			shift
			;;
		--model_name_or_path=*)
			MODEL_NAME_OR_PATH="${arg#*=}"
			;;
		--output_dir)
			OUTPUT_DIR="$1"
			shift
			;;
		--output_dir=*)
			OUTPUT_DIR="${arg#*=}"
			;;
		--seed)
			SEED="$1"
			shift
			;;
		--seed=*)
			SEED="${arg#*=}"
			;;
		--epoch)
			EPOCH="$1"
			shift
			;;
		--epoch=*)
			EPOCH="${arg#*=}"
			;;
		--lag)
			LAG_MULTI="$1"
			shift
			;;
		--lag=*)
			LAG_MULTI="${arg#*=}"
			;;
		--hostfile)
			HOSTFILE="$1"
			shift
			;;
		--hostfile=*)
			HOSTFILE="${arg#*=}"
			;;
		--zero_stage)
			ZERO_STAGE="$1"
			shift
			;;
		--zero_stage=*)
			ZERO_STAGE="${arg#*=}"
			;;
		--offload)
			OFFLOAD="$1"
			shift
			;;
		--offload=*)
			OFFLOAD="${arg#*=}"
			;;
        --save_interval)
            SAVE_INTERVAL="$1"
            shift
            ;;
        --save_interval=*)
            SAVE_INTERVAL="${arg#*=}"
            ;;			
		*)
			echo "Unknown parameter passed: '${arg}'" >&2
			exit 1
			;;
	esac
done

OUTPUT_DIR="${OUTPUT_DIR}/primal_dual_dpo_${EPOCH}_${SEED}_${LAG_MULTI}"
mkdir -p "${OUTPUT_DIR}"
OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)"
if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then
	echo '*' >"${OUTPUT_DIR}/.gitignore"
fi

cp -f "$0" "${OUTPUT_DIR}/script.sh"

if [[ -z "${WANDB_API_KEY}" ]]; then
	export WANDB_MODE="offline"
fi

MASTER_PORT_START=10000
MASTER_PORT_END=65535
MASTER_PORT="$(
	comm -23 \
		<(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \
		<(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) |
		shuf | head -n 1
)"

DEEPSPEED_ARGS=()
if [[ -n "${HOSTFILE+x}" ]]; then
	DEEPSPEED_ARGS+=("--hostfile" "${HOSTFILE}")
fi
DEEPSPEED_ARGS+=("--master_port" "${MASTER_PORT}")

# --- NEW SECTION ---
# Create a DeepSpeed configuration file to enable CPU offloading for optimizer states.
# This is particularly useful for systems with limited GPU memory, such as those with Hopper GPUs.
# The configuration enables bf16 precision and sets zero optimization to stage 3 with CPU offloading
read -r -d '' DEEPSPEED_JSON_CONFIG <<'EOF'
{
  "bf16": {
    "enabled": true
  },
  "zero_optimization": {
    "stage": 3,
    "offload_optimizer": {
      "device": "cpu",
      "pin_memory": true
    },
    "zero_force_ds_cpu_optimizer": false
  }
}
EOF
# --- END NEW SECTION ---


exec 1> >(tee "${OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${OUTPUT_DIR}/stderr.log" >&2)

deepspeed "${DEEPSPEED_ARGS[@]}" \
	--module safe_rlhf.algorithms.primal_dual_dpo \
	--train_datasets PKU-SafeRLHF-30K/train \
	--eval_datasets PKU-SafeRLHF-30K/test \
	--model_name_or_path "${MODEL_NAME_OR_PATH}" \
	--max_length 512 \
	--trust_remote_code True \
	--epochs "${EPOCH}" \
	--lag "${LAG_MULTI}" \
	--per_device_train_batch_size 8 \
	--per_device_eval_batch_size 8 \
	--gradient_accumulation_steps 1 \
	--gradient_checkpointing \
	--learning_rate 1e-6 \
	--lr_scheduler_type cosine \
	--lr_warmup_ratio 0.03 \
	--weight_decay 0.05 \
	--seed "${SEED}" \
	--need_eval \
	--eval_strategy epoch \
	--scale_coeff 0.1 \
	--output_dir "${OUTPUT_DIR}" \
	--log_type wandb \
	--log_project SafeRLHF-DPO \
	--bf16 True \
	--tf32 True \
    --save_interval "${SAVE_INTERVAL}"
	# --zero_stage "${ZERO_STAGE}" \
	# --offload "${OFFLOAD}" \
