"""
Experiment 5b: Softmax attention rank growth with plotting (mean ± std).
"""

from __future__ import annotations

from pathlib import Path
from typing import List, Sequence
import argparse

import numpy as np

from .common import ExperimentResult, stable_softmax


def run(
    T_values: Sequence[int] = (20, 40, 80, 160, 320, 640, 1280),
    d_k: int = 8,
    seed: int = 0,
    seeds: Sequence[int] | None = None,
) -> List[ExperimentResult]:
    seeds_list = list(seeds) if seeds is not None else [seed]
    results: List[ExperimentResult] = []
    for sd in seeds_list:
        rng = np.random.default_rng(sd)
        for T in T_values:
            Q = rng.standard_normal((T, d_k))
            K = rng.standard_normal((T, d_k))
            A = np.zeros((T, T), dtype=np.float64)
            for t in range(T):
                scores = Q[t] @ K[: t + 1].T
                weights = stable_softmax(scores)
                A[t, : t + 1] = weights
            rank = int(np.linalg.matrix_rank(A))
            details = f"T={T}, d_k={d_k}, rank(softmax attention)={rank}"
            meta = {"T": T, "d_k": d_k, "seed": sd, "rank": rank}
            results.append(
                ExperimentResult(
                    "Softmax attention rank growth",
                    details,
                    {**meta, "variant": "b"},
                )
            )
    return results


def plot_results(
    results: List[ExperimentResult],
    out_path: Path = Path("outputs/exp5b_rank_growth.png"),
    gap_path: Path | None = None,
) -> None:
    try:
        import matplotlib.pyplot as plt  # type: ignore
    except ImportError:  # pragma: no cover
        print("matplotlib not available; skipping exp5b plot.")
        return

    if not results:
        print("No results to plot for exp5b.")
        return

    # Aggregate by T
    agg: dict[int, list[int]] = {}
    for r in results:
        if not r.meta:
            continue
        T = r.meta.get("T")
        rank = r.meta.get("rank")
        if T is None or rank is None:
            continue
        agg.setdefault(int(T), []).append(int(rank))

    xs, ys, stds = [], [], []
    for T in sorted(agg):
        arr = np.asarray(agg[T], dtype=np.float64)
        xs.append(T)
        ys.append(float(arr.mean()))
        stds.append(float(arr.std(ddof=1)) if arr.size > 1 else 0.0)

    out_path.parent.mkdir(parents=True, exist_ok=True)
    plt.figure(figsize=(8, 6))
    xs_arr = np.array(xs, dtype=np.float64)
    ys_arr = np.array(ys, dtype=np.float64)
    std_arr = np.array(stds, dtype=np.float64)
    has_variance = np.any(std_arr > 0)
    plt.plot(xs_arr, ys_arr, "-o", linewidth=4, alpha=0.9, color="darkred", label="Softmax attention rank")
    if has_variance:
        plt.fill_between(xs_arr, np.maximum(0, ys_arr - std_arr), ys_arr + std_arr, color="darkred", alpha=0.18)
    plt.plot(xs_arr, xs_arr, "--", alpha=0.4, color="black", label="Full rank baseline")
    plt.xlabel("Sequence length T", fontsize=24)
    plt.ylabel("Matrix rank", fontsize=24)
    plt.title("Experiment 5: Softmax attention rank growth", fontsize=20, fontweight="bold", pad=16)
    plt.tick_params(axis="both", which="major", labelsize=18)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=14)
    plt.tight_layout()
    plt.savefig(out_path, dpi=225, bbox_inches="tight")
    plt.close()
    print(f"Saved exp5b plot to {out_path}")

    # Gap plot: distance to full rank baseline (T - rank).
    if gap_path is None:
        gap_path = out_path.with_name(out_path.stem + "_gap" + out_path.suffix)
    gap_mean = xs_arr - ys_arr
    gap_std = std_arr
    has_variance_gap = np.any(gap_std > 0)
    plt.figure(figsize=(8, 6))
    plt.plot(xs_arr, gap_mean, "-o", linewidth=4, alpha=0.9, color="darkred", label="Full rank - observed rank")
    if has_variance_gap:
        plt.fill_between(
            xs_arr,
            np.maximum(0, gap_mean - gap_std),
            gap_mean + gap_std,
            color="darkred",
            alpha=0.18,
        )
    plt.xlabel("Sequence length T", fontsize=24)
    plt.ylabel("Rank gap (T - rank)", fontsize=24)
    plt.title("Experiment 5: Softmax attention rank gap vs. Sequence length", fontsize=20, fontweight="bold", pad=16)
    plt.tick_params(axis="both", which="major", labelsize=18)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=14)
    plt.tight_layout()
    plt.savefig(gap_path, dpi=225, bbox_inches="tight")
    plt.close()
    print(f"Saved exp5b gap plot to {gap_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Exp5: Softmax rank growth with variance bands.")
    parser.add_argument("--T-values", type=int, nargs="+", default=[20, 40, 80, 160, 320, 640, 1280, 2560, 5120])
    parser.add_argument("--d-k", type=int, default=8)
    parser.add_argument("--seeds", type=int, nargs="+", default=[0, 1, 2, 3, 4], help="Seeds to sweep (mean ± std shown).")
    parser.add_argument("--out-path", type=str, default="outputs/exp5b_rank_growth.png")
    parser.add_argument("--out-gap-path", type=str, default=None, help="Optional path for the rank-gap plot.")
    args = parser.parse_args()

    res = run(T_values=args.T_values, d_k=args.d_k, seeds=args.seeds)
    for r in res:
        print(f"[{r.name}] {r.details}")
    plot_results(res, Path(args.out_path), Path(args.out_gap_path) if args.out_gap_path else None)
