import matplotlib.pyplot as plt
import numpy as np
import optional_numba as nbu
import math as mt
import utils as ut



@nbu.jtc
def simple_averager(c,S,truegrad, bias_type=0):
    """ A method that uses Monte Carlo averaging to converge to the true gradient by sampling stochastic gaussian approximations
    of the gradient. Unlike directional derivatives that utilize a norm=1 projection. Gaussian smoothing theory uses
    norm=sqrt(d) projections which average out in mean to the true gradient.

    :param c: 1D array of slopes or approximate slopes of truegrad along S.
    :param S: The directional sample vectors, (# samples, # parameters). Scaled to be at least proportional to ||S[i]||_2 = 1.
    :param truegrad: The true gradient used to benchmark the .
    :param bias_type: 0 - MSE/RMSE minimization adjustment, 1 - Matching norm adjustment, 2 - unadjusted: norm is biased large.
    :return: info_array
    """
    dms = S.shape[1]
    dmf=float(dms)
    smps=S.shape[0]
    gn2=np.dot(truegrad,truegrad)

    g_est = np.zeros((dms,),dtype=np.float64)# np.zeros_like(S[0])  # grad update mem
    g_tmp = np.zeros((dms,), dtype=np.float64)
    info_array = np.empty((smps, 3), dtype=np.float64)  # info array

    for i in range(smps):
        # We translate directional derivative samples into stochastic approximators, so dim scale:
        #g_est += c[i] * S[i]*dms
        #It would be a bit in numba to write our own loops here, but that would be prohibitively slow in the python interpreter.
        ii=i+1.
        fi=1./ii
        g_est[:]*=(1. - fi)
        sc=c[i]*dmf*fi
        g_tmp[:]=S[i]
        g_tmp[:]*=sc
        g_est[:]+=g_tmp
        g_tmp[:]=g_est
        if bias_type != 2:
            cm=(ii / (ii + dmf - 1.))
            if bias_type==1:cm=mt.sqrt(cm)
            g_tmp[:] *= cm
        ut.calc_stationary_info(g_tmp, truegrad, gn2, info_array[i])
    return info_array



@nbu.jtc
def simple_leastchg(v, S, truegrad, bias_type=0):

    """

    :param v: 1D array of slopes or approximate slopes of truegrad along S.
    :param S: The directional sample vectors, (# samples, # parameters). Scaled to be at least proportional to ||S[i]||_2 = 1.
    :param truegrad: The true gradient used to benchmark the .
    :param bias_type: 0 - MSE/RMSE minimization (no adjustment), 1 - Matching norm adjustment, 2 - unadjusted: norm is biased large (to match monte carlo everager).
    :return: info_array
    """
    dms = S.shape[1]
    cm=1. - 1./dms
    smps=S.shape[0]
    gn2=np.dot(truegrad,truegrad)

    g_est = np.zeros((dms,),dtype=np.float64)# np.zeros_like(S[0])  # grad update mem
    g_tmp = np.zeros((dms,), dtype=np.float64)
    info_array = np.empty((smps, 3), dtype=np.float64)  # info array
    # norm matching adjustment $=\sqrt{1-(1-d^{-1})^k}$ Is the version solved for
    # raw monte carlo everaging adjustment $=1-(1-d^{-1})^k$
    # s=2
    for i in range(smps):
        er = v[i] - S[i].dot(g_est)
        g_tmp[:]=S[i]
        g_tmp[:]*=(er / 1.0)#for true directional samples S[i].dot(S[i]) = 1.0 always, however we might use samples such that S[i].dot(S[i]) is = 1 only on average, therefore we might want to see the degradation in the raw LMS update which is why it's removed from the denominator. You can use the block least change update with block size = 1 to compare.
        g_est[:] += g_tmp
        g_tmp[:] = g_est
        if bias_type != 0:
            bc=1. / (1. - cm ** (i + 1.))
            if bias_type == 1: bc=mt.sqrt(bc)
            g_tmp[:]*= bc
        ut.calc_stationary_info(g_tmp, truegrad, gn2, info_array[i])
    return info_array



