import json
import os
import re
import time
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import openai
from config import API_KEY

# ========== 设置模型及API ==========
#os.environ["HTTP_PROXY"] = "http://127.0.0.1:7890"
#os.environ["HTTPS_PROXY"] = "http://127.0.0.1:7890"


client = openai.OpenAI(
    base_url="https://aihubmix.com/v1",
    api_key=API_KEY
)



# ========== System Prompt ==========
SYS_PROMPT = {
    "role": "system",
    "content": """请你扮演一位苏格拉底式答疑的老师。与我进行多轮对话，遵守以下规则，对我的问题进行引导式解答：
        - 始终保持对话自然流畅，让交流富有逻辑性和互动性。
        - 不要直接给出答案或完整解题步骤，而是通过提问引导我思考。
        - 每次只提出一个引导性问题，问题应基于我的回答情况，帮助我逐步接近正确答案。
        - 如果学生一直表现出不理解，应该调整讲解方式，提供进一步的解释或更基础的问题。
    """
}

SYS_PROMPT_RULE = {
    "role": "system",
    "content": """
    请你扮演一位苏格拉底式答疑的老师。与我进行多轮对话，遵守以下规则，对我的问题进行引导式解答：
    - 始终保持对话自然流畅，让交流富有逻辑性和互动性。
    - 每次只提出一个引导性问题，问题应基于我的回答情况，帮助我逐步接近正确答案。
    - 不要直接给出答案或完整解题步骤，而是通过提问引导我思考。
    - 判断我的回答的种类，当前是回答正确/理解，或者回答错误，不理解？并根据情形选择策略：
        - 如果学生回答正确/理解，你应该肯定学生的回答，并在合适情况下提出稍具挑战性的问题，推动理解深化。
        - 如果学生回答错误，你应该在尝试引导其理解问题所在，方法可能包括重新解释、举例说明，或适度指出其错误之处。提出的下一步问题应该便于操作，或是进行以你为主导的讲解。
        - 如果学生表示不理解，可考虑调整解释方式，如更基础的语言、使用类比、或提出更简化的问题。
    """
}

SYS_PROMPT_DETAIL =  {
    "role": "system",
    "content": """
    你是一位擅长苏格拉底式教学的老师，任务是在学生提问或作答后，通过多轮对话引导他们理解知识、解决问题。请你严格遵守以下规则：

    【对话风格要求】
    - 语言自然流畅，语气温和鼓励，具备逻辑性和引导性；
    - 每轮对话只提出一个与学生当前状态高度相关的引导性问题；
    - 不直接给出答案或完整解题过程，而是通过层层提问帮助学生思考；
    - 若学生回答错误或表示不理解，请及时调整策略（如举例、换角度、回顾定义等）；

    【提问标准】
    低层次：识记、确认状态类问题，通常用于确认学生是否理解（“明白了吗？”）、是否注意到某个条件（向学生确认题目中某个条件），学生无需计算或推理，可用“是/否”等简答回应。
    中等层次：应用、操作类问题引导学生进行计算、代入、对比、化简等操作，需要动手参与。学生需要经过一到两步的计算，但思路通常是明确给出的。
    高层次问题：要求学生整合信息、判断趋势、迁移应用。通常需要多步推理，思维跳跃较大。

    【格式要求】
    - 每轮对话请先在思考中输出你的教学思路，使用如下格式包裹：
    <think>
    - 请你以一个出题人的角度进行考量。学生需要回答的上一轮问题是什么？这个问题想考察什么，有什么陷阱？上一轮问题的原本答案是什么？学生的回答是正确的吗？
    - 判断学生回答的种类，学生当前是回答正确/理解，或者回答错误，不理解？并根据情形选择策略（严格遵守以下策略）：
        - 如果学生回答正确，或者说理解，你应该在回复中直接、显示地鼓励学生，在回答中使用“你已经知道上一问的答案是...”，在给出上一问答案的同时，去推进学生继续理解这个题目。
        - 如果学生回答错误，你应该在回复中直接地告诉学生这个回答的错误，在回答中使用“不是...,而是...”说明学生错误的原因并及时纠正，提出的下一步问题应该更加简单，问题层次更低，并在讲解中包含更多基础概念。
        - 如果学生说不理解，那么你应该查看上一步中最难理解的点是什么，并且对这个点进行分解，或者使用举例、打比方、联系生活实际等方式让学生以更简单的方式理解。提出的下一步问题应该更加简单，问题层次更低，并在讲解中包含更多基础概念。
    </think>

    然后再给出你面向学生的回复（仅展示回复）。

    请始终使用上述格式进行教学式对话。
"""
}


