#!/usr/bin/env python3
"""
Run the SSD experiments multiple times and log results to JSONL and/or CSV.
"""

from __future__ import annotations

import argparse
import csv
import json
from datetime import datetime, timezone
from pathlib import Path
from typing import Iterable, List

from .verify_ssd import gather_results
from . import exp4_time_scaling


def _timestamp() -> str:
    return datetime.now(timezone.utc).isoformat()


def _ensure_parent(path: Path) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)


def run_and_collect(seeds: Iterable[int]) -> List[dict]:
    records: List[dict] = []
    for seed in seeds:
        for res in gather_results(seed=seed):
            records.append(
                {
                    "timestamp": _timestamp(),
                    "seed": seed,
                    "experiment": res.name,
                    "details": res.details,
                    "meta": res.meta or {},
                }
            )
    return records


def write_jsonl(records: List[dict], path: Path) -> None:
    _ensure_parent(path)
    with path.open("w", encoding="utf-8") as f:
        for rec in records:
            f.write(json.dumps(rec) + "\n")


def write_csv(records: List[dict], path: Path) -> None:
    _ensure_parent(path)
    fieldnames = ["timestamp", "seed", "experiment", "details", "meta"]
    with path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for rec in records:
            row = rec.copy()
            row["meta"] = json.dumps(rec.get("meta", {}))
            writer.writerow(row)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Run SSD experiments and log results.")
    parser.add_argument(
        "--experiment",
        choices=["all", "exp1", "exp2", "exp2b", "exp3", "exp4", "exp5"],
        default="all",
        help="Run all experiments or a single experiment.",
    )
    parser.add_argument(
        "--repeats",
        type=int,
        default=5,
        help="Number of seeds/runs (seeds = 0..repeats-1).",
    )
    parser.add_argument(
        "--a-list",
        type=str,
        default=None,
        help="Comma-separated decay values for Exp1 (e.g., 0.5,0.8,0.9).",
    )
    parser.add_argument(
        "--t-list",
        type=str,
        default=None,
        help="Comma-separated T values for Exp1 (e.g., 20,50).",
    )
    parser.add_argument(
        "--auto-grid",
        action="store_true",
        help="Use a default grid of a values [0.5,0.8,0.9] and T values [20,50] for Exp1.",
    )
    parser.add_argument(
        "--jsonl",
        type=Path,
        default=Path("experiments/logs/ssd_runs.jsonl"),
        help="Path to JSONL log file.",
    )
    parser.add_argument(
        "--csv",
        type=Path,
        default=Path("experiments/logs/ssd_runs.csv"),
        help="Path to CSV log file.",
    )
    parser.add_argument(
        "--exp2-t-list",
        type=str,
        default=None,
        help="Comma-separated T values for Exp2 (e.g., 20,40).",
    )
    parser.add_argument(
        "--exp2-n-list",
        type=str,
        default=None,
        help="Comma-separated N values for Exp2 decays (e.g., 2,3).",
    )
    parser.add_argument(
        "--exp2b-t-list",
        type=str,
        default=None,
        help="Comma-separated T values for Exp2b (time-varying) (e.g., 12,16).",
    )
    parser.add_argument(
        "--exp2b-n-list",
        type=str,
        default=None,
        help="Comma-separated N values for Exp2b (e.g., 3,4).",
    )
    parser.add_argument(
        "--auto-grid-exp2",
        action="store_true",
        help="Use default grids for Exp2 (T=[20,40], N=[2,3]) and Exp2b (T=[12,16], N=[3]).",
    )
    parser.add_argument(
        "--exp3-t-list",
        type=str,
        default=None,
        help="Comma-separated T values for Exp3 (e.g., 15,30).",
    )
    parser.add_argument(
        "--exp3-seeds",
        type=str,
        default=None,
        help="Comma-separated seeds for Exp3 (e.g., 0,1,2).",
    )
    parser.add_argument(
        "--exp3-random-n",
        type=str,
        default=None,
        help="Comma-separated N values to add random decay cases for Exp3 (e.g., 2,3,4).",
    )
    parser.add_argument(
        "--exp4-t-list",
        type=str,
        default=None,
        help="Comma-separated T values for Exp4 (e.g., 150,300,600).",
    )
    parser.add_argument(
        "--exp4-n-list",
        type=str,
        default=None,
        help="Comma-separated N values for Exp4 (e.g., 2,4,8).",
    )
    parser.add_argument(
        "--exp4-d-list",
        type=str,
        default=None,
        help="Comma-separated d values for Exp4 (e.g., 8,16,32).",
    )
    parser.add_argument(
        "--exp4-end-to-end",
        action="store_true",
        help="If set, Exp4 measures full end-to-end timing (kernel build + matmul) instead of core loops only.",
    )
    parser.add_argument(
        "--exp4-plot",
        action="store_true",
        help="If set, save an Exp4 timing plot after runs complete.",
    )
    parser.add_argument(
        "--exp4-plot-path",
        type=Path,
        default=Path("outputs/exp4_time_scaling.png"),
        help="Path to save the Exp4 plot (used when --exp4-plot is set).",
    )
    parser.add_argument(
        "--exp5-t-list",
        type=str,
        default=None,
        help="Comma-separated T values for Exp5 (e.g., 20,40,80,160,320).",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    seeds = range(args.repeats)
    if args.auto_grid:
        a_list = [0.5, 0.8, 0.9]
        t_list = [20, 50]
        exp2_T = [20, 40]
        exp2_N = [2, 3]
        exp2b_T = [12, 16]
        exp2b_N = [3]
    else:
        a_list = [float(x) for x in args.a_list.split(",")] if args.a_list else None
        t_list = [int(x) for x in args.t_list.split(",")] if args.t_list else None
        exp2_T = [int(x) for x in args.exp2_t_list.split(",")] if args.exp2_t_list else None
        exp2_N = [int(x) for x in args.exp2_n_list.split(",")] if args.exp2_n_list else None
        exp2b_T = [int(x) for x in args.exp2b_t_list.split(",")] if args.exp2b_t_list else None
        exp2b_N = [int(x) for x in args.exp2b_n_list.split(",")] if args.exp2b_n_list else None
    exp3_T = [int(x) for x in args.exp3_t_list.split(",")] if args.exp3_t_list else None
    exp3_seeds = [int(x) for x in args.exp3_seeds.split(",")] if args.exp3_seeds else None
    exp3_random_n = [int(x) for x in args.exp3_random_n.split(",")] if args.exp3_random_n else None
    exp4_T = [int(x) for x in args.exp4_t_list.split(",")] if args.exp4_t_list else None
    exp4_N = [int(x) for x in args.exp4_n_list.split(",")] if args.exp4_n_list else None
    exp4_d = [int(x) for x in args.exp4_d_list.split(",")] if args.exp4_d_list else None
    exp5_T = [int(x) for x in args.exp5_t_list.split(",")] if args.exp5_t_list else None
    records = []
    all_results = []
    for seed in seeds:
        for res in gather_results(
            seed=seed,
            a_list=a_list,
            T_list=t_list,
            exp2_T_list=exp2_T,
            exp2_N_list=exp2_N,
            exp2b_T_list=exp2b_T,
            exp2b_N_list=exp2b_N,
            exp3_T_values=exp3_T,
            exp3_seeds=exp3_seeds,
            exp3_random_N=exp3_random_n,
            exp4_T_values=exp4_T,
            exp4_N_values=exp4_N,
            exp4_d_values=exp4_d,
            exp4_end_to_end=args.exp4_end_to_end,
            exp5_T_values=exp5_T,
        ):
            all_results.append(res)
            records.append(
                {
                    "timestamp": _timestamp(),
                    "seed": seed,
                    "experiment": res.name,
                    "details": res.details,
                    "meta": res.meta or {},
                }
            )
    if args.experiment != "all":
        target = args.experiment
        def match(rec):
            name = rec["experiment"]
            if target == "exp1":
                return name.startswith("Scalar SSM")
            if target == "exp2":
                return "Diagonal SSM (N=2)" in name
            if target == "exp2b":
                return "time-varying A_t" in name
            if target == "exp3":
                return name.startswith("Rank check")
            if target == "exp4":
                return name.startswith("Time scaling")
            if target == "exp5":
                return name.startswith("Softmax attention rank growth")
            return True
        filtered = [r for r in records if match(r)]
        records = filtered
        all_results = [
            res
            for res in all_results
            if match({"experiment": res.name})
        ]
    write_jsonl(records, args.jsonl)
    write_csv(records, args.csv)
    print(f"Logged {len(records)} experiment results")
    print(f"JSONL: {args.jsonl}")
    print(f"CSV:   {args.csv}")
    if args.exp4_plot:
        exp4_results = [r for r in all_results if r.name.startswith("Time scaling")]
        if not exp4_results:
            print("No Exp4 results available for plotting.")
        else:
            exp4_time_scaling.plot_results(exp4_results, args.exp4_plot_path)


if __name__ == "__main__":
    main()
