# Modified from PySINDy package
# Source: https://github.com/dynamicslab/pysindy
from pysindy.pysindy import SINDy, _adapt_to_multiple_trajectories, _comprehend_and_validate_inputs
import sklearn
from sklearn.base import TransformerMixin

import warnings
from itertools import product
from typing import Collection
from typing import Sequence

import numpy as np
from scipy.integrate import odeint
from scipy.integrate import solve_ivp
from scipy.interpolate import interp1d
from scipy.linalg import LinAlgWarning
from sklearn.base import BaseEstimator
from sklearn.exceptions import ConvergenceWarning
from sklearn.metrics import r2_score
from sklearn.pipeline import Pipeline
from sklearn.utils.validation import check_is_fitted

from pysindy.differentiation import FiniteDifference
from pysindy.feature_library import PolynomialLibrary
from pysindy.optimizers import EnsembleOptimizer
from pysindy.optimizers import SINDyOptimizer

try:  # Waiting on PEP 690 to lazy import CVXPY
    from pysindy.optimizers import SINDyPI
    sindy_pi_flag = True
except ImportError:
    sindy_pi_flag = False
from pysindy.optimizers import STLSQ
from pysindy.utils import AxesArray
from pysindy.utils import comprehend_axes
from pysindy.utils import concat_sample_axis
from pysindy.utils import drop_nan_samples
from pysindy.utils import SampleConcatter
from pysindy.utils import validate_control_variables
from pysindy.utils import validate_input
from pysindy.utils import validate_no_reshape
from pysindy.utils import print_model


class ConstraintTransformation(TransformerMixin):
    def __init__(self, Q, P=None, pq_prior_ratio=0.1):
        """
        Transform the input data to its coordinate in basis Q and optionally its orthogonal complement.
        Parameters:
        ----------
        Q: np.ndarray
            The basis tensor of shape (n_features_in, n_features_out, rank)
        P: np.ndarray, optional
            The orthogonal complement tensor of shape (n_features_in, n_features_out, n_features_in*n_features_out-rank)
        pq_prior_ratio: float, optional
            The ratio of the prior std of the P and Q components. Larger ratio means less constraint from Q.
        """
        self.Q = Q
        self.P = P
        self.pq_prior_ratio = pq_prior_ratio

    def fit(self, x, y=None):
        return self
    
    def __sklearn_is_fitted__(self):
        return True
    
    def transform(self, x):
        XQ = np.einsum("ik,jkr->ijr", x, self.Q)
        if self.P is None:
            return XQ.reshape(-1, self.Q.shape[2])
        else:
            XP = np.einsum("ik,jkr->ijr", x, self.P)
            XP *= self.pq_prior_ratio  # smaller ratio -> smaller feature magnitude -> larger coefficients -> larger l2 penalty
            return np.concatenate((XQ.reshape(-1, self.Q.shape[2]), XP.reshape(-1, self.P.shape[2])), axis=1)


