# Modified from PySINDy package
# Source: https://github.com/dynamicslab/pysindy
import numpy as np
import warnings
from itertools import combinations
from itertools import combinations_with_replacement as combinations_w_r
from itertools import product as iproduct

from sklearn import __version__
from sklearn.utils.validation import check_is_fitted

from pysindy.differentiation import FiniteDifference
from pysindy.feature_library.base import x_sequence_or_item
from pysindy.feature_library import PDELibrary

from pysindy.utils import AxesArray
from pysindy.utils import comprehend_axes


class ExtendedPDELibrary(PDELibrary):
    """
    A subclass of PDELibrary that also applies the specified library_functions
    to the derivatives (in addition to the original input variables).

    The new terms are appended to the standard PDELibrary feature matrix, so
    you will get columns for:

      1) (optional) bias
      2) library_functions(input variables)
      3) derivatives of input variables
      4) products of #2 and #3 (if include_interaction=True)
      5) library_functions(derivatives)
      6) products of #5 and the derivatives themselves (if include_interaction=True)

    Parameters
    ----------
    Same parameters as PDELibrary, but now the computed derivatives are also
    passed through the library_functions.
    """

    def __init__(
        self,
        library_functions=[],
        derivative_order=0,
        spatial_grid=None,
        temporal_grid=None,
        interaction_only=True,
        function_names=None,
        include_bias=False,
        include_interaction=True,
        library_ensemble=False,
        ensemble_indices=[0],
        implicit_terms=False,
        multiindices=None,
        differentiation_method=FiniteDifference,
        diff_kwargs={},
        is_uniform=None,
        periodic=None,
    ):
        super().__init__(
            library_functions=library_functions,
            derivative_order=derivative_order,
            spatial_grid=spatial_grid,
            temporal_grid=temporal_grid,
            interaction_only=interaction_only,
            function_names=function_names,
            include_bias=include_bias,
            include_interaction=include_interaction,
            library_ensemble=library_ensemble,
            ensemble_indices=ensemble_indices,
            implicit_terms=implicit_terms,
            multiindices=multiindices,
            differentiation_method=differentiation_method,
            diff_kwargs=diff_kwargs,
            is_uniform=is_uniform,
            periodic=periodic,
        )

    @staticmethod
    def _combinations(n_features, n_args, interaction_only):
        """
        Get combinations of indices for features, analogous to PDELibrary's approach.
        """
        comb = combinations if interaction_only else combinations_w_r
        return comb(range(n_features), n_args)

    @x_sequence_or_item
    def fit(self, x_full, y=None):
        """Compute number of output features.

        Parameters
        ----------
        x : array-like, shape (n_samples, n_features)
            Measurement data.

        Returns
        -------
        self : instance
        """
        n_features = x_full[0].shape[x_full[0].ax_coord]

        if float(__version__[:3]) >= 1.0:
            self.n_features_in_ = n_features
        else:
            self.n_input_features_ = n_features

        n_output_features = 0
        # Count the number of terms
        n_output_features = 0
        for f in self.functions:
            n_args = f.__code__.co_argcount
            n_output_features += len(
                list(self._combinations(n_features * (1 + self.num_derivatives), n_args, self.interaction_only))
            )

        # If there is a constant term, add 1 to n_output_features
        if self.include_bias:
            n_output_features += 1

        self.n_output_features_ = n_output_features

        # required to generate the function names
        self.get_feature_names()

        return self
    
    def get_feature_names(self, input_features=None):
        """Return feature names for output features.

        Parameters
        ----------
        input_features : list of string, length n_features, optional
            String names for input features if available. By default,
            "x0", "x1", ... "xn_features" is used.

        Returns
        -------
        output_feature_names : list of string, length n_output_features
        """
        check_is_fitted(self)
        if float(__version__[:3]) >= 1.0:
            n_features = self.n_features_in_
        else:
            n_features = self.n_input_features_

        if input_features is None:
            input_features = ["x%d" % i for i in range(n_features)]
        if self.function_names is None:
            self.function_names = list(
                map(
                    lambda i: (lambda *x: "f" + str(i) + "(" + ",".join(x) + ")"),
                    range(n_features),
                )
            )
        feature_names = []

        # Include constant term
        if self.include_bias:
            feature_names.append("1")

        def derivative_string(multiindex):
            ret = ""
            for axis in range(self.ind_range):
                if self.implicit_terms and (
                    axis
                    in [
                        self.spatiotemporal_grid.ax_time,
                        self.spatiotemporal_grid.ax_sample,
                    ]
                ):
                    str_deriv = "t"
                else:
                    str_deriv = str(axis + 1)
                for i in range(multiindex[axis]):
                    ret = ret + str_deriv
            return ret
        
        # all features including input features and derivatives
        all_features = input_features + [
            input_features[j] + "_" + derivative_string(self.multiindices[k])
            for k in range(self.num_derivatives)
            for j in range(n_features)
        ]

        # Include any non-derivative terms
        for i, f in enumerate(self.functions):
            for c in self._combinations(
                n_features * (1 + self.num_derivatives), f.__code__.co_argcount, self.interaction_only
            ):
                feature_names.append(
                    self.function_names[i](*[all_features[j] for j in c])
                )

        return feature_names

    @x_sequence_or_item
    def transform(self, x_full):
        """Transform data to pde features

        Parameters
        ----------
        x : array-like, shape (n_samples, n_features)
            The data to transform, row by row.

        Returns
        -------
        xp : np.ndarray, shape (n_samples, n_output_features)
            The matrix of features, where n_output_features is the number of
            features generated from the tensor product of the derivative terms
            and the library_functions applied to combinations of the inputs.
        """
        check_is_fitted(self)

        xp_full = []
        for x in x_full:
            n_features = x.shape[x.ax_coord]

            if float(__version__[:3]) >= 1.0:
                if n_features != self.n_features_in_:
                    raise ValueError("x shape does not match training shape")
            else:
                if n_features != self.n_input_features_:
                    raise ValueError("x shape does not match training shape")

            shape = np.array(x.shape)
            shape[-1] = self.n_output_features_
            xp = np.empty(shape, dtype=x.dtype)

            # derivative terms
            shape[-1] = n_features * self.num_derivatives
            library_derivatives = np.empty(shape, dtype=x.dtype)
            library_idx = 0
            for multiindex in self.multiindices:
                derivs = x
                for axis in range(self.ind_range):
                    if multiindex[axis] > 0:
                        s = [0 for dim in self.spatiotemporal_grid.shape]
                        s[axis] = slice(self.spatiotemporal_grid.shape[axis])
                        s[-1] = axis

                        derivs = self.differentiation_method(
                            d=multiindex[axis],
                            axis=axis,
                            **self.diff_kwargs,
                        )._differentiate(derivs, self.spatiotemporal_grid[tuple(s)])
                library_derivatives[
                    ..., library_idx : library_idx + n_features
                ] = derivs
                library_idx += n_features

            # library function terms
            n_library_terms = 0
            for f in self.functions:
                for c in self._combinations(
                    n_features * (1 + self.num_derivatives), f.__code__.co_argcount, self.interaction_only
                ):
                    n_library_terms += 1

            shape[-1] = n_library_terms
            library_functions = np.empty(shape, dtype=x.dtype)
            library_idx = 0
            x_dx = np.concatenate([x, library_derivatives], axis=-1)
            for f in self.functions:
                for c in self._combinations(
                    n_features * (1 + self.num_derivatives), f.__code__.co_argcount, self.interaction_only
                ):
                    library_functions[..., library_idx] = f(*[x_dx[..., j] for j in c])
                    library_idx += 1

            library_idx = 0

            # constant term
            if self.include_bias:
                shape[-1] = 1
                xp[..., library_idx] = np.ones(shape[:-1], dtype=x.dtype)
                library_idx += 1

            # library function terms
            xp[..., library_idx : library_idx + n_library_terms] = library_functions
            library_idx += n_library_terms

            # # pure derivative terms
            # xp[
            #     ..., library_idx : library_idx + self.num_derivatives * n_features
            # ] = library_derivatives
            # library_idx += self.num_derivatives * n_features

            # # mixed function derivative terms
            # shape[-1] = n_library_terms * self.num_derivatives * n_features
            # if self.include_interaction:
            #     xp[
            #         ...,
            #         library_idx : library_idx
            #         + n_library_terms * self.num_derivatives * n_features,
            #     ] = np.reshape(
            #         library_functions[..., np.newaxis, :]
            #         * library_derivatives[..., :, np.newaxis],
            #         shape,
            #     )
            #     library_idx += n_library_terms * self.num_derivatives * n_features
            xp = AxesArray(xp, comprehend_axes(xp))
            xp_full.append(xp)
        if self.library_ensemble:
            xp_full = self._ensemble(xp_full)
        return xp_full
    