"""
collect_results.py

Walk a directory tree like

    experiments/<dataset>/<baseline_timestamp>/results.csv

and aggregate the per-sample summaries (means ± std) into one table.

Output:
  ../all_results.csv
  ../all_results.json

The CSV columns are

dataset, baseline, sample_size,
mse, mse_std,
aug_mse, aug_mse_std,
delta_mse, delta_mse_std,
p_wilcoxon, p_wilcoxon_std,
should_proceed
"""

from pathlib import Path
import json
import pandas as pd


# ──────────────────────────────────────────────────────────────────────────────
# Helpers
# ──────────────────────────────────────────────────────────────────────────────
METRICS = ["mse", "aug_mse", "delta_mse", "p_wilcoxon"]  # ← means & std for each


def collect_experiment_rows(exp_root: Path) -> list[dict]:
    """Walk `exp_root` and pull rows from every results.csv file."""
    rows: list[dict] = []

    # level-1: dataset directory
    for dataset_dir in exp_root.iterdir():
        if not dataset_dir.is_dir():
            continue
        dataset_name = dataset_dir.name

        # level-2: run directory (name encodes baseline: {baseline}_{timestamp})
        for run_dir in dataset_dir.iterdir():
            if not run_dir.is_dir():
                continue

            baseline = run_dir.name.split("_")[0]          # e.g. "xgboost", "mlp"
            results_path = run_dir / "results.csv"
            if not results_path.exists():
                continue

            try:
                df = pd.read_csv(results_path)
            except Exception as exc:
                print(f"⚠️  Could not read {results_path}: {exc}")
                continue

            # group the long-form results by sample tag
            for sample_tag, grp in df.groupby("dataset"):
                sample_size = int(sample_tag.split("_")[-1])

                row: dict[str, object] = {
                    "dataset": dataset_name,
                    "baseline": baseline,
                    "sample_size": sample_size,
                }

                # pull mean & std for each metric
                for metric in METRICS:
                    sub = grp.loc[grp["metric"] == metric]
                    if not sub.empty:
                        row[metric] = sub["mean"].iloc[0]
                        row[f"{metric}_std"] = sub["std"].iloc[0]
                    else:
                        row[metric] = None
                        row[f"{metric}_std"] = None

                # derive boolean flag
                sp = grp.loc[grp["metric"] == "should_proceed", "mean"]
                row["should_proceed"] = bool(sp.iloc[0] > 0.5) if not sp.empty else None

                rows.append(row)

    return rows


# ──────────────────────────────────────────────────────────────────────────────
# Main
# ──────────────────────────────────────────────────────────────────────────────
def main() -> None:
    experiments_path = "./experiments"               # change if you like
    exp_root = Path(experiments_path).expanduser().resolve()
    if not exp_root.exists():
        raise SystemExit(f"Path not found: {exp_root}")

    rows = collect_experiment_rows(exp_root)
    if not rows:
        raise SystemExit("No results.csv files found.")

    col_order = (
        ["dataset", "baseline", "sample_size"]
        + [c for m in METRICS for c in (m, f"{m}_std")]
        + ["should_proceed"]
    )

    df = (
        pd.DataFrame(rows)
        .sort_values(["dataset", "baseline", "sample_size"])
        .loc[:, col_order]
    )
    print(df)

    # ─── Save artifacts ────────────────────────────────────────────────────
    csv_out = exp_root.parent / "scripts" / "all_results.csv"
    json_out = exp_root.parent / "scripts" / "all_results.json"

    df.to_csv(csv_out, index=False)
    json_out.write_text(json.dumps(json.loads(df.to_json(orient="records")), indent=2))

    print(f"\n Saved {csv_out}")
    print(f"Saved {json_out}")

    # a compact “wide” preview (means only)
    wide = df.pivot_table(
        index=["dataset", "sample_size"],
        columns="baseline",
        values=["mse", "aug_mse", "delta_mse", "p_wilcoxon"],
    )
    print(wide)


if __name__ == "__main__":
    main()