"""
Gonna translate the version from sklearn.

"""
import tensorflow as tf

from em.util import hdf5_util


def _varname(v):
    return v.name.split("/")[-1].split(":")[0]


@tf.function
def _logcosh(x):
    gx = tf.math.tanh(x)
    g_x = tf.reduce_mean(1.0 - tf.math.tanh(x) ** 2, axis=-1, keepdims=True)
    return gx, g_x


@tf.function
def _sym_decorrelation(w):
    """Symmetric decorrelation
    i.e. W <- (W * W.T) ^{-1/2} * W
    """
    # TODO: This can probably be made more efficient. SVD comes to mind.
    s, u = tf.linalg.eigh(tf.einsum("ij,kj->ik", w, w))
    # u (resp. s) contains the eigenvectors (resp. square roots of
    # the eigenvalues) of W * W.T
    return tf.einsum("ij,kj,kl->il", u * tf.math.rsqrt(s), u, w)
    # return tf.einsum("ij,kj,kl->il", u * tf.math.rsqrt(tf.maximum(s, 1e-8)), u, w)


class TfFastICA(object):
    def __init__(
        self,
        *,
        n_components,
        n_features,
        algorithm="parallel",
        whiten=True,
        fun="logcosh",
        max_iter=200,
        tol=1e-4,
        dtype=tf.float32,
        print_interval=None,
        _dont_create_vars=False,
    ):
        self.n_components = n_components
        self.n_features = n_features

        self.algorithm = algorithm
        self.whiten = whiten
        self.max_iter = max_iter
        self.tol = tol
        self.dtype = dtype
        self.print_interval = print_interval

        if fun == "logcosh":
            self.g = _logcosh
        else:
            raise ValueError(f"Invalid value of fun: {fun}")

        if self.algorithm == "parallel":
            self.algorithm_fn = self._ica_par
        else:
            raise ValueError(
                f"Got {self.algorithm} as the algorithm. Only supporting parallel at the moment."
            )

        if _dont_create_vars:
            return

        self.n_iter = tf.Variable(0, dtype=tf.int32, name="n_iter", trainable=False)

        self.mean = tf.Variable(
            tf.zeros([n_features], dtype=self.dtype), name="mean", trainable=False
        )
        self.whitening = tf.Variable(
            tf.zeros([n_components, n_features], dtype=self.dtype),
            name="whitening",
            trainable=False,
        )

        self.components = tf.Variable(
            tf.zeros([n_components, n_features], dtype=self.dtype),
            name="components",
            trainable=False,
        )

        self.mixing = tf.Variable(
            tf.zeros([n_features, n_components], dtype=self.dtype),
            name="mixing",
            trainable=False,
        )

        self.w = tf.Variable(
            tf.random.normal([n_components, n_components], dtype=self.dtype),
            name="w",
            trainable=False,
        )

    @property
    def unmixing(self):
        return self.components

    def _var(self, k, dtype):
        return tf.Variable(getattr(self, k), name=k, trainable=False, dtype=dtype)

    def _create_variables_for_saving(self):
        return [
            self._var("n_components", tf.int32),
            self._var("n_features", tf.int32),
            self._var("whiten", tf.bool),
            self._var("max_iter", tf.int32),
            self._var("tol", tf.float32),
            self.mean,
            self.whitening,
            self.components,
            self.mixing,
            self.w,
        ]

    def save(self, filepath):
        variables = self._create_variables_for_saving()
        hdf5_util.save_variables_to_hdf5(variables, filepath)

    @classmethod
    def _load_from_variables(cls, variables):
        variables = {_varname(v): v for v in variables}
        ica = TfFastICA(
            n_components=int(variables["n_components"].numpy()),
            n_features=int(variables["n_features"].numpy()),
            whiten=bool(variables["whiten"].numpy()),
            max_iter=int(variables["max_iter"].numpy()),
            tol=float(variables["tol"].numpy()),
            _dont_create_vars=True,
        )

        keys = ["mean", "whitening", "components", "mixing", "w"]
        for k in keys:
            setattr(ica, k, variables[k])
        return ica

    @classmethod
    def load(cls, filepath):
        variables = hdf5_util.load_variables_from_hdf5(filepath)
        return cls._load_from_variables(variables)

    @tf.function
    def _whiten_fn(self, x, n_components):
        # Centering the columns (ie the variables)
        x_mean = tf.reduce_mean(x, axis=-1)
        x -= x_mean[:, None]

        # Whitening and preprocessing by PCA
        d, u, _ = tf.linalg.svd(x, full_matrices=False)

        # TODO: We can run into trouble when some components have
        # a very small or zero variance.
        K = tf.transpose(u / d)[:n_components]  # see (6.33) p.140
        x1 = tf.matmul(K, x)
        # see (13.6) p.267 Here x1 is white and data
        # in x has been projected onto a subspace by PCA
        x1 *= tf.sqrt(tf.cast(tf.shape(x)[1], self.dtype))

        return x, x1, x_mean, K

    @tf.function
    def _ica_par(self, x):
        """Parallel FastICA.
        Used internally by FastICA --main loop
        """
        w = self.w
        w.assign(_sym_decorrelation(w))
        p = tf.cast(tf.shape(x)[1], self.dtype)
        for ii in tf.range(self.max_iter):
            if self.print_interval is not None and ~tf.cast(
                tf.math.mod(ii, self.print_interval), tf.bool
            ):
                tf.print("FastICA step: ", ii)
            gwtx, g_wtx = self.g(tf.matmul(w, x))

            w1 = _sym_decorrelation(tf.einsum("ij,kj->ik", gwtx, x) / p - g_wtx * w)

            a_ = tf.einsum("ij,kj->ik", w1, w)
            a_ = tf.linalg.diag_part(a_)
            lim = tf.reduce_max(tf.abs(tf.abs(a_) - 1))
            w.assign(w1)
            if lim < self.tol:
                break

        return w, ii + 1

    @tf.function
    def fit(self, x):
        n_components = self.n_components

        tf.debugging.assert_less_equal(
            n_components, tf.minimum(tf.shape(x)[0], tf.shape(x)[1])
        )

        # NOTE: The `n_samples, n_features = X.shape` in scikit-learn is wrong.
        # It should be transposed.
        x = tf.transpose(x)

        if self.whiten:
            x, x1, x_mean, K = self._whiten_fn(x, n_components)
        else:
            x1 = x

        w, n_iter = self.algorithm_fn(x1)

        self.n_iter.assign(n_iter)

        if self.whiten:
            self.components.assign(tf.matmul(w, K))
            self.mean.assign(x_mean)
            self.whitening.assign(K)
        else:
            self.components.assign(w)

        self.mixing.assign(tf.linalg.pinv(self.components))

    @tf.function
    def transform(self, x, no_mean=False):
        """Recover the sources from x (apply the unmixing matrix)."""
        if not no_mean and self.whiten:
            x -= self.mean
        return tf.einsum("...j,kj->...k", x, self.components)

    @tf.function
    def inverse_transform(self, x, no_mean=False):
        """Transform the sources back to the mixed data (apply mixing matrix)."""
        x = tf.einsum("...j,kj", x, self.mixing)
        if not no_mean and self.whiten:
            x += self.mean
        return x
