from __future__ import annotations
import argparse
import asyncio
import json
from pathlib import Path
from typing import Any, Dict, List

from openai import AsyncOpenAI

CONFIG: Dict[str, Any] = {
    "model": "o4-mini",
    "input_path": Path("./exp/...json"),
    "output_path": Path("./exp/...json"),
    "max_concurrent": 20,
    "temperature": 0.0,
    "system_prompt": (
        "You are an exacting fact-checker.\n\n"
        "Task\n"
        "––––\n"
        "Compare the **REFERENCE ANSWER** with the **CANDIDATE** answer and judge\n"
        "whether they express the *same factual content*.\n"
        "Ignore writing style, order, or extra background; focus only on whether\n"
        "the key facts (names, dates, amounts, percentages, etc.) truly match.\n\n"
        "Instructions\n"
        "------------\n"
        "1. If every critical fact in the candidate unambiguously agrees with the\n"
        "   reference, label **correct**.\n"
        "   • Paraphrases or synonymous wording are acceptable.\n"
        "   • Numbers must be equal after unit conversion/rounding (e.g. “431 M” = “$431 million”).\n"
        "2. Otherwise, label **incorrect** (do *not* use “partial”, “unknown”, etc.).\n\n"
        "Output format\n"
        "-------------\n"
        "Return a single-line JSON object *exactly* like:\n\n"
        '{"label": "correct", "rationale": "..."}\n\n'
        "• \"label\" – either \"correct\" or \"incorrect\" (lowercase).\n"
        "• \"rationale\" – a clear, factual explanation in **under 50 words**.\n"
        "  – If correct, explain briefly why the key facts match.\n"
        "  – If incorrect, identify the first major contradiction (e.g. different person, date, amount).\n"
        "  – Use specific, content-based justifications (not generic statements).\n\n"
        "Constraints\n"
        "-----------\n"
        "* Do not include extra keys, arrays, markdown, or formatting.\n"
        "* Do not quote the input or repeat the entire answer.\n"
        "* Focus only on factual correctness – avoid style or completeness judgments.\n"
        "* Do not make up metrics or mention unrelated information like percentages unless directly relevant.\n"
    ),
}

def load_json(path: Path) -> Any:
    return json.loads(path.read_text(encoding="utf-8"))

def load_jsonl(path):
    datas = []
    with open(path, 'r', encoding='utf-8') as frs:
        for fr in frs:
            datas.append(json.loads(fr))
    return datas

def dump_json(obj: Any, path: Path) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(json.dumps(obj, ensure_ascii=False, indent=2), encoding="utf-8")

async def grade_pair(
    client: AsyncOpenAI,
    sem: asyncio.Semaphore,
    *,
    model: str,
    question: str,
    answer: str,
    candidate: str,
    system_prompt: str,
    temperature: float,
) -> dict[str, str]:
    """Return {'label': 'correct'|'incorrect', 'rationale': str}"""
    messages = [
        {"role": "system", "content": system_prompt},
        {
            "role": "user",
            "content": (
                f"QUESTION:\n{question}\n\n"
                f"ANSWER (ground truth):\n{answer}\n\n"
                f"CANDIDATE:\n{candidate}"
            ),
        },
    ]
    async with sem:
        resp = await client.chat.completions.create(
            model=model,
            messages=messages,
            response_format={"type": "json_object"},
            # temperature=temperature,
        )
    return json.loads(resp.choices[0].message.content)

async def run(cfg: Dict[str, Any]) -> None:
    records: List[dict] = load_json(cfg["input_path"])['results']
    client = AsyncOpenAI()
    sem = asyncio.Semaphore(cfg["max_concurrent"])

    gt_datas = [d['answer'] for d in load_jsonl(str(cfg["ans_path"]))]
    
    datas = [
        {
            'question': r['question'],
            'result': r['answer'],
            'answer': ans,
        } for r, ans in zip(records, gt_datas)
    ]
    tasks = [
        grade_pair(
            client,
            sem,
            model=cfg["model"],
            question=rec.get("query", ""),
            answer=rec["answer"],
            candidate=rec["result"],
            system_prompt=cfg["system_prompt"],
            temperature=cfg["temperature"],
        )
        for rec in datas
    ]

    verdicts = await asyncio.gather(*tasks, return_exceptions=True)

    correct = 0
    for rec, verdict in zip(records, verdicts):
        if isinstance(verdict, Exception):
            rec.update({"label": "error", "rationale": str(verdict), "is_correct": False})
        else:
            label = verdict.get("label", "").strip().lower()
            rec.update(verdict)
            rec["is_correct"] = label == "correct"
            correct += rec["is_correct"]

    acc = correct / len(records) if records else 0.0
    dump_json(records, cfg["output_path"])
    print(
        f"✓ graded {len(records)} items — ACC: {acc:.2%} "
        f"({correct} / {len(records)}) ➜ results → {cfg['output_path']}"
    )

def parse_cli() -> Dict[str, Any]:
    p = argparse.ArgumentParser(description="OpenAI-based accuracy evaluator")
    p.add_argument("--model",           help="OpenAI model id")
    p.add_argument("--input_path",  type=Path, help="Path to evaluation JSON")
    p.add_argument("--ans_path",    type=Path, help="Path to evaluation JSON")
    p.add_argument("--output_path", type=Path, help="Where to write scored JSON")
    p.add_argument("--max-concurrent",  type=int, help="Maximum parallel API calls")
    p.add_argument("--temperature",     type=float, help="Sampling temperature")
    return {k: v for k, v in vars(p.parse_args()).items() if v is not None}

if __name__ == "__main__":
    cli_overrides = parse_cli()
    final_cfg = {**CONFIG, **cli_overrides}
    asyncio.run(run(final_cfg))
