import numpy as np
from scipy import optimize
from scipy.sparse import eye
from sklearn.utils.extmath import safe_sparse_dot

def two_loops(grad_x, m, s_list, y_list, mu_list,h0, B0):
    q = grad_x.copy()
    alpha_list = []
    for s, y, mu in zip(reversed(s_list), reversed(y_list), reversed(mu_list)):
        alpha = mu * safe_sparse_dot(s, q)
        alpha_list.append(alpha)
        q -= alpha * y
    #r = safe_sparse_dot(B0, q)
    r=h0*q
    for s, y, mu, alpha in zip(s_list, y_list, mu_list, reversed(alpha_list)):
        beta = mu * safe_sparse_dot(y, r)
        r += (alpha - beta) * s
    return -r

def blfoa1(
        x0,
        f,
        f_grad,
        f_hessian,
        max_iter=100,
        m=2,
        tol=1e-6,
        tol_norm=None,
        maxls=10,
        inverse_direction_fun=None,
        warm_restart_lists=None,
        ito=10,
        maxiter_hg=10,
        h0=1
):
    
    if tol_norm is None:
        tol_norm = lambda x: np.max(np.abs(x))

    all_x_k, all_f_k = list(), list()
    x = x0

    all_x_k.append(x.copy())
    all_f_k.append(f(x))
    B0 = eye(len(x))  # Hessian approximation
    grad_x = f_grad(x)
    

    y_list, s_list, mu_list = [], [], []
    y1_list, s1_list, mu1_list = [], [], []
    if warm_restart_lists is not None:
        s_list, y_list, mu_list = warm_restart_lists
    for k in range(1, max_iter + 1):
        
        if k<2:
            d=-grad_x
            s = (0.0001/k) * d
            
            x += s
            new_grad = f_grad(x)
            
        else:
            d = two_loops(grad_x, m, s_list, y_list, mu_list,h0, B0)
            s= 0.1*d
            x += s
            
            new_grad = f_grad(x)

            y = new_grad - grad_x
            if safe_sparse_dot(y, s)>1e-10:   
              mu = 1 / safe_sparse_dot(y, s)
              # Update the memory
              y_list.append(y.copy())
              s_list.append(s.copy())
              mu_list.append(mu)
            if len(y_list) > m:
               y_list.pop(0)
               s_list.pop(0)
               mu_list.pop(0)
        
        all_x_k.append(x.copy())
        """
        l_inf_norm_grad = tol_norm(new_grad)
        if l_inf_norm_grad < tol:
                break
        """
        grad_x=new_grad

    hess_inv = lambda x: - two_loops(x, m, s_list, y_list, mu_list,h0, B0) 
    warm_restart_lists = [s_list, y_list, mu_list]
    print(f'{k} iterates')
    return np.array(all_x_k), hess_inv, warm_restart_lists
