set -x

HYDRA_FULL_ERROR=1


## task config
# TASK=formula
TASK=text


# model config
MODEL=qwen
# MODEL=llama
# MODEL=qwen14b


# dataset config
DATASETS="wikitq tabfact finqa hitab multihiertt aitqa tablebench"

# trained model config
step_id=654
model_path=training_outputs/sft/fortune/sft_text_llama_all/global_step_${step_id}
eval_only=False


# training config
NODE_NUM=1
GPU_NUM=4

if [ $MODEL = 'qwen14b' ]; then
    MODEL=qwen
fi



for dataset in $DATASETS; do
    echo "Evaluating on dataset: $dataset"

    save_path=evaluation_outputs/sft/${TASK}/${MODEL}/${dataset}/eval_global_step_${step_id}.parquet
    test_files=data/processed_data/${TASK}/${MODEL}/${dataset}/test.parquet

    if [ $eval_only = 'False' ]; then
        python3 -m verl.trainer.main_generation \
            trainer.nnodes=$NODE_NUM \
            trainer.n_gpus_per_node=$GPU_NUM \
            data.path=$test_files \
            data.prompt_key=prompt \
            data.n_samples=1 \
            data.output_path=$save_path \
            model.path=$model_path \
            +model.trust_remote_code=True \
            rollout.temperature=1.1 \
            rollout.top_k=50 \
            rollout.top_p=0.7 \
            rollout.prompt_length=8192 \
            rollout.response_length=512 \
            rollout.tensor_model_parallel_size=2 \
            rollout.max_num_batched_tokens=11000 \
            rollout.gpu_memory_utilization=0.8
    fi

    python3 src/eval/eval_local_llm.py \
        --output_path $save_path \
        --task $TASK \
        --sample_per_data 1
done