"""Stuff for going from the NMF H to a Fisher."""
from typing import List, Sequence

import tensorflow as tf

from em.util import flat_pack

# typedefs
SparseTensor = tf.sparse.SparseTensor

##########################################################################


def sparse_vec_to_fishers(h: SparseTensor, variables: Sequence[tf.Tensor]) -> List[tf.Tensor]:
    packer = flat_pack.FlatPacker([v.shape for v in variables])
    return packer.decode_tf(tf.sparse.to_dense(h))


def single_component_to_fishers(nmf, variables, component_index: int) -> List[tf.Tensor]:
    # call sparse_vec_to_fishers
    # Add something more efficient than  nmf.get_full_sparse_H()
    h = nmf.get_single_sparse_h(component_index)
    return sparse_vec_to_fishers(h, variables)
