
from typing import Callable

import numpy as np

import scipy.special as sp
import math as mt
import math
import importlib.util
from t_model import FastT_V1
import optional_numba as nbu

def powarray(s_val,e_val,pow=2.):
    return pow**np.arange(np.ceil(mt.log(s_val)/mt.log(pow)), np.floor(mt.log(e_val)/mt.log(pow)), 1)
    

def alpha_to_sigma(alpha: float,side2=True) -> float:
    """
    For a two‐sided significance level α (e.g. α=0.05 for 95% confidence),
    returns the corresponding ±kσ threshold.
    E.g. α=0.05 → k ≈ 1.96
    """
    # Invert the two‐sided confidence:
    #   erf(k/√2) = 1 – α
    # ⇒ k = √2 · erfinv(1 – α)
    return math.sqrt(2) * sp.erfinv(1 - (1 if side2 else 2)*alpha)

@nbu.jtic
def eco_reset_factor(cn, ugh, ngh, v, c, d, a=3., sig2=0., t_buffer=True, lb=-2.):
    r"""
    
    :param cn: The N-RMSE that we optimize/estimate to accommodate the new directional derivative. 
    :param ugh:  u^T·ĝ direction vector dot gradient estimator
    :param ngh: ||ĝ|| estimator norm
    :param v: Our true gradient directional scalar, if we knew g it would be u^T g, however v can be noisy if sig!=0.
    :param c: Previous estimator expectation \sin \theta, or "Expected Normalized Root Mean Square Error" of ĝ to g.
    :param a: Typical range 2 to 4 z_sigma. 2-sided confidence interval that rejects anomalous v. Dependent on ||g||/\sqrt{d} because of how unit isotropic u is distributed as the dimension count increases, meaning a < \sqrt{d} always but this is only relevant at small dimension count. Use a z (normal) significance level to select `a` that "expects n false positives per m samples" 1-n/m.
    :param d: Dimensions or # parameters of our problem (u/g/x).shape[0]
    :param b: Confidence Interval of independent noise in v.
    :param sig2: Constant noise variance expected from v, derived for static whitenoise process.
    :return: 
    """
    di2=mt.sqrt(d)
    a=min(a,di2)
    c2=cn**2
    co2=c**2
    ratio =(1 - c2) / (1 - co2)
    scaling = mt.sqrt(ratio)
    dev_rate = (ugh * scaling - v)**2
    lam2_tol=(ngh*ngh)*c2/(d*(1-co2))
    #lam2_tol=((norm_ghat*Tfast_V1(z_sig,s))**2)*c2/(d*(1-co2))
    #denominator = norm_ghat**2 * ratio * c2 + sigma2
    #we could cut down on the calculation cost a bit more by saying:

    cr=dev_rate -a*a*lam2_tol - sig2
    #return cr
    if not t_buffer or cr<-1e-8: #cr negative can save us from always needing to call logs and FastT_V1 when t_buffer.
        return cr
    else:
        lgdi=mt.log1p(-1./d)
        lcn=2*mt.log(cn)
        s = max(lcn / lgdi, 2.) - 1.  # DoF or sample size for t-dist
        at= FastT_V1(a, s)
        at*=at
        #at = min(at, di2)  # protect by hard lower bound
        cr=dev_rate -at*lam2_tol - sig2
        return max(cr,lb)

@nbu.rgic
def marks_v_(ugh, ngh, v, c, d, a=3., sig2=0., t_buffer=True,):
    #The value calculation, so that we can create the invariant ratio form or value form.
    
    #a = min(a, mt.sqrt(d))  # protect gaussian upper bound, however the estimator may still have onset asymmetry so we t-dist using protected gauss and DoF.
    if c == 1. or ngh == 0.: return 0.,1.  #It's the first sample vector
    if t_buffer:
        s = max(2. * mt.log(c) / mt.log1p(-1 / d),2.) - 1.  # in production we will only calculate this after a reset occurs, and increment otherwise.
        a = FastT_V1(a, s)
    if ngh == 0.: ngh = v
    c2 = c * c
    gamult = ((ngh * a) ** 2.) * c2 / (d * (1 - c2))  # gamma/gradient expectation interval
    
    
    return abs(ugh - v), gamult + sig2

