#!/bin/bash


# 启用严格模式，帮助捕捉更多潜在错误
set -euo pipefail
IFS=$'\n\t'

##############################
# 参数定义
##############################
start_time=$SECONDS
# 通用参数
DATASET="LabelClassificationSubset"
PROMPT_PATH="./prompts/LabelClassification/test.json"
PROMPT_ID="direct"
API_NAME="gpt_4o"
ERROR_EXTRACTION_COUNT=3
WANDB_PROJECT_NAME="ASP_tiny_normal_eval"
# MODEL_NAME 将在循环中动态设置
EVALUATE_SCRIPT="evaluate.py"
PREDICT_SCRIPT="predict_dataset_by_LLMs.py"

# 文件路径参数
DATASETS_DIR="datasets/tiny_normal_eval"

# 运行次数
RUNS=3

# 类别和索引
CATEGORIES=(
  "related_word_language"
  "random_word_language"
  "random_str_language"
  "random_str_symbolic"
  "random_word_symbolic"
  "related_word_symbolic"
)
INDICES=(0 1 2)

# 温度值数组
TEMPERATURES=(0 0.25 0.5 0.75 1.0)

##############################
# 环境检查
##############################

# 检查 Python 脚本是否存在
if [ ! -f "$PREDICT_SCRIPT" ]; then
  echo "错误: $PREDICT_SCRIPT 不存在。"
  exit 1
fi

if [ ! -f "$EVALUATE_SCRIPT" ]; then
  echo "错误: $EVALUATE_SCRIPT 不存在。"
  exit 1
fi

# 如果使用虚拟环境，请取消注释并设置路径
# source /path/to/your/venv/bin/activate

##############################
# 函数定义
##############################

# 预测函数
run_prediction() {
  local input_file=$1
  local output_file=$2
  local temperature=$3

  echo "正在为输入文件: $input_file 运行预测，温度值: $temperature..."

  # 打印即将执行的命令
  echo "执行命令: python $PREDICT_SCRIPT --dataset $DATASET --prompt_path $PROMPT_PATH --prompt_id $PROMPT_ID --api_name $API_NAME --error_extraction_count $ERROR_EXTRACTION_COUNT --temperature $temperature --data_type $input_file --prediction_path $output_file"

  python "$PREDICT_SCRIPT" \
    --dataset "$DATASET" \
    --prompt_path "$PROMPT_PATH" \
    --prompt_id "$PROMPT_ID" \
    --api_name "$API_NAME" \
    --error_extraction_count "$ERROR_EXTRACTION_COUNT" \
    --temperature "$temperature" \
    --data_type "$input_file" \
    --prediction_path "$output_file"

  echo "预测成功: $input_file -> $output_file"
}

# 评估函数
run_evaluation() {
  local work_name=$1
  local data_type=$2
  local prediction_path=$3
  local temperature=$4

  echo "正在为工作: $work_name 运行评估，温度值: $temperature..."

  # 打印即将执行的命令
  echo "执行命令: python $EVALUATE_SCRIPT --dataset $DATASET --wandb --wandb_model_path_name prediction_path --wandb_project_name $WANDB_PROJECT_NAME --model $MODEL_NAME --wandb_work_name $work_name --data_type $data_type --prediction_path $prediction_path"

  python "$EVALUATE_SCRIPT" \
    --dataset "$DATASET" \
    --wandb \
    --wandb_model_path_name prediction_path \
    --wandb_project_name "$WANDB_PROJECT_NAME" \
    --model "$MODEL_NAME" \
    --wandb_work_name "$work_name" \
    --data_type "$data_type" \
    --prediction_path "$prediction_path"

  echo "评估成功: $work_name"
}

##############################
# 主程序
##############################

echo "开始所有预测和评估任务..."

# 循环遍历每个温度值
for TEMPERATURE in "${TEMPERATURES[@]}"; do
  echo "-----------------------------------------"
  echo "正在处理温度值: $TEMPERATURE"

  # 定义包含温度值的预测目录
  PREDICTIONS_DIR="./predictions/tiny_normal_eval/LabelClassification/${API_NAME}_temp${TEMPERATURE}"

  # 定义模型名称，包含温度值
  MODEL_NAME="${API_NAME}_temp${TEMPERATURE}"

  # 创建预测目录（如果不存在）
  mkdir -p "$PREDICTIONS_DIR"

  # 运行预测
  for (( run=1; run<=RUNS; run++ )); do
    echo "运行次数: $run / $RUNS"
    for CATEGORY in "${CATEGORIES[@]}"; do
      for IDX in "${INDICES[@]}"; do
        INPUT_FILE="${DATASETS_DIR}/${CATEGORY}_0.jsonl"
        OUTPUT_FILE="${PREDICTIONS_DIR}/${CATEGORY}_${IDX}.jsonl"

        run_prediction "$INPUT_FILE" "$OUTPUT_FILE" "$TEMPERATURE"
      done
    done
  done

  echo "预测阶段完成，温度值: $TEMPERATURE"

  echo "开始评估阶段，温度值: $TEMPERATURE"

  # 运行评估
  for CATEGORY in "${CATEGORIES[@]}"; do
    for IDX in "${INDICES[@]}"; do
      WORK_NAME="${CATEGORY}_${IDX}_temp${TEMPERATURE}"
      DATA_TYPE="${DATASETS_DIR}/${CATEGORY}_0.jsonl"
      PREDICTION_PATH="${PREDICTIONS_DIR}/${CATEGORY}_${IDX}.jsonl"

      run_evaluation "$WORK_NAME" "$DATA_TYPE" "$PREDICTION_PATH" "$TEMPERATURE"
    done
  done

  echo "评估阶段完成，温度值: $TEMPERATURE"
  echo "-----------------------------------------"
done

echo "所有预测和评估任务成功完成。"
end_time=$SECONDS
elapsed=$(( end_time - start_time ))

echo "Total execution time: $elapsed seconds"
