#!/usr/bin/env bash
# ==========================================================================================
# [August 30, 2025]
# This script generates responses for a single model given a prompt dataset.
# The generated responses are saved to a specified JSONL file.
#
# Sample Usage:
# bash scripts/test/generation.sh \
#   --model_name_or_path path/to/your/model \
#   --output_dir absolute/path/to/results/directory
#   --dataset_path data//helpful_problem.json
# ==========================================================================================

if [ -z "${BASH_VERSION}" ]; then
	echo "Please use bash to run this script." >&2
	exit 1
fi

set -x
set -e # Exit immediately if a command exits with a non-zero status.

SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)"
ROOT_DIR="$(dirname "$(dirname "${SCRIPT_DIR}")")"
export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}"

# --- Configuration ---
MODEL_NAME_OR_PATH=""
OUTPUT_DIR="${ROOT_DIR}/output/generated-responses" # Default output directory
DATASET_ARG="PKU-SafeRLHF-30K/test"                  # Default dataset for generation
unset HOSTFILE
ZERO_STAGE=3
OFFLOAD="none"

# --- Argument Parsing ---
while [[ "$#" -gt 0 ]]; do
	arg="$1"
	shift
	case "${arg}" in
		--model_name_or_path)
			MODEL_NAME_OR_PATH="$1"
			shift
			;;
		--model_name_or_path=*)
			MODEL_NAME_OR_PATH="${arg#*=}"
			;;
		--output_dir)
			OUTPUT_DIR="$1"
			shift
			;;
		--output_dir=*)
			OUTPUT_DIR="${arg#*=}"
			;;
		--dataset_path)
			# Override the default dataset with the custom HelpfulPromptDataset
			# The python script expects the format: DATASET_NAME:PROPORTION:PATH
			# We assume the NAME for HelpfulPromptDataset is 'helpful-prompt'
			if [[ ! -f "$1" ]]; then
				echo "Error: Dataset file not found at '$1'" >&2
				exit 1
			fi
			DATASET_ARG="helpful-prompt:1.0:$1"
			shift
			;;
		--dataset_path=*)
			# Handle --dataset_path=/path/to/file
			DATASET_FILE_PATH="${arg#*=}"
			if [[ ! -f "${DATASET_FILE_PATH}" ]]; then
				echo "Error: Dataset file not found at '${DATASET_FILE_PATH}'" >&2
				exit 1
			fi
			DATASET_ARG="helpful-prompt:1.0:${DATASET_FILE_PATH}"
			;;
		--hostfile)
			HOSTFILE="$1"
			shift
			;;
		--hostfile=*)
			HOSTFILE="${arg#*=}"
			;;
		--zero_stage)
			ZERO_STAGE="$1"
			shift
			;;
		--zero_stage=*)
			ZERO_STAGE="${arg#*=}"
			;;
		--offload)
			OFFLOAD="$1"
			shift
			;;
		--offload=*)
			OFFLOAD="${arg#*=}"
			;;
		*)
			echo "Unknown parameter passed: '${arg}'" >&2
			exit 1
			;;
	esac
done

# --- Validate Required Arguments ---
if [[ -z "${MODEL_NAME_OR_PATH}" ]]; then
	echo "Error: --model_name_or_path is required." >&2
	exit 1
fi
if [[ -z "${OUTPUT_DIR}" ]]; then
	echo "Error: --output_dir is required." >&2
	exit 1
fi

# --- Environment and File Path Setup ---
mkdir -p "${OUTPUT_DIR}"
OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)" # Get absolute path
if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then
	echo '*' >"${OUTPUT_DIR}/.gitignore"
fi


# Derive the output filename from the model path
MODEL_BASENAME=$(basename "${MODEL_NAME_OR_PATH}")
MODEL_OUTPUT_DIR="${OUTPUT_DIR}/${MODEL_BASENAME}"
mkdir -p "${MODEL_OUTPUT_DIR}"

OUTPUT_FILE="${MODEL_OUTPUT_DIR}/response.jsonl"
# --- Logging ---
# Redirect stdout and stderr to log files named after the model
cp -f "$0" "${MODEL_OUTPUT_DIR}/script.sh"
exec 1> >(tee "${MODEL_OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${MODEL_OUTPUT_DIR}/stderr.log" >&2)

# --- DeepSpeed Arguments ---
DEEPSPEED_ARGS=()
if [[ -n "${HOSTFILE+x}" ]]; then
	DEEPSPEED_ARGS+=("--hostfile" "${HOSTFILE}")
fi

MASTER_PORT_START=10000
MASTER_PORT_END=65535
find_free_port() {
	comm -23 \
		<(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \
		<(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) |
		shuf | head -n 1
}

# --- Run Generation ---
echo "--- Generating responses for model: ${MODEL_NAME_OR_PATH} ---"
echo "--- Output will be saved to: ${OUTPUT_FILE} ---"

deepspeed "${DEEPSPEED_ARGS[@]}" --master_port "$(find_free_port)" \
	--module safe_rlhf.evaluate.generate \
	--model_name_or_path "${MODEL_NAME_OR_PATH}" \
	--datasets "${DATASET_ARG}" \
	--per_device_batch_size 8 \
	--trust_remote_code True \
	--output_file "${OUTPUT_FILE}" \
	--zero_stage "${ZERO_STAGE}" \
	--offload "${OFFLOAD}"
	# --datasets PKU-SafeRLHF-30K/test \

echo "--- Responses saved to ${OUTPUT_FILE} ---"
echo "Generation complete."