export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

# =================== model ===================
models=(
    "Meta-Llama-3.1-8B-Instruct"
    "Meta-Llama-3.1-70B-Instruct"
    "MiniCPM3-4B"
    "Llama-2-13b-chat-hf"
    "Mistral-7B-Instruct-v0.3"
    "Mixtral-8x7B-Instruct-v0.1"
    "DeepSeek-V2-Lite-Chat"
    "Qwen2-1.5B-Instruct"
    "Qwen2-7B-Instruct"
    "Qwen2-72B-Instruct"
    "Phi-3-mini-128k-instruct"
    "Phi-3-small-128k-instruct"
    "Phi-3-medium-128k-instruct"
    "gemma-2-2b-it"
    "gemma-2-9b-it" 
    "gemma-2-27b-it"
)

# =================== data ===================
# "sharegpt_v3"
# "ultrafeedback"


PROJECT_PATH="path/to/AIR"
MODEL_BASE_PATH="path/to/models"
dataset_name="sharegpt_v3"

for model_name in "${models[@]}"; do
    echo "Processing model: $model_name"

    OPTS=""

    if [ "$model_name" == "sft" ]; then
        OPTS+=" --model_name_or_path path/to/Llama-3.1-Tulu-3-8B-SFT"
    elif [ "$model_name" == "Phi-3-small-128k-instruct" ]; then
        OPTS+=" --model_name_or_path ${MODEL_BASE_PATH}/${model_name}"
        echo "$model_name using specified tiktoken cache dir"
        export TIKTOKEN_CACHE_DIR="path/to/Phi-3-small-128k-instruct"
    elif [[ "$model_name" == "Phi-3-mini-128k-instruct" ]]; then
        OPTS+=" --model_name_or_path ${MODEL_BASE_PATH}/${model_name}"
    elif [[ "$model_name" == "Phi-3-medium-128k-instruct" ]]; then
        OPTS+=" --model_name_or_path ${MODEL_BASE_PATH}/${model_name}"
    elif [[ "$model_name" == "gemma"* ]]; then
        # export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128
        # export CUDA_WORKSPACE_CONFIG=:4096:8
        OPTS+=" --model_name_or_path ${MODEL_BASE_PATH}/${model_name}"
    else
        OPTS+=" --enable_chunked_prefill"
        OPTS+=" --model_name_or_path ${MODEL_BASE_PATH}/${model_name}"
    fi

    if [[ "$model_name" == "gemma"* ]]; then
        echo "$model_name using flash infer"
        export VLLM_ATTENTION_BACKEND=FLASHINFER
    fi

    OPTS+=" --dataset_name ${dataset_name}" # TODO

    OPTS+=" --max_tokens 2048"
    OPTS+=" --max_num_samples 10000" # dummy
    OPTS+=" --temperature 0.9" # TODO
    OPTS+=" --output_dir ${PROJECT_PATH}/outputs/${dataset_name}"

    CMD="python src/generate.py ${OPTS}"

    echo "${CMD}"

    ${CMD} 2>&1 | tee ${PROJECT_PATH}/logs/generate_on_${dataset_name}-${model_name}.log
  
done
