import torch


def DI(dataloader, model, target_idx, sensitive_idx, gpu=0, batch_size=128):
    f_10 = 0
    f_x0 = 0
    f_11 = 0
    f_x1 = 0
    model.eval()
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloader):
            labels = labels.cuda(gpu)
            inputs_var = inputs.cuda(gpu)
            labels_var = labels[:, target_idx].cuda(gpu)

            output = model(inputs_var).float()
            y_score = output.detach()
            y_bar = torch.argmax(y_score, dim=1)

            z = labels[:, sensitive_idx]

            f_10 += torch.sum((y_bar == 1) & (z == 0)).float()
            f_x0 += torch.sum((z == 0)).float()
            f_11 += torch.sum((y_bar == 1) & (z == 1)).float()
            f_x1 += torch.sum((z == 1)).float()

        fitness = min((f_10 / f_x0) / (f_11 / f_x1), (f_11 / f_x1) / (f_10 / f_x0))
        return fitness.item()


def DEO(dataloader, model, target_idx, sensitive_idx, gpu=0, batch_size=128):
    f1_100 = 0
    f1_x00 = 0
    f1_101 = 0
    f1_x01 = 0
    f2_110 = 0
    f2_x10 = 0
    f2_111 = 0
    f2_x11 = 0

    model.eval()
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloader):
            labels = labels.cuda(gpu)
            inputs_var = inputs.cuda(gpu)
            labels_var = labels[:, target_idx].cuda(gpu)

            output = model(inputs_var).float()
            y_score = output.detach()
            y_bar = torch.argmax(y_score, dim=1)
            y = labels[:, target_idx]
            z = labels[:, sensitive_idx]

            f1_100 += torch.sum((y_bar == 1) & (y == 0) & (z == 0)).float()
            f1_x00 += torch.sum((y == 0) & (z == 0)).float()
            f1_101 += torch.sum((y_bar == 1) & (y == 0) & (z == 1)).float()
            f1_x01 += torch.sum((y == 0) & (z == 1)).float()
            f2_110 += torch.sum((y_bar == 1) & (y == 1) & (z == 0)).float()
            f2_x10 += torch.sum((y == 1) & (z == 0)).float()
            f2_111 += torch.sum((y_bar == 1) & (y == 1) & (z == 1)).float()
            f2_x11 += torch.sum((y == 1) & (z == 1)).float()

        fitness1 = f1_100 / f1_x00 - f1_101 / f1_x01
        fitness2 = f2_110 / f2_x10 - f2_111 / f2_x11
        DEO_fitness = abs(fitness1) + abs(fitness2)

    return DEO_fitness.item()
