from __future__ import annotations

import math
from dataclasses import asdict, dataclass
from typing import Any

import numpy as np


@dataclass(frozen=True)
class ScanConfig:
    hidden_size: int = 256
    batch_size: int = 32
    time_steps: int = 800
    burnin_steps: int = 100
    input_dim: int = 1
    input_mode: str = "gaussian"  # gaussian | poisson
    input_std: float = 0.05
    poisson_rate: float = 5.0

    gext_mode: str = "output_mse"  # output_mse | random
    gext_std: float = 1.0
    gext_ar1_rho: float = 0.0

    w_hh_mode: str = "low_rank"  # iid | diag_dominant | low_rank
    diag_strength: float = 0.0
    low_rank_rank: int = 1
    low_rank_frac: float = 0.95

    output_dim: int = 16
    target_std: float = 1.0
    target_ar1_rho: float = 0.995
    error_lp_rho: float = 0.0
    output_weight_scale: float = 1.0
    output_weight_mode: str = "align_low_rank"  # random | align_low_rank | align_random_basis | align_krylov_basis | align_svd_basis | align_eig_basis
    output_basis_rank: int = 0

    seed: int = 123

    gain_min: float = 0.1
    gain_max: float = 2.0
    gain_points: int = 12

    lam_cap: float = 0.99
    lam_grid_points: int = 401
    denom_floor: float = 1e-3
    fit_burnin_steps: int = 100
    rho_power_iters: int = 60

    oracle_per_t: bool = False
    oracle_global: bool = False
    oracle_neg_cap: float = 100.0

    alpha_rho: float = 0.995
    alpha_clip_min: float = -0.99
    alpha_clip_max: float = 0.99
    alpha_source: str = "h"  # h | g
    lambda_window: int = 50
    eps_lambda: float = 1e-8
    use_safe_cap: bool = True


def split_rngs(seed: int) -> dict[str, np.random.Generator]:
    base = int(seed)
    return {
        "inputs": np.random.default_rng(base + 1),
        "recurrent": np.random.default_rng(base + 2),
        "output": np.random.default_rng(base + 3),
        "targets": np.random.default_rng(base + 4),
        "gext": np.random.default_rng(base + 5),
        "scan": np.random.default_rng(base + 6),
    }


def generate_inputs(rng: np.random.Generator, cfg: ScanConfig) -> np.ndarray:
    if cfg.input_mode == "gaussian":
        return rng.normal(
            0.0,
            float(cfg.input_std),
            size=(int(cfg.time_steps), int(cfg.input_dim), int(cfg.batch_size)),
        ).astype(np.float64)
    if cfg.input_mode == "poisson":
        raw = rng.poisson(
            float(cfg.poisson_rate),
            size=(int(cfg.time_steps), int(cfg.input_dim), int(cfg.batch_size)),
        ).astype(np.float64)
        return raw - float(cfg.poisson_rate)
    raise ValueError(f"Unknown input_mode={cfg.input_mode!r}")


def generate_targets(
    rng: np.random.Generator,
    time_steps: int,
    output_dim: int,
    batch_size: int,
    std: float,
    ar1_rho: float,
) -> np.ndarray:
    rho = float(ar1_rho)
    if rho <= 0.0:
        return rng.normal(0.0, float(std), size=(time_steps, output_dim, batch_size)).astype(np.float64)
    if rho >= 1.0:
        raise ValueError("target_ar1_rho must be < 1.0")
    sigma_noise = float(std) * math.sqrt(max(0.0, 1.0 - rho * rho))
    targets = np.zeros((time_steps, output_dim, batch_size), dtype=np.float64)
    targets[0] = rng.normal(0.0, float(std), size=(output_dim, batch_size)).astype(np.float64)
    for t in range(1, int(time_steps)):
        targets[t] = rho * targets[t - 1] + rng.normal(
            0.0, sigma_noise, size=(output_dim, batch_size)
        ).astype(np.float64)
    return targets


