import numpy as np
import matplotlib.pyplot as plt


def StoNewton(c1, c2, Max_Iter, ttau, nx, X_true, RR, IdR, EPS=1e-8, solver="nesterov"):

    # --- build Sigma ---
    if IdR == 1:  # Toeplitz (no diagonal shift)
        Sigma = np.array([[RR ** abs(i - j) for j in range(nx)] for i in range(nx)])
    elif IdR == 0:  # equi-corr-ish
        Sigma = RR * np.ones((nx, nx)) + (1 - RR) * np.eye(nx) + np.eye(nx)
    elif IdR == 2:  # identity
        Sigma = np.eye(nx)
    else:
        raise ValueError("Unsupported IdR")

    t = 1
    X_t = np.ones(nx)

    W_t = np.zeros((nx, nx))
    nu_vec = np.zeros(nx)
    u_t = 0.0

    cum_bar2x_f_t = np.eye(nx)

    while t <= Max_Iter + 1:
        K_t = cum_bar2x_f_t
        beta_t = c1 / (t ** c2)

        # Step 1: sample stochastic gradient + Hessian sample (linear regression)
        a_t = np.random.multivariate_normal(np.zeros(nx), Sigma)
        eps_t = np.random.normal(0.0, 1.0)
        barg_t = a_t * (a_t @ (X_t - X_true)) - eps_t * a_t

        bar_nab_x2f_t = np.outer(a_t, a_t)
        cum_bar2x_f_t = (t / (t + 1.0)) * cum_bar2x_f_t + (1.0 / (t + 1.0)) * bar_nab_x2f_t

        # Step 2: inexact Newton solve
        if ttau == 0:
            NewDir_t = np.linalg.solve(K_t + 1e-12 * np.eye(nx), -barg_t)
        else:
            if solver == "vanilla":
                B = K_t
                dx = np.zeros(nx)
                for _ in range(ttau):
                    s = np.random.normal(size=nx)
                    u = B @ s
                    denom = float(u @ u)
                    denom = max(denom, EPS)

                    numer = float(s @ (B @ dx + barg_t))
                    dx = dx - (numer / denom) * u

                NewDir_t = dx


            elif solver == "nesterov":
                B = K_t

                # --- Monte Carlo to approximate expectations in the definitions of mu_t and nu_t ---
                m_mc = 30

                Z = np.zeros((nx, nx))
                U_list = []
                den_list = []
                Ztil_list = []

                for _ in range(m_mc):
                    s = np.random.normal(size=nx)
                    u = B @ s
                    denom = float(u @ u)

                    Ztil = np.outer(u, u) / denom

                    U_list.append(u)
                    den_list.append(denom)
                    Ztil_list.append(Ztil)
                    Z += Ztil

                Z /= m_mc

                evals, evecs = np.linalg.eigh(Z)
                evals = np.maximum(evals, 1e-12)
                mu_t = float(evals.min())

                invZ = (evecs * (1.0 / evals)) @ evecs.T
                inv_sqrtZ = (evecs * (1.0 / np.sqrt(evals))) @ evecs.T

                M = np.zeros((nx, nx))
                for u, denom, Ztil in zip(U_list, den_list, Ztil_list):
                    scalar = float(u @ (invZ @ u)) / denom
                    M += scalar * Ztil
                M /= m_mc

                Tmat = inv_sqrtZ @ M @ inv_sqrtZ
                nu_t_val = float(np.linalg.eigvalsh(Tmat).max())


                # nu_t_val = 1/(mu_t)

                if t == Max_Iter + 1:
                    print(f"mu_t={mu_t:.6g}, nu_t={nu_t_val:.6g}")

                # --- Algorithm-1 parameters (paper) ---
                beta_acc  = 1.0 - np.sqrt(mu_t / nu_t_val)
                gamma_acc = np.sqrt(1.0 / (mu_t * nu_t_val))
                alpha_acc = 1.0 / (1.0 + gamma_acc * nu_t_val)

                # --- accelerated inner iterations ---
                dx = np.zeros(nx)
                v  = np.zeros(nx)
                for _ in range(ttau):
                    dy = alpha_acc * v + (1.0 - alpha_acc) * dx

                    s = np.random.normal(size=nx)
                    u = B @ s
                    denom = float(u @ u)

                    numer = float(s @ (B @ dy + barg_t))
                    omega = (numer / denom) * u

                    dx = dy - omega
                    v  = beta_acc * v + (1.0 - beta_acc) * dy - gamma_acc * omega

                NewDir_t = dx
            else:
                raise ValueError("solver must be 'vanilla' or 'nesterov'")

        # Step 3: update
        X_t = X_t + beta_t * NewDir_t

        t += 1

    return X_t


