import numpy as np
from scipy.linalg import lu, inv, norm
from numpy.random import uniform, normal, multivariate_normal
from scipy.stats import bernoulli, multivariate_normal
import time
import matplotlib.pyplot as plt


def StoNewton(c1, c2, Max_Iter, ttau, nx, X_true, RR, IdR, EPS=1e-8,
              solver="nesterov", pert = 2.0):

    Buff = 30000

    if IdR == 1:
        Sigma = np.array([[RR ** abs(i - j) for j in range(nx)] for i in range(nx)]) + 5 * np.eye(nx)
    elif IdR == 0:
        Sigma = RR * np.ones((nx, nx)) + (1 - RR) * np.eye(nx) + np.eye(nx)
    elif IdR == 2:
        Sigma = np.eye(nx)
    else:
        raise ValueError("Unsupported IdR")


    t = 1
    X_t = np.ones(nx)
    AvgX_t = np.zeros(nx)
    

    W_t = np.zeros((nx, nx))
    nu_t = np.zeros(nx)
    u_t = 0


    Pert = pert * np.eye(nx)

    cum_bar2x_f_t = np.eye(nx)
    ErrX, ErrAvgX = [], []

    while t <= Max_Iter + 1:
        K_t = cum_bar2x_f_t
        beta_t = c1 / (t ** c2)

        # Step 1: sample
        a_t = multivariate_normal.rvs(mean=np.zeros(nx), cov=Sigma)
        b_t = 2 * bernoulli.rvs(1 / (1 + np.exp(-a_t.T @ X_t))) - 1
        barg_t = -b_t / (1 + np.exp(b_t * a_t.T @ X_t)) * a_t + pert * (X_t - X_true)

        bar_nab_x2f_t = np.outer(a_t, a_t) / ((1 + np.exp(a_t.T @ X_t)) * (1 + np.exp(-a_t.T @ X_t))) + Pert
        cum_bar2x_f_t = (t / (t + 1)) * cum_bar2x_f_t + (1 / (t + 1)) * bar_nab_x2f_t

        # Step 2: inexact Newton solve
        if ttau == 0:
            if t <= 1e3:
                NewDir_t = 1e-3 * np.linalg.solve(K_t, -barg_t)
            else:
                NewDir_t = np.linalg.solve(K_t, -barg_t)
        else:
            if solver == "vanilla":
                B = K_t
                dx = np.zeros(nx)
                for _ in range(ttau):
                    s = np.random.normal(size=nx)        # s ~ N(0, I)
                    u = B @ s                            # u = B s
                    denom = float(u @ u)                 # s^T B^2 s = ||B s||^2
                    denom = max(denom, EPS)

                    numer = float(s @ (B @ dx + barg_t)) # s^T (B dx + g)
                    dx = dx - (numer / denom) * u        # dx <- dx - omega

                NewDir_t = dx

            elif solver == "nesterov":
                B = K_t

                # --- Monte Carlo to approximate expectations in the definitions of mu_t and nu_t ---
                m_mc = 30

                Z = np.zeros((nx, nx))
                U_list = []
                den_list = []
                Ztil_list = []

                for _ in range(m_mc):
                    s = np.random.normal(size=nx)
                    u = B @ s
                    denom = float(u @ u)
                    #denom = max(denom, EPS)

                    Ztil = np.outer(u, u) / denom

                    U_list.append(u)
                    den_list.append(denom)
                    Ztil_list.append(Ztil)
                    Z += Ztil

                Z /= m_mc



                evals, evecs = np.linalg.eigh(Z)
                evals = np.maximum(evals, 1e-12)
                mu_t = float(evals.min())


                invZ = (evecs * (1.0 / evals)) @ evecs.T
                inv_sqrtZ = (evecs * (1.0 / np.sqrt(evals))) @ evecs.T


                M = np.zeros((nx, nx))
                for u, denom, Ztil in zip(U_list, den_list, Ztil_list):
                    scalar = float(u @ (invZ @ u)) / denom
                    M += scalar * Ztil
                M /= m_mc

                Tmat = inv_sqrtZ @ M @ inv_sqrtZ
                nu_t_val = float(np.linalg.eigvalsh(Tmat).max())

                # --- Algorithm-1 parameters (paper) ---
                beta_acc  = 1.0 - np.sqrt(mu_t / nu_t_val)
                gamma_acc = np.sqrt(1.0 / (mu_t * nu_t_val))
                alpha_acc = 1.0 / (1.0 + gamma_acc * nu_t_val)

                # --- accelerated inner iterations ---
                dx = np.zeros(nx)
                v  = np.zeros(nx)
                for _ in range(ttau):
                    dy = alpha_acc * v + (1.0 - alpha_acc) * dx

                    s = np.random.normal(size=nx)
                    u = B @ s
                    denom = float(u @ u)


                    numer = float(s @ (B @ dy + barg_t))
                    omega = (numer / denom) * u

                    dx = dy - omega
                    v  = beta_acc * v + (1.0 - beta_acc) * dy - gamma_acc * omega

                NewDir_t = dx


            else:
                raise ValueError("solver must be 'vanilla' or 'nesterov'")

        # Step 3: update
        X_t = X_t + beta_t * NewDir_t

        # Step 4: covariance estimation
        if t > Buff:
            W_t = (t - Buff) / (t - Buff + 1) * W_t + 1 / (t - Buff + 1) * np.outer(X_t, X_t)/beta_t
            nu_t = (t - Buff) / (t - Buff + 1) * nu_t + (1 / (t - Buff + 1)) * X_t / beta_t
            u_t = (t - Buff) / (t - Buff + 1) * u_t + 1 / (t - Buff + 1) * (1 / beta_t)
            AvgX_t = (t - Buff) / (t - Buff + 1) * AvgX_t + 1 / (t - Buff + 1) * X_t 

        ErrAvgX.append(np.linalg.norm(AvgX_t - X_true))
        ErrX.append(np.linalg.norm(X_t - X_true))

        t += 1

    # Sample Covariance for last iterates
    Xi_t_SC = W_t - np.outer(nu_t, AvgX_t) - np.outer(AvgX_t, nu_t) + u_t * np.outer(AvgX_t, AvgX_t)
    COV_value_SC = np.sum(Xi_t_SC) / nx ** 2
    radius_SC = 1.96 * np.sqrt(beta_t) * np.sqrt(COV_value_SC)

    SClength.append(2 * radius_SC)
    IdSC = (np.sum(X_true) / nx >= (
            np.sum(X_t)  / nx - radius_SC)) and \
            (np.sum(X_true) / nx <= (np.sum(X_t) / nx + radius_SC))
    
    CoverageSC.append(IdSC)

    print("ErrX = ", ErrX[-1])
    print("ErrAvgX = ", ErrAvgX[-1])
    print("IdSC = ", IdSC, "; length = ", 2*radius_SC)

    return np.array(ErrX), IdSC


