#!/bin/bash

# --- Configuration ---
# 1. Set the path to your base model
MODEL_NAME="llada_instruct"
util_eos=True
# 2. Set the base directory
BASE_OUTPUT_DIR="eval/results"

# 3. Define the specific parameters for each task here
# Format: ["task_name"]="gen_length steps block_length"
declare -A TASK_PARAMS
TASK_PARAMS["addsub"]="256 256 8"
TASK_PARAMS["aqua"]="256 256 8"      # Example: Needs longer generation
TASK_PARAMS["gsm8k_test"]="256 256 8"
TASK_PARAMS["multiarith"]="128 128 8"
TASK_PARAMS["svamp"]="128 128 8"     # Example: Slightly different
TASK_PARAMS["singleeq"]="256 256 8"

# Default values (used if a task is not found in the list above)
DEFAULT_PARAMS="256 256 8"

# 4. The list of tasks to iterate over (Keys from the map above)
# You can list them manually to control the order, or use "${!TASK_PARAMS[@]}" to run all
TASKS=(
    # "addsub"
    # "aqua"
    # "multiarith"
    "svamp"
    # "singleeq"
    # "gsm8k_test"
)

# 5. Set the GPUs/Port
GPUS_TO_USE=
port=
nproc=
# --- End Configuration ---

mkdir -p $BASE_OUTPUT_DIR

# Loop over each task
for TASK_NAME in "${TASKS[@]}"; do
    echo "-------------------------------------------------"
    echo "Starting evaluation for task: $TASK_NAME"

    # --- LOGIC TO GET PARAMS ---
    # Check if specific params exist for this task, otherwise use defaults
    if [[ -v TASK_PARAMS[$TASK_NAME] ]]; then
        params="${TASK_PARAMS[$TASK_NAME]}"
    else
        params="$DEFAULT_PARAMS"
    fi

    # Split the string "8 4 8" into three variables
    read -r CURRENT_GEN CURRENT_STEPS CURRENT_BLOCK <<< "$params"
    
    echo "Params -> Gen: $CURRENT_GEN, Steps: $CURRENT_STEPS, Block: $CURRENT_BLOCK"
    echo "-------------------------------------------------"
    CONFIG_ID="g${CURRENT_GEN}_s${CURRENT_STEPS}_b${CURRENT_BLOCK}"
    TASK_OUTPUT_DIR="$BASE_OUTPUT_DIR/$TASK_NAME/$MODEL_NAME/NoFinetuning/${CONFIG_ID}_utill_eos"
    mkdir -p $TASK_OUTPUT_DIR

    CUDA_VISIBLE_DEVICES=${GPUS_TO_USE} \
    accelerate launch \
        --main_process_port ${port} \
        --num_processes ${nproc} \
        eval_llada_locally.py \
        --model_name $MODEL_NAME \
        --task_name "$TASK_NAME" \
        --output_dir "$TASK_OUTPUT_DIR" \
        --save_generations \
        --gen_length $CURRENT_GEN \
        --steps $CURRENT_STEPS \
        --block_length $CURRENT_BLOCK \
        --utill_eos \
        --temperature 0.0

    if [ $? -ne 0 ]; then
        echo "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
        echo "ERROR: Evaluation failed for task: $TASK_NAME"
        echo "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
    else
        echo "-------------------------------------------------"
        echo "Finished evaluation for task: $TASK_NAME"
        echo "-------------------------------------------------"
    fi
done

echo "All evaluations completed."