#!/usr/bin/env python3
"""
Unified aggregator for ALL experiment results.

Scans both:
- OLD structure: results/<dataset>/<category>/*.json
- NEW structure: runs/<run>/results/*.json

Outputs: results/aggregated/master_summary.json
"""

import os
import glob
import json
import argparse
from collections import defaultdict
from datetime import datetime
import numpy as np


def safe_get_results(data):
    """Safely extract the results dict from various formats."""
    if "results" in data and isinstance(data["results"], list) and data["results"]:
        return data["results"][-1]
    return data


def scan_all_results():
    """Scan all result files and extract standardized metrics."""
    all_entries = []

    # OLD structure: results/<dataset>/<category>/*.json
    for f in glob.glob("results/*/**/*.json", recursive=True):
        if "aggregated" in f or "status" in f:
            continue
        try:
            with open(f) as fp:
                data = json.load(fp)
            r = safe_get_results(data)

            parts = f.split("/")
            dataset = parts[1]
            category = parts[2]

            entry = {
                "source": f,
                "dataset": dataset,
                "category": category,
            }

            # Determine type and extract metrics
            if category.startswith("p") and "weighted" in category:
                entry["type"] = "imli_weighted"
                entry["test_acc"] = r.get("compressed_test_accuracy") or r.get("compressed_test_acc")
                entry["kept_clauses"] = r.get("kept_clauses") or r.get("compressed_clauses")
                entry["seed"] = r.get("seed")
            elif category.startswith("p"):
                entry["type"] = "imli"
                entry["test_acc"] = r.get("compressed_test_accuracy") or r.get("compressed_test_acc")
                entry["kept_clauses"] = r.get("kept_clauses") or r.get("compressed_clauses")
                entry["seed"] = r.get("seed")
            elif category == "baseline":
                entry["type"] = "baseline"
                entry["test_acc"] = r.get("test_accuracy") or r.get("test_acc")
                entry["clauses_per_class"] = r.get("n_clauses_per_class") or r.get("clauses_per_class")
                entry["seed"] = r.get("seed")
            else:
                entry["type"] = category
                entry["test_acc"] = r.get("test_accuracy") or r.get("test_acc") or r.get("compressed_test_accuracy")
                entry["seed"] = r.get("seed")

            if entry.get("test_acc") is not None:
                all_entries.append(entry)
        except Exception as e:
            pass

    # NEW structure: runs/<run>/results/*.json
    for f in glob.glob("runs/*/results/*.json"):
        if "status" in f:
            continue
        try:
            with open(f) as fp:
                data = json.load(fp)
            r = safe_get_results(data)

            run_name = f.split("/")[1]
            filename = os.path.basename(f)
            parts = filename.split("_")

            entry = {
                "source": f,
                "run": run_name,
                "dataset": r.get("dataset") or parts[1],
            }

            if parts[0] == "imli":
                entry["type"] = "imli_weighted" if "_w_" in filename else "imli"
                entry["test_acc"] = r.get("compressed_test_acc")
                entry["kept_clauses"] = r.get("compressed_clauses")
                entry["seed"] = r.get("seed")
            elif parts[0] == "baseline":
                entry["type"] = "baseline"
                entry["test_acc"] = r.get("test_acc")
                entry["clauses_per_class"] = r.get("n_clauses_per_class")
                entry["seed"] = r.get("seed")
            else:
                entry["type"] = parts[0]
                entry["test_acc"] = r.get("test_acc") or r.get("compressed_test_acc")
                entry["seed"] = r.get("seed")

            if entry.get("test_acc") is not None:
                all_entries.append(entry)
        except Exception as e:
            pass

    return all_entries


def aggregate(entries):
    """Aggregate entries by dataset."""
    by_dataset = defaultdict(lambda: {"imli": [], "baselines": defaultdict(list)})

    for e in entries:
        ds = e["dataset"]
        if e["type"] in ["imli", "imli_weighted"]:
            by_dataset[ds]["imli"].append(e)
        elif e["type"] == "baseline":
            cpc = e.get("clauses_per_class")
            if cpc:
                by_dataset[ds]["baselines"][cpc].append(e)

    return by_dataset


