import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)

from datafold.pcfold import TSCDataFrame
from datafold.dynfold.base import TSCTransformerMixin
from sklearn.base import BaseEstimator
from scipy.integrate import solve_ivp
from sklearn.pipeline import Pipeline


def gmm_likelihood_diagonal_Sigma(samples, mu, sigma):
    T, N = mu.shape

    sqrt_det_Sigma = sigma ** N
    inv_Sigma = 1 / (sigma ** 2)

    E = np.empty(shape=(samples.shape[0], mu.shape[0]))
    for i in range(samples.shape[0]):
        for t in range(mu.shape[0]):
            E[i, t] = np.sum((samples[i, :] - mu[t, :]) ** 2 * inv_Sigma)
    L = np.exp(-0.5 * E) / sqrt_det_Sigma
    return np.sum(L, axis=1) / T


class TSCSampledNetwork(BaseEstimator, TSCTransformerMixin):  # pragma: no cover
    """
    This is a simple wrapper for sampled neural networks.

    Parameters
    ----------
    nn
        A sklearn pipeline that represents the neural network (containing ``Dense``, ``Linear``
        layers, etc. from the ``swimnetwork`` Python package). Note the pipleline
        should not be fitted yet.


    References
    ----------
    See :cite:t:`bolager-2023` for the paper on sampled networks and the
    gitlab repository `swimnetworks <https://gitlab.com/felix.dietrich/swimnetworks>`__

    To install the package run

    .. code-block::

        pip install git+https://github.com/https://gitlab.com/felix.dietrich/swimnetworks

    """

    def __init__(
            self,
            nn: Pipeline,
            n_features_in: int,
            n_features_out: int,
            feature_names_in_=None,
    ):
        self.nn = nn
        self.n_features_in_ = n_features_in
        self.n_features_out_ = n_features_out
        if feature_names_in_ is None:
            self.feature_names_in_ = [str(i) for i in range(n_features_in)]
        else:
            self.feature_names_in_ = feature_names_in_
        self.inverse_nn = None

    def __repr__(self):
        return "SWIM NETWORK"

    def get_feature_names_out(self, input_features=None):
        n_features_out = self.nn[-1].weights.shape[1]
        return [f"w{i}" for i in range(n_features_out)]

    def fit(self, X: TSCDataFrame, **fit_params) -> "TSCSampledNetwork":
        self._validate_datafold_data(X=X)
        # self._validate_feature_input(X, direction="transform")

        inverse_nn = self._read_fit_params(
            [("inverse_nn", None)], fit_params=fit_params
        )

        if self.nn[-1].weights is None:
            Xm, Xp = X.tsc.shift_matrices(snapshot_orientation="row")
            self.nn = self.nn.fit(Xm, Xp)
        else:
            pass

        # must be setup only *after* the network is fitted
        # self._setup_feature_attrs_fit(X)

        if inverse_nn is None:
            K_modes = np.linalg.lstsq(self.nn.transform(Xm), Xm)[0]
            self.inverse_nn = lambda x: x @ K_modes

        if inverse_nn is not None:
            self.inverse_nn = inverse_nn

            X_target = self.nn()
            orig_states = X.columns.str.split(":")

            X_np = X.loc[:, orig_states].to_numpy()
            self.inverse_nn.fit(X_target, X_np)

        return self

    def transform(self, X):
        # self._validate_feature_input(X=X, direction="transform") # skipping for performance reasons

        X_return = self.nn.transform(X)
        X_return = TSCDataFrame.from_same_indices_as(
            X, X_return, except_columns=self.get_feature_names_out()
        )

        return X_return

    def fit_transform(self, X: TSCDataFrame, y=None, **fit_params):
        self.fit(X, **fit_params)
        X_return = self.transform(X)
        X_return = TSCDataFrame.from_same_indices_as(
            X, X_return, except_columns=self.get_feature_names_out()
        )
        return X_return

    def inverse_transform(self, X):
        X_transform = self.inverse_nn(X)
        return X_transform


