import numpy as np
import matplotlib.pyplot as plt


def StoNewton(c1, c2, Max_Iter, ttau, nx, X_true, RR, IdR, EPS=1e-8, solver="nesterov"):

    # --- build Sigma ---
    if IdR == 1:
        Sigma = np.array([[RR ** abs(i - j) for j in range(nx)] for i in range(nx)])
    elif IdR == 0:
        Sigma = RR * np.ones((nx, nx)) + (1 - RR) * np.eye(nx) + np.eye(nx)
    elif IdR == 2:
        Sigma = np.eye(nx)
    else:
        raise ValueError("Unsupported IdR")

    t = 1
    X_t = np.ones(nx)

    W_t = np.zeros((nx, nx))
    nu_vec = np.zeros(nx)
    u_t = 0.0

    cum_bar2x_f_t = np.eye(nx)

    while t <= Max_Iter + 1:
        K_t = cum_bar2x_f_t
        beta_t = c1 / (t ** c2)

        # Step 1: sample stochastic gradient + Hessian sample (linear regression)
        a_t = np.random.multivariate_normal(np.zeros(nx), Sigma)
        eps_t = np.random.normal(0.0, 1.0)
        barg_t = a_t * (a_t @ (X_t - X_true)) - eps_t * a_t

        bar_nab_x2f_t = np.outer(a_t, a_t)
        cum_bar2x_f_t = (t / (t + 1.0)) * cum_bar2x_f_t + (1.0 / (t + 1.0)) * bar_nab_x2f_t

        # Step 2: inexact Newton solve
        if ttau == 0:
            NewDir_t = np.linalg.solve(K_t + 1e-12 * np.eye(nx), -barg_t)
        else:
            if solver == "vanilla":
                NewDir_t = np.zeros(nx)
                K2_t = K_t @ K_t
                denom_diag = np.diag(K2_t)
                denom_diag = np.where(np.abs(denom_diag) < EPS, EPS, denom_diag)

                for _ in range(ttau):
                    j = np.random.randint(nx)  # uniform in {0,...,nx-1}
                    numer = (K_t[j, :] @ NewDir_t + barg_t[j])
                    NewDir_t = NewDir_t - (numer / denom_diag[j]) * K_t[:, j]

            elif solver == "nesterov":
                B = 0.5 * (K_t + K_t.T) + 1e-12 * np.eye(nx)
                B2 = B @ B
                denom = np.diag(B2)

                Z = np.zeros((nx, nx))
                Zj_list = []
                for j in range(nx):
                    Zj = np.outer(B[:, j], B[j, :]) / denom[j]
                    Zj_list.append(Zj)
                    Z += Zj
                Z /= nx

                evals, evecs = np.linalg.eigh(Z)
                mu_t = float(evals.min())

                invZ = (evecs * (1.0 / evals)) @ evecs.T
                inv_sqrtZ = (evecs * (1.0 / np.sqrt(evals))) @ evecs.T

                M = np.zeros((nx, nx))
                for Zj in Zj_list:
                    M += Zj @ (invZ @ Zj)
                M /= nx


                Tmat = inv_sqrtZ @ M @ inv_sqrtZ
                nu_t_val = float(np.linalg.eigvalsh(Tmat).max())


                # mu_t = 0.1
                # nu_t_val = 10

                if t == Max_Iter + 1:
                    print(f"mu_t={mu_t:.6g}, nu_t={nu_t_val:.6g}")

                # --- Algorithm-1 parameters ---

                beta_acc = 1.0 - np.sqrt(mu_t / nu_t_val)
                gamma_acc = np.sqrt(1.0 / (mu_t * nu_t_val))
                alpha_acc = 1.0 / (1.0 + gamma_acc * nu_t_val)

                dx = np.zeros(nx)
                v = np.zeros(nx)

                for _ in range(ttau):
                    dy = alpha_acc * v + (1.0 - alpha_acc) * dx
                    j = np.random.randint(nx)

                    numer = B[j, :] @ dy + barg_t[j]
                    omega = (numer / denom[j]) * B[:, j]

                    dx = dy - omega
                    v = beta_acc * v + (1.0 - beta_acc) * dy - gamma_acc * omega

                NewDir_t = dx

            else:
                raise ValueError("solver must be 'vanilla' or 'nesterov'")

        # Step 3: update
        X_t = X_t + beta_t * NewDir_t


        t += 1

    return X_t