from utils import eco_ratio, eco_shrinkage_reset_solution, shrink_gradestimate




@nbu.jtc
def stationary_lms_eco(v,S,truegrad, bias_type=0,
                       alpha=3., b=1., sig=0., partial_reset=False, t_buffer=True,emetrics=True):
    """
    Error Correcting Optimization framework for broyden updated gradients. Stationary testing. Uses Marks Ratio

    :param x: array x[0] holds the global point, x[1:] holds samples and determines how many samples we will use.
    :param rng_func: to generate the directional perturbations, takes x.
    :param sample_func: samples the function values.
    :param use_delta: Do we delta normalize or norm normalize from perturbation. See `s1_sample`.
    :param bias_type: 0 - MSE/RMSE minimization adjustment (this is what it is already) , 1 - Matching norm adjustment, 2 - un-adjusted (now adjusted to match the initial norm of averaging).
    :param eps: Base relative machine epsilon to be used for scaling the point's proximal region.
    :param truegrad: None - it will generate the true grad from the assigned `grad` func.
    :return: info_array
    """
    dms = S.shape[1]
    cm = 1. - 1. / dms
    smps = S.shape[0]
    gn2 = np.dot(truegrad, truegrad)

    g_est = np.zeros((dms,), dtype=np.float64)  # np.zeros_like(S[0])  # grad update mem
    g_tmp = np.zeros((dms,), dtype=np.float64)
    info_array = np.empty((smps, 3), dtype=np.float64)  # info array

    if alpha is None or alpha > (dms ** .5): alpha = dms ** .5

    ct = 1
    rm = 0
    sig2 = (b*sig) ** 2.
    c=1.
    for i in range(S.shape[0]):
        ngh2 = g_est.dot(g_est)
        if sig==0. or c == 1:
            l = 1.
            cmult = (1 - (1 / dms)) ** .5
        else:
            u = ((c * c) * ngh2) / (dms * (1 - c * c))
            l = u / (u + sig2)
            cmult = (1 - (l / dms)) ** .5
        ngh = ngh2 ** .5
        ugh = S[i].dot(g_est)
        r = eco_ratio(ugh, ngh, v[i], c, dms, alpha, sig, b, t_buffer, False)

        if r > 0:
            if rm < r: rm = r
            ct += 1
            if partial_reset:
                oc = c
                # c=noisymarksratio_rootestimator(ugh,ngh,v[i],c,dms,alpha,sig,b,)
                c, s = eco_shrinkage_reset_solution(ugh, ngh, v[i], c, dms, alpha, sig2, -r / 3., t_buffer)
                # print('Partial Reset Occurred: ',i,'oc',oc,'c',c,'r',r,'t_buffer',t_buffer)
                sd = mt.sqrt((1 - c ** 2) / (1 - oc ** 2))
                g_est[:] *= sd
                ugh *= sd
            else:
                g_est[:] = 0.
                ugh = 0.
                c = 1.  # *cmult
        else:
            c = c * cmult
        er = v[i] - ugh
        # if i%250==1:
        #     print(i,'oc',oc,'c',c,'r',r)
        g_tmp[:]=S[i]
        g_tmp[:]*=(er * l / 1.0)
        g_est[:]+=g_tmp
        g_tmp[:]=g_est
        if bias_type != 0:
            bc=1. / (1. - c*c)
            if bias_type == 1: bc=mt.sqrt(bc)
            g_tmp[:]*= bc
        ut.calc_stationary_info(g_tmp, truegrad, gn2, info_array[i])
    if emetrics: #we say false positive rate as this is a stationary benchmark with a single true gradient.
        print('Largest Bound Violation:', rm, 'Total Bound Violations', ct, ', False Positives Rate:', ct / S.shape[0])
    return info_array



