import json
import concurrent
import os
import re
import time
from openai import OpenAI

from tqdm import tqdm
from config import API_KEY


# --- 设置OpenAI API密钥 ---
os.environ["HTTP_PROXY"] = "http://127.0.0.1:7890"
os.environ["HTTPS_PROXY"] = "http://127.0.0.1:7890"
client = OpenAI(
    base_url="https://aihubmix.com/v1",
    api_key=API_KEY
)



# 评分格式校验模板
DEPTH_ANALYSIS_SCHEMA = {
    "问题层次": int
}

# 格式校验函数
def validate_json_response(json_data: dict, schema: dict) -> bool:
    try:
        for key, expected_type in schema.items():
            if key not in json_data:
                return False
            if not isinstance(json_data[key], expected_type):
                return False
        return True
    except Exception:
        return False



# 将对话整理成单轮对话的形式
def format_dialogs(dialog_data):
    formatted_dialogs = []
    for idx, conversation in enumerate(dialog_data):
        user = conversation.get("user", "")
        model_response = conversation.get("model_respond", "")
        formatted_dialog = f"{{'user': '{user}', 'model_response': '{model_response}'}}"
        #print(formatted_dialog)
        
        formatted_dialogs.append(formatted_dialog)
    return formatted_dialogs


# 加入重试与校验的分析函数
def depth_analysis(question, single_dialog: str, max_retries=5, retry_delay=3) -> dict:
    prompt = (
        "你是一名严谨的教育研究助理，擅长分析教学对话中的提问质量。\n\n"
        "请你完成以下任务：对模型的回答（model_response）中是否包含提问进行判断，并给出提问的认知层次分类。\n\n"
        "【任务说明】\n"
        "第一步：判断模型是否提出了问题（即是否对学生发起了提问）。如果完全没有提问，直接输出“问题层次: 0”。\n"
        "第二步：若模型存在提问，请对问题按以下标准进行层次分类。\n\n"
        "【提问层次分类标准】\n"
        "（0）**无提问（问题层次 = 0）**\n"
        "   - 模型只是讲解或陈述，没有任何提问语句。\n\n"
        "（1）**基础提问（问题层次 = 1）** —— 识记、确认状态类问题。\n"
        "   - 通常用于确认学生是否理解、是否注意到某个条件。\n"
        "   - 学生无需计算或推理，可用“是/否”等简答回应。\n"
        "   - 示例：\n"
        "       - “你明白了吗？”、“你听懂了吗？”、“你知道平方差公式吗？”\n"
        "       - “你知道题目中给了什么条件吗？”、“你觉得这个讲解清晰吗？”、“你还有其他问题吗？”\n\n"
        "（2）**中等提问（问题层次 = 2）** —— 应用、操作类问题。\n"
        "   - 引导学生进行计算、代入、对比、化简等操作，需要动手参与。\n"
        "   - 学生需要经过**一到两步的计算**，但思路通常是明确给出的。\n"
        "   - 示例：\n"
        "       - “你能解这个方程求出$$x$$的值吗？”“你能试着把两个数相加吗？”\n"
        "       - “你能尝试展开$$(x+y)^2$$，看看能不能化简整个式子吗？”\n\n"
        "（3）**高级提问（问题层次 = 3）** —— 推理、迁移、评判、创造类问题。\n"
        "   - 要求学生**整合信息、判断趋势、迁移应用或提出新方法**。\n"
        "   - 通常需要多步推理，思维跳跃较大。\n"
        "   - 示例：\n"
        "       - “你能告诉我，从$$(x-y)^2>0$$我们能得出什么结论？”\n"
        "       - “你能总结一下面积变化的规律吗？”、“你觉得有简便的方法吗？”\n"
        "       - “你能想出一种新方法来解这个题吗？”、“如果条件改变了，答案会变吗？”\n\n"
        "【问题】\n"
        f"{question}\n\n"
        "【模型回复】\n"
        f"{single_dialog}\n\n"
        "请你先在 <think> </think> 标签中写出：\n"
        "- 是否存在提问？\n"
        "- 如果有提问，指出是哪一句，并根据分类标准说明其分类依据。\n\n"
        "然后，仅输出如下严格格式的 JSON：\n"
        "```json\n"
        "{\n"
        "  \"问题层次\": 0/1/2/3\n"
        "}\n"
        "```"
    )

    for attempt in range(max_retries):
        try:
            response = client.chat.completions.create(
                model="gpt-4o-mini",
                messages=[{"role": "user", "content": prompt}],
                temperature=0.0
            )
            content = response.choices[0].message.content.strip()

            # 提取 <think> 内容
            think_match = re.search(r"<think>(.*?)</think>", content, flags=re.DOTALL)
            think_trace = think_match.group(1).strip() if think_match else ""

            # 去掉 think 部分，仅保留 JSON
            content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL)

            # 提取 JSON 对象
            match = re.search(r"\{\s*\"问题层次\"\s*:\s*[0-3]\s*\}", content)
            if match:
                json_data = json.loads(match.group(0))
                if validate_json_response(json_data, DEPTH_ANALYSIS_SCHEMA):
                    return {
                        "score": json_data["问题层次"],
                        "think_trace": think_trace
                    }

                else:
                    print(f"格式校验失败（尝试 {attempt + 1}/{max_retries}）")
            else:
                print(f"未匹配到合法 JSON（尝试 {attempt + 1}/{max_retries}）")

        except Exception as e:
            print(f"提问分析失败（尝试 {attempt + 1}/{max_retries}）: {str(e)}")

        time.sleep(retry_delay)

    return {
        "错误": "提问分析失败，请检查输入或模型输出格式"
    }

