#!/usr/bin/env python3
"""Summarize mean±std metrics from all_seeds_results.json files in a directory tree."""
from __future__ import annotations

import argparse
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Set

import summarize_metrics


@dataclass
class MetricSummary:
    eval_mean: Optional[float] = None
    eval_std: Optional[float] = None
    ood_mean: Optional[float] = None
    ood_std: Optional[float] = None


@dataclass
class SeedSummary:
    source: Path
    dataset: str
    metrics: Dict[str, MetricSummary] = field(default_factory=dict)


def load_single_summary(json_path: Path) -> SeedSummary:
    with json_path.open("r", encoding="utf-8") as handle:
        payload = json.load(handle)

    dataset_name = str(payload.get("dataset") or json_path.parent.name)

    def extract_stat(block: dict | None) -> tuple[Optional[float], Optional[float]]:
        if not isinstance(block, dict):
            return None, None
        mean = block.get("mean")
        std = block.get("std")
        try:
            mean_value = float(mean) if mean is not None else None
        except (TypeError, ValueError):
            mean_value = None
        try:
            std_value = float(std) if std is not None else None
        except (TypeError, ValueError):
            std_value = None
        return mean_value, std_value

    metrics: Dict[str, MetricSummary] = {}

    for key, block in payload.items():
        if not isinstance(block, dict):
            continue
        if key.startswith("eval_"):
            metric_name = key[5:]
            summary = metrics.setdefault(metric_name, MetricSummary())
            summary.eval_mean, summary.eval_std = extract_stat(block)
        elif key.startswith("ood_"):
            metric_name = key[4:]
            summary = metrics.setdefault(metric_name, MetricSummary())
            summary.ood_mean, summary.ood_std = extract_stat(block)

    return SeedSummary(source=json_path, dataset=dataset_name, metrics=metrics)


def format_stat(mean: Optional[float], std: Optional[float]) -> str:
    if mean is None or std is None:
        return "N/A"
    return f"{mean:.4f} ± {std:.4f}"


def summarize_tree(root: Path, filename: str) -> List[SeedSummary]:
    summaries: List[SeedSummary] = []
    for json_path in sorted(root.rglob(filename)):
        if not json_path.is_file():
            continue
        summaries.append(load_single_summary(json_path))
    return summaries


def print_summaries(root: Path, summaries: Iterable[SeedSummary], filename: str) -> None:
    summaries = list(summaries)
    if not summaries:
        print(f"No {filename} files found under {root}.")
        return

    rel_paths: List[Path] = []
    for summary in summaries:
        parent = summary.source.parent
        try:
            rel_path = parent.relative_to(root)
        except ValueError:
            rel_path = parent
        rel_paths.append(rel_path if rel_path != Path("") else Path("."))
    max_label_len = max(len(str(rel_path)) for rel_path in rel_paths)

    metric_names = sorted({metric for summary in summaries for metric in summary.metrics})
    metric_width = max((len(name) for name in metric_names), default=4)

    for summary, rel_path in zip(summaries, rel_paths):
        label = str(rel_path)
        if not metric_names:
            print(f"{label:<{max_label_len}}  eval: N/A  |  ood: N/A")
            continue

        for idx, metric_name in enumerate(metric_names):
            stats = summary.metrics.get(metric_name, MetricSummary())
            eval_stat = format_stat(stats.eval_mean, stats.eval_std)
            ood_stat = format_stat(stats.ood_mean, stats.ood_std)
            line_label = label if idx == 0 else ""
            print(
                f"{line_label:<{max_label_len}}  {metric_name:<{metric_width}}"
                f"  eval: {eval_stat}  |  ood: {ood_stat}"
            )


def dataset_dirs_with_metrics(root: Path) -> List[Path]:
    datasets: Set[Path] = set()
    for metric_name in ("eval_results.json", "ood_results.json"):
        for metric_path in root.rglob(metric_name):
            parents = metric_path.parents
            if len(parents) >= 3:
                dataset_dir = parents[2]
                if dataset_dir.is_dir():
                    datasets.add(dataset_dir)
    return sorted(datasets)


def generate_all_seeds(root: Path, filename: str) -> None:
    for dataset_dir in dataset_dirs_with_metrics(root):
        summarize_metrics.summarize_dataset(dataset_dir, filename)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        "root",
        type=Path,
        help="Top-level directory to search for all_seeds_results.json files.",
    )
    parser.add_argument(
        "--filename",
        default="all_seeds_results.json",
        help="Name of the JSON files to summarize (default: all_seeds_results.json).",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    root = args.root.expanduser().resolve()
    generate_all_seeds(root, args.filename)
    summaries = summarize_tree(root, args.filename)
    print_summaries(root, summaries, args.filename)


if __name__ == "__main__":
    main()
