import warnings

warnings.filterwarnings('ignore')

import numpy as np
from datafold import EDMD, DMDStandard, DMDControl, TSCDataFrame, TSCTakensEmbedding, TSCPrincipalComponent
from utils import TSCSampledNetwork, time_delay_embedding
from sklearn.pipeline import Pipeline
from datafold.utils.general import diagmat_dot_mat

class RNN():
    def __init__(self, dictionary, regularization_constant=1e-10, time_delay=None, pca_components=None, control=False, feature_names_in=None) -> None:
        self.dictionary = dictionary
        self.rcond = regularization_constant
        self.edmd = None
        self.time_delay = time_delay
        self.pca_components = pca_components
        self.control = control
        self.feature_names_in = feature_names_in
        self.n_features_out_ = None
        self.n_features_in_ = None
        self.K = None
        self.K_modes = None

    def fit(self, X, U=None):
        assert isinstance(X, TSCDataFrame), 'Please provide data in TSCDataFrame format.'
        if self.time_delay is not None:
            if self.pca_components is None:
                self.n_features_in_ = X.n_features + self.time_delay * X.n_features,
                steps = [
                    ("time-delay",
                     TSCTakensEmbedding(delays=self.time_delay)),
                    ("swim-dict",
                     TSCSampledNetwork(nn=self.dictionary, n_features_in=X.n_features + self.time_delay * X.n_features,
                                    n_features_out=self.dictionary[-1].layer_width)),
                ]
            else:
                self.n_features_in_ = X.n_features + self.time_delay * X.n_features,
                steps = [
                    ("time-delay",
                     TSCTakensEmbedding(delays=self.time_delay)),
                    ("pca",
                     TSCPrincipalComponent(n_components=self.pca_components)),
                    ("swim-dict",
                     TSCSampledNetwork(nn=self.dictionary, n_features_in=X.n_features + self.time_delay * X.n_features,
                                    n_features_out=self.dictionary[-1].layer_width)),
                ]
        else:
            self.n_features_in_ = X.n_features,
            steps = [
                ("swim-dict",
                 TSCSampledNetwork(nn=self.dictionary, n_features_in=X.n_features,
                                n_features_out=self.dictionary[-1].layer_width, feature_names_in_=self.feature_names_in)),
            ]
        self.n_features_out_ = self.dictionary[-1].layer_width
        if self.control:
            assert U is not None, "Please provide control values in order to fit a model with control."
            self.edmd = EDMD(steps, dmd_model=DMDControl(rcond=self.rcond), dict_preserves_id_state=False, include_id_state=False,
                stepwise_transform=True)
            self.edmd.fit(X=X, U=U)
        else:
            self.edmd = EDMD(
                steps, dmd_model=DMDStandard(sys_mode='spectral', rcond=self.rcond, diagonalize=True), include_id_state=(self.time_delay is not None),
                dict_preserves_id_state=False, stepwise_transform=True)
            self.edmd.fit(X=X)

        # prepare just K and K_modes for fast predictions
        Xm, Xp = X.tsc.shift_matrices(snapshot_orientation="row")
        self.K = np.linalg.lstsq(self.edmd.transform(Xm), self.edmd.transform(Xp), rcond=self.rcond)[0]
        self.K_modes = np.linalg.lstsq(self.edmd.transform(Xm), Xm[self.time_delay:], rcond=self.rcond)[0]
        return self

    def predict(self, X, U=None, time_values=None):
        assert self.edmd is not None, "Cannot make prediction until model is fit."

        if U is not None:
            assert self.control, "Please set up a system with control if control parameters are provided."
            return self.edmd.predict(X, U=U, time_values=time_values)

        # return self.edmd.predict(X, time_values=time_values)

        if self.time_delay is not None:
            return self.edmd.predict(X, time_values=time_values)
        else:
            if not isinstance(X, np.ndarray):
                X = X.to_numpy()
            dict = self.edmd.dict_steps[0][1].nn
            dict_current = dict.transform(X)
            x_result = []
            for k in range(len(time_values)):
                x_result.append(dict_current @ self.K_modes)
                dict_current = dict.transform(x_result[-1])
                dict_current = dict_current @ self.K
            return np.row_stack(x_result)

        return self.edmd.predict(X, time_values=time_values)


