
from chs import modsel, uni_eq
import datagen
import numpy as np
from pathlib import Path
import matplotlib.pyplot as pp
from matplotlib import colormaps as cm
from matplotlib import colorbar as cb
import os
from tqdm import tqdm

bgcol = [1,1,1,1]
pp.style.use('classic')


fs = 21
fstxt = 14 # for text in plots
pp.rc('font', size=fs)          # controls default text sizes
pp.rc('axes', titlesize=fs)     # fontsize of the axes title
pp.rc('axes', labelsize=fs)    # fontsize of the x and y labels
pp.rc('xtick', labelsize=fs-2)    # fontsize of the tick labels
pp.rc('ytick', labelsize=fs-2)    # fontsize of the tick labels
pp.rc('legend', fontsize=fs-2)    # legend fontsize
pp.rc('figure', titlesize=fs-2)  # fontsize of the figure title

njobs = 1 # number of nodes for parallel computing
rndpoly = True # new random polynomial each run?
Npoly = 100 # number of random polynomials to go through

mtrue = 2 # true model size (number of non-zero weights in true polynomials)

saveresults = True
silent = False

# artifial data parameters
set_maxdeg_fac = 2 # maximum degree (individual factors)
set_maxdeg_term = 4 # maximum degree (terms)
set_n = 3 # number of features in X
set_m1 = mtrue # = true m = number of non-zero terms numerator
set_m2 = 0 # number of non-zero terms denominator (if set zero then denominator is set 1 as i usual linear regression)
set_sig = 0.01 # noise level
if set_m1 == 2: # number of data points sampled depending on number of terms in true model
    set_N = 20 
elif set_m1 == 3:
    set_N = 65 
elif set_m1 == 4:
    set_N = 95 
if set_m2 == 0: # if no terms in denominator, set fractions to False
    fractions = False
elif set_m2 >= 1:
    fractions = True

### hyper-parameters
# for comprehensive search approach
nr_rsq = 2 # speficies how often a feature must be selected in Rsq feature rating to be selected for inferred model
nr_deact = 2 # specifies in how many previous feature ratings a feature must not be selected to be deactivated
nr_evi = 2 # specifies how many times in a row sum of model evidence of top models must reduce for model size incrementation to terminate
cmin_mod = 0.75 # selection threshold above which feature is selected for inferred Rsq model
cmin_deact = 0.0 # selection threshold below which features are deactivated
weightfeat = True # weight features with criterion for feature selection rating? (here for Rsq)
usestep = False # use step detection in criterion for feat2mod and deactfeat (here for Rsq)
nrsq = 500 # how many models to keep in Rsq listings
topfeat = None # specifies how many top Rsq models shall be used for feature ranking 
# (if None, half of number of basis functions (=p/2) is chosen later if usestep=False, if usestep=True then topfeat is chosen from step detection
nkeep = 25 # how many models for each Rsq for top listing for model selection
ntop = None # cut top listing for model evidence to this number of models [if None then same as topfeat]
nRsq_max = 8 # maximum number of terms to be considered, if reached CS aborts and best Rsq model for model sized given by model evidence is selected
cosel_thresh = 1. # features with correlation larger than that are considered connected, and (if used) co-selected or not co-deselected 
constterm = True # include constant term

# for bfe
maxdeg = 4 # maximum degree (individual factors)
maxdeg_term = 6 # maximum degree (terms)

# for model evidence
heur = True # use known normal parameters (mean and variance) from OLS, assuming normality
h_scalfac = 'set_N' # factor by which distributions are broadened (use string to use later defined variables)

# data handling (standardisation method)
stdmeth = 'normal' # 'none', 'normal', 'normglob', 'centre', 'centglob', 'unit'

testcase = '-' # to specify filename of stored results

# create folder if necessary
fol = 'polys_m'+'___'+testcase+'___'+str(set_m1)+'_deact'+str(cmin_deact).replace('.','')+usestep*'_steps'+'/'
if not os.path.exists(fol):
    os.mkdir(fol)
