"""Stage 2 — DeepSeek-Reasoner judges each sampled span.

Reads sampled_spans.jsonl, calls deepseek-reasoner per span, appends
results to llm_judgments.jsonl one row at a time. Resumable on `span_id`.

Concurrency: ThreadPoolExecutor with --workers (default 100). The
upstream `with_retry` decorator on DeepSeekClient.generate handles
transient failures with exponential backoff (6 attempts, 30->300s).
Retry exhaustion produces (None, None, None); we record an
"API_FAILURE" row so it stays visible and isn't retried on resume.
"""
from __future__ import annotations

import argparse
import json
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path

from analysis.exploration.llm_validation._client import (
    DeepSeekClient,
    PRIMITIVES,
    parse_json_response,
)


PROMPT_TEMPLATE = """You are an expert annotator classifying a single span of a model's reasoning trace into one reasoning primitive based on the FUNCTION the span plays in the reasoning. Do not classify by surface phrasing; classify by what the span is trying to do. The trace may be on a logic puzzle or a math problem; the same taxonomy applies to both.

# Taxonomy (choose exactly one)
- PLAN: lays out a high-level strategy for SOLVING the problem — choosing which approach to take, what order to do things in, what the solver intends to do next. Forward-looking intent. Not the same as parsing or interpreting the problem statement.
- SETUP: problem comprehension and working-representation building — decoding a puzzle input grid, interpreting clue notation, restating the problem's constraints in the model's own working terms, defining variables / unknowns for a math problem, cross-referencing the example to clarify format. The span is building WHAT the problem is, before strategy or solving begins. Distinct from OTHER (truly off-topic content) and PLAN (forward strategy on HOW to solve).
- ENUMERATE: iterates over a finite collection of cases, possibilities, or configurations with TWO OR MORE alternatives in scope, considered as a list. Examples: "Case 1...; Case 2...", "For p=2: ...; for p=3: ...; for p=5: ...", "the row could be (a,b,c) or (b,a,c) or (c,a,b)".
- HYPOTHESIZE: posits ONE concrete tentative assumption or trial value (not a list of options) and explores its consequences. If the span considers multiple alternatives, it's ENUMERATE.
- COMPUTE: performs concrete arithmetic, algebraic manipulation, LaTeX manipulation, substitution, simplification, or rule-based deduction. Produces a new value or expression as the output of a mechanical operation.
- CHECK: tests a previously-derived candidate, assignment, or sub-claim against a constraint or target. The act is evaluation — outputting whether the candidate is consistent or inconsistent. The OUTCOME (pass/fail) is reported separately; both successful confirmations and detected contradictions are CHECK.
- BACKTRACK: abandons a current line of reasoning and pivots to a different approach or earlier state. The act of switching paths (often follows a failed CHECK).
- SUMMARIZE: states a final answer or top-line conclusion that closes a chain of reasoning. The dominant act is *announcing* the result, not deriving it. A span that mostly states "Therefore the answer is 42" or `\\boxed{{42}}` after the work is done is SUMMARIZE.
- OTHER: content that is NOT problem reasoning — output formatting, JSON escaping, restating problem instructions verbatim with no working interpretation, transitional filler. If the span builds any working representation or performs any reasoning act, prefer SETUP, PLAN, or one of the action labels.

# Tie-breakers (when a span contains multiple acts)
1. **SETUP vs PLAN:** if the span is decoding/interpreting the problem (e.g. "let's parse the grid: row 0 = [...]", "the clue '1 2' means...", "let me define a = the number of red marbles, b = blue marbles"), label SETUP — building WHAT the problem is. PLAN is forward strategy on HOW to solve ("I'll start by...", "my approach is...").
2. **SETUP vs OTHER:** if the span is doing real interpretive work on the problem's content, label SETUP. OTHER is for output formatting or pure instruction restatement that does not establish a working representation.
3. **HYPOTHESIZE vs ENUMERATE:** count alternatives in scope. One trial value = HYPOTHESIZE; ≥2 listed = ENUMERATE.
4. **HYPOTHESIZE vs COMPUTE:** a "Suppose X" or "By pigeonhole, let r be ..." that is a FORCED rule-based consequence is COMPUTE — it's a deduction, not a free assumption. HYPOTHESIZE is a free choice the model makes for exploration ("Let me try x = 5...").
5. **CHECK vs COMPUTE:** if the span produces a new value via mechanical work, it's COMPUTE. If it tests an existing value against a constraint, it's CHECK (regardless of outcome).
6. **CHECK vs ENUMERATE on per-case work:** when the dominant pattern is iterating over labeled cases ("For p=2: ...; For p=3: ...; For p=5: ..."), label ENUMERATE — the listing is the point. Only label CHECK if the span tests ONE specific candidate.
7. **BACKTRACK vs CHECK:** a span that *finds* a contradiction is CHECK with outcome=fail. A span that *pivots* to a new approach is BACKTRACK. If pivoting is the dominant act, BACKTRACK.
8. **COMPUTE / CHECK vs SUMMARIZE at the end:** if the span derives the final number via real arithmetic and ends with `\\boxed{{N}}` or "the answer is N", the dominant act is the COMPUTE (or CHECK). SUMMARIZE is for spans that mostly *announce* the final answer with little or no fresh derivation, often after the work was done in earlier spans.

# Outcome field (for CHECK only)
For CHECK, also report the outcome of the test:
- pass     — the candidate satisfied the constraint
- fail     — the candidate failed (contradiction, inconsistency, mismatch)
- unclear  — the test was inconclusive or aborted
For all other labels, set outcome to null.

Use the preceding context only as background; classify the SPAN.

# Preceding context
<<<{preceding_context}>>>

# Span
<<<{span_text}>>>

# Output
Return only a JSON object on a single line:
{{"label": "<one of PLAN|SETUP|ENUMERATE|HYPOTHESIZE|COMPUTE|CHECK|BACKTRACK|SUMMARIZE|OTHER>", "outcome": "<pass|fail|unclear|null>", "confidence": "<high|medium|low>", "reasoning": "<one sentence>"}}
"""