def compute_summary(by_dataset):
    """Compute summary statistics and find matched baselines."""
    summary = []

    for dataset, data in sorted(by_dataset.items()):
        imli_entries = data["imli"]
        baselines = data["baselines"]

        if not imli_entries:
            continue

        # IMLI stats
        accs = [e["test_acc"] for e in imli_entries if e.get("test_acc")]
        clauses = [e["kept_clauses"] for e in imli_entries if e.get("kept_clauses") is not None]

        if not accs:
            continue

        row = {
            "dataset": dataset,
            "imli_acc_mean": np.mean(accs),
            "imli_acc_std": np.std(accs),
            "imli_n_seeds": len(accs),
            "kept_clauses_mean": np.mean(clauses) if clauses else None,
            "kept_clauses_std": np.std(clauses) if clauses else None,
        }

        if row["kept_clauses_mean"]:
            row["compression"] = 1 - row["kept_clauses_mean"] / 200

        # Find best matched baseline (closest clause count)
        target_cpc = int(round(row["kept_clauses_mean"] / 2)) if row["kept_clauses_mean"] else None

        if target_cpc and baselines:
            # Find closest available baseline
            available_cpcs = sorted(baselines.keys())
            closest_cpc = min(available_cpcs, key=lambda x: abs(x - target_cpc))

            matched_entries = baselines[closest_cpc]
            matched_accs = [e["test_acc"] for e in matched_entries if e.get("test_acc")]

            if matched_accs:
                row["matched_cpc"] = closest_cpc
                row["matched_acc_mean"] = np.mean(matched_accs)
                row["matched_acc_std"] = np.std(matched_accs)
                row["matched_n_seeds"] = len(matched_accs)
                row["delta"] = row["imli_acc_mean"] - row["matched_acc_mean"]

        # Full TM baseline (c=100)
        if 100 in baselines:
            full_entries = baselines[100]
            full_accs = [e["test_acc"] for e in full_entries if e.get("test_acc")]
            if full_accs:
                row["full_tm_acc_mean"] = np.mean(full_accs)
                row["full_tm_acc_std"] = np.std(full_accs)

        summary.append(row)

    # Sort by delta
    summary.sort(key=lambda x: x.get("delta", -999), reverse=True)
    return summary


def print_table(summary):
    """Print markdown comparison table."""
    print("\n" + "=" * 90)
    print("COMPARISON TABLE: MaxSAT vs Matched Compression Baseline")
    print("=" * 90)
    print("\n| Dataset | MaxSAT Acc | Matched TM | Compression | Δ |")
    print("|---------|-----------|------------|-------------|---|")

    wins, ties, losses, missing = 0, 0, 0, 0

    for row in summary:
        ds = row["dataset"]
        maxsat = f"{row['imli_acc_mean']*100:.1f}±{row['imli_acc_std']*100:.1f}%"

        if row.get("matched_acc_mean") is not None:
            matched = f"{row['matched_acc_mean']*100:.1f}±{row.get('matched_acc_std', 0)*100:.1f}%"
            delta = f"{row['delta']*100:+.1f}pp"
            if row["delta"] > 0.005:
                wins += 1
                ds = f"**{ds}**"
            elif row["delta"] < -0.005:
                losses += 1
            else:
                ties += 1
        else:
            matched = "MISSING"
            delta = "N/A"
            missing += 1

        comp = f"{row.get('compression', 0)*100:.1f}% ({row.get('kept_clauses_mean', 0):.0f} cl)"

        print(f"| {ds} | {maxsat} | {matched} | {comp} | {delta} |")

    print(f"\n**Record: {wins} wins, {ties} ties, {losses} losses, {missing} missing baselines**")
    return wins, ties, losses, missing


def main():
    parser = argparse.ArgumentParser(description="Aggregate all experiment results")
    parser.add_argument("--output", "-o", default="results/aggregated/master_summary.json")
    args = parser.parse_args()

    print("Scanning all results...")
    entries = scan_all_results()
    print(f"  Found {len(entries)} valid entries")

    print("Aggregating by dataset...")
    by_dataset = aggregate(entries)
    print(f"  Found {len(by_dataset)} datasets")

    print("Computing summary statistics...")
    summary = compute_summary(by_dataset)

    wins, ties, losses, missing = print_table(summary)

    # Save JSON
    output = {
        "generated": datetime.now().isoformat(),
        "total_entries": len(entries),
        "datasets": [r["dataset"] for r in summary],
        "record": {"wins": wins, "ties": ties, "losses": losses, "missing": missing},
        "summary": summary
    }

    os.makedirs(os.path.dirname(args.output), exist_ok=True)
    with open(args.output, "w") as f:
        json.dump(output, f, indent=2)
    print(f"\nSaved to: {args.output}")

    # Print gaps
    print("\n" + "=" * 90)
    print("GAPS: Datasets needing matched baselines")
    print("=" * 90)
    for row in summary:
        if row.get("matched_acc_mean") is None and row.get("kept_clauses_mean"):
            target_cpc = int(round(row["kept_clauses_mean"] / 2))
            print(f"  {row['dataset']}: need baseline c={target_cpc} ({target_cpc*2} total clauses)")


if __name__ == "__main__":
    main()
