
import numpy as np
import scipy.stats as ss
import scipy.integrate as si
import matplotlib.pyplot as pp
from matplotlib import cm
import chs

       

def nonlinreg_impl(maxdeg=3,maxdeg_term=np.inf,n=2,m1=2,m2=2,N=100,sig=0.001, ker='Kgen',prefac='one',seed=None):
        # maxdeg: maximum degree of basis functions (combinations may be of higher order)
        # maxdeg_term: maximum degree of terms
        # n: number of features
        # m1: number of non-zero terms, numerator and denominator
        # m2: number of non-zero terms in denominator if fractions are considered (set to zero if no fractions considered)
        # N: number of datapoints
        # sig: noise strength
        # ker: kernel generator ('Kgen', 'Kpows', 'Kpowsneg', 'Kpowsneg2', 'Kpowsneg3')
        
    genK = getattr(chs,ker)
    rng = np.random.default_rng() # random number generater instance
  
    # random and independent data for X features, each feature sampled from only 5% overlaping intervals
    b0 = -20
    b1 = 20
    b = np.unique( rng.uniform(b0,b1,n-1) ) # random points at which interval [b0,b1] is split
    bvec = np.hstack((b0,b,b1)) # edge vector for splitting
    means = bvec[:-1] + np.diff(bvec)/2 # bin centers shall be means
    stds = np.diff(bvec) / 4 # bin widths define 2sigma interval (5% overlaps for normal distributions)
    X = rng.normal(means[:,None],stds[:,None],size=(n,N))
    X.sort(axis=1)
    
    # basis function combinations
    K, eqlist, must, mustnot, _, _ = genK( X , maxdeg_fac=maxdeg, maxdeg_term=maxdeg_term, eqstr=True, prefac=prefac )
    p = K.shape[1] # total number of kernels
    
    # weight for numerator
    w = np.zeros(p) # init weight vector
    ix1 = rng.choice(np.arange(p),m1,replace=False) # random indices of where non-zero values are put
    bern = rng.binomial(size=m1,n=1,p=0.5) # bernoulli random variables to randomly choose positive or negative weight below
    w[ix1] = (bern==1)*rng.uniform(1,4,m1) + (bern==0)*rng.uniform(-4,-1,m1)
    
    # weight for denominator
    v = np.zeros(p) # init weight vector
    ix2 = rng.choice(np.arange(p),m2,replace=False) # random indices of where non-zero values are put
    bern = rng.binomial(size=m2,n=1,p=0.5) # bernoulli random variables to randomly choose positive or negative weight below
    v[ix2] = (bern==1)*rng.uniform(1,4,m2) + (bern==0)*rng.uniform(-4,-1,m2)
    # above I use mixture with means away from zero to avoid sampling weights close to noise level
    
    ## independent noise 
    z = rng.normal(0,sig,N)   
    if m2>0:
        y = (K @ w + z)  /  (K @ v) # response variable, noise added to responses (in K) directly for convenience
    elif m2==0: 
        y = K @ w + z # if no terms for denominator, set to 1 (usual linear regression)
        ix2 = np.array([0]) # means true term for denominator is constant term (first term)
        v[0] = 1 # and weight for that is 1
    
    return K, y, w, v, eqlist, ix1, ix2, X
    
    

