#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os, re, time, json, argparse
import pandas as pd
from typing import Tuple, Dict, Any
from openai import OpenAI

# -----------------------------
# Config
# -----------------------------
DEFAULT_MODEL = "gpt-4o"

PAIR_PROMPT = """Given a question along with the caption of an image and gt_answer, evaluate the two provided candidate answers.
Determine which answer is accurate, logical, and helpful.

Question:
{question}

Caption:
{caption}

Ground Truth:
{gt}

Answer 1:
{ans1}

Answer 2:
{ans2}

If you think Answer 1 is better, respond with -1.
If Answer 2 is better respond with 1.
If you think the result is tie, output 0.

Only respond with either -1 or 0 or 1.
"""

# 打分版（要求输出 JSON）
SCORE_PROMPT = """You are a strict judge. Score two candidate answers with the rubric:

- coverage (0–10): does the answer cover the key information in the ground truth
- accuracy (0–10): factual consistency against the ground truth (penalize conflicts/errors)
- details (0–10): useful, reasonable extra details
- fluency (0–10): clarity and naturalness of language

Compute weighted total (0–10) with:
  total = 0.40*coverage + 0.40*accuracy + 0.15*details + 0.05*fluency
Round totals to one decimal.

Also decide which answer is better:
  decision = -1 if Answer 1 is better; 1 if Answer 2 is better; 0 if tie.

Return ONLY a JSON object with this exact schema:
{{
  "decision": -1 | 0 | 1,
  "answer1": {{"coverage": <0-10>, "accuracy": <0-10>, "details": <0-10>, "fluency": <0-10>, "total": <0-10>}},
  "answer2": {{"coverage": <0-10>, "accuracy": <0-10>, "details": <0-10>, "fluency": <0-10>, "total": <0-10>}}
}}

Question:
{question}

Caption:
{caption}

Ground Truth:
{gt}

Answer 1:
{ans1}

Answer 2:
{ans2}
"""


# -----------------------------
# Helpers
# -----------------------------
def build_prompt(q: str, caption: str, gt: str, ans1: str, ans2: str) -> str:
    return PAIR_PROMPT.format(
        question=q or "",
        caption=caption or "",
        gt=gt or "",
        ans1=str(ans1),
        ans2=str(ans2),
    )

def build_score_prompt(q: str, caption: str, gt: str, ans1: str, ans2: str) -> str:
    return SCORE_PROMPT.format(
        question=q or "",
        caption=caption or "",
        gt=gt or "",
        ans1=str(ans1),
        ans2=str(ans2),
    )

def parse_vote(text: str) -> str:
    m = re.search(r"(-1|0|1)", (text or "").strip())
    return m.group(1) if m else "0"

def call_chat(client: OpenAI, model: str, prompt: str, max_tokens: int = 8, json_mode: bool = False, max_retries: int = 3) -> str:
    """调用 Chat Completions，支持 JSON-only 回复（通过 system 约束 & 低温度），附带重试。"""
    last_err = None
    for i in range(max_retries):
        try:
            kwargs: Dict[str, Any] = dict(
                model=model,
                temperature=0.0,
                max_tokens=max_tokens,
                messages=[{"role": "user", "content": prompt}],
            )
            # 对 JSON 结构，给足输出长度
            if json_mode:
                kwargs["max_tokens"] = 300
            rsp = client.chat.completions.create(**kwargs)
            return (rsp.choices[0].message.content or "").strip()
        except Exception as e:
            last_err = e
            time.sleep(1.5 * (i + 1))
    raise RuntimeError(f"OpenAI call failed after {max_retries} attempts: {last_err}")

