"""
Pool Baseline — batch runner (control group for CKM).

Mirrors scripts/batch_run.py exactly, but:
  - Calls scripts/pool/eval_single.py instead of scripts/eval_single.py
  - Stores results under results/pool_<batch_id>/
  - Supports concurrent topic execution via --concurrency N
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))

import json
import os
import subprocess
import time
import re
import logging
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime

from config import config
from topics import TOPICS, slugify_topic_name

logger = logging.getLogger("pool.batch")





def extract_metric(pattern: str, text: str, default=0):
    m = re.search(pattern, text)
    return m.group(1) if m else default


def is_batch_results_dir(path: Path) -> bool:
    return (path / "batch_summary.json").exists() or (path / "topics").is_dir()


# Thread-safe summary writer
_summary_lock = threading.Lock()


def _save_summary(summary_file: Path, summary_data: list):
    with _summary_lock:
        with open(summary_file, "w", encoding="utf-8") as f:
            json.dump(summary_data, f, indent=2, ensure_ascii=False)


def run_single_topic(
    topic: dict,
    scripts_dir: Path,
    current_batch_dir: Path,
    all_reports_dir: Path,
    batch_log_file: Path,
    summary_file: Path,
    summary_data: list,
    ablation: str,
    python: str,
) -> None:
    """Run a single topic. Designed to be called from a thread pool."""
    slug = topic["slug"]
    name = topic["name"]

    topic_dir = current_batch_dir / "topics" / slug
    topic_metabolism = topic_dir / "metabolism"
    topic_reports = topic_dir / "reports"
    topic_metabolism.mkdir(parents=True, exist_ok=True)
    topic_reports.mkdir(parents=True, exist_ok=True)

    # Per-topic log file for concurrent output isolation
    topic_log_file = topic_dir / "execution.log"

    cmd = [
        python, str(scripts_dir / "eval_single.py"),
        name,
        "--metabolism_dir", str(topic_metabolism),
        "--report_dir", str(topic_reports),
        "--ablation", ablation,
    ]

    logger.info("[START] %s", name)
    start_time = time.time()

    try:
        env = os.environ.copy()
        env["PYTHONIOENCODING"] = "utf-8"
        env["PYTHONUNBUFFERED"] = "1"

        process = subprocess.Popen(
            cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
            text=True, encoding="utf-8",
        )

        # Write output to per-topic log and batch log
        with open(topic_log_file, "w", encoding="utf-8") as topic_f:
            topic_f.write(f">>> Topic: {name}\n")
            topic_f.write(f">>> Started: {datetime.now().isoformat()}\n\n")
            for line in process.stdout:
                topic_f.write(line)

        process.wait()
        duration = time.time() - start_time

        if process.returncode != 0:
            raise subprocess.CalledProcessError(process.returncode, cmd)

        # Extract metrics from report
        safe_title = name.replace(" ", "_")
        suffix = f"_Ablation_{ablation}" if ablation != "none" else ""
        report_path = topic_reports / f"{safe_title}{suffix}_Evaluation_Report.md"

        metrics = {
            "slug": slug,
            "name": name,
            "duration": duration,
            "timestamp": datetime.now().isoformat(),
        }

        if report_path.exists():
            content = report_path.read_text(encoding="utf-8")
            (all_reports_dir / f"{slug}_report.md").write_text(content, encoding="utf-8")

            metrics["yield"] = int(extract_metric(r"Hypothesis Yield\s*\|\s*(\d+)", content))
            metrics["temporal_lead"] = int(extract_metric(r"Avg Temporal Lead\s*\|\s*(\d+)", content))
            metrics["cross_domain"] = float(extract_metric(r"Avg Cross-domain Score\s*\|\s*([\d.]+)", content))
            metrics["novelty"] = float(extract_metric(r"Avg Novelty Judge Score\s*\|\s*([\d.]+)", content))
            metrics["hit_rate"] = float(extract_metric(r"Predictive Hit Rate\s*\|\s*([\d.]+)", content))
            metrics["best_match_score"] = float(extract_metric(r"Avg Best Match Score\s*\|\s*([\d.]+)", content))
            metrics["precision_at_3"] = float(extract_metric(r"Precision@3.*?\|\s*([\d.]+)", content))
            metrics["unique_hit_papers"] = int(extract_metric(r"Total Unique Hit Papers\s*\|\s*(\d+)", content))

        # Read token usage if available
        token_file = topic_metabolism / "token_usage.json"
        if token_file.exists():
            import json as _json
            token_data = _json.loads(token_file.read_text(encoding="utf-8"))
            metrics["init_tokens"] = token_data.get("init_tokens", 0)
            metrics["evolution_tokens"] = token_data.get("evolution_tokens", 0)
            metrics["total_generation_tokens"] = token_data.get("total_generation_tokens", 0)

        with _summary_lock:
            summary_data.append(metrics)
        _save_summary(summary_file, summary_data)

        logger.info("[DONE] %s — %.1f min, yield=%s, hit_rate=%s",
                    name, duration / 60,
                    metrics.get("yield", "?"), metrics.get("hit_rate", "?"))

    except Exception as e:
        duration = time.time() - start_time
        logger.error("[FAIL] %s — %.1f min: %s", name, duration / 60, e)
        with open(batch_log_file, "a", encoding="utf-8") as log_f:
            log_f.write(f"ERROR [{name}]: {e}\n")


def run_batch():
    import argparse
    parser = argparse.ArgumentParser(description="Pool Baseline Batch Runner")
    parser.add_argument("--limit", type=int, help="Limit the number of topics to run")
    parser.add_argument("--concurrency", type=int, default=1,
                        help="Number of topics to run concurrently (default: 1)")
    parser.add_argument("--ablation", choices=["none", "shuffled"], default="none")
    parser.add_argument("--batch_id", type=str, default=None,
                        help="Batch ID. If omitted, resumes latest pool_ batch or creates new.")
    args = parser.parse_args()

    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
    )

    scripts_dir = Path(__file__).parent.absolute()       # scripts/pool/
    project_dir = scripts_dir.parent.parent              # ckm-eval/
    python = sys.executable
    batch_root = project_dir / "results"

    # Resolve batch directory with pool_ prefix
    if args.batch_id:
        batch_id = args.batch_id if args.batch_id.startswith("pool_") else f"pool_{args.batch_id}"
        current_batch_dir = batch_root / batch_id
    elif batch_root.exists():
        existing = sorted([
            d for d in batch_root.iterdir()
            if d.is_dir() and d.name.startswith("pool_") and is_batch_results_dir(d)
        ])
        if existing:
            current_batch_dir = existing[-1]
            logger.info("Resuming existing pool batch: %s", current_batch_dir.name)
        else:
            current_batch_dir = batch_root / ("pool_" + datetime.now().strftime("%Y%m%d_%H%M%S"))
    else:
        current_batch_dir = batch_root / ("pool_" + datetime.now().strftime("%Y%m%d_%H%M%S"))

    current_batch_dir.mkdir(parents=True, exist_ok=True)
    all_reports_dir = current_batch_dir / "all_reports"
    all_reports_dir.mkdir(exist_ok=True)

    summary_file = current_batch_dir / "batch_summary.json"
    batch_log_file = current_batch_dir / "batch_execution.log"

    # Load existing progress for resume
    summary_data = []
    if summary_file.exists():
        with open(summary_file, "r", encoding="utf-8") as f:
            summary_data = json.load(f)
    completed_slugs = {item["slug"] for item in summary_data}

    topics_to_run = TOPICS[:args.limit] if args.limit else TOPICS
    pending_topics = [t for t in topics_to_run if t["slug"] not in completed_slugs]

    concurrency = max(1, args.concurrency)

    logger.info("Pool Batch [%s]: %d topics total, %d pending, concurrency=%d, ablation=%s",
                current_batch_dir.name, len(topics_to_run), len(pending_topics),
                concurrency, args.ablation)

    if not pending_topics:
        logger.info("All topics already completed.")
        return

    with open(batch_log_file, "a", encoding="utf-8") as log_f:
        log_f.write(f"\n--- Pool Batch started at {datetime.now().isoformat()} "
                    f"(concurrency={concurrency}) ---\n")

    if concurrency == 1:
        # Sequential mode — same behavior as before, with live stdout
        for topic in pending_topics:
            run_single_topic(
                topic, scripts_dir, current_batch_dir, all_reports_dir,
                batch_log_file, summary_file, summary_data,
                args.ablation, python,
            )
    else:
        # Concurrent mode — submit up to `concurrency` topics at a time.
        # When one finishes, wait COOLDOWN before submitting the next,
        # so arXiv requests from different topics stay spread out.
        COOLDOWN_S = 60
        FAST_TOPIC_THRESHOLD_S = 10
        with ThreadPoolExecutor(max_workers=concurrency) as executor:
            topic_iter = iter(enumerate(pending_topics))
            active_futures = {}
            start_times = {}

            def _submit_next():
                """Submit one topic from the iterator, return True if submitted."""
                try:
                    i, topic = next(topic_iter)
                except StopIteration:
                    return False
                f = executor.submit(
                    run_single_topic,
                    topic, scripts_dir, current_batch_dir, all_reports_dir,
                    batch_log_file, summary_file, summary_data,
                    args.ablation, python,
                )
                active_futures[f] = topic
                start_times[f] = time.time()
                logger.info("[QUEUED] %s (%d/%d)", topic["name"], i + 1, len(pending_topics))
                return True

            # Seed initial batch
            for _ in range(min(concurrency, len(pending_topics))):
                _submit_next()

            # Process completions and refill slots
            while active_futures:
                # Wait for exactly one future to complete
                done_set = set()
                for f in as_completed(active_futures):
                    done_set.add(f)
                    break  # process one at a time

                for f in done_set:
                    topic = active_futures.pop(f)
                    elapsed = time.time() - start_times.pop(f, time.time())
                    try:
                        f.result()
                    except Exception as e:
                        logger.error("[FAIL] %s: unhandled: %s", topic["name"], e)

                    # Cooldown only if the topic actually ran (not a no-paper instant finish)
                    if elapsed > FAST_TOPIC_THRESHOLD_S:
                        logger.info("Cooling down %ds before next topic...", COOLDOWN_S)
                        time.sleep(COOLDOWN_S)

                    _submit_next()

    logger.info("Pool Batch complete. Summary: %s", summary_file)


if __name__ == "__main__":
    run_batch()