def lorenz(N=1000, T=20, sig=0.2, cma_step=0, pars={'sigma':10,'rho':28,'beta':8/3}, ics=[[1,1,1]], plotit=False, datafrac=1.0):
    
    # bgcol = [23/255,23/255,23/255,0]
    # bgcol = [1,1,1,0]
    
    sigma = pars['sigma'] ; rho = pars['rho'] ; beta = pars['beta']
    
    Mx = 3
    
    # define lorenz system
    # xx[0] = x, xx[1]=y, xx[2]=z, f is derivative vector of that
    def f(t,xx):
        return np.array( [ 
             - sigma*xx[0] + sigma*xx[1], # = xx[0]' = x'
            rho*xx[0] - xx[1] - xx[0]*xx[2], # = xx[1]' = y'
            - beta*xx[2] + xx[0]*xx[1] # = xx[2]' = z'
        ] )
    
    # exact powers [x,y,z]
    powsex = np.array( [ [ [1,0,0] , [0,1,0] ] , [ [1,0,0] , [0,1,0] , [1,0,1] ] , [ [0,0,1] , [1,1,0] ] ] , dtype='object' )
    wex = np.array([ [ -sigma, sigma ] , [ rho , -1 , -1 ] , [ -beta, 1 ] ] , dtype='object' )
    eqlist = np.array( [ "-\\sigma (x_2 - x_1)" , "x_1 (\\varrho - x_3) - x_2" , "x_1x_2 - \\beta x_3" ] )
    
        
    # obtain exact solutions of dynamical system
    rv = ss.norm # define random variable for noise
    rv.random_state = np.random.RandomState()
    Nics = len(ics)
    Z = rv.rvs(0,sig,(Mx,Nics,N)) # sample the noise
    t = np.linspace(0,T,N)  # points in time at which observations have been recorded
    t_mp = np.tile(  (t[:-1]+t[1:])/2  , Nics ).reshape((Nics,N-1)) # observation times for all solutions in same shape as observations and responses
    y = np.empty((Mx,Nics,N-1))
    x_mp = np.empty((Mx,Nics,N-1))
    if plotit:
        ndat = int(datafrac*N)
        cmap = cm.get_cmap('Wistia')
        cols = cmap(np.linspace(0,1,len(ics)))
        fig = pp.figure()
        # fig.set_facecolor(bgcol)
        ax = fig.add_subplot(111, projection='3d')
        # ax.set_facecolor(bgcol)
        ax.grid(False) 
        ax.w_xaxis.pane.fill = False
        ax.w_yaxis.pane.fill = False
        ax.w_zaxis.pane.fill = False
        ax.set_xlabel('$x$')
        ax.set_ylabel('$y$')
        ax.set_zlabel('$z$')
        ax.xaxis.labelpad = -7
        ax.yaxis.labelpad = -7
        ax.zaxis.labelpad = -7
        ax.tick_params(axis='x', which='major', pad=-4)
        ax.tick_params(axis='y', which='major', pad=-3)
        ax.tick_params(axis='z', which='major', pad=-2)
        fig2, ax2 = pp.subplots(3,1)
        # fig2.set_facecolor(bgcol)
        ax2[0].set_xlabel(None)
        ax2[0].set_ylabel('$x$')
        # ax2[0].set_facecolor(bgcol)
        ax2[1].set_xlabel(None)
        ax2[1].set_ylabel('$y$')
        # ax2[1].set_facecolor(bgcol)
        ax2[2].set_xlabel('$t$')
        ax2[2].set_ylabel('$z$')
        # ax2[2].set_facecolor(bgcol)
    for ix,z0 in enumerate(ics):
        sol0 = si.solve_ivp(f, [0., T], z0, method='RK45', dense_output=True)
        xraw = sol0.sol(t) + Z[:,ix,:] # observations, with additive noise
        y[:,ix,:] = np.diff( xraw ) / np.diff(t)  # response
        x_mp[:,ix,:] = (xraw[:,:-1]+xraw[:,1:])/2 # observations converted to midpoints
        if plotit: 
            ax.plot(xraw[0,:ndat],xraw[1,:ndat],xraw[2,:ndat])#, '-o',ms=0.5,lw=0.3,color=cols[ix] )
            ax.plot([ics[ix][0]],[ics[ix][1]],[ics[ix][2]],'o',mfc='none',ms=2,color=cols[ix] )
            ax2[0].plot(t[:ndat], xraw[0,:ndat], 'o-',ms=0.5,lw=0.5,color=cols[ix] )
            ax2[0].plot(t[0], ics[ix][0],'o',mfc='none',ms=2,color=cols[ix] )
            ax2[0].axes.xaxis.set_ticklabels([])
            ax2[1].plot(t[:ndat], xraw[1,:ndat], 'o-',ms=0.5,lw=0.5,color=cols[ix] )
            ax2[1].plot(t[0], ics[ix][1],'o',mfc='none',ms=2,color=cols[ix] )
            ax2[1].axes.xaxis.set_ticklabels([])
            ax2[2].plot(t[:ndat], xraw[2,:ndat], 'o-',ms=0.5,lw=0.5,color=cols[ix] )
            ax2[2].plot(t[0], ics[ix][2],'o',mfc='none',ms=2,color=cols[ix] )
            
    if plotit:
        pp.show()

        
    return t_mp, y, wex, eqlist, powsex, x_mp, xraw, t



