"""General utilities for dealing with sparse tensors/matrices."""
import dataclasses
from typing import List, Sequence, Tuple

import numpy as np
import tensorflow as tf

# typdefs
SparseTensor = tf.sparse.SparseTensor


def to_torch_coo_indices(indices: Sequence[np.ndarray]) -> np.ndarray:
    # Each element of the indices list must be 1-d np array.
    for inds in indices:
        assert len(inds.shape) == 1

    indices = list(indices)

    total_size = sum(inds.shape[0] for inds in indices)

    # NOTE: This will error if the indices argument is an empty sequence. Also,
    # we are assuming that all of the elements of the indices sequence have the same dtype.
    dtype = indices[0].dtype

    first_indices = np.zeros([total_size], dtype=dtype)
    offset = 0
    for i, inds in enumerate(indices):
        first_indices[offset : offset + len(inds)] = i
        offset += len(inds)

    flat_indices = np.concatenate(indices, axis=-1)

    return np.stack([
        first_indices,
        flat_indices,
    ], axis=0)


def to_scipy_coo_indices(indices: Sequence[np.ndarray]) -> np.ndarray:
    return to_torch_coo_indices(indices)


# def reduce_flat_sparse_indices(indices: np.ndarray, dense_size: int):
#     """Re-indexes a set of sparse indices such that no new index is zero for all in the set.

#     The `indices` argument is assumed to be a 2-d array with the first dimension being the batch
#     dimension.

#     The `dense_size` argument is the size of the second dimension when converted to dense form.

#     This is mainly for the torch-NMF library since it represents W and H as full dense tensors.
#     """
#     batch_size, _ = indices.shape

#     hit_counts = _compute_sparse_hits(indices, dense_size)
#     dense_keep_mask = hit_counts > 0
#     reduction_info, = np.nonzero(dense_keep_mask)

#     og_to_reduced_index = np.cumsum(dense_keep_mask.astype(np.int32)) - 1
#     reduced_size = og_to_reduced_index[-1] + 1

#     reduced_indices = np.zeros_like(indices)
#     for i in range(batch_size):
#         reduced_indices[i, :] = og_to_reduced_index[indices[i]]

#     return reduction_info, reduced_indices, reduced_size


###############################################################################

def _compute_sparse_hits(indices: np.ndarray, dense_size: int) -> np.ndarray:
    counts = np.zeros([dense_size], dtype=np.int32)
    for inds in indices:
        counts[inds] += 1
    return counts


@dataclasses.dataclass
class IndexReduction:
    kept_original_indices: np.ndarray

    original_size: int
    reduced_size: int


def remove_always_zero_indices(
    values: Sequence[np.ndarray],
    indices: Sequence[np.ndarray],
    dense_size: int,
    threshold: int = 1,
):
    assert len(values) == len(indices)
    for vals, inds in zip(values, indices):
        assert len(inds.shape) == 1
        assert len(vals.shape) == 1

    hit_counts = _compute_sparse_hits(indices, dense_size)
    dense_keep_mask = hit_counts >= threshold
    kept_original_indices, = np.nonzero(dense_keep_mask)

    og_to_reduced_index = np.cumsum(dense_keep_mask.astype(np.int32)) - 1
    reduced_size = og_to_reduced_index[-1] + 1

    reduced_values = []
    reduced_indices = []
    for vals, inds in zip(values, indices):
        mask = dense_keep_mask[inds]
        reduced_values.append(vals[mask])
        reduced_indices.append(og_to_reduced_index[inds[mask]])

    reduction_info = IndexReduction(
        kept_original_indices=kept_original_indices,
        original_size=dense_size,
        reduced_size=reduced_size,
    )

    return reduction_info, (reduced_values, reduced_indices)


def reduce_np_array(
    x: np.array,
    kept_original_indices: np.ndarray,
) -> np.ndarray:
    # x.shape = [..., original_size]
    return x[..., kept_original_indices]


def reduce_tf_tensor(
    x: tf.Tensor,
    kept_original_indices: np.ndarray,
) -> tf.Tensor:
    # x.shape = [..., original_size]
    return tf.gather(
        x, kept_original_indices, axis=-1
    )


def reduce_sparse_np_arrays(
    values: Sequence[np.ndarray],
    indices: Sequence[np.ndarray],
    dense_size: int,
    kept_original_indices: np.ndarray,
) -> Tuple[List[np.ndarray], List[np.ndarray]]:

    dense_keep_mask = np.zeros([dense_size], dtype=np.bool)
    dense_keep_mask[kept_original_indices] = True
    og_to_reduced_index = np.cumsum(dense_keep_mask.astype(np.int32)) - 1

    reduced_values = []
    reduced_indices = []
    for vals, inds in zip(values, indices):
        # Only handling sparse 1-d arrays for now.
        assert len(vals.shape) == len(inds.shape) == 1
        assert vals.shape[0] == inds.shape[0]

        mask = dense_keep_mask[inds]
        reduced_values.append(vals[mask])
        reduced_indices.append(og_to_reduced_index[inds[mask]])

    return reduced_values, reduced_indices


# def reduce_tf_sparse_tensor_as_np(
#     x: SparseTensor,
#     kept_original_indices: np.ndarray,
# ) -> np.ndarray:
#     # x.shape = [..., original_size]

#     # cumsum and gather or something
#     pass

#     # mask0 = tf.scatter_nd(
#     #     kept_original_indices[:, None],
#     #     tf.ones(kept_original_indices.shape, dtype=tf.bool),
#     #     shape=[x.shape[-1]]
#     # )
#     # mask = tf.gather(mask0, x.indices[:, -1])
#     # # x.indices[:, -1]

#     

###############################################################################

def stack_as_rows(sp_vecs: Sequence[SparseTensor]) -> SparseTensor:
    # Verify that the shapes are fine. The sp_vecs must be sparse vectors (i.e. rank 1).
    shapes = set(tuple(s.shape) for s in sp_vecs)
    assert len(shapes) == 1
    shape, = shapes
    assert len(shape) == 1

    values = tf.concat([s.values for s in sp_vecs], axis=0)
    col_inds = [tf.squeeze(s.indices, axis=-1) for s in sp_vecs]
    row_inds = [i * tf.ones_like(col_inds_) for i, col_inds_ in enumerate(col_inds)]
    indices = tf.stack([
        tf.concat(row_inds, axis=0),
        tf.concat(col_inds, axis=0),
    ], axis=-1)

    return SparseTensor(
        indices,
        values,
        [len(sp_vecs), shape[0]]
    )
