#!/usr/bin/env bash
set -euo pipefail

usage() {
  cat << EOF
Usage: $0 [OPTIONS]
    Options:
        --model <model_name_or_path>
            The base model to start with (e.g. facebook/opt-125m)

        --widths "<w1 w2 …>"
            Space-separated list of bit-widths (e.g. "8 4").

        --sparsities "<s1 s2 …>"
            Space-separated list of sparsity percentages (e.g. "10 25").
EOF
  exit 1
}

# defaults
GROUP=128
MODEL=""
WIDTHS=()
SPARSITIES=()

# parse arguments
while [[ $# -gt 0 ]]; do
  case "$1" in
    --model)
      MODEL="$2"; shift 2;;
    --widths)
      read -r -a WIDTHS <<< "$2"; shift 2;;
    --sparsities)
      read -r -a SPARSITIES <<< "$2"; shift 2;;
    -h|--help)
      usage;;
    *)
      echo "Error: unknown option '$1'"; usage;;
  esac
done

# validate
if [[ -z "$MODEL" || ${#WIDTHS[@]} -eq 0 || ${#SPARSITIES[@]} -eq 0 ]]; then
  usage
fi

MODEL_LOC="../cache/qas/models/${MODEL}-w16"

cd ../model
python model.py --model_id "$MODEL" --save_dir "$MODEL_LOC"
python tokenizer.py --model_id "$MODEL" --save_dir "$MODEL_LOC"

for WIDTH in "${WIDTHS[@]}"; do
  for S in "${SPARSITIES[@]}"; do
    echo "Running for MODEL=${MODEL}, WIDTH=${WIDTH}, GROUP=${GROUP}, SPARSITY=${S}%"

    cd ../llm-awq
    # quantize model (using AWQ)
    python -m awq.entry \
        --model_path "${MODEL_LOC}" \
        --w_bit "$WIDTH" \
        --q_group_size "$GROUP" \
        --run_awq \
        --dump_awq "../cache/qas/pt/${MODEL}-w${WIDTH}-g${GROUP}.pt"

    # evaluate quantized model
    python -m awq.entry \
        --model_path "${MODEL_LOC}" \
        --tasks wikitext \
        --w_bit "$WIDTH" \
        --q_group_size "$GROUP" \
        --load_awq "../cache/qas/pt/${MODEL}-w${WIDTH}-g${GROUP}.pt" \
        --q_backend fake \
        --dump_fake "../cache/qas/models/${MODEL}-w${WIDTH}-g${GROUP}"

    cd ../model
    python tokenizer.py \
        --model_id "${MODEL}" \
        --save_dir "../cache/qas/models/${MODEL}-w${WIDTH}-g${GROUP}"

    cd ../sparsity
    # sparsify model
    SP_FRAC=$(bc -l <<< "${S}/100")
    python qas.py \
        --w "../cache/qas/models/${MODEL}-w16" \
        --w_hat "../cache/qas/models/${MODEL}-w${WIDTH}-g${GROUP}" \
        --sparsity_fraction "$SP_FRAC" \
        --output_dir "../cache/qas/models/${MODEL}-w16-w${WIDTH}-g${GROUP}-s${S}" \
        --save_tokenizer
    cd ../llm-awq

    # evaluate model
    python -m awq.entry \
        --model_path "../cache/qas/models/${MODEL}-w16-w${WIDTH}-g${GROUP}-s${S}" \
        --tasks wikitext
    cd ../scripts
  done
done
