import os 
import json

import tiktoken
from model import *
from utils import *
from tqdm import tqdm
import yaml
import argparse
from pathlib import Path
import logging
import datetime
import sys
import random
import numpy as np
import re

######################## logging setup ########################
def setup_logger(model_name, setting, data_source):
    # 创建logs目录（如果不存在）
    logs_dir = Path("/home/weishaohang/workspace/Omni-Temp/logs")
    os.makedirs(logs_dir, exist_ok=True)
    
    # 创建带有时间戳的日志文件名
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    log_filename = f"{model_name}_{setting}_{data_source}_{timestamp}.log"
    log_filepath = logs_dir / log_filename
    
    # 配置日志记录器
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    
    # 清除现有的处理器（避免重复添加）
    for handler in logger.handlers[:]:
        logger.removeHandler(handler)
    
    # 创建文件处理器
    file_handler = logging.FileHandler(log_filepath, encoding='utf-8')
    file_handler.setLevel(logging.INFO)
    
    # 创建控制台处理器（仅显示WARNING及以上级别的消息）
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(logging.WARNING)  # 将控制台日志级别设置为WARNING，这样INFO级别的日志就不会显示在控制台
    
    # 创建格式化器
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)
    console_handler.setFormatter(formatter)
    
    # 添加处理器到记录器
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)
    
    # 记录初始信息
    logging.info(f"开始评估 - 模型: {model_name}, 设定: {setting}, 数据源: {data_source}")
    logging.info(f"日志文件保存在: {log_filepath}")
    
    return logger

######################## preliminaries and prompt paths ########################
mapping_source_and_qtype_to_prompt_type = {
    "wikidata": {
        "L1_1": "multi_choice_qa",
        "L1_2": "free_form_qa_for_time_expression",
        "L1_3": "free_form_qa", # TODO:
        "L1_4": "single_choice_qa",
        "L1_5": "single_choice_qa",
        "L2_1": "free_form_qa", # TODO:
        "L2_2": "free_form_qa", # TODO:
        "L2_3": "free_form_qa_with_refusal",
        "L3_1": "free_form_qa_with_refusal",
        "L3_2": "free_form_qa", # TODO:
        "L3_3": "single_choice_qa_for_forecast",
        "L3_4": "free_form_qa_for_false_premise",
    },
    "tcelongbench": {
        "L1_1": "multi_choice_qa",
        "L1_2": "free_form_qa_for_time_expression", # TODO:
        "L1_3": "free_form_qa", # TODO:
        "L1_4": "single_choice_qa",
        "L1_5": "single_choice_qa",
        "L2_1": "free_form_qa", # TODO:
        "L2_1_multi_choice": "single_choice_qa",
        "L2_2": "free_form_qa", # TODO:
        "L2_2_multi_choice": "single_choice_qa",
        "L2_3": "free_form_qa", # TODO:
        "L2_3_multi_choice": "single_choice_qa",
        "L3_1": "free_form_qa", # TODO:
        "L3_1_multi_choice": "single_choice_qa",
        "L3_2": "free_form_qa", # TODO:
        "L3_3": "single_choice_qa_for_forecast",
        "L3_4": "free_form_qa_for_false_premise",
        "L3_4_multi_choice": "single_choice_qa_for_false_premise",  
    },
    "long_dialog": {
        "L1_1": "multi_choice_qa",
        "L1_2": "free_form_qa_for_time_expression", # TODO:
        "L1_3": "free_form_qa", # TODO:
        "L1_4": "single_choice_qa",
        "L1_5": "single_choice_qa",
        "L2_1": "single_choice_qa",
        "L2_2": "free_form_qa", # TODO:
        "L2_2_multi_choice": "single_choice_qa",
        "L2_3": "free_form_qa", # TODO:
        "L2_3_multi_choice": "single_choice_qa", 
        "L3_1": "free_form_qa", # TODO:
        "L3_1_multi_choice": "single_choice_qa",
        "L3_2": "free_form_qa", # TODO:
        "L3_3": "single_choice_qa_for_forecast",
        "L3_4": "free_form_qa_for_false_premise",
        "L3_4_multi_choice": "single_choice_qa_for_false_premise",  
    }
}

