"""
Experiment 4b: runtime scaling with aggregated trials and error bars.
"""

from __future__ import annotations

from pathlib import Path
from typing import Iterable, List, Sequence

import numpy as np

from .common import (
    ExperimentResult,
    causal_mask,
    default_timer,
    run_attention,
    run_recurrence,
)


def _time_single_run(
    U: np.ndarray,
    A_vals: np.ndarray,
    end_to_end: bool,
) -> tuple[float, float]:
    if end_to_end:
        start = default_timer()
        B = np.ones((A_vals.size, U.shape[1]), dtype=np.float64)
        x = np.zeros(A_vals.size, dtype=np.float64)
        for t in range(U.shape[0]):
            x = A_vals * x + B @ U[t]
        rec_time = default_timer() - start

        start = default_timer()
        mask, t_idx, s_idx = causal_mask(U.shape[0])
        diff = np.where(mask, t_idx - s_idx, 0)
        M = np.zeros((U.shape[0], U.shape[0]), dtype=np.float64)
        for m in range(A_vals.size):
            M += (A_vals[m] ** diff) * mask
        _ = M @ U
        att_time = default_timer() - start
    else:
        rec_time = run_recurrence(U, A_vals)
        att_time = run_attention(U, A_vals)
    return float(rec_time), float(att_time)


def run(
    T_values: Sequence[int] = (150, 300, 600, 1200, 2400, 4800, 9600),
    N_values: Iterable[int] = (4,),
    d_values: Iterable[int] = (16,),
    base_seed: int = 0,
    n_trials: int = 5,
    n_repeats: int = 3,
    end_to_end: bool = False,
    seeds: Sequence[int] | None = None,
) -> List[ExperimentResult]:
    results: List[ExperimentResult] = []
    trial_seeds = list(seeds) if seeds is not None else [base_seed + i for i in range(n_trials)]
    for N in N_values:
        A_vals = np.linspace(0.5, 0.8, N, dtype=np.float64)
        for d in d_values:
            for T in T_values:
                rec_samples: list[float] = []
                att_samples: list[float] = []
                for trial_seed in trial_seeds:
                    rng = np.random.default_rng(trial_seed)
                    U = rng.standard_normal((T, d))
                    rec_time = []
                    att_time = []
                    for _ in range(n_repeats):
                        rec_once, att_once = _time_single_run(U, A_vals, end_to_end=end_to_end)
                        rec_time.append(rec_once)
                        att_time.append(att_once)
                    rec_samples.append(float(np.mean(rec_time)))
                    att_samples.append(float(np.mean(att_time)))
                rec_mean = float(np.mean(rec_samples))
                att_mean = float(np.mean(att_samples))
                rec_std = float(np.std(rec_samples, ddof=1)) if len(rec_samples) > 1 else 0.0
                att_std = float(np.std(att_samples, ddof=1)) if len(att_samples) > 1 else 0.0
                n_samples = len(rec_samples)
                scale = 1.96 / np.sqrt(n_samples) if n_samples > 0 else 0.0
                rec_ci = rec_std * scale
                att_ci = att_std * scale
                mode = "full" if end_to_end else "core"
                details = (
                    f"T={T}, N={N}, d={d}, mode={mode}, trials={len(trial_seeds)}, "
                    f"recurrence_time={rec_mean:.4f}±{rec_ci:.4f}s (95% CI), "
                    f"attention_time={att_mean:.4f}±{att_ci:.4f}s (95% CI)"
                )
                meta = {
                    "T": T,
                    "N": N,
                    "d": d,
                    "mode": mode,
                    "end_to_end": end_to_end,
                    "n_trials": n_samples,
                    "recurrence_mean": rec_mean,
                    "recurrence_std": rec_std,
                    "recurrence_ci": rec_ci,
                    "attention_mean": att_mean,
                    "attention_std": att_std,
                    "attention_ci": att_ci,
                    "recurrence_samples": rec_samples,
                    "attention_samples": att_samples,
                    "trial_seeds": trial_seeds,
                }
                results.append(ExperimentResult("Time scaling with error bars", details, meta))
    return results


def plot_results(results: List[ExperimentResult], out_path: Path = Path("outputs/exp4b_time_scaling.png")) -> None:
    try:
        import matplotlib.pyplot as plt  # type: ignore
    except ImportError:  # pragma: no cover
        print("matplotlib not available; skipping exp4b plot.")
        return

    if not results:
        print("No results to plot for exp4b.")
        return

    # Expect a single (N, d, mode) combination for plotting clarity.
    combos = {
        (
            r.meta.get("N") if r.meta else None,
            r.meta.get("d") if r.meta else None,
            r.meta.get("mode") if r.meta else None,
        )
        for r in results
    }
    if len(combos) > 1:
        raise ValueError("Plotting currently supports a single (N, d, mode) combination.")

    sorted_results = sorted(results, key=lambda r: r.meta["T"] if r.meta else 0)
    T_vals = [r.meta["T"] for r in sorted_results if r.meta]
    rec_means = [r.meta["recurrence_mean"] for r in sorted_results if r.meta]
    att_means = [r.meta["attention_mean"] for r in sorted_results if r.meta]
    rec_errs = [r.meta.get("recurrence_ci", 0.0) for r in sorted_results if r.meta]
    att_errs = [r.meta.get("attention_ci", 0.0) for r in sorted_results if r.meta]

    out_path.parent.mkdir(parents=True, exist_ok=True)

    plt.figure(figsize=(8, 6))
    plt.errorbar(
        T_vals,
        rec_means,
        yerr=rec_errs,
        fmt="-o",
        linewidth=4,
        alpha=0.9,
        color="darkgrey",
        ecolor="darkgrey",
        capsize=6,
        label="recurrence (O(T))",
    )
    plt.errorbar(
        T_vals,
        att_means,
        yerr=att_errs,
        fmt="-s",
        linewidth=4,
        alpha=0.9,
        color="darkred",
        ecolor="darkred",
        capsize=6,
        label="attention (O(T²))",
    )
    plt.xlabel("Sequence length T", fontsize=24)
    plt.ylabel("Time (s)", fontsize=24)
    plt.title("Experiment 4b: runtime scaling (95% CI)", fontsize=20, fontweight="bold", pad=18)
    plt.tick_params(axis="both", which="major", labelsize=20)
    plt.legend(fontsize=16)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(out_path, dpi=225, bbox_inches="tight")
    plt.close()
    print(f"Saved exp4b plot to {out_path}")


if __name__ == "__main__":
    results = run()
    for res in results:
        print(f"[{res.name}] {res.details}")
    plot_results(results)