@nbu.jtc
def gd_trustregion(f_op,grad_op,x0,
                   step=.1, step_min=1e-16,step_max=1e8,step_shrinkm=.5,step_growm=1.1,trust_q=.95,
                   max_iters=5_000,
                   grad_factor=0.,
                   ):
    """
    Gradient Descent with steps dependent on trust region scaling, and not at all on the gradient norm.
    
    :param f_op: Function operator, a standalone function that receives our parameter vector x as the first required arg. 
        Or a tuple - (func, *(all config or dataset args)).
    :param grad_op: Grad operator, like f_op but it returns a gradient vector. To improve performance with numba include 
        preallocated gradient memory as one of the operator args.
    :param x0: Our initial parameters.
    :param step: Initial optimizer step size. If all TR, this is the exact euclidean length that our parameters will change.
    :param step_min: The smallest our step size can be, machine eps is often enough such that our step won't even reach this.
    :param step_max: Largest step size.
    :param step_shrinkm: Step shrink multiple, when fd/md >= trust_q we shrink our current step by step*step_shrinkm.
    :param step_growm: Step grow multiple, when fd/md >= trust_q we grow our current step by step*step_growm.
    :param trust_q: trust_q $\in [0,1]$ (though technically you can experiment with different ranges if your gradient is biased/approximate).
        md is what our 'gradient model' estimates the value change to be, fd is our actual value change. 
        If fd/md = 1 our model agrees perfectly with the value change, if <1 value changes less than expected so shrink TR. If >1 grow TR.
        However given we are seeking solutions to bounded stationary points we can expect fd/md to mostly be < 1, therefore we actually compare 
        fd/md < trust_q. 
    :param max_iters: Determined total # iterations. As these are just demos I don't implement premature stopping criteria.
    :param grad_factor: $\in [0,1]$ Controls how much of the step is dependent on the TR range or original gradient norm, see comments below.
    :return: info_array contains cosine similarity g and g_prev, g to g_prev norm % deviation, log error of the optimization function.
    """
    
    #x_prev = np.empty(x0.size, dtype=x0.dtype)
    g_prev = np.empty(x0.size, dtype=x0.dtype)  # needed in case g from grad_op is the same array mem, which is true.
    info_array = np.empty((max_iters, 3), dtype=np.float64)
    
    val=nbu.op_call_args(f_op,x0)
    g=nbu.op_call_args(grad_op,x0)
    
    #x_prev[:]=x0 #in this version we know that the previous step is always g_prev*(step/gnp) so we dont need x_prev or p here, when step is independent of the steepest direction as in other optimizers, we do (eg implementing a TR on an adaptive method).
    #therefore md := -np.dot(g_prev,g_prev)*gf = -gnp*gnp*gf
    # A TR for GD, is typically independent of the gradient's norm, without any curvature information the gradient is often not a good predictor of optimal parameter step.
    #This is why we normalize g first and only listen to the TR surface whe grad_factor=0., however we can control this through the geometric/multiplicative mixture param grad_factor \in [0,1], which turns it into a dynamic LR and TR mixture method.
    # In general if you use grad_factor = 1 then it is utilizing gradient magnitudes, in that case you want slower trust region adaption. if 0 then you rely on TR adaption so make the multiples have shrinkm and growm have a wider range. 
    gn=mt.sqrt(np.dot(g,g))
    g_prev[:] = g
    gf = step * (gn ** (-(1.-grad_factor))) if grad_factor != 1. else step
    g[:]*=gf #no intermediate memcpy
    x0[:]-=g
    gnp=gn
    val_prev=val
    
    for i in range(max_iters):
        val=nbu.op_call_args(f_op, x0)
        #print(val)
        g = nbu.op_call_args(grad_op, x0)
        gn = mt.sqrt(np.dot(g, g))
        
        ut.calc_gradopt_info(val,g,gn,g_prev,gnp,info_array[i])

        fd = val - val_prev  # negative means improving.
        md = -gnp*gnp*gf # We know that the previous step will be exactly in the previous gradient's negative slope times the gradient scaling factor, this is a shortcut calculation, see adan_trustregion when the step is independent.
        step = grad_trmod(fd, md, step, step_min, step_max, step_shrinkm, step_growm, trust_q)
        # if i%100==0:
        #     print(i, 'tr ratio', fd / md)
        #     print(i,'nstep',step)
        
        g_prev[:] = g
        gf=step * (gn ** (-(1.-grad_factor))) if grad_factor != 1. else step
        g[:]*=gf
        x0[:] -= g
        
        gnp = gn
        val_prev = val
    
    return info_array  