# ========== 模型对话封装 ==========
def chat_with_model(history_data, model_name, max_retries=10):
    retry_count = 0
    while retry_count < max_retries:
        try:
            response = client.chat.completions.create(
                model= model_name,
                messages=history_data,
                temperature=0.1,
                max_tokens= 2048,
                extra_body={
                    "chat_template_kwargs": {"enable_thinking": False}
                }
            )

            print(response)

            if not response or not response.choices:
                retry_count += 1
                time.sleep(3)
                continue
            message = response.choices[0].message
            content = getattr(message, "content", None)
            reasoning_content = getattr(message, "reasoning_content", None)
            final_response = (content or reasoning_content or "").strip()
            if final_response:
                if "<think>" in final_response:
                    final_response = re.sub(r"<think>.*?</think>", "", final_response, flags=re.DOTALL).strip()
                return final_response
            retry_count += 1
            time.sleep(3)
        except Exception as e:
            print(f"❗API 异常: {e}")
            error_str = str(e).lower()
            if any(k in error_str for k in ["rate_limit", "429", "system is too busy", "503", "bad_response_status_code"]):
                retry_count += 1
                time.sleep(3)
                continue
            elif "10013" in error_str:
                print("⛔ 内容违规，跳过本条。")
                return None
            else:
                raise e
    print("❌ 达到最大重试次数，放弃请求。")
    return None

# ========== 单轮生成函数 ==========
def process_single_dialog(dialog, model_name):
    dialog_id = dialog.get("dialog_id", "")
    dialog_type = dialog.get("type", "")
    gt = dialog.get("GT", "")
    messages = dialog.get("messages", [])
    if not messages:
        return None
    history_data = [{"role": m["role"], "content": m["content"]} for m in messages]
    full_context = [SYS_PROMPT] + history_data
    model_response = chat_with_model(full_context, model_name)
    if model_response is None:
        return None
    
    return {
        "dialog_id": dialog_id,
        "type": dialog_type,
        "messages": dialog["messages"],   # ✅ 覆盖更新
        "GT": dialog.get("GT", ""),       # ✅ 覆盖更新
        "model_response": model_response
    }


# ========== 主控入口 ==========
def process_all_dialogs(input_path, output_path, model_name, max_workers=4, limit=None, save_every=20):
    import json
    from tqdm import tqdm
    from concurrent.futures import ThreadPoolExecutor, as_completed

    # 加载输入数据
    with open(input_path, "r", encoding="utf-8") as f:
        dialogs = json.load(f)
    if limit:
        dialogs = dialogs[:limit]

    # 构建 dialog_id -> messages 映射
    input_map = {item["dialog_id"]: item["messages"] for item in dialogs}
    input_ids = set(input_map.keys())

    # 加载已有结果
    existing_results = []
    existing_map = {}
    if os.path.exists(output_path):
        with open(output_path, "r", encoding="utf-8") as f:
            try:
                existing_results = json.load(f)
                existing_map = {item["dialog_id"]: item for item in existing_results if "dialog_id" in item}
                print(f"🔁 已加载 {len(existing_map)} 条历史结果")
            except Exception as e:
                print(f"⚠️ 读取已有输出失败: {e}")

    from collections import Counter


    # 判断 messages 是否一致
    def normalize_messages(msgs):
        normalized = [{k: m[k] for k in sorted(m)} for m in msgs]
        return json.dumps(normalized, ensure_ascii=False)

    def messages_equal(msgs1, msgs2, gt1="", gt2=""):
        return (json.dumps(msgs1, sort_keys=True, ensure_ascii=False) == json.dumps(msgs2, sort_keys=True, ensure_ascii=False)) \
            and (gt1 == gt2)



    # 分类：哪些保留、哪些需要重生成
    preserved_results = []
    remaining_dialogs = []
    for d in dialogs:
        did = d.get("dialog_id", "")
        if did in existing_map and messages_equal(d["messages"], existing_map[did].get("messages", []),
                                          d.get("GT", ""), existing_map[did].get("GT", "")):
            preserved_results.append(existing_map[did])
        else:
            remaining_dialogs.append(d)



    # 多线程生成新的回复，去重保存
    new_results_map = {}

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(process_single_dialog, dialog, model_name): dialog for dialog in remaining_dialogs}
        progress = tqdm(total=len(futures), desc="生成模型回复")

        for future in as_completed(futures):
            try:
                result = future.result()
                if result:
                    new_results_map[result["dialog_id"]] = result
            except Exception as e:
                dialog = futures[future]
                print(f"❗对话 {dialog.get('dialog_id', '未知')} 出现异常: {e}")
            progress.update(1)

            # 增量保存
            if (len(preserved_results) + len(new_results_map)) % save_every == 0:
                merged = preserved_results + list(new_results_map.values())
                with open(output_path, "w", encoding="utf-8") as f:
                    json.dump(merged, f, ensure_ascii=False, indent=2)
                print(f"💾 已增量保存 {len(merged)} 条")

        progress.close()

    # 最终保存合并结果
    all_results = preserved_results + list(new_results_map.values())
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(all_results, f, ensure_ascii=False, indent=2)
    print(f"✅ 全部处理完成，共保存 {len(all_results)} 条至 {output_path}")