def lowpass_ema(x: np.ndarray, rho: float) -> np.ndarray:
    rho_f = float(rho)
    if rho_f <= 0.0:
        return np.asarray(x, dtype=np.float64)
    if rho_f >= 1.0:
        raise ValueError("lowpass rho must be < 1.0")
    x = np.asarray(x, dtype=np.float64)
    y = np.zeros_like(x, dtype=np.float64)
    y[0] = x[0]
    for t in range(1, int(x.shape[0])):
        y[t] = rho_f * y[t - 1] + (1.0 - rho_f) * x[t]
    return y


def init_base_weights(
    rng: np.random.Generator, cfg: ScanConfig
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray | None]:
    n = int(cfg.hidden_size)
    input_dim = int(cfg.input_dim)
    base = rng.normal(0.0, 1.0, size=(n, n)).astype(np.float64) / math.sqrt(n)
    low_rank_basis: np.ndarray | None = None
    if cfg.w_hh_mode == "diag_dominant":
        base = base + float(cfg.diag_strength) * np.eye(n, dtype=np.float64)
    elif cfg.w_hh_mode == "low_rank":
        rank = max(1, int(cfg.low_rank_rank))
        frac = float(cfg.low_rank_frac)
        if not (0.0 <= frac <= 1.0):
            raise ValueError("low_rank_frac must be in [0, 1]")
        noise = base
        low_rank = np.zeros((n, n), dtype=np.float64)
        basis_rows: list[np.ndarray] = []
        for _ in range(rank):
            v = rng.normal(0.0, 1.0, size=(n,)).astype(np.float64)
            basis_rows.append(v)
            low_rank += np.outer(v, v) / float(n)
        low_rank /= float(rank)
        low_rank_basis = np.stack(basis_rows, axis=0)
        base = math.sqrt(1.0 - frac) * noise + math.sqrt(frac) * low_rank
    elif cfg.w_hh_mode != "iid":
        raise ValueError(f"Unknown w_hh_mode={cfg.w_hh_mode!r}")
    w_hh_base = base
    w_xh = rng.normal(0.0, 1.0, size=(n, input_dim)).astype(np.float64) * (0.5 / math.sqrt(max(1, input_dim)))
    bias = np.zeros((n,), dtype=np.float64)
    return w_hh_base, w_xh, bias, low_rank_basis


def build_krylov_basis(w: np.ndarray, v0: np.ndarray, rank: int, eps: float = 1e-12) -> np.ndarray:
    w = np.asarray(w, dtype=np.float64)
    n = int(w.shape[0])
    r = max(1, int(rank))
    v = np.asarray(v0, dtype=np.float64).reshape(n).copy()
    basis_cols: list[np.ndarray] = []
    for j in range(r):
        if j > 0:
            v = w.T @ v
        for q in basis_cols:
            v = v - float(q @ v) * q
        norm = float(np.linalg.norm(v))
        if not np.isfinite(norm) or norm < eps:
            break
        basis_cols.append(v / norm)
    if not basis_cols:
        e0 = np.zeros((n,), dtype=np.float64)
        e0[0] = 1.0
        basis_cols = [e0]
    return np.stack(basis_cols, axis=1)


def build_neumann_sum_vector(w: np.ndarray, v0: np.ndarray, depth: int) -> np.ndarray:
    w = np.asarray(w, dtype=np.float64)
    n = int(w.shape[0])
    d = max(1, int(depth))
    v = np.asarray(v0, dtype=np.float64).reshape(n)
    out = np.zeros_like(v)
    x = v.copy()
    for _ in range(d):
        out += x
        x = w.T @ x
    return out


