#!/usr/bin/env python3
from __future__ import annotations

import json
import sys
from pathlib import Path
from typing import Dict, Any, List

# Ensure project root
THIS_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = THIS_DIR.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))


def load_eval(run_dir: Path) -> Dict[str, Any]:
    p = run_dir / "evaluation.json"
    if not p.exists():
        return {}
    try:
        return json.loads(p.read_text(encoding="utf-8"))
    except Exception:
        return {}


def main() -> None:
    import argparse
    parser = argparse.ArgumentParser(description="Aggregate evaluations across all topics into a single report")
    parser.add_argument("--root", required=True, help="Experiments root (e.g., experiments/0821)")
    args = parser.parse_args()

    root = Path(args.root).resolve()
    runs = [p for p in root.iterdir() if p.is_dir()]
    runs.sort()

    table_lines: List[str] = [
        "| Topic | agents4sci_v2 | baseline_single | baseline_tree | baseline_debate |",
        "|---|---:|---:|---:|---:|",
    ]

    models = ["agents4sci_v2", "baseline_single", "baseline_tree", "baseline_debate"]
    # New 5D rubric metrics
    metrics = [
        "rigor_traceability",
        "integration_causality",
        "feasibility_minimality",
        "uncertainty_adaptation",
        "decisionability",
        "overall",
    ]

    # metric_aggregates[metric][model] -> List[float]
    metric_aggregates: Dict[str, Dict[str, List[float]]] = {
        m: {k: [] for k in models} for m in metrics
    }
    details: List[str] = []

    for r in runs:
        eval_data = load_eval(r)
        row = [r.name]
        for model in models:
            data = eval_data.get(model, {})
            overall = data.get("overall")
            if isinstance(overall, (int, float)):
                row.append(f"{overall:.2f}")
            else:
                row.append("-")
            # collect all metrics if present
            for m in metrics:
                val = data.get(m)
                if isinstance(val, (int, float)):
                    metric_aggregates[m][model].append(float(val))
        table_lines.append("| " + " | ".join(row) + " |")

    # Macro averages
    def avg(xs: List[float]) -> float:
        return sum(xs) / len(xs) if xs else 0.0

    def std(xs: List[float]) -> float:
        # Population standard deviation; returns 0.0 for empty or single-value lists
        if not xs or len(xs) == 1:
            return 0.0
        import math
        mean = avg(xs)
        return math.sqrt(sum((x - mean) ** 2 for x in xs) / len(xs))

    summary_lines = [
        "## Overall Comparison",
        "",
        *table_lines,
        "",
        "### Macro Average by Metric (5D rubric, across topics; mean ± std)",
    ]

    # Build a table of per-metric averages per model
    metric_table: List[str] = [
        "| Metric "+"| " + " | ".join(models) + " |",
        "|---|" + "|".join(["---:"]*len(models)) + "|",
    ]
    for m in metrics:
        row_vals = []
        for model in models:
            vals = metric_aggregates[m][model]
            mean_val = avg(vals)
            std_val = std(vals)
            row_vals.append(f"{mean_val:.2f} ± {std_val:.2f}")
        metric_table.append("| " + m + " | " + " | ".join(row_vals) + " |")
    summary_lines.extend(metric_table)

    (root / "aggregate.md").write_text("\n".join(summary_lines), encoding="utf-8")
    print(f"Saved aggregate to {root}/aggregate.md")


if __name__ == "__main__":
    main()


