"""Re-judge a system's submitted hypotheses against the canonical judge.

Given a directory of submitted hypothesis files (Markdown or JSON) and a
validation pool of candidate future papers, this script:

    1. For each hypothesis, retrieve the top-K candidate validation papers
       by embedding similarity (or use a pre-computed shortlist).
    2. Two-stage judge each (hypothesis, candidate) pair using
       ``ckm_benchmark.judge.judge_pair``.
    3. Aggregate hits/misses per topic and emit a per-topic summary JSON
       in the same format as ``results/lite_summary.json``.

Usage:
    python -m ckm_benchmark.rejudge \\
        --hypotheses path/to/hypotheses/ \\
        --validation-pool data/validation_papers/ \\
        --output results/your_system_summary.json \\
        [--top-k 30]

The hypothesis directory layout expected:

    hypotheses/
    ├── topic_slug_1/
    │   ├── hyp-2024-01-001.md  (or .json)
    │   ├── hyp-2024-01-002.md
    │   └── ...
    └── topic_slug_2/
        └── ...

For the v0.1 release, this is a thin reference script: the embedding pre-filter
is delegated to a shortlist file (``shortlists/<topic_slug>.json``) that
submitters can pre-compute, since the embedding step has no value-add over
what the original system already did. v0.2 will fold the embedding pre-filter
into this script once frozen validation pools are released.
"""

from __future__ import annotations

import argparse
import json
import logging
from collections import defaultdict
from pathlib import Path

from ckm_benchmark.judge import judge_pair, JudgeResult
from ckm_benchmark.protocol import HIT_THRESHOLD


logger = logging.getLogger("ckm_benchmark.rejudge")
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")


def _load_hypothesis(path: Path) -> dict:
    """Load a hypothesis from .md or .json. Returns a dict with 'id' and 'text'."""
    if path.suffix.lower() == ".json":
        with open(path) as fh:
            data = json.load(fh)
        # The submission JSON schema includes a 'statement' or full structured payload
        text = data.get("statement") or data.get("text") or json.dumps(data, indent=2)
        return {"id": data.get("id") or path.stem, "text": text, "raw": data}
    # Markdown: pass full text to the judge (judge prompt expects structured md)
    text = path.read_text(encoding="utf-8")
    return {"id": path.stem, "text": text, "raw": None}


def _load_candidate_papers(pool_dir: Path, topic_slug: str) -> list[dict]:
    """Load candidate validation papers for a given topic.

    Expects files in pool_dir/<topic_slug>/<arxiv_id>.json with fields:
        title, published, arxiv_id, content (text excerpt)
    """
    topic_dir = pool_dir / topic_slug
    if not topic_dir.exists():
        logger.warning("No validation pool found for topic %s at %s", topic_slug, topic_dir)
        return []
    papers = []
    for path in sorted(topic_dir.glob("*.json")):
        with open(path) as fh:
            paper = json.load(fh)
        papers.append(paper)
    return papers


def _load_shortlist(shortlist_dir: Path | None, topic_slug: str, hyp_id: str) -> list[str] | None:
    """Optional: per-hypothesis shortlist of arXiv IDs to judge.

    Returns the list of arXiv IDs to consider, or None if no shortlist file is
    present (meaning judge against the full validation pool for the topic).
    """
    if shortlist_dir is None:
        return None
    path = shortlist_dir / topic_slug / f"{hyp_id}.json"
    if not path.exists():
        return None
    with open(path) as fh:
        return json.load(fh).get("candidate_arxiv_ids", [])