@nbu.rgic
def eco_ratio(ugh, ngh, v, c, d, a=3., sig=0., b=1., t_buffer=True, ratio_form=False):
    r"""
    Extended Marks ratio, includes the noise factor and student's t interval buffer for initial eigen value/predictor asymmetry. We also use the arithmetic form, so > 0 implies anomaly instead of >1.

    :param ugh:  u^T·ĝ direction vector dot gradient estimator
    :param ngh: ||ĝ|| estimator norm
    :param v: Our true gradient directional scalar, if we knew g it would be u^T g, however v can be noisy if sig!=0.
    :param c: Previous estimator expectation \sin \theta, or "Expected Normalized Root Mean Square Error" of ĝ to g.
    :param a: Typical range 2 to 4 z_sigma. 2-sided confidence interval that rejects anomalous v. Dependent on ||g||/\sqrt{d} because of how unit isotropic u is distributed as the dimension count increases, meaning a < \sqrt{d} always but this is only relevant at small dimension count. Use a z (normal) significance level to select `a` that "expects n false positives per m samples" 1-n/m.
    :param d: Dimensions or # parameters of our problem (u/g/x).shape[0]
    :param b: Confidence Interval of independent noise in v.
    :param sig: Constant noise std expected from v, derived for static whitenoise process.
    :return: 
    """
    nf = 0. if sig is None else b*b*sig*sig  # noise factor/interval
    dv,sd=marks_v_(ugh, ngh, v, c, d, a,nf, t_buffer)
    
    if not ratio_form:
        return dv*dv - sd  # IF > 0 implies violation
    else:
        return dv/mt.sqrt(sd) # IF > 1 implies violation



@nbu.jtc
def eco_shrinkage_reset_solution(ugh, ngh, v, c, d, a=3., sig2=0., lb=-2, t_buffer=True, max_iters=12, co_tol=1e-6):
    """ A solution method to the marks ratio with optional t-dist buffering and noise. Utilizes an accelerated bracketed secant method.

    NOTE For the secant method:
    co_tol is - abs(x-x_prev)<co_tol, and not abs(f(x))<co_tol because the units of f(x) and change in f(x) can be very extreme.

    """
    #NOTE: co_tol is - abs(x-x_prev)<co_tol, and not abs(f(x))<co_tol because the units of f(x) and change in f(x) can be very extreme.
    #However the root finding bounds have a much more limited scope.
    if ngh==0.: c,max(2 * mt.log(c) / mt.log1p(-1 / d), 2.)
    m_op=(eco_reset_factor, ugh, ngh, v, c, d, a, sig2, t_buffer, lb)
    mnc =1.#max(1. + min(-1 / d + 1e-15, 0.), .9999)
    if (ugh+v)<abs(ugh-v):br_rate=2/3 #Tends to be more well behaved because shrinkage will be more effective so boosted bracketing rate.
    else:br_rate=4/9 #otherwise there is an increased chance of multiple roots so take it slower. In practice 2/3rd seems to work fine for this too.
    #print(f' Solution Args: ', ", ".join(str(i) for i in (ugh, ngh, v, c, d, a, sig, b, t_buffer, -abs(r))))
    cn,lcn,hcn,reasn=signseeking_secant_v2(m_op,c,mnc,br_rate=br_rate,er_tol=co_tol,max_iters=max_iters,sign=-1)
    #print('positive curvature solution. c:',c,'lcn:',lcn,'hcn:',hcn,'reason:',reasn)

    ### Note it turns out this third root scenario is so rare empirically, that I've never seen an example where it actually happens, so I'm commenting out.
    # if t_buffer and (mnc - (lcn + hcn) / 2.) < .01 and reasn < 2:
    #     #We need to check that this didn't land on the 3rd t-limit left root
    #     #print('Launching on t_buffer c before:',cn,c,lcn,'marks ratio',marks_ratio(ugh,ngh,v,hcn,d,a,sig,b,t_buffer),'ugh',ugh,'v',v,'ngh',ngh,)
    #     cn2, lcn2, hcn2, reasn2 = signseeking_secant_v2(m_op, c, lcn, br_rate=2 / 3, er_tol=1e-6, max_iters=max_iters,sign=-1)
    #     if reasn2<2:#then it succeeded
    #         cn=cn2
    #         reasn=reasn2
    #         #print('Second Search succeeded.')
    #     else:pass
    #         #print('Second Search Failed.')
    #     #else: #it didnt succeed we rely on the t-limit result
    #     #print('Launching on t_buffer c after:', cn,'new mr',marks_ratio(ugh,ngh,v,cn,d,a,sig,b,t_buffer))
    lstc,ls=1.,1.
    if reasn == 2: #reason 2 happens because no roots were found between 0 and 1.
        return lstc, ls
    s = max(2 * mt.log(cn) / mt.log1p(-1 / d), 1.,)
    return min(cn,lstc), max(s,ls) #it's barely meaningful to reset as a fraction less than 2 DoF, this also acts as a generic shrinkage smoothing bound for partial resets. In the t_buffered case this bound is basically never hit which confirms the cutoff. At very low dimension count <10, you can experiment with loosing these cutoffs.

