#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import os
import re
import sys
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Tuple, Callable, Optional

# Ensure project root on sys.path
THIS_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = THIS_DIR.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.a4s.llm_client import LLMClient
from src.a4s.orchestrator import Orchestrator
from src.a4s.schemas import OrchestratorConfig
from src.baselines.call_limited_client import CallLimitedClient
from src.baselines.single_agent import run_single_agent
from src.baselines.ts_tree_search import run_tree_search
from src.baselines.debate import run_debate
from scripts.evaluate_run import evaluate_run_dir


def ensure_env_loaded() -> None:
    try:
        from dotenv import load_dotenv, find_dotenv  # type: ignore
        load_dotenv(find_dotenv(), override=False)
    except Exception:
        pass


def now_hms() -> str:
    return datetime.now().strftime("%H:%M:%S")


def log_print(line: str, log_file: Path | None = None) -> None:
    msg = f"[{now_hms()}] {line}\n"
    print(msg, end="", flush=True)
    if log_file is not None:
        log_file.parent.mkdir(parents=True, exist_ok=True)
        with log_file.open("a", encoding="utf-8") as f:
            f.write(msg)


def start_heartbeat(log_file: Path, interval_sec: int = 30) -> Tuple[threading.Thread, threading.Event]:
    stop_event = threading.Event()

    def _beat() -> None:
        while not stop_event.wait(interval_sec):
            try:
                log_print("Heartbeat: job is still running...", log_file)
            except Exception:
                pass

    t = threading.Thread(target=_beat, name="heartbeat", daemon=True)
    t.start()
    return t, stop_event


def slugify_topic(topic: str, max_words: int = 6) -> str:
    clean = re.sub(r"[^A-Za-z0-9\s_-]", "", topic).strip()
    words = [w for w in re.split(r"\s+", clean) if w]
    short = " ".join(words[:max_words]) if words else "topic"
    short = short.lower().strip()
    short = re.sub(r"\s+", "_", short)
    return short or "topic"


def unique_subdir(root: Path, base_name: str) -> Path:
    p = root / base_name
    if not p.exists():
        return p
    i = 2
    while True:
        q = root / f"{base_name}-{i}"
        if not q.exists():
            return q
        i += 1


def _write_fallback_report(out_dir: Path, proposition: str, error: Exception) -> Tuple[str, Dict[str, Any]]:
    out_dir.mkdir(parents=True, exist_ok=True)
    logs_root = out_dir / "logs"

    scenario_text = ""
    try:
        scenario_path = logs_root / "scenario.md"
        if scenario_path.exists():
            scenario_text = scenario_path.read_text(encoding="utf-8")
    except Exception:
        pass

    round_summaries: List[str] = []
    if logs_root.exists():
        try:
            for rd in sorted(logs_root.glob("round_*")):
                s = rd / "summary.md"
                if s.exists():
                    round_summaries.append(s.read_text(encoding="utf-8"))
        except Exception:
            pass

    lines: List[str] = [
        f"# Agents4Sci Partial Report (Fallback)",
        "",
        f"Proposition: {proposition}",
        "",
        "This run did not complete due to an exception. The report below is assembled from available intermediate logs.",
        f"Error: {type(error).__name__}: {error}",
        "",
    ]
    if scenario_text:
        lines.extend(["## Scenario Refinement (partial)", "", scenario_text, ""])
    if round_summaries:
        lines.append("## Round Summaries (partial)")
        for idx, rs in enumerate(round_summaries, start=1):
            lines.extend(["", f"### Round {idx}", "", rs])
    if not scenario_text and not round_summaries:
        lines.append("No intermediate logs were found. Please rerun with a higher call budget or fewer rounds.")

    report_text = "\n".join(lines)
    (out_dir / "report.md").write_text(report_text, encoding="utf-8")

    structured: Dict[str, Any] = {
        "proposition": proposition,
        "status": "failed",
        "error": f"{type(error).__name__}: {error}",
        "rounds_found": len(round_summaries),
    }
    (out_dir / "structured.json").write_text(json.dumps(structured, ensure_ascii=False, indent=2), encoding="utf-8")
    return report_text, structured