# DeepSeek approximate pricing (USD per 1M tokens; cache-miss rates).
# Verify at https://api-docs.deepseek.com/quick_start/pricing
_PRICING = {
    "deepseek-reasoner": (0.55, 2.19),
    "deepseek-chat":     (0.27, 1.10),
}


def get_pricing(model: str) -> tuple[float, float]:
    return _PRICING.get(model, _PRICING["deepseek-reasoner"])


def build_prompt(span: dict) -> str:
    return PROMPT_TEMPLATE.format(
        preceding_context=span.get("preceding_context", ""),
        span_text=span["span_text"],
    )


def estimate_cost(n_remaining: int, model: str) -> tuple[int, int, float]:
    """Rough cost estimate per spec. Returns (input_tokens, output_tokens, usd)."""
    in_per = 1400
    # deepseek-chat has no CoT, so output is just the small JSON
    out_per = 350 if model == "deepseek-reasoner" else 80
    in_rate, out_rate = get_pricing(model)
    in_tokens = in_per * n_remaining
    out_tokens = out_per * n_remaining
    cost = in_tokens * in_rate / 1e6 + out_tokens * out_rate / 1e6
    return in_tokens, out_tokens, cost


def load_done_ids(out_path: Path) -> set[str]:
    if not out_path.exists():
        return set()
    done = set()
    with open(out_path) as f:
        for line in f:
            try:
                done.add(json.loads(line)["span_id"])
            except Exception:
                pass
    return done


def _single_call(
    client: DeepSeekClient, model: str, prompt: str,
    max_tokens: int, temperature: float | None,
) -> tuple[str | None, str | None, str | None, int | None]:
    """Make one LLM call, return (label, outcome, content, tokens). label=None means parse error or API failure."""
    kwargs = {"max_tokens": max_tokens}
    if temperature is not None and model != "deepseek-reasoner":
        kwargs["temperature"] = temperature
    content, total_tokens, _ = client.generate(model=model, prompt=prompt, **kwargs)
    if content is None:
        return None, None, None, None
    parsed = parse_json_response(content) or {}
    label = parsed.get("label")
    if label not in PRIMITIVES:
        return None, None, content, total_tokens
    outcome = parsed.get("outcome")
    if isinstance(outcome, str) and outcome.strip().lower() in {"null", "none", ""}:
        outcome = None
    return label, outcome, content, total_tokens


def _majority(labels: list[str]) -> tuple[str, int]:
    """Return (winner, vote_count). Tie-break: lexicographic on label."""
    from collections import Counter
    if not labels:
        return "PARSE_ERROR", 0
    counts = Counter(labels).most_common()
    top = counts[0][1]
    winners = sorted(l for l, c in counts if c == top)
    return winners[0], top


