#!/bin/bash -l
#SBATCH --output=scripts/logs/mask_topk_recall_3.out
#SBATCH -G 4

conda init
conda activate pqcache
CUDA_VISIBLE_DEVICES=0,1,2,3
export PYTHONNOUSERSITE=1

# Run LongBench with mask_topk_recall sparse attention over a fixed dataset list and recall grid.
MODEL_KEY=llama-3.1
EXP_NAME=masktopk_recall_grid

# Datasets to iterate
DATASETS=(
  # multifieldqa_en
  # multifieldqa_zh
  # trec
  # lcc
  # gov_report
  # narrativeqa
  repobench-p
  vcsum
  samsum
  triviaqa
  hotpotqa
)

# Base top-k ratio (e.g., 0.05 = top 5%)
BASE_RATIO=0.05
SAMPLE_NUM=50
# Recall ratios; values can be 0-1 or 0-100 (runner accepts both)
RECALL_LIST=(70 75 80 85 90 95)

cd /filer/tmp1/WIRED/sampling/long_context_eval

echo "Running mask_topk_recall on: ${DATASETS[*]} with base_ratio: ${BASE_RATIO} and recalls: ${RECALL_LIST[*]} (model=${MODEL_KEY}, exp=${EXP_NAME})"

# Build recall list args
RECALL_ARGS=()
for r in "${RECALL_LIST[@]}"; do
  RECALL_ARGS+=("${r}")
done

for ds in "${DATASETS[@]}"; do
  echo "[mask_topk_recall] dataset=${ds}"
  python3 run_benchmark.py \
      --model_key "${MODEL_KEY}" \
      --exp_name "${EXP_NAME}" \
      --datasets "${ds}" \
      --attention_strategy mask_topk_recall \
      --base_ratio ${BASE_RATIO} \
      --recall_list ${RECALL_ARGS[*]} \
      --num_samples ${SAMPLE_NUM}
done

echo "mask_topk_recall runs completed."