def qqplot_two_samples(x, y, title, fname):
    x = np.asarray(x).ravel()
    y = np.asarray(y).ravel()
    n = min(len(x), len(y))
    xs = np.sort(x)[:n]
    ys = np.sort(y)[:n]

    plt.figure(figsize=(6, 6))
    plt.plot(xs, ys, marker='.', linestyle='none')
    lo = min(xs.min(), ys.min())
    hi = max(xs.max(), ys.max())
    plt.plot([lo, hi], [lo, hi], linestyle='--')
    plt.xlabel("Vanilla quantiles")
    plt.ylabel("Nesterov quantiles")
    # plt.title(title)
    plt.grid(True, linestyle="--", linewidth=0.5)
    plt.tight_layout()
    plt.savefig(fname, dpi=200)
    # plt.show()


def main():
    # ---- experiment config ----
    c1, c2 = 1.0, 0.501
    T = 100000
    Exp_num = 200

    nx = 10
    X_true = np.linspace(0, 1, nx)
    w = np.ones(nx)

    beta_T = c1 / (T ** c2)

    RR, IdR = 0.4, 0
    ttau = 5

    # ---- run experiments ----
    y_vanilla = np.zeros(Exp_num)
    y_nesterov = np.zeros(Exp_num)

    for k in range(Exp_num):
        print(f"[{k+1}/{Exp_num}] vanilla ...")
        xT_v = StoNewton(c1, c2, T, ttau, nx, X_true, RR, IdR, solver="vanilla")
        y_vanilla[k] = float( np.sqrt(1/beta_T) * w @ (xT_v - X_true))

    for k in range(Exp_num):
        print(f"[{k+1}/{Exp_num}] nesterov ...")
        xT_n = StoNewton(c1, c2, T, ttau, nx, X_true, RR, IdR, solver="nesterov")
        y_nesterov[k] = float( np.sqrt(1/beta_T) * w @ (xT_n - X_true))


    mean_v = float(np.mean(y_vanilla))
    mean_n = float(np.mean(y_nesterov))
    var_v = float(np.var(y_vanilla, ddof=1))
    var_n = float(np.var(y_nesterov, ddof=1))
    cov_mat = np.cov(np.vstack([y_vanilla, y_nesterov]), ddof=1)

    print("\n=== Summary for y = w^T x_T ===")
    print(f"Vanilla : mean={mean_v:.6g}, var={var_v:.6g}")
    print(f"Nesterov: mean={mean_n:.6g}, var={var_n:.6g}")
    print("\nCovariance matrix of [y_vanilla, y_nesterov]^T across experiments:")
    print(cov_mat)

    out_txt = "y_vanilla_y_nesterov_nx_10_tau_5_munu_1.txt"

    data_out = np.column_stack([y_vanilla, y_nesterov])

    header = (
        "col0: y_vanilla, col1: y_nesterov\n"
        f"T={T}, Exp_num={Exp_num}, nx={nx}, ttau={ttau}, RR={RR}, IdR={IdR}, "
        f"c1={c1}, c2={c2}, beta_T={beta_T}\n"
        "Each row is one experiment."
    )

    np.savetxt(out_txt, data_out, fmt="%.18e", delimiter="\t", header=header)
    print(f"Saved y arrays to: {out_txt}")

    # QQ plot
    qqplot_two_samples(
        y_vanilla, y_nesterov,
        title=r"QQ plot (linear regression using Kazmaz sampling)",
        fname="qq_linear_kazmaz_nx_10_tau_5_munu_1.png"
    )


if __name__ == "__main__":
    main()