# file name for keeping results
fil = fol + "{0:d}{1:d}_{2:d}{3:d}{4:d}_{5:d}_{6:.2f}".format(set_maxdeg_fac,set_maxdeg_term,set_n,set_m1,set_m2,set_N,set_sig)
if not Path(fil + '__' + str(0) + '__res' + '.npz').is_file(): # if no polynomial results have been saved yet, start from scratch
    # create names for feature ratings, used below in comprehensive search approach
    finselvec = ['LASSO', 'LARS', 'CS-$R^2$', 'CS-$p(M)$', '$p(M)$-$R^2$', 'BSR']#, 'SR3', 'FROLS', 'STLSQ', 'MIOSR'] # LASSO R2sel evi SWR-evi SR3 FROLS STQLS MIOSR
    finsel_own = list(range(6)) # to index how many selection criteria is embedded in own algorithm, thereafter SINDy is used
    finixR2 = 2
    finixpm = 3
else:
    with np.load(fil + '__' + str(0) + '__res' + '.npz') as dat:
        finselvec = dat['finselvec']
        finsel_own = dat['finsel_own']

# for plots
finselvec2 = finselvec # this not done here: np.array(list(finselvec) + ['true']) # add true as last position
finsel_plot = [lb for lb in finselvec2 if not lb in ['$p(M)$-$R^2$']] # can take out methods, but must adjust labs below manually
labs = ['$\\mathrm{LASSO}$', '$\\mathrm{LARS}$', '$\\mathrm{CS}$-$R^2$', '$\\mathrm{CS}$-$p(M)$', '$\\mathrm{BSR}$']


ncompo = 1 # just for compatibility reasons with learning nonlinear dynamics
compo = 0 # just for compatibility reasons with learning nonlinear dynamics

mdleqs = np.empty([len(finsel_own)+1,ncompo],dtype=object) # one extra row for true model extracted later
westis = np.empty([len(finsel_own)+1,ncompo],dtype=object)

fndcnt = np.zeros(len(finselvec)) # count how often correct model found in top listing
fnd1st = np.zeros(len(finselvec)) # count how often top model (1st rank) is correct model
fndrnk = np.zeros(len(finselvec)) # add up ranks to have a measure of how far up in final top listing correct model is placed

found_ntrue = []
found_nextra = []

