"""
Pool Baseline — one-shot batch processing (control group for CKM).

Identical to CKM except Phase 2:
  CKM:  incremental sliding-window evolution with per-window hypothesis generation
  Pool: collect ALL evolution papers at once, build knowledge in one pass,
        generate hypotheses in one pass

Phase 1 (Init) and Phase 3 (Validation) are identical to CKM.
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))

import asyncio
import argparse
import logging
import re
from datetime import datetime

from config import config
from core.store import FileSystemKnowledgeStore
import json

from core.engines import (
    run_read_engine, run_init_engine, run_topic_update_engine, run_topic_discover_engine,
    run_baseline_hypothesis_engine, run_dedup_engine, get_embedding, get_max_input_chars,
)
from tools.arxiv_search import search_arxiv_topic
from tools.arxiv_fulltext import get_paper_content_record

logger = logging.getLogger("pool.main")


# ---------------------------------------------------------------------------
# Helpers — identical to CKM eval_single.py
# ---------------------------------------------------------------------------

def add_months(date_str: str, months: int) -> str:
    dt = datetime.strptime(date_str, "%Y-%m-%d")
    m = dt.month - 1 + months
    y = dt.year + m // 12
    m = m % 12 + 1
    return dt.replace(year=y, month=m, day=1).strftime("%Y-%m-%d")


def build_phase_windows(start_date: str, end_date: str, window_months: int) -> list[tuple[str, str]]:
    windows = []
    current = start_date
    while current < end_date:
        next_date = add_months(current, window_months)
        if next_date > end_date:
            next_date = end_date
        windows.append((current, next_date))
        current = next_date
    return windows


def allocate_window_budgets(total_budget: int, n_windows: int) -> list[int]:
    if total_budget <= 0 or n_windows <= 0:
        return []
    base = total_budget // n_windows
    remainder = total_budget % n_windows
    return [base + (1 if i < remainder else 0) for i in range(n_windows)]


def dedupe_and_sort_papers(papers: list) -> list:
    unique = {}
    for paper in papers:
        unique.setdefault(paper["arxiv_id"], paper)
    return sorted(unique.values(), key=lambda x: x["published"])


async def resolve_paper_contents(
    papers: list,
    cache_dir: Path,
    concurrency: int,
    phase_name: str,
) -> list:
    exp_cfg = config["experiment"]
    fulltext_timeout_s = max(1, exp_cfg.get("fulltext_timeout_s", 30))
    fulltext_retries = max(1, exp_cfg.get("fulltext_retries", 2))
    fulltext_retry_delay_s = max(0, exp_cfg.get("fulltext_retry_delay_s", 2))

    semaphore = asyncio.Semaphore(max(1, concurrency))

    async def resolve_single(paper: dict) -> dict:
        async with semaphore:
            record = await get_paper_content_record(
                paper["arxiv_id"],
                paper["abstract"],
                cache_dir,
                timeout_s=fulltext_timeout_s,
                retries=fulltext_retries,
                retry_delay_s=fulltext_retry_delay_s,
            )
        resolved = dict(paper)
        resolved["content"] = record["content"]
        resolved["content_source"] = record["source"]
        resolved["counted_fulltext"] = record["counted_fulltext"]
        return resolved

    resolved_papers = await asyncio.gather(*[resolve_single(paper) for paper in papers])
    counted_fulltext = sum(1 for paper in resolved_papers if paper["counted_fulltext"])
    logger.info(
        "[%s] Content resolved: total=%d, counted_fulltext=%d, abstract_only=%d",
        phase_name,
        len(resolved_papers),
        counted_fulltext,
        len(resolved_papers) - counted_fulltext,
    )
    return resolved_papers


def fetch_phase_papers(
    topic: str,
    start_date: str,
    end_date: str,
    total_budget: int,
    window_months: int,
    min_papers: int = 0,
    phase_name: str = "Search",
) -> list:
    windows = build_phase_windows(start_date, end_date, window_months)
    target_budget = max(total_budget, min_papers)
    budgets = allocate_window_budgets(target_budget, len(windows))
    collected = []

    for (window_start, window_end), budget in zip(windows, budgets):
        if budget <= 0:
            continue
        papers = search_arxiv_topic(
            topic,
            budget,
            start_date=window_start,
            end_date=window_end,
        )
        logger.info(
            "[%s] %s ~ %s: fetched %d papers (budget=%d)",
            phase_name,
            window_start,
            window_end,
            len(papers),
            budget,
        )
        collected.extend(papers)

    deduped = dedupe_and_sort_papers(collected)
    if min_papers and len(deduped) < min_papers:
        logger.warning(
            "[%s] Only %d papers found in %s ~ %s (minimum desired=%d)",
            phase_name,
            len(deduped),
            start_date,
            end_date,
            min_papers,
        )
    return deduped


def display_topic_name(file_name: str) -> str:
    stem = Path(file_name).stem
    if stem == "_index":
        return ""
    if stem.startswith("topic-"):
        stem = stem[len("topic-"):]
    return stem.replace("-", " ")


def format_papers_for_prompt(papers: list) -> str:
    parts = []
    for i, p in enumerate(papers):
        parts.append(
            f"### {i+1}. {p['title']}\n"
            f"- **arxiv_id**: {p['arxiv_id']}\n"
            f"- **published**: {p['published']}\n"
            f"- **abstract**: {p['abstract']}\n---"
        )
    return "\n\n".join(parts)


def concat_abstracts(papers: list) -> str:
    return "\n".join(p["abstract"] for p in papers)


# ---------------------------------------------------------------------------
# Phase 1: Init — identical to CKM
# ---------------------------------------------------------------------------

async def phase_init(topic: str, papers: list, store: FileSystemKnowledgeStore,
                     cache_dir: Path) -> int:
    """Build initial knowledge base from a broad set of papers."""
    if not papers:
        logger.warning("[Init] No papers in init time range, skipping.")
        return 0

    logger.info("[Init] Building knowledge baseline from %d metadata papers", len(papers))
    resolved_papers = await resolve_paper_contents(papers, cache_dir, 12, "Init")
    fulltext_papers = [paper for paper in resolved_papers if paper["counted_fulltext"]]
    abstract_only_count = len(resolved_papers) - len(fulltext_papers)
    if not fulltext_papers:
        raise RuntimeError("Init phase resolved zero counted fulltext papers")

    # Step 1: Concurrent per-paper extraction (concurrency=12)
    semaphore = asyncio.Semaphore(6)

    async def extract_paper(p):
        async with semaphore:
            return await run_read_engine(p["title"], p["arxiv_id"], p["published"], p["content"])

    logger.info(
        "[Init] Extracting %d counted fulltext papers concurrently (max 12), abstract-only skipped=%d...",
        len(fulltext_papers),
        abstract_only_count,
    )
    extractions = await asyncio.gather(*[extract_paper(p) for p in fulltext_papers])

    # Step 2: Track processed_ids in config.json
    config_path = store.base_dir / "config.json"
    config_data = {
        "topic": topic,
        "processed_ids": [p["arxiv_id"] for p in fulltext_papers],
        "currentDay": 0,
    }
    config_path.write_text(json.dumps(config_data, indent=2), encoding="utf-8")
    logger.info("[Init] Tracked %d counted fulltext paper IDs in config.json", len(fulltext_papers))

    # Step 3: Synthesize all extractions into knowledge files
    papers_text = "\n\n---\n\n".join(extractions)
    result = await run_init_engine(topic, papers_text, len(fulltext_papers))
    topic_ops = [op for op in result["operations"] if op["fileName"] != "_index.md"]

    if not topic_ops:
        error_text = result.get("error") or "unknown init failure"
        raw_output = result.get("raw_content", "").strip() or "(empty init output)"
        store.write_log(f"{datetime.now().strftime('%Y-%m-%d')}-init-error.md", error_text)
        store.write_log(f"{datetime.now().strftime('%Y-%m-%d')}-init-raw.md", raw_output)
        raise RuntimeError(f"Init engine returned no topic files: {error_text}")

    # Step 4: Write knowledge files and compute embeddings concurrently
    embeddings_dict = store.get_embeddings_dict()
    for op in result["operations"]:
        store.write_knowledge_file(op["fileName"], op["fileContent"])
    vecs = await asyncio.gather(*[get_embedding(op["fileContent"]) for op in result["operations"]])
    for op, vec in zip(result["operations"], vecs):
        embeddings_dict[op["fileName"]] = vec
        logger.info("[Init] Created knowledge file: %s", op["fileName"])

    store.save_embeddings_dict(embeddings_dict)

    # Step 5: Write init log and mark init complete (currentDay -> 1)
    topic_names = ", ".join(
        name for name in (display_topic_name(op["fileName"]) for op in topic_ops) if name
    )
    log_content = (
        f"# Day 0 — Initialization\n\n"
        f"Date: {datetime.now().strftime('%Y-%m-%d')}\n"
        f"Papers: {len(fulltext_papers)}\n"
        f"Abstract-only skipped: {abstract_only_count}\n"
        f"Topics: {topic_names}\n"
        f"Status: Baseline construction complete\n"
    )
    store.write_log(f"{datetime.now().strftime('%Y-%m-%d')}-init.md", log_content)

    config_path = store.base_dir / "config.json"
    config_data = json.loads(config_path.read_text(encoding="utf-8"))
    config_data["currentDay"] = 1
    config_path.write_text(json.dumps(config_data, indent=2), encoding="utf-8")

    logger.info("[Init] Baseline complete: %d knowledge files, %d tokens",
                len(result["operations"]), result["tokens"])
    return result["tokens"]


# ---------------------------------------------------------------------------
# Phase 2: Pool — one-shot batch processing (replaces CKM evolution)
# ---------------------------------------------------------------------------

async def phase_pool(topic: str, papers: list, store: FileSystemKnowledgeStore,
                     cache_dir: Path, evolution_start: str, evolution_end: str,
                     step_months: int) -> int:
    """
    Pool baseline: process ALL evolution papers in one shot.

    Unlike CKM's incremental evolution, Pool:
    1. Resolves all evolution paper contents at once
    2. Updates all topic files once with all papers
    3. Discovers new topics once
    4. Generates hypotheses once (with budget = n_windows * max_per_window)
    """
    if not papers:
        logger.warning("[Pool] No papers in evolution time range, skipping.")
        return 0

    exp_cfg = config["experiment"]
    hypothesis_max_per_window = max(1, exp_cfg.get("hypothesis_max_per_window", 3))

    # Calculate total hypothesis budget: same as CKM would have across all windows
    n_windows = len(build_phase_windows(evolution_start, evolution_end, step_months))
    pool_hypothesis_budget = n_windows * hypothesis_max_per_window

    logger.info(
        "[Pool] One-shot processing: %d papers, hypothesis budget=%d (equivalent to %d windows x %d/window)",
        len(papers), pool_hypothesis_budget, n_windows, hypothesis_max_per_window,
    )

    total_tokens = 0
    period_label = f"{evolution_start[:7]}~{evolution_end[:7]}"

    # Step 1: Resolve all paper contents at once
    resolved_papers = await resolve_paper_contents(
        papers, cache_dir, 12, "Pool",
    )
    counted_papers = [p for p in resolved_papers if p["counted_fulltext"]]
    abstract_only_skipped = len(resolved_papers) - len(counted_papers)

    if not counted_papers:
        logger.warning("[Pool] No counted fulltext papers after resolution, skipping.")
        return 0

    logger.info(
        "[Pool] Resolved: %d counted fulltext papers, %d abstract-only skipped",
        len(counted_papers), abstract_only_skipped,
    )

    # Step 2: Build papers text from all counted fulltext papers.
    # No pre-truncation — if context overflows, the engine's dynamic
    # truncation (catch context_length_exceeded → retry at 80%) handles it.
    papers_parts = []
    for p in counted_papers:
        papers_parts.append(
            f"### {p['title']}\n- arxiv_id: {p['arxiv_id']}\n"
            f"- published: {p['published']}\n"
            f"- content_source: {p['content_source']}\n\n{p['content']}"
        )
    papers_text = "\n\n---\n\n".join(papers_parts)

    log_header = [
        f"# Pool One-Shot — {len(counted_papers)} counted fulltext papers",
        "",
        f"Period: {evolution_start} ~ {evolution_end}",
        f"Total papers fetched: {len(papers)}",
        f"Counted fulltext: {len(counted_papers)}",
        f"Abstract-only skipped: {abstract_only_skipped}",
        "",
    ]
    store.write_log("pool-ingest.md", "\n".join(log_header) + papers_text)

    # Step 3: Update all topic files concurrently
    # Pool has too many papers to send all of them to each topic file (causes timeouts).
    # Use embedding similarity to select the most relevant papers per knowledge file.
    MAX_PAPERS_PER_TOPIC_UPDATE = 20

    knowledge_files_map = {
        file_name: content
        for file_name, content in store.get_knowledge_files().items()
        if file_name != "_index.md"
    }
    embeddings_dict = store.get_embeddings_dict()

    # Embed all papers (abstracts) for similarity ranking
    paper_abstracts = [p["abstract"] for p in counted_papers]
    paper_vecs = await asyncio.gather(*[get_embedding(a) for a in paper_abstracts])

    # For each knowledge file, rank papers by embedding similarity and take top-K
    from core.engines import cosine_similarity as _cos_sim

    def _select_relevant_papers(file_name: str, topic_content: str) -> str:
        topic_vec = embeddings_dict.get(file_name)
        if not topic_vec:
            # No embedding available, use all papers (truncated)
            return papers_text[:get_max_input_chars()]

        scored = []
        for i, p in enumerate(counted_papers):
            sim = _cos_sim(topic_vec, paper_vecs[i])
            scored.append((sim, p))
        scored.sort(key=lambda x: x[0], reverse=True)
        selected = scored[:MAX_PAPERS_PER_TOPIC_UPDATE]

        parts = []
        for _, p in selected:
            parts.append(
                f"### {p['title']}\n- arxiv_id: {p['arxiv_id']}\n"
                f"- published: {p['published']}\n"
                f"- content_source: {p['content_source']}\n\n{p['content']}"
            )
        return "\n\n---\n\n".join(parts)

    semaphore = asyncio.Semaphore(3)

    async def update_topic(file_name, topic_content):
        relevant_papers_text = _select_relevant_papers(file_name, topic_content)
        async with semaphore:
            res = await run_topic_update_engine(file_name, topic_content, relevant_papers_text, period_label)
            return res

    logger.info(
        "[Pool] Updating %d knowledge files (max %d relevant papers each)",
        len(knowledge_files_map), MAX_PAPERS_PER_TOPIC_UPDATE,
    )
    results = await asyncio.gather(*[
        update_topic(fn, fc) for fn, fc in knowledge_files_map.items()
    ])

    total_tokens += sum(r["tokens"] for r in results)

    for res in results:
        store.write_knowledge_file(res["fileName"], res["fileContent"])
        logger.info("[Pool] Updated knowledge file: %s", res["fileName"])

    new_vecs = await asyncio.gather(*[get_embedding(r["fileContent"]) for r in results])
    for res, vec in zip(results, new_vecs):
        embeddings_dict[res["fileName"]] = vec
    store.save_embeddings_dict(embeddings_dict)

    # Step 4: Discover new topic files (truncate papers_text to fit context)
    existing_topic_names = list(knowledge_files_map.keys())
    discover_papers_text = papers_text[:get_max_input_chars()]
    discover_res = await run_topic_discover_engine(existing_topic_names, discover_papers_text, period_label)
    total_tokens += discover_res["tokens"]
    if discover_res["operations"]:
        discover_vecs = await asyncio.gather(
            *[get_embedding(op["fileContent"]) for op in discover_res["operations"]]
        )
        for op, vec in zip(discover_res["operations"], discover_vecs):
            store.write_knowledge_file(op["fileName"], op["fileContent"])
            embeddings_dict[op["fileName"]] = vec
            logger.info("[Pool] New topic file created: %s", op["fileName"])
        store.save_embeddings_dict(embeddings_dict)

    # Step 5: Generate hypotheses — call n_windows times (same as CKM),
    # each with budget=hypothesis_max_per_window, but always using the
    # same one-shot knowledge state K_batch. This ensures the number of
    # LLM calls matches CKM; the only difference is the knowledge state
    # driving hypothesis generation.
    all_knowledge = store.get_joined_knowledge_content()

    logger.info(
        "[Pool] Generating hypotheses: %d rounds x %d/round (matching CKM window count)",
        n_windows, hypothesis_max_per_window,
    )

    for round_idx in range(n_windows):
        round_label = f"pool-round-{round_idx + 1}"
        starting_hypothesis_index = store.count_hypotheses() + 1

        hyp_res = await run_baseline_hypothesis_engine(
            all_knowledge,
            papers_text,
            round_label,
            hypothesis_max_per_window,
        )
        total_tokens += hyp_res["tokens"]

        accepted_hypotheses = hyp_res.get("hypotheses", [])
        for offset, hypothesis in enumerate(accepted_hypotheses):
            hyp_id_string = str(starting_hypothesis_index + offset).zfill(3)
            rendered_content = re.sub(
                r"^# Hypothesis H\d+",
                f"# Hypothesis H{hyp_id_string}",
                hypothesis["content"],
                count=1,
            )
            store.write_hypothesis(f"{period_label}-{hyp_id_string}", rendered_content)
            logger.info("[Pool] %s: Hypothesis H%s generated", round_label, hyp_id_string)

        rejections = hyp_res.get("rejections", [])
        if rejections:
            for rejection in rejections:
                rejection_label = str(rejection.get("candidate_index", 0)).zfill(3)
                store.write_log(
                    f"hypothesis-rejected-{round_label}-{rejection_label}.md",
                    (
                        f"# Rejected Hypothesis Candidate {rejection_label}\n\n"
                        f"Round: {round_label}\n"
                        f"Reason: {rejection.get('rejection_reason', '').strip() or '(unknown)'}\n\n"
                        f"## Candidate Content\n\n{rejection.get('content', '').strip() or '(empty)'}\n"
                    ),
                )
                logger.info(
                    "[Pool] %s: Hypothesis candidate %s rejected (%s)",
                    round_label,
                    rejection_label,
                    rejection.get("rejection_reason", "").strip() or "unknown reason",
                )
        elif not accepted_hypotheses:
            engine_error = hyp_res.get("error", "").strip()
            if engine_error:
                store.write_log(
                    f"hypothesis-engine-error-{round_label}.md",
                    (
                        f"# Hypothesis Engine Error\n\n"
                        f"Round: {round_label}\n"
                        f"Error: {engine_error}\n\n"
                        f"## Raw Output\n\n{hyp_res.get('raw_content', '').strip() or '(empty)'}\n"
                    ),
                )
                logger.warning("[Pool] %s: Hypothesis engine parse error (%s)", round_label, engine_error)
            else:
                logger.info("[Pool] %s: No hypothesis triggered", round_label)

    # Step 6: LLM-based deduplication of hypotheses
    hyp_dir = store.dirs["hypotheses"]
    hyp_files = sorted(hyp_dir.glob("*.md"))
    total_before_dedup = len(hyp_files)

    if total_before_dedup > 1:
        logger.info("[Pool] Deduplicating %d hypotheses via LLM judge...", total_before_dedup)

        # Extract ID and statement from each hypothesis file
        dedup_input = []
        for fp in hyp_files:
            content = fp.read_text(encoding="utf-8")
            # Extract H-id from filename: hyp-2024-01~2025-01-001.md -> H001
            id_match = re.search(r"-(\d{3})\.md$", fp.name)
            hyp_id = f"H{id_match.group(1)}" if id_match else fp.stem

            # Extract statement section
            stmt_match = re.search(r"## Statement\s*([\s\S]*?)(?=##|\Z)", content)
            statement = stmt_match.group(1).strip() if stmt_match else content[:500]

            dedup_input.append({
                "id": hyp_id,
                "statement": statement,
                "file_name": fp.name,
                "file_path": fp,
            })

        dedup_res = await run_dedup_engine(dedup_input)
        total_tokens += dedup_res["tokens"]

        # Remove duplicate hypothesis files and log removals
        removed_count = 0
        for item in dedup_input:
            if item["id"] in dedup_res["remove_ids"]:
                item["file_path"].unlink()
                removed_count += 1
                logger.info("[Pool] Dedup: removed %s (%s)", item["file_name"], item["id"])

        # Log dedup details
        if dedup_res["removals"]:
            dedup_log_lines = [
                f"# Hypothesis Deduplication\n",
                f"Before: {total_before_dedup}",
                f"After: {total_before_dedup - removed_count}",
                f"Removed: {removed_count}\n",
                "## Removals\n",
            ]
            for removal in dedup_res["removals"]:
                dedup_log_lines.append(
                    f"- {removal.get('id', '?')} → duplicate of {removal.get('duplicate_of', '?')}: "
                    f"{removal.get('reason', 'no reason')}"
                )
            store.write_log("pool-dedup.md", "\n".join(dedup_log_lines))

        logger.info(
            "[Pool] Dedup complete: %d → %d hypotheses",
            total_before_dedup, total_before_dedup - removed_count,
        )
    else:
        logger.info("[Pool] Only %d hypothesis, skipping dedup", total_before_dedup)

    final_hyp_count = store.count_hypotheses()
    logger.info(
        "[Pool] Complete. Papers=%d, Hypotheses=%d (before dedup=%d), Tokens=%d",
        len(counted_papers), final_hyp_count, total_before_dedup, total_tokens,
    )
    return total_tokens


# ---------------------------------------------------------------------------
# Phase 3: Validation — identical to CKM
# ---------------------------------------------------------------------------

async def phase_validation(all_papers: list, store: FileSystemKnowledgeStore,
                           validation_start: str, validation_end: str, topic: str,
                           ablation_mode: str = "none", cache_dir: Path = None) -> None:
    """Evaluate hypotheses against papers published after evolution ended."""
    from tools.calculate_metrics import calculate_metrics
    await calculate_metrics(all_papers, validation_start, validation_end,
                            topic, ablation_mode, store, cache_dir)


# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------

async def run():
    parser = argparse.ArgumentParser(description="Pool Baseline: One-Shot Batch Processing (Control for CKM)")
    parser.add_argument("topic", nargs="?", default="AI for software engineering")
    parser.add_argument("--metabolism_dir", type=str)
    parser.add_argument("--report_dir", type=str)
    parser.add_argument("--ablation", choices=["none", "shuffled"], default="none")

    args = parser.parse_args()

    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
    )

    # Resolve paths
    project_dir = Path(__file__).parent.parent.parent.absolute()  # ckm-eval/
    run_dir = project_dir / "results" / ("pool_" + datetime.now().strftime("%Y%m%d_%H%M%S"))
    metabolism_dir = Path(args.metabolism_dir).absolute() if args.metabolism_dir else run_dir / "metabolism"
    metabolism_dir.mkdir(parents=True, exist_ok=True)
    store = FileSystemKnowledgeStore(metabolism_dir)
    cache_dir = metabolism_dir / "fulltext_cache"

    report_dir = Path(args.report_dir).absolute() if args.report_dir else run_dir / "reports"
    config["paths"]["reports"] = report_dir

    # Read time parameters from config (identical to CKM)
    exp = config["experiment"]
    init_start = exp["init_start"]
    init_end = exp["init_end"]
    evo_start = exp["evolution_start"]
    evo_end = exp["evolution_end"]
    step_months = exp["evolution_step_months"]
    val_start = exp["validation_start"]
    val_end = exp["validation_end"]
    max_papers = exp["max_papers"]
    init_budget = min(exp["init_max_papers"], max_papers)
    evo_budget = min(exp["evolution_max_papers"], max_papers)
    val_min_papers = min(exp["validation_min_papers"], max_papers)
    val_budget = max(min(exp["validation_max_papers"], max_papers), val_min_papers)
    phase_search_window_months = exp["phase_search_window_months"]
    evolution_search_window_months = max(1, min(phase_search_window_months, step_months))

    topic = args.topic
    n_windows = len(build_phase_windows(evo_start, evo_end, step_months))
    hypothesis_max_per_window = max(1, exp["hypothesis_max_per_window"])
    pool_hypothesis_budget = n_windows * hypothesis_max_per_window

    logger.info("=" * 70)
    logger.info("Pool Baseline | Topic: %s", topic)
    logger.info("  Phase 1 (Init):       %s ~ %s", init_start, init_end)
    logger.info("  Phase 2 (Pool):       %s ~ %s (ONE-SHOT, no windowed evolution)", evo_start, evo_end)
    logger.info("  Hypothesis budget:    %d (equivalent to %d CKM windows x %d/window)",
                pool_hypothesis_budget, n_windows, hypothesis_max_per_window)
    logger.info("  Phase 3 (Validation): %s ~ %s", val_start, val_end)
    logger.info(
        "  Search budgets:       init=%d, evolution=%d, validation=%d (min=%d)",
        init_budget, evo_budget, val_budget, val_min_papers,
    )
    logger.info("=" * 70)

    # Fetch papers per phase — identical to CKM (same queries, same budgets)
    init_papers = fetch_phase_papers(
        topic, init_start, init_end, init_budget, phase_search_window_months,
        phase_name="Init",
    )
    evo_papers = fetch_phase_papers(
        topic, evo_start, evo_end, evo_budget, evolution_search_window_months,
        phase_name="Pool-Evolution",
    )
    val_papers = fetch_phase_papers(
        topic,
        val_start,
        val_end,
        val_budget,
        phase_search_window_months,
        min_papers=val_min_papers,
        phase_name="Validation",
    )
    all_papers = dedupe_and_sort_papers(init_papers + evo_papers + val_papers)

    if not all_papers:
        logger.error("No papers found for topic '%s', aborting.", topic)
        return

    all_papers.sort(key=lambda x: x["published"])
    logger.info("Fetched %d papers (%s ~ %s)",
                len(all_papers), all_papers[0]["published"][:10], all_papers[-1]["published"][:10])

    logger.info("Paper split: init=%d, evolution=%d, validation=%d",
                len(init_papers), len(evo_papers), len(val_papers))

    # Phase 1: Init — identical to CKM
    logger.info("--- Phase 1: Init ---")
    init_tokens = await phase_init(topic, init_papers, store, cache_dir)

    # Phase 2: Pool — one-shot batch processing (replaces CKM evolution)
    logger.info("--- Phase 2: Pool (One-Shot) ---")
    pool_tokens = await phase_pool(
        topic, evo_papers, store, cache_dir,
        evo_start, evo_end, step_months,
    )

    # Phase 3: Validation — identical to CKM
    logger.info("--- Phase 3: Validation ---")
    await phase_validation(all_papers, store, val_start, val_end, topic, args.ablation, cache_dir)

    total_tokens = init_tokens + pool_tokens
    logger.info("All phases complete. Total LLM tokens: %d", total_tokens)

    (store.base_dir / "token_usage.json").write_text(
        json.dumps({"init_tokens": init_tokens, "evolution_tokens": pool_tokens, "total_generation_tokens": total_tokens}, indent=2), encoding="utf-8"
    )


if __name__ == "__main__":
    asyncio.run(run())