def shrink_gradestimate(g,cn,c):
    g[:]*=mt.sqrt((1.-cn**2)/(1-c**2))
    return g

@nbu.jtic
def signseeking_secant_v2(f_op, lo, hi,br_rate=.5, er_tol=1e-8, max_iters=20, sign=1):
    """A bracketed secant method that achieves (empirically) faster convergence by knowing the sign of the function to the left and right of the root.
    It also allows us to select if the slope of our root is positive or negative when there are multiple roots.

    The secant method uses the two most recent points, instead of the updated lo hi brackets, this typically gets the most out of the secant method,
    while still guaranteeing convergence with bisection bracketing. Assume we know only which lo or hi has a positive sign with regards to
    the general problem, if left side is positive we are seeking a negatively sloped root sign:=-1 vice versa for right side and positive slope root.
    Then until the first time sign(value)==-1, we only take a bracketing step; this strategy allows us to converge to a root that has a sign congruent
    slope in a multi root situation. Eg in a convex problem to a -(slope) root this will always be the left root.

    Other Notes: Convergence is only guaranteed when there is a single root with a congruent slope in the bracket.
    However, the likelihood of converging to a congruent root, is still very high due to the initial side rejection strategy explained above,
    by decreasing the bracketing increment to a range that guarantees sampling a basin br_rate <.5, you once more recover guaranteed
    convergence to the signed root.

    Variable calculations are all f64.

    :param f_op: Can be a function or a function operator (tuple) that includes its arguments. It receives a single scalar value for the point estimate. 
    :param br_rate: (0,1). The bracket increment, at .5 it's classic bisection, if you expect roots to be clustered on the right then >.5 might be suitable.
    Left clustered <.5. But a smaller br_rate should always have more definite convergence.
    :param sign: =1 we expect to have f(lo)<f(root)<f(hi). if -1 we expect f(lo)>f(root)>f(hi). If this expectation is unknown,
    it controls the bracketing bias eg if f is all positives and sign=1, then the bracket will reduce from right to left at (1 - br_rate) until
    reaching hi, if negatives and sign=1 then left to right at br_rate. Note: If both sides are wrong then convergence will not occur in the single root case.

    """
    fo, f = nbu.op_call_args(f_op, lo), nbu.op_call_args(f_op, hi)
    if sign == -1:
        op_bracket = f < 0.
        fo,f=-fo,-f
        fo, f = f, fo
        lamo, lam = hi, lo  #We know lo is positive, so we are more confident in giving it the step 2 interpolation point.
        lrt, hrt = 1 - br_rate, br_rate  #we want eagerness away from known side. so smaller=more conservative.
    else:
        op_bracket = fo < 0.
        lamo, lam = lo, hi
        lrt, hrt = br_rate, 1-br_rate
    #We init op_bracket by checking if the unknown point is negative, for this algo we assume that we know either lo or hi always has a positive sign for the general problem, my choice was due to the typical format of boundary solutions. If we seek a negative sloped root then we know our left side is positive, but the right side may have a basin or multiple roots (positive or negative areas), therefore we check if our right side has a negative bracketing location.
    

    #we flip our problems sign for negative roots so that lo bracket is always -, and hi always +.
    ict = int(max_iters)
    while ict > 0:
        fd = (f - fo)
        fo = f
        if abs(fd)<1e-15: #this is important to have in case monotonic portions.
            lamo = lam
            lam = (lo*lrt + hi*hrt)
        else:
            lamn = lam - f * (lam - lamo) / fd
            lamo = lam
            lam = lamn
            lamb=(lo*lrt + hi*hrt)
            ll, lh = ((lo, lamb) if sign == -1 else (lamb, hi)) if not op_bracket else (lo, hi)
            # if op_bracket:
            #     ll,lh=lo,hi
            # else:
            #     ll,lh=((lo,lamb) if sign==-1 else (lamb,hi))

            if not (ll<lam< lh): lam = lamb

        f = nbu.op_call_args(f_op,lam)
        op_bracket= op_bracket or f < 0.
        if sign == -1: f = -f #possible we don't need this and can replace with a single sign branch for the lo hi assignment.
        ict -= 1

        if f > 0.:
            hi = lam
        else:lo = lam

        if abs(lamo - lam) < er_tol: break

    return lam,lo,hi,2 if not op_bracket else 1 if ict==0 else 0



