
from chs import modsel, uni_eq, timeout
import numpy as np
import scipy.integrate as si
import sympy as sy
import datagen
import sindy_solvers as sds
import matplotlib.pyplot as pp
from matplotlib import colormaps as cm
from matplotlib import colorbar as cb
import os
from tqdm import tqdm
import warnings


bgcol = [1,1,1,1]
pp.style.use('classic')


fs = 21
fstxt = 14 # for text in plots
pp.rc('font', size=fs)          # controls default text sizes
pp.rc('axes', titlesize=fs)     # fontsize of the axes title
pp.rc('axes', labelsize=fs)    # fontsize of the x and y labels
pp.rc('xtick', labelsize=fs-2)    # fontsize of the tick labels
pp.rc('ytick', labelsize=fs-2)    # fontsize of the tick labels
pp.rc('legend', fontsize=fs-2)    # legend fontsize
pp.rc('figure', titlesize=fs-2)  # fontsize of the figure title

getnewmdl = False # get new model equations from EQL even if stored?
savenewmdl = True # if new model learned, save it (overwrites if getnewmdl true!)?
new_ictest = False # overwrite existing IC test results?
silent = False

namsys = 'Rabi' # to specify filename of stored results

# create scenarios
scenarios = {}
Nscen = 0
N_vec = [1000, 2000, 4000, 8000, 16000] # needs more data and longer time interval
dt_vec = [0.001, 0.003, 0.01, 0.03, 0.1]
sig_vec = [0.0001, 0.0005, 0.002, 0.01] # trajectories are smaller than Lorenz (~factor 10), therefore might need smaller noise on X
ic_vec = [ [ [-1.5 , 0. , 1.0] ] ] # chaotic for ic (-1,0,0.5) and alpha=1.1, gamma=0.87
for N in N_vec:
    for dt in dt_vec:
        if N*dt >= 5: # below that I cannot expect algorithm to learn anything
            for sig in sig_vec:
                for ic in ic_vec:
                    Nscen += 1
                    scenarios[str(Nscen)] = (N,dt,sig,ic)

scenvec = [s for s,v in scenarios.items()] 

testcase = '-' # to specify filename of stored results

finselvec = ['LASSO', 'CS-$R^2$', 'CS-$p(M)$', '$p(M)$-$R^2$', 'BSR', 'SR3', 'FROLS', 'STLSQ', 'MIOSR'] # LASSO R2sel evi SWR-evi SR3 FROLS STQLS MIOSR
finsel_own = list(range(5)) # to index how many selection criteria is embedded in own algorithm, thereafter SINDy is used
finixR2 = 1
finixpm = 2

# for plots
finselvec2 = np.array(list(finselvec) + ['true']) # add true as last position
finsel_plot = [lb for lb in finselvec2 if not lb in ['$p(M)$-$R^2$']] # can take out methods, but must adjust labs below manually
labs = ['$\\mathrm{LASSO}$', '$\\mathrm{CS}$-$R^2$', '$\\mathrm{CS}$-$p(M)$', '$\\mathrm{BSR}$', '$\\mathrm{SR3}$', '$\\mathrm{FROLS}$', '$\\mathrm{STLSQ}$', '$\\mathrm{MIOSR}$', '$\\mathrm{true}$'] 


failed_nr = []
found_allterms = {s:0 for s in finselvec}
found_truemodel = {s:0 for s in finselvec}
found_truemodel_scen = []
found_MAEs = []
found_MAEsms = [] # model size (multiplied by)
found_logMAEs = []
found_MAEranking = []
found_failures = {s:0 for s in finselvec}
found_blewfail = {s:0 for s in finselvec}

sindyCVsearches = {alg: np.empty(len(scenarios),dtype=object) for algix, alg in enumerate(finselvec) if not algix in finsel_own}
sindyCVresults  = {alg: np.empty(len(scenarios),dtype=object) for algix, alg in enumerate(finselvec) if not algix in finsel_own}

            
### hyper-parameters
# for comprehensive search approach
nr_rsq = 2 # speficies how often a feature must be selected in Rsq feature rating to be selected for inferred model
nr_deact = 2 # specifies in how many previous feature ratings a feature must not be selected to be deactivated
nr_evi = 2 # specifies how many times in a row sum of evidence of top models must reduce for model size incrementation to terminate
cmin_mod = 0.75 # selection proportion above which feature is selected for inferred Rsq model
cmin_deact = 0.0 # selection proportion below which features are deactivated
weightfeat = True # weight features with criterion for feature selection rating? (here for Rsq)
usestep = False # use step detection in criterion for feat2mod and deactfeat (here for Rsq)
nrsq = 500 # how many models to keep in Rsq listings
topfeat = None # specifies how many top Rsq models shall be used for feature ranking 
# (if None, half of number of basis functions (=p/2) is chosen later if usestep=False, if usestep=True then topfeat is chosen from step detection
nkeep = 25 # how many models for each Rsq for top listing for model selection
ntop = None # cut top listing for model evidence to this number of models [if None then same as topfeat]
nRsq_max = 8 # maximum number of terms to be considered, if reached SCS aborts and best Rsq model for model sized given by evi is selected
cosel_thresh = 1. # features which correlation larger than that are considered connected, and (if used) co-selected or not co-deselected
constterm = True # include constant term
# for bfe
maxdeg = 4 # maximum degree (individual factors)
maxdeg_term = 6 # maximum degree (terms)
# for model evidence
heur = True # use known normal parameters (mean and variance) from OLS, assuming normality
h_scalfac = 'N' # factor by which distributions are broadened (use string to use later defined variables)

njobs = 1 # number of parallel jobs
njobs_sindy = 1 # ... for sindy CV search, warnings not suppressed for parallel computing

