import numpy as np
from scipy import optimize
from scipy.sparse import eye
from sklearn.utils.extmath import safe_sparse_dot

def two_loops(grad_x, m, s_list, y_list, h0,B0):
    """
    Parameters
    ----------
    grad_x : ndarray, shape (n,)
        gradient at the current point

    m : int
        memory size

    s_list : list of length m
        the past m values of s

    y_list : list of length m
        the past m values of y

    mu_list : list of length m
        the past m values of mu

    B0 : ndarray, shape (n, n)
        Initial inverse Hessian guess

    Returns
    -------
    r :  ndarray, shape (n,)
        the sr1 direction
    """

    q = grad_x.copy()
    p_list = []
    r=h0*q
    for s, y in zip(s_list, y_list):
        p=s-h0*y     #p_i=s_i-H0y_i
        i=len(p_list)
        for k in range(i):
            p = p-(safe_sparse_dot(np.ravel(p_list[k]), np.ravel(y)))/(safe_sparse_dot(np.ravel(p_list[k]), np.ravel(y_list[k])))*p_list[k]
        p_list.append(p)
    for p, y in zip(p_list, y_list):
            r = r+(safe_sparse_dot(np.ravel(p),np.ravel(q)))/(safe_sparse_dot(np.ravel(p), np.ravel(y)))*p
    return -r

def isr(
        x0,
        f,
        f_grad,
        f_hessian,
        max_iter=100,
        m=2,
        tol=1e-6,
        tol_norm=None,
        sz=0.1,
        inverse_direction_fun=None,
        warm_restart_lists=None,
        ws=4,
        maxiter_hg=10,
        h0=1,
        ex_up=False
    ):
    
    if tol_norm is None:
        tol_norm = lambda x: np.max(np.abs(x))

    all_x_k, all_f_k = list(), list()
    x = x0

    all_x_k.append(x.copy())
    all_f_k.append(f(x))

    B0 = eye(len(x))  # Hessian approximation

    grad_x = f_grad(x)
    

    y_list, s_list=[],[]
    y1_list, s1_list, mu1_list = [], [], []
    #s_list, y_list, mu_list = warm_restart_lists
    if warm_restart_lists is not None:
        s_list, y_list = warm_restart_lists
    for k in range(1, max_iter + 1):
        
        if k<ws:
            d=-grad_x
            s = (0.0001/k) * d
            x += s
            new_grad = f_grad(x)
            
        else:
            d = two_loops(grad_x, m, s_list, y_list,h0, B0)
            """
            step, _, _, new_f, _, new_grad = optimize.line_search(f, f_grad, x,
                                                              d, grad_x,
                                                              c1=c1, c2=c2,
                                                              maxiter=maxls)
            if step is None or new_grad is None:
                step = default_step
            s = step * d
            """
            s= sz*d
            x += s
            
            new_grad = f_grad(x)

            y = new_grad - grad_x
            if safe_sparse_dot(y, s)>1e-8:
              # Update the memory
              y_list.append(y.copy())
              s_list.append(s.copy())
            if len(y_list) > m:
               y_list.pop(0)
               s_list.pop(0)
        all_x_k.append(x.copy())
        l_inf_norm_grad = tol_norm(new_grad)
        """
        if l_inf_norm_grad < tol:
                break
        """
        grad_x=new_grad
    

    if ex_up==True:
      for i in range(1, maxiter_hg + 1):
        #s1_list, y1_list, mu1_list = [s_list, y_list, mu_list]
        inverse_direction = inverse_direction_fun(x)
        e = - two_loops(inverse_direction, m, s1_list, y1_list,h0, B0)
        #e = e / np.linalg.norm(e)
        y_tilde1 = f_grad(x + e) - grad_x 
        if safe_sparse_dot(y_tilde1, e)>1e-8:
          y1_list.append(y_tilde1.copy())
          s1_list.append(e.copy())
         
        if len(y1_list) > m:
            y1_list.pop(0)
            s1_list.pop(0)   
    else:
       hess_inv = lambda x: - two_loops(x, m, s_list, y_list,h0, B0) 
       warm_restart_lists = [s_list, y_list]
    print(f'{k} iterates')
    if ex_up==True:
       return np.array(all_x_k), e 
    else:
       return np.array(all_x_k), hess_inv, warm_restart_lists
    
    
    