def _init_output_weights_krylov(
    rng: np.random.Generator,
    w_hh: np.ndarray,
    output_dim: int,
    hidden_size: int,
    basis_rank: int,
    neumann_depth: int,
    v0_mode: str,
    power_iters: int,
    output_weight_scale: float,
) -> tuple[np.ndarray, np.ndarray]:
    n = int(hidden_size)
    out_dim = int(output_dim)
    r = max(1, min(int(basis_rank), n))

    mode = str(v0_mode)
    if mode == "random":
        v0 = rng.normal(0.0, 1.0, size=(n,)).astype(np.float64)
    elif mode == "ones":
        v0 = np.ones((n,), dtype=np.float64)
    elif mode == "power":
        v = np.ones((n,), dtype=np.float64)
        v /= float(np.linalg.norm(v) + 1e-12)
        for _ in range(max(1, int(power_iters))):
            v = w_hh.T @ (w_hh @ v)
            v /= float(np.linalg.norm(v) + 1e-12)
        v0 = v
    else:
        raise ValueError(f"Unknown v0_mode={v0_mode!r}")

    q0 = build_neumann_sum_vector(w_hh, v0, int(neumann_depth))
    V = build_krylov_basis(w_hh, q0, r)  # (N, r_eff)
    B = V.T.copy()  # (r_eff, N)
    B = B / (np.linalg.norm(B, axis=1, keepdims=True) + 1e-12)

    a = rng.normal(0.0, 1.0, size=(out_dim, int(B.shape[0]))).astype(np.float64)
    a *= float(output_weight_scale) / math.sqrt(max(1, int(B.shape[0])))
    w_out = a @ B
    b_out = np.zeros((out_dim,), dtype=np.float64)
    return w_out, b_out


def _init_output_weights_svd(
    rng: np.random.Generator,
    w_hh: np.ndarray,
    output_dim: int,
    hidden_size: int,
    basis_rank: int,
    side: str,
    use_abs: bool,
    output_weight_scale: float,
) -> tuple[np.ndarray, np.ndarray]:
    n = int(hidden_size)
    out_dim = int(output_dim)
    r = max(1, min(int(basis_rank), n))

    M = np.asarray(w_hh, dtype=np.float64)
    if bool(use_abs):
        M = np.abs(M)
    U, _S, Vt = np.linalg.svd(M, full_matrices=False)

    side_s = str(side).lower()
    if side_s in {"right", "v", "vt"}:
        B = Vt[:r].copy()
    elif side_s in {"left", "u"}:
        B = U[:, :r].T.copy()
    else:
        raise ValueError(f"Unknown side={side!r} (expected 'left' or 'right').")

    B = B / (np.linalg.norm(B, axis=1, keepdims=True) + 1e-12)
    a = rng.normal(0.0, 1.0, size=(out_dim, int(B.shape[0]))).astype(np.float64)
    a *= float(output_weight_scale) / math.sqrt(max(1, int(B.shape[0])))
    w_out = a @ B
    b_out = np.zeros((out_dim,), dtype=np.float64)
    return w_out, b_out


def _init_output_weights_eig(
    rng: np.random.Generator,
    w_hh: np.ndarray,
    output_dim: int,
    hidden_size: int,
    basis_rank: int,
    select: str,
    use_abs: bool,
    output_weight_scale: float,
) -> tuple[np.ndarray, np.ndarray]:
    n = int(hidden_size)
    out_dim = int(output_dim)
    r = max(1, min(int(basis_rank), n))

    M = np.asarray(w_hh, dtype=np.float64)
    if bool(use_abs):
        M = np.abs(M)
    vals, vecs = np.linalg.eig(M.T)
    select_s = str(select).lower()
    if select_s == "closest_one":
        order = np.argsort(np.abs(1.0 - vals))
    elif select_s == "largest_abs":
        order = np.argsort(-np.abs(vals))
    elif select_s == "largest_real":
        order = np.argsort(-np.real(vals))
    else:
        raise ValueError(f"Unknown select={select!r}")

    cols: list[np.ndarray] = []
    eps = 1e-10
    for idx in order.tolist():
        v = vecs[:, int(idx)]
        vr = np.real(v)
        vi = np.imag(v)
        if float(np.linalg.norm(vr)) > eps:
            cols.append(vr.astype(np.float64, copy=False))
        if float(np.linalg.norm(vi)) > eps:
            cols.append(vi.astype(np.float64, copy=False))
        if len(cols) >= max(2 * r, r + 4):
            break

    if not cols:
        cols = [rng.normal(0.0, 1.0, size=(n,)).astype(np.float64)]

    V = np.stack(cols, axis=1)
    Q, _R = np.linalg.qr(V)
    B = Q[:, :r].T.copy()
    B = B / (np.linalg.norm(B, axis=1, keepdims=True) + 1e-12)

    a = rng.normal(0.0, 1.0, size=(out_dim, int(B.shape[0]))).astype(np.float64)
    a *= float(output_weight_scale) / math.sqrt(max(1, int(B.shape[0])))
    w_out = a @ B
    b_out = np.zeros((out_dim,), dtype=np.float64)
    return w_out, b_out


