"""
Experiment 2b: Time-varying diagonal SSM vs. sum of masked attentions.

Generalizes Experiment 2 to allow a different diagonal A_t at every timestep.
"""

from __future__ import annotations

from typing import Optional

import numpy as np

from .common import ExperimentResult


def _build_time_varying_mask(A_vals: np.ndarray) -> np.ndarray:
    """
    Build a causal mask for time-varying diagonal decays.

    A_vals: shape (T, N), entry A_vals[t, m] is decay for state m at time t.
    Returns mask of shape (T, T, N) where mask[t, s, m] = prod_{k=s+1..t} A_vals[k, m] for t>=s, else 0.
    """
    T, N = A_vals.shape
    mask = np.zeros((T, T, N), dtype=np.float64)
    for m in range(N):
        for t in range(T):
            prod = 1.0
            mask[t, t, m] = 1.0
            for s in range(t - 1, -1, -1):
                prod *= A_vals[s + 1, m]
                mask[t, s, m] = prod
    return mask


def run(
    T: int = 12,
    N: int = 3,
    seed: int = 0,
    A_vals: Optional[np.ndarray] = None,  # shape (T, N)
    B_vals: Optional[np.ndarray] = None,  # shape (T, N) or (N,)
    C_vals: Optional[np.ndarray] = None,  # shape (T, N) or (N,)
    u: Optional[np.ndarray] = None,  # shape (T,)
) -> ExperimentResult:
    rng = np.random.default_rng(seed)
    if A_vals is None:
        A_vals = rng.uniform(0.5, 0.9, size=(T, N))
    else:
        A_vals = np.asarray(A_vals, dtype=np.float64)
        T, N = A_vals.shape

    B_template = np.ones(N, dtype=np.float64)
    C_template = np.ones(N, dtype=np.float64)
    B_mat = (
        np.asarray(B_vals, dtype=np.float64)
        if B_vals is not None
        else np.tile(B_template, (T, 1))
    )
    C_mat = (
        np.asarray(C_vals, dtype=np.float64)
        if C_vals is not None
        else np.tile(C_template, (T, 1))
    )
    if B_mat.ndim == 1:
        assert B_mat.shape[0] == N, "B_vals length must equal N"
        B_mat = np.tile(B_mat, (T, 1))
    else:
        assert B_mat.shape == (T, N), "B_vals must be shape (T, N) or (N,)"
    if C_mat.ndim == 1:
        assert C_mat.shape[0] == N, "C_vals length must equal N"
        C_mat = np.tile(C_mat, (T, 1))
    else:
        assert C_mat.shape == (T, N), "C_vals must be shape (T, N) or (N,)"

    if u is None:
        u = rng.standard_normal(T)
    else:
        u = np.asarray(u, dtype=np.float64)
        assert u.shape[0] == T

    # Recurrence
    x = np.zeros(N, dtype=np.float64)
    y = np.zeros(T, dtype=np.float64)
    for t in range(T):
        x = A_vals[t] * x + B_mat[t] * u[t]
        y[t] = float(C_mat[t] @ x)

    # Attention-style construction
    mask = _build_time_varying_mask(A_vals)  # shape (T, T, N)
    # Broadcast B over source positions, C over target positions
    contrib = mask * C_mat[:, None, :] * B_mat[None, :, :]
    M = np.sum(contrib, axis=2)  # (T, T)
    y_att = M @ u

    max_error = float(np.max(np.abs(y - y_att)))
    matrix_rank = int(np.linalg.matrix_rank(M))
    details = (
        f"T={T}, N={N} (time-varying A_t), max |y - y_att| = {max_error:.3e}, "
        f"matrix_rank={matrix_rank}"
    )
    meta = {
        "T": T,
        "N": N,
        "seed": seed,
        "matrix_rank": matrix_rank,
        "A_shape": list(A_vals.shape),
        "max_error": max_error,
    }
    return ExperimentResult("Diagonal SSM (time-varying A_t) ≡ sum of 1-SS heads", details, meta)


if __name__ == "__main__":
    res = run()
    print(f"[{res.name}] {res.details}")