@nbu.rgic
def calc_stationary_info(g_est: np.ndarray, g: np.ndarray, gn2:float, infor: np.ndarray) -> None:
    """
    Compute metrics between true vector x and estimator y.
    Results are stored in outp as follows
      0: Cosine similarity between x and y
      1: Ratio of norms (||y||/||x||)
      2: Normalized RMSE (RMSE / ||x||)

    :param g_est: Gradient estimator vector, in a temporary/copied array where memory can be edited.
    :param g: True gradient.
    :param gn2: Use the precalculated gradient norm.
    :param infor: 3 element info array of floats.
    """


    nx2=np.dot(g_est,g_est)
    ny2=gn2
    # ny2=0
    # for v in g: ny2+= v * v
    infor[0] = np.dot(g_est, g) / mt.sqrt(nx2 * ny2) #Cosine sim
    infor[1] = mt.sqrt(nx2 / ny2) #norm ratio
    #we will assume g_est can be edited, so we can calculate this efficiently, hopefully without array copies
    g_est[:]-=g
    #g_est[:]*=g_est
    rs=np.dot(g_est,g_est)
    infor[2] = mt.sqrt(rs/ny2) #N-RMSE, root so that the order of the error matches the order of the average parameter value.
    
@nbu.rgic
def calc_gradopt_info(fval:float,g:np.ndarray,gn,g_prev: np.ndarray, gnp: float, infor: np.ndarray) -> None:
    """
    Compute metrics between prev and current grad.
      0: Cosine similarity g_prev and g
      1: Norm deviation % (||g|| - ||g_prev||)/||g||
      2: function values. Gradest should turn it into the log plot.
      
    """
    infor[0] = np.dot(g, g_prev) / (gn*gnp) #Cosine sim
    infor[1] = ((gn-gnp) / gnp) #norm change %
    #we will assume g_est can be edited, so we can calculate this efficiently, hopefully without array copies.
    #g_est[:]*=g_est
    infor[2] = fval
    
    
@nbu.rgic
def calc_approxopt_info(fval:float,g_est: np.ndarray,gen:float, g: np.ndarray, infor: np.ndarray) -> None:
    """
      0: Cosine similarity between g_est and g
      1: Normalized RMSE (RMSE / ||x||)
      2: Normalized RMSE (RMSE / ||x||)
      3: Function Value

    :param g_est: Gradient estimator vector, in a temporary/copied array where memory can be edited.
    :param g: True gradient.
    :param gn2: Use the precalculated gradient norm.
    :param infor: 3 element info array of floats.
    """


    g2=np.dot(g,g)

    infor[0] = np.dot(g_est, g) / (gen * mt.sqrt(g2)) #Cosine sim
    g[:]-=g_est
    infor[1]=gen/mt.sqrt(g2)
    mse=np.dot(g,g)
    
    infor[2] = mt.sqrt(mse / g2) #N-RMSE
    infor[3] = fval

    

import random as rd

def set_seed(sd=None):
    np.random.seed(sd)
    rd.seed(sd)
    _set_seed(sd)

