#!/bin/bash

scaling_factor=$1
scaling_type=$2
model_type=$3
task_name=$4
step=$5


if [[ "$@" == *"flash_attn"* ]]; then
  flash_attn=1
else
  flash_attn=0
fi
echo "flash_attn: ${flash_attn}"


data_path_prefix=/scratch/nlp/wutong/dataset/PoSE-Datasets
model_path_prefix=/scratch2/nlp/wutong/${task_name}/${model_type}_results


for gold_index in 0 34 69 104 139; do
    python -u evaluation/get_responses.py \
        --input_path ${data_path_prefix}/kv_data/kv-retrieval-140_keys.jsonl.gz \
        --model_name_or_path ${model_path_prefix}/4k-$((scaling_factor*4))k-${scaling_type}/checkpoint-${step} \
        --task_name kv \
        --batch_size 1 \
        --gold_index ${gold_index} \
        --max_prompt_length $((scaling_factor*4096)) \
        --model_max_position_embeddings 4096 \
        --rope_scaling_factor ${scaling_factor} \
        --rope_scaling_type ${scaling_type} \
        --max_new_tokens 50 \
        --use_flash_attn ${flash_attn} \
        --output_path eval_output/kv_predictions/${model_type}-${task_name}-${step}/kv_140_at_${gold_index}_4k_$((scaling_factor*4))k_${scaling_type}.txt
done

