import time

import copy

import numpy as np
from numpy import array, asarray, float64, int32, zeros
from scipy import linalg
from scipy.optimize import _lbfgsb, LbfgsInvHessProduct
from scipy.sparse import linalg as splinalg

from hoag.lbfgs import lbfgs
from hoag.tn import TN
from hoag.cg import CG
from hoag.blfoaa import blfoaa
from hoag.isr import isr
from hoag.blfoa1 import blfoa1
from hoag.minibatch import MinibatchSampler
from hoag.learningrate import LearningRateScheduler


from sklearn.utils.extmath import safe_sparse_dot
from scipy.sparse import csr_matrix 

def variance_reduction(grad, memory, vr_info):
    idx, weigth = vr_info
    diff = grad - memory[idx]
    direction = diff + memory[-1]
    memory[-1] += diff * weigth
    memory[idx, :] = grad
    return direction

def hoag_lbfgs(
    h0,h_func_grad, h_hessian, h_crossed, g_func_grad, x0,Xt,yt,Xh,yh, bounds=None,
    lambda0=0., disp=None, maxcor=10,
    maxiter=100, maxiter_inner=10000, maxiter_backward=100,
    only_fit=False,
    iprint=-1, maxls=20, tolerance_decrease='exponential',
    callback=None, verbose=0, epsilon_tol_init=1e-3, exponential_decrease_factor=0.9,
    projection=None,  shine=False,refine=False, fpn=False, bome=False,bsg1=False,foa=False,sr1=False,
    refine_exp=0.5, pure_python=False,opa=False,f2sa=False,sz=False,foa1=False,saba=False,aidtn=False,aidcg=True,PZOBO=False,amicg=False ,**kwargs):
    """
    HOAG algorithm using L-BFGS-B in the inner optimization algorithm.

    Options
    -------
    eps : float
        Step size used for numerical approximation of the Jacobian.
    disp : int
        Set to True to print convergence messages.
    maxfun : int
        Maximum number of function evaluations.
    maxiter : int
        Maximum number of iterations.
    maxls : int, optional
        Maximum number of line search steps (per iteration). Default is 20.
    """
    m = maxcor
    lambdak = lambda0

    x0 = asarray(x0).ravel()
    n, = x0.shape

    if bounds is None:
        bounds = [(None, None)] * n
    if len(bounds) != n:
        raise ValueError('length of x0 != length of bounds')
    # unbounded variables must use None, not +-inf, for optimizer to work properly
    bounds = [(None if l == -np.inf else l, None if u == np.inf else u) for l, u in bounds]

    if disp is not None:
        if disp == 0:
            iprint = -1
        else:
            iprint = disp

    nbd = zeros(n, int32)
    low_bnd = zeros(n, float64)
    upper_bnd = zeros(n, float64)
    bounds_map = {(None, None): 0,
                  (1, None): 1,
                  (1, 1): 2,
                  (None, 1): 3}
    for i in range(0, n):
        l, u = bounds[i]
        if l is not None:
            low_bnd[i] = l
            l = 1
        if u is not None:
            upper_bnd[i] = u
            u = 1
        nbd[i] = bounds_map[l, u]

    if not maxls > 0:
        raise ValueError('maxls must be positive.')

    x = array(x0, float64)
    wa = zeros(2*m*n + 5*n + 11*m*m + 8*m, float64)
    iwa = zeros(3*n, int32)
    task = zeros(1, 'S60')
    csave = zeros(1, 'S60')
    lsave = zeros(4, int32)
    isave = zeros(44, int32)
    dsave = zeros(29, float64)

    exact_epsilon = 1e-12
    if tolerance_decrease == 'exact':
        epsilon_tol = exact_epsilon
    else:
        epsilon_tol = epsilon_tol_init

    Bxk = None
    L_lambda = None
    g_func_old = np.inf

    if callback is not None:
        callback(x, lambdak)

    # n_eval, F = wrap_function(F, ())
    h_func, h_grad = h_func_grad(x, lambdak)
    norm_init = linalg.norm(h_grad)
    old_grads = []
    old_lambdak = lambdak.copy()
    warm_restart_lists = None
    if bome:
        zz=np.zeros(1)
        #xhat=copy.deepcopy(x)
        sz=0.01
        for it in range(1, maxiter):
           xhat=copy.deepcopy(x)
           for itn in range(10):
               xhat_grad=h_func_grad(xhat, lambdak)[1]#check
               xhat=xhat-sz*xhat_grad
            #prepare gradients
           gx=g_func_grad(x, lambdak)[1]#upper w.r.t LL variable
           loss=h_func_grad(x, lambdak)[0]-h_func_grad(xhat, lambdak)[0]
           hx=h_func_grad(x, lambdak)[1]
           hl_minus_hl_k=np.exp(lambdak[0])*0.5*(np.dot(x,x)-np.dot(xhat,xhat))

           dg=np.hstack([gx,zz])#hstack or vstack
           dh=np.hstack([hx,hl_minus_hl_k])
           norm_dq=np.linalg.norm(dh)**2
           dot=safe_sparse_dot(dg,dh)
           lmbd=np.max([(0.5*loss-dot)/(norm_dq+1e-8),0])

           x_grad=gx+lmbd*hx
           l_grad=lmbd*hl_minus_hl_k #g_lambda=0
           
           #sz=0.1/(it+1)
           
           x=x-sz*x_grad
           lambdak=lambdak-0.1*l_grad

            # projection
           
           lambdak[lambdak < -12] = -12
           lambdak[lambdak > 12] = 12
           
            #norm_grad_lambda = linalg.norm(grad_lambda)
           if callback is not None:
              callback(x, lambdak)
        return x, lambdak
