#!/bin/bash -l
#SBATCH --output=scripts/logs/mask_topk_recall_2.out
#SBATCH -w rlab7
#SBATCH -G 4

conda init
conda activate pqcache
CUDA_VISIBLE_DEVICES=0,1,2,3
export PYTHONNOUSERSITE=1

# Run LongBench with mask_topk_recall sparse attention over a fixed dataset list and recall grid.
MODEL_KEY=llama-3.1
EXP_NAME=masktopk_recall_grid

# Datasets to iterate
DATASETS=(
  # multifieldqa_en
  # multifieldqa_zh
  # trec
  # lcc
  # gov_report
  # narrativeqa
  lsht
)

# Base top-k ratio (e.g., 0.05 = top 5%)
BASE_RATIO=0.05

# Recall ratios; values can be 0-1 or 0-100 (runner accepts both)
RECALL_LIST=(95) #70 75 80 85 90

cd /filer/tmp1/WIRED/sampling/long_context_eval

echo "Running mask_topk_recall on: ${DATASETS[*]} with base_ratio: ${BASE_RATIO} and recalls: ${RECALL_LIST[*]} (model=${MODEL_KEY}, exp=${EXP_NAME})"

# Build recall list args
RECALL_ARGS=()
for r in "${RECALL_LIST[@]}"; do
  RECALL_ARGS+=("${r}")
done

for ds in "${DATASETS[@]}"; do
  echo "[mask_topk_recall] dataset=${ds}"
  python3 run_benchmark.py \
      --model_key "${MODEL_KEY}" \
      --exp_name "${EXP_NAME}" \
      --datasets "${ds}" \
      --attention_strategy mask_topk_recall \
      --base_ratio ${BASE_RATIO} \
      --recall_list ${RECALL_ARGS[*]}
done

echo "mask_topk_recall runs completed."


