"""Stage 2 — DeepSeek-Reasoner judges each sampled span.

Reads sampled_spans.jsonl, calls deepseek-reasoner per span, appends
results to llm_judgments.jsonl one row at a time. Resumable on `span_id`.

Concurrency: ThreadPoolExecutor with --workers (default 100). The
upstream `with_retry` decorator on DeepSeekClient.generate handles
transient failures with exponential backoff (6 attempts, 30->300s).
Retry exhaustion produces (None, None, None); we record an
"API_FAILURE" row so it stays visible and isn't retried on resume.
"""
from __future__ import annotations

import argparse
import json
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path

from analysis.exploration.llm_validation._client import (
    DeepSeekClient,
    PRIMITIVES,
    parse_json_response,
)


PROMPT_TEMPLATE = """You are an expert annotator classifying a single span of a model's reasoning trace into one reasoning primitive based on the FUNCTION the span plays in the reasoning. Do not classify by surface phrasing; classify by what the span is trying to do. The trace may be on a logic puzzle or a math problem; the same taxonomy applies to both.

# Taxonomy (choose exactly one)
- PLAN: lays out a high-level strategy for SOLVING the problem — choosing which approach to take, what order to do things in, what the solver intends to do next. Forward-looking intent. Not the same as parsing or interpreting the problem statement.
- SETUP: problem comprehension and working-representation building — decoding a puzzle input grid, interpreting clue notation, restating the problem's constraints in the model's own working terms, defining variables / unknowns for a math problem, cross-referencing the example to clarify format. The span is building WHAT the problem is, before strategy or solving begins. Distinct from OTHER (truly off-topic content) and PLAN (forward strategy on HOW to solve).
- ENUMERATE: iterates over a finite collection of cases, possibilities, or configurations with TWO OR MORE alternatives in scope, considered as a list. Examples: "Case 1...; Case 2...", "For p=2: ...; for p=3: ...; for p=5: ...", "the row could be (a,b,c) or (b,a,c) or (c,a,b)".
- HYPOTHESIZE: posits ONE concrete tentative assumption or trial value (not a list of options) and explores its consequences. If the span considers multiple alternatives, it's ENUMERATE.
- COMPUTE: performs concrete arithmetic, algebraic manipulation, LaTeX manipulation, substitution, simplification, or rule-based deduction. Produces a new value or expression as the output of a mechanical operation.
- CHECK: tests a previously-derived candidate, assignment, or sub-claim against a constraint or target. The act is evaluation — outputting whether the candidate is consistent or inconsistent. The OUTCOME (pass/fail) is reported separately; both successful confirmations and detected contradictions are CHECK.
- BACKTRACK: abandons a current line of reasoning and pivots to a different approach or earlier state. The act of switching paths (often follows a failed CHECK).
- SUMMARIZE: states a final answer or top-line conclusion that closes a chain of reasoning. The dominant act is *announcing* the result, not deriving it. A span that mostly states "Therefore the answer is 42" or `\\boxed{{42}}` after the work is done is SUMMARIZE.
- OTHER: content that is NOT problem reasoning — output formatting, JSON escaping, restating problem instructions verbatim with no working interpretation, transitional filler. If the span builds any working representation or performs any reasoning act, prefer SETUP, PLAN, or one of the action labels.

# Tie-breakers (when a span contains multiple acts)
1. **SETUP vs PLAN:** if the span is decoding/interpreting the problem (e.g. "let's parse the grid: row 0 = [...]", "the clue '1 2' means...", "let me define a = the number of red marbles, b = blue marbles"), label SETUP — building WHAT the problem is. PLAN is forward strategy on HOW to solve ("I'll start by...", "my approach is...").
2. **SETUP vs OTHER:** if the span is doing real interpretive work on the problem's content, label SETUP. OTHER is for output formatting or pure instruction restatement that does not establish a working representation.
3. **HYPOTHESIZE vs ENUMERATE:** count alternatives in scope. One trial value = HYPOTHESIZE; ≥2 listed = ENUMERATE.
4. **HYPOTHESIZE vs COMPUTE:** a "Suppose X" or "By pigeonhole, let r be ..." that is a FORCED rule-based consequence is COMPUTE — it's a deduction, not a free assumption. HYPOTHESIZE is a free choice the model makes for exploration ("Let me try x = 5...").
5. **CHECK vs COMPUTE:** if the span produces a new value via mechanical work, it's COMPUTE. If it tests an existing value against a constraint, it's CHECK (regardless of outcome).
6. **CHECK vs ENUMERATE on per-case work:** when the dominant pattern is iterating over labeled cases ("For p=2: ...; For p=3: ...; For p=5: ..."), label ENUMERATE — the listing is the point. Only label CHECK if the span tests ONE specific candidate.
7. **BACKTRACK vs CHECK:** a span that *finds* a contradiction is CHECK with outcome=fail. A span that *pivots* to a new approach is BACKTRACK. If pivoting is the dominant act, BACKTRACK.
8. **COMPUTE / CHECK vs SUMMARIZE at the end:** if the span derives the final number via real arithmetic and ends with `\\boxed{{N}}` or "the answer is N", the dominant act is the COMPUTE (or CHECK). SUMMARIZE is for spans that mostly *announce* the final answer with little or no fresh derivation, often after the work was done in earlier spans.

# Outcome field (for CHECK only)
For CHECK, also report the outcome of the test:
- pass     — the candidate satisfied the constraint
- fail     — the candidate failed (contradiction, inconsistency, mismatch)
- unclear  — the test was inconclusive or aborted
For all other labels, set outcome to null.

Use the preceding context only as background; classify the SPAN.



# Examples (study the FUNCTION, not surface phrasing — your span will differ in domain and wording)

## Example 1 — PLAN
Preceding context: <<<Then we count the monster if: for V, only if direct; for G, only if reflected; for Z, always.>>>
Span: <<<But note: The path might go through the same square multiple times from different directions, and each time it might be direct or reflected depending on the segment. So we need to consider each crossing separately. This is a bit complex. We might need to set up equations or constraints. Alternatively, we can try to solve by logical deduction.>>>
Reasoning: states a forward-looking strategy ('we need to', 'we can try'); no derivation yet.
Output: {{"label": "PLAN", "outcome": "null", "confidence": "high"}}

## Example 2 — SETUP
Preceding context: <<<three numbers formed by reading top to bottom is 99, we start by defining the digits in the grid as follows: a,b,c / d,e,f>>>
Span: <<<We need to satisfy the following conditions:
1. The sum of the numbers formed by reading left to right: 100a + 10b + c + 100d + 10e + f = 999
2. The sum of the numbers formed by reading top to bottom: (100a + 10b + c) + (100d + 10e + f) = 99>>>
Reasoning: translates the problem statement into formal constraints in the model's working notation; no solving yet.
Output: {{"label": "SETUP", "outcome": "null", "confidence": "high"}}

## Example 3 — ENUMERATE
Preceding context: <<<Column 4: two separate filled cells, not adjacent.>>>
Span: <<<First, note that row 4 has two separate filled cells. That means in row 4, exactly two cells are filled, and they are not adjacent. So the pattern in row 4 is like: F E F E, or F E E F, or E F E F. But note: if the filled cells are at columns 1 and 3, that's F E F E. If at 1 and 4, that's F E E F. If at 2 and 4, that's E F E F.>>>
Reasoning: lists three concrete alternative configurations for row 4; multiple alternatives in scope.
Output: {{"label": "ENUMERATE", "outcome": "null", "confidence": "high"}}

## Example 4 — HYPOTHESIZE
Preceding context: <<<So clue should be 4. So that doesn't work.>>>
Span: <<<Maybe the rule is that Vampires are only visible if the line of sight to them does not involve any reflection at all, meaning that the beam has not hit any mirror before reaching them. So if there is a mirror before them in the path, they are not visible.>>>
Reasoning: posits one tentative rule ('Maybe the rule is...') as a free assumption to explore; no list of alternatives.
Output: {{"label": "HYPOTHESIZE", "outcome": "null", "confidence": "high"}}

## Example 5 — COMPUTE
Preceding context: <<<the losing positions are those where n ≡ 2 mod 4.>>>
Span: <<<The sequence of losing positions is 2, 6, 10, 14, ..., which can be written as 4k + 2 for k ≥ 0. The largest k such that 4k + 2 ≤ 2024 is found by solving 4k + 2 ≤ 2024: 4k ≤ 2022 ⟹ k ≤ 2022/4 = 505.5. The largest integer k is 505, so the number of terms is 505 - 0 + 1 = 506.>>>
Reasoning: performs concrete arithmetic to derive a new value (506) via mechanical operations.
Output: {{"label": "COMPUTE", "outcome": "null", "confidence": "high"}}

## Example 6 — CHECK
Preceding context: <<<Actually, (4,1) and (4,3) are not adjacent because col2 is between. So (4,3) is not connected to the main group.>>>
Span: <<<So the pattern is split into two components: one containing (2,1),(3,1),(4,1),(2,2),(1,2),(2,4),(1,4) and the other containing just (4,3). That violates the global connectivity condition.>>>
Reasoning: tests the candidate placement against the connectivity constraint and reports it fails.
Output: {{"label": "CHECK", "outcome": "null", "confidence": "high"}}

## Example 7 — BACKTRACK
Preceding context: <<<Contradiction.>>>
Span: <<<So if we try to not fix b=1 and c=1 from the island constraints, but wait, E has value 1, so b must be 1. F has value 1, so c must be 1. So they are forced. Thus, the only way to resolve is if d can be 3, but it can't. So maybe the puzzle has a different set of connections. Perhaps there is an additional connection that provides an extra bridge to I.>>>
Reasoning: after a failed CHECK, abandons the current line and pivots to a new hypothesis ('maybe the puzzle has a different set of connections').
Output: {{"label": "BACKTRACK", "outcome": "null", "confidence": "high"}}

## Example 8 — SUMMARIZE
Preceding context: <<<This results in intersections at points where x and y are multiples of 1/4 within the range [0, 1].>>>
Span: <<<6. **Counting Intersections**: The intersections occur at (0,0), (1/4,1/4), (1/2,1/2), and (1,1). Thus, the number of intersections is \boxed{{4}}.>>>
Reasoning: announces the final answer (4) after the work was already done; the boxed result is the dominant act.
Output: {{"label": "SUMMARIZE", "outcome": "null", "confidence": "high"}}

## Example 9 — OTHER
Preceding context: <<<The user has provided a puzzle description.>>>
Span: <<<Let me format my final answer according to the instructions, putting the solution within <answer> tags as requested.>>>
Reasoning: output formatting / restating instructions; no problem reasoning, no working representation built.
Output: {{"label": "OTHER", "outcome": "null", "confidence": "high"}}

# Preceding context
<<<{preceding_context}>>>

# Span
<<<{span_text}>>>

# Output
Return only a JSON object on a single line:
{{"label": "<one of PLAN|SETUP|ENUMERATE|HYPOTHESIZE|COMPUTE|CHECK|BACKTRACK|SUMMARIZE|OTHER>", "outcome": "<pass|fail|unclear|null>", "confidence": "<high|medium|low>", "reasoning": "<one sentence>"}}
"""


