from jaxtyping import Float
from torch import Tensor, nn

from src.utils.sparse_utils import sparse_softmax


class OutputAdapter(nn.Module):
    """OutputAdapter for PyG models."""

    def __init__(
        self,
        prediction_head: nn.Module,
    ) -> None:
        """Initialize the OutputAdapter.

        :param prediction_head: Prediction head that takes the output of the PyG model and predicts edge
            logits.
        """
        super().__init__()

        self.prediction_head = prediction_head

    def forward(
        self,
        x: Float[Tensor, "n_nodes d_node"],
        edge_index: Float[Tensor, "2 n_edges"],
        edge_attr: Float[Tensor, "n_edges d_edge"],
    ) -> Float[Tensor, "n_edges"]:
        """Computes edge probabilities from the output of the backbone.

        :param x: Output of the backbone.
        :param edge_index: Edge index.
        :param edge_attr: Edge features.
        :return: Predicted edge probabilities.
        """
        logits = self.prediction_head(x, edge_index, edge_attr)
        return sparse_softmax(edge_index, logits, axis=1)
