from jax import numpy as jnp

# Loss functions. All between 0 and 1


def zero_one_loss(prev_state, prev_action, prev_results, current_state):
    """
    Measurement of total classifier accuracy after action taken.
    """
    # Pr(Y_hat = Y | G)
    error_frac = prev_results.fp_frac + prev_results.fn_frac

    # Pr(Y_hat = Y)
    return jnp.sum(prev_state.pr_G * error_frac)


def max_zero_one_loss(prev_state, prev_action, prev_results, current_state):
    """
    Measurement of total classifier accuracy after action taken.
    """
    # Pr(Y_hat = Y | G)
    error_frac = prev_results.fp_frac + prev_results.fn_frac

    # Pr(Y_hat = Y on worst-performing group)
    return jnp.max(error_frac)


def tp_loss(prev_state, prev_action, prev_results, current_state):
    error_frac = 1 - prev_results.tp_frac
    return jnp.sum(prev_state.pr_G * error_frac)


def max_tp_loss(prev_state, prev_action, prev_results, current_state):
    error_frac = 1 - prev_results.tp_frac

    # loss on worst-performing group
    return jnp.max(error_frac)


# w for 'weighted' tn
def tp_wtn_loss(prev_state, prev_action, prev_results, current_state):
    error_frac = 1 - prev_results.tp_frac - 0.8 * prev_results.tn_frac
    return jnp.sum(prev_state.pr_G * error_frac)


def max_tp_wtn_loss(prev_state, prev_action, prev_results, current_state):
    error_frac = 1 - prev_results.tp_frac - 0.8 * prev_results.tn_frac

    # loss on worst-performing group
    return jnp.max(error_frac)


################################################################################
# Disparity definitions. Also between 0 and 1


def euclidean_demographic_partity(prev_state, prev_action, prev_results, current_state):
    """
    Measurement of violation of demographic parity after action taken.
    """
    v = prev_results.accept_rate
    return jnp.sum((v[:, jnp.newaxis] - v) ** 2) / 2


def equal_opportunity_disparity(prev_state, prev_action, prev_results, current_state):
    v = prev_results.tp_rate
    return jnp.sum((v[:, jnp.newaxis] - v) ** 2) / 2


def equalized_odds_disparity(prev_state, prev_action, prev_results, current_state):
    v1 = prev_results.tp_rate
    v2 = prev_results.tn_rate
    dis_v1 = jnp.sum((v1[:, jnp.newaxis] - v1) ** 2) / 2
    dis_v2 = jnp.sum((v2[:, jnp.newaxis] - v2) ** 2) / 2
    return dis_v1 + dis_v2


def qualification_rate_disparity(prev_state, prev_action, prev_results, current_state):
    v = current_state.pr_Y1
    return jnp.sum((v[:, jnp.newaxis] - v) ** 2) / 2


def group_neg_entropy(prev_state, prev_action, prev_results, current_state):
    """
    Negative Entropy of group variable G
    Minimized when groups are uniform size.
    """
    pr_g = current_state.pr_G
    return jnp.sum(pr_g * jnp.log(pr_g))


################################################################################


def zero(prev_state, prev_action, prev_results, current_state):
    return 0.0


known_loss_func_dict = {
    "zero_one_loss": zero_one_loss,
    "max_zero_one_loss": max_zero_one_loss,
    "tp_loss": tp_loss,
    "max_tp_loss": max_tp_loss,
    "tp_wtn_loss": tp_wtn_loss,
    "max_tp_wtn_loss": max_tp_wtn_loss,
    "Demographic Disparity": euclidean_demographic_partity,
    "Unequal Opportunity": equal_opportunity_disparity,
    "Odds Disparity": equalized_odds_disparity,
    "Qualification Rate Disparity": qualification_rate_disparity,
    "None": zero,
}
