"""
Lite: No-Diff — CKM without knowledge change awareness.

Identical to CKM except: skips K_before/K_after snapshots and
run_trigger_detection entirely. The hypothesis engine receives NO
information about what changed — no change_type, no reason, no diff.
Tests the value of "what changed" awareness in hypothesis generation.
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))

import asyncio
import argparse
import logging
import re
from datetime import datetime

from config import config
from core.store import FileSystemKnowledgeStore
import json

from core.engines import (
    run_read_engine, run_init_engine,
    run_ckm_hypothesis_engine,
    get_embedding,
)
from tools.arxiv_search import search_arxiv_topic
from tools.arxiv_fulltext import get_paper_content_record

logger = logging.getLogger("lite.main")

# Helpers — identical to CKM
def add_months(date_str, months):
    dt = datetime.strptime(date_str, "%Y-%m-%d")
    m = dt.month - 1 + months
    y = dt.year + m // 12
    m = m % 12 + 1
    return dt.replace(year=y, month=m, day=1).strftime("%Y-%m-%d")

def build_phase_windows(start_date, end_date, window_months):
    windows = []
    current = start_date
    while current < end_date:
        nd = add_months(current, window_months)
        if nd > end_date: nd = end_date
        windows.append((current, nd))
        current = nd
    return windows

def allocate_window_budgets(total, n):
    if total <= 0 or n <= 0: return []
    base = total // n
    rem = total % n
    return [base + (1 if i < rem else 0) for i in range(n)]

def dedupe_and_sort_papers(papers):
    u = {}
    for p in papers: u.setdefault(p["arxiv_id"], p)
    return sorted(u.values(), key=lambda x: x["published"])

async def resolve_paper_contents(papers, cache_dir, concurrency, phase_name):
    exp_cfg = config["experiment"]
    sem = asyncio.Semaphore(max(1, concurrency))
    async def resolve(paper):
        async with sem:
            rec = await get_paper_content_record(paper["arxiv_id"], paper["abstract"], cache_dir,
                timeout_s=max(1, exp_cfg.get("fulltext_timeout_s", 30)),
                retries=max(1, exp_cfg.get("fulltext_retries", 2)),
                retry_delay_s=max(0, exp_cfg.get("fulltext_retry_delay_s", 2)))
        r = dict(paper); r["content"]=rec["content"]; r["content_source"]=rec["source"]; r["counted_fulltext"]=rec["counted_fulltext"]
        return r
    resolved = await asyncio.gather(*[resolve(p) for p in papers])
    cnt = sum(1 for p in resolved if p["counted_fulltext"])
    logger.info("[%s] Resolved: total=%d, fulltext=%d", phase_name, len(resolved), cnt)
    return resolved

def fetch_phase_papers(topic, start_date, end_date, total_budget, window_months, min_papers=0, phase_name="Search"):
    windows = build_phase_windows(start_date, end_date, window_months)
    budgets = allocate_window_budgets(max(total_budget, min_papers), len(windows))
    collected = []
    for (ws, we), budget in zip(windows, budgets):
        if budget <= 0: continue
        papers = search_arxiv_topic(topic, budget, start_date=ws, end_date=we)
        logger.info("[%s] %s~%s: %d papers (budget=%d)", phase_name, ws, we, len(papers), budget)
        collected.extend(papers)
    return dedupe_and_sort_papers(collected)


# Phase 1: Init — identical to CKM
async def phase_init(topic, papers, store, cache_dir):
    if not papers: return 0
    resolved = await resolve_paper_contents(papers, cache_dir, 12, "Init")
    fulltext = [p for p in resolved if p["counted_fulltext"]]
    if not fulltext: raise RuntimeError("Init: zero fulltext")
    sem = asyncio.Semaphore(6)
    async def extract(p):
        async with sem: return await run_read_engine(p["title"], p["arxiv_id"], p["published"], p["content"])
    extractions = await asyncio.gather(*[extract(p) for p in fulltext])
    cfg_path = store.base_dir / "config.json"
    cfg_path.write_text(json.dumps({"topic": topic, "processed_ids": [p["arxiv_id"] for p in fulltext], "currentDay": 0}, indent=2), encoding="utf-8")
    result = await run_init_engine(topic, "\n\n---\n\n".join(extractions), len(fulltext))
    if not [op for op in result["operations"] if op["fileName"] != "_index.md"]:
        raise RuntimeError(f"Init: no topic files: {result.get('error')}")
    emb = store.get_embeddings_dict()
    for op in result["operations"]: store.write_knowledge_file(op["fileName"], op["fileContent"])
    vecs = await asyncio.gather(*[get_embedding(op["fileContent"]) for op in result["operations"]])
    for op, vec in zip(result["operations"], vecs): emb[op["fileName"]] = vec
    store.save_embeddings_dict(emb)
    cfg = json.loads(cfg_path.read_text(encoding="utf-8")); cfg["currentDay"] = 1
    cfg_path.write_text(json.dumps(cfg, indent=2), encoding="utf-8")
    logger.info("[Init] Complete: %d files, %d tokens", len(result["operations"]), result["tokens"])
    return result["tokens"]


# Phase 2: Evolution — CKM WITHOUT change detection
async def phase_evolution(topic, papers, store, cache_dir, evo_start, evo_end, step_months):
    """CKM evolution but with NO diff awareness — no K_before/K_after, no change detection."""
    if not papers: return 0

    exp_cfg = config["experiment"]
    window_limit = max(1, exp_cfg.get("evolution_window_paper_limit", 8))
    backfill_on_empty = exp_cfg.get("evolution_backfill_on_empty", True)
    hyp_max = max(1, exp_cfg.get("hypothesis_max_per_window", 3))
    total_tokens = 0
    current_date = evo_start
    deferred = []
    window_summaries = []
    window_index = 0

    def sort_rf(items): return sorted(items, key=lambda x: x["published"], reverse=True)
    def _trajectory():
        if not window_summaries: return "(first window)"
        return "\n".join(f"- Period {ws['period']}: {ws['summary']}" for ws in window_summaries)
    def _existing_hyps():
        files = sorted(store.dirs["hypotheses"].glob("*.md"))
        if not files: return "(none yet)"
        out = []
        for fp in files:
            c = fp.read_text(encoding="utf-8")
            m = re.search(r"-(\d{3})\.md$", fp.name)
            hid = f"H{m.group(1)}" if m else fp.stem
            sm = re.search(r"## Statement\s*([\s\S]*?)(?=##|\Z)", c)
            out.append(f"- {hid}: {sm.group(1).strip()[:200] if sm else '?'}")
        return "\n".join(out)

    while current_date < evo_end:
        next_date = add_months(current_date, step_months)
        period = current_date[:7]
        window_index += 1
        wp = [p for p in papers if current_date <= p["published"] < next_date]
        wp = sort_rf(wp)
        fresh = wp[:window_limit]
        if wp[window_limit:]:
            deferred.extend(wp[window_limit:]); deferred = sort_rf(deferred)
        borrowed = []
        if not fresh and backfill_on_empty and deferred:
            borrowed = deferred[:window_limit]; deferred = deferred[window_limit:]
        window_papers = fresh or borrowed
        if not window_papers:
            window_summaries.append({"period": period, "summary": "No papers"})
            current_date = next_date; continue
        resolved = await resolve_paper_contents(window_papers, cache_dir, min(8, len(window_papers)), f"Evo {period}")
        counted = [p for p in resolved if p["counted_fulltext"]]
        if not counted and backfill_on_empty and deferred and not borrowed:
            borrowed = deferred[:window_limit]; deferred = deferred[window_limit:]
            resolved = await resolve_paper_contents(borrowed, cache_dir, min(8, len(borrowed)), f"Evo {period} bf")
            counted = [p for p in resolved if p["counted_fulltext"]]
        if not counted:
            window_summaries.append({"period": period, "summary": "No fulltext"})
            current_date = next_date; continue

        parts = [f"### {p['title']}\n- arxiv_id: {p['arxiv_id']}\n- published: {p['published']}\n- content_source: {p['content_source']}\n\n{p['content']}" for p in counted]
        papers_text = "\n\n---\n\n".join(parts)
        store.write_log(f"window-{period}-ingest.md", f"# {period} — {len(counted)} papers (NO-DIFF)\n\n{papers_text}")

        # *** ABLATION: NO DIFF — directly append papers to knowledge, no comparative update ***
        # Instead of topic_update_engine (which diffs new vs existing knowledge),
        # we simply append paper summaries to an accumulator file.
        logger.info("[Ablation3] %s: NO-DIFF — directly appending papers to knowledge (no comparative update)", period)

        # Append papers to a single accumulator knowledge file (no structured diff)
        accumulator_file = "topic-accumulated-papers.md"
        existing_content = ""
        kf = store.get_knowledge_files()
        if accumulator_file in kf:
            existing_content = kf[accumulator_file]

        # Build per-paper summaries to append
        new_entries = []
        for p in counted:
            new_entries.append(
                f"\n### [{period}] {p['title']}\n"
                f"- arxiv_id: {p['arxiv_id']}\n"
                f"- published: {p['published']}\n"
                f"- content_source: {p['content_source']}\n\n"
                f"{p['content'][:2000]}\n"  # truncate to avoid unbounded growth
            )
        updated_content = existing_content + "\n".join(new_entries)
        store.write_knowledge_file(accumulator_file, updated_content)

        # Update embedding for the accumulator file
        emb = store.get_embeddings_dict()
        emb[accumulator_file] = await get_embedding(updated_content[:8000])
        store.save_embeddings_dict(emb)

        # Skip topic_discover_engine — no structured topic tracking in no-diff mode
        # Skip trigger_detection — no K_before/K_after comparison
        window_summary = f"{len(counted)} papers appended directly (no diff)"

        # Hypothesis generation — CKM engine but with NO change awareness
        si = store.count_hypotheses() + 1
        hr = await run_ckm_hypothesis_engine(topic=topic, knowledge_content=store.get_joined_knowledge_content(),
            papers_text=papers_text, time_period=period, max_hypotheses=hyp_max,
            evolution_trajectory=_trajectory(), existing_hypotheses=_existing_hyps(),
            trigger_type="NONE", trigger_reason="",
            n_windows=window_index)
        total_tokens += hr["tokens"]
        for off, h in enumerate(hr.get("hypotheses", [])):
            hid = str(si + off).zfill(3)
            store.write_hypothesis(f"{period}-{hid}", re.sub(r"^# Hypothesis H\d+", f"# Hypothesis H{hid}", h["content"], count=1))
            logger.info("[Evo] %s: H%s (no-diff)", period, hid)

        window_summaries.append({"period": period, "summary": window_summary, "change_type": "NONE", "hypotheses_generated": len(hr.get("hypotheses", []))})
        current_date = next_date

    logger.info("[Evolution] Complete. Tokens: %d", total_tokens)
    return total_tokens


async def phase_validation(all_papers, store, val_start, val_end, topic, ablation="none", cache_dir=None):
    from tools.calculate_metrics import calculate_metrics
    await calculate_metrics(all_papers, val_start, val_end, topic, ablation, store, cache_dir)


async def run():
    parser = argparse.ArgumentParser(description="Lite: No-Diff")
    parser.add_argument("topic", nargs="?", default="AI for software engineering")
    parser.add_argument("--metabolism_dir", type=str)
    parser.add_argument("--report_dir", type=str)
    args = parser.parse_args()
    logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(levelname)s: %(message)s")

    project_dir = Path(__file__).parent.parent.parent.absolute()
    run_dir = project_dir / "results" / ("lite_" + datetime.now().strftime("%Y%m%d_%H%M%S"))
    metabolism_dir = Path(args.metabolism_dir).absolute() if args.metabolism_dir else run_dir / "metabolism"
    metabolism_dir.mkdir(parents=True, exist_ok=True)
    store = FileSystemKnowledgeStore(metabolism_dir)
    cache_dir = metabolism_dir / "fulltext_cache"
    report_dir = Path(args.report_dir).absolute() if args.report_dir else run_dir / "reports"
    config["paths"]["reports"] = report_dir
    exp = config["experiment"]
    topic = args.topic

    logger.info("=" * 70)
    logger.info("Lite: No-Diff | Topic: %s", topic)
    logger.info("=" * 70)

    init_papers = fetch_phase_papers(topic, exp["init_start"], exp["init_end"], min(exp["init_max_papers"], exp["max_papers"]), exp["phase_search_window_months"], phase_name="Init")
    evo_papers = fetch_phase_papers(topic, exp["evolution_start"], exp["evolution_end"], min(exp["evolution_max_papers"], exp["max_papers"]), max(1, min(exp["phase_search_window_months"], exp["evolution_step_months"])), phase_name="Evolution")
    val_papers = fetch_phase_papers(topic, exp["validation_start"], exp["validation_end"], max(min(exp["validation_max_papers"], exp["max_papers"]), exp["validation_min_papers"]), exp["phase_search_window_months"], min_papers=exp["validation_min_papers"], phase_name="Validation")
    all_papers = dedupe_and_sort_papers(init_papers + evo_papers + val_papers)
    if not all_papers:
        logger.error("No papers, aborting."); return

    logger.info("Papers: %d (init=%d, evo=%d, val=%d)", len(all_papers), len(init_papers), len(evo_papers), len(val_papers))
    init_tok = await phase_init(topic, init_papers, store, cache_dir)
    evo_tok = await phase_evolution(topic, evo_papers, store, cache_dir, exp["evolution_start"], exp["evolution_end"], exp["evolution_step_months"])
    await phase_validation(all_papers, store, exp["validation_start"], exp["validation_end"], topic, "none", cache_dir)
    logger.info("Done. Total tokens: %d", init_tok + evo_tok)

    (store.base_dir / "token_usage.json").write_text(
        json.dumps({"init_tokens": init_tok, "evolution_tokens": evo_tok, "total_generation_tokens": init_tok + evo_tok}, indent=2), encoding="utf-8"
    )

if __name__ == "__main__":
    asyncio.run(run())