def make_judge_row(
    client: DeepSeekClient,
    model: str,
    span: dict,
    max_tokens: int,
    temperature: float | None = None,
    n_samples: int = 1,
) -> dict:
    """Make N judge calls per span. With N=1, returns one label.
    With N>1, returns the majority-vote label plus all per-sample labels.
    """
    prompt = build_prompt(span)
    t0 = time.time()
    labels: list[str] = []
    outcomes: list[str | None] = []
    contents: list[str | None] = []
    tokens_used = 0
    n_parse = 0
    n_api = 0
    for _ in range(n_samples):
        lab, outc, content, tok = _single_call(
            client, model, prompt, max_tokens, temperature,
        )
        if content is None:
            n_api += 1
            continue
        if lab is None:
            n_parse += 1
            contents.append(content)
            continue
        labels.append(lab)
        outcomes.append(outc)
        contents.append(content)
        if tok:
            tokens_used += tok
    dt = time.time() - t0

    in_rate, out_rate = get_pricing(model)
    cost = (
        tokens_used * 0.8 * in_rate / 1e6
        + tokens_used * 0.2 * out_rate / 1e6
    )

    if not labels:
        # All N calls failed (API or parse)
        return {
            "span_id": span["span_id"],
            "llm_label": "API_FAILURE" if n_api > n_parse else "PARSE_ERROR",
            "llm_outcome": None,
            "llm_labels": [],
            "llm_outcomes": [],
            "n_samples": n_samples,
            "n_parse_err": n_parse,
            "n_api_err": n_api,
            "raw_content": contents[0] if contents else None,
            "tokens_used": tokens_used or None,
            "cost_estimate_usd": cost,
            "elapsed_s": dt,
        }

    majority_label, vote_count = _majority(labels)
    # Outcome majority computed only on samples that voted for the majority label
    aligned_outcomes = [o for l, o in zip(labels, outcomes) if l == majority_label]
    outcome_majority, _ = _majority([o or "null" for o in aligned_outcomes])
    if outcome_majority == "null":
        outcome_majority = None
    return {
        "span_id": span["span_id"],
        "llm_label": majority_label,
        "llm_outcome": outcome_majority,
        "llm_labels": labels,
        "llm_outcomes": outcomes,
        "n_samples": n_samples,
        "n_parse_err": n_parse,
        "n_api_err": n_api,
        "vote_count": vote_count,
        "raw_content": contents[0] if contents else None,
        "tokens_used": tokens_used,
        "cost_estimate_usd": cost,
        "elapsed_s": dt,
    }


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--in", dest="in_path", required=True, type=Path)
    ap.add_argument("--out", required=True, type=Path)
    ap.add_argument("--model", default="deepseek-reasoner")
    ap.add_argument("--workers", type=int, default=100)
    ap.add_argument("--max-tokens", type=int, default=4096)
    ap.add_argument(
        "--temperature", type=float, default=0.0,
        help="Sampling temperature (deepseek-chat only; ignored by reasoner)",
    )
    ap.add_argument("--limit", type=int, default=None,
                    help="Process at most N remaining spans (after resume)")
    ap.add_argument("--n-samples", type=int, default=1,
                    help="Self-consistency: samples per span. >1 implies temperature>0")
    ap.add_argument("--dry-run", action="store_true",
                    help="Print cost estimate and exit without API calls")
    args = ap.parse_args()

    with open(args.in_path) as f:
        spans = [json.loads(line) for line in f]

    args.out.parent.mkdir(parents=True, exist_ok=True)
    done = load_done_ids(args.out)
    remaining = [s for s in spans if s["span_id"] not in done]
    if args.limit is not None:
        remaining = remaining[: args.limit]

    print(f"Total spans: {len(spans)}")
    print(f"Already done: {len(done)}")
    print(f"To process:   {len(remaining)}")

    in_tok, out_tok, cost = estimate_cost(len(remaining), args.model)
    in_tok *= args.n_samples
    out_tok *= args.n_samples
    cost *= args.n_samples
    in_rate, out_rate = get_pricing(args.model)
    print(f"Estimated input tokens : {in_tok:,}  (n_samples={args.n_samples})")
    print(f"Estimated output tokens: {out_tok:,}")
    print(f"Estimated cost (USD)   : ${cost:.2f}")
    print(
        f"  (assumes {in_rate}/M input, {out_rate}/M output for {args.model} "
        "-- verify at https://api-docs.deepseek.com/quick_start/pricing)"
    )
    if args.n_samples > 1 and args.temperature == 0.0 and args.model != "deepseek-reasoner":
        print(
            "WARNING: --n-samples > 1 with temperature=0 produces identical "
            "samples. Pass --temperature 0.7 (or similar) for self-consistency."
        )

    if args.dry_run:
        print("\n--dry-run: exiting without API calls.")
        return

    if not remaining:
        print("Nothing to do.")
        return

    client = DeepSeekClient()
    write_lock = threading.Lock()
    n_done = 0
    n_parse_err = 0
    n_api_err = 0

    def process_one(span: dict) -> dict:
        nonlocal n_done, n_parse_err, n_api_err
        row = make_judge_row(
            client, args.model, span, args.max_tokens,
            temperature=args.temperature,
            n_samples=args.n_samples,
        )
        with write_lock:
            with open(args.out, "a") as fh:
                fh.write(json.dumps(row) + "\n")
            n_done += 1
            if row["llm_label"] == "PARSE_ERROR":
                n_parse_err += 1
            elif row["llm_label"] == "API_FAILURE":
                n_api_err += 1
            if n_done % 10 == 0 or n_done == len(remaining):
                print(
                    f"  [{n_done}/{len(remaining)}] "
                    f"parse_err={n_parse_err} api_err={n_api_err}",
                    flush=True,
                )
        return row

    print(f"\nDispatching with {args.workers} workers...")
    t0 = time.time()
    with ThreadPoolExecutor(max_workers=args.workers) as ex:
        futures = [ex.submit(process_one, s) for s in remaining]
        for fut in as_completed(futures):
            fut.result()  # propagate exceptions if any leak through with_retry
    dt = time.time() - t0
    print(f"\nDone in {dt:.1f}s. Wrote {n_done} rows -> {args.out}")
    print(f"Parse errors: {n_parse_err}  API failures: {n_api_err}")


if __name__ == "__main__":
    main()
