#!/usr/bin/env bash
set -euo pipefail

GROUP=128

usage() {
    cat << EOF
Usage: $0 [OPTIONS]
    Options:
        --model <model_name_or_path>
            The base model to start with (e.g. facebook/opt-125m)

        --order <quant-first|sparsity-first>
            Whether to quantize first or prune first.

        --widths "<w1 w2 …>"
            Space-separated list of bit-widths (e.g. "8 4").

        --sparsities "<s1 s2 …>"
            Space-separated list of sparsity percentages (e.g. "10 25").
EOF
    exit 1
}

apply_quantization() {
    local WIDTH="$1"
    local GROUP="$2"
    local CURRENT_MODEL_LOC="$3"
    local NEW_MODEL_NAME="$4"
    local NEW_MODEL_LOC="../cache/baseline/models/${NEW_MODEL_NAME}"

    cd ../llm-awq
    # quantize model (using AWQ)
    python -m awq.entry \
        --model_path "${CURRENT_MODEL_LOC}" \
        --w_bit "${WIDTH}" \
        --q_group_size "${GROUP}" \
        --run_awq \
        --dump_awq "../cache/baseline/pt/${NEW_MODEL_NAME}.pt" \
        --max_memory 0:80GiB cpu:30GiB

    # evaluate quantized model
    python -m awq.entry \
        --model_path "${CURRENT_MODEL_LOC}" \
        --tasks wikitext \
        --w_bit "${WIDTH}" \
        --q_group_size "${GROUP}" \
        --load_awq "../cache/baseline/pt/${NEW_MODEL_NAME}.pt" \
        --q_backend fake \
        --dump_fake "${NEW_MODEL_LOC}" \
        --max_memory 0:80GiB cpu:30GiB

    cd ../model
    python tokenizer.py --model_id "${ORIG_MODEL}" --save_dir "${NEW_MODEL_LOC}"

    MODEL="${NEW_MODEL_NAME}"
    MODEL_LOC="${NEW_MODEL_LOC}"

    cd ../llm-awq
    # evaluate model
    python -m awq.entry --model_path "${MODEL_LOC}" --tasks wikitext --max_memory 0:80GiB cpu:30GiB
    cd ../scripts
}

apply_sparsity() {
    local SP_FRAC="$1"
    local CURRENT_MODEL_LOC="$2"
    local NEW_MODEL_NAME="$3"
    local NEW_MODEL_LOC="../cache/baseline/models/${NEW_MODEL_NAME}"

    cd ../sparsity
    # sparsify model
    python magnitude.py \
        --w "${CURRENT_MODEL_LOC}" \
        --sparsity_fraction "${SP_FRAC}" \
        --output_dir "${NEW_MODEL_LOC}" \
        --save_tokenizer

    MODEL="${NEW_MODEL_NAME}"
    MODEL_LOC="${NEW_MODEL_LOC}"

    cd ../llm-awq
    # evaluate model
    python -m awq.entry --model_path "${MODEL_LOC}" --tasks wikitext --max_memory 0:80GiB cpu:30GiB
    cd ../scripts
}

# defaults
ORIG_MODEL="facebook/opt-125m"
ORDER="quant-first"
WIDTHS=()
SPARSITIES=()

# parse arguments
[[ $# -eq 0 ]] && usage
while [[ $# -gt 0 ]]; do
    case "$1" in
        --model)
            shift; ORIG_MODEL="$1"; shift;;
        --order)
            shift; ORDER="$1"; shift;;
        --widths)
            shift
            read -r -a WIDTHS <<< "$1"
            shift;;
        --sparsities)
            shift
            read -r -a SPARSITIES <<< "$1"
            shift;;
        -h|--help)
            usage;;
        *)
            echo "Error: Unknown option '$1'"
            usage;;
    esac
done

if [[ ${#WIDTHS[@]} -eq 0 ]]; then
    echo "Error: --widths is required."
    exit 1
fi
if [[ ${#SPARSITIES[@]} -eq 0 ]]; then
    echo "Error: --sparsities is required."
    exit 1
fi

INITIAL_WIDTH=16
BASE_MODEL="${ORIG_MODEL}-w${INITIAL_WIDTH}"
BASE_MODEL_LOC="../cache/baseline/models/${BASE_MODEL}"

cd ../model
python model.py --model_id "${ORIG_MODEL}" --save_dir "${BASE_MODEL_LOC}"
python tokenizer.py --model_id "${ORIG_MODEL}" --save_dir "${BASE_MODEL_LOC}"
cd ../scripts

cd ../llm-awq
# evaluate baseline model
python -m awq.entry --model_path "${BASE_MODEL_LOC}" --tasks wikitext --max_memory 0:80GiB cpu:30GiB

for WIDTH in "${WIDTHS[@]}"; do
    for S in "${SPARSITIES[@]}"; do
        SP_FRAC=$(bc -l <<< "${S}/100")

        MODEL="${BASE_MODEL}"
        MODEL_LOC="${BASE_MODEL_LOC}"

        echo "Running for MODEL=${ORIG_MODEL}, WIDTH=${WIDTH}, GROUP=${GROUP}, SPARSITY=${S}%"
        if [[ "${ORDER}" == "quant-first" ]]; then
            apply_quantization "${WIDTH}" "${GROUP}" "${MODEL_LOC}" "${MODEL}-w${WIDTH}-g${GROUP}"
            apply_sparsity "${SP_FRAC}" "${MODEL_LOC}" "${MODEL}-s${S}"
        else
            apply_sparsity "${SP_FRAC}" "${MODEL_LOC}" "${MODEL}-s${S}"
            apply_quantization "${WIDTH}" "${GROUP}" "${MODEL_LOC}" "${MODEL}-w${WIDTH}-g${GROUP}"
        fi
    done
done
