R"""Converts a dense NMF decomposition into one involving one or more sparse factors.




CUDA_VISIBLE_DEVICES= python scripts1/sparse/sparsify_nmf.py \
    --nmf_path=/fruitbasket/users/m/project_data/extract_merge1/ll1/per_example_fishers/nmf_decomp.c1024_2kIters_65536pe.feather_berts_0.hans_lone.all_vars.5k.262144.h5 \
    --H_threshold=1e-8



"""
import os

from absl import app
from absl import flags

import numpy as np

from em.tools.nmf import nmf_common
from em.util.color_util import cu

FLAGS = flags.FLAGS


# TODO: Add descriptions to flags
flags.DEFINE_string("nmf_path", None, "")

flags.DEFINE_string("output_path", None, "")

flags.DEFINE_float("H_threshold", None, "")
# flags.DEFINE_float("W_threshold", None, "")


def to_csr(A: np.ndarray, threshold: float):
    n_rows, n_cols = A.shape
    row_infos = [0]
    all_values = []
    all_col_inds = []

    for i in range(n_rows):
        row = A[i]
        mask = row >= threshold

        values = row[mask]
        all_values.append(values)

        col_inds, = np.nonzero(mask)
        all_col_inds.append(col_inds)

        row_infos.append(row_infos[-1] + values.shape[0])

    all_values = np.concatenate(all_values, axis=0)
    row_infos = np.array(row_infos, dtype=np.int64)
    all_col_inds = np.concatenate(all_col_inds, axis=0).astype(np.int32)

    return all_values, row_infos, all_col_inds


def main(_):
    assert FLAGS.H_threshold is not None

    nmf = nmf_common.NmfDecomposition.load(os.path.expanduser(FLAGS.nmf_path))
    nmf.normalize_components_to_unit_norm()

    print(cu.hly(f'Dense size: {nmf.H.shape[0] * nmf.H.shape[1]}'))

    H_vals, H_row_infos, H_col_inds = to_csr(nmf.H, FLAGS.H_threshold)

    print(cu.hly(f'Sparse NNZ: {H_vals.shape[0]}'))

    sp_nmf = nmf_common.SparseNmfDecomposition(
        W=nmf.W,
        H_shape=nmf.H.shape,
        H_values=H_vals,
        H_row_indices=H_row_infos,
        H_column_indices=H_col_inds,
        reduce_kept_indices=nmf.reduce_kept_indices,
        full_dense_size=nmf.full_dense_size,
    )
    sp_nmf.save(os.path.expanduser(FLAGS.output_path))


if __name__ == "__main__":
    app.run(main)
