"""Tensorflow ports of some scikit-learn NMF functionality.

The main advantage of using tensorflow is that it is trivial to
run on a GPU, which can lead to significant speed ups.

My strategy for porting is to get the basics and whatever functionality
that I'm going to use first and then maybe add other things.

Links to scikit-learn code and documentation:
    https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.NMF.html#sklearn.decomposition.NMF.transform
    https://github.com/scikit-learn/scikit-learn/blob/80598905e/sklearn/decomposition/_nmf.py#L1158
"""
import numbers
import time

import numpy as np
import tensorflow as tf
from tqdm import tqdm

from typing import Optional, Union

# Some typedefs.
DenseTensorLike = Union[tf.Tensor, np.ndarray]


def _to_dense_float32_tensor(x: DenseTensorLike) -> tf.Tensor:
    return tf.cast(x, tf.float32)


def is_non_negative(x: tf.Tensor) -> bool:
    """Check if there is any negative value in an array."""
    # avoid X.min() on sparse matrix since it also sorts the indices
    return tf.reduce_min(x) >= 0


###############################################################################
###############################################################################

_ALLOWED_BETA_LOSS = {"frobenius": 2, "kullback-leibler": 1, "itakura-saito": 0}


def _beta_loss_to_float(beta_loss):
    """Convert string beta_loss to float."""
    if isinstance(beta_loss, str) and beta_loss in _ALLOWED_BETA_LOSS:
        beta_loss = _ALLOWED_BETA_LOSS[beta_loss]

    if not isinstance(beta_loss, numbers.Number):
        raise ValueError(
            "Invalid beta_loss parameter: got %r instead of one of %r, or a float."
            % (beta_loss, _ALLOWED_BETA_LOSS.keys())
        )
    return beta_loss


def _compute_regularization(alpha_W, alpha_H, l1_ratio):
    """Compute L1 and L2 regularization coefficients for W and H."""
    alpha_H = alpha_W if alpha_H == "same" else alpha_H
    l1_reg_W = alpha_W * l1_ratio
    l1_reg_H = alpha_H * l1_ratio
    l2_reg_W = alpha_W * (1.0 - l1_ratio)
    l2_reg_H = alpha_H * (1.0 - l1_ratio)
    return l1_reg_W, l1_reg_H, l2_reg_W, l2_reg_H


###############################################################################

@tf.function
def _initialize_nmf_random(X, n_components):
    n_samples, n_features = tf.shape(X)[0], tf.shape(X)[1]
    avg = tf.sqrt(tf.reduce_mean(X) / tf.cast(n_components, X.dtype))
    W = avg * tf.abs(tf.random.normal([n_samples, n_components]))
    H = avg * tf.abs(tf.random.normal([n_components, n_features]))
    return W, H


def _update_first_axis(x, index, update):
    """Equivalent to `x[index, :] = update`."""
    return tf.tensor_scatter_nd_update(x, [[index]], [update])


def _update_second_axis(x, index, update):
    """Equivalent to `x[:, index] = update`."""
    indices = tf.stack([
        tf.range(tf.shape(x)[0], dtype=tf.int32),
        tf.cast(index, tf.int32) * tf.ones(tf.shape(x)[0], dtype=tf.int32)
    ], axis=-1)
    return tf.tensor_scatter_nd_update(x, indices, update)


def _replace_mask_with_constant(x, mask, value):
    indices = tf.where(mask)
    values = value * tf.zeros([tf.shape(indices)[0]], dtype=x.dtype)
    return tf.tensor_scatter_nd_update(indices, indices, values)


