import numpy as np
from scipy.linalg import lu, inv, norm
from numpy.random import uniform, normal, multivariate_normal
import time
import matplotlib.pyplot as plt


def StoNewton(c1, c2, Max_Iter, ttau, nx, X_true, RR, IdR, EPS=1e-8,
              solver="nesterov"):

    Buff = 30000

    if IdR == 1:
        Sigma = np.array([[RR ** abs(i - j) for j in range(nx)] for i in range(nx)])
    elif IdR == 0:
        Sigma = RR * np.ones((nx, nx)) + (1 - RR) * np.eye(nx) + np.eye(nx)
    elif IdR == 2:
        Sigma = np.eye(nx)
    else:
        raise ValueError("Unsupported IdR")


    t = 1
    X_t = np.ones(nx)
    AvgX_t = np.zeros(nx)
    
    # initial for sample covariance
    W_t = np.zeros((nx, nx))
    nu_t = np.zeros(nx)
    u_t = 0


    cum_bar2x_f_t = np.eye(nx)
    ErrX, ErrAvgX = [], []

    while t <= Max_Iter + 1:
        K_t = cum_bar2x_f_t
        beta_t = c1 / (t ** c2)

        # Step 1: sample
        a_t = multivariate_normal(np.zeros(nx), Sigma)
        eps_t = normal(0.0, 1.0)
        barg_t = a_t * (a_t @ (X_t - X_true)) - eps_t * a_t

        bar_nab_x2f_t = np.outer(a_t, a_t)
        cum_bar2x_f_t = (t / (t + 1)) * cum_bar2x_f_t + (1 / (t + 1)) * bar_nab_x2f_t

        # Step 2: inexact Newton solve
        if ttau == 0:
            NewDir_t = np.linalg.solve(K_t, -barg_t)
        else:
            if solver == "vanilla":
                NewDir_t = np.zeros(nx)

                for inner_iter in range(ttau):
                    # Randomly pick an index
                    j = np.random.choice(range(nx))
                    K2_t = K_t @ K_t

                    NewDir_t = NewDir_t - (K_t[j, :] @ NewDir_t + barg_t[j]) * K_t[:, j] / (K2_t[j, j])

            elif solver == "nesterov":
                B = K_t
                B2 = K_t @ K_t
                denom = np.diag(B2)


                Z = np.zeros((nx, nx))
                Zj_list = []
                for j in range(nx):
                    Zj = np.outer(B[:, j], B[j, :]) / denom[j]
                    Zj_list.append(Zj)
                    Z += Zj
                Z /= nx

                evals, evecs = np.linalg.eigh(Z)
                mu_t = float(evals.min())

                inv_evals = 1.0 / evals
                inv_sqrt_evals = 1.0 / np.sqrt(evals)
                invZ = (evecs * inv_evals) @ evecs.T
                inv_sqrtZ = (evecs * inv_sqrt_evals) @ evecs.T

                M = np.zeros((nx, nx))
                for Zj in Zj_list:
                    M += Zj @ (invZ @ Zj)
                M /= nx

                Tmat = inv_sqrtZ @ M @ inv_sqrtZ
                nu_t_val = float(np.linalg.eigvalsh(Tmat).max())
               

                # --- Algorithm-1 parameters ---
                beta_acc  = 1.0 - np.sqrt(mu_t / nu_t_val)
                gamma_acc = np.sqrt(1.0 / (mu_t * nu_t_val))
                alpha_acc = 1.0 / (1.0 + gamma_acc * nu_t_val)

                dx = np.zeros(nx)
                v  = np.zeros(nx)
                for _ in range(ttau):
                    dy = alpha_acc * v + (1.0 - alpha_acc) * dx
                    j = np.random.choice(range(nx))
                    numer = B[j, :] @ dy + barg_t[j]
                    omega = (numer / denom[j]) * B[:, j]
                    dx = dy - omega
                    v  = beta_acc * v + (1.0 - beta_acc) * dy - gamma_acc * omega

                NewDir_t = dx

            else:
                raise ValueError("solver must be 'vanilla' or 'nesterov'")

        # Step 3: update
        X_t = X_t + beta_t * NewDir_t

        # Step 4: covariance estimation
        if t > Buff:
            W_t = (t - Buff) / (t - Buff + 1) * W_t + 1 / (t - Buff + 1) * np.outer(X_t, X_t)/beta_t
            nu_t = (t - Buff) / (t - Buff + 1) * nu_t + (1 / (t - Buff + 1)) * X_t / beta_t
            u_t = (t - Buff) / (t - Buff + 1) * u_t + 1 / (t - Buff + 1) * (1 / beta_t)
            AvgX_t = (t - Buff) / (t - Buff + 1) * AvgX_t + 1 / (t - Buff + 1) * X_t 

        ErrAvgX.append(np.linalg.norm(AvgX_t - X_true))
        ErrX.append(np.linalg.norm(X_t - X_true))

        t += 1

    # Sample Covariance for last iterates
    Xi_t_SC = W_t - np.outer(nu_t, AvgX_t) - np.outer(AvgX_t, nu_t) + u_t * np.outer(AvgX_t, AvgX_t)
    COV_value_SC = np.sum(Xi_t_SC) / nx ** 2
    radius_SC = 1.96 * np.sqrt(beta_t) * np.sqrt(COV_value_SC)

    SClength.append(2 * radius_SC)
    IdSC = (np.sum(X_true) / nx >= (
            np.sum(X_t)  / nx - radius_SC)) and \
            (np.sum(X_true) / nx <= (np.sum(X_t) / nx + radius_SC))
    
    CoverageSC.append(IdSC)

    print("ErrX = ", ErrX[-1])
    print("ErrAvgX = ", ErrAvgX[-1])
    print("IdSC = ", IdSC, "; length = ", 2*radius_SC)

    return np.array(ErrX), IdSC