# ========== 启动程序 ==========
def get_data(model_name, max_workers=16):
    safe_model_name = re.sub(r'[\\/*?:"<>|]', "_", model_name)
    for file in ["data_correct.json","data_incorrect.json",
                 "data_know.json","data_not.json"]:
        input_file = f"../data/{file}"
        output_file = f"../model_outputs/{safe_model_name}_respond_{file}"
        process_all_dialogs(input_file, output_file, model_name, max_workers,limit=None)

if __name__ == "__main__":
    model_name = "Qwen/Qwen3-8B"
    history_data = [
      {
        "role": "user",
        "content": "有一种细胞，每过一分钟这种细胞中的每一个细胞都会分裂成a个细胞．若初始培养皿内只有一个这种细胞，则经过4分钟后，培养皿内这种细胞的个数为 ____．"
      },
      {
        "role": "assistant",
        "content": "好的，我们来一起分析这道题。题目中说，有一种细胞，每过一分钟，这种细胞中的每一个细胞都会分裂成$a$个细胞。初始时，培养皿内只有一个这种细胞。你能告诉我，经过$1$分钟后，培养皿内会有多少个细胞吗？"
      },
      {
        "role": "user",
        "content": "你是人机吗？"
      },
      {
        "role": "assistant",
        "content": "我是一个人工智能程序。现在，我们回到题目上来。经过$1$分钟后，培养皿内会有$a$个细胞。你明白了吗？"
      },
      {
        "role": "user",
        "content": "明白了"
      },
      {
        "role": "assistant",
        "content": "很好，既然你理解了经过$1$分钟后的细胞数量，我们继续。由于每分钟每个细胞都会分裂成$a$个细胞，所以经过$4$分钟后，细胞的数量将是初始数量的$a$的$4$次方。你能计算出经过$4$分钟后，培养皿内会有多少个细胞吗？"
      },
      {
        "role": "user",
        "content": "不应该是分裂三次吗？"
      },
      {
        "role": "assistant",
        "content": "你可能有些混淆了，我们来仔细看看。初始时有一个细胞，经过$1$分钟后，它分裂成$a$个细胞。再经过$1$分钟，这$a$个细胞各自又分裂成$a$个细胞，所以细胞总数变成了$a^2$。这个过程持续$4$分钟，所以细胞总数会变成$a^4$。你明白了吗？"
      },
      {
        "role": "user",
        "content": "我明白了，请继续讲解"
      },
      {
        "role": "assistant",
        "content": "好的，我们来总结一下。经过$4$分钟后，培养皿内这种细胞的个数为$a^{4}$。这道题目主要考察了细胞分裂的规律，我们通过分析细胞每分钟分裂的规律，计算出了经过$4$分钟后细胞的总数。你做得很好，继续保持这种学习态度，数学学习会越来越轻松的！[summarize]"
      },
      {
        "role": "user",
        "content": "那五分钟是不是就是a的五次方了"
      }
    ]

    #get_data(model_name, max_workers=4)
    chat_with_model(history_data, model_name, max_retries=10)
