import os 
import json
import calendar  # 添加calendar模块导入
import torch
import tiktoken
from model import *
from utils import *
from tqdm import tqdm
import yaml
import argparse
from pathlib import Path
import logging
import datetime
import sys
import random
import numpy as np
import re
# 添加RAG所需的库
from rank_bm25 import BM25Okapi
import openai

# 添加检索器类型选择
RETRIEVER_TYPE_OPTIONS = ["bm25", "vector", "hybrid"]
# 默认使用混合检索
DEFAULT_RETRIEVER_TYPE = "bm25"
# 添加全局变量，可以在启动脚本时通过参数覆盖
retriever_type = DEFAULT_RETRIEVER_TYPE

# 修改预先生成的chunks基础目录
CHUNKS_BASE_DIR = "/home/weishaohang/workspace/Omni-Temp/evaluation/chunks_for_tcelongbench"

######################## logging setup ########################
def setup_logger(model_name, setting, data_source):
    # 创建logs目录（如果不存在）
    logs_dir = Path("/home/weishaohang/workspace/Omni-Temp/logs")
    os.makedirs(logs_dir, exist_ok=True)
    
    # 创建带有时间戳的日志文件名
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    log_filename = f"{model_name}_{setting}_{data_source}_{timestamp}.log"
    log_filepath = logs_dir / log_filename
    
    # 配置日志记录器
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    
    # 清除现有的处理器（避免重复添加）
    for handler in logger.handlers[:]:
        logger.removeHandler(handler)
    
    # 创建文件处理器
    file_handler = logging.FileHandler(log_filepath, encoding='utf-8')
    file_handler.setLevel(logging.INFO)
    
    # 创建控制台处理器（仅显示WARNING及以上级别的消息）
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(logging.WARNING)  # 将控制台日志级别设置为WARNING，这样INFO级别的日志就不会显示在控制台
    
    # 创建格式化器
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)
    console_handler.setFormatter(formatter)
    
    # 添加处理器到记录器
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)
    
    # 记录初始信息
    logging.info(f"开始评估 - 模型: {model_name}, 设定: {setting}, 数据源: {data_source}")
    logging.info(f"日志文件保存在: {log_filepath}")
    
    return logger

######################## preliminaries and prompt paths ########################
mapping_source_and_qtype_to_prompt_type = {
    "wikidata": {
        "L1_1": "multi_choice_qa",
        "L1_2": "free_form_qa_for_time_expression",
        "L1_3": "free_form_qa", # TODO:
        "L1_4": "single_choice_qa",
        "L1_5": "single_choice_qa",
        "L2_1": "free_form_qa", # TODO:
        "L2_2": "free_form_qa", # TODO:
        "L2_3": "free_form_qa_with_refusal",
        "L3_1": "free_form_qa_with_refusal",
        "L3_2": "free_form_qa", # TODO:
        "L3_3": "single_choice_qa_for_forecast",
        "L3_4": "free_form_qa_for_false_premise",
    },
    "tcelongbench": {
        "L1_1": "multi_choice_qa",
        "L1_2": "free_form_qa_for_time_expression", # TODO:
        "L1_3": "free_form_qa", # TODO:
        "L1_4": "single_choice_qa",
        "L1_5": "single_choice_qa",
        "L2_1": "free_form_qa", # TODO:
        "L2_1_multi_choice": "single_choice_qa",
        "L2_2": "free_form_qa", # TODO:
        "L2_2_multi_choice": "single_choice_qa",
        "L2_3": "free_form_qa", # TODO:
        "L2_3_multi_choice": "single_choice_qa",
        "L3_1": "free_form_qa", # TODO:
        "L3_1_multi_choice": "single_choice_qa",
        "L3_2": "free_form_qa", # TODO:
        "L3_3": "single_choice_qa_for_forecast",
        "L3_4": "free_form_qa_for_false_premise",
        "L3_4_multi_choice": "single_choice_qa_for_false_premise",  
    },
    "long_dialog": {
        "L1_1": "multi_choice_qa",
        "L1_2": "free_form_qa_for_time_expression", # TODO:
        "L1_3": "free_form_qa", # TODO:
        "L1_4": "single_choice_qa",
        "L1_5": "single_choice_qa",
        "L2_1": "single_choice_qa",
        "L2_2": "free_form_qa", # TODO:
        "L2_2_multi_choice": "single_choice_qa",
        "L2_3": "free_form_qa", # TODO:
        "L2_3_multi_choice": "single_choice_qa", 
        "L3_1": "free_form_qa", # TODO:
        "L3_1_multi_choice": "single_choice_qa",
        "L3_2": "free_form_qa", # TODO:
        "L3_3": "single_choice_qa_for_forecast",
        "L3_4": "free_form_qa_for_false_premise",
        "L3_4_multi_choice": "single_choice_qa_for_false_premise",  
    }
}