@nbu.jtc
def _set_seed(sd=None):
    if sd is not None:
        np.random.seed(sd)
        rd.seed(sd)


#Move this to a separate plotting.py later if more additions happen later

import matplotlib.pyplot as plt

def plot_gradopt_info(info_list, dims, mnlook=None, mxlook=None,
              choose_plots=(0, 1, 2),
              pscale_ratios=(1., 1., 1.),
              plot_names=(r"$\cos \theta$ : $g_{k} \sim g_{k-1}$", r"$(\|g_k\|-\|g_{k-1}\|)/\|g_{k-1}\|$",
                          r"Log Opt Error"),
              vstack=False,
              figsize=(10, 60),
              minorxaxes=False,
              minoryaxes=True,
              eps_b=2**(-52),
                      enablecosbound=False,
                      enablermsebound=False,
            fontsize = 25, lwidths = 2.5,
            gw1 = 1.9,
            gw2 = 1.5,
            fontcolor = "#"+"00"*3,
              ):
    return plot_gradest_info(info_list, dims, mnlook, mxlook, choose_plots, pscale_ratios, plot_names, vstack, figsize, minorxaxes, minoryaxes, eps_b,enablecosbound,enablermsebound,fontsize, lwidths, gw1, gw2, fontcolor)

def plot_approxopt_info(info_list, dims, mnlook=None, mxlook=None,
              choose_plots=(0, 1, 2),
              pscale_ratios=(1., 1., 1.),
              plot_names=(r"Cos Sim: $\hat{g}_{k} \sim g_{k}$", r"N-RMSE: $\|\hat{g}_{k} - g_{k}\|/\|g_{k}\|$",
                          r"Log Opt Error"),
              vstack=True,
              figsize=(16, 16),
              minorxaxes=False,
              minoryaxes=True,
              eps_b=2**(-52),
                      enablecosbound=False,
                      enablermsebound=False,
                        fontsize=16, lwidths=2.5,
                        gw1=1.,
                        gw2=.5,
                        fontcolor="#" + "00" * 3,
                        ):
    return plot_gradest_info(info_list, dims, mnlook, mxlook, choose_plots, pscale_ratios, plot_names, vstack, figsize, minorxaxes, minoryaxes, eps_b,enablecosbound,enablermsebound,fontsize, lwidths, gw1, gw2, fontcolor)
   
import matplotlib.ticker as mticker

