"""
Experiment 3: State dimension vs. semiseparable generator rank.
"""

from __future__ import annotations

from typing import Iterable, List, Sequence, Optional

import numpy as np

from .common import ExperimentResult, causal_mask, generator_rank


def default_cases() -> list[dict]:
    return [
        {"a": [0.9], "desc": "N=1 distinct"},
        {"a": [0.5, 0.8], "desc": "N=2 distinct"},
        {"a": [0.7, 0.7], "desc": "N=2 duplicate a1=a2"},
        {"a": [0.4, 0.6, 0.9], "desc": "N=3 distinct"},
    ]


def run(
    T_values: Sequence[int] = (15,),
    cases: Iterable[dict] | None = None,
    seeds: Optional[Sequence[int]] = None,
) -> List[ExperimentResult]:
    if cases is None:
        cases = default_cases()
    results: List[ExperimentResult] = []
    seeds = list(seeds) if seeds is not None else [0]
    for T in T_values:
        mask, t_idx, s_idx = causal_mask(T)
        for case in cases:
            for seed in seeds:
                rng = np.random.default_rng(seed)
                if case.get("a") == "random":
                    N = case.get("N", 2)
                    A_vals = rng.uniform(0.3, 0.95, size=N)
                    desc = case.get("desc", f"N={N} random decays")
                else:
                    A_vals = np.asarray(case["a"], dtype=np.float64)
                    N = A_vals.size
                    desc = case["desc"]
                B = np.ones(N, dtype=np.float64)
                C = np.ones(N, dtype=np.float64)
                M = np.zeros((T, T), dtype=np.float64)
                for m in range(N):
                    M += C[m] * B[m] * (A_vals[m] ** (t_idx - s_idx)) * mask
                gen_rank = generator_rank(A_vals, T)
                matrix_rank = int(np.linalg.matrix_rank(M))
                meta = {
                    "T": T,
                    "state_dim": N,
                    "decays": A_vals.tolist(),
                    "generator_rank": gen_rank,
                    "matrix_rank": matrix_rank,
                    "seed": seed,
                }
                results.append(
                    ExperimentResult(
                        f"Rank check: {desc}",
                        (
                            f"T={T}, state_dim={N}, generator_rank={gen_rank}, "
                            f"matrix_rank={matrix_rank}"
                        ),
                        meta,
                    )
                )
    return results


if __name__ == "__main__":
    for res in run():
        print(f"[{res.name}] {res.details}")