def run_agents4sci_with_budget(
    proposition: str,
    out_dir: Path,
    max_calls: int,
    root_log: Path | None = None,
    rounds: Optional[int] = None,
    roles: Optional[List[str]] = None,
) -> Tuple[str, Dict[str, Any]]:
    # Build config with optional overrides
    cfg = OrchestratorConfig(
        rounds=rounds if rounds is not None else OrchestratorConfig().rounds,
        expert_roles=roles if roles is not None else OrchestratorConfig().expert_roles,
        dependency_map=OrchestratorConfig().dependency_map,
        model_id=OrchestratorConfig().model_id,
    )

    client = LLMClient(default_model=cfg.model_id)
    limited = CallLimitedClient(client, max_calls=max_calls)

    # Monkey-patch LLMClient inside orchestrator to use the limited client
    from src.a4s import orchestrator as orch_mod

    orig_llmclient_cls = orch_mod.LLMClient
    try:
        orch_mod.LLMClient = lambda default_model: limited  # type: ignore
        if root_log:
            log_print("[a4s] Orchestrator initialized with call budget", root_log)
        orch = Orchestrator(cfg)
        try:
            report, structured = orch.run(proposition, out_dir=out_dir)
        except Exception as e:
            if root_log:
                log_print(f"[a4s] Exception during run: {e}. Writing fallback report...", root_log)
            report, structured = _write_fallback_report(out_dir, proposition, e)
    finally:
        orch_mod.LLMClient = orig_llmclient_cls

    out_dir.mkdir(parents=True, exist_ok=True)
    (out_dir / "report.md").write_text(report, encoding="utf-8")
    (out_dir / "structured.json").write_text(json.dumps(structured, ensure_ascii=False, indent=2), encoding="utf-8")
    return report, structured


def run_topic_all_models(
    topic_idx: int,
    topic: str,
    out_root: Path,
    max_calls: int,
    eval_model: str,
    parallel_models: bool,
    root_log: Path,
) -> None:
    folder = unique_subdir(out_root, slugify_topic(topic, max_words=6))
    folder.mkdir(parents=True, exist_ok=True)
    tlog_path = folder / "progress.log"

    tlog = lambda m: log_print(f"[topic {topic_idx}] {m}", tlog_path)

    tlog(f"Topic: {topic}")

    # Align baseline model to Orchestrator default for fair comparison
    cfg_for_baselines = OrchestratorConfig()
    base_client = LLMClient(default_model=cfg_for_baselines.model_id)

    def _run_a4s() -> None:
        out_dir = folder / "agents4sci_v2"
        tlog("Agents4Sci v2 started (budget<=25 calls)...")
        run_agents4sci_with_budget(topic, out_dir, max_calls=max_calls, root_log=tlog_path)
        tlog(f"Agents4Sci v2 finished -> {out_dir}")

    def _run_single() -> None:
        out_dir = folder / "baseline_single"
        tlog("Baseline Single-Agent started (budget<=25 calls)...")
        run_single_agent(topic, CallLimitedClient(base_client, max_calls=max_calls), out_dir)
        tlog(f"Baseline Single-Agent finished -> {out_dir}")

    def _run_tree() -> None:
        out_dir = folder / "baseline_tree"
        tlog("Baseline Tree-Search started (budget<=25 calls)...")
        run_tree_search(topic, CallLimitedClient(base_client, max_calls=max_calls), out_dir, breadth=2, depth=2)
        tlog(f"Baseline Tree-Search finished -> {out_dir}")

    def _run_debate() -> None:
        out_dir = folder / "baseline_debate"
        tlog("Baseline Debate started (budget<=25 calls)...")
        run_debate(topic, CallLimitedClient(base_client, max_calls=max_calls), out_dir, rounds=3)
        tlog(f"Baseline Debate finished -> {out_dir}")

    runs: List[Tuple[str, Callable[[], None]]] = [
        ("agents4sci_v2", _run_a4s),
        ("baseline_single", _run_single),
        ("baseline_tree", _run_tree),
        ("baseline_debate", _run_debate),
    ]

    if parallel_models:
        tlog("Launching all models in parallel...")
        with ThreadPoolExecutor(max_workers=4) as ex:
            fut_to_name = {ex.submit(fn): name for name, fn in runs}
            for fut in as_completed(fut_to_name):
                name = fut_to_name[fut]
                try:
                    fut.result()
                except Exception as e:
                    tlog(f"ERROR in {name}: {e}")
    else:
        for name, fn in runs:
            try:
                fn()
            except Exception as e:
                tlog(f"ERROR in {name}: {e}")

    summary = {
        "topic": topic,
        "paths": {
            "agents4sci_v2": str(folder / "agents4sci_v2"),
            "baseline_single": str(folder / "baseline_single"),
            "baseline_tree": str(folder / "baseline_tree"),
            "baseline_debate": str(folder / "baseline_debate"),
        },
    }
    (folder / "summary.json").write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
    tlog("summary.json written")

    try:
        tlog(f"Evaluation started (model={eval_model})...")
        evaluate_run_dir(folder, LLMClient(), eval_model=eval_model)
        tlog("Evaluation finished -> evaluation.json / evaluation.md")
    except Exception as e:
        tlog(f"Evaluation failed: {e}")