def init_output_weights(
    rng: np.random.Generator,
    cfg: ScanConfig,
    w_hh: np.ndarray,
    low_rank_basis: np.ndarray | None,
    *,
    krylov_depth: int = 8,
    krylov_v0: str = "random",
    krylov_power_iters: int = 60,
    svd_side: str = "right",
    svd_abs: bool = False,
    eig_select: str = "closest_one",
    eig_abs: bool = False,
) -> tuple[np.ndarray, np.ndarray]:
    output_dim = int(cfg.output_dim)
    hidden_size = int(cfg.hidden_size)
    mode = str(cfg.output_weight_mode)
    if mode == "random":
        w_out = (
            rng.normal(0.0, 1.0, size=(output_dim, hidden_size)).astype(np.float64)
            * (float(cfg.output_weight_scale) / math.sqrt(hidden_size))
        )
    elif mode == "align_low_rank":
        if low_rank_basis is None:
            raise ValueError("output_weight_mode=align_low_rank requires w_hh_mode=low_rank")
        basis = np.asarray(low_rank_basis, dtype=np.float64) / math.sqrt(hidden_size)
        rank = int(basis.shape[0])
        a = rng.normal(0.0, 1.0, size=(output_dim, rank)).astype(np.float64)
        a *= float(cfg.output_weight_scale) / math.sqrt(max(1, rank))
        w_out = a @ basis
    elif mode == "align_random_basis":
        basis_rank = int(cfg.output_basis_rank)
        if basis_rank <= 0:
            basis_rank = int(min(output_dim, hidden_size))
        basis_rank = max(1, min(basis_rank, hidden_size))
        basis = rng.normal(0.0, 1.0, size=(basis_rank, hidden_size)).astype(np.float64) / math.sqrt(hidden_size)
        basis /= np.linalg.norm(basis, axis=1, keepdims=True) + 1e-12
        a = rng.normal(0.0, 1.0, size=(output_dim, basis_rank)).astype(np.float64)
        a *= float(cfg.output_weight_scale) / math.sqrt(max(1, basis_rank))
        w_out = a @ basis
    elif mode == "align_krylov_basis":
        basis_rank = int(cfg.output_basis_rank) if int(cfg.output_basis_rank) > 0 else int(min(output_dim, hidden_size))
        w_out, _b = _init_output_weights_krylov(
            rng,
            w_hh,
            output_dim,
            hidden_size,
            basis_rank,
            krylov_depth,
            krylov_v0,
            krylov_power_iters,
            float(cfg.output_weight_scale),
        )
    elif mode == "align_svd_basis":
        basis_rank = int(cfg.output_basis_rank) if int(cfg.output_basis_rank) > 0 else int(min(output_dim, hidden_size))
        w_out, _b = _init_output_weights_svd(
            rng,
            w_hh,
            output_dim,
            hidden_size,
            basis_rank,
            svd_side,
            svd_abs,
            float(cfg.output_weight_scale),
        )
    elif mode == "align_eig_basis":
        basis_rank = int(cfg.output_basis_rank) if int(cfg.output_basis_rank) > 0 else int(min(output_dim, hidden_size))
        w_out, _b = _init_output_weights_eig(
            rng,
            w_hh,
            output_dim,
            hidden_size,
            basis_rank,
            eig_select,
            eig_abs,
            float(cfg.output_weight_scale),
        )
    else:
        raise ValueError(f"Unknown output_weight_mode={mode!r}")
    b_out = np.zeros((output_dim,), dtype=np.float64)
    return w_out, b_out


