import torch
from torch import Tensor
import torch.nn as nn


class BCEDiff(nn.Module):

    def __init__(self) -> None:
        super(BCEDiff, self).__init__()

        self.ce = nn.CrossEntropyLoss()

    def forward(self, input: Tensor, target: Tensor, protected_attribute: Tensor):

        ce_group_0 = self.ce(input[protected_attribute == 0], target[protected_attribute == 0])
        ce_group_1 = self.ce(input[protected_attribute == 1], target[protected_attribute == 1])

        return (torch.abs(ce_group_0 - ce_group_1))


class ProtectedAttributeCE(nn.Module):
    def __init__(self) -> None:
        super(ProtectedAttributeCE, self).__init__()

        self.ce = nn.CrossEntropyLoss()

    def forward(self, input: Tensor, target: Tensor, protected_attribute: Tensor):
        
        return self.ce(input, protected_attribute)
    
class NegativeProtectedAttributeCE(nn.Module):
    def __init__(self) -> None:
        super(NegativeProtectedAttributeCE, self).__init__()

        self.ce = nn.CrossEntropyLoss()

    def forward(self, input: Tensor, target: Tensor, protected_attribute: Tensor):
        
        return - self.ce(input, protected_attribute)
