import argparse
import math
import os
import sys

import numpy as np


def set_seed(seed: int) -> None:
    np.random.seed(seed)


def make_monotone_operator(d: int, mu: float, seed: int) -> tuple[np.ndarray, np.ndarray]:
    rng = np.random.default_rng(seed)
    q = rng.standard_normal((d, d))
    b = rng.standard_normal(d) / math.sqrt(d)
    sym_psd = (q.T @ q) / d
    skew = rng.standard_normal((d, d))
    skew = 0.5 * (skew - skew.T)
    a = sym_psd + skew + mu * np.eye(d)
    return a, b


def proj_box(x: np.ndarray, bound: float) -> np.ndarray:
    return np.clip(x, -bound, bound)


def gap_box(vx: np.ndarray, x: np.ndarray) -> float:
    # For X = [-1,1]^d: max_{y in X} <vx, x - y> = <vx, x> + ||vx||_1.
    return float(vx @ x + np.sum(np.abs(vx)))


def spectral_norm(a: np.ndarray) -> float:
    return float(np.linalg.norm(a, 2))


def build_schedules(
    d: int,
    L: float,
    D: float,
    b: np.ndarray,
    sigma: float,
    mode: str,
    strong: bool,
    strong_lambda: float,
    strong_c: float,
) -> tuple[callable, callable, float]:
    max_x_norm = math.sqrt(d)
    u_bound = L * max_x_norm + float(np.linalg.norm(b))
    g_bound = math.sqrt(u_bound * u_bound + d * sigma * sigma)

    if strong:
        if strong_lambda <= 0.0:
            raise ValueError("strong_lambda must be positive for strong schedules")
        if mode == "rg":
            def gamma_t(t: int) -> float:
                return 0.0

            def eta_t(t: int) -> float:
                return 1.0 / (strong_lambda * t)
        elif mode == "rog":
            t0 = math.ceil(6.0 * strong_c * L / strong_lambda)

            def gamma_t(t: int) -> float:
                return 0.0

            def eta_t(t: int) -> float:
                return strong_c / (strong_lambda * (t + t0))
        else:
            raise ValueError(f"unknown mode: {mode}")
        return gamma_t, eta_t, g_bound

    if mode == "rg":
        def gamma_t(t: int) -> float:
            return t ** (-4.0 / 5.0)

        def eta_t(t: int) -> float:
            return (D / g_bound) * t ** (-3.0 / 5.0)

    elif mode == "rog":
        if sigma == 0.0:
            def gamma_t(t: int) -> float:
                return 1.0 / math.sqrt(t)

            def eta_t(t: int) -> float:
                return (1.0 / 6.0) * (1.0 / L)
        else:
            def gamma_t(t: int) -> float:
                term = (D * L / (sigma * t * t)) ** (2.0 / 5.0)
                return min(1.0 / math.sqrt(t), term)

            def eta_t(t: int) -> float:
                term = (D ** 4 / (L * sigma ** 4 * t ** 3)) ** (1.0 / 5.0)
                return (1.0 / 6.0) * min(1.0 / L, term)
    else:
        raise ValueError(f"unknown mode: {mode}")

    return gamma_t, eta_t, g_bound


def run_rg(
    a: np.ndarray,
    b: np.ndarray,
    sigma: float,
    t_max: int,
    bound: float,
    seed: int,
    strong: bool,
    strong_lambda: float,
    strong_c: float,
) -> dict:
    d = a.shape[0]
    L = spectral_norm(a)
    D = 2.0 * math.sqrt(d) * bound
    gamma_t, eta_t, _ = build_schedules(
        d, L, D, b, sigma, "rg", strong, strong_lambda, strong_c
    )
    rng = np.random.default_rng(seed)

    x = np.zeros(d)
    gaps = np.zeros(t_max)

    for t in range(1, t_max + 1):
        vx = a @ x + b
        ghat = vx + sigma * rng.standard_normal(d)
        gamma = gamma_t(t)
        eta = eta_t(t)
        x = proj_box((1.0 - gamma) * x - eta * ghat, bound)
        gaps[t - 1] = gap_box(a @ x + b, x) ** 2

    return {"gaps": gaps, "L": L, "D": D}


