import numpy as np
import pysindy as ps
from tqdm import tqdm

from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import TimeSeriesSplit

import warnings
from sklearn.utils._testing import ignore_warnings
from sklearn.exceptions import ConvergenceWarning

from contextlib import contextmanager
import sys, os

@contextmanager
def suppress_stderr():
    with open(os.devnull, "w") as devnull:
        old_stderr = sys.stderr
        sys.stderr = devnull
        try:
            yield
        finally:
            sys.stderr = old_stderr

@contextmanager
def suppress_stdout():
    with open(os.devnull, "w") as devnull:
        old_stdout = sys.stdout
        sys.stdout = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout


# CV parameters for all solvers:
CVsplit = 5 # CV fold (5 is standard in sklearn, but have seen reports stating up to 10 as good general option)
CVvec0 = np.logspace(-5,5,6) # start search vector for CV
CViter = 3 # number of refined CV searched
CVrefine = 6 # number of values for refined CV search

refine_step = 3/CVvec0.size # refined search depends on steps in start search vector, 


# Sequentially thresholded least squares algorithm (Ridge regression):
# iteratively performing least squares and masking out elements of the weight array w that are below a given threshold
@ignore_warnings(category=ConvergenceWarning)
def STLSQ(x,t,difforder=2,degree=4,threshold=0.5,featnames=["x_1", "x_2", "x_3"], njobs=1,silent=False):

    warnings.filterwarnings("ignore", category=UserWarning)
    warnings.filterwarnings("ignore", category=FutureWarning)
    
    dt = np.unique(np.round(np.diff(t),decimals=14))[0] # find constant time step (entails error if not uniform on purpose)
    model = ps.SINDy(optimizer=ps.optimizers.STLSQ(),t_default=dt)
    
    # start search vector for CV
    multi1 = CVvec0
    
    # repeated nested search
    for n in tqdm(range(CViter),desc='STLSQ',disable=silent): 
        
        param_grid = {
            "optimizer__threshold": multi1,
            "differentiation_method__order": [1, 2],
            }
        
        search = GridSearchCV( model, param_grid, cv=TimeSeriesSplit(n_splits=CVsplit) , n_jobs=njobs )
        with suppress_stderr():
            with suppress_stdout():
                search.fit(x)
        difforder = search.best_params_['differentiation_method__order']
        threshold = search.best_params_['optimizer__threshold']
        
        multi1 = np.logspace(np.log10(threshold)-refine_step**n,np.log10(threshold)+refine_step**n,CVrefine)

    
    differentiation_method = ps.differentiation.FiniteDifference(order=difforder)
    feature_library = ps.feature_library.PolynomialLibrary(degree=degree,include_bias=False)
    sindy_optimizer = ps.optimizers.STLSQ(threshold=threshold)
    model = ps.SINDy(
        differentiation_method=differentiation_method,
        feature_library=feature_library,
        optimizer=sindy_optimizer,
        feature_names=featnames,
    )
    
    fitrs = model.fit(x,t)
    return fitrs, search


@ignore_warnings(category=ConvergenceWarning)
def SR3(x,t,difforder=2,degree=4,threshold=0.1,nu=1,thresholder='L0',featnames=["x_1", "x_2", "x_3"], njobs=1,silent=False):

    import warnings
    warnings.filterwarnings("ignore", category=UserWarning)
    warnings.filterwarnings("ignore", category=FutureWarning)
    
    dt = np.unique(np.round(np.diff(t),decimals=14))[0] # find constant time step (entails error if not uniform on purpose)
    model = ps.SINDy(optimizer=ps.optimizers.SR3(),t_default=dt)
    
    
    # start search vector for CV
    multi1 = CVvec0
    multi2 = CVvec0
    
    # repeated nested search
    for n in tqdm(range(CViter),desc='SR3',disable=silent): 
        
        param_grid = {
            "optimizer__threshold": multi1,
            "optimizer__nu": multi2,
            "optimizer__thresholder": ['L0', 'L1', 'L2', 'CAD'],
            "differentiation_method__order": [1, 2],
            }
        
        search = GridSearchCV( model, param_grid, cv=TimeSeriesSplit(n_splits=CVsplit) , n_jobs=njobs )
        with suppress_stderr():
            with suppress_stdout():
                search.fit(x)
        difforder = search.best_params_['differentiation_method__order']
        threshold = search.best_params_['optimizer__threshold']
        nu = search.best_params_['optimizer__nu']
        thresholder = search.best_params_['optimizer__thresholder']
        
        multi1 = np.logspace(np.log10(threshold)-refine_step**n,np.log10(threshold)+refine_step**n,CVrefine)
        multi2 = np.logspace(np.log10(nu       )-refine_step**n,np.log10(nu       )+refine_step**n,CVrefine)
    
    differentiation_method = ps.differentiation.FiniteDifference(order=difforder)
    feature_library = ps.feature_library.PolynomialLibrary(degree=degree,include_bias=False)
    sindy_optimizer = ps.optimizers.SR3(threshold=threshold,nu=nu,thresholder=thresholder)
    model = ps.SINDy(
        differentiation_method=differentiation_method,
        feature_library=feature_library,
        optimizer=sindy_optimizer,
        feature_names=featnames,
    )
    
    fitrs = model.fit(x,t)
    return fitrs, search