@nbu.jt
def quasigd_trustregion(f_op, approx_op, x0,
                        step=.1, step_min=1e-8, step_max=1e8, step_shrinkm=.5, step_growm=1.1, trust_q=.95,
                        max_iters=5_000,
                        grad_factor=0.,
                        ):
    """
    Approximate Gradient Descent with trust region, everything in the exact gradient method but with changes for benchmark
    comparisons with the true grad.

    """
    
    #g_prev = np.empty(x0.size, dtype=x0.dtype)  # needed in case g from grad_op is the same array mem, which is true.
    info_array = np.empty((max_iters, 4), dtype=np.float64) # cossim with true g, mse with true g, values.
    val_prev=1e8
    gnp=gf=1.
    
    for i in range(max_iters):
        val = nbu.op_call_args(f_op, x0)
        g_est, g_true, q_factor = nbu.op_call_args(approx_op, x0)
        gn = mt.sqrt(np.dot(g_est, g_est))
        
        ut.calc_approxopt_info(val,g_est,gn,g_true,info_array[i])
        fd = val - val_prev
        md = -(gnp*gnp) * gf*q_factor #NOTE: for trust regions to work with quasi-gradients, they have to be the MSE minimizing
        #estimator, otherwise md will scale too large, q_factor is the adjustment used to make it MSE optimal regardless
        #if we use large norm estimates or MSE estimates of the grad for optimal steps.

        step = grad_trmod(fd, md, step, step_min, step_max, step_shrinkm, step_growm, trust_q)
        # if i%100==0:
        #     print(i, 'tr ratio', fd / md)
        #     print(i,'nstep',step)
        
        g_true[:] = g_est #we use g_true as memory because after it's recorded in info_array we don't need it, but g_est may be an accumulated gradient estimator so we leave it alone.
        gf=step * (gn ** (-(1.-grad_factor))) if grad_factor != 1. else step
        g_true[:]*=gf
        x0[:] -= g_true
        
        gnp = gn
        val_prev = val
    
    return info_array


