import json
import os
import time
import re
import concurrent.futures
from typing import Any, Dict
from tqdm import tqdm
from openai import OpenAI
from config import API_KEY

os.environ["HTTP_PROXY"] = "http://127.0.0.1:7890"
os.environ["HTTPS_PROXY"] = "http://127.0.0.1:7890"
client = OpenAI(
    base_url="https://aihubmix.com/v1",
    api_key=API_KEY
)

GUIDE_QUALITY_EVALUATION_SCHEMA = {
    "用户意图": int,
    "引导质量": int,
}


def validate_json_response(json_data: Dict[str, Any], schema: Dict[str, Any]) -> bool:
    try:
        for key, expected_type in schema.items():
            if key not in json_data or not isinstance(json_data[key], expected_type):
                return False
        return True
    except Exception:
        return False


def prompt_type_1(context: str, current_user: str, current_model_respond: str) -> str:
    return (
        "你是一名严谨的教育研究助理，擅长分析教学对话中模型的推进能力。\n\n"
        "【背景】以下是截至当前的多轮教学对话，请通读并理解其演进过程：\n"
        f"{context}\n\n"
        "【任务目标】\n"
        "学生在当前轮表示“已经理解”或“回答正确”。\n"
        "请判断：当前轮的模型回复是否具有“推进性” —— 即**是否在前一轮基础上进一步推动了解题过程**。\n\n"
        "【得分标准】\n"
        "请根据以下标准判断模型是否具有推进性（引导质量得分为 1 或 0）：\n\n"
        "以下情况视为具有推进性（得分 = 1）：\n"
        "- 引导学生进行下一步计算、分析或推理\n"
        "- 提出新的问题或挑战，引导学生深入思考\n"
        "- 总结当前内容并引出更高阶的问题或方法\n"
        "- 【注意】如果模型给出最终答案，则不用再推进，直接为1分。\n"
        "以下情况视为不具备推进性（得分 = 0）：\n"
        "- 重复上一轮的讲解或问题\n"
        "- 重新询问学生已回答过的内容\n"
        "- 围绕学生已掌握的部分反复兜圈，未有新内容\n\n"
        f"【当前轮学生发言】：{current_user}\n"
        f"【当前轮模型回复】：{current_model_respond}\n\n"
         "请先在 <think> 标签中写出你的判断依据，包括：\n"
        "- 当前轮模型相比上一轮新增了哪些信息？是否在前一轮基础上推进了解题？\n"
        "- 是否出现新问题、新步骤或结束性总结？\n"
        "- 为什么你认为这是有/无推进性的？\n\n"
        "然后仅输出以下严格格式的 JSON：\n"
        "```json\n"
        "{\n"
        "  \"引导质量\": 0/1\n"
        "}\n"
        "```"
    )


def prompt_type_2(context: str, current_user: str, current_model_respond: str, GT: str) -> str:
    return (
        "你是一名严谨的教育研究助理，擅长分析模型在学生回答错误后的讲解重构能力。\n\n"
        "【背景】以下是截至当前的多轮教学对话，请通读并理解其演进过程：\n"
        f"{context}\n\n"
        "【任务目标】\n"
        "学生在上一轮回答中未能正确回答上一轮模型的问题，请完成以下任务：\n"
        "第一步：判断学生在上一轮回答中错在哪里，并结合参考答案（GT）指出主要错误点。\n"
        "第二步：请判断当前轮模型的回复是否具有“重构性” —— 即**模型是否识别到错误，并尝试用合理方式引导学生纠正或理解正确解法。**\n\n"
        "【得分标准】（得分为 1 或 0）\n"
        "以下情况判为有效重构（得 1 分）：\n"
        "- 模型暗示了学生的错误，或通过提问/引导/验证等方式引导学生自行发现错误（注意，不一定要明确指出错误）；\n"
        "- 引导方式可以是：提示、重算、强调题目条件或知识点、让学生检查某一部分是否合理等；\n"
        "- 不要求直接讲出答案，也不强求改变说法，只要方向正确、能帮助学生修正错误即可。\n\n"
        "以下情况判为无效重构（得 0 分）：\n"
        "- 模型未发现学生错误，默认错误回答正确；\n"
        "- 或者模型讲解方向错误、偏离参考答案，未做出有效引导；\n"
        "- 或者模型跳过错误继续推进下一步。\n\n"
        f"【参考答案】：\n{GT}\n\n"
        f"【当前轮学生发言】：{current_user}\n"
        f"【当前轮模型回复】：{current_model_respond}\n\n"
        "请先在 <think> 标签中写出你的判断依据，包括：\n"
        "- 学生的错误点是什么？（与参考答案的差异在哪）\n"
        "- 模型有没有围绕这些错误点进行引导？有哪些关键表达？是否方向正确？\n\n"
        "然后仅输出以下严格格式的 JSON：\n"
        "```json\n"
        "{\n"
        "  \"引导质量\": 0/1\n"
        "}\n"
        "```"
    )