@ignore_warnings(category=ConvergenceWarning)
def FROLS(x,t,difforder=2,degree=4,alpha=0.05,max_iter=6,featnames=["x_1", "x_2", "x_3"], njobs=1,silent=False):

    import warnings
    warnings.filterwarnings("ignore", category=UserWarning)
    warnings.filterwarnings("ignore", category=FutureWarning)
    
    dt = np.unique(np.round(np.diff(t),decimals=14))[0] # find constant time step (entails error if not uniform on purpose)
    model = ps.SINDy(optimizer=ps.optimizers.FROLS(),t_default=dt)
    
    
    # start search vector for CV
    multi1 = CVvec0
    
    # repeated nested search
    for n in tqdm(range(CViter),desc='FROLS',disable=silent): 
        
        param_grid = {
            "optimizer__alpha": multi1,
            "differentiation_method__order": [1, 2],
            }
        
        with suppress_stderr():
            with suppress_stdout():
                search = GridSearchCV( model, param_grid, cv=TimeSeriesSplit(n_splits=CVsplit) , n_jobs=njobs )
                search.fit(x)
                difforder = search.best_params_['differentiation_method__order']
                alpha = search.best_params_['optimizer__alpha']
        
        multi1 = np.logspace(np.log10(alpha)-refine_step**n,np.log10(alpha)+refine_step**n,6)
    
    differentiation_method = ps.differentiation.FiniteDifference(order=difforder)
    feature_library = ps.feature_library.PolynomialLibrary(degree=degree,include_bias=False)
    sindy_optimizer = ps.optimizers.FROLS(alpha=alpha,max_iter=max_iter)
    model = ps.SINDy(
        differentiation_method=differentiation_method,
        feature_library=feature_library,
        optimizer=sindy_optimizer,
        feature_names=featnames,
    )
    
    fitrs = model.fit(x,t)
    return fitrs, search


@ignore_warnings(category=ConvergenceWarning)
def MIOSR(x,t,difforder=2,degree=4,alpha=0.01,target_sparsity=12,group_sparsity=None,featnames=["x_1", "x_2", "x_3"], njobs=1,silent=False):

    import warnings
    warnings.filterwarnings("ignore", category=UserWarning)
    warnings.filterwarnings("ignore", category=FutureWarning)
    
    dt = np.unique(np.round(np.diff(t),decimals=14))[0] # find constant time step (entails error if not uniform on purpose)
    model = ps.SINDy(optimizer=ps.optimizers.MIOSR(),t_default=dt)
    
    
    # start search vector for CV
    multi1 = CVvec0
    
    # repeated nested search
    for n in tqdm(range(CViter),desc='MIOSR',disable=silent): 
        
        param_grid = {
            "optimizer__alpha": multi1,
            "optimizer__target_sparsity": [1,2,3,4,5,6,7],
            "differentiation_method__order": [1, 2],
            }
        
        search = GridSearchCV( model, param_grid, cv=TimeSeriesSplit(n_splits=CVsplit) , n_jobs=njobs )
        with suppress_stderr():
            with suppress_stdout():
                search.fit(x)
        difforder = search.best_params_['differentiation_method__order']
        alpha = search.best_params_['optimizer__alpha']
        target_sparsity = search.best_params_['optimizer__target_sparsity']
        
        multi1 = np.logspace(np.log10(alpha)-refine_step**n,np.log10(alpha)+refine_step**n,CVrefine)
    
    differentiation_method = ps.differentiation.FiniteDifference(order=difforder)
    feature_library = ps.feature_library.PolynomialLibrary(degree=degree,include_bias=False)
    sindy_optimizer = ps.optimizers.MIOSR(alpha=alpha,target_sparsity=target_sparsity,group_sparsity=group_sparsity)
    model = ps.SINDy(
        differentiation_method=differentiation_method,
        feature_library=feature_library,
        optimizer=sindy_optimizer,
        feature_names=featnames,
    )
    
    fitrs = model.fit(x,t)
    return fitrs, search
