"""
Experiment 2: Diagonal SSM (N=2) vs. sum of 1-SS masked attentions.
"""

from __future__ import annotations

import numpy as np

from .common import ExperimentResult, causal_mask, generator_rank


def run(
    T: int = 20,
    decays: tuple[float, float] = (0.5, 0.8),
    B_vec: tuple[float, float] | None = None,
    C_vec: tuple[float, float] | None = None,
    seed: int = 0,
) -> ExperimentResult:
    
    rng = np.random.default_rng(seed)
    u = rng.standard_normal(T)
    A_vals = np.asarray(decays, dtype=np.float64)
    N = A_vals.size
    B = np.ones(N, dtype=np.float64) if B_vec is None else np.asarray(B_vec, dtype=np.float64)
    C = np.ones(N, dtype=np.float64) if C_vec is None else np.asarray(C_vec, dtype=np.float64)

    x = np.zeros(N, dtype=np.float64)
    y = np.zeros(T, dtype=np.float64)
    for t in range(T):
        x = A_vals * x + B * u[t]
        y[t] = float(C @ x)

    mask, t_idx, s_idx = causal_mask(T)
    M = np.zeros((T, T), dtype=np.float64)
    for m in range(N):
        M += C[m] * B[m] * (A_vals[m] ** (t_idx - s_idx)) * mask
    y_att = M @ u

    max_error = float(np.max(np.abs(y - y_att)))
    gen_rank = generator_rank(A_vals, T)
    matrix_rank = int(np.linalg.matrix_rank(M))
    details = (
        f"T={T}, decays={list(A_vals)}, max |y - y_att| = {max_error:.3e}, "
        f"generator_rank={gen_rank}/{N}, matrix_rank={matrix_rank}"
    )
    meta = {
        "T": T,
        "decays": list(A_vals),
        "B": list(B),
        "C": list(C),
        "seed": seed,
        "matrix_rank": matrix_rank,
        "generator_rank": gen_rank,
        "max_error": max_error,
    }
    return ExperimentResult("Diagonal SSM (N=2) ≡ sum of 1-SS heads", details, meta)


if __name__ == "__main__":
    res = run()
    print(f"[{res.name}] {res.details}")
