#!/usr/bin/env bash
set -euo pipefail

export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-"0,1,2"}
export PYTHONPATH="$(cd "$(dirname "${BASH_SOURCE[0]}")/../../../.." && pwd)/experiments/transformers-4.31.0/src:${PYTHONPATH:-}"

MODEL_PATH=${1:-"liuhaotian/llava-v1.6-vicuna-13b"}
MODEL_TAG=${2:-"llava16"}

GT_DIR="/path/to/MME/mme_json"
IMG_ROOT="/path/to/MME/mme"
DEPTH_ROOT="/path/to/MME/mme_depth"
OUT_ROOT="./output/MME"
mkdir -p "${OUT_ROOT}"

EVAL_PY="./eval/mme_llava16.py"
MME_CALC_PY="./eval/eval_mme.py"

SEED=42
MAX_NEW_TOKENS=2
CONV_MODE="llava_v1"
FORMAT="no_format"

TEMPERATURE=0.0
SAMPLE_TEMPERATURE=1.0
TOP_P=1.0
TOP_K=""

# VCD
NOISE_STEP=500
CD_ALPHA=1.5
CD_BETA=0.1

# OPERA
BEAM=5
SCALE_FACTOR=50.0
THRESHOLD=15
NUM_ATTN_CANDIDATES=5
PENALTY_WEIGHTS=1.0

# DAMO
DAMO_BETA1=0.05
DAMO_BETA2=0.2
DAMO_TAU=-0.30
DAMO_ALPHA=0.7
DAMO_START_LAYER=16

# AGLA
AGLA_ALPHA=2.0
AGLA_BETA=0.5

# HALC
HALC_DETECTOR="dino"
HALC_CONTEXT_WINDOW=4
HALC_EXPAND_RATIO=0.15
HALC_CONTEXT_DOMAIN="upper"
HALC_CONTRAST_WEIGHT=0.05
HALC_SCORE_TYPE="BLIP"
HALC_DEBUGGER=0
HALC_MATURE_LAYER=""
HALC_BASE_LAYER=0
HALC_CANDIDATE_LAYERS=""
HALC_RELATIVE_TOP=0.1
HALC_BEAM_SEARCH=false
HALC_NUM_BEAMS=1

# DSCR
DSCR_ALPHA=0.6
DSCR_BETA=0.8
DSCR_SIGMA=0.6
DSCR_KEEP_RATIO=1.0
DSCR_LAMBDA=1.0
DSCR_START=0
DSCR_END=30
DSCR_KEY_ONLY=true
DSCR_VALUE_ONLY=false
DSCR_KEY_VALUE=false

RUNS=(
  "baseline:nodscr"
  "dscr:dscr"
  "vcd:nodscr:greedy"
  "vcd:nodscr:sample"
  "opera:nodscr"
  "halc:nodscr"
  "damo:nodscr"
  "agla:nodscr:greedy"
  "agla:nodscr:sample"
  "vcd:dscr:greedy"
  "vcd:dscr:sample"
  "opera:dscr"
  "halc:dscr"
  "damo:dscr"
  "agla:dscr:greedy"
  "agla:dscr:sample"
)

pick_qfile () {
  local s="$1"
  if [[ -f "${GT_DIR}/${s}.jsonl" ]]; then echo "${GT_DIR}/${s}.jsonl"; return; fi
  if [[ -f "${GT_DIR}/${s}.json"  ]]; then echo "${GT_DIR}/${s}.json";  return; fi
  echo ""
}

pick_img_dir () {
  local s="$1"
  if [[ "${s}" == "posters" ]]; then
    if [[ -d "${IMG_ROOT}/posters" ]]; then echo "${IMG_ROOT}/posters"; return; fi
    if [[ -d "${IMG_ROOT}/Posters" ]]; then echo "${IMG_ROOT}/Posters"; return; fi
  fi
  if [[ -d "${IMG_ROOT}/${s}" ]]; then echo "${IMG_ROOT}/${s}"; return; fi
  if [[ -d "${IMG_ROOT}/${s,,}" ]]; then echo "${IMG_ROOT}/${s,,}"; return; fi
  echo "${IMG_ROOT}"
}

score_one () {
  local prefix="$1"
  echo "=============================================="
  echo "[EVAL] file-name: ${prefix}"
  echo "=============================================="
  python "${MME_CALC_PY}" --gt_dir "${GT_DIR}" --gen_dir "${OUT_ROOT}" --seed "${SEED}" --file-name "${prefix}" || true
}

