#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path
from typing import List, Dict, Any

THIS_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = THIS_DIR.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.a4s.llm_client import LLMClient
from src.a4s.schemas import OrchestratorConfig
from src.ablation.runner import (
    AblationOrchestrator,
    preset_two_experts_one_round,
    preset_no_debate,
    preset_no_conflict,
    preset_no_shared_frame,
    preset_no_refinement,
)
from scripts.evaluate_run import evaluate_run_dir
from scripts.run_all_0823 import log_print, start_heartbeat, ensure_env_loaded


def main() -> None:
    ensure_env_loaded()
    parser = argparse.ArgumentParser(description="Run ablations for a single topic under experiments/0825")
    parser.add_argument("--topic", required=True)
    parser.add_argument("--out-root", default="experiments/0825")
    parser.add_argument("--eval-model", default="doubao-seed-1-6-thinking-250715")
    parser.add_argument("--rounds", type=int, default=3)
    parser.add_argument("--roles", default="Physics,ChemistryMaterials,BiologyEcology,Medicine,Sociology,Economics,PoliticsIR,EngineeringInfrastructure,EnvironmentalScience")
    args = parser.parse_args()

    out_root = Path(args.out_root).resolve()
    out_root.mkdir(parents=True, exist_ok=True)
    root_log = out_root / "progress.log"

    hb_t, hb_stop = start_heartbeat(root_log, interval_sec=30)
    log_print(f"[ablation] Topic: {args.topic}", root_log)
    try:
        # Base roles and defaults
        roles = [r.strip() for r in args.roles.split(",") if r.strip()]
        model_id = OrchestratorConfig().model_id
        orch = AblationOrchestrator(model_id=model_id)

        # Prepare variants
        variants = {
            "agents4sci_v2": (roles, args.rounds, True, True, True, True),
            "ablate_two_experts_one_round": None,  # preset
            "ablate_no_debate": None,
            "ablate_no_conflict": None,
            "ablate_no_shared_frame": None,
            "ablate_no_refinement": None,
        }

        topic_dir = out_root / (args.topic.lower().replace(" ", "_")[:64])
        topic_dir.mkdir(parents=True, exist_ok=True)

        # Full model (same pipeline via ablation orchestrator with all enabled)
        cfg_full = preset_no_debate(roles, args.rounds)  # start from a preset then override to fully enabled
        cfg_full.enable_debate = True
        cfg_full.enable_conflict_resolution = True
        cfg_full.use_shared_frame = True
        cfg_full.use_refinement = True
        log_print("[ablation] Running full model (agents4sci_v2)", root_log)
        base_dir = topic_dir / "agents4sci_v2"
        base_dir.mkdir(parents=True, exist_ok=True)
        report, structured = orch.run(args.topic, base_dir, cfg_full)

        # Presets
        cfg_1 = preset_two_experts_one_round()
        cfg_2 = preset_no_debate(roles, args.rounds)
        cfg_3 = preset_no_conflict(roles, args.rounds)
        cfg_4 = preset_no_shared_frame(roles, args.rounds)
        cfg_5 = preset_no_refinement(roles, args.rounds)

        runs = [
            ("ablate_two_experts_one_round", cfg_1),
            ("ablate_no_debate", cfg_2),
            ("ablate_no_conflict", cfg_3),
            ("ablate_no_shared_frame", cfg_4),
            ("ablate_no_refinement", cfg_5),
        ]
        for name, cfg in runs:
            log_print(f"[ablation] Running {name}", root_log)
            out_dir = topic_dir / name
            out_dir.mkdir(parents=True, exist_ok=True)
            orch.run(args.topic, out_dir, cfg)

        # Evaluate (reuse same 4-model schema by mapping dirs to names)
        # We will synthesize a temp run_dir with subfolders named like the evaluator expects
        # agents4sci_v2 as-is; map others onto baseline_* keys to reuse the table
        tmp_run = topic_dir / "_eval"
        tmp_run.mkdir(exist_ok=True)
        mapping = {
            "agents4sci_v2": base_dir,
            "baseline_single": topic_dir / "ablate_two_experts_one_round",
            "baseline_tree": topic_dir / "ablate_no_debate",
            "baseline_debate": topic_dir / "ablate_no_conflict",
        }
        for k, v in mapping.items():
            d = tmp_run / k
            d.mkdir(exist_ok=True)
            # symlink or copy report.md
            (d / "report.md").write_text((v / "report.md").read_text(encoding="utf-8"), encoding="utf-8")

        client = LLMClient()
        results = evaluate_run_dir(tmp_run, client, eval_model=args.eval_model)

        # Write an aggregate.md for this topic (5D rubric)
        models = [
            ("agents4sci_v2", "agents4sci_v2"),
            ("ablate_two_experts_one_round", "baseline_single"),
            ("ablate_no_debate", "baseline_tree"),
            ("ablate_no_conflict", "baseline_debate"),
            # Extras not in evaluation mapping (no_shared_frame/no_refinement): include by direct read
        ]
        # Try to add the remaining two by evaluating their report directly via the evaluator API shape
        # Here we simply add their overall by reusing evaluate_run_dir mapping approach (copy reports temporarily)
        # For simplicity, we will render a per-topic aggregate with available four; then append the other two as separate rows using a lightweight call

        lines = [
            f"## Ablation Results for {args.topic}",
            "",
            "| Variant | overall |",
            "|---|---:|",
        ]
        def get_overall(d: Dict[str, Any], key: str) -> float:
            v = d.get(key, {}).get("overall")
            return float(v) if isinstance(v, (int, float)) else 0.0

        for name, mapped in models:
            ov = get_overall(results, mapped)
            lines.append(f"| {name} | {ov:.2f} |")

        # Append the other two by ad-hoc evaluation reuse (read their report and score now)
        from src.a4s.llm_client import LLMClient as _LL
        from scripts.evaluate_run import score_report as _score
        for extra in ["ablate_no_shared_frame", "ablate_no_refinement"]:
            rp = (topic_dir / extra / "report.md")
            if rp.exists():
                md = rp.read_text(encoding="utf-8")
                sc = _score(_LL(), md, eval_model=args.eval_model)
                lines.append(f"| {extra} | {float(sc.get('overall', 0.0)):.2f} |")

        (topic_dir / "aggregate.md").write_text("\n".join(lines), encoding="utf-8")

        # Aggregate for this topic: simple summary.json
        summary = {"topic": args.topic, "paths": {k: str(v) for k, v in mapping.items()}}
        (topic_dir / "summary.json").write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
        log_print(f"[ablation] Completed -> {topic_dir}", root_log)
    finally:
        hb_stop.set()


if __name__ == "__main__":
    main()


