#!/bin/bash
set -x

export CUDA_LAUNCH_BLOCKING=1
export PYTHONPATH=.
export HIP_DISABLE_AUTOTUNE=1

HIP_PARAMS=(--dense-layers 3 --hip-top-k-elems 512 --start-sink-tokens 16 --end-sink-tokens 256 --hip-block-size-q 64 --hip-block-skip-q 2 --use-quantization)
#BASE_MODEL="meta-llama/Meta-Llama-3.1-8B-Instruct"
#FT_CHECKPOINT="./saves/checkpoints/meta-llama_Meta-Llama-3.1-8B-Instruct-default-hip_orig_mask-rp-32768-512-1/checkpoint-200"
BASE_MODEL="meta-llama/Meta-Llama-3.1-8B"
FT_CHECKPOINT="../hip-ft-ckpts/meta-llama_Meta-Llama-3.1-8B-default-hip_orig_mask-rp-32768-512-1/checkpoint-400"

# For STRIDE = 4096 to 131072
for STRIDE in 8192 16384 32768 65536 131072; do
  echo "========================================"
  echo "STRIDE: ${STRIDE}"

  if [ -z "${ONLY_FT}" ]; then
    echo "DENSE ATTENTION"
    python quick_extend/scripts/ppl.py --model "${BASE_MODEL}" --disable-rope-scaling --overwrite --stride "${STRIDE}" --model-parallel --disable-hip;

    echo "STRIDE: ${STRIDE}"
    echo "HiP PLUG AND PLAY"
    python quick_extend/scripts/ppl.py --model "${BASE_MODEL}" --disable-rope-scaling "${HIP_PARAMS[@]}" --overwrite --stride "${STRIDE}" --model-parallel;
  fi

  # Run the fine-tuned model if the environment variable SKIP_FT is unset
  if [ -z "${SKIP_FT}" ]; then
    echo "STRIDE: ${STRIDE}"
    echo "HIP FINE-TUNED"
    python quick_extend/scripts/ppl.py --model "${BASE_MODEL}" --init-from-checkpoint "${FT_CHECKPOINT}" --disable-rope-scaling "${HIP_PARAMS[@]}" --overwrite --stride "${STRIDE}" --model-parallel;
    printf "\n\n\n\n"
  fi
done