mapping_setting_prompt_type_to_prompt_path = {
    "base": {# 基础设定，即提供上下文，zero-shot
        "multi_choice_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_zero_shot/multi_choice_qa.txt",
        "free_form_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_zero_shot/free_form_qa.txt",
        "free_form_qa_with_refusal": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_zero_shot/free_form_qa_with_refusal.txt",
        "free_form_qa_for_time_expression": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_zero_shot/free_form_qa_for_time_expression.txt",
        "multi_choice_qa_for_forecast": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_zero_shot/multi_choice_qa_for_forecast.txt",
        "free_form_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_zero_shot/free_form_qa_for_false_premise.txt",
        "multi_choice_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_zero_shot/multi_choice_qa_for_false_premise.txt",
        "single_choice_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_zero_shot/single_choice_qa.txt",
        "single_choice_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_zero_shot/single_choice_qa_for_false_premise.txt",
        "single_choice_qa_for_forecast": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_zero_shot/single_choice_qa_for_forecast.txt"
    },
    "with_context_and_few_shot": {
        "multi_choice_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_few_shot/multi_choice_qa.txt",
        "free_form_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_few_shot/free_form_qa.txt",
        "free_form_qa_with_refusal": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_few_shot/free_form_qa_with_refusal.txt",
        "free_form_qa_for_time_expression": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_few_shot/free_form_qa_for_time_expression.txt",
        "multi_choice_qa_for_forecast": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_few_shot/multi_choice_qa_for_forecast.txt",
        "free_form_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_few_shot/free_form_qa_for_false_premise.txt",
        "multi_choice_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_few_shot/multi_choice_qa_for_false_premise.txt",
        "single_choice_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_few_shot/single_choice_qa.txt",
        "single_choice_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_few_shot/single_choice_qa_for_false_premise.txt",
        "single_choice_qa_for_forecast": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/with_context_few_shot/single_choice_qa_for_forecast.txt"
    },
    "RAG": {
        "multi_choice_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/RAG/multi_choice_qa.txt",
        "free_form_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/RAG/free_form_qa.txt",
        "free_form_qa_with_refusal": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/RAG/free_form_qa_with_refusal.txt",
        "free_form_qa_for_time_expression": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/RAG/free_form_qa_for_time_expression.txt",
        "multi_choice_qa_for_forecast": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/RAG/multi_choice_qa_for_forecast.txt",
        "free_form_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/RAG/free_form_qa_for_false_premise.txt",
        "multi_choice_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/RAG/multi_choice_qa_for_false_premise.txt",
        "single_choice_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/RAG/single_choice_qa.txt",
        "single_choice_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/RAG/single_choice_qa_for_false_premise.txt",
        "single_choice_qa_for_forecast": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/RAG/single_choice_qa_for_forecast.txt"
    },
    "closed_book": {
        "multi_choice_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/closed_book/multi_choice_qa.txt",
        "free_form_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/closed_book/free_form_qa.txt",
        "free_form_qa_with_refusal": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/closed_book/free_form_qa_with_refusal.txt",
        "free_form_qa_for_time_expression": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/closed_book/free_form_qa_for_time_expression.txt",
        "multi_choice_qa_for_forecast": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/closed_book/multi_choice_qa_for_forecast.txt",
        "free_form_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/closed_book/free_form_qa_for_false_premise.txt",
        "multi_choice_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/closed_book/multi_choice_qa_for_false_premise.txt",
        "single_choice_qa": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/closed_book/single_choice_qa.txt",
        "single_choice_qa_for_false_premise": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/closed_book/single_choice_qa_for_false_premise.txt",
        "single_choice_qa_for_forecast": "/home/weishaohang/workspace/Omni-Temp/prompts/evaluation/closed_book/single_choice_qa_for_forecast.txt"
    }
}