def SIRV(N=100, T=100, sig=0.02, cma_step=0, pars={'v':1,'m':1,'a':1}, ics=[[1.0,0.01,0.0,0.0]], plotit=False, datafrac=1.0):
    
    # bgcol = [23/255,23/255,23/255,0]
    
    v = pars['v'] # vaccination rate
    m = pars['m'] # recovery rate
    a = pars['a'] # infection rate
    
    Mx = 4
    
    # define SIRV system
    def f(t,xx):
        return np.array( [ 
            - a*xx[0]*xx[1] - v*xx[0], # = xx[0]' = S'
            a*xx[0]*xx[1] - m*xx[1], # = xx[1]' = I'
            m*xx[1], # = xx[2]' = R'
            v*xx[0] # = xx[3]' = V'
        ] )
    
    
    
    # exact powers [x,y,z]
    powsex = np.array( [ [ [1,1,0,0] , [1,0,0,0] ] , [ [1,1,0,0] , [0,1,0,0] ] , [ [0,1,0,0] ] , [ [1,0,0,0] ] ] , dtype='object' )
    wex = np.array([ [ -a, -v ] , [ a , -m ] , [ m ] , [ v ] ] , dtype='object' )
    eqlist = np.array( [ "- a x_1 x_2 - v x_1" , "a x_1 x_2 - m x_2" , "m x_2" , "v x_1" ] )
    
    # obtain exact solutions of dynamical system
    rv = ss.norm # define random variable for noise
    rv.random_state = np.random.RandomState() #seed=123
    Nics = len(ics)
    Z = rv.rvs(0,sig,(Mx,Nics,N)) # sample the noise
    t = np.linspace(0,T,N)  # points in time at which observations have been recorded
    t_mp = np.tile(  (t[:-1]+t[1:])/2  , Nics ).reshape((Nics,N-1)) # observation times for all solutions in same shape as observations and responses
    y = np.empty((Mx,Nics,N-1))
    x_mp = np.empty((Mx,Nics,N-1))
    if plotit:
        ndat = int(datafrac*N)
        cmap = cm.get_cmap('tab20c')
        cols = cmap(np.linspace(0,1,4*5))
        fig, ax = pp.subplots()
        # fig.set_facecolor(bgcol)
        ax.set_xlabel('$t$')
        ax.set_ylabel('$SIRV$')
        # ax.set_facecolor(bgcol)
    for ix,z0 in enumerate(ics):
        sol0 = si.solve_ivp(f, [0., T], z0, method='RK45', dense_output=True)
        xraw = sol0.sol(t) + Z[:,ix] # observations, add noise here for better comparison to SINDy
        y[:,ix,:] = np.diff( xraw ) / np.diff(t)
        x_mp[:,ix,:] = (xraw[:,:-1]+xraw[:,1:])/2 # observations converted to midpoints
        if plotit: 
            ax.plot(t_mp[0,:], x_mp[0,ix,:ndat], 'o-', ms=0.7, lw=0.5 , color=cols[ix+0], label=(ix==0)*'susceptible')
            ax.plot(t_mp[0,:], x_mp[1,ix,:ndat], 'o-', ms=0.7, lw=0.5 , color=cols[ix+4] , label=(ix==0)*'infectious')
            ax.plot(t_mp[0,:], x_mp[2,ix,:ndat], 'o-', ms=0.7, lw=0.5 , color=cols[ix+8] , label=(ix==0)*'recovered')
            ax.plot(t_mp[0,:], x_mp[3,ix,:ndat], 'o-', ms=0.7, lw=0.5 , color=cols[ix+12] , label=(ix==0)*'vaccinated')
            
    if plotit:
        ax.legend(framealpha=0)
        pp.show()
        
    return t_mp, y, wex, eqlist, powsex, x_mp, xraw, t
   



