from dataclasses import dataclass

import numpy as np
from datafold import EDMD, DMDStandard, DMDControl, TSCDataFrame, TSCTakensEmbedding, TSCPrincipalComponent
from sklearn.base import BaseEstimator
from sklearn.pipeline import Pipeline
from .utils import TSCSampledNetwork


@dataclass
class KIRNN(BaseEstimator):
    """
    Koopman-informed RNN model sampled using SWIM.

    Attributes:
    -----------
    dictionary : sklearn.pipeline.Pipeline
        Sampled hidden layer architecture.
    n_features_in: int
        Number of input features.
    rcond: float = 1e-10
        Cut-off ratio for small singular values of np.linalg.lstsq.
    n_features_out: int = None
        Number of output features.
    time_delay: int = None
        Number for time delays to embed with datafold.TSCTakensEmbedding.
    pca_components: int = None
        Number of PCA components to use with datafold.TSCPrincipalComponent.
    control: bool = False
        Boolean flag that indicates if inputs for control are passed, if not set to False.
    feature_names_in: list = None
        Input feature names.
    include_id_state: bool = (time_delay is not None) and (control is False)
        Boolean flag indicating whether the original states should be appended to the transformed state.
    dict_preserves_id_state: bool = False
        Boolean flag of whether the full state is contained in the transformed state.
    stepwise_transform: bool = True
        Boolean flag indicating if the transformed state should be mapped back to the original state at every step.
        Could lead to slower performance.
    """
    dictionary: Pipeline
    n_features_in: int
    rcond: float = 1e-10
    n_features_out: int = None
    time_delay: int = None
    pca_components: int = None
    control: bool = False
    feature_names_in: list = None
    include_id_state: bool = (time_delay is not None) and (control is False)
    dict_preserves_id_state: bool = False
    stepwise_transform: bool = True
    compute_pseudospectrum: bool = False
    use_koopman: bool = True

    K: np.ndarray = None
    K_modes: np.ndarray = None
    M: np.ndarray = None # alternative to K*K_modes
    edmd: EDMD = None

    def __post_init__(self):
        if self.n_features_out is None:
            self.n_features_out = self.dictionary[-1].layer_width

        if self.time_delay is not None:
            if self.pca_components is None:
                steps = [
                    ("time-delay",
                     TSCTakensEmbedding(delays=self.time_delay)),
                    ("kirnn-dict",
                     TSCSampledNetwork(nn=self.dictionary,
                                       n_features_in=(1 + self.time_delay) * self.n_features_in,
                                       n_features_out=self.dictionary[-1].layer_width)),
                ]
            else:
                steps = [
                    ("time-delay",
                     TSCTakensEmbedding(delays=self.time_delay)),
                    ("pca",
                     TSCPrincipalComponent(n_components=self.pca_components)),
                    ("kirnn-dict",
                     TSCSampledNetwork(nn=self.dictionary,
                                       n_features_in=(1 + self.time_delay) * self.n_features_in,
                                       n_features_out=self.dictionary[-1].layer_width)),
                ]
        else:
            steps = [
                ("kirnn-dict",
                 TSCSampledNetwork(nn=self.dictionary,
                                   n_features_in=self.n_features_in,
                                   n_features_out=self.dictionary[-1].layer_width,
                                   feature_names_in_=self.feature_names_in)),
            ]
        if self.control:
            self.edmd = EDMD(steps,
                             dmd_model=DMDControl(rcond=self.rcond,
                                                  compute_pseudospectrum=self.compute_pseudospectrum),
                             dict_preserves_id_state=self.dict_preserves_id_state,
                             include_id_state=self.include_id_state,
                             stepwise_transform=self.stepwise_transform)
        else:
            self.edmd = EDMD(steps,
                             dmd_model=DMDStandard(sys_mode='spectral',
                                                   rcond=self.rcond,
                                                   diagonalize=True,
                                                   compute_pseudospectrum=self.compute_pseudospectrum),
                             include_id_state=(self.time_delay is not None),
                             dict_preserves_id_state=self.dict_preserves_id_state,
                             stepwise_transform=self.stepwise_transform)

    def fit(self, X, U=None):
        assert isinstance(X, TSCDataFrame), 'Please provide data in TSCDataFrame format.'
        if self.control:
            assert U is not None, "Please provide control values in order to fit a model with control."

        self.edmd.fit(X=X, U=U)

        # prepare just K and K_modes for "fast" predictions without datafold
        Xm, Xp = X.tsc.shift_matrices(snapshot_orientation="row")
        self.K = np.linalg.lstsq(self.edmd.transform(Xm), self.edmd.transform(Xp), rcond=self.rcond)[0]
        self.K_modes = np.linalg.lstsq(self.edmd.transform(Xm), Xm[self.time_delay:], rcond=self.rcond)[0]
        if not self.use_koopman:
            self.M = np.linalg.lstsq(self.edmd.transform(Xm), Xp[self.time_delay:], rcond=self.rcond)[0]
        return self

    def predict(self, X, U=None, time_values=None):
        if U is not None:
            assert self.control, "Please set up a KIRNN with control if control parameters are provided."
            return self.edmd.predict(X, U=U, time_values=time_values)

        if self.time_delay is not None:
            return self.edmd.predict(X, time_values=time_values)

        else:
            if self.use_koopman:
                if not isinstance(X, np.ndarray):
                    X = X.to_numpy()
                dict = self.edmd.dict_steps[0][1].nn
                dict_current = dict.transform(X)
                x_result = []
                for k in range(len(time_values)):
                    x_result.append(dict_current @ self.K_modes)
                    dict_current = dict.transform(x_result[-1])
                    dict_current = dict_current @ self.K
                return np.row_stack(x_result)
            else:
                if not isinstance(X, np.ndarray):
                    X = X.to_numpy()
                dict = self.edmd.dict_steps[0][1].nn
                dict_current = dict.transform(X)
                x_result = []
                for k in range(len(time_values)):
                    x_result.append(dict_current @ self.M)
                    dict_current = dict.transform(x_result[-1])
                    # dict_current = dict_current @ self.K
                return np.row_stack(x_result)