def plot_gradest_info(info_list, dims, mnlook=None, mxlook=None,
              choose_plots=(0, 1, 2),
              pscale_ratios=(1., 1., 1.),
              plot_names=(r"$\|\cos\theta\|$", r"$\|\hat{g}_k\|/\|\nabla f(x)\|$",
                          r"N-RMSE Log10"),
              vstack=True,
              figsize=(16, 16),
              minorxaxes=False,
              minoryaxes=True,
              eps_b=2**(-52),
                enablecosbound=True,
                      enablermsebound=True,
              fontsize=20,lwidths=2.5,
                      gw1=1.,
                      gw2=.5,
              fontcolor="#"+"00"*3,
                hspace=.1,
                      wspace=.1,
              ):
    """
    info_list: list of (series, label) where series shape = (T, 3)
               columns: [cosine_sim, norm_ratio, norm_rmse]
    dims:      dimensionality used for expectation bounds
    mnlook: min sample range.
    mxlook:    max sample range
    """

    # ---------- prep ----------
    ss = info_list[0][0].shape[0]
    mnlook = 0 if mnlook is None else mnlook
    mxlook = ss if mxlook is None else min(mxlook, ss)
    rcc={'font.size': fontsize,'axes.linewidth':gw1+gw1,'font.weight': 'normal','axes.edgecolor':fontcolor, 'xtick.color':fontcolor, 'ytick.color':fontcolor,'axes.labelcolor':fontcolor}
    if fontcolor is not None:rcc['text.color']=fontcolor
    with plt.rc_context(rcc):
        plt.rc('grid', linestyle="-",)# color='white')
        if not vstack: figsize = figsize[::-1]
        fig = plt.figure(figsize=figsize)
        cp = choose_plots
        lps = len(cp)
        if vstack:
            gs = fig.add_gridspec(lps, 1, height_ratios=pscale_ratios, hspace=hspace, wspace=wspace)
        else:
            gs = fig.add_gridspec(1, lps, width_ratios=pscale_ratios, hspace=hspace, wspace=wspace)
    
        c = 0
    
        def _as(sharex=None, sharey=None):
            nonlocal c
    
            ax = fig.add_subplot(gs[c, 0] if vstack else gs[0, c], sharex=sharex, sharey=sharey)
            c += 1
            return ax
    
        ax0 = _as()
    
        x_vals = np.arange(mnlook, mxlook)
    
        # ---------- plot raw series ----------
        axes = {choose_plots[0]: ax0}
        for i in choose_plots[1:]:
            ax = _as()#(ax0)
            axes[i] = ax
        for i, ax in axes.items():
            ax.grid(True, which='major', axis='both', linestyle='--', linewidth=gw1, color=fontcolor)
            ax.xaxis.set_major_locator(mticker.MaxNLocator(nbins=5))
            
            if minorxaxes:
                ax.minorticks_on()
                ax.grid(True, which='minor', axis='x', linestyle=':', linewidth=gw2, color=fontcolor)
            if minoryaxes:
                ax.minorticks_on()
                ax.grid(True, which='minor', axis='y', linestyle=':', linewidth=gw2, color=fontcolor)
            for series, label, label_op in info_list:
                S = series[mnlook:mxlook]  # raw
                ax.plot(x_vals, S[:, i],linewidth=lwidths,
                        label=f"{label}, {label_op(S, i)}")
            ax.set_ylabel(plot_names[i])
            ax.xaxis.set_major_formatter(
                mticker.NullFormatter(),
            )
            #ax.majorticks_off()
        ax.xaxis.set_major_formatter(
             mticker.ScalarFormatter(),
        )
        ax.set_xlabel("Samples")
        #ax.xaxis.set_major_locator(mticker.MaxNLocator(nbins=5))
    
        if (0 in cp or 2 in cp) and (enablecosbound or enablermsebound):
            # ---------- LMS MSE bound ----------
            rmsebound = np.maximum((1 - (1 / dims)) ** (np.arange(ss) / 2), eps_b)
    
        if 0 in cp:
            ax0 = axes[0]
            #ax0.set_ylabel('Range')
            ax0.margins(y=.02)
            ax0.set_ylim(0.001, 1.05)
            if enablecosbound:
                cosbound = np.sqrt(1 - rmsebound * rmsebound)
                ax0.plot(x_vals, cosbound[mnlook:mxlook],linewidth=lwidths,linestyle='--',
                         label=r"$\mathbb{E}$(ECO), " + f"F: {cosbound[mxlook - 1]:.2f}",color='cyan')
            # else:
            #     ax0.set_ylim(-1., 1.01)
        if 1 in cp:
            axes[1].margins(y=0.02)
            #axes[1].set_ylabel('Range')
            axes[1].set_ylim(0.001, 1.05)
            # axes[1].set_ylim(0.01,5)
    
        if 2 in cp:
            ax2 = axes[2]
            ax2.set_yscale('log')
            #ax2.set_ylabel("Log10 Range")#, rotation=0, labelpad=30)
            ax2.margins(y=0.02)
            #overriding for now because needed
            ax2.minorticks_on()
            ax2.grid(True, which='minor', axis='y', linestyle=':', linewidth=gw2, color=fontcolor)
            #ax2.tick_params(axis='y', rotation=90)
            ax2.yaxis.set_major_formatter(
                mticker.LogFormatterSciNotation()
                #mticker.FuncFormatter(lambda y, _: f"{y:g}")
            )
            ax2.yaxis.set_minor_formatter(
                mticker.NullFormatter(),
                #mticker.LogFormatterSciNotation()
                #mticker.FuncFormatter(lambda y, _: f"{y:g}")
            )
            if enablermsebound:
                ax2.plot(x_vals, rmsebound[mnlook:mxlook],
                         label=r"$\mathbb{E}$(ECO), " + f"F: {rmsebound[mxlook - 1]:.2f}",linewidth=lwidths,linestyle='--',color='cyan')
    
        # if minoryaxes:
        #     for v in axes.values():
        #         v.grid(True, which='minor', axis='y', linestyle=':', linewidth=0.5)
        # 
        # Legends
        if len(info_list) > 1:
            for v in axes.values():
                l=v.legend()
                l.get_frame().set_alpha(.67)

    return fig, axes

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker

