import pandas as pd
import numpy as np

def probability_of_not_kprime_given_not_k_and_z(calibration_data):
    """Producing two sets of probabilitis for Z=z and Z =! z, each a |Y'| x |Y| x |Z| matrix of probabilities, where Y' is the predicated label and Y is ground truth label and Z is sensitive attribute."""
    num_classes = max(calibration_data["truth"]) + 1
    num_sensitives = max(calibration_data["sensitive"]) + 1

    probabilities_for_z = np.zeros((num_classes, num_classes, num_sensitives))
    probabilities_but_z = np.zeros((num_classes, num_classes, num_sensitives))

    for z in range(num_sensitives):
        for k in range(num_classes):
            for k_ in range(num_classes):
                numerator_for_z = calibration_data.query(f'prediction != {k_} and truth != {k} and sensitive == {z}')
                denominator_for_z = calibration_data.query(f'truth != {k} and sensitive == {z}')
                probabilities_for_z[k_, k, z] = len(numerator_for_z) / len(denominator_for_z)

                numerator_but_z = calibration_data.query(f'prediction != {k_} and truth != {k} and sensitive != {z}')
                denominator_but_z = calibration_data.query(f'truth != {k} and sensitive != {z}')
                probabilities_but_z[k_, k, z] = len(numerator_but_z) / len(denominator_but_z)
                
    return probabilities_for_z, probabilities_but_z


def probability_of_k_given_kprime_and_z(calibration_data):
    """Producing two sets of probabilitis for Z=z and Z =! z, each |Y| x |Y'| x |Z| matrix of probabilities, where Y is ground truth label and Y' is the predicated label and Z is sensitive attribute."""

    num_classes = max(calibration_data["truth"]) + 1
    num_sensitives = max(calibration_data["sensitive"]) + 1

    probabilities_for_z = np.zeros((num_classes, num_classes, num_sensitives))
    probabilities_but_z = np.zeros((num_classes, num_classes, num_sensitives))

    for z in range(num_sensitives):
        for k in range(num_classes):
            for k_ in range(num_classes):
                numerator_for_z = calibration_data.query(f'truth == {k} and prediction == {k_} and sensitive == {z}')
                denominator_for_z = calibration_data.query(f'prediction == {k_} and sensitive == {z}')
                probabilities_for_z[k, k_, z] = len(numerator_for_z) / len(denominator_for_z)

                numerator_but_z = calibration_data.query(f'truth == {k} and prediction == {k_} and sensitive != {z}')
                denominator_but_z = calibration_data.query(f'prediction == {k_} and sensitive != {z}')
                probabilities_but_z[k, k_, z] = len(numerator_but_z) / len(denominator_but_z)

    # sum of probabilities (over 'Y') should some to 1
    assert np.allclose(probabilities_for_z.sum(axis=0), 1)
    assert np.allclose(probabilities_for_z.sum(axis=0), 1)

    return probabilities_for_z, probabilities_but_z

def probability_of_k(calibration_data):
    """Producing a |Y| vector of probabilities, where Y is ground truth label."""
    num_classes = max(calibration_data["truth"]) + 1
    
    probabilities = np.zeros(num_classes)
    for k in range(num_classes):
        numerator = calibration_data.query(f'truth == {k}')
        denominator = calibration_data
        probabilities[k] = len(numerator) / len(denominator)
    
    # sum of probabilities (over 'Y') should some to 1
    assert np.allclose(probabilities.sum(axis=0), 1)
    return probabilities



def generate_calibration_constants(fairness_metric, c_votes, c_targets, c_sensitives):
    calibration_data = pd.DataFrame(
        np.c_[c_votes.argmax(axis=1), c_targets, c_sensitives], columns=["prediction", "truth", "sensitive"]) # note the prediction is not noised yet
    calibration_data = calibration_data.astype({"prediction": int, "truth": int, "sensitive": int})
    

    if fairness_metric == 'ErrorParity':
        prob_Y_given_Yhat_and_Z_equal_to_z, prob_Y_given_Yhat_and_Z_not_equal_to_z = probability_of_k_given_kprime_and_z(calibration_data)

        Y = probability_of_k(calibration_data)
        indicator = np.ones((Y.shape[0], Y.shape[0])) - np.eye(Y.shape[0])

        for_z = prob_Y_given_Yhat_and_Z_equal_to_z.transpose((0, 2, 1)) @ indicator @ Y
        but_z = prob_Y_given_Yhat_and_Z_not_equal_to_z.transpose((0, 2, 1)) @ indicator @ Y

        return for_z, but_z

    elif fairness_metric == 'EqualityOfOdds':
        # breakpoint()
        prob_Yhat_given_Y_and_Z_equal_to_z, prob_Yhat_given_Y_and_Z_not_equal_to_z = probability_of_not_kprime_given_not_k_and_z(calibration_data)
        
        for_z=prob_Yhat_given_Y_and_Z_equal_to_z
        but_z=prob_Yhat_given_Y_and_Z_not_equal_to_z

        return for_z, but_z