# for ivp solvers applied to learned models
xraw = datagen.rabinovich(N=int(1e5), T=int(1e3), sig=0, ics=ic_vec[0])[6]
sample_ic = lambda : xraw[ : , np.random.randint(xraw.shape[1]) ] + np.random.normal(0,0.1,3) # sample from attractor
Tpred = 15 # duration for prediction accuracy
Npred = 100 # number of random ICs from prediction test
Nt = 1000 # number of timesteps for ivp solvers
ivp_meth = np.array(['RK45'],dtype=object) #'RK45','LSODA','DOP853','Radau','BDF' solvers to try in this order
tol = 1e-4 
atol = tol # absolute error tolerance for solver (SINDy uses 1e-12, can lead to super long computation times for some learned models (mostly LASSO and SINDy))
rtol = tol # same, but relative error tolerance
solivp_TO = 15 # timeout for ivp solvers: time in seconds after which ivp_solvers throws an error, which effects trying next solver
max_failures = 1 # chances methods get to get its found model be solved for a random IC (if one method fails, results of all methods disregarded, and new IC for all methods)
# import to note: because of that, the set of ICs might be selected for the advantage of the method causing the failures,
# ... therefore important to check how often method failed (found_failures) 
MAEcut = 10 # MAE above this counts as separate failure

# dynamical system parameters
# pars = {'alpha':1.1, 'gamma':0.87} # chaotic for ic (-1,0,0.5)
pars = {'alpha':0.14, 'gamma':0.1} # stable limit cycle

# data handling (standardisation method)
stdmeth = 'normal' # 'none', 'normal', 'normglob', 'centre', 'centglob', 'unit'

def quietmean(*args, **kwargs):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=RuntimeWarning)
        return np.nanmean(*args, **kwargs)

folpart = '_cminmod'+str(cmin_mod)+'_cmindeact'+str(cmin_deact)+'_maxf'+str(max_failures)+'_TO'+str(solivp_TO)+'_tol'+str(tol)+'_nkeep'+str(nkeep)+'_Tpred'+str(Tpred)+'_'+"".join(finselvec)+'/'
if usestep:
    fol = namsys+'___'+testcase+'___'+'scenarios_with-steps' + folpart.replace('.','-')
else:
    fol = namsys+'___'+testcase+'___'+'scenarios_no-steps'   + folpart.replace('.','-')
    
if os.path.exists(fol):
    if getnewmdl and savenewmdl:
        input('Really new?')
else:
    os.mkdir(fol)

