#!/bin/bash -l
#SBATCH --output=scripts/logs/mask_topk.out
#SBATCH -w rlab7
#SBATCH -G 4

### -w rlab7 if there is available rlab7 node
conda init
conda activate pqcache
CUDA_VISIBLE_DEVICES=0,1,2,3
export PYTHONNOUSERSITE=1

# Run LongBench with mask_topk sparse attention over a fixed dataset list and sparsity grid.
MODEL_KEY=llama-3.1
EXP_NAME=masktopk_grid

# Datasets to iterate
DATASETS=(
  multi_news
  lsht
)

# Sparsity ratios
SPARSITY_LIST=(0.001 0.005 0.01 0.05 0.1 0.5)
SAMPLE_NUM=50

cd /filer/tmp1/WIRED/sampling/long_context_eval

echo "Running mask_topk on: ${DATASETS[*]} with ratios: ${SPARSITY_LIST[*]} (model=${MODEL_KEY}, exp=${EXP_NAME})"

# Build sparsity list args
SP_ARGS=()
for r in "${SPARSITY_LIST[@]}"; do
  SP_ARGS+=("${r}")
done

for ds in "${DATASETS[@]}"; do
  echo "[mask_topk] dataset=${ds}"
  python3 run_benchmark.py \
      --model_key "${MODEL_KEY}" \
      --exp_name "${EXP_NAME}" \
      --datasets "${ds}" \
      --attention_strategy mask_topk \
      --sparsity_list ${SP_ARGS[*]} \
      --num_samples ${SAMPLE_NUM}
done

echo "mask_topk runs completed."