@nbu.jt
def naapproxgd_trustregion(f_op, approx_op, x0,
                         step=.1, step_min=1e-8, step_max=1e8, step_shrinkm=.5, step_growm=1.1, trust_q=.95,
                         theta=.5, tau=0.,beta=None,
                         max_iters=5_000,
                         grad_factor=1.,
                         ):
    """
    Nesterov Accelerated Approximate Gradient Descent with trust region, all parameters like `approxgd_trustregion` only we add a
    momentum factor and strong convexity constant. We properly update the step and momentum coefficients by scaling the lipschitz constant
    which means step and theta are scaled in the same manner.
    Unless beta is a number between 0 and 1, then momentum will be non-adaptive, even if the lr is still variable.
    
    In this implementation we assume tau is constant/non-adaptive which allows us to simplify the quadratic solution to a simple
    sqrt(tau*theta)

    """
    y = np.empty(x0.size, dtype=x0.dtype)
    p = np.empty(x0.size, dtype=x0.dtype)
    yp= np.empty(x0.size, dtype=x0.dtype)
    info_array = np.empty((max_iters, 4), dtype=np.float64)  # cossim with true g, mse with true g, values.
    #its much more convenient to take a single grad step first so that's what we will do here.
    val_prev = nbu.op_call_args(f_op, x0)
    g_est, g_true, q_factor = nbu.op_call_args(approx_op, x0)
    gnp = mt.sqrt(np.dot(g_est, g_est))
    p[:]=-x0
    
    g_true[:]=g_est
    g_true[:]*=step*(gnp ** (-(1. - grad_factor))) if grad_factor != 1. else step
    x0[:] -= g_true
    y[:] = x0
    p[:]+=x0
    yp[:]=p
    bt=beta
    


    for i in range(max_iters):
        val = nbu.op_call_args(f_op, y)
        fd = val - val_prev  # negative means improving.
        md = np.dot(g_est, p)*q_factor  #TR would use change in x not change in y, so p
        
        pst=step
        step = grad_trmod(fd, md, step, step_min, step_max, step_shrinkm, step_growm, trust_q)
        theta = theta*step/pst
        
        g_est, g_true, q_factor = nbu.op_call_args(approx_op, y)
        gn = mt.sqrt(np.dot(g_est, g_est))

        ut.calc_approxopt_info(val, g_est, gn, g_true, info_array[i])
    

        g_true[:] = g_est  # we use g_true as memory because after it's recorded in info_array we don't need it, but g_est may be an accumulated gradient estimator so we leave it alone.
        if grad_factor != 1.:
            nm= (gn ** (-(1. - grad_factor)))
            gf = step * nm
            thetaf=theta * nm
        else:
            gf=step
            thetaf=theta
        alphar = mt.sqrt(tau*thetaf) #because our convexity estimate doesn't change from tau_0 we don't need to solve the entire
        #quadtratic like nesterov does in his 2015 paper, that way is the commented out below:
        #beta=(1.-alphar)/(1.+alphar)
        #beta=alphar*tau/(tau+alphar*tau)
        
        
        if bt is None:
            #we solve the outer beta version instead of inner, so that we can skip the intermediary step array.
            beta = (tau- alphar * tau) / (tau + alphar * tau)
        
        if i % 100 == 0:
            print(i, 'tr ratio', fd / md)
            print(i, 'nstep', step,theta,beta)
        
        yp[:]=-y #do negatives attached to arrays allocate a new array in numba pre yp assignment? idk
        y[:]=x0
        p[:]*=beta
        y[:]+=p
        yp[:]+=y
        
        g_true[:] *= gf
        
        p[:]=-x0
        x0[:]=y
        x0[:] -= g_true
        p[:]+=x0
        
        val_prev = val

    return info_array

def na_approx_htheta(l,n=1/4):
    return n/l,n*(2-n)/l


@nbu.rgic
def grad_trmod(fd,md,step,step_min,step_max,step_shrinkm,step_growm,trust_q):
    #fd function difference
    #md model difference
    if fd/md<trust_q: return max(step*step_shrinkm,step_min)
    return min(step*step_growm,step_max)