count_R2stop = 0
count_evistop = 0
count_maxitstop = 0
for nr,scen in tqdm(scenarios.items(),desc='solving/loading scenarios',disable=not silent):
    if nr in scenvec:
        
        if True:
            N, dt, sig, ics = scen
            T = N*dt
            # T = set_T
            
            # get file identifier, learned model equations will be saved into and loaded from
            filid = "mdleqs_t{0:d}-dt{1:.4f}_N{2:d}_sig{3:.4f}".format(0,dt,N,sig).replace('.','')+'_ic'+"".join([str(ic) for ic in ics[0]])
            
            
            if getnewmdl or not os.path.exists(fol+filid+'.npz'):
            
                # get artificial data
                t_mp, y_mp, wex, eqlist, powsex, x_mp, xraw, traw = datagen.rabinovich(N=N+1, T=T, sig=sig, cma_step=0, pars=pars, ics=ics, plotit=False) 
                # combine trajectories
                X = x_mp.reshape( ( x_mp.shape[0] , x_mp.shape[1]*x_mp.shape[2] ) )
                Y = y_mp.reshape( ( y_mp.shape[0] , y_mp.shape[1]*y_mp.shape[2] ) )
                ncompo = len(wex) # number of components of dynamical systems (number of ODEs)
                    
                # go through components of dynamical system
                mdleqs = np.empty([len(finsel_own)+1,ncompo],dtype=object) # one extra row for true model extracted later
                westis = np.empty([len(finsel_own)+1,ncompo],dtype=object)
                for compo in range(ncompo):
                    
                    finix = -1 # index to be incremented as we go through methods, always before new method
                    
                    # set up model class: feed with artificial data, generate kernel matrix, fix ground truth, and standardise data
                    ms = modsel(X=X, y=Y[compo,:], fractions=False, njobs=njobs, silent=silent, darkmode=False)
                    terms, nex = ms.bfe(maxdeg=maxdeg, maxdeg_term=maxdeg_term, ker='Kgen', pex=(np.array(powsex[compo]),[]), stdise='none', constterm=constterm) # stise done below
                    ms.telltruth(trueterms=(terms[nex[0]],terms[nex[1]]),w=(wex[compo],[]))
                    ms.stdise(verbose=False,method=stdmeth)
                    if not usestep and topfeat is None:
                        topfeat = int(round( 0.5*ms.p )) # this seems to be good choice if step detection is not used
                    if not usestep:
                        step_interval = ms.p # step_interval is also used for plot_modsel, put this value in case usestep==False
                    
                    # set parameters for Bayesian model evidence
                    ms.evipars['heur'] = heur # use known normal parameters (mean and variance) from OLS, assuming normality
                    if heur:
                        ms.evipars['m'] = 1 / ms.y.std(0)**2 # mode of prior for precision of noise (gamma distribution), ms.y.std(0) will be 1 due to standardisation [empirical Bayes choice]
                        ms.evipars['v'] = 0.5 # variance of prior for precision of noise (pp.plot(x,sc.stats.gamma.pdf(x,m/v+1,0,v)))
                    ms.evipars['h_scalfac'] = eval(str(h_scalfac)) # factor by which distributions are broadened
                                           
                    # set up deselection matrix to prevent deactivation of features correlated with often selected features
                    ms.calc_comat(meth='corr') # 'corr' or 'mi' (=normalised mutual information)
                    ms.coselmat(threshold=cosel_thresh) # if above threshold, features assumed to be connected, used to prevent deactivation of features connected to active features
                    if not silent:
                        ms.plot_comat() 
                                
                    # LASSO
                    finix += 1
                    nam = finselvec[finix] # name for feature rating 
                    ms.lasso(boots=1,bootsize=1,nam=nam) # run feature rating without bootstrapping, and save into rating table <nam>
                    ms.newtop(nam=nam,ntop=1) # create new top model list
                    ms.feat2mod(nam=nam,nr=1,cmin=0) # select model from most selected features in feature rating
                                 
                    # deactivate features not connected to LASSO selected features
                    # ms.deactfeat(nr=1,cmin=0,featmin=10) ### turned out to be not beneficial for identification accuracy
                    
                    # initiliase empty top model list with one model (ntop=1)
                    finix += 1 # = finixR2
                    ms.newtop(nam=finselvec[finix],ntop=1)
                        
                    ### R-squared elimination procedure for comprehensive search approach
                    # start with 1 term
                    nams_R2_fixedsize = [] # init list of fixed-sized R2 top model listings
                    nams_R2_fixedsize.append('$R^2$, 1 terms') # add name for top model listing for 1-sized models
                    ms.newtop(nam=nams_R2_fixedsize[-1],ntop=nrsq) 
                    ms.val_lnrsq(nt=1 ,presel='all') # compute R-squared (log-scale) for all models
                    if usestep:
                        step_interval = ms.p-2-1 # only look in this interval (top models) for a step
                        # (as after that many top models, all wrong features have been selected once, if nt is the correct number of features)
                    ms.mod2feat(mod=nams_R2_fixedsize[-1],topnum=topfeat,weight=weightfeat,step=usestep,step_interval=step_interval) # use best topfeat models to create feature rating from that
                    # ms.plot_featsel() # plot current feature rating
                    if not silent:
                        ms.plot_modsel(nt=step_interval) # plot last obtained top models
                    
                    # also get sum of top evis for 1-term models
                    feats_copy = ms.feats.copy() # copy active feature list to restore after evi evaluation below
                    ms.feats = np.arange(ms.fac_p*ms.p) # make all features active again, otherwise ms.evi throws an error as models in top are always of shape set here
                    topmods = ms.top[ms.ranknams[-1]].toarray()
                    evis_ = np.full(nkeep,np.nan)
                    for mix in range(min(nkeep,topmods.shape[0])): # could also just look at top one in case evi needs to be estimated
                        evis_[mix] = ms.evi( topmods[mix] , [] )[0]
                    evis = np.array([evis_.max()])
                    ms.feats = feats_copy.copy() # restore list of active features in case needed later
                    
                    # now loop through all candidate models with more terms
                    nRsq = 1 # initiliase number of terms to be considered
                    R2_success = False # initiliase R2 stepwise regression criterion
                    evi_stop = False # initialise evi-stopping criterion
                    while not R2_success and not evi_stop and nRsq<nRsq_max: # if not successfully found model from last nr Rsq listings, increase number of terms
                                
                        if nRsq >= 1+1: # if two or more feature ratings from R-squared have been created, remove features from models that were not selected
                            ms.deactfeat(nr=nr_deact,cmin=cmin_deact,featmin=nRsq_max)
                            
                        nRsq += 1 # increment number of terms in models
                        nams_R2_fixedsize.append('$R^2$, ' + str(nRsq) + ' terms') # name of model lists
                    
                        ms.newtop(nam=nams_R2_fixedsize[-1],ntop=nrsq) # next model list
                        ms.val_lnrsq(nt=nRsq ,presel='all') # compute R-squared
                        if usestep:
                            step_interval = ms.p-nRsq-1 # only look in this interval (top models) for a step
                            # (as after that many top models, all wrong features have been selected once, if nt is the correct number of features)
                        ms.mod2feat(mod=nams_R2_fixedsize[-1],topnum=topfeat,weight=weightfeat,step=usestep,step_interval=step_interval) # create feature rating from that
                        # ms.plot_featsel()
                        if not silent:
                            ms.plot_modsel(nt=step_interval)
                        
                        if nRsq >= nr_rsq+1: # if two or more feature ratings from R-squared have been created
                            ms.feat2mod(nam=finselvec[finixR2],nr=nr_rsq,cmin=cmin_mod) # try and select from last nr Rsq listings
                        # created top model listing remains empty if no model could be identified, in which case while loop continues
                        
                        # determine max model evidence for nkeep top models for this R2 round
                        feats_copy = ms.feats.copy() # copy active feature list to restore after evi evaluation below
                        ms.feats = np.arange(ms.fac_p*ms.p) # make all features active again, otherwise ms.evi throws an error as models in top are always of shape set here
                        topmods = ms.top[ms.ranknams[-1]].toarray()
                        evis_ = np.full(nkeep,np.nan)
                        for mix in range(min(nkeep,topmods.shape[0])): # could also just look at top one in case evi needs to be estimated
                            evis_[mix] = ms.evi( topmods[mix] , [] )[0][0]
                        evis = np.append(evis,evis_.max())
                        ms.feats = feats_copy.copy() # restore list of active features in case needed later
                        
                        R2_success = ms.top[finselvec[finixR2]].sum()>0 # if top model list for R2 stepwise regression is empty, then no model has been identified yet, and we keep going to increase model size (while loop)
                        if R2_success:
                            count_R2stop += 1
                        
                        # if nRsq >= 2+1: # if three or more rounds of R2 testing have been done, check if evi stopping criterion has been met
                        #     evi_stop = (evis[-1] < evis[-nr_evi-1:-1]).all()
                        #     count_evistop += 1
                        #     if evi_stop:
                        #         print('Decline in model evidence caused early stopping.')
                        
                    # plot combined Rsq results for feature selection
                    # ms.plot_featsel()
                        
                    # combine Rsq for top model listing and plot
                    ms.comb_top(topnams=nRsq,topkeeps=nkeep,nam='top $R^2$',sortit=True) # 'topnams=nRsq' makes that las nRsq top listings are used, which are the fixe-sized R2 listings, could also use 'topnams=nams_R2_fixedsize'
                    if not silent:
                        ms.plot_modsel()
                    # ms.plot_modprop()
                                        
                    # compute model evidence for combined top model listing
                    finix += 1
                    ms.newtop(nam=finselvec[finix],ntop=50) # create top model listing where top 25 are kept
                    ms.val_evi(presel='top $R^2$')
                    if not silent:
                        ms.plot_modsel(nam=finselvec[finix])                   
                    
                    # if CS-R2 did not find any models (i.e. ms.top[finselvec[1]] has no non-zero elements), get best Rsq model of size selected by evi
                    # ... for model size selected by evi all models in top model listing are considered, so the size of the best model in the evi top listing computed above is selected
                    if not R2_success: 
                        evisize = ms.top[finselvec[finixpm]][0,:].sum() # get model size (number of equations) selected by evi
                        ms.top[finselvec[finixR2]] = ms.top['$R^2$, ' + str(evisize) + ' terms'][0] # set top Rsq model of that model size as DaC (SCS) result
                        count_maxitstop += 1
                        if not silent:
                            print('Iterative R2 did not identify model size, used model evidence instead.')
                                                            
                    # pick overall best R^2 model by evi among all best R^2 models for each model size category
                    ms.comb_top(topnams=nams_R2_fixedsize,topkeeps=1,nam='top-1 $R^2$',sortit=True) # combine one best model per model size
                    finix += 1
                    ms.newtop(nam=finselvec[finix],ntop=nRsq) # create new top model listing with defined name in finselvec, ...
                    ms.val_evi(presel='top-1 $R^2$') # ..., where the top models from the combined top-1 are evaluated and sorted with evi
                    # ms.plot_modsel()  
                    
                    # bidirectional stepwise regression
                    finix += 1
                    nam = finselvec[finix] # name for feature rating 
                    ms.swsel(crit='evi',nam=nam)
                    ms.newtop(nam=nam,ntop=1) # create new top model list
                    ms.feat2mod(nam=nam,nr=1,cmin=0) # select model from most selected features in feature rating
                                        
                    # collect equations and extract true solution
                    for selix,finsel in enumerate([finselvec[ix] for ix in finsel_own]):
                        mdleqs_, westis_ = ms.gettop(nam=finsel,maxeqs=5,print_plot=False) # get top models (includes true model as first entry)
                        mdleqs[selix,compo], westis[selix,compo] = mdleqs_[1:], westis_[1:] # cut out true model, as true model extracted below separately
                    mdleqs_, westis_ = ms.gettop(nam=finselvec[0],maxeqs=5,print_plot=False) # (nam is arbitrary is only interested in true model which is always put out first)
                    mdleqs[selix+1,compo], westis[selix+1,compo] = mdleqs_[:1], westis_[:1] # extract true model and add as last entry in mdleqs and westis
            
            
                # add SINDy solutions
                for sdsol in [finsel for fix,finsel in enumerate(finselvec) if not fix in finsel_own]:
                    sindy_sol = getattr(sds, sdsol)
                    sindy_mdl, CVsearch = sindy_sol(x=xraw.T, t=traw, njobs=njobs_sindy, silent=silent) # can change parameters here if passed to sindy_sol (although would be same for all optimisers), so better CV optimised individually in functions
                    # note that njobs>1 in the above triggers warnings that cannot be supressed in any way
                    sindyCVsearches[sdsol][int(nr)-1] = CVsearch.cv_results_
                    sindyCVresults [sdsol][int(nr)-1] = CVsearch.best_params_
                    
                    # add found equations and weights
                    sindy_terms = np.array(sindy_mdl.get_feature_names(),dtype=object)
                    sindy_coefs = sindy_mdl.coefficients()
                    mdleqs = np.vstack(( mdleqs[:-1,:] , np.empty(ncompo,dtype=object) , mdleqs[-1:,:] )) # need to first make space for SINDy
                    westis = np.vstack(( westis[:-1,:] , np.empty(ncompo,dtype=object) , westis[-1:,:] )) 
                    for compo in range(ncompo):
                        sindy_equ = "+".join( ["w_{0:d}*{1:s}".format(ix+1,eq.replace(' ','*')) for ix,eq in enumerate(sindy_terms[np.abs(sindy_coefs[compo,:])>0])] )
                        if sindy_equ=='':
                            sindy_equ='0'
                        mdleqs[-2,compo] = np.array( [ sy.parsing.sympy_parser.parse_expr( sindy_equ.replace('^','**') ) ] , dtype=object ) # old: [ np.array([eq] , dtype=object) for eq in sindy_mdl.equations() ] # otherwise I don't have the 2D object array of 1D object arrays structure
                        westis[-2,compo] = np.array( [ sindy_coefs[compo,np.abs(sindy_coefs[compo,:])>0] ] , dtype=object )
                                
                # save learned equations    
                if savenewmdl:
                    np.savez(fol+filid, mdleqs=mdleqs,westis=westis,T=T,dt=dt,N=N,sig=sig,nr_rsq=nr_rsq,nr_deact=nr_deact,cmin_mod=cmin_mod,usestep=usestep,scenvec=scenvec,scenarios=scenarios,finsel_own=finsel_own,
                             cmin_deact=cmin_deact,weightfeat=weightfeat,nrsq=nrsq,topfeat=topfeat,nkeep=nkeep,ntop=ntop,nRsq_max=nRsq_max,constterm=constterm,sindyCVsearches=sindyCVsearches,sindyCVresults=sindyCVresults,
                             ncompo=ncompo,finselvec=finselvec,finselvec2=finselvec2,finsel_plot=finsel_plot,labs=labs,wex=wex,eqlist=eqlist,powsex=powsex,t_mp=t_mp,y_mp=y_mp,x_mp=x_mp,X=X,Y=Y,xraw=xraw,traw=traw,
                             count_R2stop=count_R2stop,count_evistop=count_evistop,count_maxitstop=count_maxitstop)
            
                # if not getting new model, load
            else:
                dat_mdl = np.load(fol+filid+'.npz',allow_pickle=True)
                mdleqs = dat_mdl['mdleqs']
                westis = dat_mdl['westis']
                finselvec = dat_mdl['finselvec']
                xraw = dat_mdl['xraw']
                traw = dat_mdl['traw']
                T = dat_mdl['T']
                dt = dat_mdl['dt']
                ncompo = dat_mdl['ncompo'].item()
                X = dat_mdl['X']
                t_mp = dat_mdl['t_mp']
                sindyCVsearches = dat_mdl['sindyCVsearches'].item()
                sindyCVresults = dat_mdl['sindyCVresults'].item()
                scenvec = dat_mdl['scenvec']
                scenarios = dat_mdl['scenarios'].item()
                finsel_own = dat_mdl['finsel_own']
                count_R2stop = dat_mdl['count_R2stop']
                count_evistop = dat_mdl['count_evistop']
                count_maxitstop = dat_mdl['count_maxitstop']
            
            
            
            
            ##############################################################################################################################
            ##############################################################################################################################
            ##############################################################################################################################
            # evaluation of learned models
            ##############################################################################################################################
            ##############################################################################################################################
            ##############################################################################################################################
            
            
            # initial value problem integrator methods for solving learned models, to be tried in this order
            compos = ['$x$','$y$','$z$'] 
            
            cmap = cm.get_cmap('gist_rainbow')
            cmap.N = len(finsel_plot)
            try:
                cmap(0) # first call throws error, no idea why
            except:
                pass
            
            # initialise
            featvec = ['x_'+str(i+1) for i in range(ncompo)]
            fpred = {} 
            eqpred = {}
            fpredode = {}    
            ic = {}
            tvec = {}
            
            # get true and inferred components of ode function
            for six,pred in enumerate(finsel_plot):
                fpred[pred] = np.empty(ncompo,dtype=object)
                eqpred[pred] = np.empty(ncompo,dtype=object)
                selix = np.where(pred==finselvec2)[0].item() # index to extract model corresponding to pred
                for ix in range(ncompo):
                    eqpred_ = mdleqs[selix,ix][0] # only use first learned model equations
                    wpred = westis[selix,ix][0] 
                    for wix in range(len(wpred)):
                        eqpred_ = eqpred_.subs('w_'+str(wix+1),wpred[wix])
                    fpred[pred][ix] = sy.lambdify(featvec, eqpred_)
                    eqpred[pred][ix] = eqpred_
                # set up ode function for ode solver (still needs to be adjusted by hand at this point)
                fpredode[pred] = lambda tt,xx : np.array([ fpred[pred][0](xx[0],xx[1],xx[2]) , fpred[pred][1](xx[0],xx[1],xx[2]) , fpred[pred][2](xx[0],xx[1],xx[2]) ])
            
            
            ####################################################################################################################################
            ####################################################################################################################################
            
            
            fil_ictest = fol+filid+'_ic-test'+str(Npred)#+'_sindy-opt'+str(sindy_opt_ix)
            if new_ictest or not os.path.exists(fil_ictest+'.npz'):
                
                solpreds = {}
                xpreds = {}
                failures = {s:0 for s in finsel_plot}
                for ix in tqdm(range(Npred),desc='evaluating ICs ('+nr+')',disable=silent):
                    
                    plt = 'prediction'+str(ix)
                    tvec[plt] = np.linspace(0,Tpred,Nt)
                  
                    ic[plt] = sample_ic()
                    solpreds[plt] = {}
                    xpreds[plt] = {}
                    failcount = 0
                    failed = True # still sometimes the random IC gives trouble, in that case it must try again with new random IC
                    while failed and failcount<max_failures:
                        failed = False # set to False once entered while, if one method fails it's set to True, only if no failure for any method it stays False
                        for six,pred in enumerate(finsel_plot):
                            try:

                                xpreds[plt][pred] = ( tvec[plt] , np.full([ncompo,Nt],np.nan) ) # will remain nan if no solution found, otherwise overwritten with solution

                                with timeout(seconds=solivp_TO):
                                    solpreds[plt][pred] = si.solve_ivp(fpredode[pred], [tvec[plt][0], tvec[plt][-1]], ic[plt], method=ivp_meth[0], dense_output=True , atol=atol, rtol=rtol)
                                    xpreds[plt][pred] = ( tvec[plt] , solpreds[plt][pred].sol(tvec[plt]) )
                            
                            except Exception as errmess:
                                
                                failed = True
                                failures[pred] += 1
                                if not silent:
                                    print('solving model from',pred,'failed for random IC. Will try all again with new IC, trial',failcount)
                                    print(errmess)
                        
                        if failed:
                            failcount += 1
                            ic[plt] = np.random.uniform([-20,-25,0],[20,25,50])

                np.savez(fil_ictest, Tpred=Tpred, Npred=Npred, Nt=Nt, xpreds=xpreds, ic=ic, tvec=tvec, failures=failures)
                 
            else:
                
                with np.load(fil_ictest+'.npz',allow_pickle=True) as dat:
                    Tpred = dat['Tpred'].item()
                    Npred = dat['Npred'].item()
                    Nt = dat['Nt'].item()
                    xpreds = dat['xpreds'].item()
                    ic = dat['ic'].item()
                    tvec = dat['tvec'].item()
                    failures = dat['failures'].item()
                    
            maets = {}
            MAEs = {}
            for prix, pred in enumerate(xpreds['prediction0'].keys()):
                if pred in finsel_plot: ### take out methods to only have selection made by finsel_plot in plots
                    maets_ = np.full([Npred,Nt],np.nan)
                    MAEs_ = np.full([Npred,ncompo],np.nan)
                    for plix,plt in enumerate(xpreds.keys()):
                        truth = xpreds[plt]['true'][1].copy()
                        if pred in finsel_plot: # and not pred=='true':
                            maets_[plix,:] = quietmean( np.abs( truth - xpreds[plt][pred][1] ) , 0 ) # mean absolute error across all components time-resolved
                            MAEs_[plix,:] = quietmean( np.abs( truth - xpreds[plt][pred][1] ) , 1 ) # mean absolute error across time for all components separately
                    maets[pred] = maets_
                    if not pred=='true':
                        MAEs[pred] = MAEs_
            
            # get number of terms
            ntrue = np.empty([ncompo,len(finselvec2)],dtype=int)
            nterms = np.empty([ncompo,len(finselvec2)],dtype=int)
            for ix,cmp in enumerate(range(ncompo)): # component
                _, ntrue[ix,:], nterms[ix,:] = uni_eq( mdleq=mdleqs[:,cmp] , nt_max=5 , translt={'1':'','x':'x_{1}','y':'x_{2}','z':'x_{3}'}, compl=True)
            
            if not silent: # plot MAE bar chart
                fig, ax = pp.subplots()
                fig.set_facecolor(bgcol)
                ax.set_facecolor(bgcol)
                for vix,val in enumerate(MAEs.values()):
                    bot = 0
                    for cmp in range(ncompo):
                        col = (cmp+1)/ncompo * np.array(cmap(vix))
                        col[-1] = 1
                        b = ax.bar(vix, quietmean(val[:,cmp]), bottom=bot, color=col)
                        ax.bar_label(b,labels=[compos[cmp]], label_type='edge',padding=-11)
                        ax.bar_label(b,labels=["$\;\;\;\,{0:d}\,({1:d})$".format(ntrue[cmp,-1]-ntrue[cmp,vix],nterms[cmp,vix])], label_type='edge',padding=-22,fontsize=9)
                        bot += quietmean(val[:,cmp])
                    ax.boxplot(val.sum(1), positions=[vix-0.25], flierprops={'ms':0.4})
                ax.set_xticks( np.arange(len(finsel_plot[:-1])) )
                ax.set_xticklabels( labs[:-1] , fontsize=fs-5 )
                if ax.get_ylim()[-1]>3:
                    ax.set_ylim([0,3])
                ax.set_ylabel('$\\mathrm{MAE}$')
                ax.set_title("$N={0}, dt={1}, \\sigma={2}$".format(N,dt,sig))
                pp.show()
            
            # determine number of true terms missed and number of extra terms added
            foundtruescen = []
            for pix,pred in enumerate(finsel_plot[:-1]):
                foundall = (ntrue[:,-1]==ntrue[:,pix])
                found_allterms[pred] += foundall.all()
                foundtrue = np.logical_and( ntrue[:,-1]==ntrue[:,pix] , nterms[:,pix]==0 )
                found_truemodel[pred] += foundtrue.all()
                foundtruescen.append( foundtrue.sum() )
            
            # get table of true model found per scenario
            found_truemodel_scen.append( foundtruescen )
            
            # get average MAE and average MAE over model size overall
            MAE_mean = [quietmean(v[(v<MAEcut).all(1),:]) for v in MAEs.values()] # only take those MAEs that do not blow MAEcut in any component. if all blow it, a NaN is produces, which is then taken out. all blewn MAEs count as separate failure (blewfail)
            found_MAEs.append( MAE_mean )
            MAEms_mean = [quietmean(v[(v<MAEcut).all(1),:]) * sum( [len(str(sz[0]).split('+')) for sz in mdleqs[vix]] ) for vix,v in enumerate(MAEs.values())]
            found_MAEsms.append( MAEms_mean )
            # also get average log-MAE overall
            logMAE_mean = [quietmean(np.log(v)) for v in MAEs.values()]
            found_logMAEs.append( logMAE_mean )
            
            # get ranking in terms of MAE
            MAEuni, cnt = np.unique( MAE_mean , return_counts=True) # sorted scores in ascending order, together with how often these score occured
            rank = 0 # initiliase rank
            ranks = np.full(len(finsel_plot)-1,np.nan)
            for e,r in zip(MAEuni,cnt): # go through occured MAEs and their counts
                ranks[ np.where(MAE_mean==e)[0] ] = rank # all methods with current MAE e get that rank
                rank += r # increment current rank for MAE e according to counts, that way the lower (better) rank is used for methods with same MAE
            found_MAEranking.append( list(ranks) )
            
            # count failures
            found_failures['true'] = 0
            for pred in finsel_plot:
                found_failures[pred] += failures[pred]
            for pred in finsel_plot :
                if not pred=='true': # do not use true model here and just add below
                    found_blewfail[pred] += (MAEs[pred]>=MAEcut).any(1).sum() # if MAE in any component above MAEcut, count as failure due to blown MAE
            found_blewfail['true'] = 0 # per definition (MAE from distance from simulation of true model)
            
            if not silent: # plot average time evolution of MAE(t) across different ICs
                cmap_ = cmap.copy() # make copy of used colorbar
                cmap_.N = cmap.N-1 # reduce number of colors by one (as don't want to show 'true'), effectivly removes last color
                fig, ax = pp.subplots(2,1, gridspec_kw={'height_ratios': [10, 1]}, constrained_layout = True)
                fig.set_facecolor(bgcol)
                ax[0].set_facecolor(bgcol)
                vix = -1
                for pred,val in maets.items():
                    vix += 1
                    if not pred=='true':
                        col = np.array(cmap(vix))
                        col[-1] = 1.
                        ax[0].plot(tvec[plt],quietmean(val,0),color=col,lw=1.7)
                if ax[0].get_ylim()[-1]>1:
                    ax[0].set_ylim([0,1])
                ax[0].set_ylabel('$t$')
                ax[0].set_ylabel('$\\mathrm{MAE}(t)$')
                ax[0].set_title("$N={0}, dt={1}, \\sigma={2}$".format(N,dt,sig))
                cbar = cb.ColorbarBase(ax[1], ticks=np.linspace(0,1,len(finsel_plot))[:-1]+1/(len(finsel_plot)-1)/2, cmap=cmap_, orientation='horizontal')
                cbar.ax.set_xticklabels(labs[:-1], fontsize=fs-5)
                pp.show()
        
        if not silent: 
            print("Completed Scenario with N={0:d}, T={1:.1f}, sigma={2:.4f}".format(N, T, sig))


    
