import json
import re
import requests
import os
from collections import OrderedDict
import traceback

def infer_verdict(text: str) -> str:
    m = re.findall(r"verdict\s*:\s*\[\[([ABC])\]\]", text, flags=re.I)
    if m:
        return f"[[{m[-1]}]]"
    m = re.findall(r"\[\[([ABC])\]\]", text)
    if m:
        return f"[[{m[-1]}]]"
    lower = text.lower()
    keywords = r"(better|wins|outperforms|superior|stronger|more\s+accurate|" \
               r"more\s+detailed|more\s+helpful|more\s+comprehensive|more\s+robust)"
    a_hits = len(re.findall(r"assistant\s+a[^.]{0,80}?\b" + keywords, lower))
    b_hits = len(re.findall(r"assistant\s+b[^.]{0,80}?\b" + keywords, lower))
    if a_hits > b_hits:
        return "[[A]]"
    if b_hits > a_hits:
        return "[[B]]"
    return "[[C]]"

def query(prompt: str, model_name: str):
    url = 
    headers = {
        "Authorization": 
        "Content-Type": "application/json"
    }
    data = {
        "model": model_name,
        "messages": [{"role": "user", "content": prompt}]
    }
    rsp = requests.post(url, headers=headers, json=data).json()
    msg = rsp["choices"][0]["message"]
    think = msg.get("reasoning_content", "").strip()
    resp  = msg["content"].strip()
    cost = {k: rsp["usage"][k] for k in ("prompt_tokens",
                                         "completion_tokens",
                                         "total_tokens")}
    return think, resp, cost

class Judge:
    def __init__(self, model_name: str):
        self.model_name = model_name

    @staticmethod
    def generate_prompt(question: str, ans_a: str, ans_b: str) -> str:
        return f"""
You are an impartial judge tasked with evaluating two AI assistants' answers to a question.

Please output **TWO FULLY SEPARATE SECTIONS** using the headings:

**Part 1: Thinking Process (Compare internal reasoning only)**  
- In this section, compare only the **reasoning / chain-of-thought (CoT)** parts of the responses by Assistant A and Assistant B.  
- Do **not** evaluate their final answers. Focus purely on the logical thinking process that leads to the answers.  
- Analyze and compare based on reasoning quality, clarity, correctness, completeness, and depth.  
- Treat this as your own internal deliberation process.  
- **At the very end of this section, write on a new line exactly**:  
  Verdict: [[A]] or [[B]] or [[C]]  
  (based solely on the *reasoning/CoT content* of this section).  
- This verdict is **independent** of Part 2.

**Part 2: Final Judgment Response (Compare final answers only)**  
- Now ignore the internal reasoning, and focus only on the **final answer/output content** each assistant gave.  
- Evaluate the final responses directly, based on clarity, correctness, usefulness, and completeness.  
- This is your public-facing summary judgment, not a re-analysis.  
- **At the end of this section, write:**  
  Verdict: [[A]] or [[B]] or [[C]]  
  (based solely on the *final answers* in this second section).  
- This verdict is also **independent** of Part 1.

[User Question]
{question}

[Assistant A's Answer]
{ans_a}

[Assistant B's Answer]
{ans_b}

[Your Evaluation]
"""

    def __call__(self, question: str, ans_a: str, ans_b: str):
        prompt = self.generate_prompt(question, ans_a, ans_b)
        think, resp, cost = query(prompt, self.model_name)
        lbl_think = infer_verdict(think)
        lbl_resp  = infer_verdict(resp)
        if not think.rstrip().endswith(lbl_think):
            think += f"\n\n{lbl_think}"
        if not resp.rstrip().endswith(lbl_resp):
            resp  += f"\n\n{lbl_resp}"
        map2str = {"[[A]]": "A>B", "[[B]]": "B>A", "[[C]]": "A=B"}
        res_think = map2str[lbl_think]
        res_resp  = map2str[lbl_resp]
        return {
            "judge_thinking":  think,
            "thinking_result": res_think,
            "judge_response":  resp,
            "judge result":    res_resp,
            "cost_tokens":     cost
        }

def main(input_path, output_dir, start_id=1):
    os.makedirs(output_dir, exist_ok=True)
    judge = Judge("deepseek-r1")
    with open(input_path, encoding="utf-8") as f:
        data = json.load(f)
    data = data if isinstance(data, list) else [data]

    # 构建id到下标的映射，便于任意id断点续跑
    id_to_idx = {str(item["id"]): idx for idx, item in enumerate(data)}
    start_idx = id_to_idx.get(str(start_id), 0)

    for idx, item in enumerate(data[start_idx:], start=start_idx):
        cur_id = item.get("id", idx+1)
        output_file = os.path.join(output_dir, f"id:{cur_id}.json")
        if os.path.exists(output_file):
            print(f"文件已存在，跳过 id={cur_id}")
            continue

        try:
            rec = judge(item["question"], item["response_A"], item["response_B"])
            ordered = OrderedDict()
            ordered["id"]              = cur_id
            ordered["question"]        = item.get("question", "")
            ordered["response_a"]      = item.get("response_A", "")
            ordered["response_b"]      = item.get("response_B", "")
            ordered["model_A"]         = item.get("model_A", "")
            ordered["model_B"]         = item.get("model_B", "")
            ordered["judge model"]     = "deepseek-r1"
            ordered["judge_thinking"]  = rec["judge_thinking"]
            ordered["thinking_result"] = rec["thinking_result"]   # A>B / B>A / A=B
            ordered["judge_response"]  = rec["judge_response"]
            ordered["judge result"]    = rec["judge result"]      # A>B / B>A / A=B
            ordered["cost_tokens"]     = rec["cost_tokens"]

            with open(output_file, "w", encoding="utf-8") as f:
                json.dump(ordered, f, ensure_ascii=False, indent=4)
            print(f"✅ id={cur_id} 已保存到 {output_file}")

        except Exception as e:
            print(f"❌ id={cur_id} 处理失败: {e}")
            traceback.print_exc()  # 打印完整的错误堆栈信息
            break  # 或者pass，取决于你想中断还是继续

if __name__ == "__main__":
    # ===== 这里自定义你的输入输出和起始id =====
    input_path =          # 改成你的输入文件
    output_dir =                # 改成你的输出文件夹
    start_id = "93"                                    # 改成起始id（必须和你json里的id类型一致）

    main(input_path, output_dir, start_id)