# DeepSeek approximate pricing (USD per 1M tokens; cache-miss rates).
# Verify at https://api-docs.deepseek.com/quick_start/pricing
_PRICING = {
    "deepseek-reasoner": (0.55, 2.19),
    "deepseek-chat":     (0.27, 1.10),
}


def get_pricing(model: str) -> tuple[float, float]:
    return _PRICING.get(model, _PRICING["deepseek-reasoner"])


def build_prompt(span: dict) -> str:
    return PROMPT_TEMPLATE.format(
        preceding_context=span.get("preceding_context", ""),
        span_text=span["span_text"],
    )


def estimate_cost(n_remaining: int, model: str) -> tuple[int, int, float]:
    """Rough cost estimate per spec. Returns (input_tokens, output_tokens, usd)."""
    in_per = 1400
    # deepseek-chat has no CoT, so output is just the small JSON
    out_per = 350 if model == "deepseek-reasoner" else 80
    in_rate, out_rate = get_pricing(model)
    in_tokens = in_per * n_remaining
    out_tokens = out_per * n_remaining
    cost = in_tokens * in_rate / 1e6 + out_tokens * out_rate / 1e6
    return in_tokens, out_tokens, cost


def load_done_ids(out_path: Path) -> set[str]:
    if not out_path.exists():
        return set()
    done = set()
    with open(out_path) as f:
        for line in f:
            try:
                done.add(json.loads(line)["span_id"])
            except Exception:
                pass
    return done