def qqplot_two_samples(x, y, title, fname):
    x = np.asarray(x).ravel()
    y = np.asarray(y).ravel()
    n = min(len(x), len(y))
    xs = np.sort(x)[:n]
    ys = np.sort(y)[:n]

    plt.figure(figsize=(6, 6))
    plt.plot(xs, ys, marker='.', linestyle='none')
    lo = min(xs.min(), ys.min())
    hi = max(xs.max(), ys.max())
    plt.plot([lo, hi], [lo, hi], linestyle='--')
    plt.xlabel("Vanilla quantiles")
    plt.ylabel("Nesterov quantiles")
    # plt.title(title)
    plt.grid(True, linestyle="--", linewidth=0.5)
    plt.tight_layout()
    plt.savefig(fname, dpi=200)
    # plt.show()


def main():
    # ---- experiment config ----
    c1, c2 = 1.0, 0.501
    T = 100000
    Exp_num = 200

    nx = 10
    X_true = np.linspace(0, 1, nx)
    w = np.ones(nx)

    beta_T = c1 / (T ** c2)

    RR, IdR = 0.4, 0
    ttau = 5

    # ---- run experiments ----
    y_vanilla = np.zeros(Exp_num)
    y_nesterov = np.zeros(Exp_num)

    for k in range(Exp_num):
        print(f"[{k+1}/{Exp_num}] vanilla ...")
        xT_v = StoNewton(c1, c2, T, ttau, nx, X_true, RR, IdR, solver="vanilla")
        y_vanilla[k] = float( np.sqrt(1/beta_T) * w @ (xT_v - X_true))

    for k in range(Exp_num):
        print(f"[{k+1}/{Exp_num}] nesterov ...")
        xT_n = StoNewton(c1, c2, T, ttau, nx, X_true, RR, IdR, solver="nesterov")
        y_nesterov[k] = float( np.sqrt(1/beta_T) * w @ (xT_n - X_true))

    # ---- sample covariance (variance) of y = w^T xT ----
    mean_v = float(np.mean(y_vanilla))
    mean_n = float(np.mean(y_nesterov))
    var_v = float(np.var(y_vanilla, ddof=1))
    var_n = float(np.var(y_nesterov, ddof=1))
    cov_mat = np.cov(np.vstack([y_vanilla, y_nesterov]), ddof=1)

    print("\n=== Summary for y = w^T x_T ===")
    print(f"Vanilla : mean={mean_v:.6g}, var={var_v:.6g}")
    print(f"Nesterov: mean={mean_n:.6g}, var={var_n:.6g}")
    print("\nCovariance matrix of [y_vanilla, y_nesterov]^T across experiments:")
    print(cov_mat)

    # ===== save y_vanilla and y_nesterov to txt =====
    out_txt = "y_vanilla_y_nesterov_nx_10_tau_5_munu_1.txt"

    # 2-column file: col0 = y_vanilla, col1 = y_nesterov
    data_out = np.column_stack([y_vanilla, y_nesterov])

    header = (
        "col0: y_vanilla, col1: y_nesterov\n"
        f"T={T}, Exp_num={Exp_num}, nx={nx}, ttau={ttau}, RR={RR}, IdR={IdR}, "
        f"c1={c1}, c2={c2}, beta_T={beta_T}\n"
        "Each row is one experiment."
    )

    np.savetxt(out_txt, data_out, fmt="%.18e", delimiter="\t", header=header)
    print(f"Saved y arrays to: {out_txt}")


    # QQ plot
    qqplot_two_samples(
        y_vanilla, y_nesterov,
        title=r"QQ plot (linear regression using Gaussian sampling)",
        fname="qq_linear_Gaussian_nx_10_tau_5_munu_1.png"
    )


if __name__ == "__main__":
    main()


