"""
Experiment 3b: Generator rank vs. state dimension with plotting.

This wraps the logic from exp3_rank_vs_state_dim.py, sweeps over multiple
state dimensions and seeds, and produces a plot of the generator rank
(mean ± std) as a function of state dimension N, together with the
reference line y = N.
"""


from __future__ import annotations

import argparse
from pathlib import Path
from typing import Iterable, List, Sequence

import numpy as np

from . import exp3_rank_vs_state_dim as exp3
from .common import ExperimentResult


def run_for_plot(
    N_values: Sequence[int] = (1, 2, 4, 8, 16),
    T: int | None = None,
    seeds: Sequence[int] = (0, 1, 2, 3, 4),
    min_decay: float = 0.3,
    max_decay: float = 0.85,
    min_spacing: float = 0.05,
) -> List[ExperimentResult]:
    """
    Generate results for plotting. Uses spaced decays for each N to avoid
    near-duplicate modes. Narrow decay range plus spacing reduces ill-conditioning
    that can tank ranks at larger N.
    """
    N_values = [int(n) for n in N_values]
    T_use = int(T) if T is not None else max(N_values) * 2
    seeds = list(seeds)

    def sample_decays(n: int, rng: np.random.Generator) -> list[float]:
        """Draw decays with minimum spacing; fall back to linspace if needed."""
        low, high = float(min_decay), float(max_decay)
        for _ in range(200):
            vals = np.sort(rng.uniform(low, high, size=n))
            if n == 1 or np.min(np.diff(vals)) >= min_spacing:
                return vals.tolist()
        return np.linspace(low, high - 1e-6, n).tolist()

    all_results: list[ExperimentResult] = []
    for seed in seeds:
        rng = np.random.default_rng(seed)
        cases: list[dict] = []
        for n in N_values:
            decays = sample_decays(n, rng)
            cases.append({"a": decays, "desc": f"N={n} spaced decays (seed={seed})"})
        # Run with a single seed to preserve metadata but use our pre-sampled decays.
        all_results.extend(
            exp3.run(
                T_values=(T_use,),
                cases=cases,
                seeds=[seed],
            )
        )

    return all_results


def _aggregate(results: Iterable[ExperimentResult]) -> tuple[list[int], dict[int, list[int]]]:
    """Collect generator ranks keyed by state_dim."""
    gen_by_N: dict[int, list[int]] = {}

    for res in results:
        meta = res.meta or {}
        N = meta.get("state_dim")
        gen_rank = meta.get("generator_rank")
        if N is None or gen_rank is None:
            continue
        N = int(N)
        gen_by_N.setdefault(N, []).append(int(gen_rank))

    Ns = sorted(gen_by_N)
    return Ns, gen_by_N


def _mean_std_by_N(Ns: list[int], values: dict[int, list[int]]) -> tuple[np.ndarray, np.ndarray]:
    means, stds = [], []
    for n in Ns:
        arr = np.asarray(values.get(n, []), dtype=np.float64)
        if arr.size == 0:
            means.append(np.nan)
            stds.append(0.0)
        else:
            means.append(float(arr.mean()))
            stds.append(float(arr.std(ddof=1)) if arr.size > 1 else 0.0)
    return np.asarray(means), np.asarray(stds)


def plot_results(
    results: List[ExperimentResult],
    out_path: Path = Path("outputs/exp3b_rank_vs_state_dim.png"),
) -> None:
    try:
        import matplotlib.pyplot as plt  # type: ignore
    except ImportError:  # pragma: no cover
        print("matplotlib not available; skipping exp3b plot.")
        return

    if not results:
        print("No results to plot for exp3b.")
        return

    Ns, gen_by_N = _aggregate(results)
    if not Ns:
        print("No valid entries found in exp3b results.")
        return

    gen_mean, gen_std = _mean_std_by_N(Ns, gen_by_N)
    Ns_arr = np.asarray(Ns, dtype=np.float64)

    out_path.parent.mkdir(parents=True, exist_ok=True)

    plt.figure(figsize=(8, 6))

    # Generator rank
    plt.plot(
        Ns_arr,
        gen_mean,
        "-o",
        linewidth=3.5,
        color="crimson",
        alpha=0.9,
        label="Generator rank (mean)",
    )
    plt.fill_between(
        Ns_arr,
        np.maximum(0, gen_mean - gen_std),
        gen_mean + gen_std,
        color="crimson",
        alpha=0.18,
    )

    # Reference line y = N
    plt.plot(
        Ns_arr,
        Ns_arr,
        "--",
        linewidth=2,
        color="gray",
        alpha=0.6,
        label="y = N reference",
    )

    plt.xlabel("State dimension $N$", fontsize=18)
    plt.ylabel("Rank", fontsize=18)
    plt.title(
        "Experiment 3: Generator Rank Growth vs. State Dimension",
        fontsize=16,
        fontweight="bold",
        pad=12,
    )
    plt.grid(True, alpha=0.3)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.ylim(0, max(Ns_arr) + 1)
    plt.legend(fontsize=11)
    plt.tight_layout()
    plt.savefig(out_path, dpi=250, bbox_inches="tight")
    plt.close()

    print(f"Saved exp3b plot to {out_path}")


def main() -> None:
    parser = argparse.ArgumentParser(description="Experiment 3b: rank vs. state dimension with plotting.")
    parser.add_argument(
        "--N-values",
        type=int,
        nargs="+",
        default=[1, 2, 4, 8, 16],
        help="State dimensions to evaluate.",
    )
    parser.add_argument(
        "--T",
        type=int,
        default=None,
        help="Sequence length. Defaults to 2 * max(N).",
    )
    parser.add_argument(
        "--seeds",
        type=int,
        nargs="+",
        default=[0, 1, 2, 3, 4],
        help="Seeds for random decays.",
    )
    parser.add_argument(
        "--min-decay",
        type=float,
        default=0.3,
        help="Lower bound for sampled decays.",
    )
    parser.add_argument(
        "--max-decay",
        type=float,
        default=0.85,
        help="Upper bound for sampled decays (keep <1 to reduce ill-conditioning).",
    )
    parser.add_argument(
        "--min-spacing",
        type=float,
        default=0.05,
        help="Minimum spacing between sampled decays to avoid duplicates.",
    )
    parser.add_argument(
        "--out-path",
        type=Path,
        default=Path("outputs/exp3b_rank_vs_state_dim.png"),
        help="Where to save the plot.",
    )
    args = parser.parse_args()

    results = run_for_plot(
        N_values=args.N_values,
        T=args.T,
        seeds=args.seeds,
        min_decay=args.min_decay,
        max_decay=args.max_decay,
        min_spacing=args.min_spacing,
    )
    for res in results:
        print(f"[{res.name}] {res.details}")
    plot_results(results, out_path=args.out_path)


if __name__ == "__main__":
    main()
