"""Extract per-topic frozen arXiv ID sets from existing CKM evaluation runs.

For full benchmark reproducibility, every topic should ship with a fixed set
of arXiv IDs for each phase (initialization, six evolution windows,
validation pool). This script extracts those sets from per-topic run logs
produced by ``ckm-eval/scripts/lite/eval_single.py`` and emits one JSON per
topic in ``data/arxiv_ids/<slug>.json``.

Sources used per phase (per the actual lite-run log layout):

    init       <- metabolism/config.json -> processed_ids
    evolution  <- metabolism/log/window-YYYY-MM-ingest.md (parsed for arxiv_id lines)
    validation <- metabolism/fulltext_cache/<id>.txt files MINUS the init and
                  evolution IDs (validation papers were also fetched into the
                  same cache during Stage 4 grading)

Usage:
    python -m ckm_benchmark.extract_arxiv_ids \\
        --runs-dir ../ckm-eval/results/lite_20260405_000653/topics \\
        --output-dir data/arxiv_ids
"""

from __future__ import annotations

import argparse
import json
import logging
import re
from pathlib import Path


logger = logging.getLogger("ckm_benchmark.extract_arxiv_ids")
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")


_ARXIV_LINE_RE = re.compile(r"arxiv_id:\s*([\w.]+?)(?:v\d+)?\s*$", flags=re.MULTILINE)
_ARXIV_FILENAME_RE = re.compile(r"^(\d{4}\.\d{4,5})\.txt$")
_WINDOW_FILENAME_RE = re.compile(r"^window-(\d{4}-\d{2})-ingest\.md$")


def _strip_version(arxiv_id: str) -> str:
    """Remove trailing v1/v2/... so we dedupe across versions."""
    return re.sub(r"v\d+$", "", arxiv_id)


def _init_ids(metabolism_dir: Path) -> list[str]:
    cfg = metabolism_dir / "config.json"
    if not cfg.exists():
        return []
    with open(cfg) as fh:
        data = json.load(fh)
    return [_strip_version(x) for x in data.get("processed_ids", [])]


def _evolution_ids_by_window(metabolism_dir: Path) -> dict[str, list[str]]:
    log_dir = metabolism_dir / "log"
    out: dict[str, list[str]] = {}
    if not log_dir.exists():
        return out
    for path in sorted(log_dir.glob("window-*-ingest.md")):
        m = _WINDOW_FILENAME_RE.match(path.name)
        if not m:
            continue
        window_label = m.group(1)
        text = path.read_text(encoding="utf-8", errors="ignore")
        ids: list[str] = []
        seen: set[str] = set()
        for match in _ARXIV_LINE_RE.finditer(text):
            arxiv_id = _strip_version(match.group(1))
            if arxiv_id in seen:
                continue
            seen.add(arxiv_id)
            ids.append(arxiv_id)
        out[window_label] = ids
    return out


def _all_cached_ids(metabolism_dir: Path) -> list[str]:
    cache_dir = metabolism_dir / "fulltext_cache"
    if not cache_dir.exists():
        return []
    ids: list[str] = []
    for path in sorted(cache_dir.glob("*.txt")):
        m = _ARXIV_FILENAME_RE.match(path.name)
        if m:
            ids.append(m.group(1))
    return ids


def extract_topic(topic_dir: Path) -> dict | None:
    """Extract init / evolution-by-window / validation arXiv ID sets for one topic."""
    slug = topic_dir.name
    metabolism = topic_dir / "metabolism"
    if not metabolism.exists():
        logger.warning("No metabolism directory in %s — skipping", topic_dir)
        return None

    init = _init_ids(metabolism)
    evolution = _evolution_ids_by_window(metabolism)
    cached = _all_cached_ids(metabolism)

    eligible_train = set(init)
    for window_ids in evolution.values():
        eligible_train.update(window_ids)
    validation = [arxiv_id for arxiv_id in cached if arxiv_id not in eligible_train]

    return {
        "slug": slug,
        "init": init,
        "evolution": evolution,
        "validation": validation,
        "_provenance": {
            "init_count": len(init),
            "evolution_window_count": len(evolution),
            "evolution_total_papers": sum(len(v) for v in evolution.values()),
            "validation_count": len(validation),
            "cached_total": len(cached),
        },
    }


def main() -> None:
    parser = argparse.ArgumentParser(description="Extract frozen arXiv ID sets per topic.")
    parser.add_argument("--runs-dir", type=Path, required=True,
                        help="Path to a CKM run's `topics/` directory (one subdir per topic).")
    parser.add_argument("--output-dir", type=Path, required=True,
                        help="Output directory; one JSON file per topic will be written here.")
    args = parser.parse_args()

    args.output_dir.mkdir(parents=True, exist_ok=True)
    topic_dirs = sorted(d for d in args.runs_dir.iterdir() if d.is_dir())
    logger.info("Extracting arXiv ID sets from %d topics in %s", len(topic_dirs), args.runs_dir)

    n_written = 0
    aggregate = {"init": 0, "evolution": 0, "validation": 0}
    for topic_dir in topic_dirs:
        record = extract_topic(topic_dir)
        if record is None:
            continue
        out_path = args.output_dir / f"{record['slug']}.json"
        with open(out_path, "w") as fh:
            json.dump(record, fh, indent=2)
        n_written += 1
        prov = record["_provenance"]
        aggregate["init"] += prov["init_count"]
        aggregate["evolution"] += prov["evolution_total_papers"]
        aggregate["validation"] += prov["validation_count"]
        logger.info(
            "[%s] init=%d, evolution=%d papers across %d windows, validation=%d -> %s",
            record["slug"], prov["init_count"], prov["evolution_total_papers"],
            prov["evolution_window_count"], prov["validation_count"], out_path.name,
        )

    logger.info(
        "Done: wrote %d topic files. Aggregates: init=%d, evolution=%d, validation=%d",
        n_written, aggregate["init"], aggregate["evolution"], aggregate["validation"],
    )


if __name__ == "__main__":
    main()