#     elif aidtn:
#         # raise
#         alpha = 0.01
#         v = np.zeros(1)
#         for it in range(1, 2):
#         # for it in range(1, maxiter):
#             xhat=copy.deepcopy(x)
#             for itn in range(10):
#                 xhat_grad = h_func_grad(xhat, lambdak)[1]  # check
#                 xhat = xhat - alpha * xhat_grad

#             # Prepare gradients
#             gx = g_func_grad(x, lambdak)[1]  # upper w.r.t LL variable
#             hxx = h_hessian(x, lambdak)
#             # print("hxx:", hxx)
#             v = TN(hxx, gx)
#             h_crossed(x, lambdak)
#             l_grad = -h_crossed(x, lambdak)
#             lambdak = lambdak - 0.1 * l_grad

#             # Projection
#             lambdak[lambdak < -12] = -12
#             lambdak[lambdak > 12] = 12

#             # Callback function
#             if callback is not None:
#                 callback(x, lambdak)

#         return x, lambdak
#     elif aidcg:
#         alpha = 0.01
#         beta = 0.2
#         v = np.zeros(1)
#         print("aidcg")
#         for it in range(1, 2):

#         # for it in range(1, maxiter):
#             xhat=copy.deepcopy(x)
#             for itn in range(10):
#                 xhat_grad = h_func_grad(xhat, lambdak)[1]  # check
#                 xhat = xhat - alpha * xhat_grad

#             # Prepare gradients
#             gx = g_func_grad(x, lambdak)[1]  # upper w.r.t LL variable
#             hxx = h_hessian(x, lambdak)
#             # print("hxx:", hxx)
#             v = CG(hxx, gx,v, 1)
            
            
#             h_crossed(x, lambdak)
            
#             l_grad = -h_crossed(x, lambdak)
#             lambdak = lambdak - beta * l_grad

#             # Projection
#             lambdak[lambdak < -12] = -12
#             lambdak[lambdak > 12] = 12

#             # Callback function
#             if callback is not None:
#                 callback(x, lambdak)

