#!/bin/bash

# =====================
# default values
# =====================
model="meta-llama/Llama-2-7b-hf"   # model name
model_abbr="Llama-2-7b"            # model abbr
data="pissa-dataset"               # data path
task="metamath"                    # task name
peft="LoRA"                        # peft method: LoRA, PiSSA, MiLoRA
rank=128                           # lora rank
bs=400                             # inference batch size
output_home="./"                   # reroute the whole "output" folder to other place, default is "./"
output_path=""                     # the path for trained adapter  (REQUIRED)
temp_path=""                       # the path for saving temporary merged model (REQUIRED)
gpus="0,1,2,3"                     # GPU ids
timestamp=""                       # timestamp (auto-generate if empty)

# =====================
# help function
# =====================
show_help() {
cat << EOF
Usage: $0 [options]

Options:
  --model NAME       Model name (default: $model)
  --model_abbr NAME  Model abbreviation (default: $model_abbr)
  --data PATH        Data path (default: $data)
  --task NAME        Task name, e.g., metamath, python, conversation (default: "$task")
  --peft NAME        PEFT method: LoRA, PiSSA, MiLoRA (default: "$peft")
  --rank INT         LoRA rank (default: $rank)
  --bs INT           Inference batch size (default: $bs)
  --output_home PATH Reroute the whole "output" folder to other place (default: $output_home)
  --output_path PATH Output path for trained adapter (REQUIRED)
  --temp_path PATH   Path for saving temporary merged model (REQUIRED)
  --gpus LIST        Comma-separated GPU ids (default: $gpus)
  --timestamp STR    Timestamp string; if omitted, auto-generates as YYYYmmdd-HHMMSS
  --help             Show this help message and exit
EOF
}

# =====================
# parse args
# =====================
TEMP=$(getopt -o '' \
  --long model:,model_abbr:,data:,task:,peft:,rank:,bs:,output_home:,output_path:,temp_path:,gpus:,timestamp:,help \
  -n "$0" -- "$@")
if [ $? != 0 ]; then
    echo "Error parsing options." >&2
    exit 1
fi

eval set -- "$TEMP"
while true; do
  case "$1" in
    --model)        model="$2"; shift 2 ;;
    --model_abbr)   model_abbr="$2"; shift 2 ;;
    --data)         data="$2"; shift 2 ;;
    --task)         task="$2"; shift 2 ;;
    --peft)         peft="$2"; shift 2 ;;
    --rank)         rank="$2"; shift 2 ;;
    --bs)           bs="$2"; shift 2 ;;
    --output_home)  output_home="$2"; shift 2 ;;
    --output_path)  output_path="$2"; shift 2 ;;
    --temp_path)    temp_path="$2"; shift 2 ;;
    --gpus)         gpus="$2"; shift 2 ;;
    --timestamp)    timestamp="$2"; shift 2 ;;
    --help)         show_help; exit 0 ;;
    --) shift; break ;;
    *) echo "Internal error!"; exit 1 ;;
  esac
done

# =====================
# validations
# =====================

# output_path required & must exist
if [[ -z "$output_path" ]]; then
  echo "Error: --output_path is required." >&2; exit 2
fi
if [[ ! -d "$output_path" ]]; then
  echo "Error: output_path does not exist: $output_path" >&2; exit 1
fi
# # perf.json guard
# if [[ -f "$output_path/perf.json" ]]; then
#   echo "Error: perf.json already exists in $output_path. Skipping execution." >&2
#   exit 1
# fi

# GPU count
world_size=$(awk -F',' '{print NF}' <<< "$gpus")
mod=$(( 32 % world_size ))
if (( mod != 0 )); then
  echo "Error: total attention heads (32) is not divisible by num_GPUs ($world_size)" >&2
  exit 1
fi

# Get cleaned task name!
echo "original task name: $task"
task_base="${task%%:*}"
task=$(echo "$task_base" | sed -E 's/-ep[0-9]+$//')
echo "cleaned task name: $task"

## echo settings ##
echo "===========Testing=============="
echo "Using output_path: $output_path"
echo "Using GPUs: $gpus (num_GPUs=$world_size)"
echo "Using PEFT method: $peft"
echo "Base model abbr: $model_abbr"
echo "Task: $task, Rank: $rank"
echo "Batch size = $bs"
echo "================================"

# =====================
# start testing
# =====================
resp_file="$output_path/${task}_response.jsonl"

if [[ "$peft" == "PiSSA" ]]; then
  res_model=$(readlink -m "${output_home}/output/PiSSA-${model_abbr}-r${rank}")
  
  python3 utils/gen_vllm_lora.py \
    --model "$res_model" \
    --lora "$output_path/adapter_model" \
    --sub_task "$task" \
    --batch_size $bs \
    --output_file "$resp_file" \
    --gpus "$gpus" \
    --temp_path "$temp_path"

elif [[ "$peft" == "MiLoRA" ]]; then
  res_model=$(readlink -m "${output_home}/output/MiLoRA-${model_abbr}-r${rank}")  
  
  python3 utils/gen_vllm_lora.py \
    --model "$res_model" \
    --lora "$output_path/adapter_model" \
    --sub_task "$task" \
    --batch_size $bs \
    --output_file "$resp_file" \
    --gpus "$gpus" \
    --temp_path "$temp_path"

elif [[ "$peft" == "InitAB" ]]; then
  res_model=$(readlink -m "${output_home}/output/InitAB-${model_abbr}-r${rank}-${timestamp}")
  
  python3 utils/gen_vllm_lora.py \
    --model "$res_model" \
    --lora "$output_path/adapter_model" \
    --sub_task "$task" \
    --batch_size $bs \
    --output_file "$resp_file" \
    --gpus "$gpus" \
    --temp_path "$temp_path"

else # LoRA and DoRA
  python3 utils/gen_vllm_lora.py \
    --model "$model" \
    --lora "$output_path/adapter_model" \
    --sub_task "$task" \
    --batch_size $bs \
    --output_file "$resp_file" \
    --gpus "$gpus" \
    --temp_path "$temp_path"
fi

if [ "$task" = "python" ]; then
  python3 utils/code_process.py --path "$resp_file"

  # Humaneval
  readarray -t scores < <(evalplus.evaluate --dataset humaneval --samples "$output_path/humaneval.jsonl" 2>&1 \
                          | grep "pass@1:" | awk '{print $2}')
  humaneval_base=${scores[0]}
  humaneval_extra=${scores[1]}

  # MBPP
  readarray -t scores < <(evalplus.evaluate --dataset mbpp --samples "$output_path/mbpp.jsonl" 2>&1 \
                          | grep "pass@1:" | awk '{print $2}')
  mbpp_base=${scores[0]}
  mbpp_extra=${scores[1]}

  # Write terminal output to `perf.json`
  cat > "$output_path/perf.json" <<EOF
{
  "humaneval": [$humaneval_base, $humaneval_extra],
  "mbpp": [$mbpp_base, $mbpp_extra]
}
EOF

  echo "✅ Scores saved to $output_path/perf.json"

else
  python3 utils/test_acc.py \
    --input_file "$resp_file"
fi