#!/usr/bin/env python3
"""
Aggregated runner for the structured state-space duality experiments.
"""

from __future__ import annotations

from typing import List, Iterable, Optional, Sequence
import numpy as np

from . import (
    exp1_scalar_equivalence,
    exp2_diagonal_equivalence,
    exp2b_timevarying_diagonal,
    exp3_rank_vs_state_dim,
    exp4_time_scaling,
    exp5_softmax_rank_growth,
)
from .common import ExperimentResult


def gather_results(
    seed: int = 0,
    a_list: Optional[Iterable[float]] = None,
    T_list: Optional[Iterable[int]] = None,
    exp2_T_list: Optional[Iterable[int]] = None,
    exp2_N_list: Optional[Iterable[int]] = None,
    exp2b_T_list: Optional[Iterable[int]] = None,
    exp2b_N_list: Optional[Iterable[int]] = None,
    exp3_T_values: Optional[Sequence[int]] = None,
    exp3_seeds: Optional[Sequence[int]] = None,
    exp3_random_N: Optional[Sequence[int]] = None,
    exp4_T_values: Optional[Sequence[int]] = None,
    exp4_N_values: Optional[Sequence[int]] = None,
    exp4_d_values: Optional[Sequence[int]] = None,
    exp4_end_to_end: bool = False,
    exp5_T_values: Optional[Sequence[int]] = None,
) -> List[ExperimentResult]:
    results: List[ExperimentResult] = []
    if a_list is None and T_list is None:
        results.append(exp1_scalar_equivalence.run(seed=seed))
    else:
        a_vals = list(a_list) if a_list is not None else [0.8]
        T_vals = list(T_list) if T_list is not None else [20]
        for a in a_vals:
            for T in T_vals:
                results.append(exp1_scalar_equivalence.run(T=T, a=a, seed=seed))
    # Exp2 grid over T and N (generate decays for N via linspace)
    if exp2_T_list is None:
        exp2_T_list = [20]
    if exp2_N_list is None:
        exp2_N_list = [2]
    for T2 in exp2_T_list:
        for N2 in exp2_N_list:
            decays = np.linspace(0.5, 0.9, N2).tolist()
            results.append(exp2_diagonal_equivalence.run(T=T2, decays=decays, seed=seed))
    # Exp2b grid over T and N (time-varying)
    if exp2b_T_list is None:
        exp2b_T_list = [12]
    if exp2b_N_list is None:
        exp2b_N_list = [3]
    for T2b in exp2b_T_list:
        for N2b in exp2b_N_list:
            results.append(exp2b_timevarying_diagonal.run(T=T2b, N=N2b, seed=seed))
    # Exp3 sweep
    cases = exp3_rank_vs_state_dim.default_cases()
    if exp3_random_N:
        for n in exp3_random_N:
            cases.append({"a": "random", "N": n, "desc": f"N={n} random"})
    results.extend(
        exp3_rank_vs_state_dim.run(
            T_values=exp3_T_values if exp3_T_values is not None else (15,),
            cases=cases,
            seeds=exp3_seeds,
        )
    )
    results.extend(
        exp4_time_scaling.run(
            seed=seed,
            T_values=exp4_T_values if exp4_T_values is not None else exp4_time_scaling.run.__defaults__[0],
            N_values=exp4_N_values if exp4_N_values is not None else exp4_time_scaling.run.__defaults__[1],
            d_values=exp4_d_values if exp4_d_values is not None else exp4_time_scaling.run.__defaults__[2],
            end_to_end=exp4_end_to_end,
        )
    )
    results.extend(
        exp5_softmax_rank_growth.run(
            seed=seed,
            T_values=exp5_T_values if exp5_T_values is not None else exp5_softmax_rank_growth.run.__defaults__[0],
        )
    )
    return results


def main() -> None:
    print("Structured State-Space Duality: experiment summaries\n")
    for res in gather_results():
        print(f"[{res.name}] {res.details}")


if __name__ == "__main__":
    main()