def aggregate_across_topics(out_root: Path) -> None:
    dirs = [p for p in out_root.iterdir() if p.is_dir()]
    dirs.sort()

    models = ["agents4sci_v2", "baseline_single", "baseline_tree", "baseline_debate"]
    # Updated 5D rubric metrics
    metrics = [
        "rigor_traceability",
        "integration_causality",
        "feasibility_minimality",
        "uncertainty_adaptation",
        "decisionability",
        "overall",
    ]

    aggregate: Dict[str, Dict[str, List[float]]] = {m: {k: [] for k in models} for m in metrics}
    per_topic_rows: List[str] = [
        "| Topic | agents4sci_v2 | baseline_single | baseline_tree | baseline_debate |",
        "|---|---:|---:|---:|---:|",
    ]

    for d in dirs:
        eval_path = d / "evaluation.json"
        if not eval_path.exists():
            continue
        try:
            data = json.loads(eval_path.read_text(encoding="utf-8"))
        except Exception:
            continue
        row = [d.name]
        for model in models:
            mdata = data.get(model, {})
            overall = mdata.get("overall")
            if isinstance(overall, (int, float)):
                row.append(f"{overall:.2f}")
            else:
                row.append("-")
            for met in metrics:
                val = mdata.get(met)
                if isinstance(val, (int, float)):
                    aggregate[met][model].append(float(val))
        per_topic_rows.append("| " + " | ".join(row) + " |")

    def avg(xs: List[float]) -> float:
        return sum(xs) / len(xs) if xs else 0.0

    macro_lines: List[str] = [
        "",
        "### Macro Averages (5D rubric, across topics)",
        "",
        "| Metric | agents4sci_v2 | baseline_single | baseline_tree | baseline_debate |",
        "|---|---:|---:|---:|---:|",
    ]
    for met in metrics:
        macro_lines.append(
            "| "
            + met
            + " | "
            + " | ".join(f"{avg(aggregate[met][m]):.2f}" for m in models)
            + " |"
        )

    best_line = ""
    overall_avgs = {m: avg(aggregate["overall"][m]) for m in models}
    if overall_avgs:
        best_model = max(overall_avgs.items(), key=lambda kv: kv[1])[0]
        best_line = f"\n**Best overall (macro average)**: `{best_model}` with {overall_avgs[best_model]:.2f}.\n"

    content = [
        f"## Aggregate Results for {out_root.name}",
        "",
        *per_topic_rows,
        *macro_lines,
        best_line,
        "",
    ]
    (out_root / "aggregate.md").write_text("\n".join(content), encoding="utf-8")


def repair_missing_a4s(out_root: Path | str, eval_model: str = "doubao-seed-1-6-250615", max_calls: int = 25) -> None:
    # Ensure environment (.env) is loaded for API keys
    ensure_env_loaded()
    root = Path(out_root).resolve()
    for d in sorted([p for p in root.iterdir() if p.is_dir()]):
        try:
            summary_path = d / "summary.json"
            a4s_dir = d / "agents4sci_v2"
            report_path = a4s_dir / "report.md"
            if report_path.exists():
                continue
            if not summary_path.exists():
                continue
            summary = json.loads(summary_path.read_text(encoding="utf-8"))
            topic = summary.get("topic") or ""
            if not topic:
                continue
            log_print(f"[repair] Rebuilding Agents4Sci for {d.name} ...", root / "progress.log")
            a4s_dir.mkdir(parents=True, exist_ok=True)
            run_agents4sci_with_budget(topic, a4s_dir, max_calls=max_calls, root_log=(root / "progress.log"))
            evaluate_run_dir(d, LLMClient(), eval_model=eval_model)
            log_print(f"[repair] Completed for {d.name}", root / "progress.log")
        except Exception as e:
            log_print(f"[repair] Failed for {d.name}: {e}", root / "progress.log")