def safe_parse_scores(txt: str) -> Dict[str, Any]:
    """把模型输出解析成 dict，如果失败就尽量兜底成安全默认值。"""
    template = {
        "decision": 0,
        "answer1": {"coverage": 0.0, "accuracy": 0.0, "details": 0.0, "fluency": 0.0, "total": 0.0},
        "answer2": {"coverage": 0.0, "accuracy": 0.0, "details": 0.0, "fluency": 0.0, "total": 0.0},
    }
    try:
        # 去掉可能的围栏
        s = txt.strip()
        if s.startswith("```"):
            s = re.sub(r"^```(json)?", "", s).strip()
            if s.endswith("```"):
                s = s[:-3].strip()
        obj = json.loads(s)
        # 基本字段校验
        for key in ["answer1", "answer2"]:
            for k2 in ["coverage", "accuracy", "details", "fluency", "total"]:
                obj[key][k2] = float(obj[key][k2])
        obj["decision"] = int(obj.get("decision", 0))
        # 裁剪到合法范围
        for key in ["answer1", "answer2"]:
            for k2 in ["coverage", "accuracy", "details", "fluency", "total"]:
                v = max(0.0, min(10.0, float(obj[key][k2])))
                obj[key][k2] = round(v, 1)
        if obj["decision"] not in (-1, 0, 1):
            obj["decision"] = 0
        return obj
    except Exception:
        return template

def compute_winrate(df: pd.DataFrame) -> Tuple[float, int, int, int]:
    """只看 vote_order1（= decision from scoring prompt）：-1 baseline 胜，1 sparse 胜，0 tie"""
    wins_sparse = (df["vote_order1"] == "1").sum()
    wins_base   = (df["vote_order1"] == "-1").sum()
    ties        = (df["vote_order1"] == "0").sum()
    denom = wins_sparse + wins_base
    wr = wins_sparse / denom if denom > 0 else 0.0

    print("\n=== GPT-4 Judge Summary (vote_order1 only) ===")
    print(f"baseline wins : {wins_base}")
    print(f"sparse wins   : {wins_sparse}")
    print(f"ties          : {ties}")
    print(f"Sparse Winning Rate (no ties): {wr:.3f}")
    return wr, wins_sparse, wins_base, ties

def pivot_rows(df: pd.DataFrame):
    required = ["question", "gt_answer", "baseline_txt", "two_scale_txt", "roi_txt"]
    if not all(c in df.columns for c in required):
        raise ValueError(f"CSV 需要包含列：{required}；可选列：caption, id")
    rows = []
    for i, r in df.iterrows():
        rows.append({
            "id": r.get("id", i),
            "question": r.get("question", ""),
            "caption": r.get("caption", ""),
            "gt_answer": r.get("gt_answer", ""),
            "baseline_txt": str(r.get("baseline_txt", "")),
            "two_scale_txt": str(r.get("two_scale_txt", "")),
            "roi_txt": str(r.get("roi_txt", "")),
        })
    return rows