count_R2stop = 0
count_evistop = 0
count_maxitstop = 0
# create Npoly random polynomials, data from it, and try to learn polynomial from this data, save progress in separate files
for pix in tqdm(range(Npoly),desc='solving/loading polynomials (true size ' +str(mtrue)+ ')',disable=not silent):
    
    if Path(fil + '__' + str(pix) + '__res' + '.npz').is_file(): # load available results
    
        dat = np.load(fil + '__' + str(pix) + '__res' + '.npz',allow_pickle=True)
        pix0 = dat['pix']
        finselvec = dat['finselvec']
        finsel_own = dat['finsel_own']
        finixR2 = dat['finixR2']
        finixpm = dat['finixpm']
        exfoundix = dat['exfoundix'].item()
        mdleqs = dat['mdleqs']
        westis = dat['westis']
        count_R2stop = dat['count_R2stop']
        count_evistop = dat['count_evistop']
        count_maxitstop = dat['count_maxitstop']

    else:
            
        # generate and save new artificial data
        if not Path(fil + '__'+str(pix) + '.npz').is_file() or rndpoly:
            K, y, wex, vex, terms, exix1, exix2, X = datagen.nonlinreg_impl( maxdeg=set_maxdeg_fac, maxdeg_term=set_maxdeg_term, n=set_n, m1=set_m1, m2=set_m2, N=set_N, sig=set_sig, seed=None)
            np.savez(fil + '__'+str(pix),X=X,y=y,terms=terms,exix1=exix1,exix2=exix2,wex=wex,vex=vex)
        
        # use data from file
        dat = np.load(fil + '__'+str(pix)+'.npz')
        ms = modsel(X=dat['X'], y=dat['y'], njobs=njobs, fractions=fractions, silent=silent, darkmode=False)
        ms.bfe(maxdeg=maxdeg, maxdeg_term=maxdeg_term, known=None,knownstr=None, stdise='none', pex=None, constterm=constterm, verbose=False) # stise done below
        ms.telltruth(trueterms=(dat['terms'][dat['exix1']],dat['terms'][dat['exix2']]),w=(dat['wex'][dat['exix1']],dat['vex'][dat['exix2']]))
                
        ms.stdise(verbose=False,method=stdmeth)
        
        if not usestep and topfeat is None:
            topfeat = int(round( 0.5*ms.p )) # this seems to be a good choice if step detection is not used
        if not usestep:
            step_interval = ms.p # step_interval is also used for plot_modsel, put this value in case usestep==False
        
        # set parameters for Bayesian model evidence
        ms.evipars['heur'] = heur # use known normal parameters (mean and variance) from OLS, assuming normality
        if heur:
            ms.evipars['m'] = 1 / ms.y.std(0)**2 # mode of prior for precision of noise (gamma distribution), ms.y.std(0) will be 1 due to standardisation [empirical Bayes choice]
            ms.evipars['v'] = 0.5 # variance of prior for precision of noise (pp.plot(x,sc.stats.gamma.pdf(x,m/v+1,0,v))) [empirically turned out to be a good universal value]
        ms.evipars['h_scalfac'] = eval(str(h_scalfac)) # factor by which distributions are broadened
                               
        # set up deselection matrix to prevent deactivation of features correlated with often selected features
        ms.calc_comat(meth='corr') # 'corr' or 'mi' (=normalised mutual information)
        ms.coselmat(threshold=cosel_thresh) # if above threshold, features assumed to be connected, used to prevent deactivation of features connected to active features
        if not silent:
            ms.plot_comat() 
                    
        finix = -1 # index to be incremented as we go through methods, always before new method
        
        # LASSO
        finix += 1
        nam = finselvec[finix] # name for feature rating 
        ms.lasso(boots=1,bootsize=1,nam=nam) # run feature rating without bootstrapping, and save into rating table <nam>
        ms.newtop(nam=nam,ntop=1) # create new top model list
        ms.feat2mod(nam=nam,nr=1,cmin=0) # select model from most selected features in feature rating
        
        # LARS
        finix += 1
        nam = finselvec[finix] # name for feature rating 
        ms.lars(boots=1,bootsize=1,nam=nam) # run feature rating without bootstrapping, and save into rating table <nam>
        ms.newtop(nam=nam,ntop=1) # create new top model list
        ms.feat2mod(nam=nam,nr=1,cmin=0) # select model from most selected features in feature rating
                     
        # deactivate features not connected to LASSO selected features
        # ms.deactfeat(nr=1,cmin=0,featmin=10) ### turned out to be not beneficial for identification accuracy
        
        # initiliase empty top model list with one model (ntop=1)
        finix += 1 # = finixR2
        ms.newtop(nam=finselvec[finix],ntop=1)
            
        ### R-squared elimination procedure (comprehensive search)
        # start with 1 term
        nams_R2_fixedsize = [] # init list of fixed-sized R2 top model listings
        nams_R2_fixedsize.append('$R^2$, 1 terms') # add name for top model listing for 1-sized models
        ms.newtop(nam=nams_R2_fixedsize[-1],ntop=nrsq) 
        ms.val_lnrsq(nt=1 ,presel='all') # compute R-squared (log-scale) for all models
        if usestep:
            step_interval = ms.p-2-1 # only look in this interval (top models) for a step
            # (as after that many top models, all wrong features have been selected once, if nt is the correct number of features)
        ms.mod2feat(mod=nams_R2_fixedsize[-1],topnum=topfeat,weight=weightfeat,step=usestep,step_interval=step_interval) # use best topfeat models to create feature rating from that
        # ms.plot_featsel() # plot current feature rating
        if not silent:
            ms.plot_modsel(nt=step_interval) # plot last obtained top models
        
        # also get sum of top evis for 1-term models
        feats_copy = ms.feats.copy() # copy active feature list to restore after evi evaluation below
        ms.feats = np.arange(ms.fac_p*ms.p) # make all features active again, otherwise ms.evi throws an error as models in top are always of shape set here
        topmods = ms.top[ms.ranknams[-1]].toarray()
        evis_ = np.full(nkeep,np.nan)
        for mix in range(min(nkeep,topmods.shape[0])): # could also just look at top one in case evi needs to be estimated
            evis_[mix] = ms.evi( topmods[mix] , [] )[0][0]
        evis = np.array([evis_.max()])
        ms.feats = feats_copy.copy() # restore list of active features in case needed later
        
        # now loop through all candidate models with more terms
        nRsq = 1 # initiliase number of terms to be considered
        R2_success = False # initiliase R2 stepwise regression criterion
        evi_stop = False # initialise evi-stopping criterion
        while not R2_success and not evi_stop and nRsq<nRsq_max: # if not successfully found model from last nr Rsq listings, increase number of terms
                    
            if nRsq >= 1+1: # if two or more feature ratings from R-squared have been created, remove features from models that were not selected
                ms.deactfeat(nr=nr_deact,cmin=cmin_deact,featmin=nRsq_max)
                
            nRsq += 1 # increment number of terms in models
            nams_R2_fixedsize.append('$R^2$, ' + str(nRsq) + ' terms') # name of model lists
        
            ms.newtop(nam=nams_R2_fixedsize[-1],ntop=nrsq) # next model list
            ms.val_lnrsq(nt=nRsq ,presel='all') # compute R-squared
            if usestep:
                step_interval = ms.p-nRsq-1 # only look in this interval (top models) for a step
                # (as after that many top models, all wrong features have been selected once, if nt is the correct number of features)
            ms.mod2feat(mod=nams_R2_fixedsize[-1],topnum=topfeat,weight=weightfeat,step=usestep,step_interval=step_interval) # create feature rating from that
            # ms.plot_featsel()
            if not silent:
                ms.plot_modsel(nt=step_interval)
            
            if nRsq >= nr_rsq+1: # if two or more feature ratings from R-squared have been created
                ms.feat2mod(nam=finselvec[finixR2],nr=nr_rsq,cmin=cmin_mod) # try and select from last nr Rsq listings
            # created top model listing remains empty if no model could be identified, in which case while loop continues
            
            # determine max model evidence for nkeep top models for this R2 round
            feats_copy = ms.feats.copy() # copy active feature list to restore after evi evaluation below
            ms.feats = np.arange(ms.fac_p*ms.p) # make all features active again, otherwise ms.evi throws an error as models in top are always of shape set here
            topmods = ms.top[ms.ranknams[-1]].toarray()
            evis_ = np.full(nkeep,np.nan)
            for mix in range(min(nkeep,topmods.shape[0])): # could also just look at top one in case evi needs to be estimated
                evis_[mix] = ms.evi( topmods[mix] , [] )[0][0]
            evis = np.append(evis,evis_.max())
            ms.feats = feats_copy.copy() # restore list of active features in case needed later
            
            R2_success = ms.top[finselvec[finixR2]].sum()>0 # if top model list for R2 stepwise regression is empty, then no model has been identified yet, and we keep going to increase model size (while loop)
            if R2_success:
                count_R2stop += 1
            
            # if nRsq >= 2+1: # if three or more rounds of R2 testing have been done, check if evi stopping criterion has been met
            #     evi_stop = (evis[-1] < evis[-nr_evi-1:-1]).all()
            #     count_evistop += 1
            #     if evi_stop:
            #         print('Decline in model evidence caused early stopping.')
            
        # plot combined Rsq results for feature selection
        # ms.plot_featsel()
            
        # combine Rsq for top model listing and plot
        ms.comb_top(topnams=nRsq,topkeeps=nkeep,nam='top $R^2$',sortit=True) # 'topnams=nRsq' makes that las nRsq top listings are used, which are the fixe-sized R2 listings, could also use 'topnams=nams_R2_fixedsize'
        if not silent:
            ms.plot_modsel()
            # ms.plot_modprop()
                            
        # compute model evidence for combined top model listing
        finix += 1
        ms.newtop(nam=finselvec[finix],ntop=50) # create top model listing where top 25 are kept
        ms.val_evi(presel='top $R^2$')  
        if not silent:
            ms.plot_modsel()                   
        
        # if CS-R2 did not find any models (i.e. ms.top[finselvec[1]] has no non-zero elements), get best Rsq model of size selected by evi
        # ... for model size selected by evi all models in top model listing are considered, so the size of the best model in the evi top listing computed above is selected
        if not R2_success: 
            evisize = ms.top[finselvec[finixpm]][0,:].sum() # get model size (number of equations) selected by evi
            ms.top[finselvec[finixR2]] = ms.top['$R^2$, ' + str(evisize) + ' terms'][0] # set top Rsq model of that model size as DaC (SCS) result
            count_maxitstop += 1
            if not silent:
                print('Iterative R2 did not identify model size, used model evidence instead.')
                                                
        # pick overall best R^2 model by evi among all best R^2 models for each model size category
        ms.comb_top(topnams=nams_R2_fixedsize,topkeeps=1,nam='top-1 $R^2$',sortit=True) # combine one best model per model size
        finix += 1
        ms.newtop(nam=finselvec[finix],ntop=nRsq) # create new top model listing with defined name in finselvec, ...
        ms.val_evi(presel='top-1 $R^2$') # ..., where the top models from the combined top-1 are evaluated and sorted with evi
        # ms.plot_modsel()  
        
        # bidirectional stepwise regression
        finix += 1
        nam = finselvec[finix] # name for feature rating 
        ms.swsel(crit='evi',nam=nam)
        ms.newtop(nam=nam,ntop=1) # create new top model list
        ms.feat2mod(nam=nam,nr=1,cmin=0) # select model from most selected features in feature rating
                            
        # collect equations and extract true solution
        for selix,finsel in enumerate([finselvec[ix] for ix in finsel_own]):
            mdleqs_, westis_ = ms.gettop(nam=finsel,maxeqs=5,print_plot=False) # get top models (includes true model as first entry)
            mdleqs[selix,compo], westis[selix,compo] = mdleqs_[1:], westis_[1:] # cut out true model, as true model extracted below separately
        mdleqs_, westis_ = ms.gettop(nam=finselvec[0],maxeqs=5,print_plot=False) # (nam is arbitrary is only interested in true model which is always put out first)
        mdleqs[selix+1,compo], westis[selix+1,compo] = mdleqs_[:1], westis_[:1] # extract true model and add as last entry in mdleqs and westis
        
        # save which rank true model was found for various methods (empty if not found)
        exfoundix = ms.exfound_ix
        
        # save intermediate results
        if saveresults:
            np.savez(fil + '__'+str(pix) + '__res', pix=pix,Npoly=Npoly,finselvec=finselvec,fndcnt=fndcnt,fnd1st=fnd1st,fndrnk=fndrnk,exfoundix=exfoundix,
                     mdleqs=mdleqs,westis=westis,count_R2stop=count_R2stop,count_evistop=count_evistop,count_maxitstop=count_maxitstop,finsel_own=finsel_own,finixR2=finixR2,finixpm=finixpm)

    # determine success rates
    for ix in range(len(finselvec)):
        fnd = exfoundix[finselvec[ix]]
        if len(fnd)>0:
            fndcnt[ix] += 1
            fndrnk[ix] += fnd[0]
            if fnd==0:
                fnd1st[ix] += 1
    
    foundequs, ntrue, nextra = uni_eq( mdleq=mdleqs[:,compo] , nt_max=5, compl=True)
    found_ntrue.append( list(ntrue) )
    found_nextra.append( list(nextra) )
