"""
Experiment 5: Softmax attention rank growth (negative test).
"""

from __future__ import annotations

from typing import List, Sequence

import numpy as np

from .common import ExperimentResult, stable_softmax


def run(
    T_values: Sequence[int] = (20, 40, 80, 160, 320, 640, 1280),
    d_k: int = 8,
    seed: int = 0,
) -> List[ExperimentResult]:
    rng = np.random.default_rng(seed)
    results: List[ExperimentResult] = []
    for T in T_values:
        Q = rng.standard_normal((T, d_k))
        K = rng.standard_normal((T, d_k))
        A = np.zeros((T, T), dtype=np.float64)
        for t in range(T):
            scores = Q[t] @ K[: t + 1].T
            weights = stable_softmax(scores)
            A[t, : t + 1] = weights
        rank = int(np.linalg.matrix_rank(A))
        details = f"T={T}, d_k={d_k}, rank(softmax attention)={rank}"
        meta = {"T": T, "d_k": d_k, "seed": seed, "rank": rank}
        results.append(ExperimentResult("Softmax attention rank growth", details, meta))
    return results


if __name__ == "__main__":
    for res in run():
        print(f"[{res.name}] {res.details}")