@tf.function
def _initialize_nmf_nndsvda(X, n_components, eps):
    n_samples, n_features = tf.shape(X)[0], tf.shape(X)[1]

    # TODO: Randomized SVD can be faster when we only need a few
    # components. That's what scikit-learn does.
    #
    # Note that the V here is the transpose of the V matrix
    # returned by numpy's SVD.
    S, U, V = tf.linalg.svd(X, compute_uv=True)
    S = S[: n_components]
    U = U[:, : n_components]
    V = V[:, : n_components]

    W = tf.zeros_like([n_samples, n_components], dtype=X.dtype)
    H = tf.zeros_like([n_components, n_features], dtype=X.dtype)

    # TODO: See how fast the below code is and maybe make faster. IDK if the
    # tensor_scatter_nd_update will make a copy per step in this case. I can
    # also probably do the loop in a single step.

    # The leading singular triplet is non-negative
    # so it can be used as is for initialization.
    W = _update_second_axis(W, 0, tf.sqrt(S[0]) * tf.abs(U[:, 0]))
    H = _update_first_axis(H, 0, tf.sqrt(S[0]) * tf.abs(V[:, 0]))

    for j in tf.range(1, n_components):
        x, y = U[:, j], V[:, j]

        # extract positive and negative parts of column vectors
        x_p, y_p = tf.maximum(x, 0), tf.maximum(y, 0)
        x_n, y_n = tf.abs(tf.minimum(x, 0)), tf.abs(tf.minimum(y, 0))

        # and their norms
        x_p_nrm, y_p_nrm = tf.norm(x_p), tf.norm(y_p)
        x_n_nrm, y_n_nrm = tf.norm(x_n), tf.norm(y_n)

        m_p, m_n = x_p_nrm * y_p_nrm, x_n_nrm * y_n_nrm

        # choose update
        if m_p > m_n:
            u = x_p / x_p_nrm
            v = y_p / y_p_nrm
            sigma = m_p
        else:
            u = x_n / x_n_nrm
            v = y_n / y_n_nrm
            sigma = m_n

        lbd = tf.sqrt(S[j] * sigma)
        W = _update_second_axis(W, j, lbd * u)
        H = _update_first_axis(H, j, lbd * v)

    avg = tf.reduce_mean(X)

    W = _replace_mask_with_constant(W, W < eps, 0.0)
    W = _replace_mask_with_constant(W, W == 0.0, avg)

    H = _replace_mask_with_constant(H, H < eps, 0.0)
    H = _replace_mask_with_constant(H, H == 0.0, avg)

    return W, H


def _initialize_nmf(X, n_components, eps=1e-6):
    if not is_non_negative(X):
        raise ValueError('X must be non-negative.')

    n_components = tf.cast(n_components, tf.int32)
    eps = tf.cast(eps, tf.float32)

    n_samples, n_features = X.shape
    if n_components <= min(n_samples, n_features):
        return _initialize_nmf_random(X, n_components)
    else:
        return _initialize_nmf(X, n_components, eps)


###############################################################################
###############################################################################

@tf.function
def _update_cdnmf_fast(W, HHt, XHt, permutation):
    n_samples = tf.shape(W)[0]
    n_components = tf.shape(W)[1]

    # TODO: Try to optimize this. What can be done in a batch rather than sequentially?
    violation = 0.0
    for s in tf.range(n_components):
        t = permutation[s]

        for i in tf.range(n_samples):
            grad = -XHt[i, t]

            wi = W[i]
            wit = wi[t]

            # for r in tf.range(n_components):
            #     grad += HHt[t, r] * W[i, r]

            grad += tf.reduce_sum(HHt[t] * wi)

            # projected gradient
            # if W[i, t] == 0:
            if wit == 0:
                pg = tf.minimum(0.0, grad)
            else:
                pg = grad

            violation += tf.abs(pg)

            # Hessian
            hess = HHt[t, t]

            if hess != 0:
                # W.scatter_nd_update([[i, t]], [tf.maximum(W[i, t] - grad / hess, 0.)])
                W.scatter_nd_update([[i, t]], [tf.maximum(wit - grad / hess, 0.)])
                
    return violation


@tf.function
def _update_coordinate_descent(X, W, Ht, l1_reg, l2_reg, shuffle):
    n_components = tf.shape(Ht)[1]

    HHt = tf.linalg.matmul(Ht, Ht, transpose_a=True)
    XHt = tf.linalg.matmul(X, Ht)

    # L2 regularization corresponds to increase of the diagonal of HHt
    HHt += l2_reg * tf.eye(n_components, dtype=HHt.dtype)

    # L1 regularization corresponds to decrease of each element of XHt
    XHt -= l1_reg

    permutation = tf.range(n_components)
    if shuffle:
        permutation = tf.random.shuffle(permutation)

    return _update_cdnmf_fast(W, HHt, XHt, permutation)


def _fit_coordinate_descent(
    X,
    W,
    Ht,
    tol,
    max_iter,
    l1_reg_W=0,
    l1_reg_H=0,
    l2_reg_W=0,
    l2_reg_H=0,
    shuffle=False,
):
    # Cast to tensors to prevent tf.function from retracing with changes to
    # values of these.
    tol = tf.cast(tol, tf.float32)
    l1_reg_W = tf.cast(l1_reg_W, tf.float32)
    l1_reg_H = tf.cast(l1_reg_H, tf.float32)
    l2_reg_W = tf.cast(l2_reg_W, tf.float32)
    l2_reg_H = tf.cast(l2_reg_H, tf.float32)

    Xt = tf.transpose(X)

    # TODO: Make this tqdm nices, maybe include loss information, more logging.
    for step in tqdm(range(max_iter)):
        violation = 0.0

        start = time.time()
        # Update W
        violation += _update_coordinate_descent(
            X, W, Ht, l1_reg_W, l2_reg_W, shuffle
        )
        print(time.time() - start)

        start = time.time()
        # Update H
        violation += _update_coordinate_descent(
            Xt, Ht, W, l1_reg_H, l2_reg_H, shuffle
        )
        print(time.time() - start)

        if step == 0:
            violation_init = violation

        if violation_init == 0:
            break

        if violation / violation_init <= tol:
            break

    return step + 1