found_MAE_scen = np.array(found_MAEs)
found_MAEms_scen = np.array(found_MAEsms)

nscen = len(scenvec)-len(failed_nr)

print('\n\nreport from all test scenarios:')
print('\nscenarios failed:',failed_nr)
print('\nscenarios tested:',nscen)
print('\nnumber of times all true terms found (in %):')
[print(k,np.round(v/nscen*100,1),'\\!\%') for k,v in found_allterms.items()]
print('\nnumber of times exact true model found (in %)::')
[print(k,np.round(v/nscen*100,1),'\\!\%') for k,v in found_truemodel.items()]
print('\naverage MAE (per component) of methods (lower=better):')
[print(k,np.round(v,2)) for k,v in zip(finselvec,quietmean(np.array(found_MAE_scen),0))]
print('\naverage ranking of methods in terms of MAE (lower=better):')
[print(k,np.round(v+1,2)) for k,v in zip(finselvec,quietmean(np.array(found_MAEranking),0))]
print('\nnumber of failures solving models with',str(Npred),'random ICs (in %):')
[print(k,np.round(v/nscen/Npred*100,2),'\\!\%') for k,v in found_failures.items()]
print('\nnumber of solutions for', str(Npred),'random ICs exceeding average MAE of',str(MAEcut),'(in %):')
[print(k,np.round(v/nscen/Npred*100,2),'\\!\%') for k,v in found_blewfail.items()]
if not silent:
    print('\n',fol)