finselvec2 = finselvec

p = 72
if ntop is None:
    ntop = int(round( 0.5*p ))

fig, ax = pp.subplots()
fig.set_facecolor(bgcol)
ax.set_facecolor(bgcol)
b1 = ax.bar(finselvec2,fndcnt,color=[0.5,0.2,0.1],label='in top '+str(ntop))
b2 = ax.bar(finselvec2,fnd1st,color=[0.3,0.7,0.2],label='ground truth')
ax.axhline(Npoly,color=[0.7,0.7,0.7])
ax.set_ylabel('identification accuracy')
ax.set_ylim([0,Npoly+10])

for ix in range(len(finselvec)):
    h1 = b1[ix].get_height()
    h2 = b2[ix].get_height()
    xpos = b1[ix].get_x()
    bw = b1[ix].get_width()
    if h1>0:
        pp.text(xpos + bw/2, max(25/Npoly,h1-2*25/Npoly), "{0:d}%".format(int(np.round(fndcnt[ix]/Npoly*100))), ha='center', va='bottom', fontsize=fstxt)
    if h2>0 and not h2==h1:
        pp.text(xpos + bw/2, max(25/Npoly,h2-2*25/Npoly), "{0:d}%".format(int(np.round(fnd1st[ix]/Npoly*100))), ha='center', va='bottom', fontsize=fstxt)

