#!/bin/bash
cd "$(dirname "$0")" || exit 1

default_model_max_length=3000
default_batch_size=8

if [ $# -eq 7 ];
then
  device_arg=""
  batch_size=default_batch_size
  dtype=""
  do_sample=1
elif [ $# -eq 8 ]; then
  device_arg="--device $8"
  batch_size=default_batch_size
  dtype=""
  do_sample=1
elif [ $# -eq 9 ]; then
  device_arg="--device $8"
  batch_size=$9
  dtype=""
  do_sample=1
elif [ $# -eq 10 ]; then
  device_arg="--device $8"
  batch_size=$9
  dtype=${10}
  do_sample=1
elif [ $# -eq 11 ]; then
  device_arg="--device $8"
  batch_size=$9
  dtype=${10}
  do_sample=${11}
else
  echo "Usage: $0 MODEL_DIR MODEL_NAME_STRICT MODEL_NAME_TEMPLATE MODEL_TYPE DATA_INPUT_DIR" \
       " DATA_OUTPUT_DIR DATA_FILES [DEVICE] [BATCH_SIZE] [DTYPE] [DO_SAMPLE]"
  exit 1
fi

model_dir=$1
model_strict=$2
model_template=$3
model_type=$4
data_input_dir=$5
data_output_dir=$6
data_files=$7

cd ..
export PYTHONPATH=.:${PYTHONPATH}
SRC_PATH=src

REPORT_DIR="${REPORT_DIR:-reports}"
TEST_SET="${TEST_SET:-test}"
MODEL_MAX_LEN="${MODEL_MAX_LEN:-${default_model_max_length}}"

if [ ${do_sample} -eq 0 ]; then
  gen_arg="--no_sample"
else
  gen_arg=""
fi

if [ "${dtype}" = "fp16" ]; then
    dtype_arg="--fp16"
elif [ "${dtype}" = "bf16" ]; then
    dtype_arg="--bf16"
else
    dtype_arg=""
fi

mkdir -p ${REPORT_DIR}

IFS=',' read -ra files <<< "${data_files}"
suffix=0
for file in "${files[@]}"; do
  echo "===Strict api generation==="

  python3 ${SRC_PATH}/scrape_hf_llm.py --input_files ${data_input_dir}/${file} --model_max_length ${MODEL_MAX_LEN} \
    --output_dir ${data_output_dir} --output_file_suffix ${model_strict} --model_name ${model_dir}/${model_strict} \
    --batch_size ${batch_size} --chat_format ${model_type} ${device_arg} ${gen_arg} ${dtype_arg}

  echo "===Strict api generation done==="

  echo "===Template based api generation==="

  python3 ${SRC_PATH}/scrape_hf_llm.py --input_files ${data_input_dir}/${file} --model_max_length ${MODEL_MAX_LEN} \
    --output_dir ${data_output_dir} --output_file_suffix ${model_template} \
    --prompt_function get_prompt_for_template_summarize --model_name ${model_dir}/${model_template} \
    --batch_size ${batch_size} --chat_format ${model_type} ${device_arg} ${gen_arg} ${dtype_arg}

  python3 ${SRC_PATH}/template_to_json.py --input_file ${data_output_dir}/${file}.${model_template} \
    --output_file ${data_output_dir}/${file}.${model_template}.api

  echo "===Template based api generation done==="

  echo "===Generating report==="

  python3 ${SRC_PATH}/gen_sbs.py --truth_file ${data_input_dir}/${file} \
    --prediction_file_control ${data_output_dir}/${file}.${model_strict} \
    --prediction_file_treatment ${data_output_dir}/${file}.${model_template}.api \
    --output_report ${REPORT_DIR}/report-${TEST_SET}${suffix}-${model_strict}_vs_${model_template}.tsv \
    --output_summary ${REPORT_DIR}/summary-${TEST_SET}${suffix}-${model_strict}_vs_${model_template}.tsv ${device_arg}

  suffix=`expr ${suffix} + 1`
  echo "===Report done==="
done
