"""Fitting the W given the H for an NMF."""
from typing import Sequence, Union

import numpy as np
import scipy.sparse as sps
from sklearn.decomposition import NMF
import tensorflow as tf
from torchnmf.nmf import NMF as TorchNMF

from em.util import sparse_util

# I'm just using this for pytyping now, so remove or find a work-around
# if some circular dependency issue arises.
from . import nmf_common


# typdefs
SparseTensor = tf.sparse.SparseTensor


def _to_dense_np(x: Union[np.ndarray, tf.Tensor, SparseTensor]) -> np.ndarray:
    if isinstance(x, np.ndarray):
        return x
    elif isinstance(x, tf.Tensor):
        return x.numpy()
    elif isinstance(x, SparseTensor):
        return tf.sparse.to_dense(x).numpy()
    else:
        raise TypeError('Invalid type:', type(x))


# TODO: Figure out how to do a GPU-accelerated version of this.

def transform(
    nmf_decomp: nmf_common.NmfDecomposition,
    values: Sequence[np.ndarray],
    indices: Sequence[np.ndarray],
    *,
    alpha: float = 0.0,
    beta: float = 1.0,
    l1_ratio: float = 0.0,
    tol: float = 1e-4,
    max_iter: int = 200,
    #
    need_to_reduce: bool = True,
    keep_sparse: bool = True,
):
    # Only handling sparse 1-d arrays for now.
    if need_to_reduce:
        values, indices = sparse_util.reduce_sparse_np_arrays(
            values,
            indices,
            nmf_decomp.full_dense_size,
            nmf_decomp.reduce_kept_indices,
        )
    nmf = to_sklearn_nmf(nmf_decomp, alpha=alpha, beta=beta, l1_ratio=l1_ratio, tol=tol, max_iter=max_iter)

    coo_values = np.concatenate(values, axis=0)
    coo_indices = sparse_util.to_scipy_coo_indices(indices)
    dense_shape = [len(values), len(nmf_decomp.reduce_kept_indices)]

    nmf_inputs = sps.coo_matrix((coo_values, coo_indices), shape=dense_shape)

    if not keep_sparse:
        nmf_inputs = nmf_inputs.todense()

    return nmf.transform(nmf_inputs)


def to_sklearn_nmf(
    nmf_decomp: nmf_common.NmfDecomposition,
    alpha: float = 0.0,
    beta: float = 1.0,
    l1_ratio: float = 0.0,
    tol: float = 1e-4,
    max_iter: int = 200,
) -> NMF:
    # The alpha, beta, and l1_ratio parameters are taken from the pytorch-NMF.
    # I haven't checked if this mapping to sklearn-NMF parameters is correct, or
    # if it even matters for transforming data to coefficients.
    nmf = NMF(alpha=alpha, beta_loss=beta, l1_ratio=l1_ratio, tol=tol, max_iter=max_iter, solver='mu')

    # NOTE: not sure if this is correct way to set these or if anything else
    # needs to be set.
    nmf.components_ = nmf_decomp.H
    nmf.n_components_ = nmf_decomp.H.shape[0]

    return nmf


# def _to_sps_coo_H(nmf_decomp: nmf_common.NmfDecomposition) -> sps.coo.coo_matrix:
#     n_components = nmf_decomp.H.shape[0]

#     coo_values = nmf_decomp.H.reshape([-1])
#     coo_indices = sparse_util.to_scipy_coo_indices(n_components * [nmf_decomp.reduce_kept_indices])

#     # Remove zeros in the H matrix.
#     nonzeros = coo_values != 0.0
#     coo_values = coo_values[nonzeros]
#     coo_indices = coo_indices[:, nonzeros]

#     dense_shape = [n_components, nmf_decomp.full_dense_size]

#     return sps.coo_matrix((coo_values, coo_indices), shape=dense_shape)

##########################################################################


def to_torch_nmf(
    nmf_decomp: nmf_common.NmfDecomposition,
    alpha: float = 0.0,
    beta: float = 1.0,
    l1_ratio: float = 0.0,
    tol: float = 1e-4,
    max_iter: int = 200,
) -> TorchNMF:
    pass