mapping_source_and_qtype_to_json_path = {
    "wikidata": {
        "L1_1": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L1_1_QAs.json",
        "L1_2": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L1_2_QAs.json",
        "L1_3": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L1_3_QAs.json",
        "L1_4": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L1_4_QAs.json",
        "L1_5": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L1_5_QAs.json",
        "L2_1": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L2_1_QAs.json",
        "L2_2": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L2_2_QAs.json",
        "L2_3": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L2_3_QAs.json",
        "L3_1": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L3_1_QAs.json",
        "L3_2": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L3_2_QAs.json",
        "L3_3": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L3_3_QAs.json",
        "L3_4": "/home/weishaohang/workspace/Omni-Temp/QAs_wikidata/L3_4_QAs.json",
    },
    "tcelongbench": {
        "L1_1": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L1_1_QAs.json",
        "L1_2": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L1_2_QAs.json",
        "L1_3": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L1_3_QAs.json",
        "L1_4": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L1_4_QAs.json",
        "L1_5": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L1_5_QAs.json",
        "L2_1": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L2_1_QAs.json",
        "L2_2": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L2_2_QAs.json",
        "L2_3": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L2_3_QAs.json",
        "L3_1": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L3_1_QAs.json",
        "L3_2": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L3_2_QAs.json",
        "L3_3": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L3_3_QAs.json",
        "L3_4": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L3_4_QAs.json",
        "L2_1_multi_choice": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L2_1_QAs_multi_choice.json",
        "L2_2_multi_choice": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L2_2_QAs_multi_choice.json",
        "L2_3_multi_choice": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L2_3_QAs_multi_choice.json",
        "L3_1_multi_choice": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L3_1_QAs_multi_choice.json",
        "L3_4_multi_choice": "/home/weishaohang/workspace/Omni-Temp/QAs_tcelongbench/L3_4_QAs_multi_choice.json",
    },
    "long_dialog": {
        "L1_1": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L1_1_QAs.json",
        "L1_2": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L1_2_QAs.json",
        "L1_3": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L1_3_QAs.json",
        "L1_4": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L1_4_QAs.json",
        "L1_5": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L1_5_QAs.json",
        "L2_1": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L2_1_QAs.json",
        "L2_2": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L2_2_QAs.json",
        "L2_3": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L2_3_QAs.json",
        "L3_1": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L3_1_QAs.json",
        "L3_2": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L3_2_QAs.json",
        "L3_3": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L3_3_QAs.json",
        "L3_4": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L3_4_QAs.json",
        "L2_2_multi_choice": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L2_2_QAs_multi_choice.json",
        "L2_3_multi_choice": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L2_3_QAs_multi_choice.json",
        "L3_1_multi_choice": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L3_1_QAs_multi_choice.json",
        "L3_4_multi_choice": "/home/weishaohang/workspace/Omni-Temp/QAs_long_dialog/L3_4_QAs_multi_choice.json",
    }
}

data_source_list = ["wikidata", "tcelongbench", "long_dialog"]

setting_list = ["base", "with_context_and_few_shot", "RAG", "closed_book"]

mapping_source_to_qtype_list = {
    "wikidata": ["L1_1", "L1_2", "L1_3", "L1_4", "L1_5", "L2_1", "L2_2", "L2_3", "L3_1", "L3_2", "L3_3", "L3_4"],
    "tcelongbench": ["L1_1", "L1_2", "L1_3", "L1_4", "L1_5", "L2_1", "L2_1_multi_choice", "L2_2", "L2_2_multi_choice", "L2_3", "L2_3_multi_choice", "L3_1", "L3_1_multi_choice", "L3_2", "L3_3", "L3_4", "L3_4_multi_choice"],
    # "long_dialog": ["L1_1", "L1_2", "L1_3", "L1_4", "L1_5", "L2_1", "L2_2", "L2_2_multi_choice", "L2_3", "L2_3_multi_choice", "L3_1", "L3_1_multi_choice", "L3_2", "L3_3", "L3_4", "L3_4_multi_choice"],
    
    "long_dialog": ["L1_1", "L1_2", "L1_3", "L1_4", "L1_5", "L2_1", "L2_2_multi_choice", "L2_3_multi_choice", "L3_1_multi_choice", "L3_2", "L3_4_multi_choice"]
}

mapping_source_to_meta_data_path = {
    "wikidata": "/home/weishaohang/workspace/Omni-Temp/wiki_data_processing/data_with_bank_timeline/data_with_timeline.json",
    "tcelongbench": "/home/weishaohang/workspace/Omni-Temp/meta_data/tcelongbench.json",
    "long_dialog": "/home/weishaohang/workspace/Omni-Temp/long_dialog_processing/raw_data/long_dialog.json"
}

