import numpy as np
from scipy.linalg import lu, inv, norm
from numpy.random import uniform, normal, multivariate_normal
from scipy.stats import bernoulli, multivariate_normal
import time
import matplotlib.pyplot as plt



def StoNewton(c1, c2, Max_Iter, ttau, nx, X_true, RR, IdR, EPS=1e-8,
              solver="nesterov", pert = 2.0):


    if IdR == 1:
        Sigma = np.array([[RR ** abs(i - j) for j in range(nx)] for i in range(nx)]) 
    elif IdR == 0:
        Sigma = RR * np.ones((nx, nx)) + (1 - RR) * np.eye(nx) + np.eye(nx)
    elif IdR == 2:
        Sigma = np.eye(nx)
        raise ValueError("Unsupported IdR")

    t = 1
    X_t = np.ones(nx)

    W_t = np.zeros((nx, nx))
    nu_t = np.zeros(nx)
    u_t = 0

    Pert = pert * np.eye(nx)

    cum_bar2x_f_t = np.eye(nx)
    ErrX, ErrAvgX = [], []

    while t <= Max_Iter + 1:
        K_t = cum_bar2x_f_t
        beta_t = c1 / (t ** c2)

        # Step 1: sample
        a_t = multivariate_normal.rvs(mean=np.zeros(nx), cov=Sigma)
        b_t = 2 * bernoulli.rvs(1 / (1 + np.exp(-a_t.T @ X_t))) - 1
        barg_t = -b_t / (1 + np.exp(b_t * a_t.T @ X_t)) * a_t + pert * (X_t - X_true)

        bar_nab_x2f_t = np.outer(a_t, a_t) / ((1 + np.exp(a_t.T @ X_t)) * (1 + np.exp(-a_t.T @ X_t))) + Pert
        cum_bar2x_f_t = (t / (t + 1)) * cum_bar2x_f_t + (1 / (t + 1)) * bar_nab_x2f_t

        # Step 2: inexact Newton solve
        if ttau == 0:
            if t <= 1e3:
                NewDir_t = 1e-3 * np.linalg.solve(K_t, -barg_t)
            else:
                NewDir_t = np.linalg.solve(K_t, -barg_t)
        else:
            if solver == "vanilla":
                NewDir_t = np.zeros(nx)
                for inner_iter in range(ttau):
                    # Randomly pick an index
                    j = np.random.choice(range(nx))
                    K2_t = K_t @ K_t
                    NewDir_t = NewDir_t - (K_t[j, :] @ NewDir_t + barg_t[j]) * K_t[:, j] / (K2_t[j, j])

            elif solver == "nesterov":
                B = K_t
                B2 = K_t @ K_t
                denom = np.diag(B2)

                # --- compute mu_t, nu_t from paper definitions (coordinate sketch) ---
                Z = np.zeros((nx, nx))
                Zj_list = []
                for j in range(nx):
                    Zj = np.outer(B[:, j], B[j, :]) / denom[j]
                    Zj_list.append(Zj)
                    Z += Zj
                Z /= nx

                evals, evecs = np.linalg.eigh(Z)
                mu_t = float(evals.min())

                inv_evals = 1.0 / evals
                inv_sqrt_evals = 1.0 / np.sqrt(evals)
                invZ = (evecs * inv_evals) @ evecs.T
                inv_sqrtZ = (evecs * inv_sqrt_evals) @ evecs.T

                M = np.zeros((nx, nx))
                for Zj in Zj_list:
                    M += Zj @ (invZ @ Zj)
                M /= nx

                Tmat = inv_sqrtZ @ M @ inv_sqrtZ
                nu_t_val = float(np.linalg.eigvalsh(Tmat).max())

                mu_t = 0.1
                nu_t_val = 10

                if t == Max_Iter + 1:
                    print(f"mu_t={mu_t:.6g}, nu_t={nu_t_val:.6g}")

                # --- Algorithm-1 parameters ---
                beta_acc  = 1.0 - np.sqrt(mu_t / nu_t_val)
                gamma_acc = np.sqrt(1.0 / (mu_t * nu_t_val))
                alpha_acc = 1.0 / (1.0 + gamma_acc * nu_t_val)

                dx = np.zeros(nx)
                v  = np.zeros(nx)
                for _ in range(ttau):
                    dy = alpha_acc * v + (1.0 - alpha_acc) * dx
                    j = np.random.choice(range(nx))
                    numer = B[j, :] @ dy + barg_t[j]
                    omega = (numer / denom[j]) * B[:, j]
                    dx = dy - omega
                    v  = beta_acc * v + (1.0 - beta_acc) * dy - gamma_acc * omega

                NewDir_t = dx

            else:
                raise ValueError("solver must be 'vanilla' or 'nesterov'")

        # Step 3: update
        X_t = X_t + beta_t * NewDir_t

        ErrX.append(np.linalg.norm(X_t - X_true))

        t += 1

    return X_t