# -----------------------------
# Main
# -----------------------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--csv", required=True, help="输入结果 CSV（需含 question, gt_answer, output, sparse_output；可选 caption, id）")
    ap.add_argument("--out", default="gpt4_judge_results.csv", help="输出评审明细 CSV")
    ap.add_argument("--model", default=DEFAULT_MODEL, help="GPT-4 系列模型名，如 gpt-4o / gpt-4.1 / gpt-4o-mini")
    ap.add_argument("--comp_mode", type=int, default=0, help="0 denote baseline vs roi; 1 denote roi vsvs two_scale; ")

    args = ap.parse_args()

    if not os.environ.get("OPENAI_API_KEY"):
        raise SystemExit("ERROR: OPENAI_API_KEY not set")

    client = OpenAI()
    df_in = pd.read_csv(args.csv)
    pairs = pivot_rows(df_in)
    print(f"[Info] Found {len(pairs)} pairs to judge with model={args.model}.")
    
    from tqdm import tqdm

    out_rows = []
    for idx, x in enumerate(tqdm(pairs, desc="Evaluating pairs")):
        q, cap, gt = x["question"], x["caption"], x["gt_answer"]
        if args.comp_mode ==0:
            a1, a2 = x["baseline_txt"], x["roi_txt"]
        if args.comp_mode ==1:          
            a1, a2 = x["roi_txt"], x["two_scale_txt"]

        # ---- order1：带打分 & 决策（Answer1=baseline, Answer2=sparse）----
        score_prompt = build_score_prompt(q, cap, gt, a1, a2)
        score_txt = call_chat(client, args.model, score_prompt, json_mode=True, max_tokens=300)
        score_obj = safe_parse_scores(score_txt)
        v1 = str(score_obj.get("decision", 0))

        # ---- order2：只做一次对调判别（Answer1=sparse, Answer2=baseline）----
        pair_prompt = build_prompt(q, cap, gt, a2, a1)
        v2 = parse_vote(call_chat(client, args.model, pair_prompt, max_tokens=8, json_mode=False))

        # 输出一行
        out_rows.append({
            "id": x["id"],
            "question": q,
            "caption": cap,
            "gt_answer": gt,
            "baseline": a1,
            "sparse": a2,
            "vote_order1": v1,   # from scoring JSON decision
            "vote_order2": v2,   # swapped order check
            # baseline (answer1) scores
            "base_coverage": score_obj["answer1"]["coverage"],
            "base_accuracy": score_obj["answer1"]["accuracy"],
            "base_details":  score_obj["answer1"]["details"],
            "base_fluency":  score_obj["answer1"]["fluency"],
            "base_total":    score_obj["answer1"]["total"],
            # sparse (answer2) scores
            "sparse_coverage": score_obj["answer2"]["coverage"],
            "sparse_accuracy": score_obj["answer2"]["accuracy"],
            "sparse_details":  score_obj["answer2"]["details"],
            "sparse_fluency":  score_obj["answer2"]["fluency"],
            "sparse_total":    score_obj["answer2"]["total"],
            # 原始 JSON（便于排错）
            "judge_json": score_txt[:400],
        })

        if (idx + 1) % 10 == 0:
            print(f"  judged {idx+1}/{len(pairs)}")

    df_out = pd.DataFrame(out_rows)
    df_out.to_csv(args.out, index=False)

    # ---- 全局每项统计 ----
    metrics = ["coverage", "accuracy", "details", "fluency", "total"]

    summary_rows = []
    for m in metrics:
        a1 = pd.to_numeric(df_out[f"base_{m}"], errors="coerce")
        a2 = pd.to_numeric(df_out[f"sparse_{m}"], errors="coerce")
        delta = a2 - a1

        summary_rows.append({
            "metric": m,
            "baseline_avg": a1.mean(),
            "sparse_avg": a2.mean(),
            "delta_avg": delta.mean(),          # >0 代表 sparse 平均提升
            "baseline_std": a1.std(ddof=1),
            "sparse_std": a2.std(ddof=1),
            "baseline_median": a1.median(),
            "sparse_median": a2.median(),
            "n": int(a1.notna().sum())          # 有效样本数
        })

    summary_df = pd.DataFrame(summary_rows)

    print("\n=== Global metric stats (0–10) ===")
    for _, r in summary_df.iterrows():
        print(f"{r['metric']:>8} | base_avg={r['baseline_avg']:.2f}  "
            f"sparse_avg={r['sparse_avg']:.2f}  delta_avg={r['delta_avg']:.2f}  "
            f"n={int(r['n'])}")

    # 可选：把统计也保存成一个 summary CSV
    summary_path = os.path.splitext(args.out)[0] + "_summary.csv"
    summary_df.to_csv(summary_path, index=False)
    print(f"Saved summary to: {summary_path}")

    # a1_total_avg = pd.to_numeric(df_out["base_total"], errors="coerce").mean()
    # a2_total_avg = pd.to_numeric(df_out["sparse_total"], errors="coerce").mean()
    # print("\n=== Average Totals (0–10) ===")
    # print(f"Answer 1 (baseline) avg total: {a1_total_avg:.2f}")
    # print(f"Answer 2 (sparse)   avg total: {a2_total_avg:.2f}")

    # 汇总（只看 vote_order1）
    compute_winrate(df_out)


if __name__ == "__main__":
    main()
