"""Dataset size counting."""

import tensorflow as tf


def count_dataset_size(ds, features_list):
    """Counts the number of examples in features_list.
    features_list = [{'label_0': v_0, 'label_1': v_1, ...}].
    """
    keys = tuple(features_list[0].keys())
    features_list_values = [tuple(el.values()) for el in features_list]
    features = [tuple(zip(keys, vals)) for vals in features_list_values]
    def count(counts, batch):
        classes = {}
        for feature in features:
            # feature = (('label_0',v_0), ('label_1',v_1), ...).
            query = []
            for f in feature:
                query.append(batch[f[0]] == f[1])
            classes[feature] = tf.cast(tf.reduce_all(query, axis=0), tf.int32)
        for key, val in classes.items():
            counts[key] += tf.reduce_sum(val)
        return counts

    counts = ds.reduce(
            initial_state={el: 0 for el in features},
            reduce_func=count)
    counts = {str(key): val.numpy() for key, val in counts.items()}
    return counts