mapping_setting_prompt_type_to_prompt_path = {
    "base": {# 基础设定，即提供上下文，zero-shot
        "multi_choice_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_zero_shot/multi_choice_qa.txt",
        "free_form_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_zero_shot/free_form_qa.txt",
        "free_form_qa_with_refusal": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_zero_shot/free_form_qa_with_refusal.txt",
        "free_form_qa_for_time_expression": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_zero_shot/free_form_qa_for_time_expression.txt",
        "multi_choice_qa_for_forecast": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_zero_shot/multi_choice_qa_for_forecast.txt",
        "free_form_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_zero_shot/free_form_qa_for_false_premise.txt",
        "multi_choice_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_zero_shot/multi_choice_qa_for_false_premise.txt",
        "single_choice_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_zero_shot/single_choice_qa.txt",
        "single_choice_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_zero_shot/single_choice_qa_for_false_premise.txt",
        "single_choice_qa_for_forecast": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_zero_shot/single_choice_qa_for_forecast.txt"
    },
    "with_context_and_few_shot": {
        "multi_choice_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_few_shot/multi_choice_qa.txt",
        "free_form_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_few_shot/free_form_qa.txt",
        "free_form_qa_with_refusal": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_few_shot/free_form_qa_with_refusal.txt",
        "free_form_qa_for_time_expression": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_few_shot/free_form_qa_for_time_expression.txt",
        "multi_choice_qa_for_forecast": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_few_shot/multi_choice_qa_for_forecast.txt",
        "free_form_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_few_shot/free_form_qa_for_false_premise.txt",
        "multi_choice_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_few_shot/multi_choice_qa_for_false_premise.txt",
        "single_choice_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_few_shot/single_choice_qa.txt",
        "single_choice_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_few_shot/single_choice_qa_for_false_premise.txt",
        "single_choice_qa_for_forecast": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_few_shot/single_choice_qa_for_forecast.txt"
    },
    "RAG": {
        "multi_choice_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/RAG/multi_choice_qa.txt",
        "free_form_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/RAG/free_form_qa.txt",
        "free_form_qa_with_refusal": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/RAG/free_form_qa_with_refusal.txt",
        "free_form_qa_for_time_expression": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/RAG/free_form_qa_for_time_expression.txt",
        "multi_choice_qa_for_forecast": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/RAG/multi_choice_qa_for_forecast.txt",
        "free_form_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/RAG/free_form_qa_for_false_premise.txt",
        "multi_choice_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/RAG/multi_choice_qa_for_false_premise.txt",
        "single_choice_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/RAG/single_choice_qa.txt",
        "single_choice_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/RAG/single_choice_qa_for_false_premise.txt",
        "single_choice_qa_for_forecast": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/RAG/single_choice_qa_for_forecast.txt"
    },
    "closed_book": {
        "multi_choice_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/closed_book/multi_choice_qa.txt",
        "free_form_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/closed_book/free_form_qa.txt",
        "free_form_qa_with_refusal": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/closed_book/free_form_qa_with_refusal.txt",
        "free_form_qa_for_time_expression": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/closed_book/free_form_qa_for_time_expression.txt",
        "multi_choice_qa_for_forecast": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/closed_book/multi_choice_qa_for_forecast.txt",
        "free_form_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/closed_book/free_form_qa_for_false_premise.txt",
        "multi_choice_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/closed_book/multi_choice_qa_for_false_premise.txt",
        "single_choice_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/closed_book/single_choice_qa.txt",
        "single_choice_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/closed_book/single_choice_qa_for_false_premise.txt",
        "single_choice_qa_for_forecast": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/closed_book/single_choice_qa_for_forecast.txt"
    }
}

mapping_source_and_qtype_to_json_path = {
    "wikidata": {
        "L1_1": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L1_1_QAs.json",
        "L1_2": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L1_2_QAs.json",
        "L1_3": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L1_3_QAs.json",
        "L1_4": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L1_4_QAs.json",
        "L1_5": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L1_5_QAs.json",
        "L2_1": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L2_1_QAs.json",
        "L2_2": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L2_2_QAs.json",
        "L2_3": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L2_3_QAs.json",
        "L3_1": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L3_1_QAs.json",
        "L3_2": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L3_2_QAs.json",
        "L3_3": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L3_3_QAs.json",
        "L3_4": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L3_4_QAs.json",
    },
    "tcelongbench": {
        "L1_1": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L1_1_QAs.json",
        "L1_2": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L1_2_QAs.json",
        "L1_3": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L1_3_QAs.json",
        "L1_4": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L1_4_QAs.json",
        "L1_5": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L1_5_QAs.json",
        "L2_1": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L2_1_QAs.json",
        "L2_2": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L2_2_QAs.json",
        "L2_3": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L2_3_QAs.json",
        "L3_1": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L3_1_QAs.json",
        "L3_2": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L3_2_QAs.json",
        "L3_3": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L3_3_QAs.json",
        "L3_4": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L3_4_QAs.json",
        "L2_1_multi_choice": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L2_1_QAs_multi_choice.json",
        "L2_2_multi_choice": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L2_2_QAs_multi_choice.json",
        "L2_3_multi_choice": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L2_3_QAs_multi_choice.json",
        "L3_1_multi_choice": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L3_1_QAs_multi_choice.json",
        "L3_4_multi_choice": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L3_4_QAs_multi_choice.json",
    },
    "long_dialog": {
        "L1_1": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L1_1_QAs.json",
        "L1_2": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L1_2_QAs.json",
        "L1_3": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L1_3_QAs.json",
        "L1_4": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L1_4_QAs.json",
        "L1_5": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L1_5_QAs.json",
        "L2_1": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L2_1_QAs.json",
        "L2_2": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L2_2_QAs.json",
        "L2_3": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L2_3_QAs.json",
        "L3_1": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L3_1_QAs.json",
        "L3_2": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L3_2_QAs.json",
        "L3_3": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L3_3_QAs.json",
        "L3_4": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L3_4_QAs.json",
        "L2_2_multi_choice": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L2_2_QAs_multi_choice.json",
        "L2_3_multi_choice": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L2_3_QAs_multi_choice.json",
        "L3_1_multi_choice": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L3_1_QAs_multi_choice.json",
        "L3_4_multi_choice": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L3_4_QAs_multi_choice.json",
    }
}

data_source_list = ["wikidata", "tcelongbench", "long_dialog"]

setting_list = ["base", "with_context_and_few_shot", "RAG", "closed_book"]

mapping_source_to_qtype_list = {
    "wikidata": ["L1_1", "L1_2", "L1_3", "L1_4", "L1_5", "L2_1", "L2_2", "L2_3", "L3_1", "L3_2", "L3_3", "L3_4"],
    "tcelongbench": ["L1_1", "L1_2", "L1_3", "L1_4", "L1_5", "L2_1", "L2_1_multi_choice", "L2_2", "L2_2_multi_choice", "L2_3", "L2_3_multi_choice", "L3_1", "L3_1_multi_choice", "L3_2", "L3_3", "L3_4", "L3_4_multi_choice"],
    "long_dialog": ["L1_1", "L1_2", "L1_3", "L1_4", "L1_5", "L2_1", "L2_2", "L2_2_multi_choice", "L2_3", "L2_3_multi_choice", "L3_1", "L3_1_multi_choice", "L3_2", "L3_3", "L3_4", "L3_4_multi_choice"]
}