print('\n percentage of R2 iteration terminating successfully:')
print( np.round(count_R2stop / (count_R2stop + count_maxitstop)*100,2),'\\!\%')
print('\n percentage of R2 iteration reached maximum iteration and forced stop:')
print( np.round(count_maxitstop / (count_R2stop + count_maxitstop)*100,2),'\\!\%')



finselfromMAE = np.array(list(MAEs.keys()))
finsel_plot = [lb for lb in finselfromMAE if not lb in ['$p(M)$-$R^2$']] # can take out methods, but must adjust labs below manually
cmap = cm.get_cmap('gist_rainbow')
cmap.N = len(finsel_plot)
try:
    cmap(0) # first call throws error, no idea why
except:
    pass
labs = ['$\\mathrm{LASSO}$', '$\\mathrm{CS}$-$R^2$', '$\\mathrm{CS}$-$p(M)$', '$\\mathrm{SR}$', '$\\mathrm{SR3}$', '$\\mathrm{FROLS}$', '$\\mathrm{STLSQ}$', '$\\mathrm{MIOSR}$'] 

scpars = ['$N$','$\\Delta t$','$\\sigma$','']

ylims1 = None ; ylims2 = None ; showfliers = False
# ylims1 = [ [0,3.5] , [0,7] , [0,4] , [0,2] ] ; ylims2 = [ [0,50] , [0,50] , [0,50] , [0,50] ] ; showfliers = True # MAE