pp.show()

if not silent:
    print('\n',fol)

print('\n percentage of R2 iteration terminating successfully:')
print( np.round(count_R2stop / (count_R2stop + count_maxitstop)*100,2),'\\!\%')
print('\n percentage of R2 iteration reached maximum iteration and forced stop:')
print( np.round(count_maxitstop / (count_R2stop + count_maxitstop)*100,2),'\\!\%')


cmap = cm.get_cmap('gist_rainbow')
cmap.N = len(finsel_plot)
try:
    cmap(0) # first call throws error, no idea why
except:
    pass

ylims1 = None ; ylims2 = None ; showfliers = False
# ylims1 = [ [0,3.5] , [0,7] , [0,4] , [0,2] ] ; ylims2 = [ [0,50] , [0,50] , [0,50] , [0,50] ] ; showfliers = True # MAE

lab_posvec = np.linspace(-0.4,0.4,len(finsel_plot))
            
fig, ax = pp.subplots(2,1, gridspec_kw={'height_ratios': [10, 1]}, constrained_layout = True)
fig.set_facecolor(bgcol)
ax[0].set_facecolor(bgcol)
for ix,finsel in enumerate(finsel_plot):
    fix = np.where(finsel==finselvec2)[0].item()
    col = np.array(cmap(ix))
    col_dark = [0.5,0.5,0.5,1.] * col
    flierprops = dict(marker='x', markerfacecolor=bgcol, markersize=5, markeredgecolor=col, markeredgewidth=1.1)
    mask = np.logical_not( np.isnan(np.array(found_ntrue)[:,fix]) )
    boxpl = ax[0].boxplot( np.array(found_ntrue)[mask,fix], positions=[lab_posvec[ix]], widths=[0.08], showfliers=showfliers, flierprops=flierprops, patch_artist=True ) # patch_artist allows control of plot parts like below
    boxpl['boxes'][0].set_facecolor(bgcol)
    boxpl['boxes'][0].set_edgecolor(col)
    boxpl['boxes'][0].set_linewidth(2)
    boxpl['whiskers'][0].set_color(col)
    boxpl['whiskers'][1].set_color(col)
    boxpl['whiskers'][0].set_linewidth(1.7)
    boxpl['whiskers'][1].set_linewidth(1.7)
    boxpl['caps'][0].set_color(col)
    boxpl['caps'][1].set_color(col)
    boxpl['caps'][0].set_linewidth(1.7)
    boxpl['caps'][1].set_linewidth(1.7)
    boxpl['medians'][0].set_color(col_dark)
    boxpl['medians'][0].set_linewidth(4)
    ax[0].plot( lab_posvec[ix], np.nanmean(np.array(found_ntrue)[:,fix]), 'o', markersize=11, markeredgewidth=1.4, markeredgecolor=[0,0,0], color=col)
    # ax[0].set_xlabel()
    ax[0].set_ylabel('true terms found')