def prompt_type_3(context: str, current_user: str, current_model_respond: str) -> str:
    return (
        "你是一名严谨的教育研究助理，擅长分析模型在学生“不理解”时的讲解重构能力。\n\n"
        "【背景】以下是截至当前的多轮教学对话，请通读并理解其演进过程：\n"
        f"{context}\n\n"
        "【任务目标】\n"
        "学生在当前轮表示不理解，说明他对上一轮模型讲解中的某些内容感到困惑。\n"
        "请完成以下两步任务：\n"
        "第一步：请对比模型当前轮与上一轮模型讲解内容，判断当前轮模型是否新增了有助理解的信息（如：更细的推理、更基础的概念、具体例子、定义回顾等）。\n"
        "第二步：请判断这些新增内容是否有助于学生理解之前不明白的内容，即是否降低了认知负担或更接近正确理解路径。\n\n"
        "【得分标准】（得分为 1 或 0）\n"
        "以下情况判为有效重构（得 1 分）：\n"
        "- 模型新增了更具体、更基础、更细致的讲解，如换种说法、分步推理、数值代入、概念解释等；\n"
        "- 模型尝试补充前提条件或澄清概念定义，或引导学生观察问题结构，也算有效；\n"
        "- 不要求模型完全解决困惑点，只要在方向上更接近清晰理解，即可。\n\n"
        "以下情况判为无效重构（得 0 分）：\n"
        "- 模型只是重复了上一轮内容或原话；\n"
        "- 模型只是提问“你哪里不懂”或鼓励学生自己思考，但没有新增讲解；\n"
        "- 模型讲解内容仍停留在学生不理解的层次，或没有降低认知难度。\n\n"
        f"【当前轮学生发言】：{current_user}\n"
        f"【当前轮模型回复】：{current_model_respond}\n\n"
        "请先在 <think> 标签中写出你的判断依据，包括：\n"
        "- 当前轮模型相比上一轮新增了哪些信息？是更细讲解、补定义、举例子，还是换种表达？\n"
        "- 这些新增信息是否相比更容易理解？是否降低了学生的认知负担？\n\n"
        "然后仅输出以下严格格式的 JSON：\n"
        "```json\n"
        "{\n"
        "  \"引导质量\": 0/1\n"
        "}\n"
        "```"
    )




def guide_analysis(history, single_dialog, user_type, GT=""):
    context_str = ""
    for i, turn in enumerate(history):
        context_str += f"【第{i+1}轮 用户】：{turn['user']}\n"
        context_str += f"【第{i+1}轮 模型】：{turn['model_respond']}\n"

    current_user = single_dialog["user"]
    current_model = single_dialog["model_respond"]

    if user_type == 1:
        prompt = prompt_type_1(context_str, current_user, current_model)
    elif user_type == 2:
        prompt = prompt_type_2(context_str, current_user, current_model, GT)
    elif user_type == 3:
        prompt = prompt_type_3(context_str, current_user, current_model)

    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[{"role": "user", "content": prompt}],
        temperature=0.0
    )
    return response.choices[0].message.content



