import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))

import asyncio
import argparse
import logging
import os
import re
from datetime import datetime

from config import config
from core.store import FileSystemKnowledgeStore
import json

from core.engines import (
    run_read_engine, run_init_engine, run_topic_update_engine, run_topic_discover_engine,
    run_ckm_hypothesis_engine, run_trigger_detection,
    get_embedding,
)
from tools.arxiv_search import search_arxiv_topic
from tools.arxiv_fulltext import get_paper_content_record

logger = logging.getLogger("ckm.main")


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def add_months(date_str: str, months: int) -> str:
    dt = datetime.strptime(date_str, "%Y-%m-%d")
    m = dt.month - 1 + months
    y = dt.year + m // 12
    m = m % 12 + 1
    return dt.replace(year=y, month=m, day=1).strftime("%Y-%m-%d")


def build_phase_windows(start_date: str, end_date: str, window_months: int) -> list[tuple[str, str]]:
    windows = []
    current = start_date
    while current < end_date:
        next_date = add_months(current, window_months)
        if next_date > end_date:
            next_date = end_date
        windows.append((current, next_date))
        current = next_date
    return windows


def allocate_window_budgets(total_budget: int, n_windows: int) -> list[int]:
    if total_budget <= 0 or n_windows <= 0:
        return []
    base = total_budget // n_windows
    remainder = total_budget % n_windows
    return [base + (1 if i < remainder else 0) for i in range(n_windows)]


def dedupe_and_sort_papers(papers: list) -> list:
    unique = {}
    for paper in papers:
        unique.setdefault(paper["arxiv_id"], paper)
    return sorted(unique.values(), key=lambda x: x["published"])


async def resolve_paper_contents(
    papers: list,
    cache_dir: Path,
    concurrency: int,
    phase_name: str,
) -> list:
    exp_cfg = config["experiment"]
    fulltext_timeout_s = max(1, exp_cfg.get("fulltext_timeout_s", 30))
    fulltext_retries = max(1, exp_cfg.get("fulltext_retries", 2))
    fulltext_retry_delay_s = max(0, exp_cfg.get("fulltext_retry_delay_s", 2))

    semaphore = asyncio.Semaphore(max(1, concurrency))

    async def resolve_single(paper: dict) -> dict:
        async with semaphore:
            record = await get_paper_content_record(
                paper["arxiv_id"],
                paper["abstract"],
                cache_dir,
                timeout_s=fulltext_timeout_s,
                retries=fulltext_retries,
                retry_delay_s=fulltext_retry_delay_s,
            )
        resolved = dict(paper)
        resolved["content"] = record["content"]
        resolved["content_source"] = record["source"]
        resolved["counted_fulltext"] = record["counted_fulltext"]
        return resolved

    resolved_papers = await asyncio.gather(*[resolve_single(paper) for paper in papers])
    counted_fulltext = sum(1 for paper in resolved_papers if paper["counted_fulltext"])
    logger.info(
        "[%s] Content resolved: total=%d, counted_fulltext=%d, abstract_only=%d",
        phase_name,
        len(resolved_papers),
        counted_fulltext,
        len(resolved_papers) - counted_fulltext,
    )
    return resolved_papers


def fetch_phase_papers(
    topic: str,
    start_date: str,
    end_date: str,
    total_budget: int,
    window_months: int,
    min_papers: int = 0,
    phase_name: str = "Search",
) -> list:
    windows = build_phase_windows(start_date, end_date, window_months)
    target_budget = max(total_budget, min_papers)
    budgets = allocate_window_budgets(target_budget, len(windows))
    collected = []

    for (window_start, window_end), budget in zip(windows, budgets):
        if budget <= 0:
            continue
        papers = search_arxiv_topic(
            topic,
            budget,
            start_date=window_start,
            end_date=window_end,
        )
        logger.info(
            "[%s] %s ~ %s: fetched %d papers (budget=%d)",
            phase_name,
            window_start,
            window_end,
            len(papers),
            budget,
        )
        collected.extend(papers)

    deduped = dedupe_and_sort_papers(collected)
    if min_papers and len(deduped) < min_papers:
        logger.warning(
            "[%s] Only %d papers found in %s ~ %s (minimum desired=%d)",
            phase_name,
            len(deduped),
            start_date,
            end_date,
            min_papers,
        )
    return deduped