c1, c2 = 1.0, 0.501
Max_Iter = 100000
Exp_num = 200

ttau_list = [0, 5, 10]
nx_list = [20, 40]
IdR_list = [0, 1, 2]
RR = 0.4


filename = f"Logistic_regression_Gaussian_Output.txt"

for i1 in range(len(nx_list)):   # nx = 20, 40
    nx = nx_list[i1]
    X_true = np.linspace(0, 1, nx)
    for i2 in range(len(IdR_list)):   # IdR = 0, 1, 2
        IdR = IdR_list[i2]
        for i3 in range(len(ttau_list)):   # tau = 0, 5, 10
            CoverageSC = []
            SClength = []
            ttau = ttau_list[i3]
            Last_iterate_error = 0

            for i4 in range(Exp_num):
                print(i4, "st iteration:  ttau = ", ttau, "; nx = ", nx, "; Covariance matrix with IdR = ", IdR, ".")
                ErrX, IdSC = StoNewton(c1, c2, Max_Iter, ttau, nx, X_true, RR, IdR, solver="nesterov")
                Last_iterate_error += ErrX[-1]
                print("------------")

            Last_iterate_error = Last_iterate_error / Exp_num

            print("CoverageSC = ", sum(CoverageSC) / len(CoverageSC), ", Length = ", np.mean(SClength), ", (var = " + str(np.var(SClength)) + ")" )
            print("Last error = ", Last_iterate_error)

            with open(filename, 'a') as file:
                file.write("nx = " + str(nx) + "; ttau = " + str(ttau) + "; ")
                file.write( "IdR = " + str(IdR) + "; \n")
                file.write("Last_iterate_error = " + str(Last_iterate_error) +"; \n")
                file.write("1. Coverage_last_sample_covariance = " + str(
                    sum(CoverageSC) / len(CoverageSC)) + "; Average Length = " + str(
                    np.mean(SClength)) + "; (var = " + str(np.var(SClength)) + ")"  + "\n")
                file.write("------------------------------------------------------------------  \n \n ")