c1, c2 = 1.0, 0.501
Max_Iter = 100000
Exp_num = 200

ttau_list = [0, 5, 10]
nx_list = [20, 40]
IdR_list = [0, 1, 2]
RR = 0.4


filename = f"Linear_regression_Kazmaz_Output.txt"

for i1 in range(len(nx_list)):   # nx = 20, 40
    nx = nx_list[i1]
    X_true = np.linspace(0, 1, nx)
    for i2 in range(len(IdR_list)):   # IdR = 0, 1, 2
        IdR = IdR_list[i2]
        for i3 in range(len(ttau_list)):   # tau = 0, 5, 10
            CoverageSC = []
            SClength = []
            ttau = ttau_list[i3]
            Last_iterate_error = 0

            for i4 in range(Exp_num):
                print(i4, "st iteration:  ttau = ", ttau, "; nx = ", nx, "; Covariance matrix with IdR = ", IdR, ".")
                ErrX, IdSC = StoNewton(c1, c2, Max_Iter, ttau, nx, X_true, RR, IdR, solver="nesterov")
                Last_iterate_error += ErrX[-1]
                print("------------")

            Last_iterate_error = Last_iterate_error / Exp_num

            print("CoverageSC = ", sum(CoverageSC) / len(CoverageSC), ", Length = ", np.mean(SClength), ", (var = " + str(np.var(SClength)) + ")" )
            print("Last error = ", Last_iterate_error)

            with open(filename, 'a') as file:
                file.write("nx = " + str(nx) + "; ttau = " + str(ttau) + "; ")
                file.write( "IdR = " + str(IdR) + "; \n")
                file.write("Last_iterate_error = " + str(Last_iterate_error) +"; \n")
                file.write("1. Coverage_last_sample_covariance = " + str(
                    sum(CoverageSC) / len(CoverageSC)) + "; Average Length = " + str(
                    np.mean(SClength)) + "; (var = " + str(np.var(SClength)) + ")"  + "\n")
                file.write("------------------------------------------------------------------  \n \n ")
