#!/usr/bin/env python3
import argparse
import csv
import json
from pathlib import Path


def flatten_metrics(data, prefix=""):
    flat = {}
    if isinstance(data, dict):
        # If metric dict has a value field, capture it as the primary metric.
        if "value" in data and isinstance(data["value"], (int, float)):
            key = prefix.rstrip(".")
            flat[key] = data["value"]
            # Also capture common counters when present.
            for k in ("count", "total", "correct", "match_count", "evaluated_count"):
                if k in data and isinstance(data[k], (int, float)):
                    flat[f"{key}.{k}"] = data[k]
        else:
            for k, v in data.items():
                new_prefix = f"{prefix}{k}."
                flat.update(flatten_metrics(v, new_prefix))
    return flat


def extract_path_metadata(path: Path):
    parts = path.parts
    meta = {}
    if "risk_detection_results" in parts:
        idx = parts.index("risk_detection_results")
        if len(parts) > idx + 1:
            meta["run_id"] = parts[idx + 1]
        if len(parts) > idx + 2:
            meta["model_name"] = parts[idx + 2]
        meta["result_type"] = "risk_detection_results"
    return meta


def collect_rows(root: Path):
    rows = []
    for summary_path in sorted(root.glob("**/summary.json")):
        with summary_path.open("r", encoding="utf-8") as f:
            data = json.load(f)
        row = {
            "summary_path": str(summary_path),
        }
        row.update(extract_path_metadata(summary_path))
        row.update(flatten_metrics(data))
        rows.append(row)
    return rows


def main():
    parser = argparse.ArgumentParser(
        description="Flatten summary.json metrics into a CSV."
    )
    parser.add_argument(
        "--root",
        default="IS-Bench/results",
        help="Root directory to search for summary.json files.",
    )
    parser.add_argument(
        "--output",
        default="summary_metrics.csv",
        help="Output CSV path.",
    )
    args = parser.parse_args()

    root = Path(args.root)
    if not root.exists():
        raise SystemExit(f"Root path does not exist: {root}")

    rows = collect_rows(root)
    if not rows:
        raise SystemExit(f"No summary.json files found under {root}")

    # Build header with preferred ordering first, then the rest.
    preferred_order = [
        "model_name",
        "detection_metrics.precision",
        "detection_metrics.recall",
        "detection_metrics.f1_score",
        "detection_metrics.false_positive_rate",
        "step_wise_temporal_metrics.safe_step_accuracy",
        "risk_type_metrics.risk_type_accuracy",
        "hazard_explanation_metrics.hazard_match_rate",
        "detection_metrics.confusion_matrix.true_positive",
        "detection_metrics.confusion_matrix.false_positive",
        "detection_metrics.confusion_matrix.false_negative",
        "detection_metrics.confusion_matrix.true_negative",
    ]

    all_keys = {k for row in rows for k in row.keys()}
    fieldnames = [k for k in preferred_order if k in all_keys]
    for k in sorted(all_keys):
        if k not in fieldnames:
            fieldnames.append(k)
    output_path = Path(args.output)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    with output_path.open("w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)

    print(f"Wrote {len(rows)} rows to {output_path}")


if __name__ == "__main__":
    main()
