#!/usr/bin/env python
# -*- coding: utf-8 -*-

import sys, json, time
from pathlib import Path
from typing import Dict, List, Tuple
import concurrent.futures as cf

import openai
import pandas as pd
from tqdm import tqdm

# ————— 1. OpenAI 账号 —————
openai_client = openai.OpenAI(
    base_url="",
    api_key="",
)
MODEL_NAME  = "deepseek-chat"
TEMPERATURE = 1.0
MAX_WORKERS = 64   

# ————— 2. Prompt 模板 —————
PROMPT = """现有如下问题：

问题：{question}
该问题的**标准答案**如下：{reference}

我们现在得到了两份学生对该问题的回答。
回答一的内容如下：{resp1}
回答二的内容如下：{resp2}

现在，你应当以该问题的内容为上下文背景，必须以标准答案为判决依据，判决答案一和答案二中哪份更贴近标准答案的内容。
现在，请你先展开简单的分析，随后给出结论：回答一或回答二更优秀
有时，回答一与回答二和标准答案相比也有可能不相上下。

你的回答必须在最后一行以如下形式给出：
因此，【回答一】更优秀 或 因此，【回答二】更优秀
若你认为这两个回答不相上下，你应当在最后一行以如下形式给出：
因此，【两个回答相同】
"""

PAIR_KEYS = [
    ("student_answer_a", "student_answer_b", "student_answer_a vs student_answer_b_new"),
    ("student_answer_a", "model_answer_a",    "student_answer_a vs model_answer_a_new"),
    ("student_answer_a", "model_answer_b",    "student_answer_a vs model_answer_b_new"),
    ("model_answer_a",   "model_answer_b",    "model_answer_a vs model_answer_b_new"),
]

def call_gpt(prompt: str) -> str:
    """单轮对话，返回 GPT 的完整回答文本。"""
    backoff = 5
    while True:
        try:
            rsp = openai_client.chat.completions.create(
                model=MODEL_NAME,
                temperature=TEMPERATURE,
                max_tokens=512,
                messages=[{"role": "user", "content": prompt}],
            )
            return rsp.choices[0].message.content.strip()
        except (openai.RateLimitError, openai.APIError) as e:
            print(f"OpenAI error: {e}. Retry in {backoff}s", file=sys.stderr)
            time.sleep(backoff)

def last_line(text: str) -> str:
    lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
    return lines[-1] if lines else ""

def process_pair(item: Dict, k1: str, k2: str, label_key: str) -> Tuple[str, str, str, str]:
    prompt = PROMPT.format(
        question=item["question"],
        reference=item["Concise_Reference"],
        resp1=item[k1],
        resp2=item[k2],
    )
    full    = call_gpt(prompt)
    verdict = last_line(full)
    human   = item[label_key]
    return label_key, full, verdict, human

def process_item(item: Dict, executor: cf.ThreadPoolExecutor) -> Dict:
    row = {"id": item["id"]}

    futures = [
        executor.submit(process_pair, item, k_left, k_right, label_key)
        for k_left, k_right, label_key in PAIR_KEYS
    ]

    for fut in cf.as_completed(futures):
        label_key, full, verdict, human = fut.result()

        prefix = label_key
        row[f"{prefix}_model_full"]    = full
        row[f"{prefix}_model_verdict"] = verdict
        row[f"{prefix}_human"]         = human

        print(f" {prefix}: {verdict}")

    return row

def main(json_path: Path, out_prefix: Path, start_index: int = 0):
    data = json.loads(json_path.read_text(encoding="utf-8"))
    data_to_process = data[start_index:]

    rows, part = [], 0

    with cf.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        for item in tqdm(data_to_process, desc="GPT Inference"):
            row = process_item(item, executor)
            rows.append(row)

            if len(rows) >= 100:
                part += 1
                pd.DataFrame(rows).to_excel(f"{out_prefix.stem}_part{part}.xlsx", index=False)
                print(f"Saved {out_prefix.stem}_part{part}.xlsx  ({len(rows)})")
                rows.clear()

        if rows:
            part += 1
            pd.DataFrame(rows).to_excel(f"{out_prefix.stem}_part{part}.xlsx", index=False)
            print(f"Saved {out_prefix.stem}_part{part}.xlsx  ({len(rows)})")

if __name__ == "__main__":
    if len(sys.argv) != 4:
        sys.exit("usage: python pairwise_inference_save_full.py  data.json  results.xlsx  start_index")

    json_file   = Path(sys.argv[1])
    output_file = Path(sys.argv[2])
    start_idx   = int(sys.argv[3])

    main(json_file, output_file, start_idx)