run_llava16 () {
  local method="$1"
  local mode="$2"
  local dosample="${3:-greedy}"

  local label="${method}"
  if [[ "${mode}" == "dscr" ]]; then
    if [[ "${method}" == "dscr" ]]; then
      label="dscr"
    else
      label="${method}+dscr"
    fi
  fi
  local sample_suffix=""
  if [[ "${dosample}" == "sample" ]]; then
    sample_suffix="_sample"
  fi
  local prefix="${MODEL_TAG}_${label}${sample_suffix}_seed${SEED}"
  local dscr_lam="${DSCR_LAMBDA}"
  local halc_args=""
  if [[ -n "${HALC_MATURE_LAYER}" ]]; then halc_args+=" --halc-mature-layer ${HALC_MATURE_LAYER}"; fi
  if [[ -n "${HALC_CANDIDATE_LAYERS}" ]]; then halc_args+=" --halc-candidate-layers ${HALC_CANDIDATE_LAYERS}"; fi
  if [[ "${HALC_BEAM_SEARCH}" == "true" ]]; then halc_args+=" --halc-beam-search"; fi

  local method_name="${method}"
  if [[ "${method}" == "dscr" ]]; then
    method_name="baseline"
  fi

  local temp="${TEMPERATURE}"
  local dosample_arg=""
  if [[ "${dosample}" == "sample" ]]; then
    temp="${SAMPLE_TEMPERATURE}"
    dosample_arg="--do-sample"
  fi

  local cd_alpha="${CD_ALPHA}"
  local cd_beta="${CD_BETA}"
  local scale_factor="${SCALE_FACTOR}"
  local threshold="${THRESHOLD}"
  local penalty_weights="${PENALTY_WEIGHTS}"
  local damo_tau="${DAMO_TAU}"
  local damo_beta1="${DAMO_BETA1}"
  local damo_beta2="${DAMO_BETA2}"
  local damo_alpha="${DAMO_ALPHA}"
  local agla_alpha="${AGLA_ALPHA}"
  local agla_beta="${AGLA_BETA}"
  local halc_contrast_weight="${HALC_CONTRAST_WEIGHT}"

  if [[ "${mode}" == "dscr" ]]; then
    case "${method}" in
      vcd)
        cd_alpha=1.5
        cd_beta=0.1
        ;;
      opera)
        scale_factor=50.0
        threshold=15
        penalty_weights=1.0
        ;;
      halc)
        halc_contrast_weight=0.05
        ;;
      damo)
        damo_tau=-0.3
        damo_beta1=0.05
        damo_beta2=0.2
        damo_alpha=0.7
        ;;
      agla)
        agla_alpha=2.0
        agla_beta=0.5
        ;;
    esac
  fi

  local dscr_args=""
  if [[ "${mode}" == "dscr" ]]; then
    dscr_args="--use-dscr --depth-root ${DEPTH_ROOT} --dscr-alpha ${DSCR_ALPHA} --dscr-beta ${DSCR_BETA} --dscr-sigma ${DSCR_SIGMA} --dscr-keep-ratio ${DSCR_KEEP_RATIO} --dscr-start-layer ${DSCR_START} --dscr-end-layer ${DSCR_END}"
    if [[ "${DSCR_KEY_ONLY}" == "true" ]]; then dscr_args+=" --dscr-key-only"; fi
    if [[ "${DSCR_VALUE_ONLY}" == "true" ]]; then dscr_args+=" --dscr-value-only"; fi
    if [[ "${DSCR_KEY_VALUE}" == "true" ]]; then dscr_args+=" --dscr-key-value"; fi
  fi

  python "${EVAL_PY}" \
    --model-path "${MODEL_PATH}" \
    --gt-dir "${GT_DIR}" \
    --image-root "${IMG_ROOT}" \
    --out-root "${OUT_ROOT}" \
    --datasets "${SETS[@]}" \
    --method "${method_name}" \
    --conv-mode "${CONV_MODE}" \
    --format "${FORMAT}" \
    --seed "${SEED}" \
    --max-new-tokens "${MAX_NEW_TOKENS}" \
    --temperature "${temp}" \
    --top-p "${TOP_P}" \
    ${TOP_K:+--top-k "${TOP_K}"} \
    --noise-step "${NOISE_STEP}" \
    --cd-alpha "${cd_alpha}" \
    --cd-beta "${cd_beta}" \
    --beam "${BEAM}" \
    --scale-factor "${scale_factor}" \
    --threshold "${threshold}" \
    --num-attn-candidates "${NUM_ATTN_CANDIDATES}" \
    --penalty-weights "${penalty_weights}" \
    --tau "${damo_tau}" \
    --beta-1 "${damo_beta1}" \
    --beta-2 "${damo_beta2}" \
    --alpha "${damo_alpha}" \
    --damo-start-layer "${DAMO_START_LAYER}" \
    --agla-alpha "${agla_alpha}" \
    --agla-beta "${agla_beta}" \
    --halc-detector "${HALC_DETECTOR}" \
    --halc-context-window "${HALC_CONTEXT_WINDOW}" \
    --halc-expand-ratio "${HALC_EXPAND_RATIO}" \
    --halc-context-domain "${HALC_CONTEXT_DOMAIN}" \
    --halc-contrast-weight "${halc_contrast_weight}" \
    --halc-score-type "${HALC_SCORE_TYPE}" \
    --halc-debugger "${HALC_DEBUGGER}" \
    --halc-base-layer "${HALC_BASE_LAYER}" \
    --halc-relative-top "${HALC_RELATIVE_TOP}" \
    --halc-num-beams "${HALC_NUM_BEAMS}" \
    --run-name "${prefix}" \
    ${halc_args} \
    ${dosample_arg} \
    ${dscr_args}

  score_one "${prefix}"
}

SETS=()
for f in "${GT_DIR}"/*.json "${GT_DIR}"/*.jsonl; do
  [[ -f "$f" ]] || continue
  SETS+=("$(basename "$f" | sed 's/\.jsonl\?$//')")
done

for run in "${RUNS[@]}"; do
  IFS=":" read -r method mode dosample <<< "${run}"
  run_llava16 "${method}" "${mode}" "${dosample:-greedy}"
done
