import numpy as np
from scipy import optimize
from scipy.sparse import eye
from sklearn.utils.extmath import safe_sparse_dot

def two_loops(grad_x, m, s_list, y_list, mu_list, B0):
    '''
    Parameters
    ----------
    grad_x : ndarray, shape (n,)
        gradient at the current point

    m : int
        memory size

    s_list : list of length m
        the past m values of s

    y_list : list of length m
        the past m values of y

    mu_list : list of length m
        the past m values of mu

    B0 : ndarray, shape (n, n)
        Initial inverse Hessian guess

    Returns
    -------
    r :  ndarray, shape (n,)
        the L-BFGS direction
    '''
    q = grad_x.copy()
    alpha_list = []
    for s, y, mu in zip(reversed(s_list), reversed(y_list), reversed(mu_list)):
        alpha = mu * safe_sparse_dot(s, q)
        alpha_list.append(alpha)
        q -= alpha * y
    r = safe_sparse_dot(B0, q)
    for s, y, mu, alpha in zip(s_list, y_list, mu_list, reversed(alpha_list)):
        beta = mu * safe_sparse_dot(y, r)
        r += (alpha - beta) * s
    return -r

def blfoa(
        x0,
        f,
        f_grad,
        f_hessian,
        max_iter=100,
        m=2,
        tol=1e-6,
        tol_norm=None,
        maxls=10,
        inverse_direction_fun=None,
        warm_restart_lists=None,
        ito=10,
):
    default_step = 0.01
    c1 = 0.0001
    c2 = 0.0009
    if tol_norm is None:
        tol_norm = lambda x: np.max(np.abs(x))

    all_x_k, all_f_k = list(), list()
    x = x0

    all_x_k.append(x.copy())
    all_f_k.append(f(x))

    B0 = eye(len(x))  # Hessian approximation

    grad_x = f_grad(x)

    y_list, s_list, mu_list = [], [], []
    #s_list, y_list, mu_list = warm_restart_lists
    if warm_restart_lists is not None:
        s_list, y_list, mu_list = warm_restart_lists
    for k in range(1, max_iter + 1):
        d=f_grad(x) 
        #x=x-(0.01/(k+1))*d 
        x=x-(0.01/((k+1)**0.5))*d 
        inverse_direction = inverse_direction_fun(x)
        e = - two_loops(inverse_direction, m, s_list, y_list, mu_list, B0)
        e = e / np.linalg.norm(e)
        y_tilde = f_grad(x + e) - f_grad(x) 
        mu = 1 / safe_sparse_dot(y_tilde, e)
        y_list.append(y_tilde.copy())
        s_list.append(e.copy())
        mu_list.append(mu)
        if len(y_list) > m:
            y_list.pop(0)
            s_list.pop(0)
            mu_list.pop(0)
              
        all_x_k.append(x.copy())
        l_inf_norm_grad = tol_norm(f_grad(x))
        if l_inf_norm_grad < tol:
                break
    warm_restart_lists = [s_list, y_list, mu_list]
    print(f'{k} iterates')
    return np.array(all_x_k), e, warm_restart_lists
