
import numpy as np


def compute_stratified_class_weights(targets, n_classes=None):
    """Compute weights for stratified class sampling

    Parameters
    ----------
    param_file : np.array
        vector of classification targets

    Example
    -------
        weights = compute_stratified_class_weights(targets)
        sampler = torch.utils.data.WeightedRandomSampler(weights, len(weights), replacement=True)

    Returns
    -------
    sample_weights : np.array
        vector of sampling probabilities of each individual sample
    """

    # count label occurances
    targs, counts = np.unique(targets, return_counts=True)

    if n_classes is None:
        n_classes = len(counts)

    # compute class weights
    class_weights = np.zeros(n_classes, dtype=np.float32)
    for t, c in zip(targs, counts):
        class_weights[t] = 1.0 / c

    # assign weights to samples
    sample_weights = np.zeros(len(targets), dtype=np.float32)
    for idx, t in enumerate(targets):
        sample_weights[idx] = class_weights[t]

    return sample_weights