def solve_VdP(initial_conditions, t_eval):
    def VdP(t, y):
        """ODE system."""
        mu = 1
        x1, x2 = y[0], y[1]
        y_dot = np.zeros(2)

        y_dot[0] = x2
        y_dot[1] = mu * (1 - x1 ** 2) * x2 - x1
        return y_dot

    assert initial_conditions.ndim == 2
    assert initial_conditions.shape[1] == 2
    trajectories = []

    for idx, ic in enumerate(initial_conditions):
        solution = solve_ivp(
            VdP, t_span=(t_eval[0], t_eval[-1]), y0=ic, t_eval=t_eval, method='DOP853'
        )
        trajectories.append(solution["y"])

    return np.array(trajectories)


def solve_lorenz_system(initial_conditions, t_eval):
    assert initial_conditions.ndim == 2
    assert initial_conditions.shape[1] == 3

    def lorenz(t, y):
        """ODE system."""
        sigma = 10.
        beta = 8. / 3
        rho = 28.
        x, y, z = y[0], y[1], y[2]
        y_dot = np.zeros(3)

        y_dot[0] = sigma * (y - x)
        y_dot[1] = x * (rho - z) - y
        y_dot[2] = x * y - beta * z
        return y_dot

    trajectories = []

    for idx, ic in enumerate(initial_conditions):
        solution = solve_ivp(
            lorenz, t_span=(t_eval[0], t_eval[-1]), y0=ic, t_eval=t_eval, method='DOP853'
        )
        trajectories.append(solution["y"])

    return np.array(trajectories)

def solve_rossler_system(initial_conditions, t_eval):
    assert initial_conditions.ndim == 2
    assert initial_conditions.shape[1] == 3

    def rossler(t, y):
        """ODE system."""
        x = y
        a = 0.15
        b = 0.2
        c = 10  # 5
        y_dot = np.zeros(3)

        y_dot[0] = -x[1] - x[2]
        y_dot[1] = x[0] + a * x[1]
        y_dot[2] = b + x[2] * (x[0] - c)
        return y_dot

    trajectories = []

    for idx, ic in enumerate(initial_conditions):
        solution = solve_ivp(
            rossler, t_span=[t_eval[0], t_eval[-1]], y0=ic, t_eval=t_eval, method='DOP853'
        )
        trajectories.append(solution["y"])

    return np.array(trajectories)


def time_delay_embedding(X, time_delay):
    if not isinstance(X, np.ndarray):
        X = X.to_numpy()
    original_data = X[time_delay:, :]

    # select the data (row_wise) for each delay block
    # in last iteration "max_delay - delay == 0"
    delayed_data = np.hstack(
        [
            X[time_delay - delay: -delay, :]
            for delay in range(1, time_delay + 1)
        ]
    )
    return np.hstack([original_data, delayed_data])


def eval_likelihood_gmm_for_diagonal_cov_np(z, mu, std):
    T = mu.shape[0]
    mu = mu.reshape((1, T, -1))

    vec = z - mu  # calculate difference for every time step
    precision = 1 / (std ** 2)
    precision = np.tile(np.diag(precision[0, :]), reps=(precision.shape[0], 1, 1))

    prec_vec = np.einsum('zij,azj->azi', precision, vec)
    exponent = np.einsum('abc,abc->ab', vec, prec_vec)
    sqrt_det_of_cov = np.prod(std, axis=1)
    likelihood = np.exp(-0.5 * exponent) / sqrt_det_of_cov
    return likelihood.sum(axis=1) / T


def clean_from_outliers_np(prior, posterior):
    nonzeros = (prior != 0)
    if any(prior == 0):
        prior = prior[nonzeros]
        posterior = posterior[nonzeros]
    outlier_ratio = (1 - nonzeros.astype(np.float64)).mean()
    return prior, posterior, outlier_ratio


def calc_kl_mc_np(mu_inf, cov_inf, mu_gen, cov_gen):
    mc_n = 1000
    t = np.random.random_integers(0, mu_inf.shape[0] - 1, size=(mc_n,))
    # tc.randint(0, mu_inf.shape[0], (mc_n,))

    std_inf = np.sqrt(cov_inf)
    std_gen = np.sqrt(cov_gen)

    # z_sample = (mu_inf[t] + std_inf[t] * tc.randn(mu_inf[t].shape)).reshape((mc_n, 1, -1))
    z_sample = (mu_inf[t] + std_inf[t] * np.random.normal(size=mu_inf[t].shape)).reshape((mc_n, 1, -1))

    prior = eval_likelihood_gmm_for_diagonal_cov_np(z_sample, mu_gen, std_gen)
    posterior = eval_likelihood_gmm_for_diagonal_cov_np(z_sample, mu_inf, std_inf)
    prior, posterior, outlier_ratio = clean_from_outliers_np(prior, posterior)
    kl_mc = np.mean(np.log(posterior) - np.log(prior), axis=0)
    return kl_mc, outlier_ratio


