import torch
from torch import nn


class _RiskDiff(nn.Module):
    def __init__(self, protect):
        '''
        protect: [0, 1]
        '''
        super().__init__()
        self.protect = -(2 * protect - 1)

    def forward(self, z, s):
        p = s.mean()
        prob = s * p + (1 - s) * (1 - p)
        _s = 2 * s - 1
        _s = self.protect * _s
        return torch.mean(torch.div(self.surrogate(z, _s), prob) - 1)

    def get_01score(self, z, s):
        y_pred = torch.heaviside(z, torch.tensor([0.0]))
        p1 = torch.mean(s * y_pred) / torch.mean(s)
        p2 = torch.mean((1 - s) * y_pred) / torch.mean(1 - s)
        return p1 - p2



class LogisticRiskDiff(_RiskDiff):
    def __init__(self, protect):
        super().__init__(protect)
        self.surrogate = ApproxHuber


        
def ApproxHuber(z, _s):
    mask = torch.where(z * _s < 10, 1.0, 0.0)
    return torch.log(1 + torch.exp(z * mask * _s)) + (z * _s - 0.6931)* (1 - mask)