import numpy as np
from datafold.pcfold import TSCDataFrame
from datafold.dynfold.base import TSCTransformerMixin
from sklearn.base import BaseEstimator
from sklearn.pipeline import Pipeline

class TSCSampledNetwork(BaseEstimator, TSCTransformerMixin):
    """
    This is a simple wrapper for sampled neural networks.

    Parameters
    ----------
    nn
        A sklearn pipeline that represents the neural network (containing ``Dense``, ``Linear``
        layers, etc. from the ``swimnetwork`` Python package). Note the pipleline
        should not be fitted yet.


    References
    ----------
    See :cite:t:`bolager-2023` for the paper on sampled networks and the
    gitlab repository `swimnetworks <https://gitlab.com/felix.dietrich/swimnetworks>`__

    To install the package run

    .. code-block::

        pip install git+https://github.com/https://gitlab.com/felix.dietrich/swimnetworks

    """

    def __init__(
            self,
            nn: Pipeline,
            n_features_in: int,
            n_features_out: int,
            feature_names_in_=None,
    ):
        self.nn = nn
        self.n_features_in_ = n_features_in
        self.n_features_out_ = n_features_out
        if feature_names_in_ is None:
            self.feature_names_in_ = [str(i) for i in range(n_features_in)]
        else:
            self.feature_names_in_ = feature_names_in_
        self.inverse_nn = None

    def __repr__(self):
        return "SWIM NETWORK"

    def get_feature_names_out(self, input_features=None):
        n_features_out = self.nn[-1].weights.shape[1]
        return [f"w{i}" for i in range(n_features_out)]

    def fit(self, X: TSCDataFrame, **fit_params) -> "TSCSampledNetwork":
        self._validate_datafold_data(X=X)
        # self._validate_feature_input(X, direction="transform")

        inverse_nn = self._read_fit_params(
            [("inverse_nn", None)], fit_params=fit_params
        )

        if self.nn[-1].weights is None:
            Xm, Xp = X.tsc.shift_matrices(snapshot_orientation="row")
            self.nn = self.nn.fit(Xm, Xp)
        else:
            pass

        # must be setup only *after* the network is fitted
        # self._setup_feature_attrs_fit(X)

        if inverse_nn is None:
            K_modes = np.linalg.lstsq(self.nn.transform(Xm), Xm)[0]
            self.inverse_nn = lambda x: x @ K_modes

        if inverse_nn is not None:
            self.inverse_nn = inverse_nn

            X_target = self.nn()
            orig_states = X.columns.str.split(":")

            X_np = X.loc[:, orig_states].to_numpy()
            self.inverse_nn.fit(X_target, X_np)

        return self

    def transform(self, X):
        # self._validate_feature_input(X=X, direction="transform") # skipping for performance reasons

        X_return = self.nn.transform(X)
        X_return = TSCDataFrame.from_same_indices_as(
            X, X_return, except_columns=self.get_feature_names_out()
        )

        return X_return

    def fit_transform(self, X: TSCDataFrame, y=None, **fit_params):
        self.fit(X, **fit_params)
        X_return = self.transform(X)
        X_return = TSCDataFrame.from_same_indices_as(
            X, X_return, except_columns=self.get_feature_names_out()
        )
        return X_return

    def inverse_transform(self, X):
        X_transform = self.inverse_nn(X)
        return X_transform
