#!/usr/bin/env python
# -*- coding: utf-8 -*-

import sys, json, time
from pathlib import Path
from typing import Dict, Tuple
import concurrent.futures as cf

import openai
import pandas as pd
from tqdm import tqdm

# ————— 1. OpenAI 账号 —————
openai_client = openai.OpenAI(
    base_url="",
    api_key="",
)
MODEL_NAME  = "deepseek-chat"
TEMPERATURE = 0.0 #1.0
MAX_WORKERS = 64          
MAX_RETRIES = 10           


PROMPT = """We have the following question:

Question: {question}
The reference (standard) answer to this question is as follows: {reference}

We now have two student responses to this question.
Response 1 is as follows: {resp1}
Response 2 is as follows: {resp2}

Now, you should evaluate the two responses based on the content of the question, using the standard answer as the sole basis for judgment.
Determine which of the two—Response 1 or Response 2—is closer to the reference answer in terms of content.

Please begin with a brief analysis, and then provide your final judgment in one of the following forms:

If one response is better:
"Therefore, [Response 1] is better." or "Therefore, [Response 2] is better."

If the two responses are roughly equal in quality:
"Therefore, [Both responses are equal]."
"""


PAIR_KEYS = [
    ("student_answer_a_en", "student_answer_b_en", "student_answer_a_vs_student_answer_b"),
    ("student_answer_a_en", "model_answer_a_en",  "student_answer_a_vs_model_answer_a"),
    ("student_answer_a_en", "model_answer_b_en",  "student_answer_a_vs_model_answer_b"),
    ("student_answer_b_en", "model_answer_a_en",  "student_answer_b_vs_model_answer_a"),
    ("student_answer_b_en", "model_answer_b_en",  "student_answer_b_vs_model_answer_b"),
    ("model_answer_a_en",   "model_answer_b_en",  "model_answer_a_vs_model_answer_b"),
]


def call_gpt(prompt: str) -> str:
    backoff = 5
    for attempt in range(1, MAX_RETRIES + 1):
        try:
            rsp = openai_client.chat.completions.create(
                model=MODEL_NAME,
                temperature=TEMPERATURE,
                max_tokens=512,
                messages=[{"role": "user", "content": prompt}],
            )
            content = rsp.choices[0].message.content
            if content and content.strip():
                return content.strip()
            raise ValueError("Empty content returned from model")

        except (openai.OpenAIError, ValueError) as e:
            if attempt >= MAX_RETRIES:
                raise RuntimeError(f"Failed after {MAX_RETRIES} retries: {e}") from e
            print(f"OpenAI error/empty response: {e}. "
                  f"Retry {attempt}/{MAX_RETRIES} in {backoff}s", file=sys.stderr)
            time.sleep(backoff)
            backoff = min(backoff * 2, 60)   

def last_line(text: str) -> str:
    lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
    return lines[-1] if lines else ""

def process_pair(item: Dict, k1: str, k2: str, label_key: str) -> Tuple[str, str, str, str]:
    prompt = PROMPT.format(
        question=item["Question_en"],
        reference=item["Concise_Reference_en"],
        resp1=item[k1],
        resp2=item[k2],
    )
    full    = call_gpt(prompt)
    verdict = last_line(full)
    human   = item[label_key]
    return label_key, full, verdict, human

def process_item(item: Dict, executor: cf.ThreadPoolExecutor) -> Dict:
    row = {"id": item["id"]}

    futures = [
        executor.submit(process_pair, item, k_left, k_right, label_key)
        for k_left, k_right, label_key in PAIR_KEYS
    ]

    for fut in cf.as_completed(futures):
        label_key, full, verdict, human = fut.result()

        prefix = label_key
        row[f"{prefix}_model_full"]    = full
        row[f"{prefix}_model_verdict"] = verdict
        row[f"{prefix}_human"]         = human

        print(f"[{item['id']}] {prefix}: {verdict}")

    return row

def main(json_path: Path, out_prefix: Path, start_index: int = 0):
    data = json.loads(json_path.read_text(encoding="utf-8"))
    data_to_process = data[start_index:]

    rows, part = [], 0

    with cf.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        for item in tqdm(data_to_process, desc="GPT Inference"):
            row = process_item(item, executor)
            rows.append(row)

            if len(rows) >= 50:
                part += 1
                pd.DataFrame(rows).to_excel(f"{out_prefix.stem}_part_after450_after500_{part}.xlsx", index=False)
                print(f"Saved {out_prefix.stem}_part{part}.xlsx  ({len(rows)})")
                rows.clear()

        if rows:
            part += 1
            pd.DataFrame(rows).to_excel(f"{out_prefix.stem}_part{part}.xlsx", index=False)
            print(f"Saved {out_prefix.stem}_part{part}.xlsx  ({len(rows)})")

if __name__ == "__main__":
    if len(sys.argv) != 4:
        sys.exit("usage: python pairwise_inference_save_full.py  data.json  results.xlsx  start_index")

    json_file   = Path(sys.argv[1])
    output_file = Path(sys.argv[2])
    start_idx   = int(sys.argv[3])

    main(json_file, output_file, start_idx)
