#!/bin/bash
# Copyright (c) 2024 Microsoft
# Licensed under The MIT License [see LICENSE for details]

export TOKENIZERS_PARALLELISM=false
RULER_PATH=$(dirname $0)
python -c "import nltk; nltk.download('punkt')"
python -c "import nltk; nltk.download('punkt_tab')"

# Config

export CUDA_VISIBLE_DEVICES=5
GPUS="1"

ROOT_DIR="PATH"
# ['full', 'minference', 'flexprefill', 'xattention', 'auxhead']
# for TYPE in "full" "minference" "flexprefill" "xattention" "auxhead"; do
TYPE=auxhead
GAMMA=0.95
STRIDE=8

NUM_SAMPLES=100
SEQ_LENGTHS=(
    4096
    8192
    16384
    32768
    65536
    131072
)

TASKS=(
    "niah_single_1"
    "niah_single_2"
    "niah_single_3"
    "niah_multikey_1"
    "niah_multikey_2"
    "niah_multikey_3"
    "niah_multivalue"
    "niah_multiquery"
    "vt"
    "cwe"
    "fwe"
    "qa_1"
    "qa_2"
)

# Experiment Setup
# NUM_SAMPLES=25
TEMPERATURE="0.0"
TOP_P="1.0"
TOP_K="32"

# The model
MODEL_NAME="PATH"
BENCHMARK="synthetic"
MODEL_TEMPLATE_TYPE="llama-3"
MODEL_FRAMEWORK="hf"

MODEL_NAME_FOR_PATH=$(basename ${MODEL_NAME})

for MAX_SEQ_LENGTH in "${SEQ_LENGTHS[@]}"; do

    # RESULTS_DIR="${ROOT_DIR}/${MODEL_NAME_FOR_PATH}_${TYPE}/${BENCHMARK}/${MAX_SEQ_LENGTH}"
    RESULTS_DIR="${ROOT_DIR}/${MODEL_NAME_FOR_PATH}_${TYPE}_${GAMMA}_${STRIDE}_${MAX_SEQ_LENGTH}"
    DATA_DIR="${RESULTS_DIR}/data"
    PRED_DIR="${RESULTS_DIR}/pred"
    BUDGET_PATH=${RESULTS_DIR}/${TYPE}_${GAMMA}_slide${STRIDE}.txt
    ATTN_CONFIG="--type ${TYPE} --gamma ${GAMMA} --stride ${STRIDE} --save_path ${BUDGET_PATH}"
    mkdir -p ${DATA_DIR}
    mkdir -p ${PRED_DIR}

    for TASK in "${TASKS[@]}"; do
        python ${RULER_PATH}/data/prepare.py \
            --save_dir ${DATA_DIR} \
            --benchmark ${BENCHMARK} \
            --task ${TASK} \
            --tokenizer_path ${MODEL_NAME} \
            --tokenizer_type "hf" \
            --max_seq_length ${MAX_SEQ_LENGTH} \
            --model_template_type ${MODEL_TEMPLATE_TYPE} \
            --num_samples ${NUM_SAMPLES} \
            ${REMOVE_NEWLINE_TAB}

        python ${RULER_PATH}/pred/call_api.py \
            --data_dir ${DATA_DIR} \
            --save_dir ${PRED_DIR} \
            --benchmark ${BENCHMARK} \
            --task ${TASK} \
            --server_type ${MODEL_FRAMEWORK} \
            --model_name_or_path ${MODEL_NAME} \
            --temperature ${TEMPERATURE} \
            --top_k ${TOP_K} \
            --top_p ${TOP_P} \
            ${ATTN_CONFIG} \
            ${MINFERENCE_PARAMS} \
            ${EXTRA_PARAMS} \
            ${STOP_WORDS}
    done

    python ${RULER_PATH}/eval/evaluate.py \
        --data_dir ${PRED_DIR} \
        --benchmark ${BENCHMARK}
done