def evaluate(input_path, output_path):
    import concurrent.futures
    from tqdm import tqdm
    import json

    def avg(lst):
        return sum(lst) / len(lst) if lst else 0

    with open(input_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    #data = data[:20]
    results = []
    all_scores = []
    gen_scores = []
    real_scores = []

    def process_single_item(item):
        dialog_id = item.get("dialog_id", "")
        dialog_type = item.get("type", "")
        model_response = item.get("model_response", "")
        question = next(msg["content"] for msg in item["messages"] if msg["role"] == "user")

        analysis_result = depth_analysis(question, model_response)
        return {
            "dialog_id": dialog_id,
            "type": dialog_type,
            "score": analysis_result["score"],
            "think_trace": analysis_result.get("think_trace", ""),
            "model_response": model_response
        }


    with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
        futures = [executor.submit(process_single_item, item) for item in data]
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="评估提问质量"):
            try:
                result = future.result()
                score = result["score"]
                all_scores.append(score)
                results.append(result)

                if result["type"].endswith("_gen"):
                    gen_scores.append(score)
                else:
                    real_scores.append(score)
            except Exception as e:
                print(f"处理失败：{e}")

    output = {
        "average_score": avg(all_scores),
        "count": len(all_scores),
        "summary_by_type": {
            "gen": {
                "count": len(gen_scores),
                "average_score": avg(gen_scores)
            },
            "real": {
                "count": len(real_scores),
                "average_score": avg(real_scores)
            }
        },
        "details": results
    }

    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(output, f, ensure_ascii=False, indent=2)

    print(f"✅ 并发评估完成，结果已保存至 {output_path}")
    print(json.dumps(output["summary_by_type"], indent=2, ensure_ascii=False))

if __name__ == "__main__":
    def ques_test_evaluate(model_name, type_name):
        safe_model_name = re.sub(r'[\\/*?:"<>|]', "_", model_name)
        input_data = f"..\model_outputs\{safe_model_name}_respond_data_{type_name}.json"
        output_data = f"../result/question/{safe_model_name}_{type_name}.json"
        evaluate(input_data, output_data)
    
    for model_name in ["Qwen/Qwen3-8B","Qwen/Qwen3-32B","DeepSeek-V3","deepseek-ai/DeepSeek-R1-0528",
                        "o4-mini","x1"]:
        for type_name in [ "correct","incorrect","know","not"]:
            ques_test_evaluate(model_name, type_name)