def rerun_a4s_for_existing(
    out_root: Path | str,
    rounds: int,
    roles: List[str],
    max_calls: int,
    eval_model: str = "doubao-seed-1-6-250615",
) -> None:
    # Ensure environment (.env) is loaded for API keys
    ensure_env_loaded()
    root = Path(out_root).resolve()
    root_log = root / "progress.log"
    log_print(f"[rerun] Rerunning Agents4Sci only with rounds={rounds}, roles={len(roles)}, max_calls={max_calls}", root_log)
    hb_thread, hb_stop = start_heartbeat(root_log, interval_sec=30)
    try:
        for d in sorted([p for p in root.iterdir() if p.is_dir()]):
            try:
                summary_path = d / "summary.json"
                if not summary_path.exists():
                    continue
                info = json.loads(summary_path.read_text(encoding="utf-8"))
                topic = info.get("topic") or ""
                if not topic:
                    continue
                a4s_dir = d / "agents4sci_v2"
                a4s_dir.mkdir(parents=True, exist_ok=True)
                log_print(f"[rerun] Topic {d.name}: running Agents4Sci...", root_log)
                run_agents4sci_with_budget(topic, a4s_dir, max_calls=max_calls, root_log=root_log, rounds=rounds, roles=roles)
                log_print(f"[rerun] Topic {d.name}: evaluating...", root_log)
                evaluate_run_dir(d, LLMClient(), eval_model=eval_model)
            except Exception as e:
                log_print(f"[rerun] Failed on {d.name}: {e}", root_log)
        log_print("[rerun] Re-aggregating results...", root_log)
        aggregate_across_topics(root)
        log_print(f"[rerun] Aggregate written -> {root / 'aggregate.md'}", root_log)
    finally:
        hb_stop.set()
        time.sleep(0.2)


def main() -> None:
    ensure_env_loaded()

    parser = argparse.ArgumentParser(description="Run Agents4Sci and baselines for given topics under experiments/0823 with detailed logs and aggregation")
    parser.add_argument("--out", default="experiments/0823", help="Output root directory")
    parser.add_argument("--topics", required=True, help="Semicolon-separated topics")
    parser.add_argument("--max-calls", type=int, default=25, help="Max API calls per model")
    parser.add_argument("--parallel", action="store_true", help="Run all models in parallel per topic")
    parser.add_argument("--eval-model", default="doubao-seed-1-6-250615", help="Evaluator model name")
    args = parser.parse_args()

    out_root = Path(args.out).resolve()
    out_root.mkdir(parents=True, exist_ok=True)
    root_log = out_root / "progress.log"

    hb_thread, hb_stop = start_heartbeat(root_log, interval_sec=30)

    log_print("Starting 0823 runs (Agents4Sci + baselines)", root_log)

    topics = [t.strip() for t in args.topics.split(";") if t.strip()]
    for i, t in enumerate(topics, start=1):
        log_print(f"Topic {i}: {t}", root_log)

    try:
        for idx, topic in enumerate(topics, start=1):
            log_print(f"=== Topic {idx}/{len(topics)} started ===", root_log)
            run_topic_all_models(
                topic_idx=idx,
                topic=topic,
                out_root=out_root,
                max_calls=args.max_calls,
                eval_model=args.eval_model,
                parallel_models=args.parallel,
                root_log=root_log,
            )
            log_print(f"=== Topic {idx}/{len(topics)} finished ===", root_log)

        log_print("Aggregating results across topics...", root_log)
        aggregate_across_topics(out_root)
        log_print(f"Aggregate written -> {out_root / 'aggregate.md'}", root_log)
    finally:
        hb_stop.set()
        time.sleep(0.2)

    log_print(f"All done. Results saved under {out_root}", root_log)


if __name__ == "__main__":
    main()
