
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from sklearn.metrics import roc_auc_score
import numpy as np

def AUCFairness(predict_arr, labels, sensitive_labels):
    '''
    Different type of AUC Fairness, for binary and multi-class classification
    '''

    predict_arr = torch.tensor(predict_arr)
    prob = torch.sigmoid(predict_arr)
    # print('prob : ', prob.size())
    labels = torch.tensor(labels)
    sensitive_labels = torch.tensor(sensitive_labels)


    intraAUCF, interAUCF, xAUCF = 0, 0, 0
    unique_sa = np.unique(sensitive_labels)


    for i in range(len(unique_sa)):
        sl_1 = sensitive_labels == i
        sl_0 = sensitive_labels != i
        label_1 = labels == 1
        label_0 = labels == 0


        prob_00 = prob[ sl_0 & label_0]
        prob_01 = prob[ sl_1 & label_0]
        prob_10 = prob[ sl_0 & label_1]
        prob_11 = prob[ sl_1 & label_1]

        label_00 = labels[sl_0 & label_0]
        label_01 = labels[sl_1 & label_0]
        label_10 = labels[sl_0 & label_1]
        label_11 = labels[sl_1 & label_1]

        # print(label_00.size(), label_01.size(), label_10.size(), label_11.size())

        # IntraAUC
        a0_label = torch.cat((label_00, label_10))
        a0_prob = torch.cat((prob_00,prob_10))
        a1_label = torch.cat((label_01, label_11))
        a1_prob =  torch.cat((prob_01,prob_11))
        intraAUCF += abs(roc_auc_score(a0_label, a0_prob) - roc_auc_score(a1_label, a1_prob))

        # InterAUC
        a0_label = torch.cat((label_00, label_11))
        a0_prob = torch.cat((prob_00, prob_11))
        a1_label = torch.cat((label_01, label_10))
        a1_prob = torch.cat((prob_01, prob_10))
        interAUCF += abs(roc_auc_score(a0_label, a0_prob) - roc_auc_score(a1_label, a1_prob))

        # XAUC
        a1_positive = torch.cat((label_01, label_11))
        a1_positive = torch.ones_like(a1_positive)
        a0_negative = torch.cat((label_00, label_10))
        a0_negative = torch.zeros_like(a0_negative)
        prob_a1 = torch.cat((prob_01, prob_11))
        prob_a0 = torch.cat((prob_00, prob_10))

        a_labels = torch.cat((a1_positive, a0_negative))
        a_probs = torch.cat((prob_a1, prob_a0))
        xAUCF += abs(roc_auc_score(a_labels, a_probs) - 0.5)

    return {'intraAUCF': '{:.4f}'.format(intraAUCF/len(unique_sa)), 'interAUCF': '{:.4f}'.format(interAUCF/len(unique_sa)), 'xAUCF': '{:.4f}'.format(xAUCF/len(unique_sa))}





def DP(predict_arr, sensitive_labels):
    '''
    binary or multi-class sensitive_labels
    '''
    predict_arr = torch.tensor(predict_arr)
    y_prob = torch.sigmoid(predict_arr)
    # print('y_prob : ', y_prob.size())
    sensitive_labels = torch.tensor(sensitive_labels)
    # print('sensitive_labels : ', sensitive_labels.size())
    y_prob, sensitive_labels = y_prob.view(-1), sensitive_labels.view(-1)
    # print('-------------------------------')


    DP = 0
    unique_sa = np.unique(sensitive_labels)
    # print('unique_sa : ', unique_sa)
    for i in range(len(unique_sa) - 1):
        a0_index =  sensitive_labels == i
        a1_index = sensitive_labels == i+1
        pred = (y_prob>0.5)*1
        # print('pred : ', y_prob)
        # print(min(y_prob), max(y_prob))
        DP += abs(torch.mean(pred[a0_index] * 1.0) - torch.mean(pred[a1_index] * 1.0))

    return DP/(len(unique_sa) - 1)


def EOD_EOP(predict_arr, labels, sensitive_labels):
    '''
    binary or multi-class sensitive_labels
    '''
    predict_arr = torch.tensor(predict_arr)
    y_prob = torch.sigmoid(predict_arr)
    # print('y_prob :', y_prob.size())

    sensitive_labels = torch.tensor(sensitive_labels)
    pred_labels = (y_prob > 0.5)*1
    labels = torch.tensor(labels)
    sensitive_labels, pred_labels, labels = sensitive_labels.view(-1), pred_labels.view(-1), labels.view(-1)
    eod, eop = 0, 0
    for i in np.unique(labels):
        i_ind = labels == i
        i_pred_label = pred_labels[i_ind]
        i_label = labels[i_ind]
        i_sensitive_label = sensitive_labels[i_ind]
        group_accs = []
        for j in np.unique(sensitive_labels):
            j_ind = i_sensitive_label == j
            j_pred_label = i_pred_label[j_ind]
            j_label = i_label[j_ind]
            j_acc = torch.sum(j_pred_label == j_label)/len(j_pred_label)

            group_accs.append(j_acc)
        eod += abs(group_accs[0] - group_accs[1])/len(np.unique(labels))
        if i == 1:
            eop = abs(group_accs[0] - group_accs[1])

    return eod, eop