ax[0].tick_params(
    axis='x',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    top=False,         # ticks along the top edge are off
    labelbottom=False) # labels along the bottom edge are off
ax[0].set_xlim([-0.48,0.48])
cbar = cb.ColorbarBase(ax[1], ticks=np.linspace(0,1,len(finsel_plot)+1)[:-1]+1/(len(finsel_plot))/2, cmap=cmap, orientation='horizontal')
cbar.ax.set_xticklabels(labs, fontsize=fs)
pp.show()

   
fig, ax = pp.subplots(2,1, gridspec_kw={'height_ratios': [10, 1]}, constrained_layout = True)
fig.set_facecolor(bgcol)
ax[0].set_facecolor(bgcol)
for ix,finsel in enumerate(finsel_plot):
    fix = np.where(finsel==finselvec2)[0].item()
    col = np.array(cmap(ix))
    col_dark = [0.5,0.5,0.5,1.] * col
    flierprops = dict(marker='x', markerfacecolor=bgcol, markersize=5, markeredgecolor=col, markeredgewidth=1.1)
    mask = np.logical_not( np.isnan(np.array(found_nextra)[:,fix]) )
    boxpl = ax[0].boxplot( np.array(found_nextra)[mask,fix], positions=[lab_posvec[ix]], widths=[0.08], showfliers=showfliers, flierprops=flierprops, patch_artist=True ) # patch_artist allows control of plot parts like below
    boxpl['boxes'][0].set_facecolor(bgcol)
    boxpl['boxes'][0].set_edgecolor(col)
    boxpl['boxes'][0].set_linewidth(2)
    boxpl['whiskers'][0].set_color(col)
    boxpl['whiskers'][1].set_color(col)
    boxpl['whiskers'][0].set_linewidth(1.7)
    boxpl['whiskers'][1].set_linewidth(1.7)
    boxpl['caps'][0].set_color(col)
    boxpl['caps'][1].set_color(col)
    boxpl['caps'][0].set_linewidth(1.7)
    boxpl['caps'][1].set_linewidth(1.7)
    boxpl['medians'][0].set_color(col_dark)
    boxpl['medians'][0].set_linewidth(4)
    ax[0].plot( lab_posvec[ix], np.nanmean(np.array(found_nextra)[:,fix]), 'o', markersize=11, markeredgewidth=1.4, markeredgecolor=[0,0,0], color=col)
    # ax[0].set_xlabel()
    ax[0].set_ylabel('wrong terms found')