def plot_gradest_quad(
    info_list,
    figsize=(16, 10),
    fontsize=20,
    fontcolor="#"+"00"*3,
    plot_names=(r"$\|\cos\theta\|$", r"$\|\hat{g}_k\|/\|\nabla f(x)\|$", r"N-RMSE"),
    right_ylabel="Value",
    left_xlabel="Samples",
    right_xlabel="Optimization Steps",
    ylims_cos=(0.001, 1.05),
    ylims_norm=(0.001, 1.05),
    ylims_rmse=(0.001, 1.05),
    lwidths=2.5,
    gw1=1.0,
    gw2=0.5,
    left_nticks=5,
    right_nticks=5,
    left_minorx=False,
    left_minory=True,
    right_minorx=True,
    right_minory=True,
    left_right_ratio=0.67,
    vspace_left=0.05,
    hspace_groups=0.05,
    remove_center=False,
    bound_alpha=0.25,
):
    rcc = {
        'font.size': fontsize,
        'axes.labelsize': fontsize,
        'xtick.labelsize': fontsize,
        'ytick.labelsize': fontsize,
        'legend.fontsize': fontsize,
        'axes.linewidth': gw1 + gw1,
        'font.weight': 'normal',
        'axes.edgecolor': fontcolor,
        'xtick.color': fontcolor,
        'ytick.color': fontcolor,
        'axes.labelcolor': fontcolor,
    }
    if fontcolor is not None:
        rcc['text.color'] = fontcolor

    def _lenT(a): return a.shape[0] if a is not None else None
    T = None
    for arr, _, _ in info_list:
        if arr is None: 
            continue
        T = _lenT(arr) if T is None else min(T, _lenT(arr))
    if T is None:
        raise ValueError("No data found in info_list.")
    x_vals = np.arange(T)

    from matplotlib import rcParams
    default_cycle = rcParams['axes.prop_cycle'].by_key().get('color', ['C0','C1','C2','C3','C4','C5'])
    series_colors = {i: default_cycle[i % len(default_cycle)] for i in range(len(info_list))}

    with plt.rc_context(rcc):
        if not (0 < left_right_ratio < 1):
            raise ValueError("left_right_ratio must be in (0,1).")

        fig = plt.figure(figsize=figsize)
        gs_main = fig.add_gridspec(1, 2, width_ratios=[left_right_ratio, 1-left_right_ratio], wspace=hspace_groups)

        left_rows = [0,2] if remove_center else [0,1,2]
        n_left = len(left_rows)
        gs_left = gs_main[0,0].subgridspec(n_left, 1, hspace=vspace_left)

        left_axes = [fig.add_subplot(gs_left[i,0]) for i in range(n_left)]
        axr = fig.add_subplot(gs_main[0,1])

        ordered = [(0,left_axes[0])] + ([(1,left_axes[1])] if not remove_center else []) + [(2,left_axes[-1])]

        def style_left(ax, is_bottom, ylabel, ylim):
            ax.grid(True, which='major', axis='both', linestyle='--', linewidth=gw1, color=fontcolor)
            if left_minorx or left_minory: ax.minorticks_on()
            if left_minorx:
                ax.grid(True, which='minor', axis='x', linestyle=':', linewidth=gw2, color=fontcolor)
                ax.xaxis.set_minor_formatter(mticker.NullFormatter())
            if left_minory:
                ax.grid(True, which='minor', axis='y', linestyle=':', linewidth=gw2, color=fontcolor)
                ax.yaxis.set_minor_formatter(mticker.NullFormatter())
            ax.xaxis.set_major_locator(mticker.MaxNLocator(nbins=left_nticks))
            ax.set_ylabel(ylabel)
            ax.margins(y=0.02)
            if ylim is not None: ax.set_ylim(*ylim)
            if not is_bottom:
                ax.xaxis.set_major_formatter(mticker.NullFormatter())
            else:
                ax.xaxis.set_major_formatter(mticker.ScalarFormatter())
                ax.set_xlabel(left_xlabel)

        for idx_metric, ax in ordered:
            if idx_metric == 0:
                ylim, yl = ylims_cos, plot_names[0]
            elif idx_metric == 1:
                ylim, yl = ylims_norm, plot_names[1]
            else:
                ylim, yl = ylims_rmse, plot_names[2]
            style_left(ax, ax is ordered[-1][1], yl, ylim)

            left_lines, left_labels = [], []
            for sidx, (S, name, label_op) in enumerate(info_list):
                if S is None: 
                    continue
                if S.ndim == 1 or (S.ndim==2 and S.shape[1]==1):
                    continue
                if idx_metric >= S.shape[1]:
                    continue
                y = S[:, idx_metric]
                line, = ax.plot(x_vals[:len(y)], y[:len(x_vals)], linewidth=lwidths, color=series_colors[sidx])
                fl = None
                if callable(label_op):
                    try:
                        fl = label_op(S, idx_metric)
                    except Exception:
                        fl = None
                    if isinstance(fl, str) and fl.strip() == "":
                        fl = None
                if fl is not None:
                    left_lines.append(line); left_labels.append(fl)
            if left_labels:
                leg = ax.legend(left_lines, left_labels, loc='best', framealpha=0.67)
                leg.get_frame().set_alpha(0.67)

        # right styling
        axr.grid(True, which='major', axis='both', linestyle='--', linewidth=gw1, color=fontcolor)
        if right_minorx or right_minory: axr.minorticks_on()
        if right_minorx:
            axr.grid(True, which='minor', axis='x', linestyle=':', linewidth=gw2, color=fontcolor)
            axr.xaxis.set_minor_formatter(mticker.NullFormatter())
        if right_minory:
            axr.grid(True, which='minor', axis='y', linestyle=':', linewidth=gw2, color=fontcolor)
            axr.yaxis.set_minor_formatter(mticker.NullFormatter())
        axr.xaxis.set_major_locator(mticker.MaxNLocator(nbins=right_nticks))
        axr.set_yscale('log')
        axr.set_xlabel(right_xlabel)
        axr.set_ylabel(right_ylabel)
        axr.yaxis.set_label_position("right")
        axr.yaxis.tick_left()
        axr.yaxis.set_ticks_position('left')

        right_lines, right_labels = [], []
        for sidx, (S, name, label_op) in enumerate(info_list):
            if S is None: continue
            y_lower=y_mid=y_upper=y_f=None
            if S.ndim == 1 or (S.ndim==2 and S.shape[1]==1):
                y_f = S if S.ndim==1 else S[:,0]
            elif S.ndim==2 and S.shape[1] >= 6:
                y_lower, y_mid, y_upper = S[:,3], S[:,4], S[:,5]
            elif S.ndim==2 and S.shape[1] >= 4:
                y_f = S[:,3]
            else:
                continue

            color = series_colors[sidx]
            if y_mid is not None:
                line, = axr.plot(x_vals[:len(y_mid)], y_mid[:len(x_vals)], linewidth=lwidths, color=color)
                if y_lower is not None and y_upper is not None:
                    axr.fill_between(x_vals[:min(len(y_lower),len(y_upper))],
                                     y_lower[:len(x_vals)], y_upper[:len(x_vals)],
                                     alpha=bound_alpha, linewidth=0, color=color)
            else:
                line, = axr.plot(x_vals[:len(y_f)], y_f[:len(x_vals)], linewidth=lwidths, color=color)

            add = None
            if callable(label_op):
                try:
                    SS = S if S.ndim==2 else S.reshape(-1,1)
                    add = label_op(SS, 3)
                except Exception:
                    add = None
                if isinstance(add, str) and add.strip() == "":
                    add = ""

            if add == "":
                continue
            label = name if add is None else f"{name}, {add}"
            right_lines.append(line); right_labels.append(label)

        if right_labels:
            leg = axr.legend(right_lines, right_labels, loc='best', framealpha=0.67)
            leg.get_frame().set_alpha(0.67)

        fig.tight_layout()
        axes_dict = {'left':[None,None,None], 'right': axr}
        for (idx_metric, ax) in ordered:
            axes_dict['left'][idx_metric] = ax
    return fig, axes_dict