###############################################################################
###############################################################################

class TfNMF:
    """TF Implementation of scikit-learn's NMF.

    NOTE: Currently only supports dense tensors. I might create
    a separate class or subclass of this to support sparse tensors.
    """
    def __init__(
        self,
        n_components: int,
        *,
        init=None,
        solver="cd",
        beta_loss="frobenius",
        tol=1e-4,
        max_iter=200,
        # random_state=None,
        alpha_W=0.0,
        alpha_H="same",
        l1_ratio=0.0,
        # verbose=0,
        shuffle=False,
    ):
        # TODO: Support more values for these.
        assert init is None
        assert solver == 'cd'
        assert beta_loss == 'frobenius'

        self.n_components = n_components

        self.init = init
        self.solver = solver
        self.beta_loss = beta_loss
        self.tol = tol
        self.max_iter = max_iter
        # self.random_state = random_state
        self.alpha_W = alpha_W
        self.alpha_H = alpha_H
        self.l1_ratio = l1_ratio
        # self.verbose = verbose
        self.shuffle = shuffle

        self.W = None
        self.Ht = None

    def _check_params(self, X: tf.Tensor):
        # TODO: Missing a lot of stuff from the original. Original at
        # https://github.com/scikit-learn/scikit-learn/blob/80598905e/sklearn/decomposition/_nmf.py#L1421

        # n_components
        self._n_components = self.n_components

        # beta_loss
        self._beta_loss = _beta_loss_to_float(self.beta_loss)

        assert self.solver == 'cd' and self._beta_loss == _ALLOWED_BETA_LOSS["frobenius"]

        (
            self._l1_reg_W,
            self._l1_reg_H,
            self._l2_reg_W,
            self._l2_reg_H,
        ) = _compute_regularization(
            self.alpha_W, self.alpha_H, self.l1_ratio
        )

    def _initialize_nmf(self, X):
        if self.W is not None or self.Ht is not None:
            raise ValueError('Already initialized.')
        # TODO: Have _initialize_nmf just operate on Ht
        W, H = _initialize_nmf(
            X, self._n_components
        )
        self.W = tf.Variable(W, name='W', trainable=False)
        self.Ht = tf.Variable(tf.transpose(H), name='H', trainable=False)
        return self.W, self.Ht

    def _scale_regularization(self, X):
        n_samples, n_features = X.shape
        l1_reg_W = n_features * self._l1_reg_W
        l1_reg_H = n_samples * self._l1_reg_H
        l2_reg_W = n_features * self._l2_reg_W
        l2_reg_H = n_samples * self._l2_reg_H
        return l1_reg_W, l1_reg_H, l2_reg_W, l2_reg_H

    def _fit_transform(
        self,
        X: DenseTensorLike
    ):
        """Learn a NMF model for the data X and returns the transformed data.

        Note that I just copied the documentation below, so not all of the
        functionality might be implemented (yet).

        Args:
            X: shape (n_samples, n_features), the data matrix to be decomposed.
        Returns:
            # TODO: Write this.
        """
        # assert update_H

        X = _to_dense_float32_tensor(X)

        if not is_non_negative(X):
            raise ValueError('X must be non-negative.')

        self._check_params(X)

        if tf.reduce_min(X) == 0 and self._beta_loss <= 0:
            raise ValueError(
                "When beta_loss <= 0 and X contains zeros, "
                "the solver may diverge. Please add small values "
                "to X, or use a positive beta_loss."
            )

        # initialize W and H
        W, Ht = self._initialize_nmf(X)

        # scale the regularization terms
        l1_reg_W, l1_reg_H, l2_reg_W, l2_reg_H = self._scale_regularization(X)

        assert self.solver == 'cd'

        n_iter = _fit_coordinate_descent(
            X,
            W,
            Ht,
            self.tol,
            self.max_iter,
            l1_reg_W,
            l1_reg_H,
            l2_reg_W,
            l2_reg_H,
            # update_H=update_H,
            # verbose=self.verbose,
            shuffle=self.shuffle,
            # random_state=self.random_state,
        )

        return n_iter