def rejudge_topic(
    topic_slug: str,
    hyp_dir: Path,
    pool_dir: Path,
    shortlist_dir: Path | None,
    top_k: int,
) -> dict:
    """Re-judge all hypotheses for one topic. Returns a summary record."""
    hyp_paths = sorted(p for p in hyp_dir.glob("*") if p.suffix in {".md", ".json"})
    candidates = _load_candidate_papers(pool_dir, topic_slug)
    if not candidates:
        return {
            "slug": topic_slug,
            "yield": 0,
            "hit_rate": 0.0,
            "best_match_score": 0.0,
            "unique_hit_papers": 0,
            "tokens_used": 0,
        }

    n_hyp = len(hyp_paths)
    n_hits = 0
    best_match_per_hyp: list[float] = []
    matched_paper_ids: set[str] = set()
    total_tokens = 0
    judge_results = []

    for hyp_path in hyp_paths:
        hyp = _load_hypothesis(hyp_path)
        shortlist = _load_shortlist(shortlist_dir, topic_slug, hyp["id"])
        candidates_for_hyp = (
            [p for p in candidates if p["arxiv_id"] in set(shortlist)]
            if shortlist is not None
            else candidates[:top_k]
        )

        per_hyp_best = 0.0
        per_hyp_hit = False
        for paper in candidates_for_hyp:
            try:
                result: JudgeResult = judge_pair(
                    hypothesis_text=hyp["text"],
                    paper_title=paper["title"],
                    paper_published=paper["published"],
                    paper_arxiv_id=paper["arxiv_id"],
                    paper_content=paper["content"],
                )
            except Exception as e:
                logger.error("Judge failure on %s vs %s: %s", hyp["id"], paper["arxiv_id"], e)
                continue
            total_tokens += result.tokens_used
            per_hyp_best = max(per_hyp_best, result.score)
            if result.is_hit:
                per_hyp_hit = True
                matched_paper_ids.add(paper["arxiv_id"])
            judge_results.append(
                {
                    "topic_slug": topic_slug,
                    "hypothesis_id": hyp["id"],
                    "paper_arxiv_id": paper["arxiv_id"],
                    **result.to_dict(),
                }
            )

        best_match_per_hyp.append(per_hyp_best)
        if per_hyp_hit:
            n_hits += 1
        logger.info(
            "[%s] %s: best=%.2f, hit=%s, tokens=%d",
            topic_slug, hyp["id"], per_hyp_best, per_hyp_hit, total_tokens,
        )

    return {
        "slug": topic_slug,
        "yield": n_hyp,
        "hit_rate": 100.0 * n_hits / max(1, n_hyp),
        "best_match_score": sum(best_match_per_hyp) / max(1, len(best_match_per_hyp)),
        "unique_hit_papers": len(matched_paper_ids),
        "tokens_used": total_tokens,
        "_judge_log": judge_results,  # full per-pair log for auditing; can be removed
    }


def main() -> None:
    parser = argparse.ArgumentParser(description="Re-judge a system's hypotheses.")
    parser.add_argument("--hypotheses", type=Path, required=True,
                        help="Directory of per-topic hypothesis subdirectories.")
    parser.add_argument("--validation-pool", type=Path, required=True,
                        help="Directory of per-topic validation paper subdirectories.")
    parser.add_argument("--shortlists", type=Path, default=None,
                        help="Optional directory of per-hypothesis candidate shortlists.")
    parser.add_argument("--output", type=Path, required=True,
                        help="Path to write the per-topic summary JSON.")
    parser.add_argument("--top-k", type=int, default=30,
                        help="Top-K candidate papers per hypothesis when no shortlist provided.")
    args = parser.parse_args()

    summary = []
    topic_dirs = sorted(d for d in args.hypotheses.iterdir() if d.is_dir())
    logger.info("Re-judging %d topics", len(topic_dirs))
    for topic_dir in topic_dirs:
        record = rejudge_topic(
            topic_slug=topic_dir.name,
            hyp_dir=topic_dir,
            pool_dir=args.validation_pool,
            shortlist_dir=args.shortlists,
            top_k=args.top_k,
        )
        summary.append(record)

    args.output.parent.mkdir(parents=True, exist_ok=True)
    with open(args.output, "w") as fh:
        json.dump(summary, fh, indent=2)
    logger.info("Wrote summary to %s (%d topics)", args.output, len(summary))


if __name__ == "__main__":
    main()