def display_topic_name(file_name: str) -> str:
    stem = Path(file_name).stem
    if stem == "_index":
        return ""
    if stem.startswith("topic-"):
        stem = stem[len("topic-"):]
    return stem.replace("-", " ")


def format_papers_for_prompt(papers: list) -> str:
    parts = []
    for i, p in enumerate(papers):
        parts.append(
            f"### {i+1}. {p['title']}\n"
            f"- **arxiv_id**: {p['arxiv_id']}\n"
            f"- **published**: {p['published']}\n"
            f"- **abstract**: {p['abstract']}\n---"
        )
    return "\n\n".join(parts)


def concat_abstracts(papers: list) -> str:
    return "\n".join(p["abstract"] for p in papers)


# ---------------------------------------------------------------------------
# Phase 1: Init — build knowledge baseline
# ---------------------------------------------------------------------------

async def phase_init(topic: str, papers: list, store: FileSystemKnowledgeStore,
                     cache_dir: Path) -> int:
    """Build initial knowledge base from a broad set of papers."""
    if not papers:
        logger.warning("[Init] No papers in init time range, skipping.")
        return 0

    logger.info("[Init] Building knowledge baseline from %d metadata papers", len(papers))
    resolved_papers = await resolve_paper_contents(papers, cache_dir, 12, "Init")
    fulltext_papers = [paper for paper in resolved_papers if paper["counted_fulltext"]]
    abstract_only_count = len(resolved_papers) - len(fulltext_papers)
    if not fulltext_papers:
        raise RuntimeError("Init phase resolved zero counted fulltext papers")

    # Step 1: Concurrent per-paper extraction (concurrency=12)
    semaphore = asyncio.Semaphore(6)

    async def extract_paper(p):
        async with semaphore:
            return await run_read_engine(p["title"], p["arxiv_id"], p["published"], p["content"])

    logger.info(
        "[Init] Extracting %d counted fulltext papers concurrently (max 12), abstract-only skipped=%d...",
        len(fulltext_papers),
        abstract_only_count,
    )
    extractions = await asyncio.gather(*[extract_paper(p) for p in fulltext_papers])

    # Step 2: Track processed_ids in config.json
    config_path = store.base_dir / "config.json"
    config_data = {
        "topic": topic,
        "processed_ids": [p["arxiv_id"] for p in fulltext_papers],
        "currentDay": 0,
    }
    config_path.write_text(json.dumps(config_data, indent=2), encoding="utf-8")
    logger.info("[Init] Tracked %d counted fulltext paper IDs in config.json", len(fulltext_papers))

    # Step 3: Synthesize all extractions into knowledge files
    papers_text = "\n\n---\n\n".join(extractions)
    result = await run_init_engine(topic, papers_text, len(fulltext_papers))
    topic_ops = [op for op in result["operations"] if op["fileName"] != "_index.md"]

    if not topic_ops:
        error_text = result.get("error") or "unknown init failure"
        raw_output = result.get("raw_content", "").strip() or "(empty init output)"
        store.write_log(f"{datetime.now().strftime('%Y-%m-%d')}-init-error.md", error_text)
        store.write_log(f"{datetime.now().strftime('%Y-%m-%d')}-init-raw.md", raw_output)
        raise RuntimeError(f"Init engine returned no topic files: {error_text}")

    # Step 4: Write knowledge files and compute embeddings concurrently
    embeddings_dict = store.get_embeddings_dict()
    for op in result["operations"]:
        store.write_knowledge_file(op["fileName"], op["fileContent"])
    vecs = await asyncio.gather(*[get_embedding(op["fileContent"]) for op in result["operations"]])
    for op, vec in zip(result["operations"], vecs):
        embeddings_dict[op["fileName"]] = vec
        logger.info("[Init] Created knowledge file: %s", op["fileName"])

    store.save_embeddings_dict(embeddings_dict)

    # Step 5: Write init log and mark init complete (currentDay → 1)
    topic_names = ", ".join(
        name for name in (display_topic_name(op["fileName"]) for op in topic_ops) if name
    )
    log_content = (
        f"# Day 0 — Initialization\n\n"
        f"Date: {datetime.now().strftime('%Y-%m-%d')}\n"
        f"Papers: {len(fulltext_papers)}\n"
        f"Abstract-only skipped: {abstract_only_count}\n"
        f"Topics: {topic_names}\n"
        f"Status: Baseline construction complete\n"
    )
    store.write_log(f"{datetime.now().strftime('%Y-%m-%d')}-init.md", log_content)

    # Update config: currentDay 0 → 1 (signals init is done; evolution requires >= 1)
    config_path = store.base_dir / "config.json"
    config_data = json.loads(config_path.read_text(encoding="utf-8"))
    config_data["currentDay"] = 1
    config_path.write_text(json.dumps(config_data, indent=2), encoding="utf-8")

    logger.info("[Init] Baseline complete: %d knowledge files, %d tokens",
                len(result["operations"]), result["tokens"])
    return result["tokens"]


