from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch


class ATTBackbone(nn.Module):
    """ATT based backbone."""

    def __init__(
        self, net: nn.Module, input_adapter: nn.Module, output_adapter: nn.Module
    ) -> None:
        """Initialize backbone.

        :param att: ATT module.
        :param input_adapter: Input adapter.
        :param output_adapter: Output adapter.
        """
        super().__init__()

        self.att = net
        self.input_adapter = input_adapter
        self.output_adapter = output_adapter

    def forward(self, data: Batch) -> Float[Tensor, "n_edges"]:
        """Perform a forward pass with the backbone.

        :param data: PyG batch object.
        :return: Predicted probability for each edge.
        """
        XA, A, B = self.input_adapter(data)
        t = data.t[data.ptr[:-1]]
        A_pred, B_pred = self.att(XA, A, B, t)
        return self.output_adapter(data, A_pred, B_pred)
