"""
Experiment 4: Runtime scaling O(T) vs. O(T^2) (recurrence vs. explicit attention).
"""

from __future__ import annotations

from pathlib import Path
from typing import Iterable, List, Sequence

import numpy as np

from .common import (
    ExperimentResult,
    causal_mask,
    default_timer,
    run_attention,
    run_recurrence,
)


def run(
    T_values: Sequence[int] = (150, 300, 600, 1200, 2400, 4800, 9600),
    N_values: Iterable[int] = (4,),
    d_values: Iterable[int] = (16,),
    seed: int = 0,
    n_repeats: int = 3,
    end_to_end: bool = False,
) -> List[ExperimentResult]:
    rng = np.random.default_rng(seed)
    results: List[ExperimentResult] = []
    for N in N_values:
        A_vals = np.linspace(0.5, 0.8, N, dtype=np.float64)
        for d in d_values:
            for T in T_values:
                U = rng.standard_normal((T, d))
                rec_times = []
                att_times = []
                for _ in range(n_repeats):
                    if end_to_end:
                        # Full timing: include mask/kernel construction and matmul
                        start = default_timer()
                        B = np.ones((N, d), dtype=np.float64)
                        x = np.zeros(N, dtype=np.float64)
                        for t in range(T):
                            x = A_vals * x + B @ U[t]
                        rec_times.append(default_timer() - start)

                        start = default_timer()
                        mask, t_idx, s_idx = causal_mask(T)
                        diff = np.where(mask, t_idx - s_idx, 0)
                        M = np.zeros((T, T), dtype=np.float64)
                        for m in range(A_vals.size):
                            M += (A_vals[m] ** diff) * mask
                        _ = M @ U
                        att_times.append(default_timer() - start)
                    else:
                        # Core timing: just the recurrence loop or matmul with pre-built kernel
                        rec_times.append(run_recurrence(U, A_vals))
                        att_times.append(run_attention(U, A_vals))
                t_rec = float(np.mean(rec_times))
                t_att = float(np.mean(att_times))
                mode = "full" if end_to_end else "core"
                details = (
                    f"T={T}, N={N}, d={d}, mode={mode}, recurrence_time={t_rec:.4f}s, "
                    f"attention_time={t_att:.4f}s"
                )
                meta = {
                    "T": T,
                    "N": N,
                    "d": d,
                    "seed": seed,
                    "n_repeats": n_repeats,
                    "end_to_end": end_to_end,
                    "mode": mode,
                    "recurrence_time": t_rec,
                    "attention_time": t_att,
                    "speedup": None if t_att == 0 else t_att / max(t_rec, 1e-12),
                }
                results.append(ExperimentResult("Time scaling O(T) vs O(T^2)", details, meta))
    return results


def plot_results(results: List[ExperimentResult], out_path: Path = Path("outputs/exp4_time_scaling.png")) -> None:
    try:
        import matplotlib.pyplot as plt  # type: ignore
    except ImportError:  # pragma: no cover
        print("matplotlib not available; skipping exp4 plot.")
        return

    out_path.parent.mkdir(parents=True, exist_ok=True)

    # `results` may contain multiple seeds/runs (e.g. via experiments/run_and_log.py).
    # Aggregate by (T, N, d, mode, end_to_end) and plot mean +/- std bands across runs.
    grouped: dict[tuple[int, int, int, str, bool], dict[str, list[float]]] = {}
    seeds_by_config: dict[tuple[int, int, str, bool], set[int]] = {}
    for r in results:
        if not r.meta:
            continue
        key = (
            int(r.meta["T"]),
            int(r.meta["N"]),
            int(r.meta["d"]),
            str(r.meta.get("mode", "core")),
            bool(r.meta.get("end_to_end", False)),
        )
        bucket = grouped.setdefault(key, {"rec": [], "att": []})
        bucket["rec"].append(float(r.meta["recurrence_time"]))
        bucket["att"].append(float(r.meta["attention_time"]))

        cfg_key = (key[1], key[2], key[3], key[4])  # (N, d, mode, end_to_end)
        seed = r.meta.get("seed")
        if isinstance(seed, int):
            seeds_by_config.setdefault(cfg_key, set()).add(seed)

    by_config: dict[tuple[int, int, str, bool], list[tuple[int, list[float], list[float]]]] = {}
    for (T, N, d, mode, end_to_end), bucket in grouped.items():
        by_config.setdefault((N, d, mode, end_to_end), []).append((T, bucket["rec"], bucket["att"]))

    for (N, d, mode, end_to_end), rows in sorted(by_config.items()):
        rows_sorted = sorted(rows, key=lambda x: x[0])
        T_vals = [t for t, _, _ in rows_sorted]
        rec_mean = [float(np.mean(v)) for _, v, _ in rows_sorted]
        rec_std = [float(np.std(v, ddof=1)) if len(v) > 1 else 0.0 for _, v, _ in rows_sorted]
        att_mean = [float(np.mean(v)) for _, _, v in rows_sorted]
        att_std = [float(np.std(v, ddof=1)) if len(v) > 1 else 0.0 for _, _, v in rows_sorted]

        plt.figure(figsize=(8, 6))
        plt.plot(T_vals, rec_mean, marker="o", linewidth=4, alpha=0.9, color="darkgrey", label="Recurrence (O(T))")
        plt.fill_between(
            T_vals,
            np.maximum(1e-12, np.array(rec_mean) - np.array(rec_std)),
            np.array(rec_mean) + np.array(rec_std),
            color="darkgrey",
            alpha=0.2,
        )
        plt.plot(T_vals, att_mean, marker="s", linewidth=4, alpha=0.9, color="darkred", label="Attention (O(T²))")
        plt.fill_between(
            T_vals,
            np.maximum(1e-12, np.array(att_mean) - np.array(att_std)),
            np.array(att_mean) + np.array(att_std),
            color="darkred",
            alpha=0.2,
        )

        plt.xlabel("Sequence Length T", fontsize=24)
        plt.ylabel("Time (s)", fontsize=24)
        plt.title(
            "Experiment 4: Time Scaling",
            fontsize=20,
            fontweight="bold",
        )
        plt.tick_params(axis="both", which="major", labelsize=20)
        plt.grid(True, alpha=0.3)
        plt.legend(title="Kernel", fontsize=16, title_fontsize=16)
        plt.tight_layout()

        if len(by_config) == 1:
            out_path_eff = out_path
        else:
            suffix = f"_N{N}_d{d}_{mode}_{'full' if end_to_end else 'core'}"
            out_path_eff = out_path.with_name(out_path.stem + suffix + out_path.suffix)
        plt.savefig(out_path_eff, dpi=225, bbox_inches="tight")
        plt.close()
        print(f"Saved exp4 plot to {out_path_eff}")


if __name__ == "__main__":
    results = run()
    for res in results:
        print(f"[{res.name}] {res.details}")
    plot_results(results, Path("outputs/exp4_time_scaling.png"))