mapping_source_to_meta_data_path = {
    "wikidata": "/home/weishaohang/workspace/Omni-Temp/wiki_data_processing/data_with_bank_timeline/data_with_timeline.json",
    "tcelongbench": "/home/weishaohang/workspace/Omni-Temp/meta_data/tcelongbench.json",
    "long_dialog": "/home/weishaohang/workspace/Omni-Temp/long_dialog_processing/raw_data/long_dialog.json"
}

mapping_source_to_setting_list = {
    "wikidata": ["base", "with_context_and_few_shot", "RAG", "closed_book"],
    "tcelongbench": ["base", "with_context_and_few_shot", "RAG", "closed_book"],
    "long_dialog": ["base", "with_context_and_few_shot"]
}

######################## load and dump data ########################
def load_prompt(setting, data_source, qtype):
    prompt_type = mapping_source_and_qtype_to_prompt_type[data_source][qtype]
    prompt_path = mapping_setting_prompt_type_to_prompt_path[setting][prompt_type]
    with open(prompt_path, "r", encoding="utf-8") as f:
        prompt = f.read()
    logging.debug(f"加载提示模板: {prompt_path}")
    return prompt

def load_qa_data(data_source, qtype):
    file_path = mapping_source_and_qtype_to_json_path[data_source][qtype]
    logging.info(f"加载QA数据: {file_path}")
    return json.load(open(file_path, "r", encoding="utf-8"))

def check_if_results_exist(model, setting, data_source, qtype):
    # 根据设定是否为RAG决定路径
    if setting == "RAG":
        # 在RAG设定下，添加检索器类型子目录
        path = f"/home/weishaohang/workspace/Omni-Temp/results/{model.model_name}/{setting}/{retriever_type}/{data_source}/{qtype}.json"
    else:
        # 非RAG设定保持原样
        path = f"/home/weishaohang/workspace/Omni-Temp/results/{model.model_name}/{setting}/{data_source}/{qtype}.json"
        
    logging.info(f"检查结果是否存在: {path}")
    return os.path.exists(path)

def dump_qa_data(results, model, setting, data_source, qtype):
    logging.info(f"保存 {data_source} {qtype} {setting} {model.model_name} 的QA结果...")
    
    # 根据设定是否为RAG决定路径
    if setting == "RAG":
        # 在RAG设定下，添加检索器类型子目录
        path = f"/home/weishaohang/workspace/Omni-Temp/results/{model.model_name}/{setting}/{retriever_type}/{data_source}/{qtype}.json"
    else:
        # 非RAG设定保持原样
        path = f"/home/weishaohang/workspace/Omni-Temp/results/{model.model_name}/{setting}/{data_source}/{qtype}.json"
        
    os.makedirs(os.path.dirname(path), exist_ok=True)   # 如果目录不存在，则创建目录（递归）
    with open(path, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=4, ensure_ascii=False)
    logging.info(f"结果已保存到: {path}")
    print(f"Dumped {data_source} {qtype} {setting} {model.model_name} qa data to {path}")

def make_prompt(setting, data_source, qtype, data_idx, question, cached_data=None):
    prompt_template = load_prompt(setting, data_source, qtype)
    
    # 特殊处理：对于tcelongbench的L1_1类型，始终使用base模式生成context
    actual_setting = setting
    if data_source == "tcelongbench" and qtype == "L1_1" and setting == "RAG":
        actual_setting = "base"
        logging.info(f"特殊处理: tcelongbench的L1_1类型使用base模式而非{setting}模式")
    else:
        actual_setting = setting
    
    # 修改make_context函数调用，传入当前问题和调整后的setting
    context_result = make_context(actual_setting, data_source, qtype, data_idx, cached_data, question)
    
    # 分离context和retrieved_chunks（如果有）
    if isinstance(context_result, tuple) and len(context_result) == 2:
        context, retrieved_chunks = context_result
    else:
        context = context_result
        retrieved_chunks = None
    
    prompt = prompt_template.format(context=context, question=question)
    if data_source == "tcelongbench":
        if len(tiktoken.encoding_for_model("gpt-4").encode(prompt)) > model.max_input_context_len:
            print(f"prompt的tokens数量: {len(tiktoken.encoding_for_model('gpt-4').encode(prompt))}")
            raise ValueError("prompt的tokens数量超过了模型的最大输入上下文长度")
    
    # 返回prompt和检索到的chunks
    return prompt, retrieved_chunks

# 修改函数：从预先生成的chunks文件中获取检索结果
def get_chunks_from_file(qtype, data_idx, qa_idx):
    """从预先生成的chunks文件中获取指定问题的检索结果"""
    # 根据当前检索器类型构建chunks文件路径
    chunks_dir = os.path.join(CHUNKS_BASE_DIR, retriever_type)
    chunks_file = os.path.join(chunks_dir, f"{qtype}_QAs.json")
    if qtype.endswith("_multi_choice"):
        chunks_file = os.path.join(chunks_dir, f"{qtype.strip('_multi_choice')}_QAs_multi_choice.json")
    
    if not os.path.exists(chunks_file):
        logging.warning(f"未找到预先生成的chunks文件: {chunks_file}")
        return None
    
    try:
        with open(chunks_file, 'r', encoding='utf-8') as f:
            chunks_data = json.load(f)
        
        if data_idx not in chunks_data:
            logging.warning(f"在chunks文件中未找到data_idx: {data_idx}")
            return None
        
        if qa_idx >= len(chunks_data[data_idx]):
            logging.warning(f"在chunks文件中未找到qa_idx: {qa_idx}，超出范围")
            return None
        
        return chunks_data[data_idx][qa_idx]
    
    except Exception as e:
        logging.error(f"读取chunks文件时出错: {e}")
        return None