# ---------------------------------------------------------------------------
# Phase 2: Evolution — incremental metabolism cycles
# ---------------------------------------------------------------------------

async def phase_evolution(topic: str, papers: list, store: FileSystemKnowledgeStore,
                          cache_dir: Path, evolution_start: str, evolution_end: str,
                          step_months: int) -> int:
    """
    Incremental knowledge metabolism with evolution-aware hypothesis generation.

    Each window: read papers → update knowledge → detect change signal →
    generate hypotheses with full evolution context.
    """
    if not papers:
        logger.warning("[Evolution] No papers in evolution time range, skipping.")
        return 0

    exp_cfg = config["experiment"]
    window_limit = max(1, exp_cfg.get("evolution_window_paper_limit", 8))
    backfill_on_empty = exp_cfg.get("evolution_backfill_on_empty", True)
    hypothesis_max_per_window = max(1, exp_cfg.get("hypothesis_max_per_window", 3))

    total_tokens = 0
    current_date = evolution_start
    deferred_papers = []

    # --- Evolution state tracking ---
    window_summaries = []          # list of {"period": str, "summary": str}
    windows_since_trigger = 0      # how many quiet windows since last hypothesis
    window_index = 0               # current window number

    def sort_recent_first(items: list) -> list:
        return sorted(items, key=lambda x: x["published"], reverse=True)

    def _build_evolution_trajectory() -> str:
        """Build a narrative of knowledge changes across windows."""
        if not window_summaries:
            return "(first window — no prior evolution history)"
        lines = []
        for ws in window_summaries:
            triggered_marker = " [HYPOTHESIS TRIGGERED]" if ws.get("triggered") else ""
            lines.append(f"- Period {ws['period']}: {ws['summary']}{triggered_marker}")
        return "\n".join(lines)

    def _build_existing_hypotheses_summary() -> str:
        """Summarize all hypotheses generated so far."""
        hyp_dir = store.dirs["hypotheses"]
        hyp_files = sorted(hyp_dir.glob("*.md"))
        if not hyp_files:
            return "(none yet)"
        summaries = []
        for fp in hyp_files:
            content = fp.read_text(encoding="utf-8")
            # Extract ID
            id_match = re.search(r"-(\d{3})\.md$", fp.name)
            hyp_id = f"H{id_match.group(1)}" if id_match else fp.stem
            # Extract statement
            stmt_match = re.search(r"## Statement\s*([\s\S]*?)(?=##|\Z)", content)
            stmt = stmt_match.group(1).strip()[:200] if stmt_match else "(no statement)"
            summaries.append(f"- {hyp_id}: {stmt}")
        return "\n".join(summaries)

    while current_date < evolution_end:
        next_date = add_months(current_date, step_months)
        period_label = current_date[:7]
        window_index += 1

        # --- Paper selection (identical to before) ---
        current_window_papers = [p for p in papers if current_date <= p["published"] < next_date]
        current_window_papers = sort_recent_first(current_window_papers)

        fresh_papers = current_window_papers[:window_limit]
        overflow_papers = current_window_papers[window_limit:]
        if overflow_papers:
            deferred_papers.extend(overflow_papers)
            deferred_papers = sort_recent_first(deferred_papers)
            logger.info("[Evolution] %s: deferred %d papers", period_label, len(overflow_papers))

        borrowed_papers = []
        if not fresh_papers and backfill_on_empty and deferred_papers:
            borrowed_papers = deferred_papers[:window_limit]
            deferred_papers = deferred_papers[window_limit:]
            logger.info("[Evolution] %s: backfilling %d papers", period_label, len(borrowed_papers))

        window_papers = fresh_papers or borrowed_papers
        if not window_papers:
            logger.info("[Evolution] %s: no papers, skipping", period_label)
            windows_since_trigger += 1
            window_summaries.append({
                "period": period_label, "summary": "No papers available", "triggered": False,
            })
            current_date = next_date
            continue

        resolved_window_papers = await resolve_paper_contents(
            window_papers, cache_dir, min(8, len(window_papers)), f"Evolution {period_label}",
        )
        counted_window_papers = [p for p in resolved_window_papers if p["counted_fulltext"]]
        abstract_only_skipped = len(resolved_window_papers) - len(counted_window_papers)

        if not counted_window_papers and backfill_on_empty and deferred_papers and not borrowed_papers:
            borrowed_papers = deferred_papers[:window_limit]
            deferred_papers = deferred_papers[window_limit:]
            logger.info("[Evolution] %s: backfilling after fulltext miss", period_label)
            resolved_window_papers = await resolve_paper_contents(
                borrowed_papers, cache_dir, min(8, len(borrowed_papers)),
                f"Evolution {period_label} backfill",
            )
            counted_window_papers = [p for p in resolved_window_papers if p["counted_fulltext"]]
            abstract_only_skipped += len(resolved_window_papers) - len(counted_window_papers)

        if not counted_window_papers:
            logger.info("[Evolution] %s: no counted fulltext papers, skipping", period_label)
            windows_since_trigger += 1
            window_summaries.append({
                "period": period_label, "summary": "Papers found but no fulltext available",
                "triggered": False,
            })
            current_date = next_date
            continue

        logger.info(
            "[Evolution] %s: %d counted fulltext papers (%s), abstract-only skipped=%d",
            period_label, len(counted_window_papers),
            "current-window" if fresh_papers and not borrowed_papers else "backfill",
            abstract_only_skipped,
        )

        # --- Build papers text ---
        papers_parts = []
        for p in counted_window_papers:
            paper_origin = "backfill" if borrowed_papers and p in counted_window_papers else "current-window"
            papers_parts.append(
                f"### {p['title']}\n- arxiv_id: {p['arxiv_id']}\n"
                f"- published: {p['published']}\n"
                f"- source_mode: {paper_origin}\n"
                f"- content_source: {p['content_source']}\n\n{p['content']}"
            )
        papers_text = "\n\n---\n\n".join(papers_parts)
        log_header = [
            f"# Period {period_label} — {len(counted_window_papers)} counted fulltext papers",
            "", f"Window: {window_index}", f"Windows since last trigger: {windows_since_trigger}",
            f"Current-window papers: {len(fresh_papers)}",
            f"Backfilled unread papers: {len(borrowed_papers)}",
            f"Abstract-only skipped: {abstract_only_skipped}",
            f"Deferred unread papers remaining: {len(deferred_papers)}", "",
        ]
        store.write_log(f"window-{period_label}-ingest.md", "\n".join(log_header) + papers_text)

        # --- Snapshot knowledge BEFORE update ---
        knowledge_before = store.get_joined_knowledge_content()

        # --- Update all topic files concurrently ---
        knowledge_files_map = {
            fn: c for fn, c in store.get_knowledge_files().items() if fn != "_index.md"
        }
        embeddings_dict = store.get_embeddings_dict()
        semaphore = asyncio.Semaphore(6)

        async def update_topic(file_name, topic_content):
            async with semaphore:
                return await run_topic_update_engine(file_name, topic_content, papers_text, period_label)

        results = await asyncio.gather(*[
            update_topic(fn, fc) for fn, fc in knowledge_files_map.items()
        ])
        total_tokens += sum(r["tokens"] for r in results)

        for res in results:
            store.write_knowledge_file(res["fileName"], res["fileContent"])
            logger.info("[Evolution] %s: updated %s", period_label, res["fileName"])

        new_vecs = await asyncio.gather(*[get_embedding(r["fileContent"]) for r in results])
        for res, vec in zip(results, new_vecs):
            embeddings_dict[res["fileName"]] = vec
        store.save_embeddings_dict(embeddings_dict)

        # --- Discover new topic files ---
        existing_topic_names = list(knowledge_files_map.keys())
        discover_res = await run_topic_discover_engine(existing_topic_names, papers_text, period_label)
        total_tokens += discover_res["tokens"]
        new_topics_this_window = []
        if discover_res["operations"]:
            discover_vecs = await asyncio.gather(
                *[get_embedding(op["fileContent"]) for op in discover_res["operations"]]
            )
            for op, vec in zip(discover_res["operations"], discover_vecs):
                store.write_knowledge_file(op["fileName"], op["fileContent"])
                embeddings_dict[op["fileName"]] = vec
                new_topics_this_window.append(op["fileName"])
                logger.info("[Evolution] %s: new topic: %s", period_label, op["fileName"])
            store.save_embeddings_dict(embeddings_dict)

        # --- Snapshot knowledge AFTER update ---
        knowledge_after = store.get_joined_knowledge_content()

        # --- Trigger detection: should we generate hypotheses? ---
        trigger_res = await run_trigger_detection(
            knowledge_before, knowledge_after, papers_text,
            period_label, windows_since_trigger,
        )
        total_tokens += trigger_res["tokens"]

        # Build window summary for trajectory
        window_summary = trigger_res.get("reason", "").strip() or f"{len(counted_window_papers)} papers processed"
        if new_topics_this_window:
            window_summary += f"; new topics: {', '.join(new_topics_this_window)}"
        key_changes = trigger_res.get("key_changes", [])
        if key_changes:
            window_summary += "; changes: " + "; ".join(str(c) for c in key_changes[:3])

        change_type = trigger_res["change_type"]

        # Log change characterization
        store.write_log(
            f"change-{period_label}.md",
            (
                f"# Knowledge Change — {period_label}\n\n"
                f"Window: {window_index}\n"
                f"Change type: {change_type}\n"
                f"Reason: {trigger_res.get('reason', '')}\n"
                f"Key changes: {key_changes}\n"
            ),
        )

        logger.info(
            "[Evolution] %s: change_type=%s, generating hypotheses with evolution context",
            period_label, change_type,
        )

        # --- Always generate hypotheses, with full evolution context ---
        evolution_trajectory = _build_evolution_trajectory()
        existing_hypotheses = _build_existing_hypotheses_summary()
        starting_hypothesis_index = store.count_hypotheses() + 1

        hyp_res = await run_ckm_hypothesis_engine(
            topic=topic,
            knowledge_content=knowledge_after,
            papers_text=papers_text,
            time_period=period_label,
            max_hypotheses=hypothesis_max_per_window,
            evolution_trajectory=evolution_trajectory,
            existing_hypotheses=existing_hypotheses,
            trigger_type=change_type,
            trigger_reason=trigger_res.get("reason", ""),
            n_windows=window_index,
        )
        total_tokens += hyp_res["tokens"]

        accepted_hypotheses = hyp_res.get("hypotheses", [])
        for offset, hypothesis in enumerate(accepted_hypotheses):
            hyp_id_string = str(starting_hypothesis_index + offset).zfill(3)
            rendered_content = re.sub(
                r"^# Hypothesis H\d+",
                f"# Hypothesis H{hyp_id_string}",
                hypothesis["content"],
                count=1,
            )
            store.write_hypothesis(f"{period_label}-{hyp_id_string}", rendered_content)
            logger.info("[Evolution] %s: hypothesis H%s generated (change=%s)",
                        period_label, hyp_id_string, change_type)

        rejections = hyp_res.get("rejections", [])
        if rejections:
            for rejection in rejections:
                rejection_label = str(rejection.get("candidate_index", 0)).zfill(3)
                store.write_log(
                    f"hypothesis-rejected-{period_label}-{rejection_label}.md",
                    (
                        f"# Rejected Hypothesis Candidate {rejection_label}\n\n"
                        f"Period: {period_label}\n"
                        f"Change type: {change_type}\n"
                        f"Reason: {rejection.get('rejection_reason', '').strip() or '(unknown)'}\n\n"
                        f"## Candidate Content\n\n{rejection.get('content', '').strip() or '(empty)'}\n"
                    ),
                )
        elif not accepted_hypotheses:
            engine_error = hyp_res.get("error", "").strip()
            if engine_error:
                store.write_log(
                    f"hypothesis-engine-error-{period_label}.md",
                    f"# Hypothesis Engine Error\n\nPeriod: {period_label}\nError: {engine_error}\n",
                )
                logger.warning("[Evolution] %s: hypothesis engine error (%s)", period_label, engine_error)
            else:
                logger.info("[Evolution] %s: no hypothesis met quality bar", period_label)

        window_summaries.append({
            "period": period_label, "summary": window_summary,
            "change_type": change_type,
            "hypotheses_generated": len(accepted_hypotheses),
        })

        current_date = next_date

    logger.info("[Evolution] Complete. Total tokens: %d", total_tokens)
    return total_tokens


