"""Utilities for perturbation experiments.

Basically common stuff that I'd end up placing a copy in a bunch
of scripts otherwise.
"""
import numpy as np
import tensorflow as tf


##########################################################################

class SoftmaxKlDivergenceLoss(tf.keras.losses.KLDivergence):

    def call(self, y_true, y_pred):
        return super().call(
            tf.math.softmax(y_true, axis=-1),
            tf.math.softmax(y_pred, axis=-1),
        )


##########################################################################

def get_top_example_indices(nmf, component_index: int, n_examples: int):
    return np.argsort(-nmf.W[:, component_index])[:n_examples]


def get_uniformly_random_example_indices(nmf, n_examples: int):
    total_examples = nmf.W.shape[0]
    return np.random.permutation(np.arange(total_examples))[:n_examples]