mapping_source_to_setting_list = {
    "wikidata": ["base", "with_context_and_few_shot", "RAG", "closed_book"],
    "tcelongbench": ["base", "with_context_and_few_shot", "RAG", "closed_book"],
    "long_dialog": ["base", "with_context_and_few_shot"]
}

######################## load and dump data ########################
def load_prompt(setting, data_source, qtype):
    prompt_type = mapping_source_and_qtype_to_prompt_type[data_source][qtype]
    prompt_path = mapping_setting_prompt_type_to_prompt_path[setting][prompt_type]
    with open(prompt_path, "r", encoding="utf-8") as f:
        prompt = f.read()
    logging.debug(f"加载提示模板: {prompt_path}")
    return prompt

def load_qa_data(data_source, qtype):
    file_path = mapping_source_and_qtype_to_json_path[data_source][qtype]
    logging.info(f"加载QA数据: {file_path}")
    return json.load(open(file_path, "r", encoding="utf-8"))

def check_if_results_exist(model, setting, data_source, qtype):
    path = f"/home/weishaohang/workspace/Omni-Temp/results/{model.model_name}/{setting}/{data_source}/{qtype}.json"
    logging.info(f"检查结果是否存在: {path}")
    return os.path.exists(path)

def dump_qa_data(results, model, setting, data_source, qtype):
    logging.info(f"保存 {data_source} {qtype} {setting} {model.model_name} 的QA结果...")
    path = f"/home/weishaohang/workspace/Omni-Temp/results/{model.model_name}/{setting}/{data_source}/{qtype}.json"
    os.makedirs(os.path.dirname(path), exist_ok=True)   # 如果目录不存在，则创建目录（递归）
    with open(path, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=4, ensure_ascii=False)
    logging.info(f"结果已保存到: {path}")
    print(f"Dumped {data_source} {qtype} {setting} {model.model_name} qa data to {path}")

def make_prompt(setting, data_source, qtype, data_idx, question, cached_data=None):
    prompt_template = load_prompt(setting, data_source, qtype)
    context = make_context(setting, data_source, data_idx, cached_data)
    prompt = prompt_template.format(context=context, question=question)
    if data_source == "tcelongbench":
        if len(tiktoken.encoding_for_model("gpt-4").encode(prompt)) > model.max_input_context_len:
            print(f"prompt的tokens数量: {len(tiktoken.encoding_for_model('gpt-4').encode(prompt))}")
            raise ValueError("prompt的tokens数量超过了模型的最大输入上下文长度")
    return prompt