def run_rog(
    a: np.ndarray,
    b: np.ndarray,
    sigma: float,
    t_max: int,
    bound: float,
    seed: int,
    strong: bool,
    strong_lambda: float,
    strong_c: float,
) -> dict:
    d = a.shape[0]
    L = spectral_norm(a)
    D = 2.0 * math.sqrt(d) * bound
    gamma_t, eta_t, _ = build_schedules(
        d, L, D, b, sigma, "rog", strong, strong_lambda, strong_c
    )
    rng = np.random.default_rng(seed)

    x = np.zeros(d)
    y = np.zeros(d)
    gaps = np.zeros(t_max)

    for t in range(1, t_max + 1):
        vx = a @ x + b
        ghat = vx + sigma * rng.standard_normal(d)
        gamma = gamma_t(t)
        eta = eta_t(t)
        y = proj_box((1.0 - gamma) * y - eta * ghat, bound)
        gamma_next = gamma_t(t + 1)
        eta_next = eta_t(t + 1)
        x = proj_box((1.0 - gamma_next) * y - eta_next * ghat, bound)
        gaps[t - 1] = gap_box(a @ x + b, x) ** 2

    return {"gaps": gaps, "L": L, "D": D}


def average_runs(
    mode: str,
    a: np.ndarray,
    b: np.ndarray,
    sigma: float,
    t_max: int,
    bound: float,
    seeds: list[int],
    strong: bool,
    strong_lambda: float,
    strong_c: float,
) -> np.ndarray:
    all_gaps = []
    for seed in seeds:
        if mode == "rg":
            out = run_rg(
                a, b, sigma, t_max, bound, seed, strong, strong_lambda, strong_c
            )
        elif mode == "rog":
            out = run_rog(
                a, b, sigma, t_max, bound, seed, strong, strong_lambda, strong_c
            )
        else:
            raise ValueError(f"unknown mode: {mode}")
        all_gaps.append(out["gaps"])
    return np.mean(np.stack(all_gaps, axis=0), axis=0)


def plot_curves(
    t: np.ndarray,
    curves: dict,
    out_path: str,
    slope_refs: list[float] | None,
) -> None:
    import matplotlib.pyplot as plt

    def log_markevery(x: np.ndarray, n_marks: int) -> np.ndarray:
        if x.size == 0 or n_marks <= 0:
            return np.array([], dtype=int)
        lo = float(x[0])
        hi = float(x[-1])
        if lo <= 0 or hi <= 0 or hi <= lo:
            return np.linspace(0, x.size - 1, n_marks, dtype=int)
        logx = np.log10(x)
        targets = np.linspace(logx[0], logx[-1], n_marks)
        idx = np.unique(np.searchsorted(logx, targets, side="left"))
        return np.clip(idx, 0, x.size - 1)

    plt.rcParams.update(
        {
            "font.size": 14,
            "axes.labelsize": 15,
            "legend.fontsize": 13,
            "axes.titlesize": 15,
            "xtick.labelsize": 13,
            "ytick.labelsize": 13,
            "axes.linewidth": 0.9,
        }
    )
    plt.figure(figsize=(7.4, 4.9))
    style_map = {
        "RG": {"color": "#1f77b4", "linestyle": "-", "marker": "o"},
        "ROG": {"color": "#ff7f0e", "linestyle": "--", "marker": "s"},
    }
    markevery = log_markevery(t, 10)
    for label, vals in curves.items():
        style = style_map.get(label, {})
        plt.loglog(
            t,
            vals,
            label=label,
            linewidth=2.2,
            markersize=8,
            markevery=markevery,
            **style,
        )
    if slope_refs:
        t0 = t[0]
        anchor = None
        for vals in curves.values():
            if len(vals) > 0:
                anchor = vals[0]
                break
        for slope in slope_refs:
            if anchor is None:
                continue
            ref = anchor * (t / t0) ** (-slope)
            plt.loglog(
                t,
                ref,
                linestyle=(0, (3, 3)),
                color="#2ca02c",
                linewidth=1.4,
                label=f"t^(-{slope:g})",
            )
    plt.xlabel("t")
    plt.ylabel("squared gap")
    plt.grid(True, which="both", linestyle=":", linewidth=0.6, color="#b0b0b0")
    plt.legend(frameon=False)
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)


