from typing import Literal

from torch import Tensor
from torch.nn import BCEWithLogitsLoss
from torch.nn import functional as F
from torch_geometric.data import Data
from typing_extensions import override


class BCEWithLogitsLossGraphTarget(BCEWithLogitsLoss):
    """
    Same as `torch.nn.BCEWithLogitsLoss`, but instead of a target Tensor, this takes a graph and uses `graph.y` as
    target labels.

    The `pos_weight` parameter is calculated dynamically based on the target labels.
    """

    @override
    def __init__(self, reduction: Literal["none", "mean", "sum"] = "mean"):
        super().__init__(reduction=reduction)

    @override
    def forward(self, input: Tensor, graph: Data) -> Tensor:
        # this is the number of edges with label 0 divided by the number of edges with label 1
        # https://discuss.pytorch.org/t/bceloss-with-class-weights/196991/4
        positive_class_weight = graph.y.size(0) / graph.y.sum() - 1

        return F.binary_cross_entropy_with_logits(
            input,
            target=graph.y,
            reduction=self.reduction,
            pos_weight=positive_class_weight,
        )