def evaluate_item(history, single_dialog, user_type, GT="", max_retries=5):
    for attempt in range(max_retries):
        try:
            result = guide_analysis(history, single_dialog, user_type, GT)

            # 提取 <think> 内容
            think_match = re.search(r"<think>(.*?)</think>", result, flags=re.DOTALL)
            think_trace = think_match.group(1).strip() if think_match else ""

            # 清除 think 后提取 JSON
            result = re.sub(r"<think>.*?</think>", "", result, flags=re.DOTALL).strip()
            match = re.search(r"\{\s*\"引导质量\"\s*:\s*[01]\s*\}", result)
            if match:
                parsed = json.loads(match.group(0))
                if "引导质量" in parsed:
                    parsed["think_trace"] = think_trace
                    return parsed
        except Exception as e:
            print(f"[evaluate-error] 第 {attempt+1} 次失败: {e}")
        time.sleep(1)

    return {"引导质量": 0, "think_trace": ""}




def determine_user_type(dialog_type: str) -> int:
    if dialog_type.startswith("correct") or dialog_type.startswith("know"):
        return 1
    elif dialog_type.startswith("incorrect"):
        return 2
    elif  dialog_type.startswith("not"):
        return 3
    return 0


def process_single_dialog(item):
    dialog_id = item.get("dialog_id", "")
    dialog_type = item.get("type", "")
    model_response = item.get("model_response", "")
    messages = item.get("messages", [])
    gt = item.get("GT", "")  # ✅ 获取参考答案

    user_type = determine_user_type(dialog_type)
    if user_type not in [1, 2, 3]:
        return None

    # 构造历史对话
    history = []
    for i in range(0, len(messages) - 1, 2):
        user = messages[i]["content"] if i < len(messages) and messages[i]["role"] == "user" else ""
        model = messages[i + 1]["content"] if i + 1 < len(messages) and messages[i + 1]["role"] == "assistant" else ""
        history.append({"user": user, "model_respond": model})

    # 当前轮（单轮结构）
    last_user = next((m["content"] for m in reversed(messages) if m["role"] == "user"), "")
    single_dialog = {
        "user": last_user,
        "model_respond": model_response
    }

    result = evaluate_item(history, single_dialog, user_type, gt)
    return {
        "dialog_id": dialog_id,
        "type": dialog_type,
        "score": result["引导质量"],
        "model_response": model_response,
        "think_trace": result.get("think_trace", "")
    }



def evaluate(input_path, output_path):
    with open(input_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    #data = data[:200]
    all_results = []
    gen_scores = []
    real_scores = []

    with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
        futures = [executor.submit(process_single_dialog, item) for item in data]
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="评估引导质量"):
            result = future.result()
            if result is not None:
                all_results.append(result)
                score = result["score"]
                if result["type"].endswith("_gen"):
                    gen_scores.append(score)
                else:
                    real_scores.append(score)

    summary = {
        "average_score": (sum(gen_scores + real_scores) / len(gen_scores + real_scores)) if (gen_scores or real_scores) else 0,
        "count": len(gen_scores + real_scores),
        "summary_by_type": {
            "gen": {
                "count": len(gen_scores),
                "average_score": sum(gen_scores)/len(gen_scores) if gen_scores else 0
            },
            "real": {
                "count": len(real_scores),
                "average_score": sum(real_scores)/len(real_scores) if real_scores else 0
            }
        }
    }

    output = {
        "average_score": summary["average_score"],
        "count": summary["count"],
        "summary_by_type": summary["summary_by_type"],
        "details": all_results
    }

    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(output, f, ensure_ascii=False, indent=2)

    print(f"✅ 结果已保存至 {output_path}")
    print(json.dumps(summary, indent=2, ensure_ascii=False))



if __name__ == "__main__":
    def instruc_test_evaluate(model_name, type_name):
        safe_model_name = re.sub(r'[\\/*?:"<>|]', "_", model_name)
        input_data = f"..\model_outputs\{safe_model_name}_respond_data_{type_name}.json"
        output_data = f"../result/instruction/{safe_model_name}_{type_name}.json"
        evaluate(input_data, output_data)
    
    for model_name in ["Qwen/Qwen3-32B", "Qwen/Qwen3-8B","DeepSeek-V3"]:
        for type_name in [ "incorrect"]:
            instruc_test_evaluate(model_name, type_name)