@nbu.jt
def adan_trustregion(f_op,grad_op,x0:np.ndarray, 
              step=.1, step_min=1e-16,step_max=1e8,step_shrinkm=.5,step_growm=1.1,trust_q=.95,grad_factor=1.,
              b=(0.98, 0.92, 0.99), lep=1e-8, max_iters=5_000,):
    """Adaptive Nesterov Algorithm. With an added trust region method that scales the learning rate, allowing it to converge to machine precision. 
    Adan appears to live on the pareto front of performance for vector parameter DL problems.
    Methods like SOAP, Splus, and maybe Muon can outperform it for m*n structured parameters.
    
    :tr_params: (learningrate_init, tr_min, tr_max, tr_shrinkm, tr_growm, trust_q). This is the same thing as TR with grad_factor=1 in the gradient methods.
     You can turn the TR off by setting lr=tr_min=tr_max

    https://arxiv.org/pdf/2208.06677
    
    Their claims quoted from the paper:
    "1) We propose an efficient DNN optimizer, named Adan. Adan develops a Nesterov momentum estimation method to estimate stable and accurate first- and second-order moments of the gradient in adaptive gradient algorithms for acceleration. 2) Moreover, Adan enjoys a provably faster convergence speed than previous adaptive gradient algorithms such as Adam. 3) Empirically, Adan shows superior performance over the SoTA deep optimizers across vision, language, and reinforcement learning (RL) tasks. So it is possible that the effort on trying different optimizers for different deep network architectures can be greatly reduced."

    In their paper they flip the beta notation (mistakenly or intentionally) so that b == 1 - b. eg if you used b1 =.9 for adam, then their notation equivalent would be b1=.1 for adan. In their actual pytorch implementations they get beta side correct so using that here.
    
    My take: 
    Utilizes adaptivity and momentum like other ada's, but also introduces a step size normalizer n_k. Then variance v_k and n_k include gradient t, and grad diff (t, t-1) in its calculation; likely capturing more moment information than Nesterov Momentum or ADAM.
    In general Adan seems to be an especially robust DL-type optimizer. It might improve convergence when sampling directional derivatives as compared to SGD. As this is an adaptive method, I also developed it with a trust region to simulate stationary point convergence. Modifying it with AMSgrad or AdamX-like system will prevent divergence but I haven't seen a need.
    """
    dms=x0.shape[0]
    g_prev = np.empty((dms,), dtype=x0.dtype)  # needed in case g from grad_op is the same array mem, which is true.
    p = np.empty((dms,), dtype=x0.dtype)
    mh = np.zeros((4, dms)[::-1],dtype=x0.dtype).T #technically mh[0] contains g_prev but I don't trust the fortran ordering and broadcasts, so the cost of this can be reduced irl.
    info_array = np.empty((max_iters, 3), dtype=np.float64)

    #val = nbu.op_call_args(f_op, x0)
    g = nbu.op_call_args(grad_op, x0)
    #-- all of this is to match Adan's original implementation:
    for i in range(dms):
        mh[0,i] = g[i] #g_prev
        mh[1,i] = g[i]
        #v[:] = 0.
        mh[3,i] = g[i]*g[i]
    
    gn = mt.sqrt(np.dot(g, g))
    g[:]*=step/gn #first step, we listen to gradient direction + lr so that initial step doesn't go crazy.
    x0[:] -= g
    gnp = gn
    val = nbu.op_call_args(f_op, x0)
    g = nbu.op_call_args(grad_op, x0)
    gn = mt.sqrt(np.dot(g, g))
    p[:] = -x0
    _adan_update(x0, g, step, b, mh, lep)
    for i in range(dms): mh[2,i]= g[i] - mh[0,i]


    ut.calc_gradopt_info(val, g, gn, g_prev, gnp, info_array[0])
    
    #--

    for i in range(1,max_iters):
        p[:] += x0
        g_prev[:] = g
        gnp = gn
        val_prev = val
        
        #print('x0',x0)
        val = nbu.op_call_args(f_op, x0)
        g = nbu.op_call_args(grad_op, x0)
        gn = mt.sqrt(np.dot(g, g))
        ut.calc_gradopt_info(val, g, gn, g_prev, gnp, info_array[i])

        fd = val - val_prev  # negative means improving.
        md = np.dot(g_prev,p)  #same here
        step = grad_trmod(fd, md, step, step_min, step_max, step_shrinkm, step_growm, trust_q)
        lr=step * (gn ** (-(1.-grad_factor))) if grad_factor != 1. else step
        # if i%100==0:
        #     print(i, 'tr ratio', fd / md,fd,md)
        #     print(i,'lr',lr)
        p[:]=-x0 #we want x_k - x_{k-1} so set with -x_{k-1} first.
        _adan_update(x0,g,lr,b,mh,lep)

    return info_array


