"""
Shared utilities for the structured state-space duality experiments.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable, Optional, Sequence

import numpy as np


@dataclass
class ExperimentResult:
    name: str
    details: str
    meta: Optional[dict] = None


def causal_mask(T: int) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    idx_t = np.arange(T)[:, None]
    idx_s = np.arange(T)[None, :]
    mask = (idx_t >= idx_s).astype(np.float64)
    return mask, idx_t, idx_s


def generator_rank(decays: Sequence[float], T: int) -> int:
    time_axis = np.arange(T, dtype=np.float64)
    vandermonde = np.column_stack([decay ** time_axis for decay in decays])
    return int(np.linalg.matrix_rank(vandermonde))


def run_recurrence(U: np.ndarray, A_vals: np.ndarray) -> float:
    N, d = A_vals.size, U.shape[1]
    B = np.ones((N, d), dtype=np.float64)
    x = np.zeros(N, dtype=np.float64)
    start = default_timer()
    for t in range(U.shape[0]):
        x = A_vals * x + B @ U[t]
    return default_timer() - start


def run_attention(U: np.ndarray, A_vals: np.ndarray) -> float:
    T, _d = U.shape
    mask, t_idx, s_idx = causal_mask(T)
    # Clamp exponents to nonnegative to avoid overflow for t < s (will be masked to zero anyway)
    diff = (t_idx - s_idx)
    diff = np.where(mask, diff, 0)
    M = np.zeros((T, T), dtype=np.float64)
    for m in range(A_vals.size):
        M += (A_vals[m] ** diff) * mask
    start = default_timer()
    _ = M @ U
    return default_timer() - start


def stable_softmax(vec: np.ndarray) -> np.ndarray:
    vec = vec - np.max(vec)
    exp = np.exp(vec)
    return exp / np.sum(exp)

from time import perf_counter as default_timer
