#!/usr/bin/env python3
"""Aggregate regression metrics (MAE, RMSE) across seeds for each dataset."""
from __future__ import annotations

import argparse
import json
import statistics
from pathlib import Path
from typing import Dict, Iterable, Optional

METRICS = ("mae", "rmse")


def read_json(path: Path) -> Dict[str, float]:
    with path.open("r", encoding="utf-8") as handle:
        return json.load(handle)


def select_run_dir(seed_dir: Path) -> Optional[Path]:
    """
    Return the most recent run directory that has at least one of:
    - eval_results.json
    - ood_results.json
    """
    run_dirs = sorted((p for p in seed_dir.iterdir() if p.is_dir()))
    for candidate in reversed(run_dirs):
        has_eval = (candidate / "eval_results.json").is_file()
        has_ood = (candidate / "ood_results.json").is_file()
        if has_eval or has_ood:
            return candidate
    return None


def compute_stats(values: Iterable[float]) -> Dict[str, float | None]:
    values = list(values)
    if not values:
        return {"mean": None, "std": None}
    mean = statistics.fmean(values)
    std = statistics.pstdev(values) if len(values) > 1 else 0.0
    return {"mean": mean, "std": std}


def summarize_dataset(dataset_dir: Path, output_name: str) -> Optional[Path]:
    per_seed_eval: Dict[str, Dict[str, float]] = {metric: {} for metric in METRICS}
    per_seed_ood: Dict[str, Dict[str, float]] = {metric: {} for metric in METRICS}

    for seed_dir in sorted(p for p in dataset_dir.iterdir() if p.is_dir()):
        run_dir = select_run_dir(seed_dir)
        if run_dir is None:
            print(f"[warn] Skipping {seed_dir} (no eval/ood JSON found in any run dir)")
            continue

        eval_path = run_dir / "eval_results.json"
        ood_path = run_dir / "ood_results.json"

        # Read eval if available
        if eval_path.is_file():
            try:
                eval_metrics = read_json(eval_path)
            except json.JSONDecodeError:
                print(f"[warn] Invalid JSON in {eval_path}")
                eval_metrics = {}
            for metric in METRICS:
                key = f"eval_{metric}"
                if key in eval_metrics:
                    try:
                        value = float(eval_metrics[key])
                    except (TypeError, ValueError):
                        print(f"[warn] Invalid value for {key} in {eval_path}")
                        continue
                    per_seed_eval[metric][seed_dir.name] = value

        # Read ood if available
        if ood_path.is_file():
            try:
                ood_metrics = read_json(ood_path)
            except json.JSONDecodeError:
                print(f"[warn] Invalid JSON in {ood_path}")
                ood_metrics = {}
            for metric in METRICS:
                key = f"ood_{metric}"
                if key in ood_metrics:
                    try:
                        value = float(ood_metrics[key])
                    except (TypeError, ValueError):
                        print(f"[warn] Invalid value for {key} in {ood_path}")
                        continue
                    per_seed_ood[metric][seed_dir.name] = value

        if not eval_path.is_file() and not ood_path.is_file():
            print(f"[warn] Skipping {seed_dir} (run dir had neither eval nor ood JSON)")

    # If truly nothing valid was found for either metric, bail out
    if not any(per_seed_eval[metric] for metric in METRICS) and not any(
        per_seed_ood[metric] for metric in METRICS
    ):
        print(f"[warn] No valid eval/ood metrics found in {dataset_dir}")
        return None

    all_seed_names = set()
    for metric in METRICS:
        all_seed_names.update(per_seed_eval[metric])
        all_seed_names.update(per_seed_ood[metric])

    summary: Dict[str, object] = {
        "dataset": dataset_dir.name,
        "num_seeds": len(all_seed_names),
    }

    for metric in METRICS:
        eval_values = per_seed_eval[metric]
        ood_values = per_seed_ood[metric]
        summary[f"eval_{metric}"] = {
            "per_seed": dict(sorted(eval_values.items())),
            **compute_stats(eval_values.values()),
        }
        summary[f"ood_{metric}"] = {
            "per_seed": dict(sorted(ood_values.items())),
            **compute_stats(ood_values.values()),
        }

    output_path = dataset_dir / output_name
    with output_path.open("w", encoding="utf-8") as handle:
        json.dump(summary, handle, indent=2, sort_keys=True)
        handle.write("\n")

    return output_path


def summarize_all(root: Path, output_name: str) -> None:
    dataset_dirs = sorted(p for p in root.iterdir() if p.is_dir())
    if not dataset_dirs:
        raise ValueError(f"No dataset directories found in {root}")

    for dataset_dir in dataset_dirs:
        result = summarize_dataset(dataset_dir, output_name)
        if result is not None:
            print(f"[ok] Wrote {result}")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        "root",
        nargs="?",
        default="baselines/pretrained_gnns/saved_results",
        type=Path,
        help="Root directory containing dataset subdirectories.",
    )
    parser.add_argument(
        "--output-name",
        default="all_seeds_results.json",
        help="Filename for the aggregated output JSON.",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    root = args.root.expanduser().resolve()
    summarize_all(root, args.output_name)


if __name__ == "__main__":
    main()
