"""
Ablation 1: Shuffled Timestamps — CKM with randomized paper order.

Identical to CKM except: paper timestamps are shuffled before Phase 2,
destroying the temporal signal while keeping everything else intact.
Uses CKM hypothesis engine for strict single-variable ablation.
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))

import asyncio
import argparse
import logging
import random
import re
from datetime import datetime

from config import config
from core.store import FileSystemKnowledgeStore
import json

from core.engines import (
    run_read_engine, run_init_engine, run_topic_update_engine, run_topic_discover_engine,
    run_ckm_hypothesis_engine, run_trigger_detection,
    get_embedding,
)
from tools.arxiv_search import search_arxiv_topic
from tools.arxiv_fulltext import get_paper_content_record

logger = logging.getLogger("ablation1.main")


# ---------------------------------------------------------------------------
# Helpers — identical to CKM
# ---------------------------------------------------------------------------

def add_months(date_str: str, months: int) -> str:
    dt = datetime.strptime(date_str, "%Y-%m-%d")
    m = dt.month - 1 + months
    y = dt.year + m // 12
    m = m % 12 + 1
    return dt.replace(year=y, month=m, day=1).strftime("%Y-%m-%d")


def build_phase_windows(start_date, end_date, window_months):
    windows = []
    current = start_date
    while current < end_date:
        next_date = add_months(current, window_months)
        if next_date > end_date:
            next_date = end_date
        windows.append((current, next_date))
        current = next_date
    return windows


def allocate_window_budgets(total_budget, n_windows):
    if total_budget <= 0 or n_windows <= 0:
        return []
    base = total_budget // n_windows
    remainder = total_budget % n_windows
    return [base + (1 if i < remainder else 0) for i in range(n_windows)]


def dedupe_and_sort_papers(papers):
    unique = {}
    for paper in papers:
        unique.setdefault(paper["arxiv_id"], paper)
    return sorted(unique.values(), key=lambda x: x["published"])


async def resolve_paper_contents(papers, cache_dir, concurrency, phase_name):
    exp_cfg = config["experiment"]
    semaphore = asyncio.Semaphore(max(1, concurrency))
    async def resolve_single(paper):
        async with semaphore:
            record = await get_paper_content_record(
                paper["arxiv_id"], paper["abstract"], cache_dir,
                timeout_s=max(1, exp_cfg.get("fulltext_timeout_s", 30)),
                retries=max(1, exp_cfg.get("fulltext_retries", 2)),
                retry_delay_s=max(0, exp_cfg.get("fulltext_retry_delay_s", 2)),
            )
        resolved = dict(paper)
        resolved["content"] = record["content"]
        resolved["content_source"] = record["source"]
        resolved["counted_fulltext"] = record["counted_fulltext"]
        return resolved
    resolved_papers = await asyncio.gather(*[resolve_single(p) for p in papers])
    counted = sum(1 for p in resolved_papers if p["counted_fulltext"])
    logger.info("[%s] Resolved: total=%d, fulltext=%d, abstract_only=%d",
                phase_name, len(resolved_papers), counted, len(resolved_papers) - counted)
    return resolved_papers


def fetch_phase_papers(topic, start_date, end_date, total_budget, window_months, min_papers=0, phase_name="Search"):
    windows = build_phase_windows(start_date, end_date, window_months)
    budgets = allocate_window_budgets(max(total_budget, min_papers), len(windows))
    collected = []
    for (ws, we), budget in zip(windows, budgets):
        if budget <= 0: continue
        papers = search_arxiv_topic(topic, budget, start_date=ws, end_date=we)
        logger.info("[%s] %s ~ %s: fetched %d (budget=%d)", phase_name, ws, we, len(papers), budget)
        collected.extend(papers)
    deduped = dedupe_and_sort_papers(collected)
    if min_papers and len(deduped) < min_papers:
        logger.warning("[%s] Only %d papers (minimum=%d)", phase_name, len(deduped), min_papers)
    return deduped


def display_topic_name(fn):
    stem = Path(fn).stem
    if stem == "_index": return ""
    return stem.replace("topic-", "").replace("-", " ")


# Phase 1: Init — identical to CKM
async def phase_init(topic, papers, store, cache_dir):
    if not papers:
        return 0
    resolved = await resolve_paper_contents(papers, cache_dir, 12, "Init")
    fulltext = [p for p in resolved if p["counted_fulltext"]]
    if not fulltext:
        raise RuntimeError("Init: zero fulltext papers")
    sem = asyncio.Semaphore(6)
    async def extract(p):
        async with sem:
            return await run_read_engine(p["title"], p["arxiv_id"], p["published"], p["content"])
    extractions = await asyncio.gather(*[extract(p) for p in fulltext])
    cfg_path = store.base_dir / "config.json"
    cfg_path.write_text(json.dumps({"topic": topic, "processed_ids": [p["arxiv_id"] for p in fulltext], "currentDay": 0}, indent=2), encoding="utf-8")
    result = await run_init_engine(topic, "\n\n---\n\n".join(extractions), len(fulltext))
    if not [op for op in result["operations"] if op["fileName"] != "_index.md"]:
        raise RuntimeError(f"Init: no topic files: {result.get('error')}")
    emb = store.get_embeddings_dict()
    for op in result["operations"]:
        store.write_knowledge_file(op["fileName"], op["fileContent"])
    vecs = await asyncio.gather(*[get_embedding(op["fileContent"]) for op in result["operations"]])
    for op, vec in zip(result["operations"], vecs):
        emb[op["fileName"]] = vec
    store.save_embeddings_dict(emb)
    cfg = json.loads(cfg_path.read_text(encoding="utf-8"))
    cfg["currentDay"] = 1
    cfg_path.write_text(json.dumps(cfg, indent=2), encoding="utf-8")
    logger.info("[Init] Complete: %d files, %d tokens", len(result["operations"]), result["tokens"])
    return result["tokens"]


# Phase 2: Evolution — CKM with SHUFFLED timestamps
async def phase_evolution(topic, papers, store, cache_dir, evo_start, evo_end, step_months):
    if not papers:
        return 0

    # *** ABLATION: shuffle timestamps ***
    logger.info("[Ablation1] Shuffling paper timestamps")
    dates = [p["published"] for p in papers]
    random.shuffle(dates)
    for i, p in enumerate(papers):
        p["published"] = dates[i]
    papers.sort(key=lambda x: x["published"])

    exp_cfg = config["experiment"]
    window_limit = max(1, exp_cfg.get("evolution_window_paper_limit", 8))
    backfill_on_empty = exp_cfg.get("evolution_backfill_on_empty", True)
    hyp_max = max(1, exp_cfg.get("hypothesis_max_per_window", 3))
    total_tokens = 0
    current_date = evo_start
    deferred = []
    window_summaries = []
    window_index = 0

    def sort_rf(items): return sorted(items, key=lambda x: x["published"], reverse=True)
    def _trajectory():
        if not window_summaries: return "(first window)"
        return "\n".join(f"- Period {ws['period']}: {ws['summary']}" for ws in window_summaries)
    def _existing_hyps():
        files = sorted(store.dirs["hypotheses"].glob("*.md"))
        if not files: return "(none yet)"
        out = []
        for fp in files:
            c = fp.read_text(encoding="utf-8")
            m = re.search(r"-(\d{3})\.md$", fp.name)
            hid = f"H{m.group(1)}" if m else fp.stem
            sm = re.search(r"## Statement\s*([\s\S]*?)(?=##|\Z)", c)
            out.append(f"- {hid}: {sm.group(1).strip()[:200] if sm else '?'}")
        return "\n".join(out)

    while current_date < evo_end:
        next_date = add_months(current_date, step_months)
        period = current_date[:7]
        window_index += 1
        wp = [p for p in papers if current_date <= p["published"] < next_date]
        wp = sort_rf(wp)
        fresh = wp[:window_limit]
        if wp[window_limit:]:
            deferred.extend(wp[window_limit:])
            deferred = sort_rf(deferred)
        borrowed = []
        if not fresh and backfill_on_empty and deferred:
            borrowed = deferred[:window_limit]
            deferred = deferred[window_limit:]
        window_papers = fresh or borrowed
        if not window_papers:
            window_summaries.append({"period": period, "summary": "No papers"})
            current_date = next_date; continue
        resolved = await resolve_paper_contents(window_papers, cache_dir, min(8, len(window_papers)), f"Evo {period}")
        counted = [p for p in resolved if p["counted_fulltext"]]
        if not counted and backfill_on_empty and deferred and not borrowed:
            borrowed = deferred[:window_limit]; deferred = deferred[window_limit:]
            resolved = await resolve_paper_contents(borrowed, cache_dir, min(8, len(borrowed)), f"Evo {period} bf")
            counted = [p for p in resolved if p["counted_fulltext"]]
        if not counted:
            window_summaries.append({"period": period, "summary": "No fulltext"})
            current_date = next_date; continue

        parts = [f"### {p['title']}\n- arxiv_id: {p['arxiv_id']}\n- published: {p['published']}\n- content_source: {p['content_source']}\n\n{p['content']}" for p in counted]
        papers_text = "\n\n---\n\n".join(parts)
        store.write_log(f"window-{period}-ingest.md", f"# {period} — {len(counted)} papers (SHUFFLED)\n\n{papers_text}")

        kb = store.get_joined_knowledge_content()
        kf = {fn: c for fn, c in store.get_knowledge_files().items() if fn != "_index.md"}
        emb = store.get_embeddings_dict()
        sem = asyncio.Semaphore(6)
        async def upd(fn, fc):
            async with sem: return await run_topic_update_engine(fn, fc, papers_text, period)
        res = await asyncio.gather(*[upd(fn, fc) for fn, fc in kf.items()])
        total_tokens += sum(r["tokens"] for r in res)
        for r in res: store.write_knowledge_file(r["fileName"], r["fileContent"])
        nvecs = await asyncio.gather(*[get_embedding(r["fileContent"]) for r in res])
        for r, v in zip(res, nvecs): emb[r["fileName"]] = v
        store.save_embeddings_dict(emb)

        disc = await run_topic_discover_engine(list(kf.keys()), papers_text, period)
        total_tokens += disc["tokens"]
        if disc["operations"]:
            dvecs = await asyncio.gather(*[get_embedding(op["fileContent"]) for op in disc["operations"]])
            for op, v in zip(disc["operations"], dvecs):
                store.write_knowledge_file(op["fileName"], op["fileContent"]); emb[op["fileName"]] = v
            store.save_embeddings_dict(emb)

        ka = store.get_joined_knowledge_content()
        tr = await run_trigger_detection(kb, ka, papers_text, period, 0)
        total_tokens += tr["tokens"]
        ct = tr["change_type"]
        ws = tr.get("reason", "").strip() or f"{len(counted)} papers"
        store.write_log(f"change-{period}.md", f"# Change — {period}\nType: {ct}\n")

        si = store.count_hypotheses() + 1
        hr = await run_ckm_hypothesis_engine(topic=topic, knowledge_content=ka, papers_text=papers_text,
            time_period=period, max_hypotheses=hyp_max, evolution_trajectory=_trajectory(),
            existing_hypotheses=_existing_hyps(), trigger_type=ct, trigger_reason=tr.get("reason",""), n_windows=window_index)
        total_tokens += hr["tokens"]
        for off, h in enumerate(hr.get("hypotheses", [])):
            hid = str(si + off).zfill(3)
            store.write_hypothesis(f"{period}-{hid}", re.sub(r"^# Hypothesis H\d+", f"# Hypothesis H{hid}", h["content"], count=1))
            logger.info("[Evo] %s: H%s (shuffled)", period, hid)

        window_summaries.append({"period": period, "summary": ws, "change_type": ct, "hypotheses_generated": len(hr.get("hypotheses", []))})
        current_date = next_date

    logger.info("[Evolution] Complete. Tokens: %d", total_tokens)
    return total_tokens


async def phase_validation(all_papers, store, val_start, val_end, topic, ablation="none", cache_dir=None):
    from tools.calculate_metrics import calculate_metrics
    await calculate_metrics(all_papers, val_start, val_end, topic, ablation, store, cache_dir)


async def run():
    parser = argparse.ArgumentParser(description="Ablation 1: Shuffled Timestamps")
    parser.add_argument("topic", nargs="?", default="AI for software engineering")
    parser.add_argument("--metabolism_dir", type=str)
    parser.add_argument("--report_dir", type=str)
    args = parser.parse_args()
    logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(levelname)s: %(message)s")

    project_dir = Path(__file__).parent.parent.parent.absolute()
    run_dir = project_dir / "results" / ("ablation1_" + datetime.now().strftime("%Y%m%d_%H%M%S"))
    metabolism_dir = Path(args.metabolism_dir).absolute() if args.metabolism_dir else run_dir / "metabolism"
    metabolism_dir.mkdir(parents=True, exist_ok=True)
    store = FileSystemKnowledgeStore(metabolism_dir)
    cache_dir = metabolism_dir / "fulltext_cache"
    report_dir = Path(args.report_dir).absolute() if args.report_dir else run_dir / "reports"
    config["paths"]["reports"] = report_dir
    exp = config["experiment"]
    topic = args.topic

    logger.info("=" * 70)
    logger.info("Ablation 1: Shuffled | Topic: %s", topic)
    logger.info("=" * 70)

    init_papers = fetch_phase_papers(topic, exp["init_start"], exp["init_end"], min(exp["init_max_papers"], exp["max_papers"]), exp["phase_search_window_months"], phase_name="Init")
    evo_papers = fetch_phase_papers(topic, exp["evolution_start"], exp["evolution_end"], min(exp["evolution_max_papers"], exp["max_papers"]), max(1, min(exp["phase_search_window_months"], exp["evolution_step_months"])), phase_name="Evolution")
    val_papers = fetch_phase_papers(topic, exp["validation_start"], exp["validation_end"], max(min(exp["validation_max_papers"], exp["max_papers"]), exp["validation_min_papers"]), exp["phase_search_window_months"], min_papers=exp["validation_min_papers"], phase_name="Validation")
    all_papers = dedupe_and_sort_papers(init_papers + evo_papers + val_papers)
    if not all_papers:
        logger.error("No papers, aborting."); return

    logger.info("Papers: %d (init=%d, evo=%d, val=%d)", len(all_papers), len(init_papers), len(evo_papers), len(val_papers))
    init_tok = await phase_init(topic, init_papers, store, cache_dir)
    evo_tok = await phase_evolution(topic, evo_papers, store, cache_dir, exp["evolution_start"], exp["evolution_end"], exp["evolution_step_months"])
    await phase_validation(all_papers, store, exp["validation_start"], exp["validation_end"], topic, "none", cache_dir)
    logger.info("Done. Total tokens: %d", init_tok + evo_tok)

    (store.base_dir / "token_usage.json").write_text(
        json.dumps({"init_tokens": init_tok, "evolution_tokens": evo_tok, "total_generation_tokens": init_tok + evo_tok}, indent=2), encoding="utf-8"
    )

if __name__ == "__main__":
    asyncio.run(run())