######################## make context ########################
def make_context(setting, data_source, qtype, data_idx, cached_data=None, current_question=None):
    # 如果是tcelongbench并且cached_data中有缓存的context，直接返回
    if "tcelongbench" in data_source and cached_data and "cached_context" in cached_data:
        logging.debug(f"使用缓存的context for data_idx: {data_idx}")
        return cached_data["cached_context"]

    # TODO: 未完成RAG的设定部分，这一部分还需要设计不同的检索器、对接对应的数据源（主要面向wikidata和tcelongbench两个数据源）
    def _get_article_from_md5(md5, article_dict=None):
        # 使用传入的预加载字典而不是每次读取文件
        if article_dict is None:
            # 如果没有传入预加载的字典，才读取文件（兼容旧代码）
            article_dict = read_json("/home/weishaohang/workspace/Omni-Temp/tcelongbench_processing/raw_data_dir/TCE_News_Articles.json")
        
        article_info = article_dict[md5]
        title = article_info["Title"]
        content = '\n'.join(article_info["Text"])
        return title, content
        
    # NOTE 对于wikidata，使用data里面获得的story，作为context
    if "wikidata" in data_source:
        if setting == "base" or setting == "with_context_and_few_shot":
            # 使用预加载的数据而不是每次读取文件
            if cached_data:
                context = cached_data["wikidata_data"][data_idx]["story"]
            else:
                # 如果没有传入预加载的数据，才读取文件（兼容旧代码）
                data_pth = mapping_source_to_meta_data_path[data_source]
                data = read_json(data_pth)
                context = data[data_idx]["story"]
            return context
        elif setting == "RAG":
            # TODO:从全网上的wikipedia中爬数据，作为context
            pass
        elif setting == "closed_book":
            # 因为是闭卷，所以没有context
            return None
    elif "tcelongbench" in data_source:
        if setting == "base" or setting == "with_context_and_few_shot":
            if cached_data:
                all_day_md5_list = cached_data["tcelongbench_context"]
                article_dict = cached_data["tcelongbench_articles"]
            else:
                # 如果没有预加载数据，才读取文件
                all_day_md5_list = read_json("/home/weishaohang/workspace/Omni-Temp/contexts/tcelongbench_context_for_articles.json")
                article_dict = read_json("/home/weishaohang/workspace/Omni-Temp/tcelongbench_processing/raw_data_dir/TCE_News_Articles.json")
                
            day_md5_list = list(all_day_md5_list[data_idx].items())[:-1]   # NOTE 1. data_idx应该是ce_id字符串, 2. 最后一天的articles不加入context
            context_list = []
            # TODO: 预加载所有articles，查看上下文窗口和原本context的比例关系，然后缩减文章数量，使得能够完全放入上下文窗口
            def convert_day_to_natural_language(day):
                # 将20150219转换为February 19, 2015
                day = int(day)
                year = day // 10000
                month = (day % 10000) // 100
                day = day % 100
                month_name = calendar.month_name[month]
                return f"{month_name} {day}, {year}"
            for day, md5_list in day_md5_list:
                for md5 in md5_list:
                    title, content = _get_article_from_md5(md5, article_dict)
                    context = f"Title: {title}, Day: {convert_day_to_natural_language(day)}\nContent: {content}"
                    context_list.append(context)
            context_str = '\n\n'.join(context_list)
            # TODO: 这里需要计算context_str的tokens数量，为了简单起见，可以使用tiktoken来计算
            toks_original_context = len(tiktoken.encoding_for_model("gpt-4").encode(context_str))   # NOTE 表示context_str的tokens数量，由于使用的模型几乎都是gpt-4同款tokenizer算法BPE，所以在这里使用gpt-4的tokenizer
            # print(f"原始context_str的tokens数量: {toks_original_context}")
            # 比较toks_original_context和上下文窗口的大小，如果toks_original_context大于上下文窗口的大小，则需要缩减context_list
            def _reduce_context_list(day_md5_list, context_window_size, reduce_ratio):
                # 按照缩减比例，随机选择文章，直到tokens数量小于上下文窗口的大小
                new_context_list = []
                for day, md5_list in day_md5_list:
                    # print(f"原始day: {day}, md5_list长度: {len(md5_list)}")
                    selected_md5_list = random.sample(md5_list, max(int(len(md5_list) // reduce_ratio)-3, 1))   # TODO:
                    # print(f"缩减后的md5_list长度: {len(selected_md5_list)}")
                    for md5 in selected_md5_list:
                        title, content = _get_article_from_md5(md5, article_dict)
                        context = f"Title: {title}, Day: {convert_day_to_natural_language(day)}\nContent: {content}"
                        new_context_list.append(context)
                return new_context_list
        
            # 获取模型最大输入上下文长度
            context_window_size = model.max_input_context_len
            # 如果context_str的tokens数量大于上下文窗口的大小，则需要缩减context_list
            if toks_original_context > context_window_size:
                # 计算缩减比例
                reduce_ratio = toks_original_context // context_window_size + 1  # 加1是为了防止出现tokens数量刚好等于上下文窗口的大小
                # print(f"缩减比例: {reduce_ratio}")
                # 按照缩减比例，随机选择文章，直到tokens数量小于上下文窗口的大小
                context_list = _reduce_context_list(day_md5_list, context_window_size, reduce_ratio)
                context_str = '\n\n'.join(context_list)  # 减去1000是为了防止出现tokens数量刚好等于上下文窗口的大小
                # 删去context_str末尾的2000个单词，同时保留换行符
                # 使用正则表达式匹配单词，保留换行符
                words_with_spaces = re.findall(r'(\S+|\s+)', context_str)
                # 计算单词数量（不包括空白字符）
                word_count = sum(1 for w in words_with_spaces if not w.isspace())
                
                if word_count > 2000:
                    # 需要保留的单词数量
                    words_to_keep = word_count - 2000
                    # 计数器
                    word_counter = 0
                    # 保留部分
                    preserved_text = []
                    
                    for token in words_with_spaces:
                        if not token.isspace():  # 如果是单词
                            word_counter += 1
                            if word_counter > words_to_keep:
                                break
                        preserved_text.append(token)
                    
                    # 重新组合文本
                    context_str = ''.join(preserved_text)
                    # print(f"删除末尾2000个单词后的context_str的tokens数量: {len(tiktoken.encoding_for_model('gpt-4').encode(context_str))}")
                else:
                    print(f"警告：context_str中的单词数量({word_count})少于2000，无法删除末尾2000个单词")
                
                # print(f"缩减后的context_str的tokens数量: {len(tiktoken.encoding_for_model('gpt-4').encode(context_str))}")
            return context_str
        elif setting == "RAG":
            logging.info(f"使用RAG模式为tcelongbench生成context，data_idx: {data_idx}")
            logging.info(f"当前检索器类型: {retriever_type}")
            
            # 首先尝试从预先生成的chunks文件获取检索结果
            # 需要从cached_data中获取qa_idx
            qa_idx = cached_data.get("qa_idx", 0) if cached_data else 0
            
            pre_generated_chunks = get_chunks_from_file(qtype, data_idx, qa_idx)
            
            if pre_generated_chunks is not None:
                logging.info(f"从预先生成的chunks文件中获取检索结果")
                
                # 生成上下文字符串
                context_chunks = []
                
                for i, chunk in enumerate(pre_generated_chunks):
                    title = chunk["title"]
                    formatted_date = chunk["date"]
                    truncated_content = chunk["content"]
                    
                    # 创建完整的chunk
                    chunk_text = f"[{i+1}] Title: {title}, Day: {formatted_date}\nContent: {truncated_content}"
                    context_chunks.append(chunk_text)
                
                # 合并所有chunks
                context_str = '\n\n'.join(context_chunks)
                
                # 调试代码：打印检索到的chunks
                print("\n" + "="*50)
                print(f"从预先生成的文件加载的chunks (共{len(context_chunks)}个):")
                print("-"*50)
                print(context_str)
                print("="*50 + "\n")
                
                return context_str, pre_generated_chunks
            
            logging.info(f"未找到预先生成的chunks，将进行实时检索")
            
            # 如果没有预先生成的chunks，执行原来的检索逻辑
            # 加载BM25和向量检索所需的数据
            if cached_data:
                all_day_md5_list = cached_data["tcelongbench_context"]
                article_dict = cached_data["tcelongbench_articles"]
            else:
                all_day_md5_list = read_json("/home/weishaohang/workspace/Omni-Temp/contexts/tcelongbench_context_for_articles.json")
                article_dict = read_json("/home/weishaohang/workspace/Omni-Temp/tcelongbench_processing/raw_data_dir/TCE_News_Articles.json")
            
            # 使用当前问题作为查询，而不是使用data_idx下的所有问题
            if current_question:
                query = current_question
                logging.info(f"使用当前问题作为查询: {query}")
            else:
                logging.warning(f"未提供当前问题，将尝试从QA数据中获取")
                # 从QA数据中获取问题，用于检索
                qa_file_path = mapping_source_and_qtype_to_json_path[data_source]["L1_1"]  # 使用第一个问题集作为检索依据
                qa_data = read_json(qa_file_path)
                
                if data_idx not in qa_data:
                    logging.warning(f"在QA数据中找不到data_idx: {data_idx}，使用空字符串作为查询")
                    query = ""
                else:
                    # 使用该data_idx的第一个问题作为查询
                    if qa_data[data_idx] and len(qa_data[data_idx]) > 0:
                        query = qa_data[data_idx][0]["Question"]
                        logging.info(f"使用该data_idx的第一个问题作为查询: {query}")
                    else:
                        query = ""
                        logging.warning(f"该data_idx下没有问题，使用空字符串作为查询")
            
            # 创建文档集合，每个文档是一个文章
            day_md5_list = list(all_day_md5_list[data_idx].items())[:-1]  # 排除最后一天
            
            # 准备用于检索的文档
            documents = []
            dates = []
            md5s = []
            
            for day, md5_list in day_md5_list:
                for md5 in md5_list:
                    try:
                        title, content = _get_article_from_md5(md5, article_dict)
                        # 创建一个包含文章内容的文档
                        doc_text = f"{title} {content}"
                        documents.append(doc_text)
                        dates.append(day)
                        md5s.append(md5)
                    except Exception as e:
                        logging.error(f"处理文章时出错: {e}")
            
            # 根据检索器类型选择适当的检索方法
            combined_results = []
            seen_md5s = set()
            
            if retriever_type == "bm25" or retriever_type == "hybrid":
                # 使用BM25检索器获取结果
                bm25_results = retrieve_with_bm25(documents, query, dates, md5s, article_dict, top_k=3)
                # 添加BM25结果
                for res in bm25_results:
                    if res['md5'] not in seen_md5s:
                        combined_results.append(res)
                        seen_md5s.add(res['md5'])
                logging.info(f"BM25检索找到了 {len(bm25_results)} 个结果")
            
            if retriever_type == "vector" or retriever_type == "hybrid":
                # 使用向量检索器获取结果
                vector_results = retrieve_with_vector(documents, query, dates, md5s, article_dict, top_k=3)
                # 添加未在之前结果中出现的向量检索结果
                for res in vector_results:
                    if res['md5'] not in seen_md5s and len(combined_results) < 3:
                        combined_results.append(res)
                        seen_md5s.add(res['md5'])
                logging.info(f"向量检索找到了 {len(vector_results)} 个结果")
            
            logging.info(f"最终组合了 {len(combined_results)} 个结果")
            
            # 按照日期排序结果
            combined_results.sort(key=lambda x: x['date'])
            
            # 生成最终的上下文字符串
            context_chunks = []
            raw_chunks = []  # 用于保存原始chunks信息，将添加到结果字典中
            
            for i, res in enumerate(combined_results):
                title, content = _get_article_from_md5(res['md5'], article_dict)
                day = res['date']
                
                # 转换日期
                formatted_date = convert_day_to_natural_language(day) if callable(locals().get('convert_day_to_natural_language')) else day
                
                # 创建chunk，添加索引
                chunk_header = f"[{i+1}] Title: {title}, Day: {formatted_date}\nContent: "
                
                # 计算header的单词数
                header_words = len(chunk_header.split())
                
                # 为正文内容分配的单词数 (512 - header_words)
                max_content_words = 512 - header_words
                
                # 对内容进行单词计数和截断
                content_words = content.split()
                if len(content_words) > max_content_words:
                    # 截断内容
                    truncated_content = ' '.join(content_words[:max_content_words])
                    logging.info(f"Chunk #{i+1} 内容被截断至 {max_content_words} 个单词")
                else:
                    truncated_content = content
                
                # 创建完整的chunk
                chunk = chunk_header + truncated_content
                context_chunks.append(chunk)
                
                # 保存原始chunk信息，用于后续添加到结果字典
                raw_chunks.append({
                    "title": title,
                    "date": formatted_date,
                    "content": truncated_content,
                    "source_md5": res['md5'],
                    "retrieval_score": res['score']
                })
            
            # 合并所有chunks
            context_str = '\n\n'.join(context_chunks)
            
            # 简单估计总单词数
            total_words = len(context_str.split())
            
            # 如果单词数超过模型上下文窗口的估计值，进行截断
            if total_words > model.max_input_context_len:
                # 截断到指定单词数
                words = context_str.split()
                context_str = ' '.join(words[:model.max_input_context_len])
                logging.warning(f"即使每个chunk限制在512个单词，RAG context总长度仍然超过模型上下文窗口，被截断至{model.max_input_context_len}个单词")
            
            # 调试代码：打印检索到的chunks
            print("\n" + "="*50)
            print(f"检索到的chunks (共{len(context_chunks)}个):")
            print("-"*50)
            print(context_str)
            print("="*50 + "\n")
            
            # 返回context_str和raw_chunks
            return context_str, raw_chunks
        elif setting == "closed_book":
            # 因为是闭卷，所以没有context
            return None
    elif "long_dialog" in data_source:
        if setting == "base" or setting == "with_context_and_few_shot":
            if cached_data:
                context = cached_data["long_dialog_data"][data_idx]["context"]
            else:
                # 如果没有预加载数据，才读取文件
                data_pth = mapping_source_to_meta_data_path[data_source]
                data = read_json(data_pth)
                context = data[data_idx]["context"]
            return context
        elif setting == "RAG":
            # long_dialog没有RAG设定
            raise ValueError("long_dialog没有RAG")
        elif setting == "closed_book":
            # long_dialog没有closed_book设定
            raise ValueError("long_dialog没有closed_book")

# 添加RAG检索函数
def retrieve_with_bm25(documents, query, dates, md5s, article_dict, top_k=3):
    """
    使用BM25算法进行文档检索
    
    Args:
        documents: 文档列表，每个文档是一个字符串
        query: 查询字符串
        dates: 文档对应的日期列表
        md5s: 文档对应的md5列表
        top_k: 返回的结果数量
        
    Returns:
        检索结果列表，每个结果是一个字典，包含文档索引、得分、日期和md5
    """
    logging.info("使用BM25进行检索...")
    
    # 预处理文档和查询
    tokenized_documents = [doc.lower().split() for doc in documents]
    tokenized_query = query.lower().split()
    
    # 创建BM25检索器
    try:
        bm25 = BM25Okapi(tokenized_documents)
        
        # 计算文档得分
        doc_scores = bm25.get_scores(tokenized_query)
        
        # 获取得分最高的文档
        top_indices = np.argsort(doc_scores)[-top_k:][::-1]
        
        # 准备结果
        results = []
        for idx in top_indices:
            if idx < len(documents) and idx < len(dates) and idx < len(md5s):
                results.append({
                    'index': idx,
                    'score': float(doc_scores[idx]),
                    'date': dates[idx],
                    'md5': md5s[idx]
                })
        
        logging.info(f"BM25检索完成，找到 {len(results)} 个结果")
        return results
    
    except Exception as e:
        logging.error(f"BM25检索失败: {e}")
        return []

def retrieve_with_vector(documents, query, dates, md5s, article_dict, top_k=3):
    """
    使用向量检索进行文档检索
    
    Args:
        documents: 文档列表，每个文档是一个字符串
        query: 查询字符串
        dates: 文档对应的日期列表
        md5s: 文档对应的md5列表
        top_k: 返回的结果数量
        
    Returns:
        检索结果列表，每个结果是一个字典，包含文档索引、得分、日期和md5
    """
    logging.info("使用向量检索进行检索...")
    
    try:
        # 检查OpenAI API密钥是否设置
        api_key = os.environ.get("OPENAI_API_KEY")
        if not api_key:
            logging.error("未设置OPENAI_API_KEY环境变量，无法使用向量检索")
            return []
        
        openai.api_key = api_key
        
        # 获取查询的嵌入向量
        query_response = openai.Embedding.create(
            model="text-embedding-ada-002",
            input=query
        )
        query_embedding = query_response['data'][0]['embedding']
        
        # 分批获取文档嵌入
        batch_size = 50  # OpenAI API限制
        document_embeddings = []
        
        for i in range(0, len(documents), batch_size):
            batch = documents[i:i+batch_size]
            try:
                response = openai.Embedding.create(
                    model="text-embedding-ada-002",
                    input=batch
                )
                batch_embeddings = [item['embedding'] for item in response['data']]
                document_embeddings.extend(batch_embeddings)
            except Exception as e:
                logging.error(f"获取批次 {i} 的嵌入向量时出错: {e}")
                # 对失败的批次填充空嵌入
                document_embeddings.extend([[0] * len(query_embedding)] * len(batch))
        
        # 计算余弦相似度
        similarities = []
        for doc_embedding in document_embeddings:
            similarity = cosine_similarity(query_embedding, doc_embedding)
            similarities.append(similarity)
        
        # 获取相似度最高的文档
        top_indices = np.argsort(similarities)[-top_k:][::-1]
        
        # 准备结果
        results = []
        for idx in top_indices:
            if idx < len(documents) and idx < len(dates) and idx < len(md5s):
                results.append({
                    'index': idx,
                    'score': float(similarities[idx]),
                    'date': dates[idx],
                    'md5': md5s[idx]
                })
        
        logging.info(f"向量检索完成，找到 {len(results)} 个结果")
        return results
    
    except Exception as e:
        logging.error(f"向量检索失败: {e}")
        return []

def cosine_similarity(a, b):
    """计算两个向量之间的余弦相似度"""
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

def convert_day_to_natural_language(day):
    """将日期格式转换为自然语言格式"""
    day = int(day)
    year = day // 10000
    month = (day % 10000) // 100
    day = day % 100
    month_name = calendar.month_name[month]
    return f"{month_name} {day}, {year}"

######################## eval ########################
# NOTE 评估tcelongbench时，需要剔除最后一天的articles

def gen_pred_results(model, setting, data_source, qtype, context_cache=None):
    logging.info(f"开始生成 {data_source} {qtype} {setting} {model.model_name} 的预测结果...")
    print(f"\n\nGenerating predictions for {data_source} {qtype} {setting} {model.model_name} qa data...")
    
    # 读取QAs
    QAs = load_qa_data(data_source, qtype)

    # 获取QAs的idx
    idx_list = list(QAs.keys()) if isinstance(QAs, dict) else range(len(QAs))
    
    # 根据data_source，确定results的格式
    results = {data_idx: [] for data_idx in idx_list} if "tcelongbench" in data_source else [[] for _ in idx_list]
    
    # 预加载所有需要的数据文件，避免重复读取
    cached_data = {}
    if "wikidata" in data_source and (setting == "base" or setting == "with_context_and_few_shot"):
        data_pth = mapping_source_to_meta_data_path[data_source]
        logging.info(f"预加载 wikidata 数据: {data_pth}")
        cached_data["wikidata_data"] = read_json(data_pth)
        
    if "tcelongbench" in data_source and (setting == "base" or setting == "with_context_and_few_shot") and context_cache is None:
        # 只有在没有传入context_cache时才预加载数据
        logging.info("预加载 tcelongbench 上下文和文章数据")
        cached_data["tcelongbench_context"] = read_json("/home/weishaohang/workspace/Omni-Temp/contexts/tcelongbench_context_for_articles.json")
        cached_data["tcelongbench_articles"] = read_json("/home/weishaohang/workspace/Omni-Temp/tcelongbench_processing/raw_data_dir/TCE_News_Articles.json")
        
    if "long_dialog" in data_source and (setting == "base" or setting == "with_context_and_few_shot"):
        data_pth = mapping_source_to_meta_data_path[data_source]
        logging.info(f"预加载 long_dialog 数据: {data_pth}")
        cached_data["long_dialog_data"] = read_json(data_pth)
    
    # 收集所有数据，便于并行处理
    all_prompts = []
    all_questions = []
    all_gold_answers = []
    all_retrieved_chunks = []  # 添加检索到的chunks
    # 记录每个prompt对应的data_idx和qa索引，用于后续映射结果
    prompt_mapping = []  
    
    # 首先收集所有的问题和提示
    for data_idx in tqdm(idx_list, desc=f"Collecting prompts for {data_source}/{qtype}/{setting}/{model.model_name}", total=len(idx_list)):
        qas = QAs[data_idx]
        
        for qa_idx, qa in enumerate(qas):
            question = qa["Question"]
            gold_answer = qa["Gold Answer"]
            
            # 如果是tcelongbench且有缓存，则在make_prompt前将缓存的context添加到cached_data
            if "tcelongbench" in data_source and context_cache is not None and data_idx in context_cache:
                # 将缓存的context添加到cached_data
                cached_data["cached_context"] = context_cache[data_idx]
            
            # 添加qa_idx到cached_data，用于从预先生成的chunks文件中获取检索结果
            if "tcelongbench" in data_source and setting == "RAG":
                cached_data["qa_idx"] = qa_idx
                
            prompt, retrieved_chunks = make_prompt(setting, data_source, qtype, data_idx, question, cached_data)
            
            all_prompts.append(prompt)
            all_questions.append(question)
            all_gold_answers.append(gold_answer)
            all_retrieved_chunks.append(retrieved_chunks)  # 可能为None
            prompt_mapping.append((data_idx, qa_idx))
    
    # --- 开始修改 ---
    # 1. 存储原始索引和 prompts
    indexed_prompts = list(enumerate(all_prompts))
    
    # 2. 打乱带索引的 prompts 列表 (确保 random 已导入并已设置种子)
    random.shuffle(indexed_prompts)
    
    # 3. 提取打乱后的 prompts 和对应的原始索引顺序
    original_indices_shuffled = [i for i, p in indexed_prompts]
    shuffled_prompts = [p for i, p in indexed_prompts]
    
    # 批量处理所有提示 (打乱后)
    logging.info(f"为 {data_source} {qtype} 并行生成 {len(shuffled_prompts)} 个预测 (已打乱顺序)...")
    shuffled_pred_answers = model.generate(shuffled_prompts)
    logging.info(f"完成 {data_source} {qtype} 的并行预测生成 (已打乱顺序)")

    # 4. 恢复预测结果的原始顺序
    # 创建一个临时列表来按原始索引存放预测结果
    temp_preds = [None] * len(all_prompts)
    for original_idx, pred in zip(original_indices_shuffled, shuffled_pred_answers):
        temp_preds[original_idx] = pred

    # 将恢复了正确顺序的预测结果赋给 all_pred_answers
    all_pred_answers = temp_preds
    # --- 结束修改 ---

    # 将结果映射回原始数据结构
    for i, (pred_answer, question, gold_answer, retrieved_chunks) in enumerate(zip(all_pred_answers, all_questions, all_gold_answers, all_retrieved_chunks)):
        data_idx, qa_idx = prompt_mapping[i]
        
        # 创建基本结果字典
        result_dict = {
            "Question": question,
            "Gold Answer": gold_answer,
            "Pred Answer": pred_answer
        }
        
        # 仅在setting为RAG时才添加检索到的chunks到结果字典
        if setting == "RAG" and retrieved_chunks is not None:
            result_dict["Retrieved Chunks"] = retrieved_chunks
        
        # 添加结果到适当的位置
        results[data_idx].append(result_dict)
    
    dump_qa_data(results, model, setting, data_source, qtype)
    logging.info(f"完成 {data_source} {qtype} {setting} {model.model_name} 的预测结果生成")


# 生成所有QAs的预测结果
def gen_pred_results_for_all_qa_data(model, setting, data_source):
    logging.info(f"开始为 {data_source} {setting} {model.model_name} 生成所有QA类型的预测结果...")
    
    # 为tcelongbench数据源添加context缓存
    context_cache = {}
    if data_source == "tcelongbench" and (setting == "base" or setting == "with_context_and_few_shot"):
        logging.info("开始为tcelongbench数据源预缓存每个data_idx的context...")
        # 预加载需要的数据
        cached_data = {}
        cached_data["tcelongbench_context"] = read_json("/home/weishaohang/workspace/Omni-Temp/contexts/tcelongbench_context_for_articles.json")
        cached_data["tcelongbench_articles"] = read_json("/home/weishaohang/workspace/Omni-Temp/tcelongbench_processing/raw_data_dir/TCE_News_Articles.json")
        
        # 获取所有data_idx
        # 使用第一个问题类型的数据来获取所有data_idx
        first_qtype = mapping_source_to_qtype_list[data_source][0]
        QAs = load_qa_data(data_source, first_qtype)
        idx_list = list(QAs.keys())
        
        # 为每个data_idx缓存context
        for data_idx in tqdm(idx_list, desc=f"为tcelongbench的每个data_idx缓存context"):
            context_cache[data_idx] = make_context(setting, data_source, data_idx, cached_data)
        
        logging.info(f"完成tcelongbench数据源context缓存，共缓存了{len(context_cache)}个data_idx的context")
    
    for qtype in mapping_source_to_qtype_list[data_source]:
        if data_source == "tcelongbench" and qtype == "L1_1":   # TODO: 跳过L1_1，因为L1_1不能做RAG
            continue
        if check_if_results_exist(model, setting, data_source, qtype):
            logging.info(f"结果已存在，跳过评估: {data_source} {qtype} {setting} {model.model_name}")
        else:
            # 清理CUDA缓存以减少碎片化
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            # 传递context_cache到gen_pred_results函数
            gen_pred_results(model, setting, data_source, qtype, context_cache=context_cache if data_source == "tcelongbench" else None)
    logging.info(f"完成 {data_source} {setting} {model.model_name} 所有QA类型的预测结果生成")


def main():
    # 创建命令行参数解析器
    parser = argparse.ArgumentParser(description='评估模型在不同数据集上的表现')
    parser.add_argument('--config', type=str, default='evaluation/config/eval.yaml', 
                        help='配置文件路径')
    parser.add_argument('--model', type=str, default=None, 
                        help='模型名称，会覆盖配置文件中的设置')
    parser.add_argument('--setting', type=str, default=None, 
                        help='评估设定，会覆盖配置文件中的设置')
    parser.add_argument('--data_source', type=str, default=None, 
                        help='数据源，会覆盖配置文件中的设置')
    # 添加检索器类型选择参数
    parser.add_argument('--retriever_type', type=str, default=DEFAULT_RETRIEVER_TYPE,
                        choices=RETRIEVER_TYPE_OPTIONS,
                        help='检索器类型: bm25(仅使用BM25), vector(仅使用向量检索), hybrid(混合使用)')
    args = parser.parse_args()
    
    # 检查配置文件是否存在
    config_path = Path(args.config)
    if not config_path.exists():
        raise FileNotFoundError(f"配置文件 {args.config} 不存在")
    
    # 加载配置文件
    config = yaml.safe_load(open(config_path, "r", encoding="utf-8"))
    
    # 命令行参数覆盖配置文件
    if args.model is not None:
        config["model"] = args.model
    if args.setting is not None:
        config["setting"] = args.setting
    if args.data_source is not None:
        config["data_source"] = args.data_source
    
    # 设置检索器类型
    global retriever_type
    if args.retriever_type is not None:
        retriever_type = args.retriever_type
        config["retriever_type"] = retriever_type
    elif "retriever_type" in config and config["retriever_type"] in RETRIEVER_TYPE_OPTIONS:
        retriever_type = config["retriever_type"]
    else:
        config["retriever_type"] = retriever_type
    
    logging.info(f"使用检索器类型: {retriever_type}")
    logging.info(f"Chunks将从以下目录读取: {os.path.join(CHUNKS_BASE_DIR, retriever_type)}")
    
    # 验证配置
    if config["setting"] not in setting_list:
        raise ValueError(f"无效的设定: {config['setting']}，可用选项: {setting_list}")
    if config["data_source"] not in data_source_list:
        raise ValueError(f"无效的数据源: {config['data_source']}，可用选项: {data_source_list}")
    if config["setting"] not in mapping_source_to_setting_list[config["data_source"]]:
        raise ValueError(f"数据源 {config['data_source']} 不支持设定 {config['setting']}")
    
    # 设置日志记录器
    setup_logger(config["model"], config["setting"], config["data_source"])
    
    # 打印配置信息
    logging.info("评估配置信息:")
    logging.info(f"  - 配置文件: {args.config}")
    logging.info(f"  - 模型: {config['model']}")
    logging.info(f"  - 设定: {config['setting']}")
    logging.info(f"  - 数据源: {config['data_source']}")
    
    print(f"正在使用以下配置进行评估:")
    print(f"  - 配置文件: {args.config}")
    print(f"  - 模型: {config['model']}")
    print(f"  - 设定: {config['setting']}")
    print(f"  - 数据源: {config['data_source']}")
    
    # 初始化模型并进行评估
    global model
    model = MODEL(config["model"])
    setting = config["setting"]
    data_source = config["data_source"]
    
    # 设置随机种子并记录
    random.seed(42)
    np.random.seed(42)
    logging.info(f"设置随机种子: 42")
    
    # 创建结果目录
    if setting == "RAG":
        # 在RAG设定下，添加检索器类型子目录
        result_dir = Path(f"/home/weishaohang/workspace/Omni-Temp/results/{model.model_name}/{setting}/{retriever_type}/{data_source}")
    else:
        # 非RAG设定保持原样
        result_dir = Path(f"/home/weishaohang/workspace/Omni-Temp/results/{model.model_name}/{setting}/{data_source}")
        
    os.makedirs(result_dir, exist_ok=True)
    logging.info(f"创建结果目录: {result_dir}")
    
    # 保存使用的配置
    with open(result_dir / "config.yaml", "w", encoding="utf-8") as f:
        yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
    logging.info(f"配置已保存到: {result_dir / 'config.yaml'}")
    
    try:
        # 执行评估
        gen_pred_results_for_all_qa_data(model, setting, data_source)
        logging.info("评估完成")
    except Exception as e:
        logging.error(f"评估过程中发生错误: {str(e)}", exc_info=True)
        raise

if __name__ == "__main__":
    main()