#         return x, lambdak
    elif aidcg:
        alpha = 0.01
        beta = 0.02
        v = np.zeros(1)
        print("aidcg")
        for it in range(1, maxiter):
            x =copy.deepcopy(x)
            for itn in range(1):
                xhat_grad = h_func_grad(x, lambdak)[1]  # check
                x = x - alpha * xhat_grad

            # Prepare gradients
            gx = g_func_grad(x, lambdak)[1]  # upper w.r.t LL variable
            hxx = h_hessian(x, lambdak)
            # print("hxx:", hxx)
            v = CG(hxx, gx, v, 1)
            
            
            h_crossed(x, lambdak)
            
            l_grad = h_crossed(x, lambdak) * v
            lambdak = lambdak - beta * l_grad
            
            print(f"l_grad: {l_grad}, lambdak: {lambdak}, x: {x}")
            # Projection
            lambdak[lambdak < -12] = -12
            lambdak[lambdak > 12] = 12

            # Callback function
            if callback is not None:
                callback(x, lambdak)

        return x, lambdak  
#     elif amicg:
#         alpha = 0.01
#         beta = 0.01
#         v = np.zeros(1)
#         print("aaaa")
#         for it in range(1, 2):

#         # for it in range(1, maxiter):
#             xhat=copy.deepcopy(x)
#             for itn in range(10):
#                 xhat_grad = h_func_grad(xhat, lambdak)[1]  # check
#                 xhat = xhat - alpha * xhat_grad

#             # Prepare gradients
#             gx = g_func_grad(x, lambdak)[1]  # upper w.r.t LL variable
#             hxx = h_hessian(x, lambdak)
#             # print("hxx:", hxx)
#             v = CG(hxx, gx, v, 1)
            
            
#             h_crossed(x, lambdak)
            
#             l_grad = -h_crossed(x, lambdak)
#             lambdak = lambdak - beta * l_grad

#             # Projection
#             lambdak[lambdak < -12] = -12
#             lambdak[lambdak > 12] = 12

#             # Callback function
#             if callback is not None:
#                 callback(x, lambdak)

#         return x, lambdak
    
#     elif PZOBO:
#         alpha, beta, mu, Q, N = 0.01, 0.01, 100, 10, 10
#         v = np.zeros(1)
#         print("PZOBO")
        
#         x_kj_t = x.copy()

#         sum_term = np.zeros_like(x)
        
#         # for it in range(1, maxiter):
#         for it in range(1, 2):

#             xhat=copy.deepcopy(x)
#             for itn in range(N):
#                 xhat_grad = h_func_grad(xhat, lambdak)[1]  # check
#                 xhat = xhat - alpha * xhat_grad
#             # second loop
            
#             # for itn in range(Q):
#             for itn in range(1):
#                 u_kj = np.random.randn(*lambdak.shape)
#                 for t in range(1, N + 1):
#                     # Prepare gradients
#                     gx = g_func_grad(x_kj_t, lambdak+mu*u_kj)[1]  # upper w.r.t LL variable
#                     x_kj_t = x_kj_t - alpha * gx
                    
#                 delta_j = (x_kj_t - lambdak) /mu
#                 gx = g_func_grad(x, lambdak)[1]
#                 inner = (delta_j.T @ gx)
#                 sum_term += inner * u_kj
                
#             f_lambdak = np.exp(lambdak[0])*(np.dot(x,x))
#             l_grad = f_lambdak + (1/Q)*sum_term

#             lambdak = lambdak - beta * l_grad

#             # Projection
#             lambdak[lambdak < -12] = -12
#             lambdak[lambdak > 12] = 12

#             # Callback function
#             if callback is not None:
#                 callback(x, lambdak)