#for convenience I'm keeping this element-wise operator that I implemented somewhere else, and assuming you install numba if it's too slow.
@nbu.jtc
def _adan_update(x0,g,lr,b,mh,lep):
    #assume mh is fortran ordered array, didn't go with column update so that classical numpy ops are less confusing.
    #however this can break certain numpy broadcasting and ufuncs.
    for i in range(x0.shape[0]):
        gi=g[i]
        gd = (gi - mh[0, i])  # equivalent to g[i] at step 0
        mh[0, i] = gi
        mh[1, i] = mh[1, i] * b[0] + (1. - b[0]) * g[i]
        mh[2, i] = mh[2, i] * b[1] + (1. - b[1]) * gd
        _t = (gi + b[1] * gd)
        mh[3, i] = mh[3, i] * b[2] + (1. - b[2]) * _t * _t
        nk = lr / (mt.sqrt(mh[3, i]) + lep)
        #scl = (1. + nbu.l_1_0(l2_reg, i) * lr)  # not implementing the regularized version here.
        x0[i] = (x0[i] - nk * (mh[1, i] + b[1] * mh[2, i])) # /scl #check algo again to make sure not a typo



@nbu.jt
def quasiadan_trustregion(f_op, approx_op, x0: np.ndarray,
                          step=.005, step_min=1e-16,step_max=1e8,step_shrinkm=.5,step_growm=1.1,trust_q=.95,grad_factor=1.,
                          b=(0.98, 0.92, 0.99), lep=1e-8, max_iters=5_000):
    """Everything of Adan but adjustments to accomodate approximate gradients like in `approxgd_`.
    """

    dms = x0.shape[0]
    p = np.empty((dms,), dtype=x0.dtype)
    mh = np.zeros((4, dms)[::-1], dtype=x0.dtype).T #fortran ordered
    info_array = np.empty((max_iters, 3), dtype=np.float64)

    #val = nbu.op_call_args(f_op, x0)
    g_est, g_true, q_factor = nbu.op_call_args(approx_op, x0)
    # -- all of this is to match Adan's original implementation:
    for i in range(dms):
        mh[0, i] = g_est[i]  # g_prev
        mh[1, i] = g_est[i]
        # v[:] = 0.
        mh[3, i] = g_est[i] * g_est[i]

    gn = mt.sqrt(np.dot(g_est, g_est))
    g_true[:]=g_est
    g_true[:] *= step / gn  # first step, we listen to gradient direction + lr so that initial step doesn't go crazy.
    x0[:] -= g_true
    val = nbu.op_call_args(f_op, x0)
    g_est, g_true, q_factor = nbu.op_call_args(approx_op, x0)
    gn = mt.sqrt(np.dot(g_est, g_est))
    p[:] = -x0
    _adan_update(x0, g_est, step, b, mh, lep)
    for i in range(dms): mh[2, i] = g_est[i] - mh[0, i]

    ut.calc_approxopt_info(val, g_est, gn, g_true, info_array[0])

    # --

    for i in range(1, max_iters):
        p[:] += x0
        val_prev = val

        # print('x0',x0)
        val = nbu.op_call_args(f_op, x0)
        fd = val - val_prev  # negative means improving.
        md = np.dot(g_est, p)*q_factor  # same here
        
        
        g_est, g_true, q_factor = nbu.op_call_args(approx_op, x0)
        gn = mt.sqrt(np.dot(g_est, g_est))
        
        ut.calc_approxopt_info(val, g_est, gn, g_true, info_array[0])

        #g_true[:] = g_est #dont need this for adan because g_est wouldn't be overwritten by the update kernel
        step = grad_trmod(fd, md, step, step_min, step_max, step_shrinkm, step_growm, trust_q)
        lr = step * (gn ** (-(1. - grad_factor))) if grad_factor != 1. else step
        # if i%100==0:
        #     print(i, 'tr ratio', fd / md,fd,md)
        #     print(i,'lr',lr)
        p[:] = -x0  # we want x_k - x_{k-1} so set with -x_{k-1} first.
        _adan_update(x0, g_est, lr, b, mh, lep)

    return info_array
