#!/usr/bin/env python3
"""
Check logged SSD experiments for pass/fail against simple thresholds.
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Dict, List


def load_jsonl(path: Path) -> List[dict]:
    with path.open("r", encoding="utf-8") as f:
        return [json.loads(line) for line in f if line.strip()]


def distinct_count(values: List[float]) -> int:
    # deduplicate with a small tolerance to avoid float noise
    rounded = [round(v, 12) for v in values]
    return len(set(rounded))


def check_records(records: List[dict], tol: float) -> List[str]:
    failures: List[str] = []
    for rec in records:
        name = rec.get("experiment", "")
        meta: Dict = rec.get("meta", {}) or {}

        if name == "Scalar SSM ≡ 1-SS attention":
            err = meta.get("max_error")
            if err is None or err >= tol:
                failures.append(f"{name} seed={rec.get('seed')} max_error={err} >= {tol}")

        elif name == "Diagonal SSM (N=2) ≡ sum of 1-SS heads":
            err = meta.get("max_error")
            decays = meta.get("decays") or []
            gen_rank = meta.get("generator_rank")
            expected_rank = len(decays)
            if err is None or err >= tol:
                failures.append(f"{name} seed={rec.get('seed')} max_error={err} >= {tol}")
            if gen_rank != expected_rank:
                failures.append(
                    f"{name} seed={rec.get('seed')} generator_rank={gen_rank} expected={expected_rank}"
                )

        elif name == "Diagonal SSM (time-varying A_t) ≡ sum of 1-SS heads":
            err = meta.get("max_error")
            if err is None or err >= tol:
                failures.append(f"{name} seed={rec.get('seed')} max_error={err} >= {tol}")

        elif name.startswith("Rank check:"):
            decays = meta.get("decays") or []
            gen_rank = meta.get("generator_rank")
            expected = distinct_count(decays)
            if gen_rank != expected:
                failures.append(
                    f"{name} seed={rec.get('seed')} generator_rank={gen_rank} expected={expected}"
                )

        elif name == "Softmax attention rank growth":
            T = meta.get("T")
            rank = meta.get("rank")
            if T is None or rank is None or rank < max(2, int(0.5 * T)):
                failures.append(
                    f"{name} seed={rec.get('seed')} rank={rank} too low for T={T}"
                )

        # Time scaling experiment is informational; no strict pass/fail.

    return failures


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Check SSD experiment logs for pass/fail.")
    parser.add_argument(
        "--log",
        type=Path,
        default=Path("experiments/logs/ssd_runs.jsonl"),
        help="Path to JSONL log file produced by run_and_log.py",
    )
    parser.add_argument(
        "--tol",
        type=float,
        default=1e-10,
        help="Tolerance for max_error checks.",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    records = load_jsonl(args.log)
    failures = check_records(records, args.tol)
    if failures:
        print("FAILURES detected:")
        for msg in failures:
            print(f"- {msg}")
        raise SystemExit(1)
    else:
        print(f"All checks passed for {len(records)} records (tol={args.tol}).")


if __name__ == "__main__":
    main()