def D_stsp(X_pred, X_true, cov_pred=1, cov_true=1):
    data_gen = X_pred
    scaling_inf = cov_pred
    scaling_gen = cov_true
    mu_inf = data_gen
    cov_inf = scaling_inf * np.ones_like(mu_inf)
    mu_gen = X_true
    cov_gen = scaling_gen * np.ones_like(mu_gen)

    kl_mc, _ = calc_kl_mc_np(mu_inf, cov_inf, mu_gen, cov_gen)
    return kl_mc


import numpy as np

'''Calculation of D_H, to evaluate temporal similarity between true and simulated data'''

SMOOTHING_SIGMA = 2  # choose depending on data
FREQUENCY_CUTOFF = 500


def convert_to_decibel(x):
    x = 20 * np.log10(x)
    return x[0]


def ensure_length_is_even(x):
    n = len(x)
    if n % 2 != 0:
        x = x[:-1]
        n = len(x)
    x = np.reshape(x, (n,))
    return x


def fft_smoothed(x):
    """
    :param x: input signal
    :return fft: smoothed power spectrum
    """
    x = ensure_length_is_even(x)
    fft_real = np.fft.rfft(x, norm='ortho')
    fft_magnitude = np.abs(fft_real) ** 2 * 2 / len(x)
    fft_smoothed = kernel_smoothen(fft_magnitude, kernel_sigma=SMOOTHING_SIGMA)

    return fft_smoothed


def get_average_spectrum(trajectories):
    '''average trajectories to fulfill conditions for the application
    of the Hellinger distance '''
    spectrum = []
    for trajectory in trajectories:
        trajectory = (trajectory - trajectory.mean()) / trajectory.std()
        fft = fft_smoothed(trajectory)
        spectrum.append(fft)
    spectrum = np.nanmean(np.array(spectrum), axis=0)

    return spectrum


def power_spectrum_error_per_dim(x_gen, x_true):
    # add one dimension
    x_gen = np.array([x_gen])
    x_true = np.array([x_true])

    assert x_true.shape[1] == x_gen.shape[1]
    assert x_true.shape[2] == x_gen.shape[2]
    dim_x = x_gen.shape[2]
    pse_corrs_per_dim = []
    for dim in range(dim_x):
        spectrum_true = get_average_spectrum(x_true[:, :, dim])
        spectrum_gen = get_average_spectrum(x_gen[:, :, dim])
        spectrum_true = spectrum_true[:FREQUENCY_CUTOFF]
        spectrum_gen = spectrum_gen[:FREQUENCY_CUTOFF]
        BC = np.trapz(np.sqrt(spectrum_true * spectrum_gen))
        hellinger_dist = np.sqrt(1 - BC)

        pse_corrs_per_dim.append(hellinger_dist)

    return pse_corrs_per_dim


def power_spectrum_error(x_gen, x_true):
    pse_errors_per_dim = power_spectrum_error_per_dim(x_gen, x_true)
    return np.array(pse_errors_per_dim).mean(axis=0)


def kernel_smoothen(data, kernel_sigma=1):
    """
    Smoothen data with Gaussian kernel
    @param kernel_sigma: standard deviation of gaussian, kernel_size is adapted to that
    @return: internal data is modified but nothing returned
    """
    kernel = get_kernel(kernel_sigma)
    data_final = data.copy()
    data_conv = np.convolve(data[:], kernel)
    pad = int(len(kernel) / 2)
    data_final[:] = data_conv[pad:-pad]
    data = data_final
    return data


def gauss(x, sigma=1):
    return 1 / np.sqrt(2 * np.pi * sigma ** 2) * np.exp(-1 / 2 * (x / sigma) ** 2)


def get_kernel(sigma):
    size = sigma * 10 + 1
    kernel = list(range(size))
    kernel = [float(k) - int(size / 2) for k in kernel]
    kernel = [gauss(k, sigma) for k in kernel]
    kernel = [k / np.sum(kernel) for k in kernel]
    return kernel