fixedIC = 0 # could also compare ICs (multipl in ic_vec and set to None), but since no real difference, I set 0 here which suppresses comparison for ICs

for cutix, cut in enumerate([N_vec, dt_vec, sig_vec, ic_vec]):
    
    if True:
        
        MAE_cut = [[] for _ in cut]
        MAEms_cut = [[] for _ in cut]
        true_cut = [[] for _ in cut]
    
        for scix,sc in enumerate(scenarios.values()):
            if fixedIC is None or sc[-1]==ic_vec[fixedIC]:
                idvec = sc[cutix]==np.array(cut)
                while len(idvec.shape) > 1:
                    idvec = idvec.all(1)
                MAE_cut[ np.where(idvec)[0].item() ].append(found_MAE_scen[scix])
                MAEms_cut[ np.where(idvec)[0].item() ].append(found_MAEms_scen[scix])
                true_cut[ np.where(idvec)[0].item() ].append(found_truemodel_scen[scix])
        
        lab_posvec = np.linspace(-0.4,0.4,len(finsel_plot))
        fig, ax = pp.subplots(2,1, gridspec_kw={'height_ratios': [10, 1]}, constrained_layout = True)
        fig.set_facecolor(bgcol)
        ax[0].set_facecolor(bgcol)
        for pl in range(len(cut)):
            for ix,finsel in enumerate(finsel_plot):
                fix = np.where(finsel==finselfromMAE)[0].item()
                col = np.array(cmap(ix))
                col_dark = [0.5,0.5,0.5,1.] * col
                flierprops = dict(marker='x', markerfacecolor=bgcol, markersize=5, markeredgecolor=col, markeredgewidth=1.1)
                mask = np.logical_not( np.isnan(np.array(MAE_cut[pl])[:,fix]) )
                boxpl = ax[0].boxplot( np.array(MAE_cut[pl])[mask,fix], positions=[pl + lab_posvec[ix]], widths=[0.08], showfliers=showfliers, flierprops=flierprops, patch_artist=True ) # patch_artist allows control of plot parts like below
                boxpl['boxes'][0].set_facecolor(bgcol)
                boxpl['boxes'][0].set_edgecolor(col)
                boxpl['boxes'][0].set_linewidth(2)
                boxpl['whiskers'][0].set_color(col)
                boxpl['whiskers'][1].set_color(col)
                boxpl['whiskers'][0].set_linewidth(1.7)
                boxpl['whiskers'][1].set_linewidth(1.7)
                boxpl['caps'][0].set_color(col)
                boxpl['caps'][1].set_color(col)
                boxpl['caps'][0].set_linewidth(1.7)
                boxpl['caps'][1].set_linewidth(1.7)
                boxpl['medians'][0].set_color(col_dark)
                boxpl['medians'][0].set_linewidth(4)
                ax[0].plot( pl + lab_posvec[ix], quietmean(np.array(MAE_cut[pl])[:,fix]), 'o', markersize=7, markeredgewidth=1.4, markeredgecolor=[0,0,0], color=col)
                ax[0].set_xlabel(scpars[cutix])
                ax[0].set_ylabel('$\\mathrm{MAE}$')
        if len(cut)>=2:
            ax[0].set_xticks(range(len(cut)))
            ax[0].set_xticklabels(cut)
            for vlinpos in np.arange(0,len(cut)+1):
                ax[0].axvline(vlinpos-0.5,color=[0.5,0.5,0.5],lw=1.4)
            ax[0].set_xlim([-0.6,len(cut)-1+0.6])
        else:
            ax[0].tick_params(
                axis='x',          # changes apply to the x-axis
                which='both',      # both major and minor ticks are affected
                bottom=False,      # ticks along the bottom edge are off
                top=False,         # ticks along the top edge are off
                labelbottom=False) # labels along the bottom edge are off
            ax[0].set_xlim([-0.45,0.45])
        if not ylims1 is None:
            ax[0].set_ylim(ylims1[cutix])
        cbar = cb.ColorbarBase(ax[1], ticks=np.linspace(0,1,len(finsel_plot)+1)[:-1]+1/(len(finsel_plot))/2, cmap=cmap, orientation='horizontal')
        cbar.ax.set_xticklabels(labs, fontsize=fs-5)
        pp.show()
        
        lab_posvec = np.linspace(-0.4,0.4,len(finsel_plot))
        fig, ax = pp.subplots(2,1, gridspec_kw={'height_ratios': [10, 1]}, constrained_layout = True)
        fig.set_facecolor(bgcol)
        ax[0].set_facecolor(bgcol)
        for pl in range(len(cut)):
            for ix,finsel in enumerate(finsel_plot):
                fix = np.where(finsel==finselfromMAE)[0].item()
                col = np.array(cmap(ix))
                col_dark = [0.5,0.5,0.5,1.] * col
                flierprops = dict(marker='x', markerfacecolor=bgcol, markersize=5, markeredgecolor=col, markeredgewidth=1.1)
                mask = np.logical_not( np.isnan(np.array(MAEms_cut[pl])[:,fix]) )
                boxpl = ax[0].boxplot( np.array(MAEms_cut[pl])[mask,fix], positions=[pl + lab_posvec[ix]], widths=[0.08], showfliers=showfliers, flierprops=flierprops, patch_artist=True ) # patch_artist allows control of plot parts like below
                boxpl['boxes'][0].set_facecolor(bgcol)
                boxpl['boxes'][0].set_edgecolor(col)
                boxpl['boxes'][0].set_linewidth(2)
                boxpl['whiskers'][0].set_color(col)
                boxpl['whiskers'][1].set_color(col)
                boxpl['whiskers'][0].set_linewidth(1.7)
                boxpl['whiskers'][1].set_linewidth(1.7)
                boxpl['caps'][0].set_color(col)
                boxpl['caps'][1].set_color(col)
                boxpl['caps'][0].set_linewidth(1.7)
                boxpl['caps'][1].set_linewidth(1.7)
                boxpl['medians'][0].set_color(col_dark)
                boxpl['medians'][0].set_linewidth(4)
                ax[0].plot( pl + lab_posvec[ix], quietmean(np.array(MAEms_cut[pl])[:,fix]), 'o', markersize=7, markeredgewidth=1.4, markeredgecolor=[0,0,0], color=col)
                ax[0].set_xlabel(scpars[cutix])
                ax[0].set_ylabel('$N_{\\mathrm{model}}\\cdot\\mathrm{MAE}$')
        if len(cut)>=2:
            ax[0].set_xticks(range(len(cut)))
            ax[0].set_xticklabels(cut)
            for vlinpos in np.arange(0,len(cut)+1):
                ax[0].axvline(vlinpos-0.5,color=[0.5,0.5,0.5],lw=1.4)
            ax[0].set_xlim([-0.6,len(cut)-1+0.6])
        else:
            ax[0].tick_params(
                axis='x',          # changes apply to the x-axis
                which='both',      # both major and minor ticks are affected
                bottom=False,      # ticks along the bottom edge are off
                top=False,         # ticks along the top edge are off
                labelbottom=False) # labels along the bottom edge are off
            ax[0].set_xlim([-0.45,0.45])
        if not ylims2 is None:
            ax[0].set_ylim(ylims2[cutix])
        cbar = cb.ColorbarBase(ax[1], ticks=np.linspace(0,1,len(finsel_plot)+1)[:-1]+1/(len(finsel_plot))/2, cmap=cmap, orientation='horizontal')
        cbar.ax.set_xticklabels(labs, fontsize=fs-5)
        pp.show()
        
        fig, ax = pp.subplots(2,1, gridspec_kw={'height_ratios': [10, 1]}, constrained_layout = True)
        fig.set_facecolor(bgcol)
        ax[0].set_facecolor(bgcol)
        for pl in range(len(cut)):
            for ix,finsel in enumerate(finsel_plot):
                fix = np.where(finsel==finselfromMAE)[0].item()
                col = np.array(cmap(ix))
                col_dark = [0.5,0.5,0.5,1.] * col
                flierprops = dict(marker='x', markerfacecolor=bgcol, markersize=5, markeredgecolor=col, markeredgewidth=1.1)
                mask = np.logical_not( np.isnan(np.array(true_cut[pl])[:,fix]) )
                boxpl = ax[0].boxplot( np.array(true_cut[pl])[mask,fix], positions=[pl + lab_posvec[ix]], widths=[0.08], showfliers=showfliers, flierprops=flierprops, patch_artist=True ) # patch_artist allows control of plot parts like below
                boxpl['boxes'][0].set_facecolor(bgcol)
                boxpl['boxes'][0].set_edgecolor(col)
                boxpl['boxes'][0].set_linewidth(2)
                boxpl['whiskers'][0].set_color(col)
                boxpl['whiskers'][1].set_color(col)
                boxpl['whiskers'][0].set_linewidth(1.7)
                boxpl['whiskers'][1].set_linewidth(1.7)
                boxpl['caps'][0].set_color(col)
                boxpl['caps'][1].set_color(col)
                boxpl['caps'][0].set_linewidth(1.7)
                boxpl['caps'][1].set_linewidth(1.7)
                boxpl['medians'][0].set_color(col_dark)
                boxpl['medians'][0].set_linewidth(4)
                ax[0].plot( pl + lab_posvec[ix], quietmean(np.array(true_cut[pl])[:,fix]), 'o', markersize=7, markeredgewidth=1.4, markeredgecolor=[0,0,0], color=col)
                ax[0].set_xlabel(scpars[cutix])
                ax[0].set_ylabel('identification accuracy')
        if len(cut)>=2:
            ax[0].set_xticks(range(len(cut)))
            ax[0].set_xticklabels(cut)
            for vlinpos in np.arange(0,len(cut)+1):
                ax[0].axvline(vlinpos-0.5,color=[0.5,0.5,0.5],lw=1.4)
                ax[0].set_xlim([-0.6,len(cut)-1+0.6])
        else:
            ax[0].tick_params(
                axis='x',          # changes apply to the x-axis
                which='both',      # both major and minor ticks are affected
                bottom=False,      # ticks along the bottom edge are off
                top=False,         # ticks along the top edge are off
                labelbottom=False) # labels along the bottom edge are off
            ax[0].set_xlim([-0.45,0.45])
        cbar = cb.ColorbarBase(ax[1], ticks=np.linspace(0,1,len(finsel_plot)+1)[:-1]+1/(len(finsel_plot))/2, cmap=cmap, orientation='horizontal')
        cbar.ax.set_xticklabels(labs, fontsize=fs-5)
        pp.show()