def qqplot_two_samples(x, y, title, fname):
    x = np.asarray(x).ravel()
    y = np.asarray(y).ravel()
    n = min(len(x), len(y))
    xs = np.sort(x)[:n]
    ys = np.sort(y)[:n]

    plt.figure(figsize=(6, 6))
    plt.plot(xs, ys, marker='.', linestyle='none')
    lo = min(xs.min(), ys.min())
    hi = max(xs.max(), ys.max())
    plt.plot([lo, hi], [lo, hi], linestyle='--')
    plt.xlabel("Vanilla quantiles")
    plt.ylabel("Nesterov quantiles")
    plt.title(title)
    plt.grid(True, linestyle="--", linewidth=0.5)
    plt.tight_layout()
    plt.savefig(fname, dpi=200)
    plt.show()


def main():
    # ---- experiment config ----
    c1, c2 = 1.0, 0.501
    T = 100000
    Exp_num = 200

    nx = 10
    X_true = np.linspace(0, 1, nx)
    w = np.ones(nx)

    beta_T = c1 / (T ** c2)

    RR, IdR = 0.4, 0
    ttau = 5

    # ---- run experiments ----
    y_vanilla = np.zeros(Exp_num)
    y_nesterov = np.zeros(Exp_num)

    for k in range(Exp_num):
        print(f"[{k+1}/{Exp_num}] nesterov ...")
        xT_n = StoNewton(c1, c2, T, ttau, nx, X_true, RR, IdR, solver="nesterov")
        y_nesterov[k] = float( np.sqrt(1/beta_T) * w @ (xT_n - X_true) )

    for k in range(Exp_num):
        print(f"[{k+1}/{Exp_num}] vanilla ...")
        xT_v = StoNewton(c1, c2, T, ttau, nx, X_true, RR, IdR, solver="vanilla")
        y_vanilla[k] = float( np.sqrt(1/beta_T) * w @ (xT_v - X_true) )

    mean_v = float(np.mean(y_vanilla))
    mean_n = float(np.mean(y_nesterov))
    var_v = float(np.var(y_vanilla, ddof=1))
    var_n = float(np.var(y_nesterov, ddof=1))
    cov_mat = np.cov(np.vstack([y_vanilla, y_nesterov]), ddof=1)

    print("\n=== Summary for y = w^T x_T ===")
    print(f"Vanilla : mean={mean_v:.6g}, var={var_v:.6g}")
    print(f"Nesterov: mean={mean_n:.6g}, var={var_n:.6g}")
    print("\nCovariance matrix of [y_vanilla, y_nesterov]^T across experiments:")
    print(cov_mat)

    out_txt = "y_vanilla_y_nesterov_nx_10_tau_5_munu_1.txt"

    data_out = np.column_stack([y_vanilla, y_nesterov])

    header = (
        "col0: y_vanilla, col1: y_nesterov\n"
        f"T={T}, Exp_num={Exp_num}, nx={nx}, ttau={ttau}, RR={RR}, IdR={IdR}, "
        f"c1={c1}, c2={c2}, beta_T={beta_T}\n"
        "Each row is one experiment."
    )

    np.savetxt(out_txt, data_out, fmt="%.18e", delimiter="\t", header=header)
    print(f"Saved y arrays to: {out_txt}")

    # QQ plot
    qqplot_two_samples(
        y_vanilla, y_nesterov,
        title=r"QQ plot (logistic regression using Kazmaz sampling)",
        fname="qq_logistic_kazmaz_nx_10_tau_5_munu_1.png"
    )


if __name__ == "__main__":
    main()


