#!/usr/bin/env bash
#
# Copyright 2023-2024 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# =====================================================================
# Sample script
# bash scripts/arena-evaluation.sh \
# --red_corner_model_name_or_path ./output/primal_dual_dpo_3_42_10 \
# --blue_corner_model_name_or_path ./output/primal_dual_dpo_3_42_8 \
# --reward_model_name_or_path PKU-Alignment/beaver-7b-unified-reward \
# --cost_model_name_or_path PKU-Alignment/beaver-7b-unified-cost \
# --output_dir ./results/arena-evaluation-test-0
# =====================================================================

if [ -z "${BASH_VERSION}" ]; then
	echo "Please use bash to run this script." >&2
	exit 1
fi

set -x

SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)"
ROOT_DIR="$(dirname "${SCRIPT_DIR}")"
export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}"

# Specify red and blue
RED_CORNER_MODEL_NAME_OR_PATH=""
BLUE_CORNER_MODEL_NAME_OR_PATH=""                                     
REWARD_MODEL_NAME_OR_PATH="PKU-Alignment/beaver-7b-v1.0-reward"
COST_MODEL_NAME_OR_PATH="PKU-Alignment/beaver-7b-v1.0-cost"
OUTPUT_DIR="${ROOT_DIR}/output/arena-evaluation-default"              # A default output path
unset HOSTFILE

while [[ "$#" -gt 0 ]]; do
	arg="$1"
	shift
	case "${arg}" in
		--red_corner_model_name_or_path)
			RED_CORNER_MODEL_NAME_OR_PATH="$1"
			shift
			;;
		--red_corner_model_name_or_path=*)
			RED_CORNER_MODEL_NAME_OR_PATH="${arg#*=}"
			;;
		--blue_corner_model_name_or_path)
			BLUE_CORNER_MODEL_NAME_OR_PATH="$1"
			shift
			;;
		--blue_corner_model_name_or_path=*)
			BLUE_CORNER_MODEL_NAME_OR_PATH="${arg#*=}"
			;;
		--reward_model_name_or_path)
			REWARD_MODEL_NAME_OR_PATH="$1"
			shift
			;;
		--reward_model_name_or_path=*)
			REWARD_MODEL_NAME_OR_PATH="${arg#*=}"
			;;
		--cost_model_name_or_path)
			COST_MODEL_NAME_OR_PATH="$1"
			shift
			;;
		--cost_model_name_or_path=*)
			COST_MODEL_NAME_OR_PATH="${arg#*=}"
			;;
		--output_dir)
			OUTPUT_DIR="$1"
			shift
			;;
		--output_dir=*)
			OUTPUT_DIR="${arg#*=}"
			;;
		--hostfile)
			HOSTFILE="$1"
			shift
			;;
		--hostfile=*)
			HOSTFILE="${arg#*=}"
			;;
		--zero_stage)
			ZERO_STAGE="$1"
			shift
			;;
		--zero_stage=*)
			ZERO_STAGE="${arg#*=}"
			;;
		--offload)
			OFFLOAD="$1"
			shift
			;;
		--offload=*)
			OFFLOAD="${arg#*=}"
			;;
		*)
			echo "Unknown parameter passed: '${arg}'" >&2
			exit 1
			;;
	esac
done

mkdir -p "${OUTPUT_DIR}"
OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)"
if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then
	echo '*' >"${OUTPUT_DIR}/.gitignore"
fi

cp -f "$0" "${OUTPUT_DIR}/script.sh"

MASTER_PORT_START=10000
MASTER_PORT_END=65535
MASTER_PORT="$(
	comm -23 \
		<(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \
		<(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) |
		shuf | head -n 1
)"

DEEPSPEED_ARGS=()
if [[ -n "${HOSTFILE+x}" ]]; then
	DEEPSPEED_ARGS+=("--hostfile" "${HOSTFILE}")
fi
DEEPSPEED_ARGS+=("--master_port" "${MASTER_PORT}")


exec 1> >(tee "${OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${OUTPUT_DIR}/stderr.log" >&2)

deepspeed "${DEEPSPEED_ARGS[@]}" \
	--module safe_rlhf.evaluate.arena \
	--datasets PKU-SafeRLHF-30K/test \
	--per_device_eval_batch_size 8 \
	--red_corner_model_name_or_path "${RED_CORNER_MODEL_NAME_OR_PATH}" \
	--blue_corner_model_name_or_path "${BLUE_CORNER_MODEL_NAME_OR_PATH}" \
	--reward_model_name_or_path "${REWARD_MODEL_NAME_OR_PATH}" \
	--cost_model_name_or_path "${COST_MODEL_NAME_OR_PATH}" \
	--trust_remote_code True \
	--output_dir "${OUTPUT_DIR}"