def simulate_tanh_rnn(
    w_hh: np.ndarray,
    w_xh: np.ndarray,
    bias: np.ndarray,
    inputs: np.ndarray,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    time_steps, input_dim, batch_size = inputs.shape
    hidden_size = int(w_hh.shape[0])
    h_prev = np.zeros((hidden_size, batch_size), dtype=np.float64)
    pre = np.zeros((time_steps, hidden_size, batch_size), dtype=np.float64)
    states = np.zeros((time_steps, hidden_size, batch_size), dtype=np.float64)
    u_seq = np.zeros((time_steps, hidden_size, batch_size), dtype=np.float64)
    for t in range(time_steps):
        I_t = inputs[t]
        x_t = w_hh @ h_prev + w_xh @ I_t + bias[:, None]
        h_t = np.tanh(x_t)
        pre[t] = x_t
        states[t] = h_t
        u_seq[t] = 1.0 - h_t * h_t
        h_prev = h_t
    return pre, states, u_seq


def estimate_lle_benettin(
    w_hh: np.ndarray,
    u_seq: np.ndarray,
    burnin_steps: int,
    rng: np.random.Generator,
    eps: float = 1e-12,
) -> float:
    time_steps, hidden_size, _batch = u_seq.shape
    v = rng.normal(0.0, 1.0, size=(hidden_size,)).astype(np.float64)
    v /= float(np.linalg.norm(v) + eps)
    logs: list[float] = []
    for t in range(time_steps):
        u_t = u_seq[t].mean(axis=1)
        v = u_t * (w_hh @ v)
        norm = float(np.linalg.norm(v) + eps)
        v /= norm
        if t >= int(burnin_steps):
            logs.append(math.log(norm))
    if not logs:
        return float("nan")
    return float(np.mean(logs))


def estimate_lag1_rho(x: np.ndarray, burnin: int, eps: float = 1e-12) -> float:
    x = np.asarray(x, dtype=np.float64)
    start = max(0, int(burnin))
    if x.shape[0] <= start + 1:
        return float("nan")
    x0 = x[start:-1].reshape(x.shape[0] - start - 1, -1)
    x1 = x[start + 1 :].reshape(x.shape[0] - start - 1, -1)
    num = float(np.mean(x0 * x1))
    den = float(np.mean(x0 * x0)) + eps
    return num / den


def safe_denom(denom: np.ndarray, denom_floor: float) -> np.ndarray:
    eps = 1e-12
    floor = float(denom_floor)
    mask = np.abs(denom) < floor
    if not np.any(mask):
        return denom
    return np.where(mask, floor * np.sign(denom + eps), denom)


def compute_lambda_used_seq(
    states: np.ndarray,
    u_seq: np.ndarray,
    g_seq: np.ndarray,
    cfg: ScanConfig,
) -> tuple[np.ndarray, dict[str, float]]:
    time_steps, hidden_size, batch_size = int(states.shape[0]), int(states.shape[1]), int(states.shape[2])

    fit_start = max(0, int(cfg.fit_burnin_steps))
    shape = (hidden_size, 1)

    alpha_num = np.zeros(shape, dtype=np.float64)
    alpha_den = np.zeros(shape, dtype=np.float64)
    alpha_hat = np.zeros(shape, dtype=np.float64)
    S_A2 = np.zeros(shape, dtype=np.float64)
    S_AB = np.zeros(shape, dtype=np.float64)
    lambda_vals = np.zeros(shape, dtype=np.float64)
    lambda_used_seq = np.zeros((time_steps, hidden_size, 1), dtype=np.float64)

    lambda_window = max(2, int(cfg.lambda_window))
    lambda_rho = (lambda_window - 1.0) / float(lambda_window)
    alpha_rho = float(cfg.alpha_rho)
    eps_lambda = float(cfg.eps_lambda)

    prev_g: np.ndarray | None = None
    prev_u: np.ndarray | None = None
    h_prev = np.zeros((hidden_size, batch_size), dtype=np.float64)

    alpha_sum = 0.0
    alpha_count = 0
    lam_sum = 0.0
    lam_count = 0

    for t in range(time_steps):
        h_t = states[t]
        u_t = u_seq[t]
        g_t = g_seq[t]

        if str(cfg.alpha_source) == "g":
            if prev_g is None:
                alpha_hat = np.zeros_like(alpha_hat)
            else:
                gtp_mean = np.mean(g_t * prev_g, axis=1, keepdims=True)
                gpp_mean = np.mean(prev_g**2, axis=1, keepdims=True)
                alpha_num = alpha_rho * alpha_num + (1.0 - alpha_rho) * gtp_mean
                alpha_den = alpha_rho * alpha_den + (1.0 - alpha_rho) * gpp_mean
                raw_alpha = alpha_num / (alpha_den + 1e-12)
                alpha_hat = np.clip(raw_alpha, float(cfg.alpha_clip_min), float(cfg.alpha_clip_max))
        else:
            hthp_mean = np.mean(h_t * h_prev, axis=1, keepdims=True)
            hphp_mean = np.mean(h_prev**2, axis=1, keepdims=True)
            alpha_num = alpha_rho * alpha_num + (1.0 - alpha_rho) * hthp_mean
            alpha_den = alpha_rho * alpha_den + (1.0 - alpha_rho) * hphp_mean
            raw_alpha = alpha_num / (alpha_den + 1e-12)
            alpha_hat = np.clip(raw_alpha, float(cfg.alpha_clip_min), float(cfg.alpha_clip_max))

        lambda_used_seq[t] = lambda_vals

        if t >= fit_start:
            alpha_sum += float(np.sum(alpha_hat))
            alpha_count += int(alpha_hat.size)
            lam_sum += float(np.sum(lambda_vals))
            lam_count += int(lambda_vals.size)

        if prev_g is not None and prev_u is not None:
            A_s = prev_u * u_t * (alpha_hat * prev_g - g_t)
            B_s = alpha_hat * prev_u * prev_g - u_t * g_t
            A2_mean = np.mean(A_s**2, axis=1, keepdims=True)
            AB_mean = np.mean(A_s * B_s, axis=1, keepdims=True)
            S_A2 = lambda_rho * S_A2 + (1.0 - lambda_rho) * A2_mean
            S_AB = lambda_rho * S_AB + (1.0 - lambda_rho) * AB_mean

            lambda_unproj = S_AB / (S_A2 + eps_lambda)
            if bool(cfg.use_safe_cap):
                u_abs_max = np.max(np.abs(u_t), axis=1, keepdims=True) + 1e-12
                safe_cap = (1.0 - float(cfg.denom_floor)) / u_abs_max
                cap = np.minimum(safe_cap, float(cfg.lam_cap))
                lambda_vals = np.clip(lambda_unproj, -cap, cap)
            else:
                lambda_vals = np.clip(lambda_unproj, -float(cfg.lam_cap), float(cfg.lam_cap))

        prev_g = g_t
        prev_u = u_t
        h_prev = h_t

    debug = {
        "alpha_hat_mean": float(alpha_sum / max(alpha_count, 1)),
        "lambda_used_mean": float(lam_sum / max(lam_count, 1)),
    }
    return lambda_used_seq, debug


def compute_delta_true(w_hh: np.ndarray, u_seq: np.ndarray, g_seq: np.ndarray) -> np.ndarray:
    time_steps = int(u_seq.shape[0])
    hidden_size = int(u_seq.shape[1])
    batch_size = int(u_seq.shape[2])
    delta_true = np.zeros((time_steps, hidden_size, batch_size), dtype=np.float64)
    delta_next = np.zeros((hidden_size, batch_size), dtype=np.float64)
    W_T = w_hh.T
    for t in reversed(range(time_steps)):
        delta_t = u_seq[t] * (g_seq[t] + (W_T @ delta_next))
        delta_true[t] = delta_t
        delta_next = delta_t
    return delta_true


def compute_delta_true_tbptt(
    w_hh: np.ndarray,
    u_seq: np.ndarray,
    g_seq: np.ndarray,
    tbptt_k: int,
) -> np.ndarray:
    time_steps = int(u_seq.shape[0])
    hidden_size = int(u_seq.shape[1])
    batch_size = int(u_seq.shape[2])
    delta_true = np.zeros((time_steps, hidden_size, batch_size), dtype=np.float64)
    delta_next = np.zeros((hidden_size, batch_size), dtype=np.float64)
    W_T = w_hh.T
    steps = 0
    K = max(1, int(tbptt_k))
    for t in reversed(range(time_steps)):
        if steps >= K:
            delta_next = np.zeros_like(delta_next)
            steps = 0
        delta_t = u_seq[t] * (g_seq[t] + (W_T @ delta_next))
        delta_true[t] = delta_t
        delta_next = delta_t
        steps += 1
    return delta_true


def cosine_mean_time(
    delta_true: np.ndarray,
    delta_hat: np.ndarray,
    burnin_steps: int,
    eps: float = 1e-12,
) -> float:
    time_steps = int(delta_true.shape[0])
    start = max(0, int(burnin_steps))
    if time_steps <= start:
        return float("nan")
    dt = delta_true[start:]
    dh = delta_hat[start:]
    dt_flat = dt.reshape(dt.shape[0], -1)
    dh_flat = dh.reshape(dh.shape[0], -1)
    d_norm = np.linalg.norm(dt_flat, axis=1) + eps
    h_norm = np.linalg.norm(dh_flat, axis=1) + eps
    dot = np.sum(dt_flat * dh_flat, axis=1)
    cos = np.abs(dot) / (d_norm * h_norm)
    return float(np.mean(cos))


def cosine_mean_windows(
    delta_true: np.ndarray,
    u_seq: np.ndarray,
    g_seq: np.ndarray,
    lambda_used_seq: np.ndarray,
    denom_floor: float,
    burnin_steps: int,
    start_eval: list[int],
) -> np.ndarray:
    time_steps = int(delta_true.shape[0])
    hidden_size = int(delta_true.shape[1])
    batch_size = int(delta_true.shape[2])
    fit_start = max(0, int(burnin_steps))

    cos_t = np.zeros((time_steps,), dtype=np.float64)

    d_flat = delta_true.reshape(time_steps, hidden_size * batch_size)
    d_norm = np.linalg.norm(d_flat, axis=1) + 1e-12

    for t in range(time_steps):
        lam = lambda_used_seq[t]  # (N,1)
        denom = 1.0 - lam * u_seq[t]  # (N,B)
        denom = safe_denom(denom, denom_floor)
        delta_hat = (u_seq[t] * g_seq[t]) / denom

        d_t = d_flat[t]
        h_flat = delta_hat.reshape(hidden_size * batch_size)
        h_norm = float(np.linalg.norm(h_flat) + 1e-12)
        dot = float(np.dot(d_t, h_flat))
        cos_t[t] = abs(dot) / (d_norm[t] * h_norm)

        if t < fit_start:
            cos_t[t] = 0.0

    csum = np.concatenate([np.zeros((1,), dtype=np.float64), np.cumsum(cos_t, axis=0)], axis=0)
    cos_mean = np.zeros((len(start_eval),), dtype=np.float64)
    for wi, st in enumerate(start_eval):
        denom = float(time_steps - st)
        cos_mean[wi] = (csum[time_steps] - csum[st]) / max(1e-12, denom)
    return cos_mean


def config_asdict(cfg: ScanConfig) -> dict[str, Any]:
    return asdict(cfg)