ax[0].tick_params(
    axis='x',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    top=False,         # ticks along the top edge are off
    labelbottom=False) # labels along the bottom edge are off
ax[0].set_xlim([-0.48,0.48])
if set_m1==4:
    ax[0].set_ylim([0,20])
cbar = cb.ColorbarBase(ax[1], ticks=np.linspace(0,1,len(finsel_plot)+1)[:-1]+1/(len(finsel_plot))/2, cmap=cmap, orientation='horizontal')
cbar.ax.set_xticklabels(labs, fontsize=fs)
pp.show()


found_crit = np.array(found_ntrue)-np.array(found_nextra)

fig, ax = pp.subplots(2,1, gridspec_kw={'height_ratios': [10, 1]}, constrained_layout = True)
fig.set_facecolor(bgcol)
ax[0].set_facecolor(bgcol)
ax[0].axhline(set_m1,color=[0.7,0.7,0.7],lw=4)
for ix,finsel in enumerate(finsel_plot):
    fix = np.where(finsel==finselvec2)[0].item()
    col = np.array(cmap(ix))
    col_dark = [0.5,0.5,0.5,1.] * col
    flierprops = dict(marker='x', markerfacecolor=bgcol, markersize=5, markeredgecolor=col, markeredgewidth=1.1)
    mask = np.logical_not( np.isnan(np.array(found_crit)[:,fix]) )
    boxpl = ax[0].boxplot( np.array(found_crit)[mask,fix], positions=[lab_posvec[ix]], widths=[0.08], showfliers=showfliers, flierprops=flierprops, patch_artist=True ) # patch_artist allows control of plot parts like below
    boxpl['boxes'][0].set_facecolor(bgcol)
    boxpl['boxes'][0].set_edgecolor(col)
    boxpl['boxes'][0].set_linewidth(2)
    boxpl['whiskers'][0].set_color(col)
    boxpl['whiskers'][1].set_color(col)
    boxpl['whiskers'][0].set_linewidth(1.7)
    boxpl['whiskers'][1].set_linewidth(1.7)
    boxpl['caps'][0].set_color(col)
    boxpl['caps'][1].set_color(col)
    boxpl['caps'][0].set_linewidth(1.7)
    boxpl['caps'][1].set_linewidth(1.7)
    boxpl['medians'][0].set_color(col_dark)
    boxpl['medians'][0].set_linewidth(4)
    ax[0].plot( lab_posvec[ix], np.nanmean(np.array(found_crit)[:,fix]), 'o', markersize=11, markeredgewidth=1.4, markeredgecolor=[0,0,0], color=col)
    # ax[0].set_xlabel()
    ax[0].text(lab_posvec[ix], set_m1+0.2, "{0:d}%".format(int(np.round(fnd1st[ix]/Npoly*100))), ha='center', va='bottom', fontsize=fstxt+4)
    ax[0].set_ylabel('$N_{\\mathrm{diff}}$',fontsize=fs+4)
ax[0].tick_params(
    axis='x',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    top=False,         # ticks along the top edge are off
    labelbottom=False) # labels along the bottom edge are off
ax[0].set_xlim([-0.48,0.48])
if set_m1==4:
    ax[0].set_ylim([-16,6])
    ax[0].set_yticks([-16,-12,-8,-4,0,4])
elif set_m1==3:
    ax[0].set_ylim([-15,5])
    ax[0].set_yticks([-15,-11,-7,-3,0,3])
elif set_m1==2:
    ax[0].set_ylim([-14,4])
    ax[0].set_yticks([-14,-11,-7,-4,-2,0,2])
cbar = cb.ColorbarBase(ax[1], ticks=np.linspace(0,1,len(finsel_plot)+1)[:-1]+1/(len(finsel_plot))/2, cmap=cmap, orientation='horizontal')
cbar.ax.set_xticklabels(labs, fontsize=fs)
pp.show()


