import torch
from query import query_marginal


def demographic_parity_distance(data, target_feature, protected_feature, dataset, operation='mean'):
    """
    Calculates the demographic parity distance on the data with respect to a target feature and the given protected feature. 
    We allow for two options: mean, if we want the mean DP distance of all pairwise configurations, or max, if we want to
    have the maximal DP distance, as in the defintion. Note that the first option is more practical whenever this function
    is used in training where gradients may be computed, as max produces sparse gradients.

    :param data: (torch.tensor) The full one hot encoded data. Note that if you want to test a classifier for DP, then you
        first have to concatenate the predicted labels in a correct format to the data.
    :param target_feature: (str) The name of the target (label) feature.
    :param protected_feature: (str) The name of the protected feature.
    :param dataset: (BaseDataset) The instantiated dataset object containing the necessary information for the data.
    :param operation: (str) The operation applied to aggregate the absolute differences in the expected labeling. Available are
        only mean or max. Note that mean is preferable when we have to differentiate through this function.
    :return: (torch.float) The aggregated DP distance on the current constellation.
    """
    assert operation in ['mean', 'max'], 'Only mean and max operations are available'
    target_protected_marginal = query_marginal(data, (target_feature, protected_feature), dataset.full_one_hot_index_map, normalize=True, input_torch=True)
    # renormalize the marginal such that each column sums to 1
    normalization_constant = target_protected_marginal.sum(0)
    renormalized_target_protected_marginal = target_protected_marginal / normalization_constant.view((1, -1))
    expected_target_per_protected = (torch.arange(target_protected_marginal.size()[0], device=data.device).view((-1, 1)) * renormalized_target_protected_marginal).sum(0)
    all_differences = expected_target_per_protected.view((-1, 1)) - expected_target_per_protected.view((1, -1))
    op_of_absolute_differences = all_differences.abs().mean() if operation == 'mean' else all_differences.abs().max()
    return op_of_absolute_differences