def main() -> int:
    parser = argparse.ArgumentParser(description="Synthetic VI experiments for RG/ROG.")
    parser.add_argument("--d", type=int, default=10)
    parser.add_argument("--t-max", type=int, default=5000)
    parser.add_argument("--sigma", type=float, default=0.1)
    parser.add_argument("--mu", type=float, default=0.0)
    parser.add_argument("--bound", type=float, default=1.0)
    parser.add_argument("--seeds", type=int, default=5)
    parser.add_argument("--out", type=str, default="synthetic_vi_gap.png")
    parser.add_argument("--only", type=str, default="both", choices=["rg", "rog", "both"])
    parser.add_argument(
        "--slopes",
        type=str,
        default="0.4",
        help="comma-separated reference slopes for t^{-alpha}",
    )
    parser.add_argument("--batch", action="store_true", help="run 6 preset cases")
    parser.add_argument("--mu-strong", type=float, default=0.2)
    parser.add_argument("--strong-c", type=float, default=2.0)
    args = parser.parse_args()

    if args.batch:
        cases = [
            ("monotone", 0.0, 0.1, [0.4], False),
            ("monotone", 0.0, 0.01, [0.4], False),
            ("monotone", 0.0, 0.0, [0.4], False),
            ("strongly-monotone", args.mu_strong, 0.1, [1.0], True),
            ("strongly-monotone", args.mu_strong, 0.01, [1.0], True),
            ("strongly-monotone", args.mu_strong, 0.0, [1.0], True),
        ]
        for tag, mu, sigma, slopes, strong in cases:
            a, b = make_monotone_operator(args.d, mu, seed=1)
            seeds = list(range(args.seeds))
            t = np.arange(1, args.t_max + 1)
            curves = {}
            if args.only in ("rg", "both"):
                curves["RG"] = average_runs(
                    "rg",
                    a,
                    b,
                    sigma,
                    args.t_max,
                    args.bound,
                    seeds,
                    strong,
                    mu,
                    args.strong_c,
                )
            if args.only in ("rog", "both"):
                curves["ROG"] = average_runs(
                    "rog",
                    a,
                    b,
                    sigma,
                    args.t_max,
                    args.bound,
                    seeds,
                    strong,
                    mu,
                    args.strong_c,
                )
            out_name = (
                f"synthetic_vi_{tag}_sigma{sigma:g}_t{args.t_max}.png"
            )
            plot_curves(t, curves, out_name, slopes)
            print(f"saved: {out_name}")
        return 0

    set_seed(0)
    a, b = make_monotone_operator(args.d, args.mu, seed=1)
    strong = args.mu > 0.0

    seeds = list(range(args.seeds))
    t = np.arange(1, args.t_max + 1)
    curves = {}
    if args.only in ("rg", "both"):
        curves["RG"] = average_runs(
            "rg",
            a,
            b,
            args.sigma,
            args.t_max,
            args.bound,
            seeds,
            strong,
            args.mu,
            args.strong_c,
        )
    if args.only in ("rog", "both"):
        curves["ROG"] = average_runs(
            "rog",
            a,
            b,
            args.sigma,
            args.t_max,
            args.bound,
            seeds,
            strong,
            args.mu,
            args.strong_c,
        )

    slope_refs = []
    if args.slopes.strip():
        slope_refs = [float(s) for s in args.slopes.split(",") if s.strip()]
    plot_curves(t, curves, args.out, slope_refs)
    print(f"saved: {args.out}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
