#!/bin/bash

source /mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/train_regression/env_volcengine_temp.sh

# 定义所有数据集的名称
DATA_NAMES=(
    # "suzuki_50_0.5"
    # "suzuki_50_0.25"
    # "suzuki_50_2.0"
    # "suzuki_50_4.0"
    # "arylation_0.5"
    # "arylation_0.25"
    # "arylation_2.0"
    # "arylation_4.0"
    # "grouped_exp_0.5"
    # "grouped_exp_0.25"
    # "grouped_exp_2.0"
    # "grouped_exp_4.0"
    "grouped_exp"
)

# 基础模型路径
BASE_MODEL_PATH="/mnt/shared-storage-user/caipengxiang/H200-share/models/share/step1_llama3_8b_0916_yearly_pistachio_ep3"

# 项目目录
PROJECT_DIR="/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS"

# 保存结果的目录
SAVE_RESULTS_DIR="/mnt/shared-storage-user/caipengxiang/H200-ai4chem/Exp_AB_results"

# 检查所需路径是否都存在
required_paths=(
    "${BASE_MODEL_PATH}"
    "${PROJECT_DIR}/train_regression/data4regression"
    "${PROJECT_DIR}/train_regression/yield_ft_ds_config.json"
    "${SAVE_RESULTS_DIR}"
)

for path in "${required_paths[@]}"; do
    if [ ! -e "$path" ]; then
        echo "Error: Required path not found: $path"
        exit 1
    fi
done

export TOKENIZERS_PARALLELISM=false # 禁用并行化以避免死锁

# 循环处理每个数据集
for DATA_NAME in "${DATA_NAMES[@]}"; do
    echo "Processing data set: ${DATA_NAME}"

    # 检查数据集对应的checkpoint目录是否存在
    CHECKPOINT_DIR="/mnt/shared-storage-user/caipengxiang/H200-ai4chem/Exp_AB_results/share/llama-3.1-8B/clustered/${DATA_NAME}"
    if [ ! -e "${CHECKPOINT_DIR}" ]; then
        echo "Error: Checkpoint directory not found for data set ${DATA_NAME}: ${CHECKPOINT_DIR}"
        continue
    fi

    # 运行deepspeed命令
    deepspeed --num_nodes=${NNODES} \
              --num_gpus=${GPUS} \
              --master_addr=${MASTER_ADDR} \
              --master_port=${MASTER_PORT} \
              --node_rank=${RANK} \
              inference_diff_data_scale.py \
              --pretrained_model_path ${BASE_MODEL_PATH} \
              --searchspace_name ${DATA_NAME} \
              --lora 1 \
              --data_name ${DATA_NAME} \
              --batch_size 48 \
              --checkpoint_dir "${CHECKPOINT_DIR}" \
              --load_by_torch 1 \
              --save_path "${SAVE_RESULTS_DIR}"
done