"""Canonical CKM Benchmark judge.

Two-stage cross-provider judge for grading whether a candidate future paper
substantively validates a research hypothesis:

    Stage 1 — GPT-4o-mini pre-filter (threshold 5.0)
    Stage 2 — GPT-4o re-judgment   (threshold 6.0, hit decision)

The judge is intentionally model-agnostic at the call-site: any OpenAI-compatible
endpoint can be substituted via the OPENAI_BASE_URL environment variable. This
keeps the canonical judge reproducible while leaving room for an alternative-judge
mode (Claude/Gemini) in v0.2 for submissions that use OpenAI models for generation.

Usage (single hypothesis vs single candidate paper):
    from ckm_benchmark.judge import judge_pair

    result = judge_pair(
        hypothesis_text=hypothesis_md,
        paper_title="SAFE: Stable Alignment Finetuning ...",
        paper_published="2026-02-04",
        paper_arxiv_id="2602.04651",
        paper_content=paper_excerpt,
    )
    print(result["is_hit"], result["score"], result["reasoning"])

Usage (full re-judge of a system submission):
    See ckm_benchmark.rejudge.
"""

from __future__ import annotations

import os
import re
import time
from dataclasses import dataclass

from ckm_benchmark.protocol import HIT_THRESHOLD, PRE_FILTER_THRESHOLD


PREFILTER_MODEL = os.environ.get("CKM_BENCHMARK_PREFILTER_MODEL", "gpt-4o-mini")
JUDGE_MODEL = os.environ.get("CKM_BENCHMARK_JUDGE_MODEL", "gpt-4o")

# Excerpt budget for candidate paper content fed to the judge.
# Long enough to capture the abstract + introduction; short enough to keep
# per-call cost predictable. Configurable via env, see .env.example.
PAPER_EXCERPT_CHAR_BUDGET = int(os.environ.get("CKM_BENCHMARK_PAPER_EXCERPT_BUDGET", "6000"))

JUDGE_TEMPERATURE = float(os.environ.get("CKM_BENCHMARK_JUDGE_TEMPERATURE", "0.0"))


JUDGE_SYSTEM_PROMPT = (
    "You are a rigorous scientific reviewer judging whether a published paper "
    "substantively validates a previously generated research hypothesis. "
    "Score strictly. Superficial topic overlap does NOT count as validation; "
    "you must find specific, concrete alignment between the hypothesis's named "
    "method delta or claim and the paper's actual contribution."
)


JUDGE_USER_TEMPLATE = """\
## Research Hypothesis
{hypothesis}

## Candidate Future Paper
- Title: {title}
- ArXiv ID: {arxiv_id}
- Published: {published}

Paper excerpt:
{paper_content}

---

Score the alignment of this paper to the hypothesis on a 1-10 scale, where:

  10 — paper realises essentially the same proposal named in the hypothesis
   8 — same problem and similar method, different specifics
   6 — same general research direction (HIT THRESHOLD)
   4 — related but with substantively different approach
   2 — same broad area, unrelated specifics
   1 — unrelated

Output your reasoning followed by the line:

    SCORE: <integer or one decimal>

Cite at least one concrete shared element (named method, dataset, target metric,
or architectural choice) when scoring 6 or above.
"""


@dataclass(frozen=True)
class JudgeResult:
    is_hit: bool
    score: float                # 1-10, two-stage final score
    stage1_score: float         # pre-filter score
    stage2_score: float | None  # None if pre-filter rejected
    reasoning: str
    tokens_used: int
    elapsed_s: float

    def to_dict(self) -> dict:
        return {
            "is_hit": self.is_hit,
            "score": self.score,
            "stage1_score": self.stage1_score,
            "stage2_score": self.stage2_score,
            "reasoning": self.reasoning,
            "tokens_used": self.tokens_used,
            "elapsed_s": self.elapsed_s,
        }


def _truncate(content: str, budget: int = PAPER_EXCERPT_CHAR_BUDGET) -> str:
    if len(content) <= budget:
        return content
    return content[:budget] + "\n[... truncated]"


def _parse_score(text: str) -> float:
    """Extract a 1-10 score from a judge response.

    Looks for ``SCORE: x`` (preferred) or any decimal number near the word "score".
    Falls back to 0.0 if no score is found, which causes the candidate to be
    treated as a clear miss.
    """
    match = re.search(r"SCORE\s*:\s*([\d.]+)", text, flags=re.IGNORECASE)
    if match:
        try:
            return float(match.group(1))
        except ValueError:
            pass
    # Fallback: any decimal near "score"
    match = re.search(r"score[^0-9]{0,8}([\d.]+)", text, flags=re.IGNORECASE)
    if match:
        try:
            return float(match.group(1))
        except ValueError:
            pass
    return 0.0


def _call_openai_chat(model: str, system: str, user: str, *, temperature: float = JUDGE_TEMPERATURE) -> tuple[str, int]:
    """Single OpenAI chat completion. Returns (content, tokens_used).

    Uses the OpenAI SDK; the user must have OPENAI_API_KEY set. The base URL
    is taken from OPENAI_BASE_URL when set, allowing OpenAI-compatible endpoints.
    """
    try:
        from openai import OpenAI
    except ImportError as e:
        raise RuntimeError(
            "openai package not installed. Run `pip install openai>=1.0`."
        ) from e

    client = OpenAI()  # picks up OPENAI_API_KEY (and OPENAI_BASE_URL) automatically
    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "system", "content": system}, {"role": "user", "content": user}],
        temperature=temperature,
    )
    content = response.choices[0].message.content or ""
    tokens = response.usage.total_tokens if response.usage else 0
    return content, tokens


def judge_pair(
    *,
    hypothesis_text: str,
    paper_title: str,
    paper_published: str,
    paper_arxiv_id: str,
    paper_content: str,
    prefilter_threshold: float = PRE_FILTER_THRESHOLD,
    hit_threshold: float = HIT_THRESHOLD,
) -> JudgeResult:
    """Two-stage judge for a single (hypothesis, candidate paper) pair.

    Returns a JudgeResult with the final score, the stage breakdown, the
    judge's reasoning text, and bookkeeping (tokens, elapsed seconds).
    """
    start = time.perf_counter()
    user_prompt = JUDGE_USER_TEMPLATE.format(
        hypothesis=hypothesis_text.strip(),
        title=paper_title.strip(),
        arxiv_id=paper_arxiv_id.strip(),
        published=paper_published.strip(),
        paper_content=_truncate(paper_content),
    )

    # Stage 1: pre-filter
    s1_text, s1_tokens = _call_openai_chat(PREFILTER_MODEL, JUDGE_SYSTEM_PROMPT, user_prompt)
    s1_score = _parse_score(s1_text)

    if s1_score < prefilter_threshold:
        return JudgeResult(
            is_hit=False,
            score=s1_score,
            stage1_score=s1_score,
            stage2_score=None,
            reasoning=s1_text,
            tokens_used=s1_tokens,
            elapsed_s=time.perf_counter() - start,
        )

    # Stage 2: stronger judge re-runs on borderline/high candidates
    s2_text, s2_tokens = _call_openai_chat(JUDGE_MODEL, JUDGE_SYSTEM_PROMPT, user_prompt)
    s2_score = _parse_score(s2_text)

    return JudgeResult(
        is_hit=s2_score >= hit_threshold,
        score=s2_score,
        stage1_score=s1_score,
        stage2_score=s2_score,
        reasoning=s2_text,
        tokens_used=s1_tokens + s2_tokens,
        elapsed_s=time.perf_counter() - start,
    )