# ---------------------------------------------------------------------------
# Phase 3: Validation — evaluate hypotheses against future papers
# ---------------------------------------------------------------------------

async def phase_validation(all_papers: list, store: FileSystemKnowledgeStore,
                           validation_start: str, validation_end: str, topic: str,
                           ablation_mode: str = "none", cache_dir: Path = None) -> None:
    """Evaluate CKM hypotheses against papers published after evolution ended."""
    from tools.calculate_metrics import calculate_metrics
    await calculate_metrics(all_papers, validation_start, validation_end,
                            topic, ablation_mode, store, cache_dir)


# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------

async def run():
    parser = argparse.ArgumentParser(description="CKM Eval: 3-Phase Research Evolution System")
    parser.add_argument("topic", nargs="?", default="AI for software engineering")
    parser.add_argument("--metabolism_dir", type=str)
    parser.add_argument("--report_dir", type=str)
    args = parser.parse_args()

    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
    )

    # Resolve paths
    project_dir = Path(__file__).parent.parent.absolute()  # ckm-eval/
    run_dir = project_dir / "results" / datetime.now().strftime("%Y%m%d_%H%M%S")
    metabolism_dir = Path(args.metabolism_dir).absolute() if args.metabolism_dir else run_dir / "metabolism"
    metabolism_dir.mkdir(parents=True, exist_ok=True)
    store = FileSystemKnowledgeStore(metabolism_dir)
    cache_dir = metabolism_dir / "fulltext_cache"

    report_dir = Path(args.report_dir).absolute() if args.report_dir else run_dir / "reports"
    config["paths"]["reports"] = report_dir

    # Read time parameters from config
    exp = config["experiment"]
    init_start = exp["init_start"]
    init_end = exp["init_end"]
    evo_start = exp["evolution_start"]
    evo_end = exp["evolution_end"]
    step_months = exp["evolution_step_months"]
    val_start = exp["validation_start"]
    val_end = exp["validation_end"]
    max_papers = exp["max_papers"]
    init_budget = min(exp["init_max_papers"], max_papers)
    evo_budget = min(exp["evolution_max_papers"], max_papers)
    val_min_papers = min(exp["validation_min_papers"], max_papers)
    val_budget = max(min(exp["validation_max_papers"], max_papers), val_min_papers)
    phase_search_window_months = exp["phase_search_window_months"]
    evolution_search_window_months = max(1, min(phase_search_window_months, step_months))
    evolution_window_paper_limit = max(1, exp["evolution_window_paper_limit"])
    evolution_backfill_on_empty = exp["evolution_backfill_on_empty"]
    hypothesis_max_per_window = max(1, exp["hypothesis_max_per_window"])

    topic = args.topic
    logger.info("=" * 70)
    logger.info("CKM Eval | Topic: %s", topic)
    logger.info("  Phase 1 (Init):       %s ~ %s", init_start, init_end)
    logger.info("  Phase 2 (Evolution):  %s ~ %s (step=%d months)", evo_start, evo_end, step_months)
    logger.info("  Evolution read mode:  per-window limit=%d, backfill_on_empty=%s",
                evolution_window_paper_limit, evolution_backfill_on_empty)
    logger.info("  Hypotheses/window:    up to %d accepted hypotheses", hypothesis_max_per_window)
    logger.info("  Phase 3 (Validation): %s ~ %s", val_start, val_end)
    logger.info(
        "  Search budgets:       init=%d, evolution=%d, validation=%d (min=%d, init/val window=%d months, evolution window=%d months)",
        init_budget,
        evo_budget,
        val_budget,
        val_min_papers,
        phase_search_window_months,
        evolution_search_window_months,
    )
    logger.info("=" * 70)

    # Fetch papers per phase so early years cannot consume the entire global cap.
    init_papers = fetch_phase_papers(
        topic, init_start, init_end, init_budget, phase_search_window_months,
        phase_name="Init",
    )
    evo_papers = fetch_phase_papers(
        topic, evo_start, evo_end, evo_budget, evolution_search_window_months,
        phase_name="Evolution",
    )
    val_papers = fetch_phase_papers(
        topic,
        val_start,
        val_end,
        val_budget,
        phase_search_window_months,
        min_papers=val_min_papers,
        phase_name="Validation",
    )
    all_papers = dedupe_and_sort_papers(init_papers + evo_papers + val_papers)

    if not all_papers:
        logger.error("No papers found for topic '%s', aborting.", topic)
        return

    all_papers.sort(key=lambda x: x["published"])
    logger.info("Fetched %d papers (%s ~ %s)",
                len(all_papers), all_papers[0]["published"][:10], all_papers[-1]["published"][:10])

    logger.info("Paper split: init=%d, evolution=%d, validation=%d",
                len(init_papers), len(evo_papers), len(val_papers))

    # --- Resume logic: skip completed phases ---
    config_path = store.base_dir / "config.json"
    current_day = 0
    if config_path.exists():
        current_day = json.loads(config_path.read_text(encoding="utf-8")).get("currentDay", 0)

    init_tokens = 0
    evo_tokens = 0

    # Phase 1 is done when currentDay >= 1
    phase1_done = current_day >= 1
    # Phase 2 is done when hypotheses exist (at least 1 file)
    phase2_done = phase1_done and store.count_hypotheses() > 0 and len(list(store.dirs["hypotheses"].glob("*.md"))) > 0
    # Check if last window's hypotheses exist to confirm full completion
    if phase2_done:
        from dateutil.relativedelta import relativedelta
        from datetime import datetime as dt
        last_window_start = dt.strptime(evo_end, "%Y-%m-%d") - relativedelta(months=step_months)
        last_period = f"{last_window_start.strftime('%Y-%m')}~{dt.strptime(evo_end, '%Y-%m-%d').strftime('%Y-%m')}"
        last_period_short = last_window_start.strftime('%Y-%m')
        last_hyps = list(store.dirs["hypotheses"].glob(f"hyp-{last_period_short}-*.md"))
        phase2_done = len(last_hyps) > 0

    # Phase 3 is done when report exists
    report_dir_path = config["paths"]["reports"]
    phase3_done = any(f.endswith("_Report.md") or f.endswith("_report.md")
                      for f in (os.listdir(report_dir_path) if report_dir_path.exists() else []))

    if phase3_done:
        logger.info("All phases already complete, skipping.")
        return

    # Phase 1: Init
    if phase1_done:
        logger.info("--- Phase 1: Init --- SKIPPED (already complete)")
        # Recover init_tokens from token_usage if available
        token_path = store.base_dir / "token_usage.json"
        if token_path.exists():
            init_tokens = json.loads(token_path.read_text(encoding="utf-8")).get("init_tokens", 0)
    else:
        logger.info("--- Phase 1: Init ---")
        init_tokens = await phase_init(topic, init_papers, store, cache_dir)

    # Phase 2: Evolution
    if phase2_done:
        logger.info("--- Phase 2: Evolution --- SKIPPED (already complete)")
        token_path = store.base_dir / "token_usage.json"
        if token_path.exists():
            evo_tokens = json.loads(token_path.read_text(encoding="utf-8")).get("evolution_tokens", 0)
    else:
        logger.info("--- Phase 2: Evolution ---")
        evo_tokens = await phase_evolution(
            topic, evo_papers, store, cache_dir,
            evo_start, evo_end, step_months,
        )

    # Phase 3: Validation
    logger.info("--- Phase 3: Validation ---")
    await phase_validation(all_papers, store, val_start, val_end, topic, "none", cache_dir)

    total_tokens = init_tokens + evo_tokens
    logger.info("All phases complete. Total LLM tokens: %d", total_tokens)

    # Persist token usage for batch_run to pick up
    token_usage = {
        "init_tokens": init_tokens,
        "evolution_tokens": evo_tokens,
        "total_generation_tokens": total_tokens,
    }
    (store.base_dir / "token_usage.json").write_text(
        json.dumps(token_usage, indent=2), encoding="utf-8"
    )


if __name__ == "__main__":
    asyncio.run(run())
