import torch


class SelectiveNetLoss(torch.nn.Module):
    def __init__(self, loss_func, coverage: float, lm: float = 32.0):
        """
        Args:
            loss_func: base loss function. the shape of loss_func(x, target) shoud be (B).
                       e.g.) torch.nn.CrossEntropyLoss(reduction=none) : classification
            coverage: target coverage.
            lm: Lagrange multiplier for coverage constraint. original experiment's value is 32.
        """
        super(SelectiveNetLoss, self).__init__()
        assert 0.0 < coverage <= 1.0
        assert 0.0 < lm

        self.loss_func = loss_func
        self.coverage = coverage
        self.lm = lm

    def forward(self, prediction_out, selection_out, target, device):
        """
        Args:
            prediction_out: (B,num_classes)
            selection_out:  (B, 1)
        """
        emprical_coverage = selection_out.mean()

        emprical_risk = (
            self.loss_func(prediction_out, target) * selection_out.view(-1)
        ).mean()
        emprical_risk = emprical_risk / emprical_coverage

        coverage = torch.tensor(
            [self.coverage], dtype=torch.float32, requires_grad=True, device=device
        )
        penalty = (
            torch.max(
                coverage - emprical_coverage,
                torch.tensor(
                    [0.0], dtype=torch.float32, requires_grad=True, device=device
                ),
            )
            ** 2
        )
        penalty *= self.lm

        selective_loss = emprical_risk + penalty

        return selective_loss