class ConstrainedSINDy(SINDy):
    """
    SINDy model with linear constraint on the coefficients.
    """
    def __init__(
        self,
        constraint_tensor=None,
        orth_comp_tensor=None,
        constraint_breaking_factor=None,
        optimizer=None,
        feature_library=None,
        differentiation_method=None,
        feature_names=None,
        t_default=1,
        discrete_time=False,
    ):
        super().__init__(
            optimizer=optimizer,
            feature_library=feature_library,
            differentiation_method=differentiation_method,
            feature_names=feature_names,
            t_default=t_default,
            discrete_time=discrete_time,
        )

        self.Q = constraint_tensor  # (n_features_in, n_features_out, rank)
        self.P = orth_comp_tensor
        self.constraint_breaking_factor = constraint_breaking_factor

    def fit(
        self,
        x,
        t=None,
        x_dot=None,
        u=None,
        multiple_trajectories=False,
        unbias=True,
        quiet=False,
        ensemble=False,
        library_ensemble=False,
        replace=True,
        n_candidates_to_drop=1,
        n_subset=None,
        n_models=None,
        ensemble_aggregator=None,
    ):
        if ensemble or library_ensemble:
            # DeprecationWarning are ignored by default...
            warnings.warn(
                "Ensembling arguments are deprecated."
                "Use the EnsembleOptimizer class instead.",
                UserWarning,
            )
        if t is None:
            t = self.t_default

        if not multiple_trajectories:
            x, t, x_dot, u = _adapt_to_multiple_trajectories(x, t, x_dot, u)
            multiple_trajectories = True
        elif (
            not isinstance(x, Sequence)
            or (not isinstance(x_dot, Sequence) and x_dot is not None)
            or (not isinstance(u, Sequence) and u is not None)
        ):
            raise TypeError(
                "If multiple trajectories set, x and if included,"
                "x_dot and u, must be Sequences"
            )
        x, x_dot, u = _comprehend_and_validate_inputs(
            x, t, x_dot, u, self.feature_library
        )

        if (n_models is not None) and n_models <= 0:
            raise ValueError("n_models must be a positive integer")
        if (n_subset is not None) and n_subset <= 0:
            raise ValueError("n_subset must be a positive integer")

        if u is None:
            self.n_control_features_ = 0
        else:
            u = validate_control_variables(
                x,
                u,
                trim_last_point=(self.discrete_time and x_dot is None),
            )
            self.n_control_features_ = u[0].shape[u[0].ax_coord]
        x, x_dot = self._process_multiple_trajectories(x, t, x_dot)

        # Set ensemble variables
        self.ensemble = ensemble
        self.library_ensemble = library_ensemble

        # Append control variables
        if u is not None:
            x = [np.concatenate((xi, ui), axis=xi.ax_coord) for xi, ui in zip(x, u)]

        if hasattr(self.optimizer, "unbias"):
            unbias = self.optimizer.unbias

        # backwards compatibility for ensemble options
        if ensemble and n_subset is None:
            n_subset = x[0].shape[x[0].ax_time]
        if library_ensemble:
            self.feature_library.library_ensemble = False
        if ensemble and not library_ensemble:
            if n_subset is None:
                n_sample_tot = np.sum([xi.shape[xi.ax_time] for xi in x])
                n_subset = int(0.6 * n_sample_tot)
            optimizer = SINDyOptimizer(
                EnsembleOptimizer(
                    self.optimizer,
                    bagging=True,
                    n_subset=n_subset,
                    n_models=n_models,
                ),
                unbias=unbias,
            )
            self.coef_list = optimizer.optimizer.coef_list
        elif not ensemble and library_ensemble:
            optimizer = SINDyOptimizer(
                EnsembleOptimizer(
                    self.optimizer,
                    library_ensemble=True,
                    n_models=n_models,
                ),
                unbias=unbias,
            )
            self.coef_list = optimizer.optimizer.coef_list
        elif ensemble and library_ensemble:
            if n_subset is None:
                n_sample_tot = np.sum([xi.shape[xi.ax_time] for xi in x])
                n_subset = int(0.6 * n_sample_tot)
            optimizer = SINDyOptimizer(
                EnsembleOptimizer(
                    self.optimizer,
                    bagging=True,
                    n_subset=n_subset,
                    n_models=n_models,
                    library_ensemble=True,
                ),
                unbias=unbias,
            )
            self.coef_list = optimizer.optimizer.coef_list
        else:
            optimizer = SINDyOptimizer(self.optimizer, unbias=unbias)
        steps = [
            ("features", self.feature_library),
            ("shaping", SampleConcatter()),
            ("constraint_transform", ConstraintTransformation(self.Q, self.P, self.constraint_breaking_factor)),
            ("model", optimizer),
        ]
        x_dot = concat_sample_axis(x_dot)
        # flatten x_dot
        assert x_dot.shape[-1] == self.Q.shape[0], (x_dot.shape, self.Q.shape)
        x_dot = x_dot.reshape(-1, 1)
        self.model = Pipeline(steps)
        action = "ignore" if quiet else "default"
        with warnings.catch_warnings():
            warnings.filterwarnings(action, category=ConvergenceWarning)
            warnings.filterwarnings(action, category=LinAlgWarning)
            warnings.filterwarnings(action, category=UserWarning)
            self.model.fit(x, x_dot)

        # New version of sklearn changes attribute name
        if float(sklearn.__version__[:3]) >= 1.0:
            self.n_features_in_ = self.model.steps[0][1].n_features_in_
            n_input_features = self.model.steps[0][1].n_features_in_
        else:
            self.n_input_features_ = self.model.steps[0][1].n_input_features_
            n_input_features = self.model.steps[0][1].n_input_features_
        self.n_output_features_ = self.model.steps[0][1].n_output_features_

        if self.feature_names is None:
            feature_names = []
            for i in range(n_input_features - self.n_control_features_):
                feature_names.append("x" + str(i))
            for i in range(self.n_control_features_):
                feature_names.append("u" + str(i))
            self.feature_names = feature_names

        return self
    
    def equations(self, precision=3):
        """
        Get the right hand sides of the SINDy model equations.

        Parameters
        ----------
        precision: int, optional (default 3)
            Number of decimal points to include for each coefficient in the
            equation.

        Returns
        -------
        equations: list of strings
            List of strings representing the SINDy model equations for each
            input feature.
        """
        check_is_fitted(self, "model")
        if self.discrete_time:
            base_feature_names = [f + "[k]" for f in self.feature_names]
        else:
            base_feature_names = self.feature_names
        input_features = self.model.steps[0][1].get_feature_names(base_feature_names)
        coef = self.model.steps[-1][1].coef_
        # print("Raw coefficients:", coef[0])
        if self.P is None:
            coef = np.einsum("jkr,ir->ijk", self.Q, coef)  # (1, n_features_out, n_features_in)
            coef = coef[0]
        else:
            QP = np.concatenate((self.Q, self.P * self.constraint_breaking_factor), axis=2)
            coef = np.einsum("jkr,ir->ijk", QP, coef)
            coef = coef[0]
        
        return [
            print_model(
                coef[i], input_features, precision=precision
            )
            for i in range(coef.shape[0])
        ]
    
    def unconstrained_coefficients(self):
        coef = self.model.steps[-1][1].coef_
        if self.P is None:
            coef = np.einsum("jkr,ir->ijk", self.Q, coef)  # (1, n_features_out, n_features_in)
            coef = coef[0]
        else:
            QP = np.concatenate((self.Q, self.P * self.constraint_breaking_factor), axis=2)
            coef = np.einsum("jkr,ir->ijk", QP, coef)
            coef = coef[0]
        return coef
    