def _single_call(
    client: DeepSeekClient, model: str, prompt: str,
    max_tokens: int, temperature: float | None,
) -> tuple[str | None, str | None, str | None, int | None]:
    """Make one LLM call, return (label, outcome, content, tokens). label=None means parse error or API failure."""
    kwargs = {"max_tokens": max_tokens}
    if temperature is not None and model != "deepseek-reasoner":
        kwargs["temperature"] = temperature
    content, total_tokens, _ = client.generate(model=model, prompt=prompt, **kwargs)
    if content is None:
        return None, None, None, None
    parsed = parse_json_response(content) or {}
    label = parsed.get("label")
    if label not in PRIMITIVES:
        return None, None, content, total_tokens
    outcome = parsed.get("outcome")
    if isinstance(outcome, str) and outcome.strip().lower() in {"null", "none", ""}:
        outcome = None
    return label, outcome, content, total_tokens


def _majority(labels: list[str]) -> tuple[str, int]:
    """Return (winner, vote_count). Tie-break: lexicographic on label."""
    from collections import Counter
    if not labels:
        return "PARSE_ERROR", 0
    counts = Counter(labels).most_common()
    top = counts[0][1]
    winners = sorted(l for l, c in counts if c == top)
    return winners[0], top


def make_judge_row(
    client: DeepSeekClient,
    model: str,
    span: dict,
    max_tokens: int,
    temperature: float | None = None,
    n_samples: int = 1,
) -> dict:
    """Make N judge calls per span. With N=1, returns one label.
    With N>1, returns the majority-vote label plus all per-sample labels.
    """
    prompt = build_prompt(span)
    t0 = time.time()
    labels: list[str] = []
    outcomes: list[str | None] = []
    contents: list[str | None] = []
    tokens_used = 0
    n_parse = 0
    n_api = 0
    for _ in range(n_samples):
        lab, outc, content, tok = _single_call(
            client, model, prompt, max_tokens, temperature,
        )
        if content is None:
            n_api += 1
            continue
        if lab is None:
            n_parse += 1
            contents.append(content)
            continue
        labels.append(lab)
        outcomes.append(outc)
        contents.append(content)
        if tok:
            tokens_used += tok
    dt = time.time() - t0

    in_rate, out_rate = get_pricing(model)
    cost = (
        tokens_used * 0.8 * in_rate / 1e6
        + tokens_used * 0.2 * out_rate / 1e6
    )

    if not labels:
        # All N calls failed (API or parse)
        return {
            "span_id": span["span_id"],
            "llm_label": "API_FAILURE" if n_api > n_parse else "PARSE_ERROR",
            "llm_outcome": None,
            "llm_labels": [],
            "llm_outcomes": [],
            "n_samples": n_samples,
            "n_parse_err": n_parse,
            "n_api_err": n_api,
            "raw_content": contents[0] if contents else None,
            "tokens_used": tokens_used or None,
            "cost_estimate_usd": cost,
            "elapsed_s": dt,
        }

    majority_label, vote_count = _majority(labels)
    # Outcome majority computed only on samples that voted for the majority label
    aligned_outcomes = [o for l, o in zip(labels, outcomes) if l == majority_label]
    outcome_majority, _ = _majority([o or "null" for o in aligned_outcomes])
    if outcome_majority == "null":
        outcome_majority = None
    return {
        "span_id": span["span_id"],
        "llm_label": majority_label,
        "llm_outcome": outcome_majority,
        "llm_labels": labels,
        "llm_outcomes": outcomes,
        "n_samples": n_samples,
        "n_parse_err": n_parse,
        "n_api_err": n_api,
        "vote_count": vote_count,
        "raw_content": contents[0] if contents else None,
        "tokens_used": tokens_used,
        "cost_estimate_usd": cost,
        "elapsed_s": dt,
    }


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--in", dest="in_path", required=True, type=Path)
    ap.add_argument("--out", required=True, type=Path)
    ap.add_argument("--model", default="deepseek-reasoner")
    ap.add_argument("--workers", type=int, default=100)
    ap.add_argument("--max-tokens", type=int, default=4096)
    ap.add_argument(
        "--temperature", type=float, default=0.0,
        help="Sampling temperature (deepseek-chat only; ignored by reasoner)",
    )
    ap.add_argument("--limit", type=int, default=None,
                    help="Process at most N remaining spans (after resume)")
    ap.add_argument("--n-samples", type=int, default=1,
                    help="Self-consistency: samples per span. >1 implies temperature>0")
    ap.add_argument("--dry-run", action="store_true",
                    help="Print cost estimate and exit without API calls")
    args = ap.parse_args()

    with open(args.in_path) as f:
        spans = [json.loads(line) for line in f]

    args.out.parent.mkdir(parents=True, exist_ok=True)
    done = load_done_ids(args.out)
    remaining = [s for s in spans if s["span_id"] not in done]
    if args.limit is not None:
        remaining = remaining[: args.limit]

    print(f"Total spans: {len(spans)}")
    print(f"Already done: {len(done)}")
    print(f"To process:   {len(remaining)}")

    in_tok, out_tok, cost = estimate_cost(len(remaining), args.model)
    in_tok *= args.n_samples
    out_tok *= args.n_samples
    cost *= args.n_samples
    in_rate, out_rate = get_pricing(args.model)
    print(f"Estimated input tokens : {in_tok:,}  (n_samples={args.n_samples})")
    print(f"Estimated output tokens: {out_tok:,}")
    print(f"Estimated cost (USD)   : ${cost:.2f}")
    print(
        f"  (assumes {in_rate}/M input, {out_rate}/M output for {args.model} "
        "-- verify at https://api-docs.deepseek.com/quick_start/pricing)"
    )
    if args.n_samples > 1 and args.temperature == 0.0 and args.model != "deepseek-reasoner":
        print(
            "WARNING: --n-samples > 1 with temperature=0 produces identical "
            "samples. Pass --temperature 0.7 (or similar) for self-consistency."
        )

    if args.dry_run:
        print("\n--dry-run: exiting without API calls.")
        return

    if not remaining:
        print("Nothing to do.")
        return

    client = DeepSeekClient()
    write_lock = threading.Lock()
    n_done = 0
    n_parse_err = 0
    n_api_err = 0

    def process_one(span: dict) -> dict:
        nonlocal n_done, n_parse_err, n_api_err
        row = make_judge_row(
            client, args.model, span, args.max_tokens,
            temperature=args.temperature,
            n_samples=args.n_samples,
        )
        with write_lock:
            with open(args.out, "a") as fh:
                fh.write(json.dumps(row) + "\n")
            n_done += 1
            if row["llm_label"] == "PARSE_ERROR":
                n_parse_err += 1
            elif row["llm_label"] == "API_FAILURE":
                n_api_err += 1
            if n_done % 10 == 0 or n_done == len(remaining):
                print(
                    f"  [{n_done}/{len(remaining)}] "
                    f"parse_err={n_parse_err} api_err={n_api_err}",
                    flush=True,
                )
        return row

    print(f"\nDispatching with {args.workers} workers...")
    t0 = time.time()
    with ThreadPoolExecutor(max_workers=args.workers) as ex:
        futures = [ex.submit(process_one, s) for s in remaining]
        for fut in as_completed(futures):
            fut.result()  # propagate exceptions if any leak through with_retry
    dt = time.time() - t0
    print(f"\nDone in {dt:.1f}s. Wrote {n_done} rows -> {args.out}")
    print(f"Parse errors: {n_parse_err}  API failures: {n_api_err}")


if __name__ == "__main__":
    main()
