from copy import deepcopy
from typing import Dict, Callable, Union, Any
from pysindy import SINDy as _SINDy
from pysindy import PolynomialLibrary, optimizers
import re
import traceback
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.model_selection import GridSearchCV


class SINDy(_SINDy):
    
    def _format_equation(self, expr: str):
        expr = re.sub(r"(\d+\.?\d*) (1)", repl=r"\1 * \2", string=expr)
        for var_name in self.feature_names:
            expr = re.sub(fr"(\d+\.?\d*) ({var_name})", repl=r"\1 * \2", string=expr)
        expr = expr.replace("^", "**")
        return expr
    
    def format_equations(self):
        eqs = [self._format_equation(eq) for eq in self.equations(precision=100)]
        return eqs
        
        
class LinearSINDy(SINDy):
    
    def __init__(
        self,
        optimizer=None,
        differentiation_method=None,
        feature_names=None,
        t_default=1,
        discrete_time=False,
    ):
        super().__init__(
            optimizer=optimizer,
            feature_library=PolynomialLibrary(
                degree=1, 
                include_interaction=False, 
                include_bias=False,
            ),
            feature_names=feature_names,
            differentiation_method=differentiation_method,
            t_default=t_default,
            discrete_time=discrete_time,
        )
        self.times = None
        
    def get_system_matrix(self) -> np.ndarray:
        eqs = self.format_equations()
        system_matrix = []
        for eq in eqs:
            coefs = np.zeros(len(eqs))
            if "x" in eq:
                terms = eq.split("+")
                for term in terms:
                    coef, variable = term.split("*")
                    idx = re.findall(pattern=r"\d+", string=variable)
                    assert len(idx) == 1, len(idx)
                    idx = int(idx[0])
                    coefs[idx] = float(coef)
            system_matrix.append(coefs)
        system_matrix = np.array(system_matrix)
        return system_matrix
    
    def fit(self, x, t: Union[None, np.ndarray] = None) -> None:
        """Fit model to given data. 

        Args:
            x (_type_): Observed trajectory. If t is None, first column in x corresponds to time.
            t (Union[None, np.ndarray], optional): If t is None, then t must be given in the first column of x. Defaults to None.

        Returns:
            self
        """
        if t is None:
            t, x = x[:, 0], x[:, 1:]            
        self.observed_trajectory = x
        self.observed_t = t
        return super().fit(x, t)
    
    def get_info(self) -> Dict:
        try:
            opt_err = self.get_optimization_error()
        except:
            opt_err = np.nan
        return {
            "name": "LinearSindy",
            "optimization_error": opt_err,
            # "locals": self.__dict__
        }
        
    def get_optimization_error(self, error_func: Callable = mean_absolute_error):
        return error_func(
            self.observed_trajectory,
            self.simulate(x0=self.observed_trajectory[0], t=self.observed_t,),
        )
        

class LinearSINDyWithHyperparamSearch():
    
    def __init__(
        self,
        optimizer=None,
        differentiation_method=None,
        feature_names=None,
        t_default=1,
        discrete_time=False,
        scoring: Callable = r2_score,
    ):
        self.optimizer = optimizer
        self.differentiation_method = differentiation_method
        self.feature_names = feature_names
        self.t_default = t_default
        self.discrete_time = discrete_time
        self.scoring = scoring
    
    def fit(self, x, t):
        self.model = LinearSINDy(
            optimizer=self.optimizer,
            differentiation_method=self.differentiation_method,
            feature_names=self.feature_names,
            t_default=self.t_default,
            discrete_time=self.discrete_time,
        )
        tx = np.hstack([t.reshape(-1, 1), x])
        param_grid = {
            "optimizer__threshold": [1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1],
            "optimizer__alpha": [0.01, 0.05, 0.1],
            "differentiation_method__order": [1, 2, 3],
            "optimizer__max_iter": [20, 100],            
        }
        self.search = GridSearchCV(
            self.model,
            param_grid,
            cv=[(np.arange(len(t), dtype=int), np.arange(len(t), dtype=int))], 
            refit=True,
            return_train_score=True,
            scoring=self.scoring,
        )
        self.search.fit(tx)
        
    def get_system_matrix(self) -> np.ndarray:
        try:
            return self.search.best_estimator_.get_system_matrix()
        except AttributeError as e:

            traceback.print_exc()
        return np.nan
            
        
    def get_info(self) -> np.ndarray:
        try:
            return self.search.best_estimator_.get_info()
        except AttributeError as e:
                      traceback.print_exc()
        return {
            "name": "LinearSindy",
            "optimization_error": np.nan,
        }


if __name__ == '__main__':
    import scipy.integrate
    model = LinearSINDyWithHyperparamSearch()
    t = np.linspace(0, 10, 512, endpoint=False, dtype=np.float32)
    ode = lambda x, t: np.array([-0.5*x[0] + 0.2*x[1],  -0.1*x[1]])
    xs = scipy.integrate.odeint(func=ode, y0=np.array([3, 2], dtype=np.float32), t=t,).astype(np.float32)
    
    model.fit(xs, t)
    print(model.get_system_matrix())
    print(model.get_info())