import numpy as np
from scipy import optimize
from scipy.sparse import eye
from sklearn.utils.extmath import safe_sparse_dot

def two_loops(grad_x, m, s_list, y_list, mu_list, h0,B0):
    '''
    Parameters
    ----------
    grad_x : ndarray, shape (n,)
        gradient at the current point

    m : int
        memory size

    s_list : list of length m
        the past m values of s

    y_list : list of length m
        the past m values of y

    mu_list : list of length m
        the past m values of mu

    B0 : ndarray, shape (n, n)
        Initial inverse Hessian guess

    Returns
    -------
    r :  ndarray, shape (n,)
        the L-BFGS direction
    '''
    q = grad_x.copy()
    alpha_list = []
    for s, y, mu in zip(reversed(s_list), reversed(y_list), reversed(mu_list)):
        alpha = mu * safe_sparse_dot(s, q)
        alpha_list.append(alpha)
        q -= alpha * y
    #r = safe_sparse_dot(B0, q)
    r=h0*q
    for s, y, mu, alpha in zip(s_list, y_list, mu_list, reversed(alpha_list)):
        beta = mu * safe_sparse_dot(y, r)
        r += (alpha - beta) * s
    return -r

def blfoaa(
        x0,
        f,
        f_grad,
        f_hessian,
        max_iter=100,
        m=2,
        tol=1e-6,
        tol_norm=None,
        maxls=10,
        inverse_direction_fun=None,
        warm_restart_lists=None,
        ito=10,
        maxiter_hg=10,
        h0=1
):
    default_step = 0.01
    c1 = 0.0001
    c2 = 0.0009
    if tol_norm is None:
        tol_norm = lambda x: np.max(np.abs(x))

    all_x_k, all_f_k = list(), list()
    x = x0

    all_x_k.append(x.copy())
    all_f_k.append(f(x))

    B0 = eye(len(x))  # Hessian approximation

    grad_x = f_grad(x)
    

    y_list, s_list, mu_list = [], [], []
    y1_list, s1_list, mu1_list = [], [], []
    if warm_restart_lists is not None:
        s_list, y_list, mu_list = warm_restart_lists
    for k in range(1, max_iter + 1):
        
        if k<2:
            d=-grad_x
            s = (0.0001/k) * d
            x += s
            new_grad = f_grad(x)
            
        else:
            d = two_loops(grad_x, m, s_list, y_list, mu_list,h0, B0)
            s= 0.1*d
            x += s
            
            new_grad = f_grad(x)

            y = new_grad - grad_x
            if safe_sparse_dot(y, s)>1e-8:
              mu = 1 / safe_sparse_dot(y, s)
              # Update the memory
              y_list.append(y.copy())
              s_list.append(s.copy())
              mu_list.append(mu)
            if len(y_list) > m:
               y_list.pop(0)
               s_list.pop(0)
               mu_list.pop(0)
        """
        inverse_direction = inverse_direction_fun(x)
        e = - two_loops(inverse_direction, m, s1_list, y1_list, mu1_list, B0)
        e = e / np.linalg.norm(e)
        y_tilde1 = f_grad(x + e) - f_grad(x) 
        mu1 = 1 / safe_sparse_dot(y_tilde1, e)
        y1_list.append(y_tilde1.copy())
        s1_list.append(e.copy())
        mu1_list.append(mu1)
        if len(y1_list) > m:
            y1_list.pop(0)
            s1_list.pop(0)
            mu1_list.pop(0)
        """
        all_x_k.append(x.copy())
        """
        l_inf_norm_grad = tol_norm(new_grad)
       
        if l_inf_norm_grad < tol:
                break
        """
        grad_x=new_grad
        
    
    for i in range(1, maxiter_hg + 1):
        #s1_list, y1_list, mu1_list = [s_list, y_list, mu_list]
        inverse_direction = inverse_direction_fun(x)
        e = - two_loops(inverse_direction, m, s1_list, y1_list, mu1_list,h0, B0)
        #e = e / np.linalg.norm(e)
        y_tilde1 = f_grad(x + e) - grad_x 
        if safe_sparse_dot(y_tilde1, e)>1e-8:
          mu1 = 1 / safe_sparse_dot(y_tilde1, e)
          y1_list.append(y_tilde1.copy())
          s1_list.append(e.copy())
          mu1_list.append(mu1)
        if len(y1_list) > m:
            y1_list.pop(0)
            s1_list.pop(0)
            mu1_list.pop(0)
    
    warm_restart_lists = [s_list, y_list, mu_list]
    print(f'{k} iterates')
    return np.array(all_x_k), e, warm_restart_lists