######################## make context ########################
def make_context(setting, data_source, data_idx, cached_data=None):
    # 如果是tcelongbench并且cached_data中有缓存的context，直接返回
    if "tcelongbench" in data_source and cached_data and "cached_context" in cached_data:
        logging.debug(f"使用缓存的context for data_idx: {data_idx}")
        return cached_data["cached_context"]

    # TODO: 未完成RAG的设定部分，这一部分还需要设计不同的检索器、对接对应的数据源（主要面向wikidata和tcelongbench两个数据源）
    def _get_article_from_md5(md5, article_dict=None):
        # 使用传入的预加载字典而不是每次读取文件
        if article_dict is None:
            # 如果没有传入预加载的字典，才读取文件（兼容旧代码）
            article_dict = read_json("/home/weishaohang/workspace/Omni-Temp/tcelongbench_processing/raw_data_dir/TCE_News_Articles.json")
        
        article_info = article_dict[md5]
        title = article_info["Title"]
        content = '\n'.join(article_info["Text"])
        return title, content
        
    # NOTE 对于wikidata，使用data里面获得的story，作为context
    if "wikidata" in data_source:
        if setting == "base" or setting == "with_context_and_few_shot":
            # 使用预加载的数据而不是每次读取文件
            if cached_data:
                context = cached_data["wikidata_data"][data_idx]["story"]
            else:
                # 如果没有传入预加载的数据，才读取文件（兼容旧代码）
                data_pth = mapping_source_to_meta_data_path[data_source]
                data = read_json(data_pth)
                context = data[data_idx]["story"]
            return context
        elif setting == "RAG":
            # TODO:从全网上的wikipedia中爬数据，作为context
            pass
        elif setting == "closed_book":
            # 因为是闭卷，所以没有context
            return None
    elif "tcelongbench" in data_source:
        if setting == "base" or setting == "with_context_and_few_shot":
            if cached_data:
                all_day_md5_list = cached_data["tcelongbench_context"]
                article_dict = cached_data["tcelongbench_articles"]
            else:
                # 如果没有预加载数据，才读取文件
                all_day_md5_list = read_json("/home/weishaohang/workspace/Omni-Temp/contexts/tcelongbench_context_for_articles.json")
                article_dict = read_json("/home/weishaohang/workspace/Omni-Temp/tcelongbench_processing/raw_data_dir/TCE_News_Articles.json")
                
            day_md5_list = list(all_day_md5_list[data_idx].items())[:-1]   # NOTE 1. data_idx应该是ce_id字符串, 2. 最后一天的articles不加入context
            context_list = []
            # TODO: 预加载所有articles，查看上下文窗口和原本context的比例关系，然后缩减文章数量，使得能够完全放入上下文窗口
            def convert_day_to_natural_language(day):
                # 将20150219转换为February 19, 2015
                day = int(day)
                year = day // 10000
                month = (day % 10000) // 100
                day = day % 100
                month_name = calendar.month_name[month]
                return f"{month_name} {day}, {year}"
            for day, md5_list in day_md5_list:
                for md5 in md5_list:
                    title, content = _get_article_from_md5(md5, article_dict)
                    context = f"Title: {title}, Day: {convert_day_to_natural_language(day)}\nContent: {content}"
                    context_list.append(context)
            context_str = '\n\n'.join(context_list)
            # TODO: 这里需要计算context_str的tokens数量，为了简单起见，可以使用tiktoken来计算
            toks_original_context = len(tiktoken.encoding_for_model("gpt-4").encode(context_str))   # NOTE 表示context_str的tokens数量，由于使用的模型几乎都是gpt-4同款tokenizer算法BPE，所以在这里使用gpt-4的tokenizer
            # print(f"原始context_str的tokens数量: {toks_original_context}")
            # 比较toks_original_context和上下文窗口的大小，如果toks_original_context大于上下文窗口的大小，则需要缩减context_list
            def _reduce_context_list(day_md5_list, context_window_size, reduce_ratio):
                # 按照缩减比例，随机选择文章，直到tokens数量小于上下文窗口的大小
                new_context_list = []
                for day, md5_list in day_md5_list:
                    # print(f"原始day: {day}, md5_list长度: {len(md5_list)}")
                    selected_md5_list = random.sample(md5_list, max(int(len(md5_list) // reduce_ratio)-3, 1))   # TODO:
                    # print(f"缩减后的md5_list长度: {len(selected_md5_list)}")
                    for md5 in selected_md5_list:
                        title, content = _get_article_from_md5(md5, article_dict)
                        context = f"Title: {title}, Day: {convert_day_to_natural_language(day)}\nContent: {content}"
                        new_context_list.append(context)
                return new_context_list
        
            # 获取模型最大输入上下文长度
            context_window_size = model.max_input_context_len
            # 如果context_str的tokens数量大于上下文窗口的大小，则需要缩减context_list
            if toks_original_context > context_window_size:
                # 计算缩减比例
                reduce_ratio = toks_original_context // context_window_size + 1  # 加1是为了防止出现tokens数量刚好等于上下文窗口的大小
                # print(f"缩减比例: {reduce_ratio}")
                # 按照缩减比例，随机选择文章，直到tokens数量小于上下文窗口的大小
                context_list = _reduce_context_list(day_md5_list, context_window_size, reduce_ratio)
                context_str = '\n\n'.join(context_list)  # 减去1000是为了防止出现tokens数量刚好等于上下文窗口的大小
                # 删去context_str末尾的2000个单词，同时保留换行符
                # 使用正则表达式匹配单词，保留换行符
                words_with_spaces = re.findall(r'(\S+|\s+)', context_str)
                # 计算单词数量（不包括空白字符）
                word_count = sum(1 for w in words_with_spaces if not w.isspace())
                
                if word_count > 2000:
                    # 需要保留的单词数量
                    words_to_keep = word_count - 2000
                    # 计数器
                    word_counter = 0
                    # 保留部分
                    preserved_text = []
                    
                    for token in words_with_spaces:
                        if not token.isspace():  # 如果是单词
                            word_counter += 1
                            if word_counter > words_to_keep:
                                break
                        preserved_text.append(token)
                    
                    # 重新组合文本
                    context_str = ''.join(preserved_text)
                    # print(f"删除末尾2000个单词后的context_str的tokens数量: {len(tiktoken.encoding_for_model('gpt-4').encode(context_str))}")
                else:
                    print(f"警告：context_str中的单词数量({word_count})少于2000，无法删除末尾2000个单词")
                
                # print(f"缩减后的context_str的tokens数量: {len(tiktoken.encoding_for_model('gpt-4').encode(context_str))}")
            return context_str
        elif setting == "RAG":
            # TODO: 将所有articles作为数据源，检索片段并得到context
            pass
        elif setting == "closed_book":
            # 因为是闭卷，所以没有context
            return None
    elif "long_dialog" in data_source:
        if setting == "base" or setting == "with_context_and_few_shot":
            if cached_data:
                context = cached_data["long_dialog_data"][data_idx]["context"]
            else:
                # 如果没有预加载数据，才读取文件
                data_pth = mapping_source_to_meta_data_path[data_source]
                data = read_json(data_pth)
                context = data[data_idx]["context"]
            return context
        elif setting == "RAG":
            # long_dialog没有RAG设定
            raise ValueError("long_dialog没有RAG")
        elif setting == "closed_book":
            # long_dialog没有closed_book设定
            raise ValueError("long_dialog没有closed_book")



######################## eval ########################
# NOTE 评估tcelongbench时，需要剔除最后一天的articles

def gen_pred_results(model, setting, data_source, qtype, context_cache=None):
    logging.info(f"开始生成 {data_source} {qtype} {setting} {model.model_name} 的预测结果...")
    print(f"\n\nGenerating predictions for {data_source} {qtype} {setting} {model.model_name} qa data...")
    
    # 读取QAs
    QAs = load_qa_data(data_source, qtype)

    # 获取QAs的idx
    idx_list = list(QAs.keys()) if isinstance(QAs, dict) else range(len(QAs))
    
    # 根据data_source，确定results的格式
    results = {data_idx: [] for data_idx in idx_list} if "tcelongbench" in data_source else [[] for _ in idx_list]
    
    # 预加载所有需要的数据文件，避免重复读取
    cached_data = {}
    if "wikidata" in data_source and (setting == "base" or setting == "with_context_and_few_shot"):
        data_pth = mapping_source_to_meta_data_path[data_source]
        logging.info(f"预加载 wikidata 数据: {data_pth}")
        cached_data["wikidata_data"] = read_json(data_pth)
        
    if "tcelongbench" in data_source and (setting == "base" or setting == "with_context_and_few_shot") and context_cache is None:
        # 只有在没有传入context_cache时才预加载数据
        logging.info("预加载 tcelongbench 上下文和文章数据")
        cached_data["tcelongbench_context"] = read_json("/home/weishaohang/workspace/Omni-Temp/contexts/tcelongbench_context_for_articles.json")
        cached_data["tcelongbench_articles"] = read_json("/home/weishaohang/workspace/Omni-Temp/tcelongbench_processing/raw_data_dir/TCE_News_Articles.json")
        
    if "long_dialog" in data_source and (setting == "base" or setting == "with_context_and_few_shot"):
        data_pth = mapping_source_to_meta_data_path[data_source]
        logging.info(f"预加载 long_dialog 数据: {data_pth}")
        cached_data["long_dialog_data"] = read_json(data_pth)
    
    # 收集所有数据，便于并行处理
    all_prompts = []
    all_questions = []
    all_gold_answers = []
    # 记录每个prompt对应的data_idx和qa索引，用于后续映射结果
    prompt_mapping = []  
    
    # 首先收集所有的问题和提示
    for data_idx in tqdm(idx_list, desc=f"Collecting prompts for {data_source}/{qtype}/{setting}/{model.model_name}", total=len(idx_list)):
        qas = QAs[data_idx]
        
        for qa_idx, qa in enumerate(qas):
            question = qa["Question"]
            gold_answer = qa["Gold Answer"]
            
            # 如果是tcelongbench且有缓存，则在make_prompt前将缓存的context添加到cached_data
            if "tcelongbench" in data_source and context_cache is not None and data_idx in context_cache:
                # 将缓存的context添加到cached_data
                cached_data["cached_context"] = context_cache[data_idx]
                
            prompt = make_prompt(setting, data_source, qtype, data_idx, question, cached_data)
            
            all_prompts.append(prompt)
            all_questions.append(question)
            all_gold_answers.append(gold_answer)
            prompt_mapping.append((data_idx, qa_idx))
    
    # --- 开始修改 ---
    # 1. 存储原始索引和 prompts
    indexed_prompts = list(enumerate(all_prompts))
    
    # 2. 打乱带索引的 prompts 列表 (确保 random 已导入并已设置种子)
    random.shuffle(indexed_prompts)
    
    # 3. 提取打乱后的 prompts 和对应的原始索引顺序
    original_indices_shuffled = [i for i, p in indexed_prompts]
    shuffled_prompts = [p for i, p in indexed_prompts]
    
    # 批量处理所有提示 (打乱后)
    logging.info(f"为 {data_source} {qtype} 并行生成 {len(shuffled_prompts)} 个预测 (已打乱顺序)...")
    shuffled_pred_answers = model.generate(shuffled_prompts)
    logging.info(f"完成 {data_source} {qtype} 的并行预测生成 (已打乱顺序)")

    # 4. 恢复预测结果的原始顺序
    # 创建一个临时列表来按原始索引存放预测结果
    temp_preds = [None] * len(all_prompts)
    for original_idx, pred in zip(original_indices_shuffled, shuffled_pred_answers):
        temp_preds[original_idx] = pred

    # 将恢复了正确顺序的预测结果赋给 all_pred_answers
    all_pred_answers = temp_preds
    # --- 结束修改 ---

    # 将结果映射回原始数据结构 (此部分保持不变，因为 all_pred_answers, all_questions, all_gold_answers, prompt_mapping 现在都是原始顺序)
    for i, (pred_answer, question, gold_answer) in enumerate(zip(all_pred_answers, all_questions, all_gold_answers)):
        data_idx, qa_idx = prompt_mapping[i]
        results[data_idx].append({
            "Question": question,
            "Gold Answer": gold_answer,
            "Pred Answer": pred_answer
        })
    
    dump_qa_data(results, model, setting, data_source, qtype)
    logging.info(f"完成 {data_source} {qtype} {setting} {model.model_name} 的预测结果生成")


# 生成所有QAs的预测结果
def gen_pred_results_for_all_qa_data(model, setting, data_source):
    logging.info(f"开始为 {data_source} {setting} {model.model_name} 生成所有QA类型的预测结果...")
    
    # 为tcelongbench数据源添加context缓存
    context_cache = {}
    if data_source == "tcelongbench" and (setting == "base" or setting == "with_context_and_few_shot"):
        logging.info("开始为tcelongbench数据源预缓存每个data_idx的context...")
        # 预加载需要的数据
        cached_data = {}
        cached_data["tcelongbench_context"] = read_json("/home/weishaohang/workspace/Omni-Temp/contexts/tcelongbench_context_for_articles.json")
        cached_data["tcelongbench_articles"] = read_json("/home/weishaohang/workspace/Omni-Temp/tcelongbench_processing/raw_data_dir/TCE_News_Articles.json")
        
        # 获取所有data_idx
        # 使用第一个问题类型的数据来获取所有data_idx
        first_qtype = mapping_source_to_qtype_list[data_source][0]
        QAs = load_qa_data(data_source, first_qtype)
        idx_list = list(QAs.keys())
        
        # 为每个data_idx缓存context
        for data_idx in tqdm(idx_list, desc=f"为tcelongbench的每个data_idx缓存context"):
            context_cache[data_idx] = make_context(setting, data_source, data_idx, cached_data)
        
        logging.info(f"完成tcelongbench数据源context缓存，共缓存了{len(context_cache)}个data_idx的context")
    
    for qtype in mapping_source_to_qtype_list[data_source]:
        if check_if_results_exist(model, setting, data_source, qtype):
            logging.info(f"结果已存在，跳过评估: {data_source} {qtype} {setting} {model.model_name}")
        else:
            # 清理CUDA缓存以减少碎片化
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            # 传递context_cache到gen_pred_results函数
            gen_pred_results(model, setting, data_source, qtype, context_cache=context_cache if data_source == "tcelongbench" else None)
    logging.info(f"完成 {data_source} {setting} {model.model_name} 所有QA类型的预测结果生成")


def main():
    # 创建命令行参数解析器
    parser = argparse.ArgumentParser(description='评估模型在不同数据集上的表现')
    parser.add_argument('--config', type=str, default='evaluation/config/eval.yaml', 
                        help='配置文件路径')
    parser.add_argument('--model', type=str, default=None, 
                        help='模型名称，会覆盖配置文件中的设置')
    parser.add_argument('--setting', type=str, default=None, 
                        help='评估设定，会覆盖配置文件中的设置')
    parser.add_argument('--data_source', type=str, default=None, 
                        help='数据源，会覆盖配置文件中的设置')
    args = parser.parse_args()
    
    # 检查配置文件是否存在
    config_path = Path(args.config)
    if not config_path.exists():
        raise FileNotFoundError(f"配置文件 {args.config} 不存在")
    
    # 加载配置文件
    config = yaml.safe_load(open(config_path, "r", encoding="utf-8"))
    
    # 命令行参数覆盖配置文件
    if args.model is not None:
        config["model"] = args.model
    if args.setting is not None:
        config["setting"] = args.setting
    if args.data_source is not None:
        config["data_source"] = args.data_source
    
    # 验证配置
    if config["setting"] not in setting_list:
        raise ValueError(f"无效的设定: {config['setting']}，可用选项: {setting_list}")
    if config["data_source"] not in data_source_list:
        raise ValueError(f"无效的数据源: {config['data_source']}，可用选项: {data_source_list}")
    if config["setting"] not in mapping_source_to_setting_list[config["data_source"]]:
        raise ValueError(f"数据源 {config['data_source']} 不支持设定 {config['setting']}")
    
    # 设置日志记录器
    setup_logger(config["model"], config["setting"], config["data_source"])
    
    # 打印配置信息
    logging.info("评估配置信息:")
    logging.info(f"  - 配置文件: {args.config}")
    logging.info(f"  - 模型: {config['model']}")
    logging.info(f"  - 设定: {config['setting']}")
    logging.info(f"  - 数据源: {config['data_source']}")
    
    print(f"正在使用以下配置进行评估:")
    print(f"  - 配置文件: {args.config}")
    print(f"  - 模型: {config['model']}")
    print(f"  - 设定: {config['setting']}")
    print(f"  - 数据源: {config['data_source']}")
    
    # 初始化模型并进行评估
    global model
    model = MODEL(config["model"])
    setting = config["setting"]
    data_source = config["data_source"]
    
    # 设置随机种子并记录
    random.seed(42)
    np.random.seed(42)
    logging.info(f"设置随机种子: 42")
    
    # 创建结果目录
    result_dir = Path(f"/home/weishaohang/workspace/Omni-Temp/results/{model.model_name}/{setting}/{data_source}")
    os.makedirs(result_dir, exist_ok=True)
    logging.info(f"创建结果目录: {result_dir}")
    
    # 保存使用的配置
    with open(result_dir / "config.yaml", "w", encoding="utf-8") as f:
        yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
    logging.info(f"配置已保存到: {result_dir / 'config.yaml'}")
    
    try:
        # 执行评估
        gen_pred_results_for_all_qa_data(model, setting, data_source)
        logging.info("评估完成")
    except Exception as e:
        logging.error(f"评估过程中发生错误: {str(e)}", exc_info=True)
        raise

if __name__ == "__main__":
    main()