#         return x, lambdak
    
    

    elif bsg1:
        for it in range(1, maxiter):
           for itn in range(1,maxiter_inner):
               x_grad=h_func_grad(x, lambdak)[1]#check
               x=x-0.01*x_grad

            #prepare gradients
           gx=g_func_grad(x, lambdak)[1]#upper w.r.t LL variable 
           hx=h_func_grad(x, lambdak)[1]
           hl=0.5*np.exp(lambdak[0])*(np.dot(x,x))
        
           l_grad=(np.dot(hx,gx)/(np.dot(hx,hx)))*hl#g_lambda=0

           lambdak=lambdak-0.01*l_grad
           
            # projection
           lambdak[lambdak < -12] = -12
           lambdak[lambdak > 12] = 12
           if callback is not None:
             callback(x, lambdak)
        return x, lambdak
    elif saba:
       batch_size=32
       n_samplesi,n_features= Xt.shape
       n_sampleso,_= Xh.shape
       outer_sampler= MinibatchSampler(n_sampleso,batch_size)
       inner_sampler= MinibatchSampler(n_samplesi,batch_size)
       v = np.zeros_like(x)
       n_batches = (n_samplesi + batch_size - 1) // batch_size
       n_outer=(n_sampleso + batch_size - 1) // batch_size
       memory_inner_grad = np.zeros((n_batches + 1, n_features))
       memory_hvp = np.zeros((n_batches + 1, n_features))
       memory_grad_in_outer=np.zeros((n_outer + 1, n_features))
       memory_cross_v=np.zeros((n_batches + 1, 1))
       
       inner_step_size=0.125
       outer_step_size =0.125
       
       for k in range(maxiter):
           
           
           slice_inner, vr_inner = inner_sampler.get_batch()
           grad_inner_var=h_func_grad(x, lambdak,slice_inner)[1]
           hv=h_hessian(x, lambdak,slice_inner)
           hvp=hv(v)
           cross_v=np.dot(np.exp(lambdak[0])*x,v)

           slice_outer, vr_outer = outer_sampler.get_batch()
           grad_in_outer=g_func_grad(x, lambdak,slice_outer)[1]
           
           
           grad_inner_var = variance_reduction(
            grad_inner_var, memory_inner_grad, vr_inner)
           hvp = variance_reduction(hvp, memory_hvp, vr_inner)
           cross_v = variance_reduction(
            cross_v, memory_cross_v, vr_inner)
           
           grad_in_outer = variance_reduction(
            grad_in_outer, memory_grad_in_outer, vr_outer )
           
           
           x -= inner_step_size * grad_inner_var
           v -= inner_step_size * (hvp + grad_in_outer)
           lambdak -= outer_step_size * (cross_v)
           # projection
           
           lambdak[lambdak < -12] = -12
           lambdak[lambdak > 12] = 12
           
           if callback is not None:
              callback(x, lambdak)
       return x, lambdak
    elif f2sa:
        batch_size=1000
        n_sampleso,_=Xh.shape
        n_samplesi,_= Xt.shape
        outer_sampler= MinibatchSampler(n_sampleso,batch_size)
        inner_sampler= MinibatchSampler(n_samplesi,batch_size)
        
 
        lbda=0.1
        sz=0.1
        for it in range(1, maxiter):
           
           xhat=copy.deepcopy(x)
           for itn in range(1,maxiter_inner):
               slice_inner1, _ = inner_sampler.get_batch()
               slice_inner2, _ = outer_sampler.get_batch()
               xhat_grad=h_func_grad(xhat, lambdak,slice_inner1)[1]#check
               xhat=xhat-sz*xhat_grad
               gx=g_func_grad(x, lambdak,slice_inner2)[1]#upper w.r.t LL variable
               hx=h_func_grad(x, lambdak,slice_inner1)[1]
               x_grad=gx+lbda*hx
               x=x-sz*x_grad
                
           hl_minus_hl_k=np.exp(lambdak[0])*0.5*(np.dot(x,x)-np.dot(xhat,xhat))
           l_grad=lbda*hl_minus_hl_k #g_lambda=0

           
           lambdak=lambdak-0.1*l_grad
           lbda=lbda+0.0001
            # projection
           
           lambdak[lambdak < -12] = -12
           lambdak[lambdak > 12] = 12
           
            #norm_grad_lambda = linalg.norm(grad_lambda)
           if callback is not None:
              callback(x, lambdak)
        return x, lambdak
    else:
       # raise
       for it in range(1, maxiter):
          h_func, h_grad = h_func_grad(x, lambdak) 
          n_iterations = 0
          task[:] = 'START'
          old_x = x.copy()
          start = time.time()
          if not pure_python:
                pgtol_lbfgs = 1e-120
                factr = 1e-120  # / np.finfo(float).eps
                _lbfgsb.setulb(
                    m, x, low_bnd, upper_bnd, nbd, h_func, h_grad,
                    factr, pgtol_lbfgs, wa, iwa, task, iprint, csave, lsave,
                    isave, dsave, maxls)
                task_str = task.tostring()
                if task_str.startswith(b'FG'):
                    # minimization routine wants h_func and h_grad at the current x
                    # Overwrite h_func and h_grad:
                    h_func, h_grad = h_func_grad(x, lambdak)
                    if linalg.norm(h_grad)  < \
                        epsilon_tol * norm_init * np.exp(np.min(old_lambdak) - np.min(lambda0)):
                        # this one is finished
                        break

                elif task_str.startswith(b'NEW_X'):
                    # new iteration
                    if n_iterations > maxiter_inner:
                        task[:] = 'STOP: TOTAL NO. of ITERATIONS EXCEEDS LIMIT'
                        print('ITERATIONS EXCEEDS LIMIT')
                        continue
                        # break
                    else:
                        n_iterations += 1
                else:
                    if verbose > 1:
                        print('LBFGS decided finish!')
                        print(task_str)
                    break
          inverse_direction_fun = lambda x: g_func_grad(x, lambdak)[1]
          if foa:
                if h0==1:
                    h0=0.1
                xs, hess_inv,  warm_restart_lists=blfoa1(
                x0=x,
                f=lambda beta: h_func_grad(beta, lambdak)[0],
                f_grad=lambda beta: h_func_grad(beta, lambdak)[1],
                f_hessian=None,  # unused
                max_iter=maxiter_inner,
                m=m,
                #tol=epsilon_tol * norm_init * np.exp(np.min(old_lambdak) - np.min(lambda0)),
                tol=1/(1+it),
                tol_norm=linalg.norm,
                maxls=maxls,
                inverse_direction_fun=inverse_direction_fun,
                warm_restart_lists=warm_restart_lists,
                maxiter_hg=3,
                h0=h0
             )
                x= xs[-1]
          elif sr1:
                if h0==0.01:
                    h0=0.1
                xs, hess_inv, warm_restart_lists=isr(
                x0=x,
                f=lambda beta: h_func_grad(beta, lambdak)[0],
                f_grad=lambda beta: h_func_grad(beta, lambdak)[1],
                f_hessian=None,  # unused
                max_iter=maxiter_inner,
                m=m,
                #tol=epsilon_tol * norm_init * np.exp(np.min(old_lambdak) - np.min(lambda0)),
                tol=1/(1+it),
                tol_norm=linalg.norm,
                sz=0.1,
                inverse_direction_fun=inverse_direction_fun,
                warm_restart_lists=warm_restart_lists,
                ws=4,
                maxiter_hg=3,
                h0=h0,
                ex_up=False
             )
                x= xs[-1]
          elif foa1:
                xs, uk, warm_restart_lists=blfoaa(
                x0=x,
                f=lambda beta: h_func_grad(beta, lambdak)[0],
                f_grad=lambda beta: h_func_grad(beta, lambdak)[1],
                f_hessian=None,  # unused
                max_iter=maxiter_inner,
                m=m,
                #tol=epsilon_tol * norm_init * np.exp(np.min(old_lambdak) - np.min(lambda0)),
                tol=1/(1+it),
                tol_norm=linalg.norm,
                maxls=maxls,
                inverse_direction_fun=inverse_direction_fun,
                warm_restart_lists=warm_restart_lists,
                maxiter_hg=3,
                h0=h0
             )
                x= xs[-1]
          elif shine:
                if opa==False:
                   inverse_direction_fun =None
                xs, _, hess_inv, warm_restart_lists = lbfgs(
                x0=x,
                f=lambda beta: h_func_grad(beta, lambdak)[0],
                f_grad=lambda beta: h_func_grad(beta, lambdak)[1],
                f_hessian=None,  # unused
                max_iter=maxiter_inner,
                m=m,
                tol=epsilon_tol * norm_init * np.exp(np.min(old_lambdak) - np.min(lambda0)),
                #tol=1/(1+it),
                tol_norm=linalg.norm,
                maxls=maxls,
                inverse_direction_fun=inverse_direction_fun,
                inverse_secant_freq=maxiter-it,
                warm_restart_lists=warm_restart_lists,
               )
                x = xs[-1] 

          end = time.time()
          print(f'Forward took {end-start} seconds')
          g_func, g_grad= g_func_grad(x, lambdak)
          start = time.time() 
          if shine or foa or sr1:
              Bxk = hess_inv(g_grad)
          elif foa1 :
              Bxk = uk
          elif fpn:
              Bxk = g_grad

          if not (shine or foa or foa1 or sr1) or refine:
               tol_CG = epsilon_tol
               if maxiter_backward:
                  fhs = h_hessian(x, lambdak)
                  B_op = splinalg.LinearOperator(
                      shape=(x.size, x.size),
                        matvec=lambda z: fhs(z))
                  if Bxk is None:
                    Bxk = x.copy()
                  Bxk, success = splinalg.cg(
                    B_op,
                    g_grad,
                    x0=Bxk,
                    tol=tol_CG,
                    maxiter=maxiter_backward,
                   )
                  if success != 0:
                    print('CG did not converge to the desired precision')
          end=time.time()
          if verbose > 0:
            print(f'Backward took {end-start} seconds')
          old_epsilon_tol = epsilon_tol
          if tolerance_decrease == 'quadratic':
            epsilon_tol = epsilon_tol_init / (it ** 2)
          elif tolerance_decrease == 'cubic':
            epsilon_tol = epsilon_tol_init / (it ** 3)
          elif tolerance_decrease == 'exponential':
            epsilon_tol *= exponential_decrease_factor
          elif tolerance_decrease == 'exact':
            epsilon_tol = 1e-24
          else:
            raise NotImplementedError

          epsilon_tol = max(epsilon_tol, exact_epsilon)
        # .. update hyperparameters ..
          grad_lambda = - h_crossed(x, lambdak).dot(Bxk)
          if linalg.norm(grad_lambda) < 1e-12:
            # increase tolerance
            if verbose > 0:
                print('too low tolerance %s, moving to next iteration' % epsilon_tol)
            print('too low tolerance %s, moving to next iteration' % epsilon_tol)
            continue
          old_grads.append(linalg.norm(grad_lambda))
        
          if L_lambda is None:
            if old_grads[-1] > 1e-3:
                # make sure we are not selecting a step size that is too smal
                L_lambda = old_grads[-1] / np.sqrt(len(lambdak))
            else:
                L_lambda = 1
        
          step_size = (1./L_lambda)
          #if foa==True or sr1==True or foa1==True:
          if bsg1==True :
            lambdak= lambdak
          else:
            old_lambdak = lambdak.copy()
            lambdak -= step_size * grad_lambda
        # projection
            lambdak[lambdak < -12] = -12
            lambdak[lambdak > 12] = 12
          
          if sz:
             step_size = 0.01/(it+1)
             #step_size = 0.1
             old_lambdak = lambdak.copy()
             lambdak -= step_size * grad_lambda
          else:
            incr = linalg.norm(step_size * grad_lambda)
            C = 0.25
            factor_L_lambda = 1.0
            if g_func <= g_func_old + C * epsilon_tol + \
                old_epsilon_tol * (C + factor_L_lambda) * incr - factor_L_lambda * (L_lambda) * incr * incr:
               L_lambda *= 0.95
               lambdak -= step_size * grad_lambda
            elif g_func >= 1.2 * g_func_old:
            # decrease step size
               L_lambda *= 2
               lambdak = old_lambdak.copy()
               print('!!step size rejected!!', g_func, g_func_old)
               g_func_old, g_grad_old= g_func_grad(x, old_lambdak)
            # tighten tolerance
               epsilon_tol *= 0.5
            else:
              old_lambdak = lambdak.copy()
              lambdak -= step_size * grad_lambda

        # projection
          if projection is None:
            pass
          else:
            lambdak = projection(lambdak)

        # projection
          lambdak[lambdak < -12] = -12
          lambdak[lambdak > 12] = 12
        # if g_func - g_func_old > 0:
        #     raise ValueError
          norm_grad_lambda = linalg.norm(grad_lambda)
          g_func_old = g_func

          if callback is not None:
             callback(x, lambdak)
    return x, lambdak

   
