"""
Ablation 2: Abstract-Only — CKM without fulltext reading.

Identical to CKM except: skips fulltext download, all papers use abstract only.
Uses CKM hypothesis engine for strict single-variable ablation.
Tests the value of deep reading (fulltext) in knowledge metabolism.
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))

import asyncio
import argparse
import logging
import re
from datetime import datetime

from config import config
from core.store import FileSystemKnowledgeStore
import json

from core.engines import (
    run_read_engine, run_init_engine, run_topic_update_engine, run_topic_discover_engine,
    run_ckm_hypothesis_engine, run_trigger_detection,
    get_embedding,
)
from tools.arxiv_search import search_arxiv_topic

logger = logging.getLogger("ablation2.main")

# Helpers — identical to CKM
def add_months(date_str, months):
    dt = datetime.strptime(date_str, "%Y-%m-%d")
    m = dt.month - 1 + months
    y = dt.year + m // 12
    m = m % 12 + 1
    return dt.replace(year=y, month=m, day=1).strftime("%Y-%m-%d")

def build_phase_windows(start_date, end_date, window_months):
    windows = []
    current = start_date
    while current < end_date:
        nd = add_months(current, window_months)
        if nd > end_date: nd = end_date
        windows.append((current, nd))
        current = nd
    return windows

def allocate_window_budgets(total, n):
    if total <= 0 or n <= 0: return []
    base = total // n
    rem = total % n
    return [base + (1 if i < rem else 0) for i in range(n)]

def dedupe_and_sort_papers(papers):
    u = {}
    for p in papers: u.setdefault(p["arxiv_id"], p)
    return sorted(u.values(), key=lambda x: x["published"])


async def resolve_paper_contents_abstract_only(papers, phase_name):
    """ABLATION: skip fulltext, use abstract as content. All papers count as valid."""
    resolved = []
    for p in papers:
        r = dict(p)
        r["content"] = p["abstract"]
        r["content_source"] = "abstract-only-ablation"
        r["counted_fulltext"] = True
        resolved.append(r)
    logger.info("[%s] Abstract-Only: %d papers (no fulltext download)", phase_name, len(resolved))
    return resolved


def fetch_phase_papers(topic, start_date, end_date, total_budget, window_months, min_papers=0, phase_name="Search"):
    windows = build_phase_windows(start_date, end_date, window_months)
    budgets = allocate_window_budgets(max(total_budget, min_papers), len(windows))
    collected = []
    for (ws, we), budget in zip(windows, budgets):
        if budget <= 0: continue
        papers = search_arxiv_topic(topic, budget, start_date=ws, end_date=we)
        logger.info("[%s] %s~%s: %d papers (budget=%d)", phase_name, ws, we, len(papers), budget)
        collected.extend(papers)
    return dedupe_and_sort_papers(collected)


# Phase 1: Init — abstract only
async def phase_init(topic, papers, store):
    if not papers: return 0
    resolved = await resolve_paper_contents_abstract_only(papers, "Init")
    sem = asyncio.Semaphore(6)
    async def extract(p):
        async with sem: return await run_read_engine(p["title"], p["arxiv_id"], p["published"], p["content"])
    extractions = await asyncio.gather(*[extract(p) for p in resolved])
    cfg_path = store.base_dir / "config.json"
    cfg_path.write_text(json.dumps({"topic": topic, "processed_ids": [p["arxiv_id"] for p in resolved], "currentDay": 0}, indent=2), encoding="utf-8")
    result = await run_init_engine(topic, "\n\n---\n\n".join(extractions), len(resolved))
    if not [op for op in result["operations"] if op["fileName"] != "_index.md"]:
        raise RuntimeError(f"Init: no topic files: {result.get('error')}")
    emb = store.get_embeddings_dict()
    for op in result["operations"]: store.write_knowledge_file(op["fileName"], op["fileContent"])
    vecs = await asyncio.gather(*[get_embedding(op["fileContent"]) for op in result["operations"]])
    for op, vec in zip(result["operations"], vecs): emb[op["fileName"]] = vec
    store.save_embeddings_dict(emb)
    cfg = json.loads(cfg_path.read_text(encoding="utf-8")); cfg["currentDay"] = 1
    cfg_path.write_text(json.dumps(cfg, indent=2), encoding="utf-8")
    logger.info("[Init] Complete (abstract-only): %d files, %d tokens", len(result["operations"]), result["tokens"])
    return result["tokens"]


# Phase 2: Evolution — CKM with abstract only
async def phase_evolution(topic, papers, store, evo_start, evo_end, step_months):
    """Full CKM evolution but all papers use abstract only."""
    if not papers: return 0

    exp_cfg = config["experiment"]
    window_limit = max(1, exp_cfg.get("evolution_window_paper_limit", 8))
    backfill_on_empty = exp_cfg.get("evolution_backfill_on_empty", True)
    hyp_max = max(1, exp_cfg.get("hypothesis_max_per_window", 3))
    total_tokens = 0
    current_date = evo_start
    deferred = []
    window_summaries = []
    window_index = 0

    def sort_rf(items): return sorted(items, key=lambda x: x["published"], reverse=True)
    def _trajectory():
        if not window_summaries: return "(first window)"
        return "\n".join(f"- Period {ws['period']}: {ws['summary']}" for ws in window_summaries)
    def _existing_hyps():
        files = sorted(store.dirs["hypotheses"].glob("*.md"))
        if not files: return "(none yet)"
        out = []
        for fp in files:
            c = fp.read_text(encoding="utf-8")
            m = re.search(r"-(\d{3})\.md$", fp.name)
            hid = f"H{m.group(1)}" if m else fp.stem
            sm = re.search(r"## Statement\s*([\s\S]*?)(?=##|\Z)", c)
            out.append(f"- {hid}: {sm.group(1).strip()[:200] if sm else '?'}")
        return "\n".join(out)

    while current_date < evo_end:
        next_date = add_months(current_date, step_months)
        period = current_date[:7]
        window_index += 1
        wp = [p for p in papers if current_date <= p["published"] < next_date]
        wp = sort_rf(wp)
        fresh = wp[:window_limit]
        if wp[window_limit:]:
            deferred.extend(wp[window_limit:]); deferred = sort_rf(deferred)
        borrowed = []
        if not fresh and backfill_on_empty and deferred:
            borrowed = deferred[:window_limit]; deferred = deferred[window_limit:]
        window_papers = fresh or borrowed
        if not window_papers:
            window_summaries.append({"period": period, "summary": "No papers"})
            current_date = next_date; continue

        # *** ABLATION: abstract only, no fulltext ***
        counted = await resolve_paper_contents_abstract_only(window_papers, f"Evo {period}")

        parts = [f"### {p['title']}\n- arxiv_id: {p['arxiv_id']}\n- published: {p['published']}\n- content_source: {p['content_source']}\n\n{p['content']}" for p in counted]
        papers_text = "\n\n---\n\n".join(parts)
        store.write_log(f"window-{period}-ingest.md", f"# {period} — {len(counted)} papers (ABSTRACT-ONLY)\n\n{papers_text}")

        kb = store.get_joined_knowledge_content()

        # Knowledge update — identical to CKM
        kf = {fn: c for fn, c in store.get_knowledge_files().items() if fn != "_index.md"}
        emb = store.get_embeddings_dict()
        sem = asyncio.Semaphore(6)
        async def upd(fn, fc):
            async with sem: return await run_topic_update_engine(fn, fc, papers_text, period)
        res = await asyncio.gather(*[upd(fn, fc) for fn, fc in kf.items()])
        total_tokens += sum(r["tokens"] for r in res)
        for r in res: store.write_knowledge_file(r["fileName"], r["fileContent"])
        nvecs = await asyncio.gather(*[get_embedding(r["fileContent"]) for r in res])
        for r, v in zip(res, nvecs): emb[r["fileName"]] = v
        store.save_embeddings_dict(emb)

        disc = await run_topic_discover_engine(list(kf.keys()), papers_text, period)
        total_tokens += disc["tokens"]
        new_topics = []
        if disc["operations"]:
            dvecs = await asyncio.gather(*[get_embedding(op["fileContent"]) for op in disc["operations"]])
            for op, v in zip(disc["operations"], dvecs):
                store.write_knowledge_file(op["fileName"], op["fileContent"]); emb[op["fileName"]] = v; new_topics.append(op["fileName"])
            store.save_embeddings_dict(emb)

        ka = store.get_joined_knowledge_content()

        # Trigger detection — identical to CKM
        tr = await run_trigger_detection(kb, ka, papers_text, period, 0)
        total_tokens += tr["tokens"]
        ct = tr["change_type"]
        ws = tr.get("reason", "").strip() or f"{len(counted)} papers"
        if new_topics: ws += f"; new topics: {', '.join(new_topics)}"
        kc = tr.get("key_changes", [])
        if kc: ws += "; " + "; ".join(kc[:3])
        store.write_log(f"change-{period}.md", f"# Change — {period}\nType: {ct}\nReason: {tr.get('reason','')}\n")

        # Hypothesis generation — CKM engine
        si = store.count_hypotheses() + 1
        hr = await run_ckm_hypothesis_engine(topic=topic, knowledge_content=ka, papers_text=papers_text,
            time_period=period, max_hypotheses=hyp_max, evolution_trajectory=_trajectory(),
            existing_hypotheses=_existing_hyps(), trigger_type=ct, trigger_reason=tr.get("reason",""), n_windows=window_index)
        total_tokens += hr["tokens"]
        for off, h in enumerate(hr.get("hypotheses", [])):
            hid = str(si + off).zfill(3)
            store.write_hypothesis(f"{period}-{hid}", re.sub(r"^# Hypothesis H\d+", f"# Hypothesis H{hid}", h["content"], count=1))
            logger.info("[Evo] %s: H%s (abstract-only)", period, hid)

        window_summaries.append({"period": period, "summary": ws, "change_type": ct, "hypotheses_generated": len(hr.get("hypotheses", []))})
        current_date = next_date

    logger.info("[Evolution] Complete. Tokens: %d", total_tokens)
    return total_tokens


# Phase 3 — identical to CKM (uses fulltext for fair evaluation)
async def phase_validation(all_papers, store, val_start, val_end, topic, ablation="none", cache_dir=None):
    from tools.calculate_metrics import calculate_metrics
    await calculate_metrics(all_papers, val_start, val_end, topic, ablation, store, cache_dir)


async def run():
    parser = argparse.ArgumentParser(description="Ablation 2: Abstract-Only")
    parser.add_argument("topic", nargs="?", default="AI for software engineering")
    parser.add_argument("--metabolism_dir", type=str)
    parser.add_argument("--report_dir", type=str)
    args = parser.parse_args()
    logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(levelname)s: %(message)s")

    project_dir = Path(__file__).parent.parent.parent.absolute()
    run_dir = project_dir / "results" / ("ablation2_" + datetime.now().strftime("%Y%m%d_%H%M%S"))
    metabolism_dir = Path(args.metabolism_dir).absolute() if args.metabolism_dir else run_dir / "metabolism"
    metabolism_dir.mkdir(parents=True, exist_ok=True)
    store = FileSystemKnowledgeStore(metabolism_dir)
    cache_dir = metabolism_dir / "fulltext_cache"
    report_dir = Path(args.report_dir).absolute() if args.report_dir else run_dir / "reports"
    config["paths"]["reports"] = report_dir
    exp = config["experiment"]
    topic = args.topic

    logger.info("=" * 70)
    logger.info("Ablation 2: Abstract-Only | Topic: %s", topic)
    logger.info("=" * 70)

    init_papers = fetch_phase_papers(topic, exp["init_start"], exp["init_end"], min(exp["init_max_papers"], exp["max_papers"]), exp["phase_search_window_months"], phase_name="Init")
    evo_papers = fetch_phase_papers(topic, exp["evolution_start"], exp["evolution_end"], min(exp["evolution_max_papers"], exp["max_papers"]), max(1, min(exp["phase_search_window_months"], exp["evolution_step_months"])), phase_name="Evolution")
    val_papers = fetch_phase_papers(topic, exp["validation_start"], exp["validation_end"], max(min(exp["validation_max_papers"], exp["max_papers"]), exp["validation_min_papers"]), exp["phase_search_window_months"], min_papers=exp["validation_min_papers"], phase_name="Validation")
    all_papers = dedupe_and_sort_papers(init_papers + evo_papers + val_papers)
    if not all_papers:
        logger.error("No papers, aborting."); return

    logger.info("Papers: %d (init=%d, evo=%d, val=%d)", len(all_papers), len(init_papers), len(evo_papers), len(val_papers))

    logger.info("--- Phase 1: Init (Abstract-Only) ---")
    init_tok = await phase_init(topic, init_papers, store)
    logger.info("--- Phase 2: Evolution (Abstract-Only) ---")
    evo_tok = await phase_evolution(topic, evo_papers, store, exp["evolution_start"], exp["evolution_end"], exp["evolution_step_months"])
    logger.info("--- Phase 3: Validation ---")
    await phase_validation(all_papers, store, exp["validation_start"], exp["validation_end"], topic, "none", cache_dir)
    logger.info("Done. Total tokens: %d", init_tok + evo_tok)

    (store.base_dir / "token_usage.json").write_text(
        json.dumps({"init_tokens": init_tok, "evolution_tokens": evo_tok, "total_generation_tokens": init_tok + evo_tok}, indent=2), encoding="utf-8"
    )

if __name__ == "__main__":
    asyncio.run(run())
