"""Just a quick wrapper to be able to run multiple scikit-learn NMFs at once.

Note that the naive solution here is likely inferior to a
method that runs the NMFs in a batch, especially on GPU versions.
"""
import functools
from typing import Tuple, Union

import numpy as np
from sklearn.decomposition import NMF
import tensorflow as tf


def _perform_single_nmf(matrix: np.ndarray, nmf_kwargs) -> Tuple[np.ndarray, np.ndarray]:
    model = NMF(**nmf_kwargs)
    W = model.fit_transform(matrix)
    H = model.components_
    print('Done')
    return W, H


def perform_nmfs(matrices: np.ndarray, **nmf_kwargs) -> Tuple[np.ndarray, np.ndarray]:
    # Note that this method requires all matrices to have the same shape.
    # If n_processes is None, then os.cpu_count() is used.

    if isinstance(matrices, tf.Tensor):
        matrices = matrices.numpy()

    # Needs batch dimension followed by two matix dimensions.
    assert len(matrices.shape) == 3

    # I found a simple for loop to be much faster than using Python multiprocessing.
    outputs = []
    for matrix in matrices:
        outputs.append(_perform_single_nmf(matrix, nmf_kwargs))

    Ws, Hs = zip(*outputs)

    return np.stack(Ws, axis=0), np.stack(Hs, axis=0)