def rabinovich(N=1000, T=20, sig=0.2, cma_step=0, pars={'alpha':1.1,'gamma':0.87}, ics=[[-1,0,0.5]], plotit=False, datafrac=1.0):
    
    # bgcol = [23/255,23/255,23/255,0]
    # bgcol = [1,1,1,0]
    
    alpha = pars['alpha'] ; gamma = pars['gamma'] 
    
    Mx = 3
    
    # define Rabinovich–Fabrikant system
    # xx[0] = x, xx[1]=y, xx[2]=z, f is derivative vector of that
    def f(t,xx):
        return np.array( [ 
            xx[1]*xx[2] - xx[1] + xx[1]*xx[0]**2 + gamma*xx[0], # = xx[0]' = x'
            3*xx[0]*xx[2] + xx[0] - xx[0]**3 + gamma*xx[1], # = xx[1]' = y'
            -2*alpha*xx[2] - 2*xx[0]*xx[1]*xx[2] # = xx[2]' = z'
        ] )
    
    # exact powers [x,y,z]
    powsex = np.array( [ 
        [ [0,1,1] , [0,1,0] , [2,1,0] , [1,0,0] ] , 
        [ [1,0,1] , [1,0,0] , [3,0,0] , [0,1,0] ] , 
        [ [0,0,1] , [1,1,1] ] 
    ] , dtype='object' )
    wex = np.array([ 
        [ 1 ,-1 , 1 , gamma ] , 
        [ 3 , 1 ,-1 , gamma ] , 
        [-2*alpha , -2 ] 
    ] , dtype='object' )
    eqlist = np.array( [ "x_2 (x_3 - 1 + x_1^2) + \\gamma x_1" , "x_1 (3 x_3 + 1 - x_1^2) + \\gamma x_2" , "-2 x_3 (\\alpha + x_1 x_2)" ] )    
        
    # obtain exact solutions of dynamical system
    rv = ss.norm # define random variable for noise
    rv.random_state = np.random.RandomState() #seed=123
    Nics = len(ics)
    Z = rv.rvs(0,sig,(Mx,Nics,N)) # sample the noise
    t = np.linspace(0,T,N)  # points in time at which observations have been recorded
    t_mp = np.tile(  (t[:-1]+t[1:])/2  , Nics ).reshape((Nics,N-1)) # observation times for all solutions in same shape as observations and responses
    y = np.empty((Mx,Nics,N-1))
    x_mp = np.empty((Mx,Nics,N-1))
    if plotit:
        ndat = int(datafrac*N)
        cmap = cm.get_cmap('Wistia')
        cols = cmap(np.linspace(0,1,len(ics)))
        fig = pp.figure()
        # fig.set_facecolor(bgcol)
        ax = fig.add_subplot(111, projection='3d')
        # ax.set_facecolor(bgcol)
        ax.grid(False) 
        ax.w_xaxis.pane.fill = False
        ax.w_yaxis.pane.fill = False
        ax.w_zaxis.pane.fill = False
        ax.set_xlabel('$x$')
        ax.set_ylabel('$y$')
        ax.set_zlabel('$z$')
        ax.xaxis.labelpad = -7
        ax.yaxis.labelpad = -7
        ax.zaxis.labelpad = -7
        ax.tick_params(axis='x', which='major', pad=-4)
        ax.tick_params(axis='y', which='major', pad=-3)
        ax.tick_params(axis='z', which='major', pad=-2)
        fig2, ax2 = pp.subplots(3,1)
        # fig2.set_facecolor(bgcol)
        ax2[0].set_xlabel(None)
        ax2[0].set_ylabel('$x$')
        # ax2[0].set_facecolor(bgcol)
        ax2[1].set_xlabel(None)
        ax2[1].set_ylabel('$y$')
        # ax2[1].set_facecolor(bgcol)
        ax2[2].set_xlabel('$t$')
        ax2[2].set_ylabel('$z$')
        # ax2[2].set_facecolor(bgcol)
    for ix,z0 in enumerate(ics):
        sol0 = si.solve_ivp(f, [0., T], z0, method='LSODA', dense_output=True)
        xraw = sol0.sol(t) + Z[:,ix,:] # observations, with additive noise
        y[:,ix,:] = np.diff( xraw ) / np.diff(t)  # response
        x_mp[:,ix,:] = (xraw[:,:-1]+xraw[:,1:])/2 # observations converted to midpoints
        if plotit: 
            ax.plot(xraw[0,:ndat],xraw[1,:ndat],xraw[2,:ndat], '-o',ms=0.5,lw=0.3,color=cols[ix] )
            ax.plot([ics[ix][0]],[ics[ix][1]],[ics[ix][2]],'o',mfc='none',ms=2,color=cols[ix] )
            ax2[0].plot(t[:ndat], xraw[0,:ndat], 'o-',ms=0.5,lw=0.5,color=cols[ix] )
            ax2[0].plot(t[0], ics[ix][0],'o',mfc='none',ms=2,color=cols[ix] )
            ax2[0].axes.xaxis.set_ticklabels([])
            ax2[1].plot(t[:ndat], xraw[1,:ndat], 'o-',ms=0.5,lw=0.5,color=cols[ix] )
            ax2[1].plot(t[0], ics[ix][1],'o',mfc='none',ms=2,color=cols[ix] )
            ax2[1].axes.xaxis.set_ticklabels([])
            ax2[2].plot(t[:ndat], xraw[2,:ndat], 'o-',ms=0.5,lw=0.5,color=cols[ix] )
            ax2[2].plot(t[0], ics[ix][2],'o',mfc='none',ms=2,color=cols[ix] )
            
    if plotit:
        pp.show()

        
    return t_mp, y, wex, eqlist, powsex, x_mp, xraw, t

